#!/usr/bin/python

# Example illustrating the use of MBAR for computing the hydration free energy of OPLS 3-methylindole
# in TIP3P water through alchemical free energy simulations.

#===================================================================================================
# IMPORTS
#===================================================================================================

import pdb
import numpy
import pymbar # multistate Bennett acceptance ratio estimator.
import timeseries # for timeseries analysis 
import commands
import re

#===================================================================================================
# CONSTANTS
#===================================================================================================

convert_atmnm3_to_kJmol = 1.01325e5*(1e-09)**3 * 6.02214 * (1e23) / 1000 # Convert pV from atmospheres*nm^3 into kJ/mol
kB = 1.381*6.02214/1000.0  # Boltzmann's constant (kJ/mol/K)
relative_tolerance = 1e-10
verbose = True
optional = False  # set this to false if we only want the results for the paper
#bEnergy = True # whether or not the energy is being printed in the dhdl file
bEnergy = False # whether or not the energy is being printed in the dhdl file
bExpanded = True # is this an expanded ensemble file, such that all the states are in one file

methods = ['TI','DEXP','IEXP','BAR','MBAR'];

#===================================================================================================
# PARAMETERS
#===================================================================================================

temperature = 298.0 # temperature (K)
pressure = 1.0 # pressure (atm)

nequil = 100; # number of samples from each state assumed to be equilibrating, and thus omitted.
#datafile_directory = 'datadirectory' # directory in which datafiles are stored
datafile_directory = '/bigtmp/mrs5ptstore/' # directory in which datafiles are stored
#datafile_prefix  = 'wang-landau' # prefixes for datafile sets 
#datafile_prefix  = 'tbut_p5_p005.dhdl' # prefixes for datafile sets
#datafile_prefix  = 'tbut_p25_c10v12.dhdl' # prefixes for datafile sets
datafile_prefix  = 'CONH2_c10v10.dhdl' # prefixes for datafile sets 

#===================================================================================================
# HELPER FUNCTIONS
#===================================================================================================
def sortbynum(item):
   vals = item.split('.')
   for v in reversed(vals):
      if v.isdigit():
         return int(v)
   print "Error: No digits found in filename, can't sort ", item


#===================================================================================================
# MAIN
#===================================================================================================

# Compute inverse temperature in 1/(kJ/mol)
beta = 1./(kB*temperature)

#===================================================================================================
# Read all snapshot data
#===================================================================================================

# Generate a list of all datafiles with this prefix.
# NOTE: This uses the file ordering provided by 'ls' to determine which order the states are in.
filenames = commands.getoutput('ls %(datafile_directory)s/%(datafile_prefix)s.*' % vars()).split()

# Determine number of files
n_files = len(filenames)
# Determine number of files
filenames = commands.getoutput('ls %(datafile_directory)s/%(datafile_prefix)s*' % vars()).split()   

nsnapshots = numpy.zeros(n_files, int) # nsnapshots[nf] is the number of snapshots from file nf
# no larger than number of lines starting out.

if len(filenames) > 0:
   # sort the files numerically.
   filenames.sort(key=sortbynum)

for nf in range(n_files):
   # Temporarily read the file into memory.

   infile = open(filenames[nf], 'r')
   lines = infile.readlines()
   infile.close()
   
   # Determine maxnumber of snapshots from quickly parsing file and ignoring header lines.
   nsnapshots[nf] = 0
   for line in lines:
      if ((line[0] == '#') or (line[0] == '@')):
         continue
      nsnapshots[nf] += 1

# Determine maximum number of snapshots from any state
maxn = max(nsnapshots)

# first, we need to read the prefixes to see what's happening. Assume it can be read from first file.
filename = filenames[0]
# Read contents of file into memory.
print "Reading metadata from %s..." % filename
infile = open(filename, 'r')
lines = infile.readlines()
infile.close()

n_components = 0
n_states = 0
bPV = False
maxlambda = 1000 # arbitrary number higher than max
for line in lines:
   # Split line into elements.
   elements = line.split()
   
   # This section automatically parses the header to count the number
   # of dhdl components, and the number of states at which energies
   # are calculated, and should be modified for different file input formats.
   #                                                                          
   if ((line[0] == '#') or (line[0] == '@')):
      if (line[0] == '@'):
         # it's an xvg legend entry -- load in the information
         if (line[2] == 's'):  
            # it's a xvg entry, and a lambda component, so note it down
            if (re.search('-lambda',line)):     
               #it's a listing of the lambdas
               n_components +=1
               lv = numpy.zeros([n_components,maxlambda],float)  
            elif (re.search("\\\\xD\\\\f",line)): 
               lambda_string = elements[5]
               lambda_list = re.sub('[()\"]','',lambda_string)
               lambdas = lambda_list.split(',');
               for i in range(n_components):
                  lv[i,n_states] = float(lambdas[i])
               n_states+=1;   
            elif (re.search("pv",line)):     
               bPV = 1;   

