#!/usr/bin/python

# Compute temperature-dependent transition rate for alanine dipeptide as a function of temperature from parallel tempering data.

#=============================================================================================
# REFERENCE
#
# Chodera JD, Dill KA, and Pande VS. Dynamical reweighting: Improved estimates of dynamical properties
# from combining data at multiple temperatures.
#=============================================================================================

#=============================================================================================
# REQUIREMENTS
#
# This code requires the 'pynetcdf' package, containing the Scientific.IO.NetCDF package built for numpy.
#
# http://pypi.python.org/pypi/pynetcdf/
# http://sourceforge.net/project/showfiles.php?group_id=1315&package_id=185504
#
# This code also uses the 'MBAR' package, implementing the multistate Bennett acceptance ratio estimator, available here:
#
# http://www.simtk.org/home/pymbar
#=============================================================================================

#===================================================================================================
# IMPORTS
#===================================================================================================
import numpy
from numpy import *
from math import *
import MBAR # for MBAR analysis
import timeseries # for timeseries analysis
import commands
import os
import os.path
from numpy.linalg import * # linear algebra methods
import datetime # time and date
from pynetcdf import NetCDF # for writing of data objects for plotting in Matlab or Mathematica

#===================================================================================================
# CONSTANTS
#===================================================================================================

kB = 1.3806503 * 6.0221415 / 4184.0 # Boltzmann constant in kcal/mol/K

#===================================================================================================
# PARAMETERS
#===================================================================================================

data_directory = '../../datasets/alanine-dipeptide/parallel-tempering' # data containing the parallel tempering data
temperature_list_filename = os.path.join(data_directory, 'temps') # file containing temperatures in K
total_energies_filename = os.path.join(data_directory, 'etot.out') # file containing total energies (in kcal/mol) for each temperature and snapshot
trajectory_segment_length = 200 # number of snapshots in each contiguous trajectory segment
ntrajectories = 501 # number of trajectories to use
Tmax = 100 # maximum lag time (in snapshots) for which transition probabilities are to be computed
ninterpolate = 1 # number of interpolated temperatures to add in between parallel tempering temperatures for estimation of rates
counting_method = 'sliding' # How to count transitions to compute Cij -- 'single' or 'sliding'
netcdf_output_filename = 'output/alanine-dipeptide-reactive-flux.nc' # netcdf output filename

#===================================================================================================
# SUBROUTINES
#===================================================================================================

def write_file(filename, contents):
   """Write the specified contents to a file.
   
   ARGUMENTS
     filename (string) - the file to be written
     contents (string) - the contents of the file to be written
     
   """

   outfile = open(filename, 'w')
   
   if type(contents) == list:
      for line in contents:
         outfile.write(line)
   elif type(contents) == str:
      outfile.write(contents)
   else:
      raise "Type for 'contents' not supported: " + repr(type(contents))
      
   outfile.close()
   
   return

def read_file(filename):
   """Read contents of the specified file.
      
   ARGUMENTS
     filename (string) - the name of the file to be read
     
   RETURNS
     lines (list of strings) - the contents of the file, split by line

   """

   infile = open(filename, 'r')
   lines = infile.readlines()
   infile.close()
   
   return lines

def logSum(log_terms):
   """Compute the log of a sum of terms whose logarithms are provided.

   REQUIRED ARGUMENTS  
      log_terms is the array (possibly multidimensional) containing the logs of the terms to be summed.

   RETURN VALUES
      log_sum is the log of the sum of the terms.

   """

   # compute the maximum argument
   max_log_term = log_terms.max()

   # compute the reduced terms
   terms = numpy.exp(log_terms - max_log_term)

   # compute the log sum
   log_sum = log( terms.sum() ) + max_log_term

   # return the log sum
   return log_sum

def kronecker(i,j):
   """Kronecker delta.
   """

   if (i == j):
      return 1
   return 0

def delta(a,b,i,j):
   """
   """

   if (((a == i) and (b == j)) or ((a == j) and (b == i))):
      return 1
   return 0


#===================================================================================================
# MAIN
#===================================================================================================

#===================================================================================================
# Initialize the NetCDF file for output of computed data objects.
#===================================================================================================

