#!/usr/bin/python

# Compute temperature-dependent transition probabilities for alanine dipeptide parallel tempering data.

#===================================================================================================
# IMPORTS
#===================================================================================================
import numpy
from numpy import *
from math import *
import MBAR # for MBAR analysis
import timeseries # for timeseries analysis
import commands
import os
import os.path
from numpy.linalg import * # linear algebra methods

#===================================================================================================
# CONSTANTS
#===================================================================================================

kB = 1.3806503 * 6.0221415 / 4184.0 # Boltzmann constant in kcal/mol/K

#===================================================================================================
# PARAMETERS
#===================================================================================================

data_directory = 'parallel-tempering-phipsi-data' # data containing the parallel tempering data
temperature_list_filename = os.path.join(data_directory, 'temps') # file containing temperatures in K
total_energies_filename = os.path.join(data_directory, 'etot.out') # file containing total energies (in kcal/mol) for each temperature and snapshot
trajectory_segment_length = 200 # number of snapshots in each contiguous trajectory segment
ntrajectories = 501 # number of trajectories to use
tau = 60 # lag time in snapshots
ninterpolate = 1 # number of interpolated temperatures to add in between parallel tempering temperatures for estimation of rates
counting_method = 'sliding' # How to count transitions to compute Cij -- 'single' or 'sliding'

#===================================================================================================
# SUBROUTINES
#===================================================================================================

def write_file(filename, contents):
   """Write the specified contents to a file.
   
   ARGUMENTS
     filename (string) - the file to be written
     contents (string) - the contents of the file to be written
     
   """

   outfile = open(filename, 'w')
   
   if type(contents) == list:
      for line in contents:
         outfile.write(line)
   elif type(contents) == str:
      outfile.write(contents)
   else:
      raise "Type for 'contents' not supported: " + repr(type(contents))
      
   outfile.close()
   
   return

def read_file(filename):
   """Read contents of the specified file.
      
   ARGUMENTS
     filename (string) - the name of the file to be read
     
   RETURNS
     lines (list of strings) - the contents of the file, split by line

   """

   infile = open(filename, 'r')
   lines = infile.readlines()
   infile.close()
   
   return lines

def logSum(log_terms):
   """Compute the log of a sum of terms whose logarithms are provided.

   REQUIRED ARGUMENTS  
      log_terms is the array (possibly multidimensional) containing the logs of the terms to be summed.

   RETURN VALUES
      log_sum is the log of the sum of the terms.

   """

   # compute the maximum argument
   max_log_term = log_terms.max()

   # compute the reduced terms
   terms = numpy.exp(log_terms - max_log_term)

   # compute the log sum
   log_sum = log( terms.sum() ) + max_log_term

   # return the log sum
   return log_sum

def kronecker(i,j):
   """Kronecker delta.
   """

   if (i == j):
      return 1
   return 0

def delta(a,b,i,j):
   """
   """

   if (((a == i) and (b == j)) or ((a == j) and (b == i))):
      return 1
   return 0

#===================================================================================================
# MAIN
#===================================================================================================

#===================================================================================================
# Read temperatures
#===================================================================================================

# Read list of temperatures.
lines = read_file(temperature_list_filename)
# Construct list of temperatures
temperatures = lines[0].split()
# Create numpy array of temperatures
K = len(temperatures)
temperature_k = zeros([K], float32) # temperature_k[k] is temperature of temperature index k in K
for k in range(K):
   temperature_k[k] = float(temperatures[k])
# Compute inverse temperatures
beta_k = (kB * temperature_k)**(-1) 

# Define other constants
N = ntrajectories
T = trajectory_segment_length

#===================================================================================================
# Read eneriges
#===================================================================================================

# Read total energies, averaging total energy over each trajectory segment.

