# Normalize molecule with OEChem.
#
# TODO:
# * Use mmtools.moltools.ligandtools equivalents of these functions

import openeye.oechem
import openeye.oeiupac

import tempfile
import commands
from math import *
from openeye.oechem import *
from openeye.oeomega import *
from openeye.oeiupac import *
from openeye.oeshape import *
try:
   from openeye.oequacpac import * #DLM added 2/25/09 for OETyperMolFunction; replacing oeproton
except:
   from openeye.oeproton import * #GJR temporary fix because of old version of openeye tools
from openeye.oeiupac import *
from openeye.oeszybki import *
import os
import re
import shutil

input_filename = 'bosutinib-2.pdb'
output_filename = 'bosutinib-2-oemol.mol2'
charge = None

#=============================================================================================
def expandConformations(molecule, maxconfs = None, threshold = None, include_original = False, torsionlib = None, verbose = False, strictTyping = None):   
   """Enumerate conformations of the molecule with OpenEye's Omega after normalizing molecule. 

   ARGUMENTS
   molecule (OEMol) - molecule to enumerate conformations for

   OPTIONAL ARGUMENTS
     include_original (boolean) - if True, original conformation is included (default: False)
     maxconfs (integer) - if set to an integer, limits the maximum number of conformations to generated -- maximum of 120 (default: None)
     threshold (real) - threshold in RMSD (in Angstroms) for retaining conformers -- lower thresholds retain more conformers (default: None)
     torsionlib (string) - if a path to an Omega torsion library is given, this will be used instead (default: None)
     verbose (boolean) - if True, omega will print extra information
     strictTyping (boolean) -- if specified, pass option to SetStrictAtomTypes for Omega to control whether related MMFF types are allowed to be substituted for exact matches.

   RETURN VALUES
     expanded_molecule - molecule with expanded conformations

   EXAMPLES
     # create a new molecule with Omega-expanded conformations
     expanded_molecule = expandConformations(molecule)

     
   """
   # Initialize omega
   omega = OEOmega()
   if strictTyping != None:
     omega.SetStrictAtomTypes( strictTyping)
   #Set atom typing options

   # Set verbosity.
   #omega.SetVerbose(verbose)
   #DLM 2/27/09: Seems to be obsolete in current OEOmega

   # Set maximum number of conformers.
   if maxconfs:
      omega.SetMaxConfs(maxconfs)
     
   # Set whether given conformer is to be included.
   omega.SetIncludeInput(include_original)
   
   # Set RMSD threshold for retaining conformations.
   if threshold:
      omega.SetRMSThreshold(threshold) 
 
   # If desired, do a torsion drive.
   if torsionlib:
      omega.SetTorsionLibrary(torsionlib)

   # Create copy of molecule.
   expanded_molecule = OEMol(molecule)   

   # Enumerate conformations.
   omega(expanded_molecule)


   # verbose output
   if verbose: print "%d conformation(s) produced." % expanded_molecule.NumConfs()

   # return conformationally-expanded molecule
   return expanded_molecule