print "Done reading metadata from %s..." % filename

K = n_states
lv = lv[0:n_components,0:K]
dhdlt = numpy.zeros([K,n_components,maxn],float) 
u_klt = numpy.zeros([K,K,maxn], numpy.float64) # u_klt[k,m,t] is the reduced potential energy of snapshot t of state k evaluated at state m
# Parse the file.
pe = numpy.zeros(K,numpy.float64); # temporary storage for energies
nsnapshots = numpy.zeros(K, int) # nsnapshots[k] is the number of states from file k

# Load all of the data
for nf in range(n_files):
   # File to be read
   filename = filenames[nf]   
   
   # Read contents of file into memory.
   print "Reading %s..." % filename
   infile = open(filename, 'r')
   lines = infile.readlines()
   infile.close()

   for line in lines:
      # Split line into elements.
      if ((line[0] != '#') and (line[0] != '@')):

         elements = line.split()
         # what is the time of the sample
         time = float(elements.pop(0))

         state = int(elements.pop(0))

         if (bEnergy):
            energy = float(elements.pop(0))
         else:
            energy = 0
         # 
         # In this section, store the derivative with respect to lambda
         # 
      
         for nl in range(n_components):
            dhdlt[state,nl,nsnapshots[state]] = float(elements.pop(0))
         # now record the potential energy differences.   
         for l in range(K):
            pe[l] = float(elements.pop(0)) + energy
                  
         # pressure-volume contribution
         if (bPV):
            pv = float(elements.pop(0))
         else:
            pv = 0

         # compute and store reduced potential energy at each state
         for l in range(K):
            u_klt[state,l,nsnapshots[state]] = beta * (pe[l] + pv)
            
         nsnapshots[state] +=1 

#===================================================================================================
# Preliminaries: Subsample data to obtain uncorrelated samples
#===================================================================================================   

maxn = numpy.max(nsnapshots)
nequilibrated = maxn - nequil
u_kln = numpy.zeros([K,K,nequilibrated], numpy.float64) # u_kln[k,m,n] is the reduced potential energy of uncorrelated sample index n from state k evaluated at state m
dhdl = numpy.zeros([K,n_components,nequilibrated], float) #dhdl is value for dhdl for each component in the file at each time.
N_k = numpy.zeros(K, int) # N_k[k] is the number of uncorrelated samples from state k
g = numpy.zeros(K,float) # autocorrelation times for the data
for k in range(K):
   # Determine indices of uncorrelated samples from potential autocorrelation analysis at state k.
   # alternatively, could use the energy differences -- here, we will use total dhdl
   dhdl_sum = numpy.sum(dhdlt[k,:,nequil:nsnapshots[k]],axis=0)
   g[k] = timeseries.statisticalInefficiency(dhdl_sum)
   indices = numpy.array(timeseries.subsampleCorrelatedData(dhdl_sum,g=g[k])) # indices of uncorrelated samples
   N = len(indices) # number of uncorrelated samples
   N_k[k] = N      
   indices += nequil

   for n in range(n_components):
      dhdl[k,n,0:N] = dhdlt[k,n,indices]
   for l in range(K):         
      u_kln[k,l,0:N] = u_klt[k,l,indices]
print "Correlation times:"
print g
print ""
print "number of uncorrelated samples:"
print N_k
print ""

#===================================================================================================
# Preliminaries: Calculate average TI values
#===================================================================================================   

ave_dhdl = numpy.zeros([K,n_components],float)
std_dhdl = numpy.zeros([K,n_components],float)
for k in range(K):           
   # first, compute <dhdl> and std(dhdl) for each component, for each simulation         
   ave_dhdl[k,:] = numpy.average(dhdl[k,:,0:N_k[k]],axis=1)        
   std_dhdl[k,:] = numpy.std(dhdl[k,:,0:N_k[k]],axis=1)/numpy.sqrt(N_k[k]-1)        

#===================================================================================================
# Calculate free energies with different methods
#===================================================================================================    

