#!/usr/bin/env python

#=============================================================================================
# Test action of NonbondedForce.setUseDispersionCorrection() 
#=============================================================================================

import simtk.openmm as openmm
import simtk.unit as units

import numpy

#=============================================================================================
# Periodic box of Lennard-Jones particles
#=============================================================================================

def LennardJonesFluid(mm=None, nx=6, ny=6, nz=6, mass=None, sigma=None, epsilon=None, cutoff=None, switch=False, dispersion_correction=True):
    """
    Create a periodic rectilinear grid of Lennard-Jones particles.

    DESCRIPTION

    Parameters for argon are used by default.
    Cutoff is set to 3 sigma by default.
    
    OPTIONAL ARGUMENTS

    mm (implements simtk.openmm) - mm implementation to use (default: simtk.openmm)
    nx, ny, nz (int) - number of atoms to initially place on grid in each dimension (default: 6)
    mass (simtk.unit.Quantity with units of mass) - particle masses (default: 39.9 amu)
    sigma (simtk.unit.Quantity with units of length) - Lennard-Jones sigma parameter (default: 3.4 A)
    epsilon (simtk.unit.Quantity with units of energy) - Lennard-Jones well depth (default: 0.234 kcal/mol)
    cutoff (simtk.unit.Quantity with units of length) - cutoff for nonbonded interactions (default: 2.5 * sigma)
    dispersion_correction (boolean) - if True, will use analytical dispersion correction (if not using switching function) (default: True)

    EXAMPLES

    Create default-size Lennard-Jones fluid.

    >>> [system, coordinates] = LennardJonesFluid()

    Create a larger 10x8x5 box of Lennard-Jones particles.

    >>> [system, coordinates] = LennardJonesFluid(nx=10, ny=8, nz=5)

    """

    # Use pyOpenMM by default.
    if mm is None:
        mm = openmm

    # Default parameters
    if mass is None: mass = 39.9 * units.amu # argon
    if sigma is None: sigma = 3.4 * units.angstrom # argon
    if epsilon is None: epsilon = 0.238 * units.kilocalories_per_mole # argon
    charge        = 0.0 * units.elementary_charge
    if cutoff is None: cutoff = 2.5 * sigma

    scaleStepSizeX = 1.0
    scaleStepSizeY = 1.0
    scaleStepSizeZ = 1.0

    # Determine total number of atoms.
    natoms = nx * ny * nz

    # Create an empty system object.
    system = mm.System()

    # Set up periodic nonbonded interactions with a cutoff.
    nb = mm.NonbondedForce()
    nb.setNonbondedMethod(mm.NonbondedForce.CutoffPeriodic)    
    nb.setCutoffDistance(cutoff)
    nb.setUseDispersionCorrection(dispersion_correction)
        
    coordinates = units.Quantity(numpy.zeros([natoms,3],numpy.float32), units.angstrom)

    maxX = 0.0 * units.angstrom
    maxY = 0.0 * units.angstrom
    maxZ = 0.0 * units.angstrom

    atom_index = 0
    for ii in range(nx):
        for jj in range(ny):
            for kk in range(nz):
                system.addParticle(mass)
                if switch:
                    nb.addParticle([])                    
                else:
                    nb.addParticle(charge, sigma, epsilon)
                x = sigma*scaleStepSizeX*ii
                y = sigma*scaleStepSizeY*jj
                z = sigma*scaleStepSizeZ*kk

                coordinates[atom_index,0] = x
                coordinates[atom_index,1] = y
                coordinates[atom_index,2] = z
                atom_index += 1
                
                # Wrap coordinates as needed.
                if x>maxX: maxX = x
                if y>maxY: maxY = y
                if z>maxZ: maxZ = z
                
    # Set periodic box vectors.
    x = maxX+2*sigma*scaleStepSizeX
    y = maxY+2*sigma*scaleStepSizeY
    z = maxZ+2*sigma*scaleStepSizeZ

    a = units.Quantity((x,                0*units.angstrom, 0*units.angstrom))
    b = units.Quantity((0*units.angstrom,                y, 0*units.angstrom))
    c = units.Quantity((0*units.angstrom, 0*units.angstrom, z))
    system.setDefaultPeriodicBoxVectors(a, b, c)

    # Add the nonbonded force.
    system.addForce(nb)

    return (system, coordinates)

def compute_energy(system, coordinates, platform_name='OpenCL'):
    """
    Report computed energy.

    ARGUMENTS
    
    system (simtk.openmm.System) - system for which energy is to be computed
    coordinates - the coordinates to use in computing the energy

    OPTIONAL ARGUMENTS

    platform_name (string) - the name of the OpenMM platform to use (default: 'OpenCL')

    """

    # Create a factory to produce alchemical intermediates.
    platform = openmm.Platform.getPlatformByName(platform_name)
    
    timestep = 1.0 * units.femtoseconds
    integrator = openmm.VerletIntegrator(timestep)
    context = openmm.Context(system, integrator, platform)
    context.setPositions(coordinates)
    state = context.getState(getEnergy=True)
    potential = state.getPotentialEnergy()    
    print "platform %12s : %24.8f kcal/mol" % (platform_name, potential / units.kilocalories_per_mole)
    return

if __name__ == "__main__":

    platform_names = ['Reference', 'OpenCL', 'Cuda']

    print "Creating Lennard-Jones fluid system without dispersion correction..."
    [system, coordinates] = LennardJonesFluid(dispersion_correction=False)
    for platform_name in platform_names:
        compute_energy(system, coordinates, platform_name)
    print ""

    print "Creating Lennard-Jones fluid system with dispersion correction..."
    [system, coordinates] = LennardJonesFluid(dispersion_correction=True)
    for platform_name in platform_names:
        compute_energy(system, coordinates, platform_name)
    print ""

