#!/usr/bin/python

import pdb
import glob,os,re,sys
import subprocess
import numpy,scipy
import pymbar
import scipy.integrate
from numpy import pi,cos

if (len(sys.argv) < 4):
    print "Usage: ./compare_wham_bar.py nbin nwin nrep totaln"
    raise SystemExit

nbin = int(sys.argv[1])
nwin = int(sys.argv[2])
nrep = int(sys.argv[3])
totaln = int(sys.argv[4])

metafile = 'meta'+str(nwin)+'.dat'

chi_min = -pi # min for PMF
chi_max = pi # max for PMF

delta = (chi_max - chi_min) / float(nbin)

kB = 0.0019829237;
T_wham = 1/kB
T_mbar = 1
label = 'tmp_bin' + str(nbin) + 'nrep' + str(nrep)

infile =  open(metafile, 'r')
mlines = infile.readlines()
meta2 = label+'.dat'
outfile = open(meta2,'w')
nf = 0
filein = list()
fileout = list()
for line in mlines:
    tokens = line.split()
    filein.append(tokens[0])
    fileout.append(label + '.' + tokens[0])
    newline = fileout[nf] + ' ' + tokens[1] + ' ' + tokens[2] + '\n'
    outfile.write(newline)
    nf+=1

outfile.close()

wham_binv = numpy.zeros([nrep,nbin],float)
wham_winv = numpy.zeros([nrep,nwin],float)
mbar_binv = numpy.zeros([nrep,nbin],float)
mbar_winv = numpy.zeros([nrep,nwin],float)

truefree = numpy.zeros([nwin],float)

for i in range(nrep):
    print "Replica #%-4d" % (i+1)
    maxn = totaln/nrep;
    for j in range(len(filein)):
	fin = filein[j]
	fout = fileout[j]
        infile = open(fin,'r')
        lines = infile.readlines()
        truefree[j] = (lines[1].split())[2]
        if (j==0):
            ref_free = truefree[j]
        truefree[j] -= ref_free    
        outfile = open(fout,'w');
	istart = i*maxn+5;
	iend = (i+1)*maxn+5;
  	plines = lines[istart:iend];
	for p in plines:
            outfile.write(p)
        outfile.close()
            
    ####### compute WHAM data
    
    # Command line: wham [P|Ppi|Pval] hist_min hist_max num_bins tol temperature numpad metadatafile freefile [num_MC_trials randSeed]
    args = ['../../../misc/wham-grossfield/wham/wham', 'Ppi',str(chi_min),str(chi_max),str(nbin),'0.0000000001', str(T_wham), '0',meta2,label + '.batch','100' ,'-1']
    outstuff = open(label + '.outstuff','w')
    runwham = subprocess.Popen(args,stdout = outstuff)
    outstuff.close()
    runwham.wait()
    fbatch = label + '.batch'
    fout = label + '.outstuff'
    infile1 = open(fbatch,'r')

    blines = infile1.readlines()

    j=0;
    for line in blines:
        if line[0] != '#':
            vals = line.split()
            wham_binv[i,j] = vals[1]  
            j+=1

    infile2 = open(fout,'r')
    olines = infile2.readlines()

    for j in range(len(olines)):
        if (olines[j][2:8] == 'Window'):
            lastwin = j
            
    k = 0
    for l in range(lastwin+1,lastwin+nwin+1):
        vals = olines[l].split()
        wham_winv[i,k] = vals[2]
        k+=1

    ##### Compute MBAR data

    # Parameters
    K = nwin # number of umbrellas
    beta = 1/T_mbar # inverse temperature of simulations (in 1/kT)

    # Allocate storage for simulation data
    N_k = numpy.zeros([K], numpy.int32) # N_k[k] is the number of snapshots from umbrella simulation k
    K_k = numpy.zeros([K], numpy.float64) # K_k[k] is the spring constant (in kJ/mol/deg**2) for umbrella simulation k
    chi0_k = numpy.zeros([K], numpy.float64) # chi0_k[k] is the spring center location (in deg) for umbrella simulation k
    chi_kn = numpy.zeros([K,maxn], numpy.float64) # chi_kn[k,n] is the torsion angle (in deg) for snapshot n from umbrella simulation k
    u_kn = numpy.zeros([K,maxn], numpy.float64) # u_kn[k,n] is the reduced potential energy with umbrella restraints of snapshot n of umbrella simulation k
    unbiased_kn = numpy.zeros([K,maxn], numpy.float64) # u_kn[k,n] is the reduced potential energy without umbrella restraints of snapshot n of umbrella simulation k
    u_kln = numpy.zeros([K,K,maxn], numpy.float64) # u_kln[k,l,n] is the reduced potential energy of snapshot n from umbrella simulation k evaluated at umbrella l


    # Read in umbrella spring constants and centers.
    infile = open(meta2, 'r')
    lines = infile.readlines()
    infile.close()
    filename = list()
    for k in range(K):
        # Parse line k.
        line = lines[k]
        tokens = line.split()
        filename.append(tokens[0])
        chi0_k[k] = float(tokens[1]) # spring center locatiomn (in deg)
        K_k[k] = float(tokens[2])  # spring constant (in kT/rad**2)   

    # Read the simulation data
    for k in range(K):
        # Read torsion angle data.
        #print "Reading %s..." % filename
        infile = open(filename[k], 'r')
        lines = infile.readlines()
        infile.close()
        # Parse data.
        n = 0
        for line in lines:
            if line[0] != '#' and line[0] != '@':
                tokens = line.split()
                chi_kn[k,n] =  float(tokens[1]) # torsion angle
                u_kn[k,n] = float(tokens[2]) # unbiased energy
                unbiased_kn[k,n] = float(tokens[3]) # unbiased energy
            n += 1
        N_k[k] = n

    # Construct torsion bins
    #print "Binning data..."

    bin_kn = numpy.zeros([K,maxn], numpy.int32)
    for k in range(K):
        for n in range(N_k[k]):
            # Compute bin assignment.
            bin_kn[k,n] = int((chi_kn[k,n] - chi_min) / delta)

    # Evaluate reduced energies in all umbrellas
    #print "Evaluating reduced potential energies..."
    for k in range(K):
        for n in range(N_k[k]):
            for l in range(K):
                # Compute minimum-image torsion deviation from umbrella center l
                dchi = chi_kn[k,n] - chi0_k[l]
                if abs(dchi) > pi:
                    dchi = 2*pi - abs(dchi)

            # Compute energy differences of snapshot n from simulation k in umbrella potential l
                u_kln[k,l,n] = unbiased_kn[k,n] + beta * (K_k[l]/2.0) * dchi**2

    # Initialize MBAR.
    #print "Running MBAR..."
    mbar = pymbar.MBAR(u_kln, N_k, verbose = False, method = 'Newton-Raphson',initialize='BAR')
    mbar_winv[i,:] = mbar.f_k

    # Compute PMF in unbiased potential (in units of kT).
    (f_i, df_i) = mbar.computePMF(unbiased_kn, bin_kn, nbin)
    mbar_binv[i,:] = f_i


