#!/usr/bin/python

# Example illustrating the use of MBAR for computing the hydration free energy of OPLS 3-methylindole
# in TIP3P water through alchemical free energy simulations.

#===================================================================================================
# IMPORTS
#===================================================================================================

import numpy
import pymbar # multistate Bennett acceptance ratio estimator.
import timeseries # for timeseries analysis 
import commands
import re

#===================================================================================================
# CONSTANTS
#===================================================================================================

convert_atmnm3_to_kJmol = 1.01325e5*(1e-09)**3 * 6.02214 * (1e23) / 1000 # Convert pV from atmospheres*nm^3 into kJ/mol
kB = 1.381*6.02214/1000.0  # Boltzmann's constant (kJ/mol/K)
relative_tolerance = 1e-10
verbose = False
optional = False  # set this to false if we only want the results for the paper

methods = ['TI','FEXP','REXP','BAR','MBAR'];



#===================================================================================================
# PARAMETERS
#===================================================================================================

temperature = 298.0 # temperature (K)
pressure = 1.0 # pressure (atm)

nequil = 100; # number of samples assumed to be equilibration, and thus omitted.
datafile_directory = 'data/3-methylindole-38steps/' # directory in which datafiles are stored
datafile_prefix  = 'dhdl' # prefixes for datafile sets 

#===================================================================================================
# MAIN
#===================================================================================================

# Compute inverse temperature in 1/(kJ/mol)
beta = 1./(kB*temperature)

#===================================================================================================
# Read all snapshot data
#===================================================================================================

# Generate a list of all datafiles with this prefix.
# NOTE: This uses the file ordering provided by 'ls' to determine which order the states are in.
filenames = commands.getoutput('ls %(datafile_directory)s/%(datafile_prefix)s.*' % vars()).split()

# Determine number of alchemical intermediates.
K = len(filenames)
# Determine length of files.
filenames = commands.getoutput('ls %(datafile_directory)s/%(datafile_prefix)s*' % vars()).split()   
   
nsnapshots = numpy.zeros(K, int) # nsnapshots[k] is the number of snapshots for state k; 
# no larger than number of lines starting out.
for k in range(K):
   # Temporarily read the file into memory.
   infile = open(filenames[k], 'r')
   lines = infile.readlines()
   infile.close()
   
   # Determine maxnumber of snapshots (one snapshot per line).
   nsnapshots[k] = len(lines)
   
   maxn = max(nsnapshots)   
# Load all of the data
pe = numpy.zeros(K,numpy.float64);
u_klt = numpy.zeros([K,K,maxn], numpy.float64) # u_klt[k,m,t] is the reduced potential energy of snapshot t of state k evaluated at state m
for k in range(K):
   # File to be read
   filename = filenames[k]   
   
      # Read contents of file into memory.
   print "Reading %s..." % filename
   infile = open(filename, 'r')
   lines = infile.readlines()
   infile.close()

   t = 0
   n_components = 0
   n_states = 0
   bPV = False

   # Parse the file.

   for line in lines:


      # split into elements
      elements = line.split()

      # This section automatically parses the header to count the number
      # of dhdl components, and the number of states at which energies
      # are calculated, and should be modified for different file input formats.
      #                                                                          
      if ((line[0] == '#') or (line[0] == '@')):
         if (line[0] == '@'):
            # it's an xvg legend entry -- load in the information
            if (line[2] == 's'):  
               # it's a xvg entry, and a lambda component, so note it down
               if (re.search('-lambda',line)):     
                  #it's a listing of the lambdas
                  n_components +=1
                  lv = numpy.zeros([n_components,K],float)
               elif (re.search("\\\\8D\\\\4H \\\\8\\l\\\\4",line)): 
                  lambda_string = elements[5]
                  lambda_list = re.sub('[()\"]','',lambda_string)
                  lambdas = lambda_list.split(',');
                  for i in range(n_components):
                     lv[i,n_states] = lambdas[i]
                  n_states+=1;   
               #elif (re.search("pv",line)):     
               #    bPV = 1;   # for testing now, eliminated PV term
      else:                           
         if ((t==0) and (k==0)):     # we don't know the number of components until here.
            dhdlt = numpy.zeros([K,n_components,maxn],float) 

         time = float(elements.pop(0))
            
         # 
         # In this section, store the derivative with respect to lambda
         # 
      
         for nl in range(n_components):
            dhdlt[k,nl,t] = float(elements.pop(0))
         # now record the potential energy differences.   
         for l in range(K):
            pe[l] = float(elements.pop(0))
                  
         # pressure-volume contribution
         if (bPV):
            pv = float(elements.pop(0))
         else:
            pv = 0

         # compute and store reduced potential energy at each state
         for l in range(K):
            u_klt[k,l,t] = beta * (pe[l] + pv)
               
         t += 1   
   nsnapshots[k] = t

