#!/usr/local/bin/env python

#=============================================================================================
# MODULE DOCSTRING
#=============================================================================================

"""
Example illustrating the use of NCMC switches two swap two ions in water.

DESCRIPTION

COPYRIGHT

@author John D. Chodera <jchodera@gmail.com>

All code in this repository is released under the GNU General Public License.

This program is free software: you can redistribute it and/or modify it under
the terms of the GNU General Public License as published by the Free Software
Foundation, either version 3 of the License, or (at your option) any later
version.

This program is distributed in the hope that it will be useful, but WITHOUT ANY
WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A
PARTICULAR PURPOSE.  See the GNU General Public License for more details.
 
You should have received a copy of the GNU General Public License along with
this program.  If not, see <http://www.gnu.org/licenses/>.

TODO

"""

#=============================================================================================
# GLOBAL IMPORTS
#=============================================================================================

import os
import os.path
import sys
import math
import copy
import time

import numpy

import simtk.unit as units
import simtk.openmm as openmm
    
#import Scientific.IO.NetCDF as netcdf # for netcdf interface in Scientific
import netCDF4 as netcdf # for netcdf interface provided by netCDF4 in enthought python

#=============================================================================================
# INTEGRATORS
#=============================================================================================

def VelocityVerletIntegrator(timestep):
    """
    Construct a velocity Verlet integrator.

    ARGUMENTS

    timestep (numpy.unit.Quantity compatible with femtoseconds) - the integration timestep

    RETURNS

    integrator (simtk.openmm.CustomIntegrator) - a velocity Verlet integrator

    NOTES

    This code is verbatim from Peter Eastman's example.

    """
    
    integrator = openmm.CustomIntegrator(timestep)

    integrator.addPerDofVariable("x1", 0)

    integrator.addUpdateContextState()
    integrator.addComputePerDof("v", "v+0.5*dt*f/m")
    integrator.addComputePerDof("x", "x+dt*v")
    integrator.addComputePerDof("x1", "x")
    integrator.addConstrainPositions()
    integrator.addComputePerDof("v", "v+0.5*dt*f/m+(x-x1)/dt")
    integrator.addConstrainVelocities()
    
    return integrator

#=============================================================================================
# SUBROUTINES
#=============================================================================================

def equilibrate(system, temperature, sqrt_kT_over_m, coordinates, platform):
    collision_rate = 10.0 / units.picoseconds
    timestep = 1.0 * units.femtoseconds
    nsteps = 10000

    print "Equilibrating for %.3f ps..." % ((nsteps * timestep) / units.picoseconds)
    
    # Create integrator and context.
    integrator = openmm.LangevinIntegrator(temperature, collision_rate, timestep)
    context = openmm.Context(system, integrator, platform)

    # Set coordinates.
    context.setPositions(coordinates)

    # Set Maxwell-Boltzmann velocities
    velocities = sqrt_kT_over_m * numpy.random.standard_normal(size=sqrt_kT_over_m.shape)
    context.setVelocities(velocities)

    # Equilibrate.
    integrator.step(nsteps)

    # Compute energy
    print "Computing energy."
    state = context.getState(getEnergy=True)
    potential_energy = state.getPotentialEnergy()
    print "potential energy: %.3f kcal/mol" % (potential_energy / units.kilocalories_per_mole)

    # Get coordinates.
    state = context.getState(getPositions=True, getVelocities=True)    
    coordinates = state.getPositions(asNumpy=True)
    velocities = state.getVelocities(asNumpy=True)    
    box_vectors = state.getPeriodicBoxVectors()
    system.setDefaultPeriodicBoxVectors(*box_vectors)
    
    return [coordinates, velocities]

def compute_forces(platform, system, positions):
    """
    Compute forces for given positions.
    """

    timestep = 1.0 * units.femtoseconds
    integrator = openmm.VerletIntegrator(timestep)            
    context = openmm.Context(system, integrator, platform)    
    context.setPositions(positions)
    state = context.getState(getForces=True)
    forces = state.getForces(asNumpy=True)
    return forces

def compute_energy(platform, system, positions, velocities):
    """
    Compute total energy for positions and velocities.
    """
    # Create a context.
    timestep = 1.0 * units.femtoseconds
    integrator = openmm.VerletIntegrator(timestep)            
    context = openmm.Context(system, integrator, platform)
    # Set positions and velocities.
    context.setPositions(positions)
    context.setVelocities(velocities)
    # Compute total energy.
    state = context.getState(getEnergy=True)
    total_energy = state.getPotentialEnergy() + state.getKineticEnergy()
    return total_energy

