#!/usr/bin/env python
# Comparing MBAR for WHAM for an umbrella sampling simulation for the chi torsion of a valine sidechain in lysozyme L99A with benzene bound in the cavity.
# 

import pdb
import numpy # numerical array library
#from math import *
import pymbar # multistate Bennett acceptance ratio
import timeseries # timeseries analysis
import subprocess, sys

#common information
if (len(sys.argv) < 2):
    nbins  = 144  # default
else:
    nbins = int(sys.argv[1])

chi_min = -180
chi_max = 180
temperature = 300
K = 26 # number of umbrellas
delta = (chi_max - chi_min) / float(nbins)
bin_centers = numpy.zeros([nbins], numpy.float64)
for i in range(nbins):
    bin_centers[i] = chi_min + delta/2 + delta * i

# first do WHAM
wham_binv = numpy.zeros([nbins],float)
wham_bine = numpy.zeros([nbins],float)
wham_winv = numpy.zeros([K],float)
wham_wine = numpy.zeros([K],float)

kB = 0.008314472
kT = temperature*kB
kcal = 4.1868                # WHAM has incorrect definition of kcal/kJ conversion
kcaltokT = kT/kcal
name = 'PMF'
wham_outfile =  'WHAM_PMF_' + str(nbins)
randn = 12323;

# Command line: wham [P|Ppi|Pval] hist_min hist_max num_bins tol temperature numpad metadatafile freefile [num_MC_trials randSeed]
wham_bin = '../../../misc/wham-grossfield/wham/wham'
args = [wham_bin,'P360.0',str(chi_min),str(chi_max),str(nbins),'0.00000000001',str(temperature),'0','metadata.dat', wham_outfile,'10',str(randn)]
wham_stdout = 'WHAM_PMF_' + str(nbins) + '.stdout'
stdoutfile = open(wham_stdout,'w')
runwham = subprocess.Popen(args,stdout = stdoutfile)
runwham.wait()
stdoutfile.close()

infile = open(wham_outfile,'r');
lines = infile.readlines()
infile.close()
for line in lines:
    if line[0] == '#':
        if line[0:5] == '#Coor': 
            data = 'PMF'
            i=0
        elif line[0:5] == '#Wind':
            data = 'DF'
            i=0

    else:    
        if data == 'PMF':
            (theta,pmf,err,prob,perr) = line.split()
            wham_binv[i] = float(pmf)/kcaltokT
            wham_bine[i] = float(err)/kcaltokT
            i+=1

        if data =='DF':    
            vals = line.split()
            wham_winv[i] = float(vals[1])/kcaltokT
            wham_wine[i] = float(vals[2])/kcaltokT 
            i+=1

# now do MBAR

# Constants.
kB = 1.381e-23 * 6.022e23 / 1000.0 # Boltzmann constant in kJ/mol/K
# Parameters
N_max = 501 # maximum number of snapshots/simulation
T_k = numpy.ones(K,float)*temperature # inital temperatures are all equal 
beta = 1.0 / (kB * temperature) # inverse temperature of simulations (in 1/(kJ/mol))


# Allocate storage for simulation data
N_k = numpy.zeros([K], numpy.int32) # N_k[k] is the number of snapshots from umbrella simulation k
K_k = numpy.zeros([K], numpy.float64) # K_k[k] is the spring constant (in kJ/mol/deg**2) for umbrella simulation k
chi0_k = numpy.zeros([K], numpy.float64) # chi0_k[k] is the spring center location (in deg) for umbrella simulation k
chi_kn = numpy.zeros([K,N_max], numpy.float64) # chi_kn[k,n] is the torsion angle (in deg) for snapshot n from umbrella simulation k
u_kn = numpy.zeros([K,N_max], numpy.float64) # u_kn[k,n] is the reduced potential energy without umbrella restraints of snapshot n of umbrella simulation k
u_kln = numpy.zeros([K,K,N_max], numpy.float64) # u_kln[k,l,n] is the reduced potential energy of snapshot n from umbrella simulation k evaluated at umbrella l

