#!/usr/bin/python

# Compute temperature-dependent transition probabilities as a function of lag time from alanine dipeptide parallel tempering data.

#=============================================================================================
# REQUIREMENTS
#
# This code requires the 'pynetcdf' package, containing the Scientific.IO.NetCDF package built for numpy.
#
# http://pypi.python.org/pypi/pynetcdf/
# http://sourceforge.net/project/showfiles.php?group_id=1315&package_id=185504
#
# This code also uses the 'MBAR' package, implementing the multistate Bennett acceptance ratio estimator, available here:
#
# http://www.simtk.org/home/pymbar
#=============================================================================================

#===================================================================================================
# IMPORTS
#===================================================================================================
import numpy
import math
import pymbar # for MBAR analysis
import timeseries # for timeseries analysis
import commands
import os
import os.path
from numpy.linalg import * # linear algebra methods
import datetime # time and date
#from pynetcdf import NetCDF # for writing of data objects for plotting in Matlab or Mathematica
import netCDF4 as netcdf # for writing of data objects for plotting in Matlab or Mathematica

#===================================================================================================
# CONSTANTS
#===================================================================================================

kB = 1.3806503 * 6.0221415 / 4184.0 # Boltzmann constant in kcal/mol/K

#===================================================================================================
# PARAMETERS
#===================================================================================================

data_directory = '../../datasets/alanine-dipeptide/parallel-tempering' # data containing the parallel tempering data
temperature_list_filename = os.path.join(data_directory, 'temps') # file containing temperatures in K
total_energies_filename = os.path.join(data_directory, 'etot.out') # file containing total energies (in kcal/mol) for each temperature and snapshot
trajectory_segment_length = 200 # number of snapshots in each contiguous trajectory segment
ntrajectories = 501 # number of trajectories to use
tau_unit = 0.1 # sampling time in ps
netcdf_output_filename = 'output/alanine-dipeptide-autocorrelation.nc' # netcdf output filename
#use_analytical_momentum = True # if True, will include analytical momentum contribution to partition function in energies
#netcdf_output_filename = 'output/alanine-dipeptide-transition-rate-analytical-momentum.nc' # netcdf output filename
ndof = (3*22-21) + 431*(3*3-3) - 3 # number of degrees of freedom

#===================================================================================================
# SUBROUTINES
#===================================================================================================

def write_file(filename, contents):
   """Write the specified contents to a file.
   
   ARGUMENTS
     filename (string) - the file to be written
     contents (string) - the contents of the file to be written
     
   """

   outfile = open(filename, 'w')
   
   if type(contents) == list:
      for line in contents:
         outfile.write(line)
   elif type(contents) == str:
      outfile.write(contents)
   else:
      raise "Type for 'contents' not supported: " + repr(type(contents))
      
   outfile.close()
   
   return

def read_file(filename):
   """Read contents of the specified file.
      
   ARGUMENTS
     filename (string) - the name of the file to be read
     
   RETURNS
     lines (list of strings) - the contents of the file, split by line

   """

   infile = open(filename, 'r')
   lines = infile.readlines()
   infile.close()
   
   return lines

def logSum(log_terms):
   """Compute the log of a sum of terms whose logarithms are provided.

   REQUIRED ARGUMENTS  
      log_terms is the array (possibly multidimensional) containing the logs of the terms to be summed.

   RETURN VALUES
      log_sum is the log of the sum of the terms.

   """

   # compute the maximum argument
   max_log_term = log_terms.max()

   # compute the reduced terms
   terms = numpy.exp(log_terms - max_log_term)

   # compute the log sum
   log_sum = log( terms.sum() ) + max_log_term

   # return the log sum
   return log_sum

def kronecker(i,j):
   """Kronecker delta.
   """

   if (i == j):
      return 1
   return 0

def delta(a,b,i,j):
   """
   """

   if (((a == i) and (b == j)) or ((a == j) and (b == i))):
      return 1
   return 0


