#!/usr/bin/python

# Compute temperature-dependent transition probabilities and rate for a specific lag time from alanine dipeptide parallel tempering data.

#=============================================================================================
# REQUIREMENTS
#
# This code requires the 'pynetcdf' package, containing the Scientific.IO.NetCDF package built for numpy.
#
# http://pypi.python.org/pypi/pynetcdf/
# http://sourceforge.net/project/showfiles.php?group_id=1315&package_id=185504
#
# This code also uses the 'MBAR' package, implementing the multistate Bennett acceptance ratio estimator, available here:
#
# http://www.simtk.org/home/pymbar
#=============================================================================================

#===================================================================================================
# IMPORTS
#===================================================================================================
import numpy
from numpy import *
from math import *
import pymbar # for MBAR analysis
import timeseries # for timeseries analysis
import commands
import os
import os.path
from numpy.linalg import * # linear algebra methods
import datetime # time and date
#from pynetcdf import NetCDF # for writing of data objects for plotting in Matlab or Mathematica
import netCDF4 as netcdf # for writing of data objects for plotting in Matlab or Mathematica

#===================================================================================================
# CONSTANTS
#===================================================================================================

kB = 1.3806503 * 6.0221415 / 4184.0 # Boltzmann constant in kcal/mol/K

#===================================================================================================
# PARAMETERS
#===================================================================================================

data_directory = '../../datasets/alanine-dipeptide/parallel-tempering' # data containing the parallel tempering data
temperature_list_filename = os.path.join(data_directory, 'temps') # file containing temperatures in K
total_energies_filename = os.path.join(data_directory, 'etot.out') # file containing total energies (in kcal/mol) for each temperature and snapshot
trajectory_segment_length = 200 # number of snapshots in each contiguous trajectory segment
ntrajectories = 501 # number of trajectories to use
tau = 60 # lag time in snapshots
tau_unit = 0.1 # sampling time in ps
ninterpolate = 1 # number of interpolated temperatures to add in between parallel tempering temperatures for estimation of rates
counting_method = 'sliding' # How to count transitions to compute Cij -- 'single' or 'sliding'
netcdf_output_filename = 'output/alanine-dipeptide-transition-rate.nc' # netcdf output filename
#use_analytical_momentum = True # if True, will include analytical momentum contribution to partition function in energies
#netcdf_output_filename = 'output/alanine-dipeptide-transition-rate-analytical-momentum.nc' # netcdf output filename
ndof = (3*22-21) + 431*(3*3-3) - 3 # number of degrees of freedom

#===================================================================================================
# SUBROUTINES
#===================================================================================================

def write_file(filename, contents):
   """Write the specified contents to a file.
   
   ARGUMENTS
     filename (string) - the file to be written
     contents (string) - the contents of the file to be written
     
   """

   outfile = open(filename, 'w')
   
   if type(contents) == list:
      for line in contents:
         outfile.write(line)
   elif type(contents) == str:
      outfile.write(contents)
   else:
      raise "Type for 'contents' not supported: " + repr(type(contents))
      
   outfile.close()
   
   return

def read_file(filename):
   """Read contents of the specified file.
      
   ARGUMENTS
     filename (string) - the name of the file to be read
     
   RETURNS
     lines (list of strings) - the contents of the file, split by line

   """

   infile = open(filename, 'r')
   lines = infile.readlines()
   infile.close()
   
   return lines

def logSum(log_terms):
   """Compute the log of a sum of terms whose logarithms are provided.

   REQUIRED ARGUMENTS  
      log_terms is the array (possibly multidimensional) containing the logs of the terms to be summed.

   RETURN VALUES
      log_sum is the log of the sum of the terms.

   """

   # compute the maximum argument
   max_log_term = log_terms.max()

   # compute the reduced terms
   terms = numpy.exp(log_terms - max_log_term)

   # compute the log sum
   log_sum = log( terms.sum() ) + max_log_term

   # return the log sum
   return log_sum

def kronecker(i,j):
   """Kronecker delta.
   """

   if (i == j):
      return 1
   return 0

def delta(a,b,i,j):
   """
   """

   if (((a == i) and (b == j)) or ((a == j) and (b == i))):
      return 1
   return 0


#===================================================================================================
# MAIN
#===================================================================================================

#===================================================================================================
# Initialize the NetCDF file for output of computed data objects.
#===================================================================================================

