#!/usr/local/bin/env python

#=============================================================================================
# MODULE DOCSTRING
#=============================================================================================

"""

"""

#=============================================================================================
# GLOBAL IMPORTS
#=============================================================================================

import os
import os.path
import numpy
import math

import simtk
import simtk.openmm
import simtk.unit as units

#=============================================================================================
# Periodic box of Lennard-Jones particles
#=============================================================================================

def LennardJonesFluid(mm=None, nx=6, ny=6, nz=6, mass=None, sigma=None, epsilon=None, cutoff=None, switch=True, dispersion_correction=True):
    """
    Create a periodic rectilinear grid of Lennard-Jones particles.

    DESCRIPTION

    Parameters for argon are used by default.
    Cutoff is set to 3 sigma by default.
    
    OPTIONAL ARGUMENTS

    mm (implements simtk.openmm) - mm implementation to use (default: simtk.openmm)
    nx, ny, nz (int) - number of atoms to initially place on grid in each dimension (default: 6)
    mass (simtk.unit.Quantity with units of mass) - particle masses (default: 39.9 amu)
    sigma (simtk.unit.Quantity with units of length) - Lennard-Jones sigma parameter (default: 3.4 A)
    epsilon (simtk.unit.Quantity with units of energy) - Lennard-Jones well depth (default: 0.234 kcal/mol)
    cutoff (simtk.unit.Quantity with units of length) - cutoff for nonbonded interactions (default: 2.5 * sigma)
    switch (simtk.unit.Quantity with units of length) - if specified, the switching function will be turned on at this distance (default: None)
    dispersion_correction (boolean) - if True, will use analytical dispersion correction (if not using switching function) (default: True)

    EXAMPLES

    Create default-size Lennard-Jones fluid.

    >>> [system, coordinates] = LennardJonesFluid()

    Create a larger 10x8x5 box of Lennard-Jones particles.

    >>> [system, coordinates] = LennardJonesFluid(nx=10, ny=8, nz=5)

    Create Lennard-Jones fluid using switched particle interactions (switched off betwee 7 and 9 A) and more particles.

    >>> [system, coordinates] = LennardJonesFluid(nx=10, ny=10, nz=10, switch=7.0*units.angstroms, cutoff=9.0*units.angstroms)


    """

    # Use pyOpenMM by default.
    if mm is None:
        mm = simtk.openmm

    # Default parameters
    if mass is None: mass = 39.9 * units.amu # argon
    if sigma is None: sigma = 3.4 * units.angstrom # argon
    if epsilon is None: epsilon = 0.238 * units.kilocalories_per_mole # argon
    charge        = 0.0 * units.elementary_charge
    if cutoff is None: cutoff = 2.5 * sigma

    scaleStepSizeX = 1.0
    scaleStepSizeY = 1.0
    scaleStepSizeZ = 1.0

    # Determine total number of atoms.
    natoms = nx * ny * nz

    # Create an empty system object.
    system = mm.System()

    # Set up periodic nonbonded interactions with a cutoff.
    if switch:
        energy_expression = "LJ * S;"
        energy_expression += "LJ = 4*epsilon*((sigma/r)^12 - (sigma/r)^6);"
        #energy_expression += "sigma = 0.5 * (sigma1 + sigma2);"
        #energy_expression += "epsilon = sqrt(epsilon1*epsilon2);"
        energy_expression += "S = (cutoff^2 - r^2)^2 * (cutoff^2 + 2*r^2 - 3*switch^2) / (cutoff^2 - switch^2)^3;"
        nb = mm.CustomNonbondedForce(energy_expression)
        nb.addGlobalParameter('switch', switch)
        nb.addGlobalParameter('cutoff', cutoff)
        nb.addGlobalParameter('sigma', sigma)
        nb.addGlobalParameter('epsilon', epsilon)
        nb.setNonbondedMethod(mm.CustomNonbondedForce.CutoffPeriodic)
        nb.setCutoffDistance(cutoff)        
    else:
        nb = mm.NonbondedForce()
        nb.setNonbondedMethod(mm.NonbondedForce.CutoffPeriodic)    
        nb.setCutoffDistance(cutoff)
        nb.setUseDispersionCorrection(dispersion_correction)
        
    coordinates = units.Quantity(numpy.zeros([natoms,3],numpy.float32), units.angstrom)

    maxX = 0.0 * units.angstrom
    maxY = 0.0 * units.angstrom
    maxZ = 0.0 * units.angstrom

    atom_index = 0
    for ii in range(nx):
        for jj in range(ny):
            for kk in range(nz):
                system.addParticle(mass)
                if switch:
                    nb.addParticle([])                    
                else:
                    nb.addParticle(charge, sigma, epsilon)
                x = sigma*scaleStepSizeX*ii
                y = sigma*scaleStepSizeY*jj
                z = sigma*scaleStepSizeZ*kk

                coordinates[atom_index,0] = x
                coordinates[atom_index,1] = y
                coordinates[atom_index,2] = z
                atom_index += 1
                
                # Wrap coordinates as needed.
                if x>maxX: maxX = x
                if y>maxY: maxY = y
                if z>maxZ: maxZ = z
                
    # Set periodic box vectors.
    x = maxX+2*sigma*scaleStepSizeX
    y = maxY+2*sigma*scaleStepSizeY
    z = maxZ+2*sigma*scaleStepSizeZ

    a = units.Quantity((x,                0*units.angstrom, 0*units.angstrom))
    b = units.Quantity((0*units.angstrom,                y, 0*units.angstrom))
    c = units.Quantity((0*units.angstrom, 0*units.angstrom, z))
    system.setDefaultPeriodicBoxVectors(a, b, c)

    # Add the nonbonded force.
    system.addForce(nb)

    return (system, coordinates)