print "Reading total energies..."
etot_kn = zeros([K,N], float64) # etot_kn[k,n] is the mean total energy over trajectory segment n of temperature index k
lines = read_file(total_energies_filename)
for n in range(N):
   # Allocate storage for all energies for one trajectory segment.
   etot_kt = zeros([K,T], float64) # etot_kt[k,t] is the total energy of snapshot t of temperature index k
   # Read total energies for all snapshots of a trajectory segment
   for t in range(T):
      # Get line containing the energies for snapshot t of trajectory segment n
      line = lines[n*T + t]
      # Extract energy values from text
      elements = line.split()
      for k in range(K):
         etot_kt[k,t] = float(elements[k])
   # Compute mean and variance of total energies over trajectory segment
   for k in range(K):
      etot_kn[k,n] = etot_kt[k,:].mean()

#===================================================================================================
# Read phi, psi trajectories
#===================================================================================================

print "Reading phi, psi trajectories..."
phi_knt = zeros([K,N,T], float32) # phi_knt[k,n,t] is phi angle (in degrees) for snapshot t of trajectory segment n of temperature k
psi_knt = zeros([K,N,T], float32) # psi_knt[k,n,t] is psi angle (in degrees) for snapshot t of trajectory segment n of temperature k
for k in range(K):
   filename = os.path.join(data_directory, '%d.tors' % (k))
   lines = read_file(filename)
   print "k = %d, %d lines read" % (k, len(lines))
   for n in range(N):
      for t in range(T):
         # Get line containing snapshot t of trajectory n      
         line = lines[n*T + t + 1]
         # Extract phi and psi
         elements = line.split()
         phi_knt[k,n,t] = float(elements[3])
         psi_knt[k,n,t] = float(elements[4])

#===================================================================================================
# Assign discrete state identities
#===================================================================================================

print "Assigning states..."
nstates = 6;
state_knt = zeros([K,N,T], int16) # state_knt[k,n,t] is the state index (0...5) of snapshot t of trajectory segment n of temperature k

state_knt[((phi_knt >=  117) | (phi_knt < -105)) & ((psi_knt >=   28) | (psi_knt < -124))] = 0
state_knt[((phi_knt >= -105) & (phi_knt <    0)) & ((psi_knt >=   28) | (psi_knt < -124))] = 1
state_knt[((phi_knt >   117) | (phi_knt < -105)) & ((psi_knt >= -124) & (psi_knt <   28))] = 2
state_knt[((phi_knt >= -105) & (phi_knt <    0)) & ((psi_knt >= -124) & (psi_knt <   28))] = 3
state_knt[((phi_knt >=    0) & (phi_knt <  117)) & ((psi_knt >=  111) | (psi_knt <   -5))] = 4
state_knt[((phi_knt >=    0) & (phi_knt <  117)) & ((psi_knt >=  -5)  & (psi_knt <  111))] = 5

#===================================================================================================
# Compute observable for correlation function
#===================================================================================================

print "Computing observables Akn..."
Aijkn = zeros([nstates, nstates, K, N], float32)
for i in range(nstates):
   for j in range(nstates):
      if counting_method == 'sliding':
         Aijkn[i,j,:,:] = mean((state_knt[:,:,0:(T-tau)]==i) & (state_knt[:,:,tau:T]==j) | (state_knt[:,:,0:(T-tau)]==j) & (state_knt[:,:,tau:T]==i), axis=2)
      elif counting_method == 'single':
         Aijkn[i,j,:,:] = ((state_knt[:,:,0]==i) & (state_knt[:,:,tau]==j) | (state_knt[:,:,0]==j) & (state_knt[:,:,tau]==i))
      else:
         raise "counting_method '%s' unrecognized." % counting_method
      
      if (i != j): Aijkn[i,j,:,:] /= 2      

#===================================================================================================
# Estimate maximum likelihood transition matrix at each temperature separately.
#===================================================================================================

