#!/usr/local/bin/env python

#=============================================================================================
# MODULE DOCSTRING
#=============================================================================================

"""
Compare platforms using different tolerances.

DESCRIPTION

COPYRIGHT

@author John D. Chodera <jchodera@gmail.com>

All code in this repository is released under the GNU General Public License.

This program is free software: you can redistribute it and/or modify it under
the terms of the GNU General Public License as published by the Free Software
Foundation, either version 3 of the License, or (at your option) any later
version.

This program is distributed in the hope that it will be useful, but WITHOUT ANY
WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A
PARTICULAR PURPOSE.  See the GNU General Public License for more details.
 
You should have received a copy of the GNU General Public License along with
this program.  If not, see <http://www.gnu.org/licenses/>.

TODO

"""

#=============================================================================================
# GLOBAL IMPORTS
#=============================================================================================

import os
import os.path
import sys
import math

import simtk.unit as units
import simtk.chem.openmm as openmm
#import simtk.chem.openmm.extras.amber as amber
import amber

#=============================================================================================
# SUBROUTINES
#=============================================================================================

def setIonCharge(system, atom_index, new_charge):
    """
    Set ion charge to specified value.

    """
    for force_index in range(system.getNumForces()):
        force = system.getForce(force_index)
        if hasattr(force, 'getParticleParameters'):
            [old_charge, sigma, epsilon] = force.getParticleParameters(atom_index)
            force.setParticleParameters(atom_index, new_charge, sigma, epsilon)
            break
        
    return        

def equilibrate(system, coordinates, platform):
    print "Equilibrating..."

    # Create integrator and context.
    integrator = openmm.LangevinIntegrator(temperature, collision_rate, timestep)
    context = openmm.Context(system, integrator, platform)

    # Set coordinates.
    context.setPositions(coordinates)

    # Equilibrate.
    integrator.step(5000)

    # Compute energy
    print "Computing energy."
    state = context.getState(getEnergy=True)
    potential_energy = state.getPotentialEnergy()
    print "potential energy: %.3f kcal/mol" % (potential_energy / units.kilocalories_per_mole)

    # Get coordinates.
    state = context.getState(getPositions=True)    
    coordinates = state.getPositions(asNumpy=True)
    
    return coordinates

#=============================================================================================
# MAIN AND TESTS
#=============================================================================================

if __name__ == "__main__":

    # Parameters
    prmtop_filename = 'system.prmtop'
    crd_filename = 'system.crd'

    temperature = 298.0 * units.kelvin
    collision_rate = 5.0 / units.picoseconds
    timestep = 2.0 * units.femtoseconds

    # List all available platforms
    print "Available platforms:"
    for platform_index in range(openmm.Platform.getNumPlatforms()):
        platform = openmm.Platform.getPlatform(platform_index)
        print "%5d %s" % (platform_index, platform.getName())
    print ""

    # Select platform.
    platform = openmm.Platform.getPlatformByName("Cuda")

    # Create system.
    print "Reading system..."
    cutoff = 9.0 * units.angstroms
    system = amber.readAmberSystem(prmtop_filename, nonbondedMethod='reaction-field', nonbondedCutoff=cutoff, shake='h-bonds')
    [coordinates, box_vectors] = amber.readAmberCoordinates(crd_filename, read_box=True)
    system.setPeriodicBoxVectors(box_vectors[0], box_vectors[1], box_vectors[2])

    # Zero ion charge.
    setIonCharge(system, 0, 0.0 * units.elementary_charge)

    # Find NonbondedForce.
    nonbonded_force = None
    for force_index in range(system.getNumForces()):
        force = system.getForce(force_index)
        if hasattr(force, 'getParticleParameters'):
            nonbonded_force = force
            break

    # Add charge interaction back in with CustomBondForce.
    energy_expression = '(C*ion_charge*charge/r)*(1.0-step(r/cutoff))'
    force = openmm.CustomBondForce(energy_expression)
    force.addGlobalParameter('C', 332.0 * units.kilocalories_per_mole * units.angstroms / units.elementary_charge**2)
    force.addGlobalParameter('ion_charge', 1.0 * units.elementary_charge)
    force.addGlobalParameter('cutoff', cutoff)
    force.addPerBondParameter('charge')
    for atom_index in range(1, system.getNumParticles()):
        [charge, sigma, epsilon] = nonbonded_force.getParticleParameters(atom_index)
        force.addBond(0, atom_index, [charge])
    system.addForce(force)
    
    # Equilibrate.
    coordinates = equilibrate(system, coordinates, platform)

    # Create context.
    integrator = openmm.LangevinIntegrator(temperature, collision_rate, timestep)    
    context = openmm.Context(system, integrator, platform)
    context.setPositions(coordinates)

    nsteps = 50
    work = 0.0 * units.kilocalories_per_mole
    for step in range(nsteps):
        # Compute work.
        context.setParameter('ion_charge', (1.0 - float(step)/float(nsteps)) * units.elementary_charge)
        state = context.getState(getEnergy=True)
        old_energy = state.getPotentialEnergy()
        
        context.setParameter('ion_charge', (1.0 - float(step+1)/float(nsteps)) * units.elementary_charge)
        state = context.getState(getEnergy=True)
        new_energy = state.getPotentialEnergy()

        # Evolve.
        integrator.step(1)
        
        # Accumulate work
        delta_energy = new_energy - old_energy
        work += delta_energy
        print "step %5d / %5d : energy = %8.1f kcal/mol, delta_energy = %8.1f kcal/mol, accumulated work = %8.1f kcal/mol" % (step+1, nsteps, new_energy / units.kilocalories_per_mole, delta_energy / units.kilocalories_per_mole, work / units.kilocalories_per_mole)


    
