#!/usr/bin/python

# Estimate 1D potential of mean force for psi torsion of alanine dipeptide parallel tempering data using MBAR.
# Gaussian kernels are used for PMF estimate.
#
# PROTOCOL
#
# * Potential energies and (phi, psi) torsions from parallel tempering simulation are read in by temperature
# * WHAM [2] is used to rapidly generate an initial guess for the dimensionless free energies f_k.
# * Replica trajectories of potential energies and torsions are reconstructed to reflect their true temporal
# correlation, and then subsampled to produce statistically independent samples, collecting them again by temperature
# * The MBAR class is initialized with this initial guess at dimensionless free energies f_k, reducing time for
# solution of self-consistent equations
# * The torsions are binned into sequentially labeled bins in two dimensions
# * The relative free energies and uncertainties of these torsion bins at the temperature of interest is estimated
# * The 2D PMF is written out
# 
#
# REFERENCES
#
# [1] Shirts MR and Chodera JD. Statistically optimal analysis of samples from multiple equilibrium states.
# J. Chem. Phys. 129:124105, 2008
# http://dx.doi.org/10.1063/1.2978177
#
# [2] Kumar S, Bouzida D, Swensen RH, Kollman PA, and Rosenberg JM. The weighted histogram analysis method
# for free-energy calculations on biomolecules. I. The Method. J. Comput Chem. 13:1011, 1992.

#===================================================================================================
# IMPORTS
#===================================================================================================

import numpy
from math import *
import pymbar # for MBAR analysis
import timeseries # for timeseries analysis
import commands
import os
import os.path

#===================================================================================================
# CONSTANTS
#===================================================================================================

kB = 1.3806503 * 6.0221415 / 4184.0 # Boltzmann constant in kcal/mol/K

#===================================================================================================
# PARAMETERS
#===================================================================================================

data_directory = '../../../../datasets/alanine-dipeptide/parallel-tempering-10ns' # directory containing the parallel tempering data
temperature_list_filename = os.path.join(data_directory, 'temperatures') # file containing temperatures in K
potential_energies_filename = os.path.join(data_directory, 'energies', 'potential-energies') # file containing total energies (in kcal/mol) for each temperature and snapshot
trajectory_segment_length = 20 # number of snapshots in each contiguous trajectory segment
niterations = 500 # number of iterations to use
target_temperature = 302 # target temperature for 1D PMF (in K)
nbins = 180 # number of bins per torsion dimension
bandwidth = 20 # bandwidth (in degrees) of histogram bin

#===================================================================================================
# SUBROUTINES
#===================================================================================================

def read_file(filename):
   """Read contents of the specified file.
      
   ARGUMENTS
     filename (string) - the name of the file to be read
     
   RETURNS
     lines (list of strings) - the contents of the file, split by line

   """

   infile = open(filename, 'r')
   lines = infile.readlines()
   infile.close()
   
   return lines

def logSum(log_terms):
   """Compute the log of a sum of terms whose logarithms are provided.

   REQUIRED ARGUMENTS  
      log_terms is the array (possibly multidimensional) containing the logs of the terms to be summed.

   RETURN VALUES
      log_sum is the log of the sum of the terms.

   """

   # compute the maximum argument
   max_log_term = log_terms.max()

   # compute the reduced terms
   terms = numpy.exp(log_terms - max_log_term)

   # compute the log sum
   log_sum = log( terms.sum() ) + max_log_term

   # return the log sum
   return log_sum
 
#===================================================================================================
# MAIN
#===================================================================================================

#===================================================================================================
# Read temperatures
#===================================================================================================

# Read list of temperatures.
lines = read_file(temperature_list_filename)
# Construct list of temperatures
temperatures = lines[0].split()
# Create numpy array of temperatures
K = len(temperatures)
temperature_k = numpy.zeros([K], numpy.float32) # temperature_k[k] is temperature of temperature index k in K
for k in range(K):
   temperature_k[k] = float(temperatures[k])
# Compute inverse temperatures
beta_k = (kB * temperature_k)**(-1) 

# Define other constants
T = trajectory_segment_length * niterations # total number of snapshots per temperature

#===================================================================================================
# Read potential eneriges
#===================================================================================================

