#!/usr/bin/python

import pdb
import numpy
import matplotlib.pylab as plt
import subprocess
import os
from os import path
# run the python for two different sets umbrellas

smallbin = 18
largebin = 72
numbrellas = 26

# generate output files for each of the systems we are interested in, if they do not exist

binvals = [smallbin,largebin]

for nbins in binvals:

    win_mbar = numpy.zeros([numbrellas,2],float)
    win_wham = numpy.zeros([numbrellas,2],float)

    bin_mbar = numpy.zeros([nbins,3],float)
    bin_wham = numpy.zeros([nbins,3],float)


    bin_output = 'PMF_' + str(nbins) + '.txt'

    if (not(os.path.isfile(bin_output))):
        print 'Running compare_wham_mbar_lysozyme.py with %d bins' % (nbins) 
        argv = ['./compare_wham_mbar_lysozyme.py',str(nbins)]
        plotfile = open(bin_output,'w')
        genpmf = subprocess.Popen(argv,stdout = plotfile)
        genpmf.wait()
        plotfile.close()
        
    plotfile = open(bin_output,'r')
    lines = plotfile.readlines()
    plotfile.close()

    for line in lines:
        if line[0] == '#':
            if line[1:5] == 'WHAM':
                method = 'WHAM'
            if line[1:5] == 'MBAR':
                method = 'MBAR'
            if line[15:23] == 'umbrella':    
                data = 'umbrella'
            if line[15:19] == 'bins':    
                data = 'bins'
                ibin = 0
        else:
            vals = line.split()
            for v in vals:
                v = float(v)
            if method == 'WHAM':
                if data == 'umbrella':
                    iwin = vals[0]
                    win_wham[iwin,:] = numpy.array(vals[1:3])
                if data == 'bins':
                    bin_wham[ibin,:] = numpy.array(vals[0:3])
                    ibin+=1
            if method == 'MBAR':
                if data == 'umbrella':
                    iwin = vals[0]
                    win_mbar[iwin,:] = numpy.array(vals[1:3])
                if data == 'bins':
                    bin_mbar[ibin,:] = numpy.array(vals[0:3])
                    ibin+=1
                
    if nbins == smallbin:           
         small_win_mbar = win_mbar.copy()
         small_win_wham = win_wham.copy()
         small_bin_mbar = bin_mbar.copy()
         small_bin_wham = bin_wham.copy()
    if nbins == largebin:    
         large_win_mbar = win_mbar.copy()
         large_win_wham = win_wham.copy()
         large_bin_mbar = bin_mbar.copy()
         large_bin_wham = bin_wham.copy()

# set definitions:

override = {
    'family'              : 'sans-serif',
    'verticalalignment'   : 'bottom',
    'horizontalalignment' : 'center',
    'weight'              : 'bold',
    'size'                : 18
  }


# We want to first compare the _free energies_ of the umbrellas for small and large # of states:

winvals = range(1,numbrellas+1)
# set title of the window free energy chart

# set the maximum and minima for the axis
max_axis = 1.15*numpy.max(numpy.concatenate([small_win_wham[:,0],small_win_mbar[:,0],large_win_wham[:,0],large_win_mbar[:,0]]))
min_axis = 1.15*numpy.min(numpy.concatenate([small_win_wham[:,0],small_win_mbar[:,0],large_win_wham[:,0],large_win_mbar[:,0]]))

fig1 = plt.figure(1)
plt.title('WHAM and MBAR umbrella free energies \n for leucine side chain, using ' + str(smallbin) + ' bins',size=18)
plt.xlabel('Bin Number',size= 16)
plt.ylabel('$\Delta$G (kT)',size = 16)
plt.axis([1,numbrellas+1,min_axis,max_axis])
plt.errorbar(winvals,small_win_wham[:,0],yerr=small_win_wham[:,1]) 
plt.errorbar(winvals,small_win_mbar[:,0],yerr=small_win_mbar[:,1]) 
plt.savefig('small_win_mbar_comparison.pdf')

