#!/usr/bin/python
#
# $Id: gromacs_to_dg.py,v 1.10 2007/08/02 07:10:25 mrshirts Exp $

"""Python script to load Michael's data and try various approaches to calculating the solvation free energy. This is somewhat preliminary. Data is read in as output from GROMACS, in kJ/mol, and converted to kcal/mol for output. 

REVISIONS:
- MRS  11/2/09, preparation to include with pysmbar
- MRS, 12/22/07, initial draft

DEPENDENCIES:
- numpy
- scipy
- pyMBAR
- BAR
- timeseries
- alphaanalysis

TO DO:
- Add comparison to vanilla WHAM as well, to get a sense of how bias may affect it?
"""

#===================================================================================================
#IMPORT STUFF, GET SOME USER INPUT ABOUT WHAT SORT OF ANALYSIS TO DO ON WHAT DATA FILES; SET SOME CONSTANTS
#===================================================================================================

import pdb # Python debugger
import BAR # for miscellaneous tools
import MBAR # maximum-likelihood wham
from numpy import * # Convenient array library required for arguments to AlchemicalWHAM
import timeseries # import the timeseries tools
import subprocess # for running a commandline lookup of the length
import os # For file management
import os.path 
from optparse import OptionParser # For parsing of command line arguments
from alphaanalysis import alphaanalysis
from numpy.random import randint, seed # import random number generator -- be careful importing everything, since "beta" is a function here!

#========================================================================
# CONSTANTS AND DEFAULT PARAMETERS
#========================================================================

temperature = 298.0 # temperature, in K
convert_atmnm3_to_kJmol = 1.01325e5*(1e-09)**3 * 6.02214 * (1e23) / 1000 # Convert pV from atmospheres*nm^3 into kJ/mol
convert_kJ_to_kcal = 1./4.184 # convert from kJ to kcal
pressure = 1.0 # pressure, in atm
kB = 1.381*6.02214/1000.0  # Boltzmann's constant, in kcal mol^-1 K^-1
#set the temperature
beta = 1./(kB*temperature) # inverse temperature, in 1 / (kcal/mol)
converttokcal = convert_kJ_to_kcal/beta

def generate_true_ij(converttounits):
   # these are generated from a full MBAR analysis of all 11 states over 20 ns each. 
   f_true = matrix([  0.,6.75017168,8.69543746,11.92013365,15.25694981, 16.42687156,15.90581512,14.29461832,12.00590787,9.11165552,8.03658622])
   #f_true = matrix([  0.,6.75017168,8.69543746,11.92013365,15.25694981, 16.42687156,15.90581512,14.29461832,12.00590787])
   # make a full matrix of the f_ij

   onevector = matrix(ones(len(f_true),float))
   f_true_ij = onevector.T *f_true- f_true.T*onevector
   return f_true_ij*converttounits

#========================================================================
# SUBROUTINES
#========================================================================

def storegrid(Delta,dDelta,Deltai,dDeltai,K,k,name,converttounits):

   onevector = matrix(ones(K,float))
   f_true_ij = generate_true_ij(converttounits)

   Deltaij  = zeros([K,K],float)
   dDeltaij = zeros([K,K],float)
   Deltaij_error = zeros([K,K],float)

   Deltai[k+1] = Deltai[k] + (Delta*converttounits)
   dDeltai[k+1] = dDeltai[k] + (dDelta*converttounits)**2

   print "Interval %d,  DeltaF=%.4f +/- %.4f, Sum=%.4f +/- %.4f" % (k+1,around(Delta,4),around(sqrt(dDelta),4),around(Deltai[k+1],4),around(sqrt(dDeltai[k+1]),4))

   if (k==K-2):
      print "%20s free energy: %.4f +/- %.4f" % (name,around(Deltai[K-1],4), around(sqrt(dDeltai[K-1]),4))
      print ""

      Deltaij = onevector.T*matrix(Deltai) - matrix(Deltai).T*onevector
      dDeltaij = sqrt(abs(onevector.T*matrix(dDeltai)- matrix(dDeltai).T*onevector))
      Deltaij_error = Deltaij - f_true_ij

   return Deltaij,dDeltaij,Deltaij_error   