# Open the NetCDF trajectory file.
print "Opening NetCDF file for writing..."
#netcdf_file = NetCDF.NetCDFFile(netcdf_output_filename, 'w')
netcdf_file = netcdf.Dataset(netcdf_output_filename, 'w', format='NETCDF3_CLASSIC')
    
# Set global attributes.
setattr(netcdf_file, 'title', "Analysis data produced at %s" % datetime.datetime.now().ctime())
setattr(netcdf_file, 'application', 'temperature-dependent-transition-probabilities.py')

#===================================================================================================
# Read temperatures
#===================================================================================================

# Read list of temperatures.
lines = read_file(temperature_list_filename)
# Construct list of temperatures
temperatures = lines[0].split()
# Create numpy array of temperatures
K = len(temperatures)
temperature_k = zeros([K], float32) # temperature_k[k] is temperature of temperature index k in K
for k in range(K):
   temperature_k[k] = float(temperatures[k])
# Compute inverse temperatures
beta_k = (kB * temperature_k)**(-1) 

# Define other constants
N = ntrajectories
T = trajectory_segment_length

# Store in netcdf.
netcdf_file.createDimension('K', K)             # number of temperatures
netcdf_file.createDimension('N', N)             # number of trajectories per temperature
netcdf_file.createDimension('T', T)             # number of snapshots per trajectory
  
variable = netcdf_file.createVariable('temperature_k', 'd', ('K',))
setattr(variable, 'units', 'Kelvin')
setattr(variable, 'description', 'temperature_k[k] is the temperature of temperature index k')
netcdf_file.variables['temperature_k'][:] = temperature_k

variable = netcdf_file.createVariable('beta_k', 'f', ('K',))
setattr(variable, 'units', '1/(kcal/mol)')
setattr(variable, 'description', 'beta_k[k] is the inverse temperature of temperature index k')
netcdf_file.variables['beta_k'][:] = beta_k

#===================================================================================================
# Read eneriges
#===================================================================================================

# Read total energies, averaging total energy over each trajectory segment.

print "Reading total energies..."
E_kn = zeros([K,N], float64) # E_kn[k,n] is the mean total energy over trajectory segment n of temperature index k
lines = read_file(total_energies_filename)
for n in range(N):
   # Allocate storage for all energies for one trajectory segment.
   E_kt = zeros([K,T], float64) # E_kt[k,t] is the total energy of snapshot t of temperature index k
   # Read total energies for all snapshots of a trajectory segment
   for t in range(T):
      # GEet line containing the energies for snapshot t of trajectory segment n
      line = lines[n*T + t]
      # Extract energy values from text
      elements = line.split()
      for k in range(K):
         E_kt[k,t] = float(elements[k])
   # Compute mean and variance of total energies over trajectory segment
   for k in range(K):
      E_kn[k,n] = E_kt[k,:].mean()

# Store in netcdf.
variable = netcdf_file.createVariable('E_kn', 'd', ('K', 'N',))
setattr(variable, 'units', 'kcal/mol')
setattr(variable, 'description', 'E_kn[k,n] is the total energy of trajectory segment n from temperature index k')
netcdf_file.variables['E_kn'][:,:] = E_kn

#===================================================================================================
# Read phi, psi trajectories
#===================================================================================================

print "Reading phi, psi trajectories..."
phi_knt = zeros([K,N,T], float32) # phi_knt[k,n,t] is phi angle (in degrees) for snapshot t of trajectory segment n of temperature k
psi_knt = zeros([K,N,T], float32) # psi_knt[k,n,t] is psi angle (in degrees) for snapshot t of trajectory segment n of temperature k
for k in range(K):
   filename = os.path.join(data_directory, '%d.tors' % (k))
   lines = read_file(filename)
   print "k = %d, %d lines read" % (k, len(lines))
   for n in range(N):
      for t in range(T):
         # Get line containing snapshot t of trajectory n      
         line = lines[n*T + t + 1]
         # Extract phi and psi
         elements = line.split()
         phi_knt[k,n,t] = float(elements[3])
         psi_knt[k,n,t] = float(elements[4])

#===================================================================================================
# Assign discrete state identities
#===================================================================================================

print "Assigning states..."
nstates = 2;
state_knt = zeros([K,N,T], int16) # state_knt[k,n,t] is the state index (0...1) of snapshot t of trajectory segment n of temperature k

