#!/usr/bin/env python

from __future__ import division
from numpy import *
from sympy import *
from scipy import * 
from scipy.integrate import *
from scipy.optimize import * 
from scipy.interpolate import * 

import numpy
import sympy
import re
import os
import sys
import commands
import time 
import pdb

# List methods used to approximate g(r)
methods = ['g0','g1']
method = methods[0] # Select method used to approximate g(r) 

# Specific heat capacity of water under constant pressure
Cp = 4.184; kB = 1.380658*6.0221367/1000.0  # Boltzmann's constant (kJ/mol/K)
temp = 298; betav = Cp/(kB*temp)
rho_OW = 0.03330001
rho_HW = 2.0*rho_OW

# Value for beta and the combined values for epsilon and sigma  
betain = betav; epsuv = 0.21146489; siguv = 3.44030500
# Initial values of the soft core potential to be optimized
Ain = 1.0; Bin = 1.0; Cin= 6.0; alphain = 2**(-Cin/6.0); alpha0 = 0 

# Initial guess value of Cand alpha for the optimization
optv0 = [Cin, alphain] 

# Group all input variables together as tuple 
Uinall = lambda optv,epsin,sigin,lamin,rin: [epsin,sigin,Ain,Bin,optv[0],optv[1],lamin,rin]
Ginall = lambda optv,epsin,sigin,lamin,rin: [betain,epsin,sigin,Ain,Bin,optv[0],optv[1],lamin,rin]
DGinall = lambda optv,rhoin,epsin,sigin,lamin,rin,ilam: [rhoin,betain,epsin,sigin,Ain,Bin,optv[0],optv[1],lamin,rin,ilam]

# Define lower and upper limit for lambda, rij and delta lambda for numerical derivative
lam_min = 0; lam_max = 1 - lam_min; dlam = 1E-3 
r_min = 0; r_max = 10
rcut = 1E-3; # The cutoff for values of r to prevent numerical unstability at r=0 due to LJ potential 

# Define number of data point in the chosen range of lambda and rij 
Nlam = 51; Nrij = 1001

# Define range of lambda and rij to be used for numerical integration 
lambda_range = linspace(lam_min,lam_max,Nlam) 
rij_dg = linspace(r_min,r_max,Nrij) 


## Define x,y and z grids to test arbitrary g(r) approximation method
#dr = 0.5 # Increment in x,y and z direction
#x_range = arange(-r_max,r_max+dr,dr)
#y_range = x_range
#z_range = x_range
#ngrid = len(x_range)*len(y_range)*len(z_range)
#xgrid = numpy.zeros(ngrid,float)
#ygrid = numpy.zeros(ngrid,float)
#zgrid = numpy.zeros(ngrid,float)
#ir = 0
#for x in x_range:
#    for y in y_range:
#        for z in z_range:
#            xgrid[ir] = x  
#            ygrid[ir] = y  
#            zgrid[ir] = z  
#            ir += 1
#Vcell = dr**3
#rgrid = sqrt(xgrid**2+ygrid**2+zgrid**2)
#rgrid = rgrid[rgrid.argsort(),]

#===================================================================================================
# Functions to calculate all the properties for two-body and three-body approximations
#===================================================================================================

# Function to calculate the LJ potential 
def getulj(epsinv,siginv,Ainv,Binv,Cinv,alphainv,laminv,rinv):

    # For any value of reffpC < reffpC_cut, cap them to reffpC_cut to prevent numerical unstability
    # This normally occurs when lambda=1, but we extended to to cases when lambda approaches 1 as well
    reffpC_cut = rcut**Cinv

    # Sub-components required to calculate various function in an efficient manner
    neg6byC = -6.0/Cinv
    neg12byC = -12.0/Cinv
    epslampAx4 = 4*epsinv*laminv**Ainv
    sigpC = siginv**Cinv
    sigp6 = siginv**6.0
    sigp12 = siginv**12.0
    rpC = rinv**Cinv
    rbysig = rinv/siginv
    ralpha = alphainv 
    rlamp1 = 1 - laminv
    rlampB = rlamp1**Binv
    ralphalampB = ralpha*rlampB*sigpC
    reffpC = ralphalampB + rpC
    # Replace any value of reffpC < reffpC_cut to reffpC_cut to prevent numerical unstability
    if ('%.8f' %(laminv) == '%.8f' %(1)):
        ir = 0
        while reffpC[ir] < reffpC_cut:
            reffpC[ir] = reffpC_cut
            ir += 1
    reff6inv = sigp6*reffpC**neg6byC
    reff12inv = sigp12*reffpC**neg12byC

    Ulj = epslampAx4*(reff12inv - reff6inv)

    return Ulj


