#!/usr/local/bin/env python

#=============================================================================================
# MODULE DOCSTRING
#=============================================================================================

"""
Implicit Ligand Theory

Stage 3: Free energy calculations on fixed receptors.

@author John D. Chodera <jchodera@gmail.com>
@author David D. L. Minh <daveminh@gmail.com>

All code in this repository is released under the GNU General Public License.

This program is free software: you can redistribute it and/or modify it under
the terms of the GNU General Public License as published by the Free Software
Foundation, either version 3 of the License, or (at your option) any later
version.

This program is distributed in the hope that it will be useful, but WITHOUT ANY
WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A
PARTICULAR PURPOSE.  See the GNU General Public License for more details.
 
You should have received a copy of the GNU General Public License along with
this program.  If not, see <http://www.gnu.org/licenses/>.

TODO

"""

#=============================================================================================
# GLOBAL IMPORTS
#=============================================================================================

import os
import os.path
import sys
import copy
import time

import numpy
import numpy.random 

import simtk.unit as units
import simtk.openmm as openmm

#=============================================================================================
# PARAMETERS
#=============================================================================================

# Define alchemical protocol.
from alchemy import AlchemicalState
alchemical_protocol = list()
alchemical_protocol.append(AlchemicalState(0.00, 1.00, 1.00, 1.)) # fully interacting
alchemical_protocol.append(AlchemicalState(0.00, 0.75, 1.00, 1.)) 
alchemical_protocol.append(AlchemicalState(0.00, 0.50, 1.00, 1.))
alchemical_protocol.append(AlchemicalState(0.00, 0.25, 1.00, 1.))
alchemical_protocol.append(AlchemicalState(0.00, 0.00, 0.75, 1.))
alchemical_protocol.append(AlchemicalState(0.00, 0.00, 0.50, 1.))
alchemical_protocol.append(AlchemicalState(0.00, 0.00, 0.25, 1.))
alchemical_protocol.append(AlchemicalState(0.00, 0.00, 0.00, 1.)) # discharged, LJ decoupled
# Set protocol to decoupling.
for alchemical_state in alchemical_protocol:
    alchemical_state.annihilateElectrostatics = False
    alchemical_state.annihilateLennardJones = False

#=============================================================================================
# MAIN
#=============================================================================================