#===================================================================================================
# MAIN
#===================================================================================================

#===================================================================================================
# Initialize the NetCDF file for output of computed data objects.
#===================================================================================================

# Open the NetCDF trajectory file.
print "Opening NetCDF file for writing..."
#netcdf_file = NetCDF.NetCDFFile(netcdf_output_filename, 'w')
netcdf_file = netcdf.Dataset(netcdf_output_filename, 'w', format='NETCDF3_CLASSIC')
    
# Set global attributes.
setattr(netcdf_file, 'title', "Analysis data produced at %s" % datetime.datetime.now().ctime())
setattr(netcdf_file, 'application', 'temperature-dependent-transition-probabilities.py')

#===================================================================================================
# Read temperatures
#===================================================================================================

print "Reading temperatures and other dimensions from parallel tempering dataset..."
# Read list of temperatures.
lines = read_file(temperature_list_filename)
# Construct list of temperatures
temperatures = lines[0].split()
# Create numpy array of temperatures
K = len(temperatures)
temperature_k = numpy.zeros([K], numpy.float32) # temperature_k[k] is temperature of temperature index k in K
for k in range(K):
   temperature_k[k] = float(temperatures[k])
# Compute inverse temperatures
beta_k = (kB * temperature_k)**(-1) 

# Define other constants
N = ntrajectories
T = trajectory_segment_length

# Store in netcdf.
netcdf_file.createDimension('K', K)             # number of temperatures
netcdf_file.createDimension('N', N)             # number of trajectories per temperature
netcdf_file.createDimension('T', T)             # number of snapshots per trajectory
  
variable = netcdf_file.createVariable('temperature_k', 'd', ('K',))
setattr(variable, 'units', 'Kelvin')
setattr(variable, 'description', 'temperature_k[k] is the temperature of temperature index k')
netcdf_file.variables['temperature_k'][:] = temperature_k

variable = netcdf_file.createVariable('beta_k', 'f', ('K',))
setattr(variable, 'units', '1/(kcal/mol)')
setattr(variable, 'description', 'beta_k[k] is the inverse temperature of temperature index k')
netcdf_file.variables['beta_k'][:] = beta_k

#===================================================================================================
# Read eneriges
#===================================================================================================

# Read total energies for each trajectory segment.
# Because the kinetic energy reported by the integrator is inaccurate, we  average total energy over
# each trajectory segment to obtain a better estimator of the true Hamiltonian.

print "Reading total energies..."
E_kn = numpy.zeros([K,N], numpy.float64) # E_kn[k,n] is the total energy for trajectory segment n of temperature index k
lines = read_file(total_energies_filename)
for n in range(N):
   # Allocate storage for all energies for one trajectory segment.
   E_kt = numpy.zeros([K,T], numpy.float64) # E_kt[k,t] is the total energy of snapshot t of temperature index k
   # Read total energies for all snapshots of a trajectory segment
   for t in range(T):
      # GEet line containing the energies for snapshot t of trajectory segment n
      line = lines[n*T + t]
      # Extract energy values from text
      elements = line.split()
      for k in range(K):
         E_kt[k,t] = float(elements[k])
   # Compute mean and variance of total energies over trajectory segment
   for k in range(K):
      E_kn[k,n] = E_kt[k,:].mean()

# Store in netcdf.
variable = netcdf_file.createVariable('E_kn', 'd', ('K', 'N',))
setattr(variable, 'units', 'kcal/mol')
setattr(variable, 'description', 'E_kn[k,n] is the total energy of trajectory segment n from temperature index k')
netcdf_file.variables['E_kn'][:,:] = E_kn

#===================================================================================================
# Compute reduced potential energy of each trajectory segment in each condition.
#===================================================================================================