fig2= plt.figure(2)
plt.title('WHAM and MBAR umbrella free energies \n for leucine side chain, using ' + str(largebin) + ' bins',size=18)
plt.xlabel('Bin Number',size= 16)
plt.ylabel('$\Delta$G (kT)',size = 16)
plt.axis([1,numbrellas+1,0,max_axis])
plt.errorbar(winvals,large_win_wham[:,0],yerr=large_win_wham[:,1]) 
plt.errorbar(winvals,large_win_mbar[:,0],yerr=large_win_mbar[:,1]) 
plt.savefig('large_win_mbar_comparison.pdf')


# Then we want to compare the bin PMFs 

# set the maximum and minima for the axis
max_axis = 1.15*numpy.max(numpy.concatenate([small_bin_wham[:,1],small_bin_mbar[:,1],large_bin_wham[:,1],large_bin_mbar[:,1]]))
min_axis = 1.15*numpy.min(numpy.concatenate([small_bin_wham[:,1],small_bin_mbar[:,1],large_bin_wham[:,1],large_bin_mbar[:,1]]))

plt.figure(3)
plt.title('WHAM and MBAR PMF \n for leucine side chain, using ' + str(smallbin) + ' bins',size=18)
plt.xlabel('$\chi$(degrees)',size= 16)
plt.ylabel('$\Delta$G (kT)',size = 16)
plt.axis([-185.0,185.0,min_axis,max_axis])
plt.errorbar(small_bin_wham[:,0],small_bin_wham[:,1],yerr=small_bin_wham[:,2]) 
plt.errorbar(small_bin_mbar[:,0],small_bin_mbar[:,1],yerr=small_bin_mbar[:,2]) 
plt.savefig('small_bin_mbar_comparison.pdf')


plt.figure(4)
plt.title('WHAM and MBAR PMF\n for leucine side chain, using ' + str(largebin) + ' bins',size=18)
plt.xlabel('$\chi$(degrees)',size= 16)
plt.ylabel('$\Delta$G (kT)',size = 16)
plt.axis([-185.0,185.0,min_axis,max_axis])
plt.errorbar(large_bin_wham[:,0],large_bin_wham[:,1],yerr=large_bin_wham[:,2]) 
plt.errorbar(large_bin_mbar[:,0],large_bin_mbar[:,1],yerr=large_bin_mbar[:,2]) 
plt.savefig('large_bin_mbar_comparison.pdf')


plt.figure(5)
plt.title('WHAM PMF for leucine side chain,\n comparing ' + str(smallbin) + ' and ' + str(largebin) + ' bins',size=18)
plt.xlabel('$\chi$(degrees)',size= 16)
plt.ylabel('$\Delta$G (kT)',size = 16)
plt.axis([-185.0,185.0,min_axis,max_axis])
plt.errorbar(small_bin_wham[:,0],small_bin_wham[:,1],yerr=small_bin_wham[:,2]) 
plt.errorbar(large_bin_wham[:,0],large_bin_wham[:,1],yerr=large_bin_wham[:,2]) 
plt.savefig('wham_large_small_comparison.pdf')


plt.figure(6)
plt.title('MBAR PMF for leucine side chain,\n comparing ' + str(smallbin) + ' and ' + str(largebin) + ' bins',size=18)
plt.xlabel('$\chi$(degrees)',size= 16)
plt.ylabel('$\Delta$G (kT)',size = 16)
plt.axis([-185.0,185.0,min_axis,max_axis])
plt.errorbar(small_bin_mbar[:,0],small_bin_mbar[:,1],yerr=small_bin_mbar[:,2]) 
plt.errorbar(large_bin_mbar[:,0],large_bin_mbar[:,1],yerr=large_bin_mbar[:,2]) 
plt.savefig('mbar_large_small_comparison.pdf')
plt.show()