#============================================================================================================================
# Perform free energy analysis on an entire dataset.
#============================================================================================================================

def run_free_analysis(datafile, temperature, rseed, nlam, nfiles, nblock, relative_tolerance = 1.0e-8, verbose=False,recycling=False):

   """

   usage:
   
   run_free_analysis(datafile, temperature, rseed, nlam, nfiles, final_type, relative_tolerance = 1.0e-8, verbose=False)
                          
   required arguments:
      datafile - the name of the filename provided as input data
      beta - the inverse temperature of the simulations, in (kcal/mol)^(-1)
      nlam - the number of lambda values
      nfiles - the number of files -- does not need to correspond to nlam
      rseed - random number seeding
   optional arguments:
      verbose - if True, will spit out lots of information (default False)
      nblock - number of sampling blocks (default 1)
      recycling - use recycling of uncorrelated values, by averaging the offset-subsampled data together. In other words,
      if [1,4,7,10] are independent samples, the overall value and variance estimate will be reported as the averages from
      [1,4,7,10], [2,5,8,11], and [3,6,9,12].  This will reduce the variance, there is some data contained in the partially 
      correlated samples.

   """

   # Initialize dictionary to store results.
   results = {}
   # initalize the list of replicates of the blocked experiments
   replicates = []

   K = nlam
   NF = nfiles
	
   #========================================================================
   # READ DATA FILE
   #========================================================================

   seed(rseed) # set random number seed (for reproducible bootstrap results)
   print "Random seed is %d" % rseed	
   print "Number of lambda values is %d" % K
   print "Temperature is %.1f K" % temperature
   print "Inverse temperature is %.5f (kJ/mol)^-1" % beta
   print "1/ Inverse temperature is %.5f (kJ/mol)" % (1.0/ beta)
   if verbose:
     print "Output is verbose"
   if recycling:
     print "Recycling all correlated data"
   print "Number of blocks is %d" % nblock
   print "" 
   #open, name the files and find the length	
   datafilearray = [None]*NF
   snapshots = zeros((K),int)
   for n in range(NF):   
      datafilearray[n] =  datafile + '.' + str(n+1) + '.xvg'
      if (not (os.path.isfile(datafilearray[n]))):
         datafilearray[n] =  datafile + '.' + '0' + str(n+1) + '.xvg'
      text=subprocess.Popen([r"wc","-l", datafilearray[n]], stdout=subprocess.PIPE).communicate()[0]
      snapshots[n] = int(text[0:8])

   snapmax = snapshots.max()

   # Allocate storage.
   if verbose: print "Allocating storage for energies..."
   pe = zeros((K),float)
   
   # Read text file containing data and store temporarily
   print "Reading data from files..."
   AllU_kmt = zeros((K,K,snapmax),float) # AllU_kmt[n,m,t] is the value of the energy for snapshot t from simulation k at potential m
   snapshots = zeros((K),int)

   for nf in range(NF):   
      if verbose: print "Reading datafile %s into memory..." % datafilearray[nf]
      file = open(datafilearray[nf],'r')
      text = file.readlines()
      file.close()
   # Parse stored datafile contents.
      if verbose: print "Parsing energies from datafile..."
      for line in text:
         # DATAFILE FORMAT
         # Format illustrated by example lines below.
         #                                                                                                     1         1
         #           1         2         3         4         5         6         7         8         9         0         1
         # 012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456
         #
         # Example lines:
         #       0.0000  20 -8.045388905043521e+03 -8.052663173781861e+03 -8.052412537126758e+03
         #
         # Format:
         # 0:11  time
         # 12:15 lambda state  #MRS note -- may not be present in sample files!!!!
         # 16:38 fep-lambda value
         # 39:41 Current scaled potential energy
         # 42:64 First lambda
         # 65:87 Second Lambda and so on . . 

         # split into elements
         elements = line.split()

	 if (line[0] != '#') and (line[0] != '@'):
           # check whether the state is defined here . . . . 
           k = nf
           snapshots[k] += 1
           
           time = float(elements.pop(0))

           state = float(elements.pop(0))

           for m in range(K):
               pe[m] = float(elements.pop(0))
               
           # box volume (nm^3)
           volume = float(elements.pop(0))

           # compute pV work in kJ/mol
           pV = volume

           t = snapshots[k]-1

           for m in range(K):      
              AllU_kmt[k,m,t] = beta*(-pe[m]+pV) 

   samples = snapshots / nblock
   sampmax = samples.max()	

   gsums = []

   for i in range(nblock):


      print "Replica %d / %d" % (i+1,nblock)

      U_kmt = zeros((K,K,sampmax),float32) # U_kmt[n,m,t] is the value of the energy for snapshot t from simulation k at potential m
      u_kln = zeros((K,K,sampmax),float32) # u_kln[k,l,n] is the reduced potential energy for uncorrelated  snapshot n from state k at state l
      N_k = zeros(K, int32)

      # slice a subsample of the data  
      for k in range(K):
         U_kmt[k,:,0:samples[k]] = AllU_kmt[k,:,i*samples[k]:(i+1)*samples[k]]

      #====================================================================================
      # Statistical inefficiency 
      #===================================================================================
      # Compute statistical inefficiencies of the potential energy for each simulation to verify statistical independence of snapshots.

      if verbose: print "Computing statistical inefficiencies:"
      g = zeros((K),float)  # g[k] is the statistical inefficiency of simulation n (n = 0,...,nlambda-1).
      increment = zeros(K,float) # g[k]/gmin, the amount we increment to get separate uncorrelated samples
      for k in range(K):
         # get number of snapshots at lambda_n
         nelem = snapshots[k]-1
         # Compute the statistical inefficiency for simulation n.
         g[k] = timeseries.statisticalInefficiency(U_kmt[k,k,0:snapshots[k]], U_kmt[k,k,0:snapshots[k]])
         if verbose: print("lambda %d: g = %.3f") % (k,around(g[k],3))  
      # Find the maximum, and the increment level
      gmin = min(g)	
      increment = g/gmin # the amount we increment to get separate uncorrelated samples
      gmin_int = int(floor(gmin))

      #====================================================================================
      # Subsample the data to obtain uncorrelated samples
      #===================================================================================

      print "Subsampling data to produce uncorrelated samples..."

      indices_all = []
      for k in range(K):
         # Compute indices of effectively uncorrelated samples by examining potential energy timeseries of state k.
         l = k + 1
         if l >= K:
            l = k - 1
         indices = timeseries.subsampleCorrelatedData(U_kmt[k,l,0:snapshots[k]])
         #indices = range(snapshots[k])
         # Store subsampled data.
         indices_all.append(indices)


      if (recycling):
        recycling_limit = gmin_int
      else: 
        recycling_limit = 1     

      for gk in range(recycling_limit):

         for k in range(K):
            indices = indices_all[k]
            index_offset = int(floor(gk*increment[gk]))*ones(len(indices),int) 
            indices += index_offset
	    if (indices[len(indices)-1] >= samples[k]):
	      indices = indices[0:N_k[k]-1]
            N_k[k] = len(indices)

            for l in range(K):
               u_kln[k,l,0:N_k[k]] = U_kmt[k,l,indices]

         print "number of samples per lambda:"   
         print N_k
         print ""

         if (recycling):
           print "Uncorrelated Set %d / %d of Replica %d / %d" % (gk+1, recycling_limit, i+1,nblock)

         #====================================================================================
         # Scheme 1: do EXP across all states, forwards, reverse, and averaged
         #===================================================================================


           
      	 print "Method: EXP"

         # Independent data are not required, and correlated analysis is 
         # performed to compute uncertainties, so we use the full timeseries.     

         # Forward EXP
         print ""
         print "Forward EXP:"

         DeltaFi = zeros(K,float)
         dDeltaFi = zeros(K,float)

         for k in range(K-1):

            # Compute forward reduced work values     
            w_F = (u_kln[k,k+1,0:N_k[k]] - u_kln[k,k,0:N_k[k]])
     
            # Estimate forward free energy differences and uncertainties.
            (DeltaF1, dDeltaF1) = BAR.EXP(w_F, compute_uncertainty = True, independent = True)

            (DeltaFij,dDeltaFij,DeltaFij_error) = storegrid(DeltaF1,dDeltaF1,DeltaFi,dDeltaFi,K,k,"Forward EXP",converttokcal)

         # Reverse EXP
         print ""
         print "Reverse EXP:"

         DeltaRi = zeros(K,float)
         dDeltaRi = zeros(K,float)

         for k in range(K-1):

            # Compute reverse work values
            w_R = (u_kln[k+1,k,0:N_k[k+1]] - u_kln[k+1,k+1,0:N_k[k+1]]) 

            # Estimate reverse free energy differences and uncertainties.
            (DeltaR1, dDeltaR1) = BAR.EXP(w_R, compute_uncertainty = True, independent = True)
            DeltaR1 *= -1
            (DeltaRij,dDeltaRij,DeltaRij_error) = storegrid(DeltaR1,dDeltaR1,DeltaRi,dDeltaRi,K,k,"Reverse EXP",converttokcal)

         DeltaAi = zeros(K,float)
         dDeltaAi = zeros(K,float)

         # Average of forward and reverse EXP
         print ""
         print "Averge of forward and reverse EXP:"
         for k in range(K-1):
     
            # Estimate mean forward and reverse free energy difference and its uncertainty.
            DeltaA1 = 0.5*(DeltaRi[k+1] - DeltaRi[k] +  DeltaFi[k+1] - DeltaFi[k])/converttokcal
            dDeltaA1 = 0.5*(dDeltaRi[k+1] - dDeltaRi[k] +  dDeltaFi[k+1] - dDeltaFi[k])/converttokcal  

            (DeltaAij,dDeltaAij,DeltaAij_error) = storegrid(DeltaA1,dDeltaA1,DeltaAi,dDeltaAi,K,k,"Average EXP",converttokcal)

         DeltaWi = zeros(K,float)
         dDeltaWi = zeros(K,float)

         for k in range(K-1):
            if (k%2!=0):	 
	       DeltaW1 = (DeltaFi[k+1]-DeltaFi[k])/converttokcal
               dDeltaW1 = (dDeltaFi[k+1]+dDeltaFi[k])/converttokcal
            else:
	       DeltaW1 = (DeltaRi[k+1]-DeltaRi[k])/converttokcal
               dDeltaW1 = (dDeltaRi[k+1]+dDeltaRi[k])/converttokcal

            (DeltaWij,dDeltaWij,DeltaWij_error) = storegrid(DeltaW1,dDeltaW1,DeltaWi,dDeltaWi,K,k,"Double-Wide EXP",converttokcal)

      #====================================================================================
      # Scheme 2: do BAR across all the states, summing individually
      #===================================================================================

         print "Method: BAR"

         DeltaBi   = zeros(K,float)
         dDeltaBi  = zeros(K,float)

         for k in range(K-1):
            # Compute forward and backward work.
            w_F = (u_kln[k, k+1, 0:N_k[k]] - u_kln[k, k, 0:N_k[k]]) # forward work
            w_R = (u_kln[k+1, k, 0:N_k[k+1]] - u_kln[k+1, k+1, 0:N_k[k+1]]) # reverse work

            # Estimate free energy difference with BAR
            (DeltaB1, dDeltaB1) = BAR.BAR(w_F, w_R, tolerance=relative_tolerance, verbose = False, compute_uncertainty = True)

            (DeltaBij,dDeltaBij,DeltaBij_error) = storegrid(DeltaB1,dDeltaB1,DeltaBi,dDeltaBi,K,k,"BAR",converttokcal)
            
      #====================================================================================
      # Scheme 2a: do Pairwise BAR across all the states, summing individually
      #===================================================================================

         print "Method: Unoptimized BAR -- assume initial free energies are zeroed"

         DeltaUBi   = zeros(K,float)
         dDeltaUBi  = zeros(K,float)

         for k in range(K-1):
            # Compute forward and backward work.
            w_F = (u_kln[k, k+1, 0:N_k[k]] - u_kln[k, k, 0:N_k[k]]) # forward work
            w_R = (u_kln[k+1, k, 0:N_k[k+1]] - u_kln[k+1, k+1, 0:N_k[k+1]]) # reverse work

            # Estimate free energy difference with BAR
            (DeltaUB1, dDeltaUB1) = BAR.BAR(w_F, w_R, DeltaF=0.0, tolerance=relative_tolerance, verbose = False, compute_uncertainty = True, unoptimized= True)

            (DeltaUBij,dDeltaUBij,DeltaUBij_error) = storegrid(DeltaUB1,dDeltaUB1,DeltaUBi,dDeltaUBi,K,k,"Unopt. BAR",converttokcal)

      #====================================================================================
      # Scheme 2b: do Pairwise BAR across all the states, summing individually
      #===================================================================================

         print "Method: Postoptimized BAR -- compute for a range of initial dF, and choose" 
         print "the result where the input dF is closest to the output dF. Actually may" 
         print "require more iterations, but does not require self consistent solutions."

         DeltaPBi   = zeros(K,float)
         dDeltaPBi  = zeros(K,float)

         for k in range(K-1):
            # Compute forward and backward work.
            w_F = (u_kln[k, k+1, 0:N_k[k]] - u_kln[k, k, 0:N_k[k]]) # forward work
            w_R = (u_kln[k+1, k, 0:N_k[k+1]] - u_kln[k+1, k+1, 0:N_k[k+1]]) # reverse work

            minD = 1E6;
            for DeltaFest in range (-10,10,1):
               # Estimate free energy difference with BAR
               (DeltaEst, dDeltaEst) = BAR.BAR(w_F, w_R, DeltaF=DeltaFest, tolerance=relative_tolerance, verbose = False, compute_uncertainty = True, unoptimized= True)
               if (abs(DeltaEst - DeltaFest) < minD):
                  DeltaPB1 = DeltaEst
                  dDeltaPB1 = dDeltaEst
                  minD = abs(DeltaEst - DeltaFest)
            (DeltaPBij,dDeltaPBij,DeltaPBij_error) = storegrid(DeltaPB1,dDeltaPB1,DeltaPBi,dDeltaPBi,K,k,"Postopt. BAR",converttokcal)


         #=====================================================================================================
         # Scheme 3: do two state MBAR across all the states, summing individually - should be identical to BAR
         #=====================================================================================================

         print "Method: Pairwise MBAR.  Should give identical results to BAR."

         DeltaPi   = zeros(K,float)
         dDeltaPi  = zeros(K,float)

         U2 = zeros((2,2,snapmax),float);
         N2 = zeros((2),float);
         kp = k+2
         
         for k in range(K-1):
            mbar = MBAR.MBAR(u_kln[k:(k+2),k:(k+2),:], N_k[k:(k+2)], relative_tolerance=relative_tolerance, verbose = True, method='Newton-Raphson')
            (Deltaf_ij_estimated, dDeltaf_ij_estimated) = mbar.getFreeEnergyDifferences()	
            DeltaP1 = float(Deltaf_ij_estimated[0,1])
            dDeltaP1 = float(dDeltaf_ij_estimated[0,1])

            (DeltaPij,dDeltaPij,DeltaPij_error) = storegrid(DeltaP1,dDeltaP1,DeltaPi,dDeltaPi,K,k,"Pairwise MBAR",converttokcal)

         #========================================================================================
         # SCHEME 4:  Do all-state MBAR, PRINT RESULTS
         # 
         #========================================================================================

         print "Method: Full-state MBAR"
 
         pdb.set_trace()
         mbar = MBAR.MBAR(u_kln, N_k, relative_tolerance=relative_tolerance, verbose = True, method='Newton-Raphson', BARinitialize = True)	
         (DeltaMij, dDeltaMij) = mbar.getFreeEnergyDifferences()
         DeltaMij *= converttokcal
         dDeltaMij *= converttokcal
         DeltaMij_error = DeltaMij - generate_true_ij(converttokcal)

         print "DeltaMij:"		
         print DeltaMij
         print "dDeltaMij:"
         print dDeltaMij
      

         gcopy = {}
         gcopy['Forward EXP']  = (DeltaFij,dDeltaFij,DeltaFij_error)
         gcopy['Reverse EXP']  = (DeltaRij,dDeltaRij,DeltaRij_error)
         gcopy['Average EXP']  = (DeltaAij,dDeltaAij,DeltaAij_error)
         gcopy['Double-Wide']  = (DeltaWij,dDeltaWij,DeltaWij_error)
         gcopy['BAR']          = (DeltaBij,dDeltaBij,DeltaBij_error)
         gcopy['Unopt. BAR']   = (DeltaUBij,dDeltaUBij,DeltaUBij_error)
         gcopy['Postopt. BAR'] = (DeltaPBij,dDeltaPBij,DeltaPBij_error)
         gcopy['Pairwise BAR'] = (DeltaPij,dDeltaPij,DeltaPij_error)
         gcopy['MBAR']         = (DeltaMij,dDeltaMij,DeltaMij_error)

         gkeys = gcopy.keys()

         gsum = {}
         if (gk==0):
            for k in gkeys:	
               gsum[k] = zeros([3,K,K],float)

         for g in (gkeys):       
            gsum[g] += gcopy[g]
         
      for g in (gkeys):       
         gsum[g] /= recycling_limit

      gsums.append(gsum)            

   for g in gkeys:
      replicates = list()
      for gs in gsums: 
         replicate = dict()
         replicate['Deltaf_ij'] = gs[g][0]
         replicate['dDeltaf_ij_estimated'] = gs[g][1]
         replicate['Deltaf_ij_error'] = gs[g][2]
         replicates.append(replicate)

      print "========== %20s ==========" % (str(g))   
      alphaanalysis(replicates,K)