print "Computing reduced total energies for MBAR..."
u_kln = numpy.zeros([K,K,N], numpy.float64) # u_kln[k,l,n] is reduced dimensionless total energy of trajectory segment n of temperature k evaluated at temperature l
for k in range(K):
   for l in range(K):
      u_kln[k,l,:] = beta_k[l] * E_kn[k,:]
   
N_k = numpy.zeros([K], numpy.int32)
N_k[:] = N

#===================================================================================================
# Read in guess of path ensemble 'free energies' to save time.
#===================================================================================================

# Read free energies, if they exist.
free_energy_filename = 'f_k.dat'
if os.path.exists(free_energy_filename):
   print "Reading path ensemble 'free energies' from file %s..." % free_energy_filename
   lines = read_file(free_energy_filename)
   contents = ""
   for line in lines:
      contents += line.strip() + ' '
   elements = contents.split()
   f_k_initial = numpy.zeros([K], numpy.float64)
   for k in range(K):
      f_k_initial[k] = float(elements[k])
   print f_k_initial
else:
   f_k_initial = None

#===================================================================================================
# Initialize MBAR for dynamical reweighting.
#===================================================================================================

# Initialize MBAR 
print "Initializing MBAR (will estimate free energy differences first time)..."
mbar = pymbar.MBAR(u_kln, N_k, method = 'self-consistent-iteration', verbose = True, initial_f_k = f_k_initial, relative_tolerance = 1.0e-8)

print "Converged path ensemble 'free energies' = "
print mbar.f_k

# Write converged free energies to a file to save time in future executions.
outfile = open(free_energy_filename, 'w')
for k in range(K):
   outfile.write('%30.20e\n' % mbar.f_k[k])
outfile.close()

# Store converged free energies in netcdf.
variable = netcdf_file.createVariable('f_k', 'd', ('K',))
setattr(variable, 'units', 'none')
setattr(variable, 'description', 'f_k[k] is the dimensionless free energy of state k, up to an irrelevant additive constant')
netcdf_file.variables['f_k'][:] = mbar.f_k

#===================================================================================================
# Compute total contribution weight to each transition from various temperatures.
#===================================================================================================

print "Computing contribution to each transition from various temperatures..."

log_w_kl = numpy.zeros([K, K], numpy.float64) # log_w_kl[k,l] is the contribution from temperature l to the estimation of any expectations at temperature k

for k in range(K):
   # Compute trajectory log weights.
   log_w_kn = mbar._computeUnnormalizedLogWeights(beta_k[k] * E_kn)
   for l in range(K):
      log_w_kl[k,l] = pymbar.logsum(log_w_kn[l,:])

# Store in netcdf.
variable = netcdf_file.createVariable('log_w_kl', 'd', ('K', 'K',))
setattr(variable, 'units', 'none')
setattr(variable, 'description', 'log_w_kl[k,l] is the contribution from temperature l to the estimation of any expectations at temperature k')
netcdf_file.variables['log_w_kl'][:,:] = log_w_kl

#===================================================================================================
# Read phi, psi trajectories from parallel tempering dataset.
#===================================================================================================

print "Reading phi, psi trajectories from parallel tempering dataset..."
phi_knt = numpy.zeros([K,N,T], numpy.float32) # phi_knt[k,n,t] is phi angle (in degrees) for snapshot t of trajectory segment n of temperature k
psi_knt = numpy.zeros([K,N,T], numpy.float32) # psi_knt[k,n,t] is psi angle (in degrees) for snapshot t of trajectory segment n of temperature k
for k in range(K):
   filename = os.path.join(data_directory, '%d.tors' % (k))
   lines = read_file(filename)
   print "k = %d, %d lines read" % (k, len(lines))
   for n in range(N):
      for t in range(T):
         # Get line containing snapshot t of trajectory n      
         line = lines[n*T + t + 1]
         # Extract phi and psi
         elements = line.split()
         phi_knt[k,n,t] = float(elements[3])
         psi_knt[k,n,t] = float(elements[4])