print "Reading potential energies..."
U_kt = numpy.zeros([K,T], numpy.float32) # U_kn[k,t] is the potential energy (in kcal/mol) for snapshot t of temperature index k
lines = read_file(potential_energies_filename)
print "%d lines read, processing %d snapshots" % (len(lines), T)
for t in range(T):
   # Get line containing the energies for snapshot t of trajectory segment n
   line = lines[t]
   # Extract energy values from text
   elements = line.split()
   for k in range(K):
      U_kt[k,t] = float(elements[k])

#===================================================================================================
# Read phi, psi trajectories
#===================================================================================================

print "Reading phi, psi trajectories..."
phi_kt = numpy.zeros([K,T], numpy.float32) # phi_kt[k,n,t] is phi angle (in degrees) for snapshot t of temperature k
psi_kt = numpy.zeros([K,T], numpy.float32) # psi_kt[k,n,t] is psi angle (in degrees) for snapshot t of temperature k
for k in range(K):
   phi_filename = os.path.join(data_directory, 'backbone-torsions', '%d.phi' % (k))
   psi_filename = os.path.join(data_directory, 'backbone-torsions', '%d.psi' % (k))   
   phi_lines = read_file(phi_filename)
   psi_lines = read_file(psi_filename)
   print "k = %d, %d phi lines read, %d psi lines read" % (k, len(phi_lines), len(psi_lines))   
   for t in range(T):
      # Extract phi and psi
      phi_kt[k,t] = float(phi_lines[t])
      psi_kt[k,t] = float(psi_lines[t])

#===================================================================================================
# Read replica indices
#===================================================================================================

print "Reading replica indices..."
filename = os.path.join(data_directory, 'replica-indices')
lines = read_file(filename)
replica_ik = numpy.zeros([niterations,K], numpy.int32) # replica_ki[i,k] is the replica index of temperature k for iteration i
for i in range(niterations):
   elements = lines[i].split()
   for k in range(K):
      replica_ik[i,k] = int(elements[k])
print "Replica indices for %d iterations processed." % niterations

#===================================================================================================
# Permute data by replica and subsample to generate an uncorrelated subset of data by temperature
#===================================================================================================

assume_uncorrelated = False
if (assume_uncorrelated):
   # DEBUG - use all data, assuming it is uncorrelated
   print "Using all data, assuming it is uncorrelated..."
   U_kn = U_kt.copy()
   phi_kn = phi_kt.copy()
   psi_kn = psi_kt.copy()
   N_k = numpy.zeros([K], numpy.int32)
   N_k[:] = T
   N_max = T
else:
   # Permute data by replica
   print "Permuting data by replica..."
   U_kt_replica = U_kt.copy()
   phi_kt_replica = psi_kt.copy()
   psi_kt_replica = psi_kt.copy()
   for iteration in range(niterations):
      # Determine which snapshot indices are associated with this iteration
      snapshot_indices = iteration*trajectory_segment_length + numpy.arange(0,trajectory_segment_length)
      for k in range(K):
         # Determine which replica generated the data from temperature k at this iteration
         replica_index = replica_ik[iteration,k]
         # Reconstruct portion of replica trajectory.
         U_kt_replica[replica_index,snapshot_indices] = U_kt[k,snapshot_indices]
         phi_kt_replica[replica_index,snapshot_indices] = phi_kt[k,snapshot_indices]
         psi_kt_replica[replica_index,snapshot_indices] = psi_kt[k,snapshot_indices]
   # Estimate the statistical inefficiency of the simulation by analyzing the timeseries of interest.
   # We use the max of cos and sin of the phi and psi timeseries because they are periodic angles.
   # The 
   print "Computing statistical inefficiencies..."
   g_cosphi = timeseries.statisticalInefficiencyMultiple(numpy.cos(phi_kt_replica * numpy.pi / 180.0))
   print "g_cos(phi) = %.1f" % g_cosphi
   g_sinphi = timeseries.statisticalInefficiencyMultiple(numpy.sin(phi_kt_replica * numpy.pi / 180.0))
   print "g_sin(phi) = %.1f" % g_sinphi   
   g_cospsi = timeseries.statisticalInefficiencyMultiple(numpy.cos(psi_kt_replica * numpy.pi / 180.0))
   print "g_cos(psi) = %.1f" % g_cospsi
   g_sinpsi = timeseries.statisticalInefficiencyMultiple(numpy.sin(psi_kt_replica * numpy.pi / 180.0))
   print "g_sin(psi) = %.1f" % g_sinpsi
   # Subsample data with maximum of all correlation times.
   print "Subsampling data..."
   g = numpy.max(numpy.array([g_cosphi, g_sinphi, g_cospsi, g_sinpsi]))
   indices = timeseries.subsampleCorrelatedData(U_kt[k,:], g = g)   
   print "Using g = %.1f to obtain %d uncorrelated samples per temperature" % (g, len(indices))
   N_max = int(numpy.ceil(T / g)) # max number of samples per temperature   
   U_kn = numpy.zeros([K, N_max], numpy.float64)
   phi_kn = numpy.zeros([K, N_max], numpy.float64)
   psi_kn = numpy.zeros([K, N_max], numpy.float64)
   N_k = N_max * numpy.ones([K], numpy.int32)
   for k in range(K):
      U_kn[k,:] = U_kt[k,indices]
      phi_kn[k,:] = phi_kt[k,indices]
      psi_kn[k,:] = psi_kt[k,indices]
   print "%d uncorrelated samples per temperature" % N_max
         