print "Estimating transition probabilities..."
# Estimate correlation function Cji(tau) individually for each temperature
Cijk = mean(Aijkn, 3) # Cjik[j,i,k] is the estimated correlation function <chi_i(0) chi_j(tau)> from data only at temperature index k
# Estimate population p_i for each state individually for each temperature
Pik = sum(Cijk, 1)
# Estimate transition matrix Tji for each state individually for each temperature
Tijk = zeros([nstates, nstates, K], float64)
for k in range(K):
   for i in range(nstates):      
      Tijk[i,:,k] = Cijk[i,:,k] / Pik[i,k]

# DEBUG
for k in range(K):
   print "Temperature %d : %f K" % (k, temperature_k[k])
   print Tijk[:,:,k]

#===================================================================================================
# Compute reduced potential energy of each trajectory segment in each condition
#===================================================================================================

print "Computing energies..."
u_kln = zeros([K,K,N], float64) # u_kln[k,l,n] is reduced potential energy of trajectory segment n of temperature k evaluated at temperature l
for k in range(K):
   for l in range(K):
      for n in range(N):
         u_kln[k,l,n] = beta_k[l] * etot_kn[k,n]
   
N_k = zeros([K], int32)
N_k[:] = N

#===================================================================================================
# Read in guess of free energies
#===================================================================================================

# Read free energies, if they exist.
free_energy_filename = 'f_k.dat'
if os.path.exists(free_energy_filename):
   print "Reading free energies from file %s..." % free_energy_filename
   lines = read_file(free_energy_filename)
   contents = ""
   for line in lines:
      contents += line.strip() + ' '
   elements = contents.split()
   f_k_initial = zeros([K], float64)
   for k in range(K):
      f_k_initial[k] = float(elements[k])
   print f_k_initial
else:
   f_k_initial = None

#===================================================================================================
# Initialize MBAR to compute free energy estimates
#===================================================================================================

# Initialize MBAR 
print "Initializing MBAR (will estimate free energy differences first time)..."
mbar = MBAR.MBAR(u_kln, N_k, method = 'self-consistent-iteration', verbose = True, initial_f_k = f_k_initial, relative_tolerance = 1.0e-8)

print "Converged free energies = "
print mbar.f_k

# Write converged free energies
outfile = open(free_energy_filename, 'w')
for k in range(K):
   outfile.write('%30.20e\n' % mbar.f_k[k])
outfile.close()

#===================================================================================================
# Compute optimal estimates of unnormalized correlation functions at multiple temperatures.
#===================================================================================================

# Define list of temperatures at which we want to evaluate rates.
interpolated_temperatures = zeros([K + ninterpolate*(K-1)], float64)
for k in range(K-1):
   # fix temperatures at parallel tempering temperatures
   interpolated_temperatures[k*(ninterpolate+1)] = temperature_k[k]
   # linear interpolation of intermediate temperatures
   for i in range(ninterpolate):
      interpolated_temperatures[k*(ninterpolate+1) + (i+1)] = (temperature_k[k+1] - temperature_k[k]) / (ninterpolate+1) * (i+1) + temperature_k[k]
# add last achored parallel tempering temperature
interpolated_temperatures[(K-1)*(ninterpolate+1)] = temperature_k[K-1]

# DEBUG
print "interpolated_temperatures = "
print interpolated_temperatures

# Write
Tij_outfile = open('Tij.dat', 'w')
dTij_outfile = open('dTij.dat', 'w')
outfile = open('combined.dat', 'w')

