#=============================================================================================
# MODULE DOCSTRING
#=============================================================================================

"""
GBVI_validation.py

Test to compute the 

For each sdf entry in the sdf file, parse and construct the 3D coordinates
from the 2D topology information.  I believe the experimental hydration free
energy is also listed there.  Compute AM1-BCC charges.

Create an OpenMM system for each molecule, and add a GBVIForce term.  To the
force term, you should add all the atoms, with GBVI parameters corresponding
to types deduced by SMARTS.  (There are luckily only a few types, and they
are simple.).  Be sure to also explicitly tell the GBVIForce object about
the connectivity information.

Compute the single-point potential energy by constructing a Context with
some integrator and just getting the State, telling it you want the Energy.
 (Look at Randy's examples for how to do this.). This should correspond to
Paul's listed hydration energy.

"""
#=============================================================================================
# GLOBAL IMPORTS
#=============================================================================================

import os
import math
import numpy

import simtk
import simtk.chem.openmm as openmm
import simtk.unit as units

from openeye.oechem import *
from openeye.oequacpac import *
from openeye.oeszybki import *
from openeye.oeomega import *

# Load plugins.
print "Loading plugins..."
openmm.Platform.loadPluginsFromDirectory(os.path.join(os.environ['OPENMM_INSTALL_DIR'], 'lib'))
openmm.Platform.loadPluginsFromDirectory(os.path.join(os.environ['OPENMM_INSTALL_DIR'], 'lib', 'plugins'))

#=============================================================================================
# GBVI Parameter class
#=============================================================================================

# Dict containing the values for GBVI parameters.
# The second atom if present is the value of the first atom if it is bonded to it.
# (i.e. H-O is the ['H']['O'][radius, gamma])

GBVIParameters = dict()
GBVIParameters['AM1-BCC'] = {OEElemNo_H:{OEElemNo_N:[1.25, 1.0461],\
                                        OEElemNo_O:[1.00, 2.2465],\
                                        OEElemNo_P:[1.00, 1.4322],\
                                        OEElemNo_S:[1.25, 0.8739],\
                                        OEElemNo_C:[1.25, 0.2437]},\
                             OEElemNo_C:{'a':[2.00, -0.1199],\
                                         'x':[1.80, -0.2863]},\
                             OEElemNo_N:{'a':[1.65, -0.6263],\
                                         'x':[1.65, -4.0443]},\
                             OEElemNo_O:{'a':[1.40, 5.0707],\
                                         'x':[1.35, 0.6859]},\
                             OEElemNo_P:{'x':[2.15, -4.1181]},\
                             OEElemNo_S:{'x':[1.95, -1.0215],\
                                         'a':[1.95, -1.0215]},\
                             OEElemNo_F:{'x':[1.50, 1.9309]},\
                             OEElemNo_Cl:{'x':[1.80, 0.0131]},\
                             OEElemNo_Br:{'x':[2.40, -1.1648]},\
                             OEElemNo_I:{'x':[2.60, 0.0869]}}

GBVIParameters['MMFF94'] = {OEElemNo_H:{OEElemNo_N:[1.50, -1.1686 ],\
                                         OEElemNo_O:[1.15, -1.9987 ],\
                                         OEElemNo_P:[1.25, -7.7321],\
                                         OEElemNo_S:[1.50, -1.1690],\
                                         OEElemNo_C:[1.50, 0.1237 ]},\
                             OEElemNo_C:{'a':[2.30, -0.2048 ],\
                                         'x':[2.15, -1.1087]},\
                             OEElemNo_N:{'a':[2.20, -2.0826],\
                                         'x':[2.20, -5.4887 ]},\
                             OEElemNo_O:{'a':[1.56, 3.9466],\
                                         'x':[1.85, -1.6709]},\
                             OEElemNo_P:{'x':[1.91, -6.1361 ]},\
                             OEElemNo_S:{'x':[2.10, -0.3279],\
                                         'a':[2.10, -0.3279]},\
                             OEElemNo_Cl:{'x':[2.10, 0.8911]},\
                             OEElemNo_F:{'x':[1.50, 2.3697]},\
                             OEElemNo_Br:{'x':[2.30, -0.5673]},\
                             OEElemNo_I:{'x':[2.30, -4.4705]}}