state_knt[((psi_knt >=   28) | (psi_knt < -124))] = 0 # C7_eq
state_knt[((psi_knt >= -124) & (psi_knt <   28))] = 1 # alpha_R

#state_knt[((phi_knt >= 0  ) & (phi_knt <  117)) & ((psi_knt >=   -5) & (psi_knt < 111))] = 0 # alpha_L (state 6)
#state_knt[((phi_knt >= 0  ) & (phi_knt <  117)) & ((psi_knt >=  111) | (psi_knt <  -5))] = 0 # C_7^ax (state 5)
#state_knt[((phi_knt >= 117) | (phi_knt < -105)) & ((psi_knt >= -124) & (psi_knt <  28))] = 1 # alpha_P (state 3)
#state_knt[((phi_knt >= 117) | (phi_knt < -105)) & ((psi_knt >=  28) | (psi_knt < -124))] = 1 # C_5 (state 1)

#state_knt[((phi_knt >= 0  ) & (phi_knt <  117))] = 0 # 5 + 6
#state_knt[((phi_knt >= 117) | (phi_knt < 0))] = 1 # 1 + 2 + 3 + 4

# Create netcdf dimension.
netcdf_file.createDimension('nstates', nstates) # number of states

#===================================================================================================
# Compute observable for correlation function
#===================================================================================================

print "Computing observables Akn..."
A_ijkn = zeros([nstates, nstates, K, N], float32)
for i in range(nstates):
   for j in range(nstates):
      if counting_method == 'sliding':
         A_ijkn[i,j,:,:] = mean((state_knt[:,:,0:(T-tau)]==i) & (state_knt[:,:,tau:T]==j) | (state_knt[:,:,0:(T-tau)]==j) & (state_knt[:,:,tau:T]==i), axis=2)
      elif counting_method == 'single':
         A_ijkn[i,j,:,:] = ((state_knt[:,:,0]==i) & (state_knt[:,:,tau]==j) | (state_knt[:,:,0]==j) & (state_knt[:,:,tau]==i))
      else:
         raise "counting_method '%s' unrecognized." % counting_method
      
      if (i != j): A_ijkn[i,j,:,:] /= 2      

# Store in netcdf.
variable = netcdf_file.createVariable('A_ijkn', 'd', ('nstates', 'nstates', 'K', 'N',))
setattr(variable, 'units', 'none')
setattr(variable, 'description', 'A_ijkn[i,j,k,n] is the fractional count of transitions observed from state i to state j in either direction in trajectory segment n of temperature index k')
netcdf_file.variables['A_ijkn'][:,:,:] = A_ijkn

#===================================================================================================
# Estimate maximum likelihood transition matrix at each temperature separately.
#===================================================================================================

print "Estimating transition probabilities..."
# Estimate correlation function Cji(tau) individually for each temperature
C_ijk = mean(A_ijkn, 3) # Cjik[j,i,k] is the estimated correlation function <chi_i(0) chi_j(tau)> from data only at temperature index k
# Estimate population p_i for each state individually for each temperature
P_ik = sum(C_ijk, 1)
# Estimate transition matrix Tji for each state individually for each temperature
T_ijk = zeros([nstates, nstates, K], float64)
for k in range(K):
   for i in range(nstates):      
      T_ijk[i,:,k] = C_ijk[i,:,k] / P_ik[i,k]

# DEBUG
for k in range(K):
   print "Temperature %d : %f K" % (k, temperature_k[k])
   print T_ijk[:,:,k]

# Store in netcdf
variable = netcdf_file.createVariable('C_ijk', 'd', ('nstates', 'nstates', 'K',))
setattr(variable, 'units', 'none')
setattr(variable, 'description', 'C_ijk[i,j,k] is the probability of observing a transition from i to j in either direction from data only at temperature index k')
netcdf_file.variables['C_ijk'][:,:,:] = C_ijk

variable = netcdf_file.createVariable('T_ijk', 'd', ('nstates', 'nstates', 'K',))
setattr(variable, 'units', 'none')
setattr(variable, 'description', 'T_ijk[i,j,k] is the conditional probability of observing a transition to state j from state i estimated from data only at temperature index k')
netcdf_file.variables['T_ijk'][:,:,:] = T_ijk