# Compute expectations
ninterpolated_temperatures = len(interpolated_temperatures)
for temperature_index in range(ninterpolated_temperatures):
   # DEBUG
   print "Computing expectations for temperature %d/%d (%.1f K) by reweighting..." % (temperature_index, ninterpolated_temperatures, interpolated_temperatures[temperature_index])
   
   # Compute u_kn at this temperature
   beta = 1.0 / (kB * interpolated_temperatures[temperature_index])
   u_kn = beta * etot_kn

   # Form elements into observables
   A_ikn = zeros([nstates*nstates, K, N], float64)
   index_ij = 0
   for i in range(nstates):
      for j in range(nstates):
         A_ikn[index_ij,:,:] = Aijkn[i,j,:,:]
         index_ij += 1

   # Compute expectations of all Cij by reweighting.
   (A_i, d2A_ij) = mbar.computeMultipleExpectations(A_ikn, u_kn)
   
   # Form estimates of correlation functions
   Cij = zeros([nstates, nstates], float64)
   index = 0
   for i in range(nstates):
      for j in range(nstates):
         Cij[i,j] = A_i[index]
         index += 1
         
   # Extract uncertainties in Cij.
   dCijdCkl = zeros([nstates,nstates,nstates,nstates], float64)
   index_ij = 0
   for i in range(nstates):
      for j in range(nstates):
         index_kl = 0
         for k in range(nstates):
            for l in range(nstates):
               dCijdCkl[i,j,k,l] = d2A_ij[index_ij, index_kl]
               index_kl += 1
         index_ij += 1

   # Form transition matrix estimates
   Tij = zeros([nstates, nstates], float64)
   for i in range(nstates):
      for j in range(nstates):
         Tij[i,j] = Cij[i,j] / Cij[i,:].sum()

   # Write out transition probabilities
   Tij_outfile.write('%8.3f ' % interpolated_temperatures[temperature_index])
   for i in range(nstates):
      for j in range(nstates):
         Tij_outfile.write('%24.16e ' % Tij[i,j])
   Tij_outfile.write('\n')

   # Compute estimate of stationary probabilities
   print "Pi = "
   Pi = Cij.sum(1)
   print Pi
   print "sum_i Pi = %f" % Pi.sum()
   
   # Compute derivates of Tij in terms of Cij
   dTabdCij = zeros([nstates, nstates, nstates, nstates], float64) # dTabdCij is derivative of Tab with respect to Cij
   for a in range(nstates):
      for b in range(nstates):
         for i in range(nstates):
            for j in range(nstates):
               dTabdCij[a,b,i,j] = delta(a,b,i,j) / Pi[a]
               for g in range(nstates):
                  dTabdCij[a,b,i,j] += - Tij[a,b]/Pi[a] * delta(a,g,i,j)

   # Propagate uncertainty into dTijdTkl
   print "Computing dTijdTkl..."
#   dTijdTkl = zeros([nstates, nstates, nstates, nstates], float64)
#   for a in range(nstates):
#      for b in range(nstates):
#         for c in range(nstates):
#            for d in range(nstates):
#               for i in range(nstates):
#                  for j in range(nstates):
#                     for k in range(nstates):
#                        for l in range(nstates):
#                           dTijdTkl[a,b,c,d] += dTabdCij[a,b,i,j] * dTabdCij[c,d,k,l] * dCijdCkl[i,j,k,l]


   # Propagate uncertainty into d2Tij
   d2Tij = zeros([nstates, nstates], float64)
   for a in range(nstates):
      for b in range(nstates):
         for i in range(nstates):
            for j in range(nstates):
               for k in range(nstates):
                  for l in range(nstates):
                     d2Tij[a,b] += dTabdCij[a,b,i,j] * dTabdCij[a,b,k,l] * dCijdCkl[i,j,k,l]

   # Write out uncertainty in transition probabilities
   dTij_outfile.write('%8.3f ' % interpolated_temperatures[temperature_index])
   for i in range(nstates):
      for j in range(nstates):
         dTij_outfile.write('%24.16e ' % sqrt(d2Tij[i,j]))
   dTij_outfile.write('\n')

   # Write combined
   outfile.write('%8.3f ' % interpolated_temperatures[temperature_index])
   for i in range(nstates):
      for j in range(nstates):
         outfile.write('%24.16e %24.16e %24.16e ' % (Tij[i,j], sqrt(d2Tij[i,j]), 2 * sqrt(d2Tij[i,j])))
   outfile.write('\n')
   
# Close files.
Tij_outfile.close()
dTij_outfile.close()
outfile.close()
