#!/usr/local/bin/env python

#=============================================================================================
# MODULE DOCSTRING
#=============================================================================================

"""
Compare energy and gradient for standard Force terms, SoftcoreForce terms, and CustomForce terms
implementing alchemical annihilation and decoupling.

DESCRIPTION

COPYRIGHT

@author John D. Chodera <jchodera@gmail.com>

All code in this repository is released under the GNU General Public License.

This program is free software: you can redistribute it and/or modify it under
the terms of the GNU General Public License as published by the Free Software
Foundation, either version 3 of the License, or (at your option) any later
version.

This program is distributed in the hope that it will be useful, but WITHOUT ANY
WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A
PARTICULAR PURPOSE.  See the GNU General Public License for more details.
 
You should have received a copy of the GNU General Public License along with
this program.  If not, see <http://www.gnu.org/licenses/>.

TODO

"""

#=============================================================================================
# GLOBAL IMPORTS
#=============================================================================================

import os
import os.path
import sys
import math
import copy
import time

import numpy

import repex

import simtk
import simtk.unit as units
import simtk.openmm as openmm

#=============================================================================================
# SUBROUTINES
#=============================================================================================

def norm(n01):
    return n01.unit * numpy.sqrt(numpy.dot(n01/n01.unit, n01/n01.unit))

#=============================================================================================
# ALCHEMICAL MODIFICATIONS
#=============================================================================================

def createCustomSoftcoreGBOBC(solventDielectric, soluteDielectric, igb):

    custom = openmm.CustomGBForce()

    custom.addPerParticleParameter("q");
    custom.addPerParticleParameter("radius");
    custom.addPerParticleParameter("scale");
    custom.addPerParticleParameter("lambda");
    custom.addGlobalParameter("solventDielectric", solventDielectric);
    custom.addGlobalParameter("soluteDielectric", soluteDielectric);
    custom.addGlobalParameter("offset", 0.009)
    custom.addComputedValue("I",  "lambda1*lambda2*step(r+sr2-or1)*0.5*(1/L-1/U+0.25*(r-sr2^2/r)*(1/(U^2)-1/(L^2))+0.5*log(L/U)/r);"
                                  "U=r+sr2;"
                                  "L=max(or1, D);"
                                  "D=abs(r-sr2);"
                                  "sr2 = scale2*or2;"
                                  "or1 = radius1-offset; or2 = radius2-offset", openmm.CustomGBForce.ParticlePairNoExclusions)

    if igb == 2:
        custom.addComputedValue("B", "1/(1/or-tanh(0.8*psi+2.909125*psi^3)/radius);"
                                  "psi=I*or; or=radius-offset", openmm.CustomGBForce.SingleParticle)

    elif igb == 5:
        custom.addComputedValue("B", "1/(1/or-tanh(psi-0.8*psi^2+4.85*psi^3)/radius);"
                                  "psi=I*or; or=radius-offset", openmm.CustomGBForce.SingleParticle)

    else:
        print "ERROR: Incorrect igb# input for 'createCustomGBOBC'"
        print "Exiting..."
        sys.exit()

    custom.addEnergyTerm("lambda*28.3919551*(radius+0.14)^2*(radius/B)^6-lambda*0.5*138.935485*(1/soluteDielectric-1/solventDielectric)*q^2/B", openmm.CustomGBForce.SingleParticle);
    custom.addEnergyTerm("-138.935485*lambda1*lambda2*(1/soluteDielectric-1/solventDielectric)*q1*q2/f;"
                          "f=sqrt(r^2+B1*B2*exp(-r^2/(4*B1*B2)))", openmm.CustomGBForce.ParticlePairNoExclusions);        

    return custom