if __name__ == '__main__':    
    # Initialize command-line argument parser.

    usage = """
    USAGE

    python %prog --sourcedir source-directory [-v | --verbose] [-i | --iterations ITERATIONS] [--restraints restraint-type] --output output-netcdf

    EXAMPLES

    # Specify directory containing prmtop and crd files (complex, ligand, receptor) ending in .prmtop and .crd
    python %prog --sourcedir ../../src/examples/benzene-toluene --iterations 1000 --output output.nc --verbose

    NOTES

    In atom ordering, receptor comes before ligand atoms.

    """

    # Parse command-line arguments.
    from optparse import OptionParser
    parser = OptionParser(usage=usage)
    parser.add_option("--sourcedir", dest="source_directory", default=None, help="source directory containing complex, ligand, and receptor prmtop and crd files", metavar="SOURCE_DIRECTORY")
    parser.add_option("-v", "--verbose", action="store_true", dest="verbose", default=False, help="verbosity flag")
    parser.add_option("-i", "--iterations", dest="niterations", type="int", default=None, help="number of iterations", metavar="ITERATIONS")
    parser.add_option("--restraints", dest="restraint_type", default=None, help="specify ligand restraint type: 'harmonic' or 'flat-bottom' (default: 'flat-bottom')")
    parser.add_option("--output", dest="store_filename", default=None, help="specify output NetCDF file---must be unique for each calculation (default: 'output.nc')")

    # Parse command-line arguments.
    (options, args) = parser.parse_args()

    # Check arguments for validity.
    if not options.source_directory:
        parser.error("source directory containing ligand, receptor, and complex crd and prmtop files must be specified")
    if not os.path.exists(options.source_directory):
        parser.error("source directory '%s' does not exist" % options.source_directory)

    # Create System objects from AMBER prmtop files and load coordinates.
    if options.verbose: print "Reading AMBER prmtop and inpcrd files from directory '%s'..." % options.source_directory
    # NOTE: Hard-coded options for now.
    # TODO: Handle explicit solvent systems and different kinds of GB / nonbonded / constraint treatments.
    # TODO: Can add constraints to hydrogens in ligand only later.
    import simtk.openmm.app as app
    nonbondedMethod = app.NoCutoff
    implicitSolvent = app.OBC2
    constraints = None
    removeCMMotion = False
    systems = dict()
    initial_positions = dict()
    for name in ['complex', 'receptor', 'ligand']:
        # Read prmtop.
        prmtop_filename = os.path.join(options.source_directory, '%s.prmtop' % name)
        if options.verbose: print "Reading %s..." % prmtop_filename
        systems[name] = app.AmberPrmtopFile(prmtop_filename).createSystem(nonbondedMethod=nonbondedMethod, implicitSolvent=implicitSolvent, constraints=constraints, removeCMMotion=removeCMMotion)
        # Read coordinates.
        inpcrd_filename = os.path.join(options.source_directory, '%s.crd' % name)
        if options.verbose: print "Reading %s..." % inpcrd_filename
        initial_positions[name] = app.AmberInpcrdFile(inpcrd_filename).getPositions(asNumpy=True)

    # Compute index ranges of receptor and ligand atoms.
    receptor_atoms = range(0,systems['receptor'].getNumParticles())
    ligand_atoms = range(systems['receptor'].getNumParticles(), systems['complex'].getNumParticles())

    # Set all receptor atom masses to zero in complex so that receptor will be rigid.
    if options.verbose: print "Zeroing masses in receptor..."
    for atom_index in receptor_atoms:
        systems['complex'].setParticleMass(atom_index, 0.0 * units.amu)

    #
    # Create alchemically-modified states.
    #
    
    # TODO: Make sure decoupling works in alchemy.py.
    
    if options.verbose: print "Creating alchemically-modified states..."
    from alchemy import AbsoluteAlchemicalFactory

    # Create alchemical factory.
    alchemical_factory = AbsoluteAlchemicalFactory(systems['complex'], ligand_atoms=ligand_atoms)

    # Create alchemically-modified systems.
    alchemical_systems = alchemical_factory.createPerturbedSystems(alchemical_protocol)

    #
    # Read receptor conformations from NetCDF file.
    #

    # Create NetCDF file.
    import netCDF4 as netcdf # netcdf4-python
    ncfile = netcdf.Dataset(options.store_filename, 'a')

    # Create group to store docked complexes.
    if not 'free_energies' in ncfile.groups.keys():
        ncgrp = ncfile.createGroup('free_energies')
        ncvar_positions = ncgrp.createVariable('free_energies', 'f', ('iteration',))
        setattr(ncvar_positions, 'units', 'kT')
        setattr(ncvar_positions, "long_name", "free_energies[iteration] is free energy of complex 'iteration' with rigid receptor.")

    #
    # Run free energy calculation with fixed receptor.
    #
    
    for iteration in range(options.niterations):
        if options.verbose: print "performing free energy calculation for complex %8d / %8d" % (iteration, options.niterations)

        # Read coordinates.
        positions = units.Quantity(ncfile.groups['docked_complexes'].variables['positions'][iteration,:,:], units.nanometers)

        # Set up alchemical calculation.
        from thermodynamics import ThermodynamicState
        from repex import HamiltonianExchange
        temperature = 300.0 * units.kelvin
        reference_state = ThermodynamicState(temperature=temperature)
        store_filename = 'repex.nc'
        if os.path.exists(store_filename): os.remove(store_filename) # remove old copy of NetCDF file
        simulation = HamiltonianExchange(reference_state, alchemical_systems, positions, store_filename)

        # Set simulation options.
        simulation.verbose = options.verbose
        simulation.number_of_iterations = 10 # set the simulation to only run 10 iterations
        simulation.timestep = 1.0 * units.femtoseconds # set the timestep for integration
        simulation.minimize = True
        simulation.nsteps_per_iteration = 1000 

        # Run alchemical calculation.
        simulation.run()

        # Analyze data.
        analysis = simulation.analyze()
        ncfile.groups['free_energies'].variables['free_energy'] = analysis['Delta_f_ij'][0:-1] # free energy (in kT)

    # Close.
    ncfile.close()

    
    
