#=============================================================================================
# MODULE DOCSTRING
#=============================================================================================

"""
GBVI_validation.py

Test to compute the 

For each sdf entry in the sdf file, parse and construct the 3D coordinates
from the 2D topology information.  I believe the experimental hydration free
energy is also listed there.  Compute AM1-BCC charges.

Create an OpenMM system for each molecule, and add a GBVIForce term.  To the
force term, you should add all the atoms, with GBVI parameters corresponding
to types deduced by SMARTS.  (There are luckily only a few types, and they
are simple.).  Be sure to also explicitly tell the GBVIForce object about
the connectivity information.

Compute the single-point potential energy by constructing a Context with
some integrator and just getting the State, telling it you want the Energy.
 (Look at Randy's examples for how to do this.). This should correspond to
Paul's listed hydration energy.

"""
#=============================================================================================
# GLOBAL IMPORTS
#=============================================================================================

import os
import math
import numpy

import simtk
import simtk.chem.openmm as openmm
import simtk.unit as units

from openeye.oechem import *
from openeye.oequacpac import *
from openeye.oeszybki import *

# Load plugins.
print "Loading plugins..."
openmm.Platform.loadPluginsFromDirectory(os.path.join(os.environ['OPENMM_INSTALL_DIR'], 'lib'))
openmm.Platform.loadPluginsFromDirectory(os.path.join(os.environ['OPENMM_INSTALL_DIR'], 'lib', 'plugins'))

#=============================================================================================
# GBVI Parameter class
#=============================================================================================

# Dict containing the values for GBVI parameters.
# The second atom if present is the value of the first atom if it is bonded to it.
# (i.e. H-O is the ['H']['O'][radius, gamma])

GBVIParameters = dict()
GBVIParameters['AM1-BCC'] = {OEElemNo_H:{OEElemNo_N:[1.25, 1.0461],\
                                        OEElemNo_O:[1.00, 2.2465],\
                                        OEElemNo_P:[1.00, 1.4322],\
                                        OEElemNo_S:[1.25, 0.8739],\
                                        OEElemNo_C:[1.25, 0.2437]},\
                             OEElemNo_C:{'a':[2.00, -0.1199],\
                                         'x':[1.80, -0.2863]},\
                             OEElemNo_N:{'a':[1.65, -0.6263],\
                                         'x':[1.65, -4.0443]},\
                             OEElemNo_O:{'a':[1.40, 5.0707],\
                                         'x':[1.35, 0.6859]},\
                             OEElemNo_P:{'x':[2.15, -4.1181]},\
                             OEElemNo_S:{'x':[1.95, -1.0215],\
                                         'a':[1.95, -1.0215]},\
                             OEElemNo_F:{'x':[1.50, 1.9309]},\
                             OEElemNo_Cl:{'x':[1.80, 0.0131]},\
                             OEElemNo_Br:{'x':[2.40, -1.1648]},\
                             OEElemNo_I:{'x':[2.60, 0.0869]}}

GBVIParameters['MMFF94'] = {OEElemNo_H:{OEElemNo_N:[1.50, -1.1686 ],\
                                         OEElemNo_O:[1.15, -1.9987 ],\
                                         OEElemNo_P:[1.25, -7.7321],\
                                         OEElemNo_S:[1.50, -1.1690],\
                                         OEElemNo_C:[1.50, 0.1237 ]},\
                             OEElemNo_C:{'a':[2.30, -0.2048 ],\
                                         'x':[2.15, -1.1087]},\
                             OEElemNo_N:{'a':[2.20, -2.0826],\
                                         'x':[2.20, -5.4887 ]},\
                             OEElemNo_O:{'a':[1.56, 3.9466],\
                                         'x':[1.85, -1.6709]},\
                             OEElemNo_P:{'x':[1.91, -6.1361 ]},\
                             OEElemNo_S:{'x':[2.10, -0.3279],\
                                         'a':[2.10, -0.3279]},\
                             OEElemNo_Cl:{'x':[2.10, 0.8911]},\
                             OEElemNo_F:{'x':[1.50, 2.3697]},\
                             OEElemNo_Br:{'x':[2.30, -0.5673]},\
                             OEElemNo_I:{'x':[2.30, -4.4705]}}

#=============================================================================================
# GB/VI comparison
#=============================================================================================

# Choose comparison mode.
mode = 'AM1-BCC'
#mode = 'MMFF94'

# Create input stream from parameterized database.
ifs = oemolistream('solvation.sdf')

