#!/usr/bin/python

# Example illustrating the use of MBAR for computing the hydration free energy of OPLS 3-methylindole
# in TIP3P water through alchemical free energy simulations.

#===================================================================================================
# IMPORTS
#===================================================================================================

import numpy
import pymbar # multistate Bennett acceptance ratio estimator.
import timeseries # for timeseries analysis 
import commands
import re
import pdb
import scipy
from scipy import interpolate

#===================================================================================================
# CONSTANTS
#===================================================================================================

convert_atmnm3_to_kJmol = 1.01325e5*(1e-09)**3 * 6.02214 * (1e23) / 1000 # Convert pV from atmospheres*nm^3 into kJ/mol
kB = 1.381*6.02214/1000.0  # Boltzmann's constant (kJ/mol/K)
relative_tolerance = 1e-10
verbose = False
optional = False  # set this to false if we only want the results for the paper
NR = 1
Nboot = 40  

methods = ['MBAR','ENDPOINT'];

############vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv#######################
#####THIS SECTION OF METHODS CAN BE TAKEN OUT OF THE DISTRIBUTION########

# expanded list of methods
#methods = ['TI','FEXP','REXP','UBAR','RBAR','BAR','PMBAR','MBAR']; 
#methods = ['TI','FEXP','REXP','UBAR','RBAR','BAR','PMBAR','MBAR','UMBAR']; 

#####THIS SECTION OF METHODS CAN BE TAKEN OUT OF THE DISTRIBUTION########
############^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^#######################


#===================================================================================================
# PARAMETERS
#===================================================================================================

temperature = 298.0 # temperature (K)
pressure = 1.0 # pressure (atm)
mpow = 3 # maximum TI fit parameter
nequil = 50; # number of samples assumed to be equilibration
#datafile_directory = '../../large-datasets/trp50ns/' # directory in which datafiles are store
#datafile_directory = '../../large-datasets/trp38/' # directory in which datafiles are stored
#datafile_directory = '../../trunk/examples/alchemical-free-energy/data/3-methylindole-38steps/' # directory in which datafiles are stored
#datafile_directory = '../../trunk/examples/alchemical-free-energy/data/3-methylindole-11steps/' # directory in which datafiles are stored
datafile_directory = './'
datafile_prefix  = 'BUAM.dgdl' # prefixes for datafile sets 

#===================================================================================================
# HELPER FUNCTIONS
#===================================================================================================
def sortbynum(item):
   vals = item.split('.')
   for v in reversed(vals):
      if v.isdigit():
         return int(v)
   print "Error: No digits found in filename, can't sort ", item

#===================================================================================================
# MAIN
#===================================================================================================

# Compute inverse temperature in 1/(kJ/mol)
beta = 1./(kB*temperature)

#===================================================================================================
# Read all snapshot data
#===================================================================================================

# Generate a list of all datafiles with this prefix.
# NOTE: This uses the file ordering provided by 'ls' to determine which order the states are in.
filenames = commands.getoutput('ls %(datafile_directory)s/%(datafile_prefix)s.*' % vars()).split()

# Determine number of alchemical intermediates.
K = len(filenames)
# Determine length of files.
filenames = commands.getoutput('ls %(datafile_directory)s/%(datafile_prefix)s*' % vars()).split()   

# sort the files numerically, if possible.
filenames.sort(key=sortbynum)

K = 18
nsnapshots = numpy.zeros(K, int) # nsnapshots[k] is the number of snapshots for state k; 
# no larger than number of lines starting out.
for k in range(len(filenames)):
   # Temporari1ly read the file into memory.
   infile = open(filenames[k], 'r')
   lines = infile.readlines()
   infile.close()
   
   # Determine maxnumber of snapshots from quickly parsing file and ignoring header lines.
   maxn = 0
   for line in lines:
      if ((line[0] == '#') or (line[0] == '@')):
         continue
      maxn += 1