df_allk = list()
ddf_allk = list()

for k in range(K-1):
   df = dict()
   ddf = dict()

   for name in methods:
      if name == 'TI':
         #===================================================================================================
         # Estimate free energy difference with TI.
         #===================================================================================================   
         # multiply by beta to get it dimensionless like other free energy intervals
         dlam = lv[:,k+1]-lv[:,k]
         df['TI'] = 0.5*beta*numpy.dot(dlam,(ave_dhdl[k]+ave_dhdl[k+1]))        
         ddf['TI'] = 0.5*beta*numpy.sqrt(numpy.dot(dlam**2,std_dhdl[k]**2+std_dhdl[k+1]**2))               

      if name == 'DEXP':
         #===================================================================================================
         # Estimate free energy difference with Forward-direction EXP (in this case, Deletion from solvent)
         #===================================================================================================   
                         
         w_F = u_kln[k,k+1,0:N_k[k]] - u_kln[k,k,0:N_k[k]] 
         (df['DEXP'], ddf['DEXP']) = pymbar.EXP(w_F)

      if name == 'IEXP':
         #===================================================================================================
         # Estimate free energy difference with Reverse-direction EXP (in this case, insertion into solvent)
         #===================================================================================================   
         w_R = u_kln[k+1,k,0:N_k[k+1]] - u_kln[k+1,k+1,0:N_k[k+1]] 
         (rdf,rddf) = pymbar.EXP(w_R)
         (df['IEXP'], ddf['IEXP']) = (-rdf,rddf)

      if name == 'BAR':
         #===================================================================================================
         # Estimate free energy difference with BAR.
         #===================================================================================================   

         # w_F and w_R computed above                      
         (df['BAR'], ddf['BAR']) = pymbar.BAR(w_F, w_R, relative_tolerance=relative_tolerance, verbose = verbose)      


   df_allk = numpy.append(df_allk,df)
   ddf_allk = numpy.append(ddf_allk,ddf)

for name in methods:
   if name == 'MBAR':

   #===================================================================================================
   # Estimate free energy difference with MBAR -- all states at once
   #===================================================================================================   

      # Initialize MBAR (computing free energy estimates, which may take a while)
      if (verbose):
         print "Computing free energy differences..."
      MBAR = pymbar.MBAR(u_kln, N_k, verbose = verbose, method = 'adaptive', relative_tolerance = relative_tolerance, initialize = 'zeroes') # use faster Newton-Raphson solver
      #MBAR = pymbar.MBAR(u_kln, N_k, verbose = verbose, method = 'Newton-Raphson', relative_tolerance = relative_tolerance, initialize = 'zeroes') # use faster Newton-Raphson solver
      
      if (verbose):
         # Get matrix of dimensionless free energy differences and uncertainty estimate.
         print "Computing covariance matrix..."

      (Deltaf_ij, dDeltaf_ij) = MBAR.getFreeEnergyDifferences(uncertainty_method='svd-ew')

      if (verbose):
         # Matrix of free energy differences
         print "Deltaf_ij:"		
         print Deltaf_ij
         # Matrix of uncertainties in free energy difference (expectations standard deviations of the estimator about the true free energy)
         print "dDeltaf_ij:"
         print dDeltaf_ij
   	
      for k in range(K-1):
         df_allk[k]['MBAR'] = Deltaf_ij[k,k+1]                      
         ddf_allk[k]['MBAR'] = dDeltaf_ij[k,k+1]

#All done with calculations, now summarize and print stats

dF = dict();
ddF = dict();
for name in methods:
   if name == 'MBAR': 
      dF['MBAR'] = Deltaf_ij[0,K-1]                            
      ddF['MBAR'] = dDeltaf_ij[0,K-1]
   else:
      dF[name] = 0
      ddF[name] = 0
      for k in range(K-1):
         dF[name] += df_allk[k][name]
         ddF[name] += (ddf_allk[k][name])**2
      ddF[name] = numpy.sqrt(ddF[name]);

for name in methods:
   print '%10s (kJ/mol)' % name,
print ''
for k in range(K-1):
   print '%5d: ' % k,
   for name in methods:
      print '%8.3f +- %6.3f' % (df_allk[k][name]/beta,ddf_allk[k][name]/beta),
   print ''
for name in methods:
   print '-------------------', 
print ''
print 'TOTAL: ',
for name in methods:
   print '%8.3f +- %6.3f' % (dF[name]/beta,ddF[name]/beta),
print ''  