#=============================================================================================
def enumerateStates(molecules, enumerate = "protonation", consider_aromaticity = True, maxstates = 200, verbose = True):
    """Enumerate protonation or tautomer states for a list of molecules.

    ARGUMENTS
      molecules (OEMol or list of OEMol) - molecules for which states are to be enumerated

    OPTIONAL ARGUMENTS
      enumerate - type of states to expand -- 'protonation' or 'tautomer' (default: 'protonation')
      verbose - if True, will print out debug output

    RETURNS
      states (list of OEMol) - molecules in different protonation or tautomeric states

    TODO
      Modify to use a single molecule or a list of molecules as input.
      Apply some regularization to molecule before enumerating states?
      Pick the most likely state?
      Add more optional arguments to control behavior.
    """

    # If 'molecules' is not a list, promote it to a list.
    if type(molecules) != type(list()):
       molecules = [molecules]

    # Check input arguments.
    if not ((enumerate == "protonation") or (enumerate == "tautomer")):
        raise "'enumerate' argument must be either 'protonation' or 'tautomer' -- instead got '%s'" % enumerate

    # Create an internal output stream to expand states into.
    ostream = oemolostream()
    ostream.openstring()
    ostream.SetFormat(OEFormat_SDF)
    
    # Default parameters.
    only_count_states = False # enumerate states, don't just count them

    # Enumerate states for each molecule in the input list.
    states_enumerated = 0
    for molecule in molecules:
        if (verbose): print "Enumerating states for molecule %s." % molecule.GetTitle()
        
        # Dump enumerated states to output stream (ostream).
        if (enumerate == "protonation"): 
            # Create a functor associated with the output stream.
            functor = OETyperMolFunction(ostream, consider_aromaticity, False, maxstates)
            # Enumerate protonation states.
            if (verbose): print "Enumerating protonation states..."
            states_enumerated += OEEnumerateFormalCharges(molecule, functor, verbose)        
        elif (enumerate == "tautomer"):
            # Create a functor associated with the output stream.
            functor = OETautomerMolFunction(ostream, consider_aromaticity, False, maxstates)
            # Enumerate tautomeric states.
            if (verbose): print "Enumerating tautomer states..."
            states_enumerated += OEEnumerateTautomers(molecule, functor, verbose)    
    print "Enumerated a total of %d states." % states_enumerated

    # Collect molecules from output stream into a list.
    states = list()
    if (states_enumerated > 0):    
        state = OEMol()
        istream = oemolistream()
        istream.openstring(ostream.GetString())
        istream.SetFormat(OEFormat_SDF)
        while OEReadMolecule(istream, state):
           states.append(OEMol(state)) # append a copy

    # Return the list of expanded states as a Python list of OEMol() molecules.
    return states

def assignPartialCharges(molecule, charge_model = 'am1bcc', multiconformer = False, minimize_contacts = False, verbose = False, maxconfs = 10):
   """Assign partial charges to a molecule using OEChem oeproton.

   ARGUMENTS
     molecule (OEMol) - molecule for which charges are to be assigned

   OPTIONAL ARGUMENTS
     charge_model (string) - partial charge model, one of ['am1bcc'] (default: 'am1bcc')
     multiconformer (boolean) - if True, multiple conformations are enumerated and the resulting charges averaged (default: False)
     minimize_contacts (boolean) - if True, intramolecular contacts are eliminated by minimizing conformation with MMFF with all charges set to absolute values (default: False)
     verbose (boolean) - if True, information about the current calculation is printed
     maxstates (int) - max number of states for multiconformer fit (default: 10)

   RETURNS
     charged_molecule (OEMol) - the charged molecule with GAFF atom types

   NOTES
     multiconformer and minimize_contacts can be combined, but this can be slow

   EXAMPLES
     # create a molecule
     molecule = createMoleculeFromIUPAC('phenol')
     # assign am1bcc charges
     assignPartialCharges(molecule, charge_model = 'am1bcc')

   TODO
     * Keep track of partial charge std over conformers to monitor how variable charges are.

   """

   #Check that molecule has atom names; if not we need to assign them
   assignNames = False
   for atom in molecule.GetAtoms():
       if atom.GetName()=='':
          assignNames = True #In this case we are missing an atom name and will need to assign
   if assignNames:
      if verbose: print "Assigning TRIPOS names to atoms"
      OETriposAtomNames(molecule)

   # Check input pameters.
   supported_charge_models  = ['am1bcc']
   if not (charge_model in supported_charge_models):
      raise "Charge model %(charge_model)s not in supported set of %(supported_charge_models)s" % vars()

   # Expand conformations if desired.   
   if multiconformer:
      expanded_molecule = expandConformations(molecule, maxconfs=maxconfs)
   else:
      expanded_molecule = OEMol(molecule)
   nconformers = expanded_molecule.NumConfs()
   if verbose: print 'assignPartialCharges: %(nconformers)d conformations will be used in charge determination.' % vars()
   
   # Set up storage for partial charges.
   partial_charges = dict()
   for atom in molecule.GetAtoms():
      name = atom.GetName()
      partial_charges[name] = 0.0

   # Assign partial charges for each conformation.
   conformer_index = 0
   for conformation in expanded_molecule.GetConfs():
      conformer_index += 1
      if verbose and multiconformer: print "assignPartialCharges: conformer %d / %d" % (conformer_index, expanded_molecule.NumConfs())

      # Assign partial charges to a copy of the molecule.
      if verbose: print "assignPartialCharges: determining partial charges..."
      charged_molecule = OEMol(conformation)   
      if charge_model == 'am1bcc':
         OEAssignPartialCharges(charged_molecule, OECharges_AM1BCC)         
      
      # Minimize with positive charges to splay out fragments, if desired.
      if minimize_contacts:
         if verbose: print "assignPartialCharges: Minimizing conformation with MMFF and absolute value charges..." % vars()         
         # Set partial charges to absolute value.
         for atom in charged_molecule.GetAtoms():
            atom.SetPartialCharge(abs(atom.GetPartialCharge()))
         # Minimize in Cartesian space to splay out substructures.
         szybki = OESzybki() # create an instance of OESzybki
         szybki.SetRunType(OERunType_CartesiansOpt) # set minimization         
         szybki.SetUseCurrentCharges(True) # use charges for minimization
         results = szybki(charged_molecule)
         # DEBUG
         #writeMolecule(charged_molecule, 'minimized.mol2')
         for result in results: result.Print(oeout)
         # Recompute charges;
         if verbose: print "assignPartialCharges: redetermining partial charges..."         
         OEAssignPartialCharges(charged_molecule, OECharges_AM1BCC)         
         

      # Accumulate partial charges.
      for atom in charged_molecule.GetAtoms():
         name = atom.GetName()
         partial_charges[name] += atom.GetPartialCharge()
         if verbose: print "%8s %10.5f" % (name, atom.GetPartialCharge())

   # Compute and store average partial charges in a copy of the original molecule.
   charged_molecule = OEMol(molecule)
   for atom in charged_molecule.GetAtoms():
      name = atom.GetName()
      atom.SetPartialCharge(partial_charges[name] / nconformers)

   # Return the charged molecule
   return charged_molecule