def set_ion_charges(system, ion_charge, ion_atom_indices=[0], context=None):
    """
    Set the ion charge in the given system.

    """

    # Find NonbondedForce.
    nonbonded_force = None
    for force_index in range(system.getNumForces()):
        force = system.getForce(force_index)
        if hasattr(force, 'getParticleParameters'):
            nonbonded_force = force
            break
    
    # Set ion charge.
    for atom_index in ion_atom_indices:
        [old_charge, sigma, epsilon] = nonbonded_force.getParticleParameters(atom_index)
        nonbonded_force.setParticleParameters(atom_index, ion_charge * units.elementary_charge, sigma, epsilon)    

    # Update parameters in context.
    if context:
        nonbonded_force.updateParametersInContext(context)

    return

def modify_system(system):
    """
    Modify the System object to allow alchemical swapping of ion identities.

    """

    # Find NonbondedForce.
    nonbonded_force = None
    for force_index in range(system.getNumForces()):
        force = system.getForce(force_index)
        if hasattr(force, 'getParticleParameters'):
            nonbonded_force = force
            break
    
    # Set ion charges to zero.
    ion_atom_indices = [0, 1]
    for atom_index in ion_atom_indices:
        [old_charge, sigma, epsilon] = nonbonded_force.getParticleParameters(atom_index)
        nonbonded_force.setParticleParameters(atom_index, 0.0 * units.elementary_charge, sigma, epsilon)    

    # Create a CustomNonbondedForce for reaction field.
    energy_expression  = "compute_flag*C*q1*q2*(r^(-1) + k_rf*r^2 - c_rf);"
    energy_expression += "q1 = charge1*"
    energy_expression += "k_rf = cutoff^(-3)*(epsilon_solvent - 1)/(2*epsilon_solvent+1);"
    energy_expression += "c_rf = cutoff^(-1)*(3*epsilon_solvent)/(2*epsilon_solvent+1);"
    energy_expression += "compute_flag = alchemical1*(1-alchemical2) + alchemical2*(1-alchemical1) + alchemical1*alchemical2;" # only interactions between ions and between ions and water
    custom_force = openmm.CustomNonbondedForce(energy_expression)
    custom_force.addGlobalParameter('C', C)
    custom_force.addGlobal_parameter('lambda', 1.0)


    # Update parameters in context.
    if context:
        nonbonded_force.updateParametersInContext(context)

    return

def minimize(platform, system, positions):
    # Create a Context.
    timestep = 1.0 * units.femtoseconds
    integrator = openmm.VerletIntegrator(timestep)
    context = openmm.Context(system, integrator, platform)
    # Set coordinates.
    context.setPositions(positions)
    # Compute initial energy.
    state = context.getState(getEnergy=True)
    initial_potential = state.getPotentialEnergy()
    print "initial potential: %12.3f kcal/mol" % (initial_potential / units.kilocalories_per_mole)
    # Minimize.
    openmm.LocalEnergyMinimizer.minimize(context)    
    # Compute final energy.
    state = context.getState(getEnergy=True, getPositions=True)
    final_potential = state.getPotentialEnergy()
    positions = state.getPositions(asNumpy=True)        
    # Report
    print "final potential  : %12.3f kcal/mol" % (final_potential / units.kilocalories_per_mole)

    return positions

def initialize_netcdf_file(filename, nsteps_to_try):
    """
    Initialize NetCDF file for storage.
    
    """    

    # Open NetCDF file for writing
    # ncfile = netcdf.NetCDFFile(filename, 'w') # for Scientific.IO.NetCDF
    ncfile = netcdf.Dataset(filename, 'w') # for netCDF4

    # Create dimensions.
    ncfile.createDimension('nsteps_index', len(nsteps_to_try))
    ncfile.createDimension('iteration', 0) # unlimited number of iterations

    ncfile.createVariable('nsteps_to_try', 'i', ('nsteps_index',))
    ncfile.variables['nsteps_to_try'][:] = numpy.array(nsteps_to_try)

    ncfile.createVariable('work', 'd', ('nsteps_index', 'iteration'))

    # Force sync to disk to avoid data loss.
    ncfile.sync()

    return ncfile
    
#=============================================================================================
# MAIN AND TESTS
#=============================================================================================

