#!/usr/bin/env python

#=============================================================================================
# Test use of CustomNonbondedForce to replace interactions with a single particle in a 
# Lennard-Jones fluid.
#=============================================================================================

import simtk.openmm as openmm
import simtk.unit as units

import numpy

#=============================================================================================
# Periodic box of Lennard-Jones particles
#=============================================================================================

def LennardJonesFluid(mm=None, nx=6, ny=6, nz=6, mass=None, sigma=None, epsilon=None, cutoff=None, dispersion_correction=True):
    """
    Create a periodic rectilinear grid of Lennard-Jones particles.

    DESCRIPTION

    Parameters for argon are used by default.
    Cutoff is set to 3 sigma by default.
    
    OPTIONAL ARGUMENTS

    mm (implements simtk.openmm) - mm implementation to use (default: simtk.openmm)
    nx, ny, nz (int) - number of atoms to initially place on grid in each dimension (default: 6)
    mass (simtk.unit.Quantity with units of mass) - particle masses (default: 39.9 amu)
    sigma (simtk.unit.Quantity with units of length) - Lennard-Jones sigma parameter (default: 3.4 A)
    epsilon (simtk.unit.Quantity with units of energy) - Lennard-Jones well depth (default: 0.234 kcal/mol)
    cutoff (simtk.unit.Quantity with units of length) - cutoff for nonbonded interactions (default: 2.5 * sigma)
    dispersion_correction (boolean) - if True, will use analytical dispersion correction (if not using switching function) (default: True)

    EXAMPLES

    Create default-size Lennard-Jones fluid.

    >>> [system, coordinates] = LennardJonesFluid()

    Create a larger 10x8x5 box of Lennard-Jones particles.

    >>> [system, coordinates] = LennardJonesFluid(nx=10, ny=8, nz=5)

    """

    # Use pyOpenMM by default.
    if mm is None:
        mm = openmm

    # Default parameters
    if mass is None: mass = 39.9 * units.amu # argon
    if sigma is None: sigma = 3.4 * units.angstrom # argon
    if epsilon is None: epsilon = 0.238 * units.kilocalories_per_mole # argon
    charge        = 0.0 * units.elementary_charge
    if cutoff is None: cutoff = 2.5 * sigma

    scaleStepSizeX = 1.0
    scaleStepSizeY = 1.0
    scaleStepSizeZ = 1.0

    # Determine total number of atoms.
    natoms = nx * ny * nz

    # Create an empty system object.
    system = mm.System()

    # Set up periodic nonbonded interactions with a cutoff.
    nb = mm.NonbondedForce()
    nb.setNonbondedMethod(mm.NonbondedForce.CutoffPeriodic)    
    nb.setCutoffDistance(cutoff)
    nb.setUseDispersionCorrection(dispersion_correction)
        
    coordinates = units.Quantity(numpy.zeros([natoms,3],numpy.float32), units.angstrom)

    maxX = 0.0 * units.angstrom
    maxY = 0.0 * units.angstrom
    maxZ = 0.0 * units.angstrom

    atom_index = 0
    for ii in range(nx):
        for jj in range(ny):
            for kk in range(nz):
                system.addParticle(mass)
                nb.addParticle(charge, sigma, epsilon)
                x = sigma*scaleStepSizeX*ii
                y = sigma*scaleStepSizeY*jj
                z = sigma*scaleStepSizeZ*kk

                coordinates[atom_index,0] = x
                coordinates[atom_index,1] = y
                coordinates[atom_index,2] = z
                atom_index += 1
                
                # Wrap coordinates as needed.
                if x>maxX: maxX = x
                if y>maxY: maxY = y
                if z>maxZ: maxZ = z
                
    # Set periodic box vectors.
    x = maxX+2*sigma*scaleStepSizeX
    y = maxY+2*sigma*scaleStepSizeY
    z = maxZ+2*sigma*scaleStepSizeZ

    a = units.Quantity((x,                0*units.angstrom, 0*units.angstrom))
    b = units.Quantity((0*units.angstrom,                y, 0*units.angstrom))
    c = units.Quantity((0*units.angstrom, 0*units.angstrom, z))
    system.setDefaultPeriodicBoxVectors(a, b, c)

    # Add the nonbonded force.
    system.addForce(nb)

    return (system, coordinates)