# Open file for output
outfile = open('gbvi.out', 'w')

# Open file for writing small-molecule parameters.
molecules_outfile = open('molecules.' + mode + '.out', 'w')

# Initialize storage for experimental and calculated free energies of hydration.
experiment = dict()
calculated = dict()

# Iterate over molecules
for molecule in ifs.GetOEGraphMols():
   # DEBUG
   if molecule.NumAtoms() != 1:
      continue

   # Get molecule name
   name = molecule.GetTitle()

   # Get metadata.
   name = OEGetSDData(molecule, 'name').strip()
   idx = int(OEGetSDData(molecule, 'idx'))
   dg_exp = float(OEGetSDData(molecule, 'dG(exp)')) * units.kilocalories_per_mole
   dg_moe = float(OEGetSDData(molecule, 'E_sol')) * units.kilocalories_per_mole   

   # Assign charges.
   if (mode == 'AM1-BCC'):
      # Assign AM1-BCC charges.
      if molecule.NumAtoms() == 1:
         # Use formal charges for ions.
         OEFormalPartialCharges(molecule)         
      else:
         # Assign AM1-BCC charges for multiatom molecules.
         OEAssignPartialCharges(molecule, OECharges_AM1BCC, False) # use explicit hydrogens      
      
   elif (mode == 'MMFF94'):
      # Assign MMFF94 charges.
      noHydrogen = False # use explicit hydrogens
      OEAssignFormalCharges(molecule)
      OEAssignPartialCharges(molecule, OECharges_MMFF94, noHydrogen)

      # Check that formal and partial charges sum up to the same amount.
      total_formal_charge = 0.0
      total_partial_charge = 0.0
      for atom in molecule.GetAtoms():
         total_formal_charge += atom.GetFormalCharge()
         total_partial_charge += atom.GetPartialCharge()
      if abs(total_formal_charge - total_partial_charge) > 0.01:
         print "total formal charge: %f" % total_formal_charge
         print "total partial charge: %f" % total_partial_charge         
         OEFormalPartialCharges(molecule)
      
   else:
      raise Exception("mode '%s' not recognized.")

   # Assign atomaticity.
   OEAssignAromaticFlags(molecule)         
   
   # Create OpenMM System.
   system = openmm.System()
   for atom in molecule.GetAtoms():
      mass = OEGetDefaultMass(atom.GetAtomicNum())
      system.addParticle(mass * units.amu)

   # Add nonbonded term.