#===================================================================================================
# Compute estimate of transition rates from counting.
#===================================================================================================

if (nstates == 2):
   # THE FOLLOWING IS SPECIFIC TO TWO-STATE CASES   

   variable = netcdf_file.createVariable('rate_single_k', 'd', ('K', ))
   setattr(variable, 'units', '1/ps')
   setattr(variable, 'description', 'rate_single_k[k] is the phenomenological rate constant at temperature index k estimated from only data at that temperature')

   variable = netcdf_file.createVariable('drate_single_k', 'd', ('K', ))
   setattr(variable, 'units', '1/ps')
   setattr(variable, 'description', 'drate_single_k[k] is the uncertainty in rate_single_k[k]')

   for k in range(K):
      # Compute \alpha and \beta quantities and their uncertainties

      temperature = temperature_k[k]

      # \alpha = <\chi_1>
      # \beta = <\chi_1(0) \chi_2(\tau)>

      alpha_n = A_ijkn[0,0,k,:]+A_ijkn[0,1,k,:]
      beta_n = A_ijkn[0,1,k,:]

      # Compute alpha, beta, and their uncertainties
      alpha = alpha_n.mean()
      beta = beta_n.mean()
      covariance = numpy.cov(alpha_n, beta_n)
      d2alpha = covariance[0,0] / float(N)
      d2beta = covariance[1,1] / float(N)
      dalphadbeta = covariance[0,1] / float(N)
      
      # Compute Perron eigenvalue of 2x2 transition matrix
      # mu = 1 - (T_12 + T_21) = 1 - \beta \alpha^{-1} (1-\alpha)^{-1}
      mu = 1 - beta / alpha / (1-alpha)
      dmu_dalpha = beta * (1-2*alpha) / alpha**2 / (1-alpha)**2 # JDC FIXED
      dmu_dbeta = + 1 / alpha / (1-alpha) # JDC FIXED 
      d2mu = dmu_dalpha**2 * d2alpha + dmu_dbeta**2 * d2beta + 2*dmu_dalpha*dmu_dbeta * dalphadbeta # JDC FIXED
      
      # Compute phenomenological transition rate
      # k = - log(mu) / tau
      if (mu > 0.0):
         rate = - log(mu) / (tau * tau_unit) # phenomenological rate constant
         d2rate = d2mu / (mu**2) / (tau * tau_unit)**2
         print "rate at %6.1f K : %24.12f +- %24.12f" % (temperature, rate, sqrt(d2rate))
         netcdf_file.variables['rate_single_k'][k] = rate
         netcdf_file.variables['drate_single_k'][k] = sqrt(d2rate)
      else:
         # Cannot estimate rate.
         netcdf_file.variables['rate_single_k'][k] = 0.0
         netcdf_file.variables['drate_single_k'][k] = 0.0

#===================================================================================================
# Compute reduced potential energy of each trajectory segment in each condition
#===================================================================================================

print "Computing energies..."
u_kln = zeros([K,K,N], float64) # u_kln[k,l,n] is reduced potential energy of trajectory segment n of temperature k evaluated at temperature l
for k in range(K):
   for l in range(K):
      for n in range(N):
         u_kln[k,l,n] = beta_k[l] * E_kn[k,n]
   
N_k = zeros([K], int32)
N_k[:] = N

#===================================================================================================
# Read in guess of free energies
#===================================================================================================

# Read free energies, if they exist.
free_energy_filename = 'f_k.dat'
if os.path.exists(free_energy_filename):
   print "Reading free energies from file %s..." % free_energy_filename
   lines = read_file(free_energy_filename)
   contents = ""
   for line in lines:
      contents += line.strip() + ' '
   elements = contents.split()
   f_k_initial = zeros([K], float64)
   for k in range(K):
      f_k_initial[k] = float(elements[k])
   print f_k_initial
else:
   f_k_initial = None

#===================================================================================================
# Initialize MBAR to compute free energy estimates
#===================================================================================================

# Initialize MBAR 
print "Initializing MBAR (will estimate free energy differences first time)..."
mbar = pymbar.MBAR(u_kln, N_k, method = 'self-consistent-iteration', verbose = True, initial_f_k = f_k_initial, relative_tolerance = 1.0e-8)

print "Converged free energies = "
print mbar.f_k