#=============================================================================================
# GB/VI comparison
#=============================================================================================

# Choose comparison mode.
#mode = 'AM1-BCC'
mode = 'MMFF94'

# Create input stream from parameterized database.
ifs = oemolistream('labute_3D_charged.oeb')

# Open file for output
outfile = open('gbvi.out', 'w')

# Open file for writing small-molecule parameters.
molecules_outfile = open('molecules.' + mode + '.out', 'w')

# Initialize storage for experimental and calculated free energies of hydration.
experiment = dict()
calculated = dict()

# Iterate over molecules
for molecule in ifs.GetOEGraphMols():
   # Get molecule name
   name = molecule.GetTitle()

   # Get experimental free energy
   dg_exp = molecule.GetFloatData('dg') * units.kilocalories_per_mole

   # Assign charges.
   if (mode == 'AM1-BCC'):
      # Assign AM1-BCC charges.
      noHydrogen = False # use explicit hydrogens
      OEAssignPartialCharges(molecule, OECharges_AM1BCC, noHydrogen)
   elif (mode == 'MMFF94'):
      # Assign MMFF94 charges.
      noHydrogen = False # use explicit hydrogens
      OEAssignFormalCharges(molecule)
      OEAssignPartialCharges(molecule, OECharges_MMFF94, noHydrogen)

      # Check that formal and partial charges sum up to the same amount.
      total_formal_charge = 0.0
      total_partial_charge = 0.0
      for atom in molecule.GetAtoms():
         total_formal_charge += atom.GetFormalCharge()
         total_partial_charge += atom.GetPartialCharge()
      if abs(total_formal_charge - total_partial_charge) > 0.01:
         print "total formal charge: %f" % total_formal_charge
         print "total partial charge: %f" % total_partial_charge         
         OEFormalPartialCharges(molecule)
      
   else:
      raise Exception("mode '%s' not recognized.")

   # Assign atomaticity.
   OEAssignAromaticFlags(molecule)         
   
   # Create OpenMM System.
   system = openmm.System()
   for atom in molecule.GetAtoms():
      mass = OEGetDefaultMass(atom.GetAtomicNum())
      system.addParticle(mass * units.amu)

   # Add nonbonded term.