# Read in umbrella spring constants and centers.
infile = open('centers.dat', 'r')
lines = infile.readlines()
infile.close()
for k in range(K):
    # Parse line k.
    line = lines[k]
    tokens = line.split()
    chi0_k[k] = float(tokens[0]) # spring center locatiomn (in deg)
    K_k[k] = float(tokens[1]) * (numpy.pi/180)**2 # spring constant (read in kJ/mol/rad**2, converted to kJ/mol/deg**2)    
    if len(tokens) > 2:
        T_k[k] = float(tokens[2])  # temperature the kth simulation was run at.

beta_k = 1.0/(kB*T_k)   # beta factor for the different temperatures
# Read the simulation data
for k in range(K):
    # Read torsion angle data.
    filename = 'data/prod%d_dihed.xvg' % k
    #print "Reading %s..." % filename
    infile = open(filename, 'r')
    lines = infile.readlines()
    infile.close()
    # Parse data.
    n = 0
    for line in lines:
        if line[0] != '#' and line[0] != '@':
            tokens = line.split()
            chi = float(tokens[1]) # torsion angle
            # wrap chi_kn to be within [-180,+180)
            while(chi < chi_min):
                chi += (chi_max-chi_min)
            while(chi >= chi_max):
                chi -= (chi_max-chi_min)
            chi_kn[k,n] = chi
            
            n += 1
    N_k[k] = n

# Set zero of u_kn -- this is arbitrary (is it?)
u_kn -= u_kn.min()

# Construct torsion bins
#print "Binning data..."
# compute bin centers
# Bin data
bin_kn = numpy.zeros([K,N_max], numpy.int32)
for k in range(K):
    for n in range(N_k[k]):
        # Compute bin assignment.
        bin_kn[k,n] = int((chi_kn[k,n] - chi_min) / delta)

# Evaluate reduced energies in all umbrellas
#print "Evaluating reduced potential energies..."
for k in range(K):
    for n in range(N_k[k]):
        for l in range(K):
            # Compute minimum-image torsion deviation from umbrella center l
            dchi = chi_kn[k,n] - chi0_k[l]
            if abs(dchi) > 180.0:
                dchi = 360.0 - abs(dchi)

            # Compute energy of snapshot n from simulation k in umbrella potential l
            u_kln[k,l,n] = u_kn[k,n] + beta_k[k] * (K_k[l]/2.0) * dchi**2

# Initialize MBAR.
# print "Running MBAR..."
mbar = pymbar.MBAR(u_kln, N_k, verbose = False, method = 'Newton-Raphson')
mbar_winv = mbar.f_k
mbar_wine = (mbar.getFreeEnergyDifferences())[1][0,:]  # get the uncertainties, assuming zeroed first state.

# Compute PMF in unbiased potential (in units of kT).
(f_i, df_i) = mbar.computePMF(u_kn, bin_kn, nbins,uncertainties='from-normalization')
mbar_binv = f_i
mbar_bine = df_i

#zero to the lowest state

minf = numpy.min(mbar_binv)
mbar_binv -= minf

# Write out PMF
#print "PMF (in units of kT)"
#print "%8s %8s %8s" % ('bin', 'f', 'df')
#for i in range(nbins):
#    print "%8.1f %8.3f %8.3f" % (bin_center_i[i], f_i[i], df_i[i])


names = ['WHAM','MBAR']

for name in names:

    if (name == 'WHAM'):
        winv = wham_winv.copy()
        wine = wham_wine.copy()
        binv = wham_binv.copy()
        bine = wham_bine.copy()

    elif (name == 'MBAR'):
        winv = mbar_winv.copy()
        wine = mbar_wine.copy()
        binv = mbar_binv.copy()
        bine = mbar_bine.copy()

    print "#%s Averaged umbrella free energies" % name
    print "#Umbr. # | Estimate |  Std. Err"   
    for i in range(K):
        print "%6d%12.6f%12.6f" % (i,winv[i],wine[i])

    print "#%s Averaged bins" % name
    print "#Bin center  | Estimate | Std. Err"

    for i in range(nbins):
        print "%12.6f%12.6f%12.6f" % (bin_centers[i],binv[i],bine[i])