# Open the NetCDF trajectory file.
print "Opening NetCDF file for writing..."
netcdf_file = NetCDF.NetCDFFile(netcdf_output_filename 'w')
    
# Set global attributes.
setattr(netcdf_file, 'title', "Analysis data produced at %s" % datetime.datetime.now().ctime())
setattr(netcdf_file, 'application', 'temperature-dependent-transition-probabilities.py')

#===================================================================================================
# Read temperatures
#===================================================================================================

# Read list of temperatures.
lines = read_file(temperature_list_filename)
# Construct list of temperatures
temperatures = lines[0].split()
# Create numpy array of temperatures
K = len(temperatures)
temperature_k = zeros([K], float32) # temperature_k[k] is temperature of temperature index k in K
for k in range(K):
   temperature_k[k] = float(temperatures[k])
# Compute inverse temperatures
beta_k = (kB * temperature_k)**(-1) 

# Define other constants
N = ntrajectories
T = trajectory_segment_length

# Store in netcdf.
netcdf_file.createDimension('K', K)             # number of temperatures
netcdf_file.createDimension('N', N)             # number of trajectories per temperature
netcdf_file.createDimension('T', T)             # number of snapshots per trajectory
netcdf_file.createDimension('Tmax', Tmax)       # number of snapshot intervals at which transition probabilities are evaluated

variable = netcdf_file.createVariable('temperature_k', 'd', ('K',))
setattr(variable, 'units', 'Kelvin')
setattr(variable, 'description', 'temperature_k[k] is the temperature of temperature index k')
netcdf_file.variables['temperature_k'][:] = temperature_k

variable = netcdf_file.createVariable('beta_k', 'f', ('K',))
setattr(variable, 'units', '1/(kcal/mol)')
setattr(variable, 'description', 'beta_k[k] is the inverse temperature of temperature index k')
netcdf_file.variables['beta_k'][:] = beta_k

#===================================================================================================
# Read eneriges
#===================================================================================================

# Read total energies, averaging total energy over each trajectory segment.

print "Reading total energies..."
E_kn = zeros([K,N], float64) # E_kn[k,n] is the mean total energy over trajectory segment n of temperature index k
lines = read_file(total_energies_filename)
for n in range(N):
   # Allocate storage for all energies for one trajectory segment.
   E_kt = zeros([K,T], float64) # E_kt[k,t] is the total energy of snapshot t of temperature index k
   # Read total energies for all snapshots of a trajectory segment
   for t in range(T):
      # GEet line containing the energies for snapshot t of trajectory segment n
      line = lines[n*T + t]
      # Extract energy values from text
      elements = line.split()
      for k in range(K):
         E_kt[k,t] = float(elements[k])
   # Compute mean and variance of total energies over trajectory segment
   for k in range(K):
      E_kn[k,n] = E_kt[k,:].mean()

# Store in netcdf.
variable = netcdf_file.createVariable('E_kn', 'd', ('K', 'N',))
setattr(variable, 'units', 'kcal/mol')
setattr(variable, 'description', 'E_kn[k,n] is the total energy of trajectory segment n from temperature index k')
netcdf_file.variables['E_kn'][:,:] = E_kn

#===================================================================================================
# Read phi, psi trajectories
#===================================================================================================

print "Reading phi, psi trajectories..."
phi_knt = zeros([K,N,T], float32) # phi_knt[k,n,t] is phi angle (in degrees) for snapshot t of trajectory segment n of temperature k
psi_knt = zeros([K,N,T], float32) # psi_knt[k,n,t] is psi angle (in degrees) for snapshot t of trajectory segment n of temperature k
for k in range(K):
   filename = os.path.join(data_directory, '%d.tors' % (k))
   lines = read_file(filename)
   print "k = %d, %d lines read" % (k, len(lines))
   for n in range(N):
      for t in range(T):
         # Get line containing snapshot t of trajectory n      
         line = lines[n*T + t + 1]
         # Extract phi and psi
         elements = line.split()
         phi_knt[k,n,t] = float(elements[3])
         psi_knt[k,n,t] = float(elements[4])

#===================================================================================================
# Assign discrete state identities
#===================================================================================================