#=============================================================================================
# Periodic water box
#=============================================================================================

def WaterBox(box_edge=2.5*units.nanometers, cutoff=0.9*units.nanometers):
   """
   Create a water box test system.

   OPTIONAL ARGUMENTS

   box_edge (simtk.unit.Quantity with units compatible with nanometers) - edge length for cubic box [should be greater than 2*cutoff] (default: 2.3 nm)
   cutoff  (simtk.unit.Quantity with units compatible with nanometers) - nonbonded cutoff (default: 0.9 * units.nanometers)

   RETURNS

   system (simtk.openmm.System) - the water box system
   positions (simtk.unit.Quantity of nparticles x 3 with units compatible with nanometers) - the particle positions

   """

   # DEBUG
   #import testsystems
   #return testsystems.WaterBox(cutoff=cutoff)
   # END DEBUG

   import simtk.openmm.app as app

   # Load forcefield for solvent model.
   ff =  app.ForceField('tip3p.xml')

   # Create empty topology and coordinates.
   top = app.Topology()
   pos = units.Quantity((), units.angstroms)

   # Create new Modeller instance.
   m = app.Modeller(top, pos)

   # Add solvent to specified box dimensions.
   boxSize = units.Quantity(numpy.ones([3]) * box_edge/box_edge.unit, box_edge.unit)
   m.addSolvent(ff, boxSize=boxSize)
   
   # Get new topology and coordinates.
   newtop = m.getTopology()
   newpos = m.getPositions()
   
   # Convert positions to numpy.
   positions = units.Quantity(numpy.array(newpos / newpos.unit), newpos.unit)
   
   # Create OpenMM System.
   nonbondedMethod = app.CutoffPeriodic
   constraints = app.HBonds
   system = ff.createSystem(newtop, nonbondedMethod=nonbondedMethod, nonbondedCutoff=cutoff, constraints=constraints, rigidWater=True, removeCMMotion=False)

   # Turn on switching function.
   forces = { system.getForce(index).__class__.__name__ : system.getForce(index) for index in range(system.getNumForces()) }
   forces['NonbondedForce'].setUseSwitchingFunction(True)
   forces['NonbondedForce'].setSwitchingDistance(cutoff - 0.5 * units.angstroms)

   return [system, positions]