def compute_energy(system, coordinates, platform_name='OpenCL'):
    """
    Report computed energy.

    ARGUMENTS
    
    system (simtk.openmm.System) - system for which energy is to be computed
    coordinates - the coordinates to use in computing the energy

    OPTIONAL ARGUMENTS

    platform_name (string) - the name of the OpenMM platform to use (default: 'OpenCL')

    """

    # Create a factory to produce alchemical intermediates.
    platform = openmm.Platform.getPlatformByName(platform_name)
    
    timestep = 1.0 * units.femtoseconds
    integrator = openmm.VerletIntegrator(timestep)
    context = openmm.Context(system, integrator, platform)
    context.setPositions(coordinates)
    state = context.getState(getEnergy=True)
    potential = state.getPotentialEnergy()    
    print "platform %12s : %24.8f kcal/mol" % (platform_name, potential / units.kilocalories_per_mole)
    return

if __name__ == "__main__":

    platform_names = ['Reference', 'OpenCL', 'Cuda']

    print "Lennard-Jones fluid system without dispersion correction..."
    [system, coordinates] = LennardJonesFluid(dispersion_correction=False)
    for platform_name in platform_names:
        compute_energy(system, coordinates, platform_name)
    print ""

    # Replace interactions with a single particle using CustomNonbondedForce.
    alchemical_atom_indices = range(0,10) # indices of particles to replace
    energy_expression = "4*lambda*epsilon*((sigma/r)^12 - (sigma/r)^6);"
    energy_expression += "epsilon = sqrt(epsilon1*epsilon2);" # mixing rule for epsilon
    energy_expression += "sigma = 0.5*(sigma1 + sigma2);" # mixing rule for sigma
    energy_expression += "lambda = alchemical1*(1-alchemical2) + alchemical2*(1-alchemical1) + alchemical1*alchemical2;"
    custom_nonbonded_force = openmm.CustomNonbondedForce(energy_expression)            
    custom_nonbonded_force.addPerParticleParameter("sigma") # Lennard-Jones sigma
    custom_nonbonded_force.addPerParticleParameter("epsilon") # Lennard-Jones epsilon
    custom_nonbonded_force.addPerParticleParameter("alchemical") # alchemical flag: 1 if this particle is alchemically modified, 0 otherwise
    system.addForce(custom_nonbonded_force)
    # Copy particle parameters.
    nonbonded_force = system.getForce(0)
    for particle_index in range(nonbonded_force.getNumParticles()):
        # Retrieve parameters.
        [charge, sigma, epsilon] = nonbonded_force.getParticleParameters(particle_index)
        # Add corresponding particle to softcore interactions.
        if particle_index in alchemical_atom_indices:
            # Turn off Lennard-Jones contribution from alchemically-modified particles.
            nonbonded_force.setParticleParameters(particle_index, charge, sigma, epsilon*0.0) 
            # Add contribution back to custom force.
            custom_nonbonded_force.addParticle([sigma, epsilon, 1])
        else:
            custom_nonbonded_force.addParticle([sigma, epsilon, 0])
    # Set periodicity and cutoff parameters corresponding to reference Force.
    if nonbonded_force.getNonbondedMethod() in [openmm.NonbondedForce.Ewald, openmm.NonbondedForce.PME]:
        # Convert Ewald and PME to CutoffPeriodic.
        custom_nonbonded_force.setNonbondedMethod( openmm.CustomNonbondedForce.CutoffPeriodic )
    else:
        custom_nonbonded_force.setNonbondedMethod( nonbonded_force.getNonbondedMethod() )
    custom_nonbonded_force.setCutoffDistance( nonbonded_force.getCutoffDistance() )

    print "After replacing interactions of a single particle..."
    for platform_name in platform_names:
        compute_energy(system, coordinates, platform_name)
    print ""