# Write converged free energies
outfile = open(free_energy_filename, 'w')
for k in range(K):
   outfile.write('%30.20e\n' % mbar.f_k[k])
outfile.close()

# Store converged free energies in netcdf.
variable = netcdf_file.createVariable('f_k', 'd', ('K',))
setattr(variable, 'units', 'none')
setattr(variable, 'description', 'f_k[k] is the dimensionless free energy of state k, up to an irrelevant additive constant')
netcdf_file.variables['f_k'][:] = mbar.f_k

#===================================================================================================
# Compute transition probability and as a function of energy.
#===================================================================================================

print "Constructing energy histograms for productive trajectories..."

M = 50 # total energy histogram bins
Emin = E_kn.min() # minimum total energy
Emax = E_kn.max() # maximum total energy
SMALL = 1.0e-3
delta_E = (Emax + SMALL - Emin) / float(M) # width of energy bin
print "Energy bin width for %d histogram bins spanning %.3f to %.3f kcal/mol is %.3f kcal/mol" % (M, Emin, Emax, delta_E)

# Construct energy bin centers.
E_m = linspace(Emin + delta_E/2, Emax - delta_E/2, M)

# Compute histogram
A_ijm = zeros([nstates, nstates, M])
for m in range(M):
   # Select all trajectories in this bin using boolean selection criteria.
   selection = (E_m[m] - delta_E/2 <= E_kn) & (E_kn < E_m[m] + delta_E/2)
   print "m = %5d, selection.size = %d" % (m, selection.size)
   # Accumulate histogram statistics.
   for i in range(nstates):
      for j in range(nstates):
         # Compute mean probability of observing (i,j) transition in either direction for trajectory energies in this bin.
         A_kn = A_ijkn[i,j,:,:]
         A_ijm[i,j,m] = A_kn[selection].sum() / selection.sum()
   
# Store in netcdf.
dimension = netcdf_file.createDimension('M', M) # number of histogram bins
# TODO: Can we set attributes for dimensions?

variable = netcdf_file.createVariable('E_m', 'd', ('M',))
setattr(variable, 'units', 'kcal/mol')
setattr(variable, 'description', 'E_m[m] is the midpoint of energy bin m')
netcdf_file.variables['E_m'][:] = E_m

variable = netcdf_file.createVariable('A_ijm', 'd', ('nstates','nstates','M'))
setattr(variable, 'units', 'none')
setattr(variable, 'description', 'A_ijm[i,j,m] is the fraction of trajectories observed to initiate in state i and terminate in state j (or vice-versa) in total energy bin m')
netcdf_file.variables['A_ijm'][:,:,:] = A_ijm

#===================================================================================================
# Compute total contribution weight to each transition from various temperatures.
#===================================================================================================

print "Computing contribution to each transition from various temperatures..."

log_w_ijkl = zeros([nstates, nstates, K, K], float64) # log_w_ijkl[i,j,k,l] is the contribution from temperature l to the estimation of transitions between i and j at temperature k
log_w_ikl = zeros([nstates, K, K], float64) # log_w_ikl[i,k,l] is the contribution from temperature l to the estimation of occupation probability of state i at temperature k
log_w_kl = zeros([K, K], float64) # log_w_kl[k,l] is the contribution from temperature l to the estimation of any expectations at temperature k

for k in range(K):
   # Compute trajectory log weights.
   log_w_kn = mbar._computeUnnormalizedLogWeights(beta_k[k] * E_kn)
   LOG_ZERO = log_w_kn.min() - 1000
   # Sum up log weights by transition (i,j) from each temperature l.
   for l in range(K):
      for i in range(nstates):
         for j in range(nstates):
            selection = (A_ijkn[i,j,l,:] > 0)
            if (selection.sum() > 0):
               log_w_ijkl[i,j,k,l] = pymbar.logsum(log_w_kn[l,selection] + numpy.log(A_ijkn[i,j,l,selection]))
            else:
               log_w_ijkl[i,j,k,l] = LOG_ZERO
         log_w_ikl[i,k,l] = pymbar.logsum(log_w_ijkl[i,:,k,l])
      log_w_kl[k,l] = pymbar.logsum(log_w_ikl[:,k,l])