def build_softcore_system(reference_system, receptor_atoms, ligand_atoms, valence_lambda, coulomb_lambda, vdw_lambda, annihilate=False):
    """
    Build alchemically-modified system where ligand is decoupled or annihilated using *SoftcoreForce classes.
    
    """

    # Create new system.
    system = openmm.System()

    # Set periodic box vectors.
    [a,b,c] = reference_system.getDefaultPeriodicBoxVectors()
    system.setDefaultPeriodicBoxVectors(a,b,c)

    # Add atoms.
    for atom_index in range(reference_system.getNumParticles()):
        mass = reference_system.getParticleMass(atom_index)
        system.addParticle(mass)

    # Add constraints
    for constraint_index in range(reference_system.getNumConstraints()):
        [iatom, jatom, r0] = reference_system.getConstraintParameters(constraint_index)
        system.addConstraint(iatom, jatom, r0)    

    # Perturb force terms.
    for force_index in range(reference_system.getNumForces()):
        reference_force = reference_system.getForce(force_index)
        # Dispatch forces
        if isinstance(reference_force, openmm.HarmonicBondForce):
            # HarmonicBondForce
            force = openmm.HarmonicBondForce()
            for bond_index in range(reference_force.getNumBonds()):
                # Retrieve parameters.
                [iatom, jatom, r0, K] = reference_force.getBondParameters(bond_index)
                # Annihilate if directed.
                if annihilate and (iatom in ligand_atoms) and (jatom in ligand_atoms):
                    K *= valence_lambda
                # Add bond parameters.
                force.addBond(iatom, jatom, r0, K)
            # Add force to new system.
            system.addForce(force)
        elif isinstance(reference_force, openmm.HarmonicAngleForce):
            # HarmonicAngleForce
            force = openmm.HarmonicAngleForce()
            for angle_index in range(reference_force.getNumAngles()):
                # Retrieve parameters.
                [iatom, jatom, katom, theta0, Ktheta] = reference_force.getAngleParameters(angle_index)
                # Annihilate if directed:
                if annihilate and (iatom in ligand_atoms) and (jatom in ligand_atoms) and (katom in ligand_atoms):
                    Ktheta *= valence_lambda
                # Add parameters.
                force.addAngle(iatom, jatom, katom, theta0, Ktheta)
            # Add force to system.                
            system.addForce(force)
        elif isinstance(reference_force, openmm.PeriodicTorsionForce):
            # PeriodicTorsionForce
            force = openmm.PeriodicTorsionForce()
            for torsion_index in range(reference_force.getNumTorsions()):
                # Retrieve parmaeters.
                [particle1, particle2, particle3, particle4, periodicity, phase, k] = reference_force.getTorsionParameters(torsion_index)
                # Annihilate if directed:
                if annihilate and (particle1 in ligand_atoms) and (particle2 in ligand_atoms) and (particle3 in ligand_atoms) and (particle4 in ligand_atoms):
                    k *= valence_lambda
                # Add parameters.
                force.addTorsion(particle1, particle2, particle3, particle4, periodicity, phase, k)
            # Add force to system.
            system.addForce(force)            
        elif isinstance(reference_force, openmm.NonbondedForce):
            # NonbondedForce
            force = openmm.NonbondedSoftcoreForce()
            for particle_index in range(reference_force.getNumParticles()):
                # Retrieve parameters.
                [charge, sigma, epsilon] = reference_force.getParticleParameters(particle_index)
                # Alchemically modify parameters.
                if particle_index in ligand_atoms:
                    charge *= coulomb_lambda
                    epsilon *= vdw_lambda
                    # Add modified particle parameters.
                    force.addParticle(charge, sigma, epsilon, vdw_lambda)
                else:
                    # Add unmodified particle parameters.
                    force.addParticle(charge, sigma, epsilon, 1.0)
            for exception_index in range(reference_force.getNumExceptions()):
                # Retrieve parameters.
                [iatom, jatom, chargeprod, sigma, epsilon] = reference_force.getExceptionParameters(exception_index)
                # Alchemically modify epsilon and chargeprod.
                if (iatom in ligand_atoms) and (jatom in ligand_atoms):
                    if annihilate:
                        epsilon *= vdw_lambda 
                        chargeprod *= coulomb_lambda
                    # Add modified exception parameters.
                    force.addException(iatom, jatom, chargeprod, sigma, epsilon, vdw_lambda)
                else:
                    # Add unmodified exception parameters.
                    force.addException(iatom, jatom, chargeprod, sigma, epsilon, 1.0)                    
            # Set parameters.
            force.setNonbondedMethod( reference_force.getNonbondedMethod() )
            force.setCutoffDistance( reference_force.getCutoffDistance() )
            force.setReactionFieldDielectric( reference_force.getReactionFieldDielectric() )
            force.setEwaldErrorTolerance( reference_force.getEwaldErrorTolerance() )
            # Add force to new system.
            system.addForce(force)
            
        elif isinstance(reference_force, openmm.GBSAOBCForce):
            # GBSAOBCForce
            force = openmm.GBSAOBCSoftcoreForce()

            force.setSolventDielectric( reference_force.getSolventDielectric() )
            force.setSoluteDielectric( reference_force.getSoluteDielectric() )

            for particle_index in range(reference_force.getNumParticles()):
                # Retrieve parameters.
                [charge, radius, scaling_factor] = reference_force.getParticleParameters(particle_index)
                # Alchemically modify parameters.
                if particle_index in ligand_atoms:
                    # Scale charge and contribution to GB integrals.
                    force.addParticle(charge*coulomb_lambda, radius, scaling_factor, coulomb_lambda)
                else:
                    # Don't modulate GB.
                    force.addParticle(charge, radius, scaling_factor, 1.0)

            # Add force to new system.
            system.addForce(force)
        else:
            # Don't add unrecognized forces.
            pass

    return system