# Function to calculate dH/dL, ddH/dL2 and (dH/dL)^2 for dilute gas limit
def getdgfunc(rhoinv,betainv,epsinv,siginv,Ainv,Binv,Cinv,alphainv,laminv,rinv,ilam):

    # For any value of reffpC < reffpC_cut, cap them to reffpC_cut to prevent numerical unstability
    # This normally occurs when lambda=1, but we extended to to cases when lambda approaches 1 as well
    reffpC_cut = rcut**Cinv

    # Sub-components required to calculate various function in an efficient manner
    neg6byC = -6.0/Cinv
    neg6byCm1 = -6.0/Cinv - 1.0
    neg6byCm2 = -6.0/Cinv - 2.0
    neg12byC = -12.0/Cinv
    neg12byCm1 = -12.0/Cinv - 1.0
    neg12byCm2 = -12.0/Cinv - 2.0
    epsAx4 = 4*epsinv*Ainv
    epsAAm1x4 = epsAx4*(Ainv-1)
    epslampAx4 = 4*epsinv*laminv**Ainv
    depslampAx4_dl = epsAx4*laminv**(Ainv-1.0)
    # If A = 1, then the second derivative of lambda**A should be zero, not lambda**(A-2)
    # The if statement below ensure that the above case hold for the 2nd derivative of lambda**A
    if ('%.8f' %(Ainv) == '%.8f' %(1)):
        ddepslampAx4_dl2 = 0
    else:
        ddepslampAx4_dl2 = epsAAm1x4*laminv**(Ainv-2.0)
    sigpC = siginv**Cinv
    sigp6 = siginv**6.0
    sigp12 = siginv**12.0
    rpC = rinv**Cinv
    rbysig = rinv/siginv
    rbysigp2 = rbysig*rbysig
    ralpha = alphainv 
    rlamp1 = 1 - laminv
    #rlamp2 = rlamp1*rlamp1
    rlampB = rlamp1**Binv
    ralphalampB = ralpha*rlampB*sigpC
    reffpC = ralphalampB + rpC
    # Replace any value of reffpC < reffpC_cut to reffpC_cut to prevent numerical unstability
    if ('%.8f' %(laminv) == '%.8f' %(1)):
        ir = 0
        while reffpC[ir] < reffpC_cut:
            reffpC[ir] = reffpC_cut
            ir += 1
    dreffpC_dl = -Binv*ralpha*sigpC*rlamp1**(Binv-1.0)
    # If B = 1, then the second derivative of (1-lambda)**B should also be zero, not (1-lambda)**(B-2)
    # The if statement below ensure that the above case hold for the 2nd derivative of (1-lambda)**B
    if ('%.8f' %(Binv) == '%.8f' %(1)):
        ddreffpC_dl2 = 0
    else:
        ddreffpC_dl2 = Binv*(Binv-1.0)*ralpha*sigpC*rlamp1**(Binv-2.0)
    dreffpC_dlp2 = dreffpC_dl*dreffpC_dl
    reff6inv = sigp6*reffpC**neg6byC
    reff6invm1 = sigp6*reffpC**neg6byCm1
    reff6invm2 = sigp6*reffpC**neg6byCm2
    reff12inv = sigp12*reffpC**neg12byC
    reff12invm1 = sigp12*reffpC**neg12byCm1
    reff12invm2 = sigp12*reffpC**neg12byCm2
    dreff6inv_dl = neg6byC*dreffpC_dl*reff6invm1
    ddreff6inv_dl2 = neg6byC*neg6byCm1*dreffpC_dlp2*reff6invm2 + neg6byC*ddreffpC_dl2*reff6invm1
    dreff12inv_dl = neg12byC*dreffpC_dl*reff12invm1
    ddreff12inv_dl2 = neg12byC*neg12byCm1*dreffpC_dlp2*reff12invm2 + neg12byC*ddreffpC_dl2*reff12invm1

    # All functions are define here
    Ulj = epslampAx4*(reff12inv - reff6inv)
    dUdL_lam = epslampAx4*(dreff12inv_dl - dreff6inv_dl) + depslampAx4_dl*(reff12inv - reff6inv)
    ddUdL2_lam = epslampAx4*(ddreff12inv_dl2 - ddreff6inv_dl2) + depslampAx4_dl*(dreff12inv_dl - dreff6inv_dl) + depslampAx4_dl*(dreff12inv_dl - dreff6inv_dl) + ddepslampAx4_dl2*(reff12inv - reff6inv)
    dUdLsq_lam = dUdL_lam**2.0

    nsam = len(rinv)
    if method == 'g0':
        gr = exp(-betainv*Ulj)

    elif method == 'g1':
        #Increasing the number of data points in rinv in order to get an accurate value for the integral in g1
        nfold = 4 #Increase by 4 fold
        nsam_int = nfold*(len(rinv)-1) + 1
        rinv_int = linspace(rinv[0],rinv[-1],nsam_int)
        rbysig_int = rinv_int/siginv
        rbysigp2_int = rbysig_int*rbysig_int
        rpC_int = rinv_int**Cinv
        ralpha_int = alphainv 
        ralphalampB_int = ralpha_int*rlampB*sigpC
        reffpC_int = ralphalampB_int + rpC_int
        # Replace any value of reffpC < reffpC_cut to reffpC_cut to prevent numerical unstability
        if ('%.8f' %(laminv) == '%.8f' %(1)):
            ir = 0
            while reffpC_int[ir] < reffpC_cut:
                reffpC_int[ir] = reffpC_cut
                ir += 1
        reff6inv_int = sigp6*reffpC_int**neg6byC
        reff12inv_int = sigp12*reffpC_int**neg12byC
        Ulj_int = epslampAx4*(reff12inv_int - reff6inv_int)

        # Determine g1 from the Quantum correction for the 3rd varial coefficient
        fr = (exp(-betainv*Ulj) - 1)*rinv
        fr_int = (exp(-betainv*Ulj_int) - 1)*rinv_int
        fs_int = numpy.zeros([nsam,nsam],float)
        phi = numpy.zeros(nsam-1,float)
        g1 = numpy.zeros(nsam,float)

        # Evaluate the integral in each sub-interval in fr
        for irp in range(nsam-1):
            irp_ini = nfold*irp
            irp_fin = nfold*(irp+1)+1
            phi[irp] = simps(fr_int[irp_ini:irp_fin],rinv_int[irp_ini:irp_fin])

        for icount in range(1,nsam):
            for ir in range(icount,nsam):
                rupper = icount+ir # Upper limit for the first integral
                rlower = abs(icount-ir) # Lower limit for the first integral
                # First evaluate the first integral as a function of R and r
                fs_int[icount,ir] = sum(phi[rlower:rupper])
                fs_int[ir,icount] = fs_int[icount,ir]

            # Then the second integral is evaluated here with Simpson's rule
            fs = fs_int[icount,:]
            g1[icount] = 2*pi*simps(fs*fr,rinv)/rinv[icount]

        #Use cubic polynomial to obtain the extrapolated value of g1 at R = 0
        nfit = 30 # Use the first 30 samples to fit the cubic curve
        polycoeff = polyfit(rinv[1:nfit],g1[1:nfit],3)
        g1[0] = polyval(polycoeff,rinv[0])

        # g(r) with g1 correction term
        gr = exp(-betainv*Ulj)*(1+rhoinv*g1)

    g0 = exp(-betainv*Ulj)

    # dHdL, ddHdL2 and dHdL^2 as a function of rij at a particular value of lambda for dilute gas limit
    dHdL_lam = dUdL_lam*gr*rhoinv*4*pi*rinv**2.0
    dHdLsq_lam = dUdLsq_lam*gr*rhoinv*4*pi*rinv**2.0
    ddHdL2_lam = ddUdL2_lam*gr*rhoinv*4*pi*rinv**2.0

    # For pure hardcore potential at lambda = 1, g(r=0) = 0. Therefore any properties that are the product of
    # g(r) would have a zero value at r=0. The if statement below prevent numerical overflow for such a case
    if ('%.8f' %(laminv) == '%.8f' %(1)):
        dHdL_lam[0:5] = 0
        dHdLsq_lam[0:5] = 0
        ddHdL2_lam[0:5] = 0

    # Determine the avergare value for dHdL, ddHdL2 and dHdL^2
    dHdL = simps(dHdL_lam,rinv)
    dHdLsq = simps(dHdLsq_lam,rinv)
    ddHdL2 = simps(ddHdL2_lam,rinv)

    #pdb.set_trace()
    return dHdL_lam, dHdLsq_lam, dHdL, dHdLsq, ddHdL2, gr