#========================================================================
# COMMAND-LINE DRIVER
#
# This command-line driver runs he analysis routine on a given datafile.
#========================================================================

# Create command-line argument options.
usage_string = "usage: %prog [--verbose] [--temperature TEMPERATURE] [--rseed RANDOM SEED] --data DATAFILE --nlam NUMBER_OF_LAMBDAS\n\nexample: ./gromacs_to_dg.py --verbose --recycling --temperature 298 --seed 1 --nblock 1 --nlam 9 --nfiles 9 --data vdw" 
version_string = "%prog $Version: $"

parser = OptionParser(usage=usage_string, version=version_string)

parser.add_option("-d", "--data", metavar='DATAFILE',
                  action="store", type="string", dest='datafile', default='',
                  help="prefix of column-formatted datafile containing potential energies")
parser.add_option('-v', "--verbose",
                  action="store_true", dest="verbose",
                  help="Displays debug information")
parser.add_option('-T', "--temperature", metavar='TEMPERATURE',
                  action="store", type="float", dest='temperature', default=298.0,
                  help='temperature simulations are run at - for now, assumes only one')
parser.add_option('-S', "--seed", metavar='SEED',
                  action="store", type="int", dest='rseed', default=0,
                  help='random number for seeding')
parser.add_option('-l', "--nlam", metavar='NUMBER_OF_LAMBDAS',
                  action="store", type="int", dest='nlam', default=0,
                  help='number of lambdas')
parser.add_option('-f', "--nfiles", metavar='NUMBER_OF_FILES',
                  action="store", type="int", dest='nfiles', default=0,
                  help='number of data files')
parser.add_option('-b', "--nblock", metavar='NUMBER_OF_BLOCKS',
                  action="store", type="int", dest='nblock', default=1,
                  help='number blocks for sampling')
parser.add_option('-r', "--recycling", metavar='RECYCLING OF CORRELATED DATA',
                  action="store_true", dest='recycling',
                  help='Use recycling of correlated data to determine better improved averages')

(options,args)=parser.parse_args()

# Perform minimal error checking.
if not options.datafile:
  parser.print_help()
  parser.error("Please enter the prefix of the data files to analyze.")
if not options.nlam:
  parser.print_help()
  parser.error("Please enter the number of lambda values.")
if not options.nfiles:
  parser.print_help()
  parser.error("Please enter the number of lambda values.")

# Run analysis 
results = run_free_analysis(options.datafile, options.temperature,options.rseed,options.nlam,options.nfiles,options.nblock,verbose=options.verbose, recycling=options.recycling)