if __name__ == "__main__":

    temperature = 298.0 * units.kelvin
    collision_rate = 20.0 / units.picoseconds
    timestep = 1.0 * units.femtoseconds
    pressure = 1.0 * units.atmospheres

    initial_ion_charge = +1.0
    final_ion_charge = -1.0

    netcdf_filename = 'ion-inversion.nc'

    #
    # Set up system.
    #

    from simtk.openmm import app

    # Read ion.
    pdbfile = app.PDBFile('nacl.pdb')

    # Load forcefield.
    forcefields_to_use = ['tip3p.xml', 'amber96.xml']
    forcefield = app.ForceField(*forcefields_to_use)

    # Solvate.
    modeller = app.Modeller(pdbfile.topology, pdbfile.positions)
    modeller.addSolvent(forcefield, model='tip3p', padding=15.0*units.angstroms)

    # Create System.
    cutoff = 9.0 * units.angstroms
    system = forcefield.createSystem(modeller.topology, nonbondedMethod=app.CutoffPeriodic, nonbondedCutoff=cutoff, rigidWater=True)

    # Extract positions.
    positions = modeller.positions

    print "system has %d atoms" % system.getNumParticles()

    #
    # Set up simulation.
    #

    kB = units.BOLTZMANN_CONSTANT_kB * units.AVOGADRO_CONSTANT_NA
    kT = kB * temperature # thermal energy
    beta = 1.0 / kT # inverse temperature
    
    niterations = 1000 # number of work samples to collect
    nsteps_to_try = [2**n for n in range(0,16)] # number of steps per switch
    print nsteps_to_try
    
    # Initialize netcdf file.
    ncfile = initialize_netcdf_file(netcdf_filename, nsteps_to_try)

    # List all available platforms
    print "Available platforms:"
    for platform_index in range(openmm.Platform.getNumPlatforms()):
        platform = openmm.Platform.getPlatform(platform_index)
        print "%5d %s" % (platform_index, platform.getName())
    print ""

    # Select platform.
    platform = openmm.Platform.getPlatformByName("CUDA")

    # Set initial ion charge.
    set_ion_charge(system, initial_ion_charge)
    
    # Minimize energy.
    print "Minimizing energy..."
    print system.getDefaultPeriodicBoxVectors()
    positions = minimize(platform, system, positions)

    # Add barostat (which will only be used during equilibration.
    barostat_frequency = 25 # number of steps between MC volume adjustments
    barostat = openmm.MonteCarloBarostat(pressure, temperature, barostat_frequency)
    system.addForce(barostat) 

    # Form vectors of masses and sqrt(kT/m) for force propagation and velocity randomization.
    print "Creating masses array..."
    nparticles = system.getNumParticles()
    masses = units.Quantity(numpy.zeros([nparticles,3], numpy.float64), units.amu)
    for particle_index in range(nparticles):
        masses[particle_index,:] = system.getParticleMass(particle_index)
    kT = kB * temperature # thermal energy    
    sqrt_kT_over_m = units.Quantity(numpy.zeros([nparticles,3], numpy.float64), units.nanometers / units.picosecond)
    for particle_index in range(nparticles):
        sqrt_kT_over_m[particle_index,:] = units.sqrt(kT / masses[particle_index,0]) # standard deviation of velocity distribution for each coordinate for this atom

    # Generate and store work samples for switching from +1 charge to -1 charge.
    print "Opening work file..."
    outfile = open('work.out', 'w') # file for writing work values
    lechner_outfile = open('lechner_work.out', 'w') # file for writing work values
    for iteration in range(niterations):
        print "iteration %5d / %5d" % (iteration, niterations)
        
        # Generate a new configuration.
        set_ion_charge(system, initial_ion_charge)
        barostat.setFrequency(barostat_frequency)
        [positions, velocities] = equilibrate(system, temperature, sqrt_kT_over_m, positions, platform)

        # DEBUG
        outfile = open('test.pdb', 'w')
        app.PDBFile.writeFile(modeller.topology, positions, file=outfile)
        outfile.close()

        # Turn off barostat.
        barostat.setFrequency(0)

        # Draw new Maxwell-Boltzmann velocities.
        velocities = sqrt_kT_over_m * numpy.random.standard_normal(size=sqrt_kT_over_m.shape)
        
        for (nsteps_index, nsteps) in enumerate(nsteps_to_try):
            # Create integrator.
            integrator = VelocityVerletIntegrator(timestep)

            # Create Context.
            context = openmm.Context(system, integrator, platform)

            # Set positions and velocities.
            context.setPositions(positions)
            context.setVelocities(velocities)

            # Apply constraints.
            tol = integrator.getConstraintTolerance()
            context.applyConstraints(tol)
            context.applyVelocityConstraints(tol)
            
            # Set initial ion charge.
            set_ion_charge(system, initial_ion_charge, context=context)

            # Compute initial total energy.
            state = context.getState(getEnergy=True)
            initial_total_energy = state.getPotentialEnergy() + state.getKineticEnergy()            

            # Integrate nonequilibrium switching trajectory.
            for step in range(nsteps):
                # Update ion charge.
                charge = (initial_ion_charge + (final_ion_charge - initial_ion_charge)*float(step+1)/float(nsteps)) * units.elementary_charge
                set_ion_charge(system, charge, context=context)
                
                # Velocity Verlet step
                integrator.step(1)
                
            # Compute final total energy.
            state = context.getState(getEnergy=True)
            final_total_energy = state.getPotentialEnergy() + state.getKineticEnergy()            

            # Compute total work.
            work = final_total_energy - initial_total_energy
            print "%8d steps : work = %8.1f kT" % (nsteps, work / kT)

            # Record data.
            ncfile.variables['work'][nsteps_index,iteration] = work / kT
            ncfile.sync()

    # Close netcdf file.
    ncfile.close()