# Function to calculate the variance via numerical integration in rij and lambda using Simpson's rule 
def getdgvariance(optv,epsuv,siguv,lam_range,rij_range):

    ilam = 0
    for lamv in lam_range:
        # For pure TIP3P, there is no LJ term fo H so only properties for O are taken into account 
        rho = rho_OW
        #Input for the predefined function to calculate dHdL, ddHdL2, dHdLsq and g(r) 
        DGin = DGinall(optv,rho,epsuv,siguv,lamv,rij_range,ilam)
        OdHdL_dg_lam, OdHdLsq_dg_lam, OdHdL_dg[ilam], OdHdLsq_dg[ilam], OddHdL2_dg[ilam], gr_dg = getdgfunc(*DGin)        

        ilam += 1	

    # Combine values for O and H to get the actual total value 
    dHdL = OdHdL_dg #+ HdHdL
    dHdLsq = OdHdLsq_dg #+ HdHdLsq
    ddHdL2 = OddHdL2_dg #+ HddHdL2
    
    # Variance of the free energy  
    var_deltaF = (simps(ddHdL2,lam_range) - dHdL[-1] + dHdL[0])/ betain
    print 'C, alpha = ', optv, ';   var(deltaF) = ', var_deltaF

    #pdb.set_trace()
    return var_deltaF 