# Load all of the data
u_klt = numpy.zeros([K,K,maxn], numpy.float64) # u_klt[k,m,t] is the reduced potential energy of snapshot t of state k evaluated at state m
for k in range(len(filenames)):
   # File to be read
   filename = filenames[k]   
   
      # Read contents of file into memory.
   print "Reading %s..." % filename
   infile = open(filename, 'r')
   lines = infile.readlines()
   infile.close()

   t = 0
   n_components = 0
   n_states = 0
   bPV = False

   # Parse the file.
   pe = numpy.zeros(K,numpy.float64); # temporary storage for energies
   for line in lines:

      # split into elements
      elements = line.split()
      if ((line[0] == '#') or (line[0] == '@')):
         if (line[0] == '@'):
            # it's an xvg legend entry -- load in the information
            if (line[2] == 's'):  
               # it's a xvg entry, and a lambda component, so note it down
               if (re.search('-lambda',line)):     
                  #it's a listing of the lambdas
                  n_components +=1
                  lv = numpy.zeros([K,n_components],float)
               elif (re.search("\\\\8D\\\\4H \\\\8\\l\\\\4",line)): 
                  lambda_string = elements[5]
                  lambda_list = re.sub('[()\"]','',lambda_string)
                  lambdas = lambda_list.split(',');
                  for i in range(n_components):
                     lv[n_states,i] = lambdas[i]
                  n_states+=1;   
               #elif (re.search("pv",line)):     
               #    bPV = 1;   # for testing now, eliminated PV term
      else:                           
         if (t==0) and (k==0):     # we don't know the number of components until here.
            n_components = 1
            dhdlt = numpy.zeros([K,n_components,maxn],float) 

         time = float(elements.pop(0))
         #
         #  If we print the energy in the dhdl file; if not, delete this line.
         #
         #energy = float(elements.pop(0))            
         
         state = int(elements.pop(0))-1            
         # now record the derivative with respect to lambda
      
         for nl in range(n_components):
            dhdlt[state,nl,t] = float(elements.pop(0))
         # now record the potential energy differences.   
         for l in range(K):
            #pe[l] = float(elements.pop(0)) + energy
            pe[l] = float(elements.pop(0))
                  
         # pressure-volume contribution
         if (bPV):
            pv = float(elements.pop(0))
         else:
            pv = 0
         # compute and store reduced potential energy at each state
         for l in range(K):
            u_klt[state,l,nsnapshots[state]] = beta * (pe[l] + pv)
               
         t += 1   
         nsnapshots[state] += 1   

#===================================================================================================
# Preliminaries: Subsample data to obtain uncorrelated samples
#===================================================================================================   
Nequilibrated = max(nsnapshots) - nequil
ur_kln = numpy.zeros([K,K,Nequilibrated], numpy.float64) # ur_kln[k,m,n] is the reduced potential energy of uncorrelated sample index n from state k evaluated at state m
dhdlr = numpy.zeros([K,n_components,Nequilibrated], float) #dhdl is value for dhdl for each component in the file at each time.
NR_k = numpy.zeros(K, int) # NR_k[k] is the number of uncorrelated samples from state k over all replicates
g = numpy.zeros(K,float) # autocorrelation times for the data
for k in range(K):
   pdb.set_trace()
   # Determine indices of uncorrelated samples from potential autocorrelation analysis at state k.
   # alternatively, could use the energy differences -- here, we will use total dhdl
   g[k] = timeseries.statisticalInefficiency(u_klt[k,k,nequil:nsnapshots[k]])
   indices = numpy.array(timeseries.subsampleCorrelatedData(u_klt[k,k,nequil:nsnapshots[k]],g=g[k])) # indices of uncorrelated samples
   N = len(indices) # number of uncorrelated samples
   NR_k[k] = N      
   indices += nequil
   #for n in range(n_components):
   #   dhdlr[k,n,0:N] = dhdlt[k,n,indices]
   for l in range(K):         
      ur_kln[k,l,0:N] = u_klt[k,l,indices]

print "Correlation times:"
print g
print ""
print "number of uncorrelated samples:"
print NR_k
print ""

#Now, figure out how many replicates we have:
SN_k = numpy.zeros([NR+1,K],int)
for k in range(K):
   for r in range(NR):
      SN_k[r+1,k] += SN_k[r,k] + NR_k[k]/NR
   # last one might have a few more based on divisibilities
   SN_k[NR,k] = NR_k[k]
 
#===================================================================================================
# Preliminaries: Calculate average TI values
#===================================================================================================   


replicate_dF = list()
replicate_ddF = list()
replicate_ddFboot = list()
replicate_dU = list()
replicate_ddU = list()
replicate_ddUboot = list()
replicate_dS = list()
replicate_ddS = list()
replicate_ddSboot = list()