#===================================================================================================
# Preliminaries: Subsample data to obtain uncorrelated samples
#===================================================================================================   
Nequilibrated = max(nsnapshots) - nequil
u_kln = numpy.zeros([K,K,Nequilibrated], numpy.float64) # u_kln[k,m,n] is the reduced potential energy of uncorrelated sample index n from state k evaluated at state m
dhdl = numpy.zeros([K,n_components,Nequilibrated], float) #dhdl is value for dhdl for each component in the file at each time.
N_k = numpy.zeros(K, int) # N_k[k] is the number of uncorrelated samples from state k
g = numpy.zeros(K,float) # autocorrelation times for the data
for k in range(K):
   # Determine indices of uncorrelated samples from potential autocorrelation analysis at state k.
   # alternatively, could use the energy differences -- here, we will use total dhdl
   dhdl_sum = numpy.sum(dhdlt[k,:,:],axis=0)
   g[k] = timeseries.statisticalInefficiency(dhdl_sum[nequil:nsnapshots[k]])
   indices = numpy.array(timeseries.subsampleCorrelatedData(dhdl_sum[nequil:nsnapshots[k]],g=g[k])) # indices of uncorrelated samples
   N = len(indices) # number of uncorrelated samples
   N_k[k] = N      
   indices += nequil
   for n in range(n_components):
      dhdl[k,n,0:N] = dhdlt[k,n,indices]
   for l in range(K):         
      u_kln[k,l,0:N] = u_klt[k,l,indices]
print "Correlation times:"
print g
print ""
print "number of uncorrelated samples:"
print N_k
print ""

#===================================================================================================
# Preliminaries: Calculate average TI values
#===================================================================================================   

ave_dhdl = numpy.zeros([K,n_components],float)
std_dhdl = numpy.zeros([K,n_components],float)

for k in range(K):           
   # first, compute <dhdl> and std(dhdl) for each component, for each simulation         
   ave_dhdl[k,:] = numpy.average(dhdl[k,:,0:N_k[k]],axis=1)        
   std_dhdl[k,:] = numpy.std(dhdl[k,:,0:N_k[k]],axis=1)/numpy.sqrt(N_k[k]-1)        

#===================================================================================================
# Calculate free energies with different methods
#===================================================================================================    

df_allk = list()
ddf_allk = list()

for k in range(K-1):
   df = dict()
   ddf = dict()

   for name in methods:
      if name == 'TI':
         #===================================================================================================
         # Estimate free energy difference with TI.
         #===================================================================================================   

         # multiply by beta to get it dimensionless like other free energy intervals
         dlam = lv[:,k+1]-lv[:,k]
         df['TI'] = 0.5*beta*numpy.dot(dlam,(ave_dhdl[k]+ave_dhdl[k+1]))        
         ddf['TI'] = 0.5*beta*numpy.sqrt(numpy.dot(dlam**2,std_dhdl[k]**2+std_dhdl[k+1]**2))               

      if name == 'FEXP':
         #===================================================================================================
         # Estimate free energy difference with Forward-direction EXP.
         #===================================================================================================   
                         
         w_F = u_kln[k,k+1,0:N_k[k]] 
         (df['FEXP'], ddf['FEXP']) = pymbar.EXP(w_F)

      if name == 'REXP':
         #===================================================================================================
         # Estimate free energy difference with Reverse-direction EXP.
         #===================================================================================================   
         w_R = u_kln[k+1,k,0:N_k[k+1]] 
         (rdf,rddf) = pymbar.EXP(w_R)
         (df['REXP'], ddf['REXP']) = (-rdf,rddf)

      if name == 'BAR':
         #===================================================================================================
         # Estimate free energy difference with BAR.
         #===================================================================================================   

         # w_F and w_R computed above                      
         (df['BAR'], ddf['BAR']) = pymbar.BAR(w_F, w_R, relative_tolerance=relative_tolerance, verbose = verbose)      


   df_allk = numpy.append(df_allk,df)
   ddf_allk = numpy.append(ddf_allk,ddf)

for name in methods:
   if name == 'MBAR':

   #===================================================================================================
   # Estimate free energy difference with MBAR -- all states at once
   #===================================================================================================   

      # Initialize MBAR (computing free energy estimates, which may take a while)
      print "Computing free energy differences..."
      MBAR = pymbar.MBAR(u_kln, N_k, verbose = verbose, method = 'Newton-Raphson', relative_tolerance = relative_tolerance) # use faster Newton-Raphson solver

      if (verbose):
         # Get matrix of dimensionless free energy differences and uncertainty estimate.
         print "Computing covariance matrix..."

      (Deltaf_ij, dDeltaf_ij) = MBAR.getFreeEnergyDifferences(uncertainty_method='svd-ew')

      if (verbose):
         # Matrix of free energy differences
         print "Deltaf_ij:"		
         print Deltaf_ij
         # Matrix of uncertainties in free energy difference (expectations standard deviations of the estimator about the true free energy)
         print "dDeltaf_ij:"
         print dDeltaf_ij
   	
      for k in range(K-1):
         df_allk[k]['MBAR'] = Deltaf_ij[k,k+1]                      
         ddf_allk[k]['MBAR'] = dDeltaf_ij[k,k+1]


#All done with calculations, now summarize and print stats

dF = dict();
ddF = dict();
for name in methods:
   if name == 'MBAR': 
      dF['MBAR'] = Deltaf_ij[0,K-1]                            
      ddF['MBAR'] = dDeltaf_ij[0,K-1]
   else:
      dF[name] = 0
      ddF[name] = 0
      for k in range(K-1):
         dF[name] += df_allk[k][name]
         ddF[name] += (ddf_allk[k][name])**2
      ddF[name] = numpy.sqrt(ddF[name]);

for name in methods:
   print '%10s (kJ/mol)' % name,
print ''
for k in range(K-1):
   print '%5d: ' % k,
   for name in methods:
      print '%8.3f +- %6.3f' % (df_allk[k][name]/beta,ddf_allk[k][name]/beta),
   print ''
for name in methods:
   print '-------------------', 
print ''
print 'TOTAL: ',
for name in methods:
   print '%8.3f +- %6.3f' % (dF[name]/beta,ddF[name]/beta),
print ''  