#   nonbonded_force = openmm.NonbondedSoftcoreForce()
#   nonbonded_force.setNonbondedMethod(openmm.NonbondedForce.NoCutoff)
#   for atom in molecule.GetAtoms():
#      charge = 0.0 * units.elementary_charge
#      sigma = 1.0 * units.angstrom
#      epsilon = 0.0 * units.kilocalories_per_mole
#      nonbonded_force.addParticle(charge, sigma, epsilon)
#   system.addForce(nonbonded_force)

   # Add GBVI term
   #gbvi_force = openmm.GBVISoftcoreForce()
   gbvi_force = openmm.GBVIForce()   
   gbvi_force.setNonbondedMethod(openmm.GBVIForce.NoCutoff) # set no cutoff
   gbvi_force.setSoluteDielectric(1.0)
   gbvi_force.setSolventDielectric(78)

   # Use scaling method.
   #gbvi_force.setBornRadiusScalingMethod(openmm.GBVISoftcoreForce.QuinticSpline)
   #gbvi_force.setQuinticLowerLimitFactor(0.75)
   #gbvi_force.setQuinticUpperBornRadiusLimit(50.0*units.nanometers)
   
   # Create list of OpenEye atoms in molecule.
   atoms = [atom for atom in molecule.GetAtoms()]

   # Add atomic parameters
   for atom in atoms:      
      charge = atom.GetPartialCharge() * units.elementary_charge
   
      # Assign GB/VI parameters based on AM1-BCC parameters of Table 1 from DOI 10.1002/jcc
      radius = None
      gamma = None      
      if atom.IsHydrogen():
         # Atom is hydrogen; determine what it is bonded to.
         neighbors = [ neighbor for neighbor in atom.GetAtoms()] 
         neighbor_type = neighbors[0].GetType()
         [radius, gamma] = GBVIParameters[mode][atom.GetAtomicNum()][neighbors[0].GetAtomicNum()]         
      else:
         aromatic = 'x'
         if atom.IsAromatic(): aromatic = 'a'
         if atom.GetExplicitValence() > atom.GetExplicitDegree():
            aromatic = 'a'

         try:
            [radius, gamma] = GBVIParameters[mode][atom.GetAtomicNum()][aromatic]
         except:
            [radius, gamma] = GBVIParameters[mode][atom.GetAtomicNum()]['x']

      # Assign units for radius and gamma/
      radius = units.Quantity(radius, units.angstroms)
      gamma = units.Quantity(gamma, units.kilojoules_per_mole)

      #if name.find('amine') > -1:
      #print "%5s %5s %8.3f Radius %8.3f A : gamma %8.3f kcal/mol" % (atom.GetName(), atom.GetType(), atom.GetPartialCharge(), radius / units.angstroms, gamma / units.kilocalories_per_mole)

      lambda_ = 1.0 # fully interacting
      # gbvi_force.addParticle(charge, radius, gamma, lambda_) # for GBVISoftcoreForce
      gbvi_force.addParticle(charge, radius, gamma)

   # Add bonds.
   for bond in molecule.GetBonds():
      # Get atom indices.
      iatom = bond.GetBgnIdx()
      jatom = bond.GetEndIdx()
      # Get bond length.
      (xi, yi, zi) = molecule.GetCoords(atoms[iatom])
      (xj, yj, zj) = molecule.GetCoords(atoms[jatom])
      distance = math.sqrt((xi-xj)**2 + (yi-yj)**2 + (zi-zj)**2) * units.angstroms
      if name.find('amine') > -1:
         print distance
      # Identify bonded atoms to GBVI.
      gbvi_force.addBond(iatom, jatom, distance)

   # Add the force to the system.
   system.addForce(gbvi_force)

   # Build coordinate array.
   natoms = len(atoms)
   coordinates = units.Quantity(numpy.zeros([natoms, 3]), units.angstroms)
   for (index,atom) in enumerate(atoms):
      (x,y,z) = molecule.GetCoords(atom)
      coordinates[index,:] = units.Quantity(numpy.array([x,y,z]),units.angstroms)

   # Write molecule parameters.
   molecules_outfile.write('%s\n' % name)
   molecules_outfile.write('%d\n' % system.getNumParticles())   
   for index in range(gbvi_force.getNumParticles()):
      mass = system.getParticleMass(index)
      #[charge, radius, gamma, lambda_] = gbvi_force.getParticleParameters(index) # GBVISoftcoreForce
      [charge, radius, gamma] = gbvi_force.getParticleParameters(index)
      molecules_outfile.write('%5d %8.3f %8.5f %8.5f %12.6f %10.5f %10.5f %10.5f\n' % (index, mass / units.amu, charge / units.elementary_charge, radius / units.nanometers, gamma / units.kilojoules_per_mole, coordinates[index,0] / units.nanometers, coordinates[index,1] / units.nanometers, coordinates[index,2] / units.nanometers))
   molecules_outfile.write('%d\n' % gbvi_force.getNumBonds())
   for index in range(gbvi_force.getNumBonds()):
      [iatom, jatom, distance] = gbvi_force.getBondParameters(index)
      molecules_outfile.write('%5d %5d %5d %10.5f\n' % (index, iatom, jatom, distance / units.nanometers))
   molecules_outfile.write('\n')
   

   # Create OpenMM Context.
   platform = openmm.Platform.getPlatformByName("Reference")
   timestep = 1.0 * units.femtosecond # arbitrary
   integrator = openmm.VerletIntegrator(timestep)
   context = openmm.Context(system, integrator)

   # Set the coordinates.
   context.setPositions(coordinates)

   # Get the energy
   state = context.getState(getEnergy=True)
   dg_gbvi = state.getPotentialEnergy()

   if math.isnan(dg_gbvi / units.kilocalories_per_mole):
      raise Exception("GBVI energy is nan.")

   # Store energies.
   experiment[name] = dg_exp
   calculated[name] = dg_gbvi
   
   # Clean up
   del system, context, integrator, coordinates

   outstring = "%48s %8.3f %8.3f" % (name, dg_exp / units.kilocalories_per_mole, dg_gbvi / units.kilocalories_per_mole)
   print outstring
   outfile.write(outstring + '\n')

# Close input stream.
ifs.close() 
outfile.close()    
    