for r in range(NR):
   print "Replicate: ",
   print r+1

   # for bootstrap: first is the normal arrangement, then the rest are random indices.

   N_k = SN_k[r+1,:] - SN_k[r,:]

   dhdl_original = numpy.zeros([K,n_components,maxn], float)
   u_kln_original = numpy.zeros([K,K,maxn],numpy.float64)

   for k in range(K):
      u_kln_original[k,:,0:N_k[k]] = ur_kln[k,:,SN_k[r,k]:SN_k[r+1,k]] 
      dhdl_original[k,:,0:N_k[k]] = dhdlr[k,:,SN_k[r,k]:SN_k[r+1,k]]        

   u_kln = numpy.zeros([K,K,max(N_k)],numpy.float64)
   dhdl = numpy.zeros([K,n_components,max(N_k)],float)
   random_indices = numpy.zeros([K,Nboot+1,maxn],int);
   for k in range(K):
      random_indices[k,0,0:N_k[k]] = range(N_k[k]);
      
   for b in range(1,Nboot+1):
      numpy.random.seed()
      for k in range(K):
         random_indices[k,b,0:N_k[k]]=numpy.random.random_integers(0,high=N_k[k]-1,size=N_k[k])

   bootf_list = list()
   bootu_list = list()
   boots_list = list()

   for b in range(Nboot+1):
      MBAR = None
      print "Bootstrap: ",
      print b
      for k in range(K):
         for l in range(K):
            u_kln[k,l,0:N_k[k]] = u_kln_original[k,l,random_indices[k,b,0:N_k[k]]]
         for n in range(n_components):
            dhdl[k,n,0:N_k[k]] = dhdl_original[k,n,random_indices[k,b,0:N_k[k]]]

   #===================================================================================================
   # Calculate free energies, entropies, and enthalpies with different methods
   #===================================================================================================    
            
      df_allk = list()
      ddf_allk = list()
      du_allk = list()
      ddu_allk = list()
      ds_allk = list()
      dds_allk = list()

      for k in range(K-1):
         df_allk = numpy.append(df_allk,dict())
         ddf_allk = numpy.append(ddf_allk,dict())
         du_allk = numpy.append(du_allk,dict())
         ddu_allk = numpy.append(ddu_allk,dict())
         ds_allk = numpy.append(ds_allk,dict())
         dds_allk = numpy.append(dds_allk,dict())

      ave_dhdl = numpy.zeros([K,n_components],float)
      std_dhdl = numpy.zeros([K,n_components],float)

      for name in methods:
         # Initialize MBAR (computing free energy estimates, which may take a while)
         #print "Computing free energy differences..."
         if name == 'MBAR' or name == 'ENDPOINT':
            if (MBAR == None):
               MBAR = pymbar.MBAR(u_kln, N_k, verbose = verbose, method = 'Newton-Raphson', relative_tolerance = relative_tolerance) # use faster Newton-Raphson solver

         if name == 'MBAR':

            #===================================================================================================
            # Estimate free energy difference with MBAR -- all states at once
            #===================================================================================================   

            (Deltaf_ij, dDeltaf_ij, Deltau_ij, dDeltau_ij, Deltas_ij, dDeltas_ij) = MBAR.computeEntropyAndEnthalpy(uncertainty_method='svd-ew')

            if (verbose):
               # Matrix of free energy differences
               print "Deltaf_ij:"		
               print Deltaf_ij
               # Matrix of uncertainties in free energy difference (expectations standard deviations of the estimator about the true free energy)
               print "dDeltaf_ij:"
               print dDeltaf_ij
     	
            for k in range(K-1):
               df_allk[k]['MBAR'] = Deltaf_ij[k,k+1]                      
               ddf_allk[k]['MBAR'] = dDeltaf_ij[k,k+1]
               du_allk[k]['MBAR'] = Deltau_ij[k,k+1]                      
               ddu_allk[k]['MBAR'] = dDeltau_ij[k,k+1]
               ds_allk[k]['MBAR'] = Deltas_ij[k,k+1]                      
               dds_allk[k]['MBAR'] = dDeltas_ij[k,k+1]
               

         if name == 'ENDPOINT':
            #===================================================================================================
            # Estimate free energy difference with MBAR -- one step, matrix inversion method
            #===================================================================================================   
            # Initialize MBAR (computing free energy estimates, which may take a while)
            print "Computing free energy differences..."

            (Delta_f_ij, dDelta_f_ij) = MBAR.getFreeEnergyDifferences(uncertainty_method='svd-ew')
            # estimate enthalpy by endpoint method: <U>= \sum_{i=1}^N U / N

            u_k = numpy.zeros([K],float)
            du_k = numpy.zeros([K],float)
            for k in range(K):
               
               u_k[k] = numpy.average(u_kln[k,k,0:N_k[k]])
               du_k[k] = numpy.sqrt(numpy.var(u_kln[k,k,0:N_k[k]])/(N_k[k]))
               print "ENDPOINT: u_k[%d] %10.5f +/- %10.5f" % (k,u_k[k],du_k[k])

            for k in range(K-1):
               
               df_allk[k]['ENDPOINT'] = Deltaf_ij[k,k+1]                      
               ddf_allk[k]['ENDPOINT'] = dDeltaf_ij[k,k+1]
               du_allk[k]['ENDPOINT'] = u_k[k+1]-u_k[k]                      
               ddu_allk[k]['ENDPOINT'] = numpy.sqrt(du_k[k+1]**2 + du_k[k]**2)
               ds_allk[k]['ENDPOINT'] = (u_k[k+1]-u_k[k]) - Deltaf_ij[k,k+1] 
               dds_allk[k]['ENDPOINT'] = numpy.sqrt(dDeltaf_ij[k,k+1]**2 + du_k[k+1]**2 + du_k[k]**2)

   #All done with calculations, now summarize and print stats

      if (b==0):           
         dF = dict();
         ddF = dict();
         ddFboot = dict();

         dU = dict();
         ddU = dict();
         ddUboot = dict();

         dS = dict();
         ddS = dict();
         ddSboot = dict();

         for name in methods:
            if name == 'MBAR': 
               dF['MBAR'] = Deltaf_ij[0,K-1]                            
               ddF['MBAR'] = dDeltaf_ij[0,K-1]
               dU['MBAR'] = Deltau_ij[0,K-1]                            
               ddU['MBAR'] = dDeltau_ij[0,K-1]
               dS['MBAR'] = Deltas_ij[0,K-1]                            
               ddS['MBAR'] = dDeltas_ij[0,K-1]
               
            elif name == 'ENDPOINT':
               dF['ENDPOINT'] = Deltaf_ij[0,K-1]
               ddF['ENDPOINT'] = dDeltaf_ij[0,K-1]
               dU['ENDPOINT'] = u_k[K-1]-u_k[0]
               ddU['ENDPOINT'] = numpy.sqrt(du_k[0]**2 + du_k[K-1]**2)
               dS['ENDPOINT'] = (u_k[K-1]-u_k[0]) - Deltaf_ij[0,K-1] 
               ddS['ENDPOINT'] = numpy.sqrt(dDeltaf_ij[0,K-1]**2 + du_k[0]**2 + du_k[K-1]**2)
            else:
               #holding
               dF[name] = 0
               ddF[name] = 0
               dU[name] = 0
               ddU[k][name] = 0 
               dS[name] = 0
               ddS[name] = 0
      else:             
         bootf = dict();
         bootu = dict();
         boots = dict();
         for name in methods:
            if name == 'MBAR': 
               bootf['MBAR'] = Deltaf_ij[0,K-1]                            
               bootu['MBAR'] = Deltau_ij[0,K-1]                            
               boots['MBAR'] = Deltas_ij[0,K-1]                            
            elif name == 'ENDPOINT':
               bootf['ENDPOINT'] = Deltaf_ij[0,K-1]
               bootu['ENDPOINT'] = u_k[K-1]-u_k[0]
               boots['ENDPOINT'] = (u_k[K-1]-u_k[0]) - Deltaf_ij[0,K-1]
            else:
               bootf[name] = 0
               bootu[name] = 0
               boots[name] = 0

         bootf_list.append(bootf)
         bootu_list.append(bootu)
         boots_list.append(boots)
     
      if (b==Nboot):
         for name in methods:
            fvals = numpy.zeros([Nboot],float)
            uvals = numpy.zeros([Nboot],float)
            svals = numpy.zeros([Nboot],float)
            for i in range(Nboot):
               fvals[i] = bootf_list[i][name]
               uvals[i] = bootu_list[i][name]
               svals[i] = boots_list[i][name]
            ddFboot[name] = numpy.std(fvals)
            ddUboot[name] = numpy.std(uvals)
            ddSboot[name] = numpy.std(svals)

   replicate_dF = numpy.append(replicate_dF,dF)
   replicate_ddF = numpy.append(replicate_ddF,ddF)
   replicate_ddFboot = numpy.append(replicate_ddFboot,ddFboot)
   replicate_dU = numpy.append(replicate_dU,dU)
   replicate_ddU = numpy.append(replicate_ddU,ddU)
   replicate_ddUboot = numpy.append(replicate_ddUboot,ddUboot)
   replicate_dS = numpy.append(replicate_dS,dS)
   replicate_ddS = numpy.append(replicate_ddS,ddS)
   replicate_ddSboot = numpy.append(replicate_ddSboot,ddSboot)