def IdealGas(mm=None, nx=6, ny=6, nz=6, mass=None, sigma=None, epsilon=None):
    """
    Create a periodic rectilinear grid of Lennard-Jones particles.

    DESCRIPTION

    Parameters for argon are used by default.
    Cutoff is set to 3 sigma by default.
    
    OPTIONAL ARGUMENTS

    mm (implements simtk.openmm) - mm implementation to use (default: simtk.openmm)
    nx, ny, nz (int) - number of atoms to initially place on grid in each dimension (default: 6)
    mass (simtk.unit.Quantity with units of mass) - particle masses (default: 39.9 amu)
    sigma (simtk.unit.Quantity with units of length) - Lennard-Jones sigma parameter (default: 3.4 A)
    epsilon (simtk.unit.Quantity with units of energy) - Lennard-Jones well depth (default: 0.234 kcal/mol)

    EXAMPLES

    Create default-size Lennard-Jones fluid.

    >>> [system, coordinates] = LennardJonesFluid()

    Create a larger 10x8x5 box of Lennard-Jones particles.

    >>> [system, coordinates] = LennardJonesFluid(nx=10, ny=8, nz=5)

    Create Lennard-Jones fluid using switched particle interactions (switched off betwee 7 and 9 A) and more particles.

    >>> [system, coordinates] = LennardJonesFluid(nx=10, ny=10, nz=10, switch=7.0*units.angstroms, cutoff=9.0*units.angstroms)


    """

    # Use pyOpenMM by default.
    if mm is None:
        mm = simtk.openmm

    # Default parameters
    if mass is None: mass = 39.9 * units.amu # argon
    if sigma is None: sigma = 3.4 * units.angstrom # argon
    epsilon = 0.0 * units.kilocalories_per_mole # argon
    charge        = 0.0 * units.elementary_charge
    cutoff = 2.5 * sigma

    scaleStepSizeX = 1.0
    scaleStepSizeY = 1.0
    scaleStepSizeZ = 1.0

    # Determine total number of atoms.
    natoms = nx * ny * nz

    # Create an empty system object.
    system = mm.System()

    # Set up periodic nonbonded interactions with a cutoff.
    nb = mm.NonbondedForce()
    nb.setNonbondedMethod(mm.NonbondedForce.CutoffPeriodic)    
    nb.setCutoffDistance(cutoff)
    
    coordinates = units.Quantity(numpy.zeros([natoms,3],numpy.float32), units.angstrom)

    maxX = 0.0 * units.angstrom
    maxY = 0.0 * units.angstrom
    maxZ = 0.0 * units.angstrom

    atom_index = 0
    for ii in range(nx):
        for jj in range(ny):
            for kk in range(nz):
                system.addParticle(mass)
                nb.addParticle(charge, sigma, epsilon)
                x = sigma*scaleStepSizeX*ii
                y = sigma*scaleStepSizeY*jj
                z = sigma*scaleStepSizeZ*kk

                coordinates[atom_index,0] = x
                coordinates[atom_index,1] = y
                coordinates[atom_index,2] = z
                atom_index += 1
                
                # Wrap coordinates as needed.
                if x>maxX: maxX = x
                if y>maxY: maxY = y
                if z>maxZ: maxZ = z
                
    # Set periodic box vectors.
    x = maxX+2*sigma*scaleStepSizeX
    y = maxY+2*sigma*scaleStepSizeY
    z = maxZ+2*sigma*scaleStepSizeZ

    a = units.Quantity((x,                0*units.angstrom, 0*units.angstrom))
    b = units.Quantity((0*units.angstrom,                y, 0*units.angstrom))
    c = units.Quantity((0*units.angstrom, 0*units.angstrom, z))
    system.setDefaultPeriodicBoxVectors(a, b, c)

    # Add the nonbonded force.
    system.addForce(nb)

    return (system, coordinates)

#=============================================================================================
# MAIN AND TESTS
#=============================================================================================

if __name__ == "__main__":
    import doctest

    # Test all systems on Reference platform.
    platform = simtk.openmm.Platform.getPlatformByName("Reference")
    doctest.testmod()    