#===================================================================================================
# Compute observables A(x) over all trajectories.
#===================================================================================================

print "Computing observables..."

# alpha_R indicator function
A_knt = numpy.zeros([K,N,T], numpy.float32)
A_knt[((psi_knt >= -124) & (psi_knt <   28))] = 1 

#===================================================================================================
# Compute autocorrelation functions <A(0)A(t)> from each temperature separately.
#===================================================================================================

print "Computing autocorrelation function for observable from each temperature separately..."
C_kt = numpy.zeros([K, T], numpy.float32) # C_kt[k,t] is the estimate of the correlation function <A(0)A(t)> from temperature index k
dC_kt = numpy.zeros([K, T], numpy.float32) # dC_kt[k,t] is the standard error in C_kt[k,t]

for k in range(K):   
   for t in range(T):
      A0At = (A_knt[k,:,0] * A_knt[k,:,t])
      C_kt[k,t] = A0At.mean()
      dC_kt[k,t] = A0At.std() / numpy.sqrt(N)
      
# Store in netcdf.
variable = netcdf_file.createVariable('C_kt_direct', 'd', ('K', 'T',))
setattr(variable, 'units', 'none')
setattr(variable, 'description', 'C_kt_direct[k,t] is the estimate of the correlation function <A(0)A(t)>_k from only data collected temperature index k')
netcdf_file.variables['C_kt_direct'][:,:] = C_kt

variable = netcdf_file.createVariable('dC_kt_direct', 'd', ('K', 'T',))
setattr(variable, 'units', 'none')
setattr(variable, 'description', 'dC_kt_direct[k,t] is the standard error in C_kt_direct[k,t]')
netcdf_file.variables['dC_kt_direct'][:,:] = dC_kt

#===================================================================================================
# Compute autocorrelation functions <A(0)A(t)> by dynamical rweighting.
#===================================================================================================

A0Atkn = numpy.zeros([T, K, N], numpy.float64) # A0Atkn[t,k,n] is A(0)A(t) from trajectory segment n of temperature index k

for k in range(K):
   for n in range(N):
      for t in range(T):
         A0Atkn[t,k,n] = (A_knt[k,n,0] * A_knt[k,n,t])
         
# Compute expectations at each temperature by reweighting.
C_kt = numpy.zeros([K, T], numpy.float32) # C_kt[k,t] is the estimate of the correlation function <A(0)A(t)> from temperature index k
dC_kt = numpy.zeros([K, T], numpy.float32) # dC_kt[k,t] is the standard error in C_kt[k,t]
for k in range(K):
   # Compute u_kn at this temperature
   u_kn = beta_k[k] * E_kn

   # Compute expectations.
   (A_i, d2A_ij) = mbar.computeMultipleExpectations(A0Atkn, u_kn)

   # Store expectations.
   C_kt[k,:] = A_i[:]
   dC_kt[k,:] = numpy.sqrt(numpy.diag(d2A_ij))

# Store in netcdf.
variable = netcdf_file.createVariable('C_kt_reweighting', 'd', ('K', 'T',))
setattr(variable, 'units', 'none')
setattr(variable, 'description', 'C_kt_reweighting[k,t] is the estimate of the correlation function <A(0)A(t)>_k from all data by reweighting')
netcdf_file.variables['C_kt_reweighting'][:,:] = C_kt

variable = netcdf_file.createVariable('dC_kt_reweighting', 'd', ('K', 'T',))
setattr(variable, 'units', 'none')
setattr(variable, 'description', 'dC_kt_reweightingt[k,t] is the standard error in C_kt_reweighting[k,t]')
netcdf_file.variables['dC_kt_reweighting'][:,:] = dC_kt

#===================================================================================================
# Close NetCDF file.
#===================================================================================================

print "Done.\nErrors after this point can be ignored.\n\n"
netcdf_file.close()