for name in methods:
   print '%10s (kJ/mol):' % name,

   rep_dF = numpy.zeros(NR,float)   
   rep_ddF = numpy.zeros(NR,float);
   rep_ddFboot = numpy.zeros(NR,float);

   rep_dU = numpy.zeros(NR,float)   
   rep_ddU = numpy.zeros(NR,float);
   rep_ddUboot = numpy.zeros(NR,float);

   rep_dS = numpy.zeros(NR,float)   
   rep_ddS = numpy.zeros(NR,float);
   rep_ddSboot = numpy.zeros(NR,float);

   for r in range(NR):
      rep_dF[r] = replicate_dF[r][name]/beta
      rep_ddF[r] = replicate_ddF[r][name]/beta
      rep_ddFboot[r] = replicate_ddFboot[r][name]/beta

      rep_dU[r] = replicate_dU[r][name]/beta
      rep_ddU[r] = replicate_ddU[r][name]/beta
      rep_ddUboot[r] = replicate_ddUboot[r][name]/beta

      rep_dS[r] = replicate_dS[r][name]/beta
      rep_ddS[r] = replicate_ddS[r][name]/beta
      rep_ddSboot[r] = replicate_ddSboot[r][name]/beta

   ave_dF = numpy.average(rep_dF)
   std_dF = numpy.std(rep_dF,ddof=1)
   ave_ddF = numpy.average(rep_ddF)
   std_ddF = numpy.std(rep_ddF,ddof=1)

   ave_dU = numpy.average(rep_dU)
   std_dU = numpy.std(rep_dU,ddof=1)
   ave_ddU = numpy.average(rep_ddU)
   std_ddU = numpy.std(rep_ddU,ddof=1)

   ave_dS = numpy.average(rep_dS)
   std_dS = numpy.std(rep_dS,ddof=1)
   ave_ddS = numpy.average(rep_ddS)
   std_ddS = numpy.std(rep_ddS,ddof=1)

   boot_dF = numpy.average(rep_ddFboot)
   std_boot_dF = numpy.std(rep_ddFboot,ddof=1)

   boot_dU = numpy.average(rep_ddUboot)
   std_boot_dU = numpy.std(rep_ddUboot,ddof=1)

   boot_dS = numpy.average(rep_ddSboot)
   std_boot_dS = numpy.std(rep_ddSboot,ddof=1)

   print '%8.3f +/- %6.3f (predicted %6.3f +/- %6.3f, bootstrap %6.3f +/- %6.3f)' % (ave_dF,std_dF,ave_ddF,std_ddF,boot_dF,std_boot_dF),
   print '%8.3f +/- %6.3f (predicted %6.3f +/- %6.3f, bootstrap %6.3f +/- %6.3f)' % (ave_dU,std_dU,ave_ddU,std_ddU,boot_dU,std_boot_dU),
   print '%8.3f +/- %6.3f (predicted %6.3f +/- %6.3f, bootstrap %6.3f +/- %6.3f)' % (ave_dS,std_dS,ave_ddS,std_ddS,boot_dS,std_boot_dS)

   print 'dF'
   print rep_dF,
   print rep_ddF,
   print rep_ddFboot

   print 'dU'
   print rep_dU,
   print rep_ddU,
   print rep_ddUboot

   print 'dS'
   print rep_dS,
   print rep_ddS,
   print rep_ddSboot