#### remove unneeded files to reduce clutter
filelist=glob.glob(label+'*')
for file in filelist:
    os.remove(file)

#### Define the underlying function  #####

def func(X):
   xf = X - pi/2
   result = 2*(3+cos(xf)+cos(2*xf)+cos(4*xf)) 
   return result

def expfunc(X):
   return numpy.exp(-func(X))
 
#### first, read in data ######

bin_centers = numpy.zeros([nbin],float)
expect_delta = numpy.zeros([nbin],float)
v_mid = numpy.zeros([nbin],float)
v_ave = numpy.zeros([nbin],float)
for i in range(nbin):
    bin_min = chi_min+delta*i
    bin_max = chi_min+delta*(i+1)
    bin_centers[i] = (bin_max+bin_min)/2
    # expectation value of the delta function approximated by the bin histogram
    expect_delta[i] = -numpy.log((scipy.integrate.quadrature(expfunc,bin_min,bin_max))[0])

    # potential of the bin center
    v_mid[i] = func((bin_min+bin_max)/2)

    # average potential over the bin
    v_ave[i] = ((scipy.integrate.quadrature(func,bin_min,bin_max))[0])/delta

expect_delta = expect_delta - numpy.min(expect_delta)
v_mid = v_mid - numpy.min(v_mid)
v_ave = v_ave - numpy.min(v_ave)

names = ['WHAM','MBAR']
for name in names:

    if (name == 'WHAM'):
        winv = wham_winv.copy()
        binv = wham_binv.copy()
    elif (name == 'MBAR'):
        winv = mbar_winv.copy()
        binv = mbar_binv.copy()

    ave_winv = numpy.average(winv,axis=0)
    std_winv = numpy.std(winv,axis=0)/numpy.sqrt(nrep-1)
    ave_binv = numpy.average(binv,axis=0)
    std_binv = numpy.std(binv,axis=0)/numpy.sqrt(nrep-1)


    print "%s Averaged umbrella free energies" % name
    print "Umbr. # | Analyt. G  | Estimate |    Bias    |  Std. Err"   
    for i in range(nwin):
        print "%6d%12.6f%12.6f%12.6f%12.6f" % (i,truefree[i],ave_winv[i],truefree[i]-ave_winv[i],std_winv[i])

    print "%s Averaged bins" % name
    print "Bin center   | Binned PMF | Est. PMF | Ave. Pot. | Midp. Pot.| Bin Error | Est. Bias |  Std. Err"

    for i in range(nbin):
        print "%12.6f%12.6f%12.6f%12.6f%12.6f%12.6f%12.6f%12.6f" % (bin_centers[i],expect_delta[i],ave_binv[i],v_ave[i],v_mid[i],v_ave[i]-expect_delta[i],expect_delta[i]-ave_binv[i],std_binv[i])