# Function to evaluate all properties for the optimal pathway for two-body g0 and three-body g0 + \rho g1
def getdgoptimal(optv,epsuv,siguv,lam_range,rij_range):

    ilam = 0
    for lamv in lam_range:
        # For pure TIP3P, there is no LJ term fo H so only properties for O are taken into account 
        rho = rho_OW
        #Input for the predefined function to calculate dHdL, ddHdL2, dHdLsq and g(r) via coding
        DGin = DGinall(optv,rho,epsuv,siguv,lamv,rij_range,ilam)
        OdHdL_dg_lam, OdHdLsq_dg_lam, OdHdL_dg[ilam], OdHdLsq_dg[ilam], OddHdL2_dg[ilam], gr_dg = getdgfunc(*DGin)        
 
        ilam += 1	

    # Constant required for forward, central and backward finite difference derivatives
    dlamv = lam_range[1] - lam_range[0]
    dlamvx2 = 2*dlamv
    dlamvx12 = 12*dlamv
    FDIL = 2 # Limit of index below which forward difference is used
    BDIL = len(lam_range) - 3 # Limit of index above which backward difference is used 
    # Evaluating ddLdHdL using all three finite difference derivatives methods with constant dlamv
    for ilam in range(len(lam_range)):
        ilamm3 = ilam - 3
        ilamm2 = ilam - 2
        ilamm1 = ilam - 1
        ilamp1 = ilam + 1
        ilamp2 = ilam + 2
        ilamp3 = ilam + 3
        if ilam <= FDIL:
            OddLdHdL_dg[ilam] = (-OdHdL_dg[ilamp2]+4*OdHdL_dg[ilamp1]-3*OdHdL_dg[ilam])/dlamvx2    
        elif ilam >= BDIL:
            OddLdHdL_dg[ilam] = (OdHdL_dg[ilamm2]-4*OdHdL_dg[ilamm1]+3*OdHdL_dg[ilam])/dlamvx2    
        else:
            OddLdHdL_dg[ilam] = (OdHdL_dg[ilamm2]-8*OdHdL_dg[ilamm1]+8*OdHdL_dg[ilamp1]-OdHdL_dg[ilamp2])/dlamvx12    

    # Combine values for O and H to get the actual total value 
    dHdL = OdHdL_dg #+ HdHdL
    dHdLsq = OdHdLsq_dg #+ HdHdLsq
    ddHdL2 = OddHdL2_dg #+ HddHdL2
    ddLdHdL = OddLdHdL_dg #+ HddLdHdL
   
    # Variance of the dHdL
    var_dHdL = dHdLsq 
    var_dHdL_new = (ddHdL2 - ddLdHdL)/betain
 
    # Variance of the free energy  
    var_deltaF = simps(dHdLsq,lam_range) 
    var_deltaF_new = (simps(ddHdL2,lam_range) - dHdL[-1] + dHdL[0])/ betain
    print 'var(deltaF) = ' + repr(var_deltaF)
    print 'C, alpha = ', optv, ';   var(deltaF)_new = ', var_deltaF_new

    return OdHdL_dg_lam, OdHdLsq_dg_lam, dHdL, ddHdL2, ddLdHdL, dHdLsq, var_dHdL, var_dHdL_new, var_deltaF, var_deltaF_new 

