#!/usr/bin/python

# Example illustrating the use of MBAR for computing the hydration free energy of OPLS 3-methylindole
# in TIP3P water through alchemical free energy simulations.

#===================================================================================================
# IMPORTS
#===================================================================================================

from numpy import *
from math import *
import MBAR # multistate Bennett acceptance ratio estimator.
import timeseries # for timeseries analysis 
import commands
import pdb;
#===================================================================================================
# CONSTANTS
#===================================================================================================

convert_atmnm3_to_kJmol = 1.01325e5*(1e-09)**3 * 6.02214 * (1e23) / 1000 # Convert pV from atmospheres*nm^3 into kJ/mol
kB = 1.381*6.02214/1000.0  # Boltzmann's constant (kJ/mol/K)

#===================================================================================================
# PARAMETERS
#===================================================================================================

temperature = 283.0 # temperature (K)
pressure = 1.0 # pressure (atm)

datafile_directory = 'data' # directory in which datafiles are stored
datafile_prefixes  = ['trptest'] # prefixes for datafile sets ('coul.l*.dat' and 'vdw.l*.dat')

#===================================================================================================
# MAIN
#===================================================================================================

# Compute inverse temperature in 1/(kJ/mol)
beta = 1./(kB*temperature)

#===================================================================================================
# Read all snapshot data
#===================================================================================================

phases = list() # storage for overall estimate of free energy difference (in kJ/mol)
for datafile_prefix in datafile_prefixes:
   # Generate a list of all datafiles with this prefix.
   # NOTE: This uses the file ordering provided by 'ls' to determine which order the states are in.
   filenames = commands.getoutput('ls %(datafile_directory)s/%(datafile_prefix)s*' % vars()).split()

   # Determine number of alchemical intermediates.
   K = len(filenames)

   # Determine length of files.
   filenames = commands.getoutput('ls %(datafile_directory)s/%(datafile_prefix)s*' % vars()).split()   
   
   # Determine number of snapshots in each file
   nsnapshots = zeros(K, int32) # nsnapshots[k] is the number of snapshots for state k
   for k in range(K):
      # Temporarily read the file into memory.
      infile = open(filenames[k], 'r')
      lines = infile.readlines()
      infile.close()

      # Determine number of snapshots (one snapshot per line).
      nsnapshots[k] = len(lines)

   # Load all of the data
   u_klt = zeros([K,K,max(nsnapshots)], float64) # u_klt[k,m,t] is the reduced potential energy of snapshot t of state k evaluated at state m
   for k in range(K):
      # File to be read
      filename = filenames[k]   

      # Read contents of file into memory.
      print "Reading %s..." % filename
      infile = open(filename, 'r')
      lines = infile.readlines()
      infile.close()

      # Parse the file.
      for t in range(nsnapshots[k]):
         # Parse the line
         #
        # DATAFILE FORMAT                                                                                                               
         # Format illustrated by example lines below.                                                                                    
         #                                                                                                     1         1               
         #           1         2         3         4         5         6         7         8         9         0         1               
         # 012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456         
         #                                                                                                                               
         # Example lines:                                                                                                                
         #    1000.000 -3.5852624038e+04 -3.5862709533e+04 -3.5890918437e+04 -3.5896703874e+04 -3.5908245048e+04   26.976760             
         #                                                                                                                               
         # Format:       
         # 0:10  time                                                                                                                    
         # 11:20  fep-lambda value                                                                                                       
         # 21:38 First lambda                                                                                                            
         # 39:56 Second lambda                                                                                                           
         # 57:74 and so on . .                                                                                                           
         #       last column is box volume (in nm^3) 
         #
         # (last column) box volume (in nm^3)

         # get line
         line = lines[t]

         # split into elements
         elements = line.split()

         if (line[0] != '#') and (line[0] != '@'):
            # first element is snapshot time (ps)
            #pdb.set_trace()

            time = float(elements.pop(0))

            # second element is state variable
            time = float(elements.pop(0))

            # potential evaluated at all states (kJ/mol)
            U_l = zeros(K, float64)
            for l in range(K):
               U_l[l] = float(elements.pop(0))
            # box volume (nm^3)
            volume = float(elements.pop(0))

            # compute pV work in kJ/mol
            #pV = pressure * volume * convert_atmnm3_to_kJmol
            pV = volume

            # compute and store reduced potential energy at each state
            for l in range(K):
               u_klt[k,l,t] = beta * (-U_l[l] + pV)

   #===================================================================================================
   # Subsample data to obtain uncorrelated samples
   #===================================================================================================   

   u_kln = zeros([K,K,max(nsnapshots)], float64) # u_kln[k,m,n] is the reduced potential energy of uncorrelated sample index n from state k evaluated at state m
   N_k = zeros(K, int32) # N_k[k] is the number of uncorrelated samples from state k
   for k in range(K):
      # Determine indices of uncorrelated samples from reduced potential energy autocorrelation analysis at state k.
      indices = timeseries.subsampleCorrelatedData(u_klt[k,k,0:nsnapshots[k]]) # indices of uncorrelated samples
      N = len(indices) # number of uncorrelated samples
      N_k[k] = N      
      for l in range(K):         
         u_kln[k,l,0:N] = u_klt[k,l,indices]

   print "number of uncorrelated samples:"
   print N_k
   print ""

   #===================================================================================================
   # Estimate free energy difference with MBAR.
   #===================================================================================================   

   # Initialize MBAR (computing free energy estimates, which may take a while)
   print "Computing free energy differences..."
   #mbar = MBAR.MBAR(u_kln, N_k, verbose = True, method = 'self-consistent-iteration') # use slow self-consistent-iteration (the default)
   mbar = MBAR.MBAR(u_kln, N_k, verbose = True, method = 'Newton-Raphson') # use faster Newton-Raphson solver

   # Get matrix of dimensionless free energy differences and uncertainty estimate.
   print "Computing covariance matrix..."
   (Deltaf_ij, dDeltaf_ij) = mbar.getFreeEnergyDifferences()

   # Matrix of free energy differences
   print "Deltaf_ij:"		
   print Deltaf_ij
   # Matrix of uncertainties in free energy difference (expectations standard deviations of the estimator about the true free energy)
   print "dDeltaf_ij:"
   print dDeltaf_ij
   	
   # Accumulate free energy differences
   data = dict()
   data['DeltaF'] = Deltaf_ij[0,K-1] / beta
   data['dDeltaF'] = dDeltaf_ij[0,K-1] / beta
   data['phase'] = datafile_prefix
   phases.append(data)

# Print summary statistics
DeltaF = 0.0
d2DeltaF = 0.0
for phase in phases:
   print '%32s %16.8f +- %16.8f kJ/mol' % (phase['phase'], phase['DeltaF'], phase['dDeltaF'])
   DeltaF += phase['DeltaF']
   d2DeltaF += phase['dDeltaF']**2
dDeltaF = sqrt(d2DeltaF)
print ""
print '%32s %16.8f +- %16.8f kJ/mol' % ('TOTAL', DeltaF, dDeltaF)