#===================================================================================================
# Generate a list of indices of all configurations in kn-indexing
#===================================================================================================

# Create a list of indices of all configurations in kn-indexing.
mask_kn = numpy.zeros([K,N_max], dtype=numpy.bool)
for k in range(0,K):
   mask_kn[k,0:N_k[k]] = True
# Create a list from this mask.
indices = numpy.where(mask_kn)

#===================================================================================================
# Compute reduced potential energy of all snapshots at all temperatures
#===================================================================================================

print "Computing reduced potential energies..."
u_kln = numpy.zeros([K,K,N_max], numpy.float32) # u_kln[k,l,n] is reduced potential energy of trajectory segment n of temperature k evaluated at temperature l
for k in range(K):
   for l in range(K):
      u_kln[k,l,0:N_k[k]] = beta_k[l] * U_kn[k,0:N_k[k]]

#===================================================================================================
# Initialize MBAR.
#===================================================================================================

# Initialize MBAR with Newton-Raphson
print "Initializing MBAR (will estimate free energy differences first time)..."
mbar = pymbar.MBAR(u_kln, N_k, method='Newton-Raphson', verbose=True, initialize='BAR')

#===================================================================================================
# Compute PMF at the desired temperature.
#===================================================================================================

print "Computing potential of mean force..."

# Construct Gaussian kernels.
target_beta = 1.0 / (kB * target_temperature)
bin_centers = numpy.zeros([nbins], numpy.float32)
u_kln = numpy.zeros([K,nbins,N_max], numpy.float32) # u_ikn[i,k,n] is the umbrella + potential from sample n of temperature k
K = bandwidth**(-2)
for i in range(nbins):
   center = -180 + (i+0.5)*(360.0 / nbins) # bin center
   bin_centers[i] = center
   delta = abs(psi_kn[:,:] - center)
   # wrap around
   #indices = numpy.argwhere(delta > 180.0)
   #delta[indices] = 360.0 - delta[indices]
   delta = numpy.where(delta > 180.0, 360.0 - delta, delta)
   u_kln[:,i,:] = target_beta * U_kn[:,:] + (K/2)*delta**2
[f_ij, d2f_ij] = mbar.computePerturbedFreeEnergies(u_kln)

# Find index of bin with lowest free energy.
imin = f_ij[0,:].argmin()

# Show free energy and uncertainty of each occupied bin relative to lowest free energy
print "PMF"
print ""
print "%8s %6s %10s %10s" % ('bin', 'psi', 'f', 'df')
for i in range(nbins):
   print '%8d %6.1f %10.3f %10.3f' % (i, bin_centers[i], f_ij[imin,i], sqrt(d2f_ij[imin,i]))

outfile = open('pmf.out', 'w')
for i in range(nbins):
   outfile.write('%8d %6.1f %10.3f %10.3f\n' % (i, bin_centers[i], f_ij[imin,i], sqrt(d2f_ij[imin,i])))
outfile.close()