#   nonbonded_force = openmm.NonbondedSoftcoreForce()
#   nonbonded_force.setNonbondedMethod(openmm.NonbondedForce.NoCutoff)
#   for atom in molecule.GetAtoms():
#      charge = 0.0 * units.elementary_charge
#      sigma = 1.0 * units.angstrom
#      epsilon = 0.0 * units.kilocalories_per_mole
#      nonbonded_force.addParticle(charge, sigma, epsilon)
#   system.addForce(nonbonded_force)

   # Add GBVI term
   gbvi_force = openmm.GBVISoftcoreForce()
   #gbvi_force = openmm.GBVIForce()   
   gbvi_force.setNonbondedMethod(openmm.GBVIForce.NoCutoff) # set no cutoff
   gbvi_force.setSoluteDielectric(1)
   gbvi_force.setSolventDielectric(78)

   # Use scaling method.
   gbvi_force.setBornRadiusScalingMethod(openmm.GBVISoftcoreForce.QuinticSpline)
   gbvi_force.setQuinticLowerLimitFactor(0.75)
   gbvi_force.setQuinticUpperBornRadiusLimit(50.0*units.nanometers)
   
   # Create list of OpenEye atoms in molecule.
   atoms = [atom for atom in molecule.GetAtoms()]

   # Add atomic parameters
   for atom in atoms:      
      charge = atom.GetPartialCharge() * units.elementary_charge
   
      # Assign GB/VI parameters based on AM1-BCC parameters of Table 1 from DOI 10.1002/jcc
      radius = None
      gamma = None      
      if atom.IsHydrogen():
         # Atom is hydrogen; determine what it is bonded to.
         neighbors = [ neighbor for neighbor in atom.GetAtoms()] 
         neighbor_type = neighbors[0].GetType()
         [radius, gamma] = GBVIParameters[mode][atom.GetAtomicNum()][neighbors[0].GetAtomicNum()]         
      else:
         aromatic = 'x'
         if atom.IsAromatic(): aromatic = 'a'
         if atom.GetExplicitValence() > atom.GetExplicitDegree():
            aromatic = 'a'

         try:
            [radius, gamma] = GBVIParameters[mode][atom.GetAtomicNum()][aromatic]
         except:
            [radius, gamma] = GBVIParameters[mode][atom.GetAtomicNum()]['x']

      # Assign units for radius and gamma.
      radius = units.Quantity(radius, units.angstroms)
      gamma = units.Quantity(gamma, units.kilocalories_per_mole)

      charge *= 0.0 # DEBUG
      #gamma *= -1.0 # DEBUG

      #if name.find('amine') > -1:
      print "name %5s type %5s charge %8.3f radius %8.3f A : gamma %8.3f kcal/mol" % (atom.GetName(), atom.GetType(), charge / units.elementary_charge, radius / units.angstroms, gamma / units.kilocalories_per_mole)
      
      lambda_ = 1.0 # fully interacting
      gbvi_force.addParticle(charge, radius, gamma, lambda_) # for GBVISoftcoreForce
      # gbvi_force.addParticle(charge, radius, gamma) # for GBVIForce

   # Add bonds.
   for bond in molecule.GetBonds():
      # Get atom indices.
      iatom = bond.GetBgnIdx()
      jatom = bond.GetEndIdx()
      # Get bond length.
      (xi, yi, zi) = molecule.GetCoords(atoms[iatom])
      (xj, yj, zj) = molecule.GetCoords(atoms[jatom])
      distance = math.sqrt((xi-xj)**2 + (yi-yj)**2 + (zi-zj)**2) * units.angstroms
      # Identify bonded atoms to GBVI.
      gbvi_force.addBond(iatom, jatom, distance)

   # Add the force to the system.
   system.addForce(gbvi_force)

   # Build coordinate array.
   natoms = len(atoms)
   coordinates = units.Quantity(numpy.zeros([natoms, 3]), units.angstroms)
   for (index,atom) in enumerate(atoms):
      (x,y,z) = molecule.GetCoords(atom)
      coordinates[index,:] = units.Quantity(numpy.array([x,y,z]),units.angstroms)

   # Write molecule parameters.
   molecules_outfile.write('%s %8.3f %8.3f\n' % (name, dg_exp / units.kilojoules_per_mole, dg_moe / units.kilojoules_per_mole))
   molecules_outfile.write('%d\n' % system.getNumParticles())   
   for index in range(gbvi_force.getNumParticles()):
      mass = system.getParticleMass(index)
      [charge, radius, gamma, lambda_] = gbvi_force.getParticleParameters(index) # GBVISoftcoreForce
      #[charge, radius, gamma] = gbvi_force.getParticleParameters(index) # GBVIForce
      molecules_outfile.write('%5d %8.3f %8.5f %8.5f %12.6f %10.5f %10.5f %10.5f\n' % (index, mass / units.amu, charge / units.elementary_charge, radius / units.nanometers, gamma / units.kilojoules_per_mole, coordinates[index,0] / units.nanometers, coordinates[index,1] / units.nanometers, coordinates[index,2] / units.nanometers))
   molecules_outfile.write('%d\n' % gbvi_force.getNumBonds())
   for index in range(gbvi_force.getNumBonds()):
      [iatom, jatom, distance] = gbvi_force.getBondParameters(index)
      molecules_outfile.write('%5d %5d %5d %10.5f\n' % (index, iatom, jatom, distance / units.nanometers))
   molecules_outfile.write('\n')
   

   # Create OpenMM Context.
   platform = openmm.Platform.getPlatformByName("Reference")
   timestep = 1.0 * units.femtosecond # arbitrary
   integrator = openmm.VerletIntegrator(timestep)
   context = openmm.Context(system, integrator, platform)

   # Set the coordinates.
   context.setPositions(coordinates)

   # Get the energy
   state = context.getState(getEnergy=True)
   dg_gbvi = state.getPotentialEnergy()

   if math.isnan(dg_gbvi / units.kilocalories_per_mole):
      raise Exception("GBVI energy is nan.")

   # Store energies.
   experiment[name] = dg_exp
   calculated[name] = dg_gbvi
   
   # Clean up
   del system, context, integrator, coordinates

   outstring = "%48s %8.3f %8.3f %8.3f" % (name, dg_exp / units.kilocalories_per_mole, dg_moe / units.kilocalories_per_mole, dg_gbvi / units.kilocalories_per_mole)
   print outstring
   outfile.write(outstring + '\n')

# Close input stream.
ifs.close() 
outfile.close()    
    