# Store in netcdf.
variable = netcdf_file.createVariable('log_w_ijkl', 'd', ('nstates', 'nstates', 'K', 'K',))
setattr(variable, 'units', 'none')
setattr(variable, 'description', 'log_w_ijkl[i,j,k,l] is the contribution from temperature l to the estimation of transitions between i and j at temperature k')
netcdf_file.variables['log_w_ijkl'][:,:,:,:] = log_w_ijkl

variable = netcdf_file.createVariable('log_w_ikl', 'd', ('nstates', 'K', 'K',))
setattr(variable, 'units', 'none')
setattr(variable, 'description', 'log_w_ikl[i,k,l] is the contribution from temperature l to the estimation of occupation probability of state i at temperature k')
netcdf_file.variables['log_w_ikl'][:,:,:] = log_w_ikl

variable = netcdf_file.createVariable('log_w_kl', 'd', ('K', 'K',))
setattr(variable, 'units', 'none')
setattr(variable, 'description', 'log_w_kl[k,l] is the contribution from temperature l to the estimation of any expectations at temperature k')
netcdf_file.variables['log_w_kl'][:,:] = log_w_kl

#===================================================================================================
# Compute probability distribution for finding the system in a given energy bin as a function of temperature.
#===================================================================================================

print "Constructing energy probability histograms at various temperatures..."

P_km = zeros([K, M], float64) # P_km[k,m] is the probability of find the system in energy bin m at temperature \beta_k.
for k in range(K):
   # Compute trajectory log weights.
   log_w_kn = mbar._computeUnnormalizedLogWeights(beta_k[k] * E_kn)
   # Sum up log weights in each bin.
   log_P_m = zeros([M], float64) # log_P_m[m] is the log probability of finding the system in energy bin m
   for m in range(M):
      # Select all trajectories in this energy bin.
      selection = (E_m[m] - delta_E/2 <= E_kn) & (E_kn < E_m[m] + delta_E/2)
      # Accumulate log probability weight.
      log_P_m[m] = pymbar.logsum(log_w_kn[selection])
   # Normalize.
   P_km[k,:] = numpy.exp(log_P_m - log_P_m.max())
   P_km[k,:] /= sum(P_km[k,:])

#===================================================================================================
# Compute optimal estimates of unnormalized correlation functions at multiple temperatures.
#===================================================================================================

# Define list of temperatures at which we want to evaluate rates.
Kinterpolated = K + ninterpolate*(K-1)
interpolated_temperatures = zeros([K + ninterpolate*(K-1)], float64)
for k in range(K-1):
   # fix temperatures at parallel tempering temperatures
   interpolated_temperatures[k*(ninterpolate+1)] = temperature_k[k]
   # linear interpolation of intermediate temperatures
   for i in range(ninterpolate):
      interpolated_temperatures[k*(ninterpolate+1) + (i+1)] = (temperature_k[k+1] - temperature_k[k]) / (ninterpolate+1) * (i+1) + temperature_k[k]
# add last achored parallel tempering temperature
interpolated_temperatures[(K-1)*(ninterpolate+1)] = temperature_k[K-1]

# DEBUG
print "interpolated_temperatures = "
print interpolated_temperatures

# Create netcdf variable for storage of transition matrices.
netcdf_file.createDimension('Kinterpolated', Kinterpolated)             # number of interpolated temperatures at which transition matrix is estimated

variable = netcdf_file.createVariable('interpolated_temperature_k', 'd', ('Kinterpolated', ))
setattr(variable, 'units', 'Kelvin')
setattr(variable, 'description', 'interpolated_temperature_k[k] is the set of interpolated temperatures at which the transition matrix Tr_ijk is estimated')
netcdf_file.variables['interpolated_temperature_k'][:] = interpolated_temperatures

variable = netcdf_file.createVariable('Tr_ijk', 'd', ('nstates', 'nstates', 'Kinterpolated', ))
setattr(variable, 'units', 'none')
setattr(variable, 'description', 'Tr[i,j,k] is the conditional transition probability from i to j at temperature index k estimated by reweighting all data')

variable = netcdf_file.createVariable('dTr_ijk', 'd', ('nstates', 'nstates', 'Kinterpolated', ))
setattr(variable, 'units', 'none')
setattr(variable, 'description', 'dTr[i,j,k] is the estimated statistical uncertainty (one standard deviation) of Tr_ijk[i,j,k]')