#===================================================================================================
# Initialize array for all variables 
#===================================================================================================
# Define the array for variables for two-body and three-body approximations 
dHdL_dg = numpy.zeros(len(lambda_range),float) 
dHdLsq_dg = numpy.zeros(len(dHdL_dg),float) 
ddHdL2_dg = numpy.zeros(len(dHdL_dg),float) 
ddLdHdL_dg = numpy.zeros(len(dHdL_dg),float) 
OdHdL_dg = numpy.zeros(len(lambda_range),float) 
OdHdLsq_dg = numpy.zeros(len(dHdL_dg),float) 
OddHdL2_dg = numpy.zeros(len(dHdL_dg),float) 
OddLdHdL_dg = numpy.zeros(len(dHdL_dg),float) 
#HdHdL_dg = numpy.zeros(len(lambda_range),float) 
#HdHdLsq_dg = numpy.zeros(len(dHdL_dg),float) 
#HddHdL2_dg = numpy.zeros(len(dHdL_dg),float) 
#HddLdHdL_dg = numpy.zeros(len(dHdL_dg),float) 

# Define the array for variables for xyz approximation 
dHdL = numpy.zeros(len(lambda_range),float) 
dHdLsq = numpy.zeros(len(dHdL),float) 
ddHdL2 = numpy.zeros(len(dHdL),float) 
ddLdHdL= numpy.zeros(len(dHdL),float) 
OdHdL = numpy.zeros(len(lambda_range),float) 
OdHdLsq = numpy.zeros(len(dHdL),float) 
OddHdL2 = numpy.zeros(len(dHdL),float) 
OddLdHdL= numpy.zeros(len(dHdL),float) 
HdHdL = numpy.zeros(len(lambda_range),float) 
HdHdLsq = numpy.zeros(len(dHdL),float) 
HddHdL2 = numpy.zeros(len(dHdL),float) 
HddLdHdL= numpy.zeros(len(dHdL),float) 

#===================================================================================================
# Start performing optimization and evaluation of the results here 
#===================================================================================================
bounds = ([1, 200],[1E-5,100]) # Impose a bounded values for C and alpha for the optimization

print 'start DG optimize', time.ctime()
DGoptv, var_deltaF_dg_opt, DG_dict = fmin_l_bfgs_b(getdgvariance, optv0, args=(epsuv,siguv,lambda_range,rij_dg), approx_grad=1, bounds=bounds, m=10, factr=1e07, pgtol=1e-05, epsilon=1e-05, maxfun=1000)
print 'Optimal Parameters for ' + '%s' %(method)
print(DGoptv)
print(DG_dict)
print 'end DG optimize', time.ctime()

# Evaluate optimal properties for dilute gas limit
DGopt = (DGoptv,epsuv,siguv,lambda_range,rij_dg)
dHdL_dg_lam, dHdLsq_dg_lam, dHdL_dg, ddHdL2_dg, ddLdHdL_dg, dHdLsq_dg, var_dHdL_dg, var_dHdL_dg_new, var_deltaF_dg, var_deltaF_dg_new = getdgoptimal(*DGopt)
#print 'dHdL_dg = ' + repr(dHdL_dg)
#print 'ddLdHdL_dg = ' + repr(ddLdHdL_dg)
#print 'dHdLsq_dg = ' + repr(dHdLsq_dg)
print 'end DG evaluation', time.ctime()

# write two-body or three body approximate results to a file
dhdl_dg = 'Optimal-A%i-B%i-%i_stages.dg'%(Ain,Bin,len(lambda_range))
datafile = open(dhdl_dg, 'w')
datafile.write("%16s %16s %16s %16s %16s %16s %16s %16s\n" %('epsilon','sigma','A','B','C_opt','alpha_opt','alpha1','var(deltaF)'))
datafile.write("%16.8e %16.8e %16.8e %16.8e %16.8e %16.8e %16.8e %16.8e\n"  %(epsuv,siguv,Ain,Bin,DGoptv[0],DGoptv[1],alpha0,var_deltaF_dg_new))
datafile.write("%16s %16s %16s %16s %16s %16s\n" %('lambda', '<dHdL>', '+- error', '<dHdLsq>', '+- error', 'var(dHdL)'))
for k in range(len(lambda_range)):
    datafile.write("%16.8e %16.8e %16.8e %16.8e %16.8e %16.8e\n"  %(lambda_range[k], dHdL_dg[k], 0, dHdLsq_dg[k], 0, var_dHdL_dg_new[k]))
datafile.close()