#=============================================================================================
# METHODS FOR INTERROGATING MOLECULES
#=============================================================================================
def formalCharge(molecule):
   """Report the net formal charge of a molecule.

   ARGUMENTS
     molecule (OEMol) - the molecule whose formal charge is to be determined

   RETURN VALUES
     formal_charge (integer) - the net formal charge of the molecule

   EXAMPLE
     net_charge = formalCharge(molecule)
   """

   # Create a copy of the molecule.
   molecule_copy = OEMol(molecule)

   # Assign formal charges.
   OEFormalPartialCharges(molecule_copy)

   # Compute net formal charge.
   formal_charge = int(round(OENetCharge(molecule_copy)))

   # return formal charge
   return formal_charge

#======================================================================================================
# MAIN BODY
#======================================================================================================

# Create a molecule from the temporary PDB file.
molecule = openeye.oechem.OEMol()
ifs = openeye.oechem.oemolistream(input_filename)
openeye.oechem.OEReadMolecule(ifs, molecule)
ifs.close()

# Assign aromaticity/bonds, do naming.
openeye.oechem.OEAssignAromaticFlags(molecule) # check aromaticity
openeye.oechem.OEAddExplicitHydrogens(molecule) # add hydrogens   
name = openeye.oeiupac.OECreateIUPACName(molecule) # attempt to determine IUPAC name
molecule.SetTitle(name) # Set title to IUPAC name

# Select appropriate formal charge state, if specified.
if (charge != None):
    # Enumerate protonation states and select desired state.
    protonation_states = enumerateStates(molecule, enumerate = "protonation", verbose = verbose)
    for molecule in protonation_states:
        if formalCharge(molecule) == charge:
            # Return the molecule if we've found one in the desired protonation state.
            break
    # Check to make sure that desired formal charge state has been found.
    if formalCharge(molecule) != charge:
        print "enumerateStates did not enumerate a molecule with desired formal charge."
        print "Options are:"
        for molecule in protonation_states:
            print "%s, formal charge %d" % (molecule.GetTitle(), formalCharge(molecule))
        raise "Could not find desired formal charge."

# Assign AM1-BCC charges.
#molecule = assignPartialCharges(molecule, charge_model='am1bcc', multiconformer=False, minimize_contacts=False, verbose=True)
#molecule = assignPartialCharges(molecule, charge_model='am1bcc', multiconformer=True, minimize_contacts=True, verbose=True)
molecule = assignPartialCharges(molecule, charge_model='am1bcc', multiconformer=True, minimize_contacts=False, verbose=True)

# Write molecule to file, if desired.
ostream = openeye.oechem.oemolostream()
ostream.open(output_filename)
openeye.oechem.OEWriteMolecule(ostream, molecule)
ostream.close()     