print "Assigning states..."
nstates = 2;
state_knt = zeros([K,N,T], int16) # state_knt[k,n,t] is the state index (0...1) of snapshot t of trajectory segment n of temperature k

state_knt[((psi_knt >=   28) | (psi_knt < -124))] = 0 # C7_eq
state_knt[((psi_knt >= -124) & (psi_knt <   28))] = 1 # alpha_R

# Create netcdf dimension.
netcdf_file.createDimension('nstates', nstates) # number of states

#===================================================================================================
# Compute reduced total energy of each trajectory segment in each condition
#===================================================================================================

print "Computing energies..."
u_kln = zeros([K,K,N], float64) # u_kln[k,l,n] is reduced total energy of trajectory segment n of temperature k evaluated at temperature l
for k in range(K):
   for l in range(K):
      for n in range(N):
         u_kln[k,l,n] = beta_k[l] * E_kn[k,n]
   
N_k = zeros([K], int32)
N_k[:] = N

#===================================================================================================
# Read in guess of free energies
#===================================================================================================

# Read free energies, if they exist.
free_energy_filename = 'f_k.dat'
if os.path.exists(free_energy_filename):
   print "Reading free energies from file %s..." % free_energy_filename
   lines = read_file(free_energy_filename)
   contents = ""
   for line in lines:
      contents += line.strip() + ' '
   elements = contents.split()
   f_k_initial = zeros([K], float64)
   for k in range(K):
      f_k_initial[k] = float(elements[k])
   print f_k_initial
else:
   f_k_initial = None

#===================================================================================================
# Initialize MBAR to compute free energy estimates
#===================================================================================================

# Initialize MBAR 
print "Initializing MBAR (will estimate free energy differences first time)..."
mbar = MBAR.MBAR(u_kln, N_k, method = 'self-consistent-iteration', verbose = True, initial_f_k = f_k_initial, relative_tolerance = 1.0e-8)

print "Converged free energies = "
print mbar.f_k

# Write converged free energies
outfile = open(free_energy_filename, 'w')
for k in range(K):
   outfile.write('%30.20e\n' % mbar.f_k[k])
outfile.close()

# Store converged free energies in netcdf.
variable = netcdf_file.createVariable('f_k', 'd', ('K',))
setattr(variable, 'units', 'none')
setattr(variable, 'description', 'f_k[k] is the dimensionless free energy of state k, up to an irrelevant additive constant')
netcdf_file.variables['f_k'][:] = mbar.f_k

#===================================================================================================
# Compute estimate of transition probabilities at each temperature from counting (including time-reverse).
#===================================================================================================

print "Computing estimate of transition probabilities by counting."

C_ijkt = zeros([nstates, nstates, K, Tmax], float64) # C_ijkt[i,j,k,t] is the probability of observing a trajectory that starts in state i and ends in state j some (t+1) time intervals later at temperature index k
T_ijkt = zeros([nstates, nstates, K, Tmax], float64) # T_ijkt[i,j,k,t] is the probability of observing a trajectory end in state j given it was in state i (t+1) time intervals earlier, at temperature index k
for t in range(Tmax):
   # Define lag time
   tau = t + 1
   # Accumulate estimate of unnormalized correlation function Cij(t) = <chi_i(0) chi_j(t)>
   for i in range(nstates):
      for j in range(nstates):         
         # Accumulate contributions from all trajectory segments at each temperature.
         if counting_method == 'sliding':
            C_ijkt[i,j,:,t] = mean(mean((state_knt[:,:,0:(T-tau)]==i) & (state_knt[:,:,tau:T]==j) | (state_knt[:,:,0:(T-tau)]==j) & (state_knt[:,:,tau:T]==i), axis=2), axis=1)
         elif counting_method == 'single':
            C_ijkt[i,j,:,t] = mean(((state_knt[:,:,0]==i) & (state_knt[:,:,tau]==j) | (state_knt[:,:,0]==j) & (state_knt[:,:,tau]==i)), axis=1)
         else:
            raise "counting_method '%s' unrecognized." % counting_method
      if (i != j): C_ijkt[i,j,:,t] /= 2
   # Compute transition probabilities by normalizing.
   for i in range(nstates):
      for k in range(K):
         T_ijkt[i,:,k,t] = C_ijkt[i,:,k,t] / sum(C_ijkt[i,:,k,t])