def build_custom_system(reference_system, receptor_atoms, ligand_atoms, valence_lambda, coulomb_lambda, vdw_lambda, annihilate=False):
    """
    Build alchemically-modified system where ligand is decoupled or annihilated using Custom*Force classes.
    
    """

    # Create new system.
    system = openmm.System()

    # Set periodic box vectors.
    [a,b,c] = reference_system.getDefaultPeriodicBoxVectors()
    system.setDefaultPeriodicBoxVectors(a,b,c)

    # Add atoms.
    for atom_index in range(reference_system.getNumParticles()):
        mass = reference_system.getParticleMass(atom_index)
        system.addParticle(mass)

    # Add constraints
    for constraint_index in range(reference_system.getNumConstraints()):
        [iatom, jatom, r0] = reference_system.getConstraintParameters(constraint_index)
        system.addConstraint(iatom, jatom, r0)    

    # Perturb force terms.
    for force_index in range(reference_system.getNumForces()):
        reference_force = reference_system.getForce(force_index)
        # Dispatch forces
        if isinstance(reference_force, openmm.HarmonicBondForce):
            # HarmonicBondForce
            force = openmm.HarmonicBondForce()
            for bond_index in range(reference_force.getNumBonds()):
                # Retrieve parameters.
                [iatom, jatom, r0, K] = reference_force.getBondParameters(bond_index)
                # Annihilate if directed.
                if annihilate and (iatom in ligand_atoms) and (jatom in ligand_atoms):
                    K *= valence_lambda
                # Add bond parameters.
                force.addBond(iatom, jatom, r0, K)
            # Add force to new system.
            system.addForce(force)
        elif isinstance(reference_force, openmm.HarmonicAngleForce):
            # HarmonicAngleForce
            force = openmm.HarmonicAngleForce()
            for angle_index in range(reference_force.getNumAngles()):
                # Retrieve parameters.
                [iatom, jatom, katom, theta0, Ktheta] = reference_force.getAngleParameters(angle_index)
                # Annihilate if directed:
                if annihilate and (iatom in ligand_atoms) and (jatom in ligand_atoms) and (katom in ligand_atoms):
                    Ktheta *= valence_lambda
                # Add parameters.
                force.addAngle(iatom, jatom, katom, theta0, Ktheta)
            # Add force to system.                
            system.addForce(force)
        elif isinstance(reference_force, openmm.PeriodicTorsionForce):
            # PeriodicTorsionForce
            force = openmm.PeriodicTorsionForce()
            for torsion_index in range(reference_force.getNumTorsions()):
                # Retrieve parmaeters.
                [particle1, particle2, particle3, particle4, periodicity, phase, k] = reference_force.getTorsionParameters(torsion_index)
                # Annihilate if directed:
                if annihilate and (particle1 in ligand_atoms) and (particle2 in ligand_atoms) and (particle3 in ligand_atoms) and (particle4 in ligand_atoms):
                    k *= valence_lambda
                # Add parameters.
                force.addTorsion(particle1, particle2, particle3, particle4, periodicity, phase, k)
            # Add force to system.
            system.addForce(force)            
        elif isinstance(reference_force, openmm.NonbondedForce):
            # NonbondedForce will handle charges and exception interactions.
            force = openmm.NonbondedForce()
            for particle_index in range(reference_force.getNumParticles()):
                # Retrieve parameters.
                [charge, sigma, epsilon] = reference_force.getParticleParameters(particle_index)
                # Remove Lennard-Jones interactions, which will be handled by CustomNonbondedForce.
                epsilon *= 0.0
                # Alchemically modify charges.
                if particle_index in ligand_atoms:
                    charge *= coulomb_lambda
                # Add modified particle parameters.
                force.addParticle(charge, sigma, epsilon)
            for exception_index in range(reference_force.getNumExceptions()):
                # Retrieve parameters.
                [iatom, jatom, chargeprod, sigma, epsilon] = reference_force.getExceptionParameters(exception_index)
                # Alchemically modify epsilon and chargeprod.
                # Note that exceptions are handled by NonbondedForce and not CustomNonbondedForce.
                if (iatom in ligand_atoms) and (jatom in ligand_atoms):
                    if annihilate:
                        epsilon *= vdw_lambda 
                        chargeprod *= coulomb_lambda
                # Add modified exception parameters.
                force.addException(iatom, jatom, chargeprod, sigma, epsilon)
            # Set parameters.
            force.setNonbondedMethod( reference_force.getNonbondedMethod() )
            force.setCutoffDistance( reference_force.getCutoffDistance() )
            force.setReactionFieldDielectric( reference_force.getReactionFieldDielectric() )
            force.setEwaldErrorTolerance( reference_force.getEwaldErrorTolerance() )
            # Add force to new system.
            system.addForce(force)

            # CustomNonbondedForce
            # Softcore potential.
            energy_expression = "4*epsilon*lambda*x*(x-1.0);"
            energy_expression += "x = 1.0/(alpha*(1.0-lambda) + (r/sigma)^6);"
            energy_expression += "epsilon = sqrt(epsilon1*epsilon2);"
            energy_expression += "sigma = 0.5*(sigma1 + sigma2);"
            energy_expression += "lambda = lambda1*lambda2;"

            force = openmm.CustomNonbondedForce(energy_expression)            
            alpha = 0.5 # softcore parameter
            force.addGlobalParameter("alpha", alpha);
            force.addPerParticleParameter("sigma")
            force.addPerParticleParameter("epsilon")
            force.addPerParticleParameter("lambda"); 
            for particle_index in range(reference_force.getNumParticles()):
                # Retrieve parameters.
                [charge, sigma, epsilon] = reference_force.getParticleParameters(particle_index)
                # Alchemically modify parameters.
                if particle_index in ligand_atoms:
                    force.addParticle([sigma, epsilon, vdw_lambda])
                else:
                    force.addParticle([sigma, epsilon, 1.0])
            for exception_index in range(reference_force.getNumExceptions()):
                # Retrieve parameters.
                [iatom, jatom, chargeprod, sigma, epsilon] = reference_force.getExceptionParameters(exception_index)
                # All exceptions are handled by NonbondedForce, so we exclude all these here.
                force.addExclusion(iatom, jatom)
            if reference_force.getNonbondedMethod() in [openmm.NonbondedForce.Ewald, openmm.NonbondedForce.PME]:
                force.setNonbondedMethod( openmm.CustomNonbondedForce.CutoffPeriodic )
            else:
                force.setNonbondedMethod( reference_force.getNonbondedMethod() )
            force.setCutoffDistance( reference_force.getCutoffDistance() )
            system.addForce(force)
            
        elif isinstance(reference_force, openmm.GBSAOBCForce):
            # GBSAOBCForce
            solvent_dielectric = reference_force.getSolventDielectric()
            solute_dielectric = reference_force.getSoluteDielectric()
            force = createCustomSoftcoreGBOBC(solvent_dielectric, solute_dielectric, igb=5)
            for particle_index in range(reference_force.getNumParticles()):
                # Retrieve parameters.
                [charge, radius, scaling_factor] = reference_force.getParticleParameters(particle_index)
                # Alchemically modify parameters.
                if particle_index in ligand_atoms:
                    # Scale charge and contribution to GB integrals.
                    force.addParticle([charge*coulomb_lambda, radius, scaling_factor, coulomb_lambda])
                else:
                    # Don't modulate GB.
                    force.addParticle([charge, radius, scaling_factor, 1.0])

            # Add force to new system.
            system.addForce(force)
        else:
            # Don't add unrecognized forces.
            pass

    return system