if (nstates == 2):
   # THE FOLLOWING ARE SPECIFIC TO TWO-STATE CASE
   
   variable = netcdf_file.createVariable('rate_k', 'd', ('Kinterpolated', ))
   setattr(variable, 'units', '1/ps')
   setattr(variable, 'description', 'rate_k[k] is the phenomenological rate constant at interpolated temperature index k estimated from reweighting all data')

   variable = netcdf_file.createVariable('drate_k', 'd', ('Kinterpolated', ))
   setattr(variable, 'units', '1/ps')
   setattr(variable, 'description', 'drate_k[k] is the uncertainty in rate_k[k]')

   variable = netcdf_file.createVariable('mu_k', 'd', ('Kinterpolated', ))
   setattr(variable, 'units', 'none')
   setattr(variable, 'description', 'mu_k[k] is Perron eigenvalue for interpolated temperature index k')

   variable = netcdf_file.createVariable('dmu_k', 'd', ('Kinterpolated', ))
   setattr(variable, 'units', 'none')
   setattr(variable, 'description', 'dmu_k[k] is the uncertainty in mu_k[k]')
   
   variable = netcdf_file.createVariable('K_ijk', 'd', ('nstates', 'nstates', 'Kinterpolated', ))
   setattr(variable, 'units', 'none')
   setattr(variable, 'description', 'K_ijk[i,j,k] is the estimated rate constant from state i to state j at interpolated temperature index k')

# Compute expectations
ninterpolated_temperatures = len(interpolated_temperatures)
for temperature_index in range(ninterpolated_temperatures):
   # DEBUG
   print "Computing expectations for temperature %d/%d (%.1f K) by reweighting..." % (temperature_index, ninterpolated_temperatures, interpolated_temperatures[temperature_index])
   
   # Compute u_kn at this temperature
   temperature = interpolated_temperatures[temperature_index]
   beta = 1.0 / (kB * temperature)
   u_kn = beta * E_kn

   # Form elements into observables
   A_ikn = zeros([nstates*nstates, K, N], float64)
   index_ij = 0
   for i in range(nstates):
      for j in range(nstates):
         A_ikn[index_ij,:,:] = A_ijkn[i,j,:,:]
         index_ij += 1

   # Compute expectations of all Cij by reweighting.   
   (A_i, d2A_ij) = mbar.computeMultipleExpectations(A_ikn, u_kn)
   
   # Form estimates of correlation functions
   Cij = zeros([nstates, nstates], float64)
   index = 0
   for i in range(nstates):
      for j in range(nstates):
         Cij[i,j] = A_i[index]
         index += 1
         
   # Extract uncertainties in Cij.
   dCijdCkl = zeros([nstates,nstates,nstates,nstates], float64)
   index_ij = 0
   for i in range(nstates):
      for j in range(nstates):
         index_kl = 0
         for k in range(nstates):
            for l in range(nstates):
               dCijdCkl[i,j,k,l] = d2A_ij[index_ij, index_kl]
               index_kl += 1
         index_ij += 1

   # Form transition matrix estimates
   Tij = zeros([nstates, nstates], float64)
   for i in range(nstates):
      for j in range(nstates):
         Tij[i,j] = Cij[i,j] / Cij[i,:].sum()

   # Write to netcdf
   netcdf_file.variables['Tr_ijk'][:,:,temperature_index] = Tij

   # Compute estimate of stationary probabilities.
   print "Pi = "
   Pi = Cij.sum(1)
   print Pi
   
   # Compute derivates of Tij in terms of Cij
   dTabdCij = zeros([nstates, nstates, nstates, nstates], float64) # dTabdCij is derivative of Tab with respect to Cij
   for a in range(nstates):
      for b in range(nstates):
         for i in range(nstates):
            for j in range(nstates):
               dTabdCij[a,b,i,j] = delta(a,b,i,j) / Pi[a]
               for g in range(nstates):
                  dTabdCij[a,b,i,j] += - Tij[a,b]/Pi[a] * delta(a,g,i,j)

   # Propagate uncertainty into dTijdTkl
   print "Computing dTijdTkl..."