# Store correlation functions in netcdf.
variable = netcdf_file.createVariable('C_ijkt', 'd', ('nstates', 'nstates', 'K', 'Tmax',))
setattr(variable, 'units', 'none')
setattr(variable, 'description', 'C_ijkt[i,j,k,t] is the probability of observing a trajectory that starts in state i and ends in state j some (t+1) time intervals later at temperature index k')
netcdf_file.variables['C_ijkt'][:,:,:,:] = C_ijkt
         
# Store estimated transition probabilities in netcdf.
variable = netcdf_file.createVariable('T_ijkt', 'd', ('nstates', 'nstates', 'K', 'Tmax',))
setattr(variable, 'units', 'none')
setattr(variable, 'description', 'T_ijkt[i,j,k,t] is the probability of observing a trajectory end in state j given it was in state i (t+1) time intervals earlier, at temperature index k')
netcdf_file.variables['T_ijkt'][:,:,:,:] = T_ijkt

#===================================================================================================
# Compute optimal estimates of unnormalized correlation functions at multiple temperatures.
#===================================================================================================

# Define list of temperatures at which we want to evaluate rates.
Kinterpolated = K + ninterpolate*(K-1)
interpolated_temperatures = zeros([K + ninterpolate*(K-1)], float64)
for k in range(K-1):
   # fix temperatures at parallel tempering temperatures
   interpolated_temperatures[k*(ninterpolate+1)] = temperature_k[k]
   # linear interpolation of intermediate temperatures
   for i in range(ninterpolate):
      interpolated_temperatures[k*(ninterpolate+1) + (i+1)] = (temperature_k[k+1] - temperature_k[k]) / (ninterpolate+1) * (i+1) + temperature_k[k]
# add last achored parallel tempering temperature
interpolated_temperatures[(K-1)*(ninterpolate+1)] = temperature_k[K-1]

# DEBUG
print "interpolated_temperatures = "
print interpolated_temperatures

# Create netcdf variable for storage of transition matrices.
netcdf_file.createDimension('Kinterpolated', Kinterpolated)             # number of interpolated temperatures at which transition matrix is estimated

variable = netcdf_file.createVariable('interpolated_temperature_k', 'd', ('Kinterpolated', ))
setattr(variable, 'units', 'Kelvin')
setattr(variable, 'description', 'interpolated_temperature_k[k] is the set of interpolated temperatures at which the transition matrix Tr_ijk is estimated')
netcdf_file.variables['interpolated_temperature_k'][:] = interpolated_temperatures

variable = netcdf_file.createVariable('Tr_ijkt', 'd', ('nstates', 'nstates', 'Kinterpolated', 'Tmax', ))
setattr(variable, 'units', 'none')
setattr(variable, 'description', 'Tr[i,j,k,t] is the conditional transition probability from i to j at temperature index k for (t+1) time intervals estimated by reweighting all data')

variable = netcdf_file.createVariable('dTr_ijkt', 'd', ('nstates', 'nstates', 'Kinterpolated', 'Tmax', ))
setattr(variable, 'units', 'none')
setattr(variable, 'description', 'dTr[i,j,k,t] is the estimated statistical uncertainty (one standard deviation) of Tr_ijk[i,j,k,t]')