#=============================================================================================
# MAIN AND TESTS
#=============================================================================================

if __name__ == "__main__":
    verbose = True

    # PARAMETERS
    complex_prmtop_filename = 'system.prmtop'
    complex_crd_filename = 'system.crd'

    pressure = 1.0 * units.atmosphere
    temperature = 300.0 * units.kelvin
    barostat_frequency = 10
    nsteps = 2500

    # Specify which CPUs should be attached to specific GPUs for maximum performance.
    cpu_platform_name = 'Reference'
    gpu_platform_name = 'OpenCL'
    #cpuid_gpuid_mapping = { 0:0, 1:1, 2:0, 3:1, 8:2, 9:3, 10:4, 11:5, 12:2, 13:3, 14:5, 15:5 } # cpuid:gpuid for NCSA Forge (doubled up)
    cpuid_gpuid_mapping = { 0:0, 1:1, 8:2, 9:3, 10:4, 11:5 } # cpuid:gpuid for NCSA Forge

    # Initialize MPI, if available.    
    try:
        # Initialize MPI.
        # Set up device to bind to.
        from mpi4py import MPI # MPI wrapper
        hostname = os.uname()[1]

        # Turn off output from non-root nodes:
        if not (MPI.COMM_WORLD.rank==0):
            verbose = False

        # Make sure random number generators have unique seeds.
        seed = numpy.random.randint(sys.maxint - MPI.COMM_WORLD.size) + MPI.COMM_WORLD.rank
        numpy.random.seed(seed)

        # Choose appropriate platform for each device.
        cpuid = MPI.COMM_WORLD.rank # use default rank as CPUID (TODO: Improve this)
        #print "node '%s' MPI_WORLD rank %d/%d" % (hostname, MPI.COMM_WORLD.rank, MPI.COMM_WORLD.size)
        if cpuid in cpuid_gpuid_mapping.keys():
            platform = openmm.Platform.getPlatformByName(gpu_platform_name)
            deviceid = cpuid_gpuid_mapping[cpuid]
            platform.setPropertyDefaultValue('OpenCLDeviceIndex', '%d' % deviceid) # select OpenCL device index
            platform.setPropertyDefaultValue('CudaDeviceIndex', '%d' % deviceid) # select Cuda device index
            print "node '%s' MPI_WORLD rank %d/%d cpuid %d platform %s deviceid %d" % (hostname, MPI.COMM_WORLD.rank, MPI.COMM_WORLD.size, cpuid, gpu_platform_name, deviceid)
        else:
            platform = openmm.Platform.getPlatformByName(cpu_platform_name)

        # Set up CPU and GPU communicators.
        gpu_process_list = filter(lambda x : x < MPI.COMM_WORLD.size, cpuid_gpuid_mapping.keys())
        #gpu_process_list = gpu_process_list[0:2] # DEBUG (limit to specified number of GPUs)
        print gpu_process_list
        all_group = MPI.COMM_WORLD.Get_group()         
        gpu_group = MPI.Group.Incl(all_group, gpu_process_list)
        cpu_group = MPI.Group.Excl(all_group, gpu_process_list)
        # gpu_comm = MPI.COMM_WORLD.Create(gpu_group)
        # cpu_comm = MPI.COMM_WORLD.Create(cpu_group)

        if cpuid in gpu_process_list:
            color = 0 # GPU
        else:
            color = 1 # CPU    
        comm = MPI.COMM_WORLD.Split(color=color)

        # DEBUG
        #print "node '%s' MPI_WORLD rank %d/%d gpu_comm rank %d/%d cpu_comm rank %d/%d" % (hostname, MPI.COMM_WORLD.rank, MPI.COMM_WORLD.size, gpu_comm.rank, gpu_comm.size, cpu_comm.rank, cpu_comm.size)
        print "node '%s' MPI_WORLD rank %d/%d comm rank %d/%d" % (hostname, MPI.COMM_WORLD.rank, MPI.COMM_WORLD.size, comm.rank, comm.size)

        # Use just the GPU communicator.
        # DEBUG: Does this work?  Or do we have to take two branches?
        #comm = gpu_comm
        
    except Exception as e:
        print e
        print "WARNING: Could not initialize MPI; falling back to serial execution."
        platform = openmm.Platform.getPlatformByName(gpu_platform_name)
        comm = None
        verbose = True

    # Load standard systems.
    import simtk.pyopenmm.amber.amber_file_parser as amber
    cutoff = 7.0 * units.angstrom
    reference_system = amber.readAmberSystem(complex_prmtop_filename, mm=openmm, gbmodel=None, nonbondedCutoff=cutoff, nonbondedMethod='CutoffPeriodic', shake='h-bonds')
    coordinates = amber.readAmberCoordinates(complex_crd_filename)
    coordinates = units.Quantity(numpy.array(coordinates / coordinates.unit), coordinates.unit) # make into numpy

    receptor_atoms = range(12,reference_system.getNumParticles()) # solvent
    ligand_atoms = range(0,12) # solute
    
    # Construct alchemical systems.
    if verbose: print "Constructing alchemical states..."
    systems = list() # alchemically-modified systems
    valence_lambda = numpy.array([1.0, 1.00, 1.0, 1.00, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])
    coulomb_lambda = numpy.array([1.0, 0.75, 0.5, 0.25, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])    
    vdw_lambda     = numpy.array([1.0, 1.00, 1.0, 1.00, 1.0, 0.8, 0.6, 0.4, 0.3, 0.2, 0.1, 0.0])
    nlambda = len(coulomb_lambda)
    for lambda_index in range(nlambda):
        # Create alchemically-modified state.
        system = build_custom_system(reference_system, receptor_atoms, ligand_atoms, valence_lambda[lambda_index], coulomb_lambda[lambda_index], vdw_lambda[lambda_index], annihilate=True)

        # Append system.
        systems.append(system)

    # Set up reference thermodynamic state.
    import thermodynamics
    reference_state = thermodynamics.ThermodynamicState(systems[0], temperature, pressure)

    # Create replica-exchange simulation.
    if verbose: print "Setting up replica-exchange simulation..."
    output_filename = 'repex.nc'
    #simulation = repex.ReplicaExchange(states, [coordinates], output_filename, mpicomm=comm) # initialize the replica-exchange simulation
    simulation = repex.HamiltonianExchange(reference_state, systems, [coordinates], output_filename, mpicomm=comm) # initialize the replica-exchange simulation
    simulation.verbose = True
    simulation.number_of_iterations = 1000
    simulation.timestep = 2.0 * units.femtoseconds # set the timestep for integration
    simulation.nsteps_per_iteration = nsteps
    simulation.minimize = False
    simulation.show_mixing_statistics = True
    simulation.number_of_equilibration_iterations = 0
    simulation.platform = platform
    simulation.replica_mixing_scheme = 'none'
    #simulation.replica_mixing_scheme = 'swap-all'    
    
    if verbose: print "Running..."
    if comm:
        # Only GPU nodes run simulation.
        if cpuid in gpu_process_list:   
            simulation.run()
        # Wait for all nodes to finish
        MPI.COMM_WORLD.Barrier()            
    else:
        simulation.run() # run the simulation
                                