#   dTijdTkl = zeros([nstates, nstates, nstates, nstates], float64)
#   for a in range(nstates):
#      for b in range(nstates):
#         for c in range(nstates):
#            for d in range(nstates):
#               for i in range(nstates):
#                  for j in range(nstates):
#                     for k in range(nstates):
#                        for l in range(nstates):
#                           dTijdTkl[a,b,c,d] += dTabdCij[a,b,i,j] * dTabdCij[c,d,k,l] * dCijdCkl[i,j,k,l]


   # Propagate uncertainty into d2Tij
   d2Tij = zeros([nstates, nstates], float64)
   for a in range(nstates):
      for b in range(nstates):
         for i in range(nstates):
            for j in range(nstates):
               for k in range(nstates):
                  for l in range(nstates):
                     d2Tij[a,b] += dTabdCij[a,b,i,j] * dTabdCij[a,b,k,l] * dCijdCkl[i,j,k,l]

   # Compute std dev of elements.
   dTij = numpy.sqrt(d2Tij)

   # Write to netcdf
   netcdf_file.variables['dTr_ijk'][:,:,temperature_index] = dTij
   
   if (nstates == 2):
      # THE FOLLOWING IS SPECIFIC TO TWO-STATE CASES

      # Compute \alpha and \beta quantities and their uncertainties

      # \alpha = <\chi_1>
      # \beta = <\chi_1(0) \chi_2(\tau)>

      # Form observables for \alpha and \beta.
      A_ikn = zeros([2, K, N], float64)
      A_ikn[0,:,:] = sum(A_ijkn[0,:,:,:],axis=0) # alpha
      A_ikn[1,:,:] = A_ijkn[0,1,:,:] # beta
            
      # Compute expectations by reweighting.
      (A_i, d2A_ij) = mbar.computeMultipleExpectations(A_ikn, u_kn)

      # Extract alpha, beta, and their uncertainties
      alpha = A_i[0]
      beta = A_i[1]
      d2alpha = d2A_ij[0,0]
      d2beta = d2A_ij[1,1]
      dalphadbeta = d2A_ij[0,1]
      
      # Compute Perron eigenvalue of 2x2 transition matrix
      # mu = 1 - (T_12 + T_21) = 1 - \beta \alpha^{-1} (1-\alpha)^{-1}
      mu = 1 - beta / alpha / (1-alpha)
      dmu_dalpha = beta / alpha**2 / (1-alpha)**2
      dmu_dbeta = - 1 / alpha / (1-alpha)
      d2mu = dmu_dalpha**2 * d2alpha + dmu_dbeta**2 * d2beta + dmu_dalpha*dmu_dbeta * dalphadbeta
      netcdf_file.variables['mu_k'][temperature_index] = mu
      netcdf_file.variables['dmu_k'][temperature_index] = sqrt(d2mu)
      
      # Compute phenomenological transition rate
      # k = - log(mu) / tau
      rate = - log(mu) / (tau * tau_unit) # phenomenological rate constant
      d2rate = d2mu / (mu**2) / (tau * tau_unit)**2
      netcdf_file.variables['rate_k'][temperature_index] = rate
      netcdf_file.variables['drate_k'][temperature_index] = sqrt(d2rate)
      print "rate at %6.1f K : %24.12f +- %24.12f" % (temperature, rate, sqrt(d2rate))

      # Compute partitioned transition rates
      netcdf_file.variables['K_ijk'][0,1,temperature_index] = rate * Pi[1]
      netcdf_file.variables['K_ijk'][1,0,temperature_index] = rate * Pi[0]
      # TODO: Compute uncertainties in partitioned transition rates.

#===================================================================================================
# Write Matlab-friendly summary of rates.
#===================================================================================================
rates_filename = 'output/alanine-dipeptide-observed-rates.dat'
outfile = open(rates_filename, 'w')
for k in range(K):
   outfile.write('%8.3f %24e %24e\n' % (temperature_k[k], netcdf_file.variables['rate_single_k'][k], netcdf_file.variables['drate_single_k'][k]))

outfile.close()
rates_filename = 'output/alanine-dipeptide-interpolated-rates.dat'
outfile = open(rates_filename, 'w')
for k in range(ninterpolated_temperatures):
   outfile.write('%8.3f %24e %24e\n' % (interpolated_temperatures[k], netcdf_file.variables['rate_k'][k], netcdf_file.variables['drate_k'][k]))

outfile.close()

#===================================================================================================
# Close NetCDF file.
#===================================================================================================

print "Done.\nErrors after this point can be ignored.\n\n"
netcdf_file.close()