# Compute expectations
ninterpolated_temperatures = len(interpolated_temperatures)
for t in range(Tmax):
   # Compute lag time
   tau = t+1

   # DEBUG
   print "Lag time tau = %d" % tau
      
   # Compute A_ijkn for this lag time.
   print "Computing observables Akn..."
   A_ijkn = zeros([nstates, nstates, K, N], float32)
   for i in range(nstates):
      for j in range(nstates):
         if counting_method == 'sliding':
            A_ijkn[i,j,:,:] = mean((state_knt[:,:,0:(T-tau)]==i) & (state_knt[:,:,tau:T]==j) | (state_knt[:,:,0:(T-tau)]==j) & (state_knt[:,:,tau:T]==i), axis=2)
         elif counting_method == 'single':
            A_ijkn[i,j,:,:] = ((state_knt[:,:,0]==i) & (state_knt[:,:,tau]==j) | (state_knt[:,:,0]==j) & (state_knt[:,:,tau]==i))
         else:
            raise "counting_method '%s' unrecognized." % counting_method
         
         if (i != j): A_ijkn[i,j,:,:] /= 2      


   for temperature_index in range(ninterpolated_temperatures):
      # DEBUG
      print "Computing expectations for temperature %d/%d (%.1f K) by reweighting..." % (temperature_index, ninterpolated_temperatures, interpolated_temperatures[temperature_index])
      
      # Compute u_kn at this temperature
      beta = 1.0 / (kB * interpolated_temperatures[temperature_index])
      u_kn = beta * E_kn
      
      # Form elements into observables
      A_ikn = zeros([nstates*nstates, K, N], float64)
      index_ij = 0
      for i in range(nstates):
         for j in range(nstates):
            A_ikn[index_ij,:,:] = A_ijkn[i,j,:,:]
            index_ij += 1

      # Compute expectations of all Cij by reweighting.
      (A_i, d2A_ij) = mbar.computeMultipleExpectations(A_ikn, u_kn)

      # Form estimates of correlation functions
      Cij = zeros([nstates, nstates], float64)
      index = 0
      for i in range(nstates):
         for j in range(nstates):
            Cij[i,j] = A_i[index]
            index += 1

      # Extract uncertainties in Cij.
      dCijdCkl = zeros([nstates,nstates,nstates,nstates], float64)
      index_ij = 0
      for i in range(nstates):
         for j in range(nstates):
            index_kl = 0
            for k in range(nstates):
               for l in range(nstates):
                  dCijdCkl[i,j,k,l] = d2A_ij[index_ij, index_kl]
                  index_kl += 1
            index_ij += 1

      # Form transition matrix estimates
      Tij = zeros([nstates, nstates], float64)
      for i in range(nstates):
         for j in range(nstates):
            Tij[i,j] = Cij[i,j] / Cij[i,:].sum()

      # Write to netcdf
      netcdf_file.variables['Tr_ijkt'][:,:,temperature_index,t] = Tij

      # Compute estimate of stationary probabilities
      print "Pi = "
      Pi = Cij.sum(1)
      print Pi
      print "sum_i Pi = %f" % Pi.sum()

      # Compute derivates of Tij in terms of Cij
      dTabdCij = zeros([nstates, nstates, nstates, nstates], float64) # dTabdCij is derivative of Tab with respect to Cij
      for a in range(nstates):
         for b in range(nstates):
            for i in range(nstates):
               for j in range(nstates):
                  dTabdCij[a,b,i,j] = delta(a,b,i,j) / Pi[a]
                  for g in range(nstates):
                     dTabdCij[a,b,i,j] += - Tij[a,b]/Pi[a] * delta(a,g,i,j)

      # Propagate uncertainty into dTijdTkl
      print "Computing dTijdTkl..."
#   dTijdTkl = zeros([nstates, nstates, nstates, nstates], float64)
#   for a in range(nstates):
#      for b in range(nstates):
#         for c in range(nstates):
#            for d in range(nstates):
#               for i in range(nstates):
#                  for j in range(nstates):
#                     for k in range(nstates):
#                        for l in range(nstates):
#                           dTijdTkl[a,b,c,d] += dTabdCij[a,b,i,j] * dTabdCij[c,d,k,l] * dCijdCkl[i,j,k,l]


      # Propagate uncertainty into d2Tij
      d2Tij = zeros([nstates, nstates], float64)
      for a in range(nstates):
         for b in range(nstates):
            for i in range(nstates):
               for j in range(nstates):
                  for k in range(nstates):
                     for l in range(nstates):
                        d2Tij[a,b] += dTabdCij[a,b,i,j] * dTabdCij[a,b,k,l] * dCijdCkl[i,j,k,l]

      # Compute std dev of elements.
      dTij = numpy.sqrt(d2Tij)

      # Write to netcdf
      netcdf_file.variables['dTr_ijkt'][:,:,temperature_index,t] = dTij
   
#===================================================================================================
# Close NetCDF file.
#===================================================================================================

netcdf_file.close()

