#!/usr/bin/env python

 
"""Script to run simple argon simulation."""
__author__ = "Randall J. Radmer"
__version__ = "1.0"

import sys

# Import the OpenMM libraries  and wrapper code can call it "mm"
import simtk.openmm.openmm as mm
import simtk.unit as unit

try:
    openMM_version=mm.Platform.getOpenMMVersion()
except AttributeError:
    raise Exception("must use OpenMM version 2.0 or better")
print "Using OpenMM Version:", openMM_version

# simple data dump/write module
import dumpOutput

# Step size in unit.femtoseconds
stepSize =  1.0 * unit.femtosecond   # Same as unit.Quantity(1.0, femtosecond)

# Total simulation time in naonseconds
totalTime = 0.5 * unit.picosecond    # Same as unit.Quantity(0.01, nanosecond)

# Total number of steps to run on the GPU. A report is generated
# when the GPU returns results
stepsPerReport = 10                  # No units needed

# Atoms outside of cutoff distance will not see each other
cutoffDistance = 1.0 * unit.meters  # Same as unit.Quantity(1.0, meters)

# Set atomic parameters
mass_Ar    = 39.9  * unit.amu        # Same as unit.Quantity(39.9, amu)
sigma_Ar   = 3.350 * unit.angstrom   # Same as unit.Quantity(3.350, angstrom)
epsilon_Ar = 0.996 * unit.kilojoule_per_mole
q_Ar       = 0.0   * unit.elementary_charge

# Tell OpenMM which Platform is preferred (set to None to let OpenMM pick best)
#preferredPlatformName='Reference'
#preferredPlatformName='OpenCL'
#preferredPlatformName='Cuda'
preferredPlatformName=None

# Tell OpenMM which GPU is preferred (set to None to let OpenMM pick best)
#preferredGPU='1'
preferredGPU=None

# Write more info by setting verbose to True
verbose=True

def main():
    """Main: builds OpenMM system, adds particles and parameters,
       runs and write coords at PDB file"""

    # Open out PDB file or just write to standard out
    try:
        fOut=open(sys.argv[1], 'w')
    except IndexError:
        fOut=sys.stdout

    # Create OpenMM Platform objects, list them, and
    # if preferredPlatformName is not set then keep the fastest.
    platformDict={}
    for ii in range(mm.Platform.getNumPlatforms()):
        p = mm.Platform.getPlatform(ii)
        platformSpeed = p.getSpeed()
        platformDict[platformSpeed]=p

    platform = None
    if verbose:
        sys.stdout.write("Available Platforms:\n")
    for platformSpeed in sorted(platformDict.keys()):
        p=platformDict[platformSpeed]
        pName=p.getName()
        if pName==preferredPlatformName:
            platform=p
        if verbose:
            sys.stdout.write("  %s (Speed=%s)\n" %
                              (pName, platformSpeed))
    del platformDict

    if platform is None:
        platform=p
    platformName=platform.getName()
    if verbose:
        sys.stdout.write("Using %s Platform\n" % platformName)

    # set GPU number
    if preferredGPU is not None:
        if 'Cuda'==platformName:
            platform.setPropertyDefaultValue('CudaDevice', preferredGPU)
        if 'OpenCL'==platformName:
            platform.setPropertyDefaultValue('OpenCLDeviceIndex', preferredGPU)

    # Create an OpenMM System objects.
    # This will hold all other objects created below.
    system=mm.System()

    # Create and setup an OpenMM Nonbonded Force object
    # There are also force objects for bonds, angles, etc.,
    # but we don't  need these for argon atoms.
    nb = mm.NonbondedForce()
    nb.setNonbondedMethod(mm.NonbondedForce.CutoffNonPeriodic)
    nb.setCutoffDistance(cutoffDistance)
    # Add nonbonded force object to system object
    system.addForce(nb)

    # Create a center of mass motion removing "Force" object
    # and add to system
    cm=mm.CMMotionRemover(1)
    system.addForce(cm)

    # Make a Quantity list that will hold all atom coordinates
    # (dimension is length). Note that by giving the units as
    # angstroms we are telling the system to store everything
    # added to the Quantity in units of angstroms.
    # Once passed to OpenMM, it will be converted to OpenMM's
    # preferred units of nanometer.
    #
    # The following two lines have the same effect:
    # initPositions = Quantity(angstroms)
    initPositions = [] * unit.angstrom

    # Add particles to the system object,
    # and add parameters to Force object.
    # and add coords to initPosition Quantity "list".

    # Add atom 1
    system.addParticle(mass_Ar)
    nb.addParticle(q_Ar, sigma_Ar, epsilon_Ar)
    initPositions.append((0, 0, 0) * unit.angstrom)
    # Add atom 2
    system.addParticle(mass_Ar)
    nb.addParticle(q_Ar, sigma_Ar, epsilon_Ar)
    initPositions.append((5.1, 0, 0) * unit.angstrom)
    # Add atom 3
    system.addParticle(mass_Ar)
    nb.addParticle(q_Ar, sigma_Ar, epsilon_Ar)
    initPositions.append((0, 0.4, 0) * unit.nanometer); # notice different unit

    # Select simple integrator
    integrator = mm.VerletIntegrator(stepSize)
    # Make a context for this system
    # Note you can have more than one context per system
    # because each context object manages a state  type
    # varaiables, not system level things like the force field.
    context=mm.Context(system, integrator, platform)

    # Print GPU number
    if 'Cuda'==platformName or 'OpenCL'==platformName:
        gpuDevice=None
        propNames = platform.getPropertyNames()
        if 'CudaDevice' in propNames:
            gpuDevice = platform.getPropertyValue(context, 'CudaDevice')
            if verbose:
                sys.stdout.write("Cuda Device: %s\n" % gpuDevice)
        if 'OpenCLDeviceIndex' in propNames:
            gpuDevice = platform.getPropertyValue(context, 'OpenCLDeviceIndex')
            if verbose:
                sys.stdout.write("OpenCL Device: %s\n" % gpuDevice)
        if preferredGPU is not None and preferredGPU!=gpuDevice:
            sys.stdout.write("WARNING: preferredGPU (%s) not used\n" %
                             preferredGPU)

    # Set initial atom positions quantity for the new context object
    context.setPositions(initPositions)
    if verbose:
        sys.stdout.write("Num Particles: %d\n" % (nb.getNumParticles()))

    if verbose:
        sys.stdout.write("Run to %s\n" % totalTime)

    # Get initial state
    state = context.getState(getPositions=True,
                             getVelocities=False,
                             getForces=False,
                             getEnergy=True)
    # Get initial energies
    simTime = state.getTime()
    eK = state.getKineticEnergy()
    eP = state.getPotentialEnergy()

    # Get initial coords (should be what we just set it to)
    coords = state.getPositions()


    # Write energies and dump coords to PDB output file
    step=0
    pdbDumpCount=0
    dumpOutput.writeEnergyAndCoords(fOut, step, eK, eP, coords,
                                    simTime, pdbDumpCount, verbose)

    sys.stdout.flush()
    # Run simulation, and write energies and build PDB file
    simTime = 0 * unit.picosecond
    while simTime < totalTime:
        # Run some steps (on GPU if you have one)
        integrator.step(stepsPerReport)

        step+=stepsPerReport
        # Get current state
        state = context.getState(getPositions=True,
                                 getVelocities=False,
                                 getForces=False,
                                 getEnergy=True)
        simTime = state.getTime()
        eK = state.getKineticEnergy()
        eP = state.getPotentialEnergy()
        coords = state.getPositions()

        # Write energies and dump coords to PDB output file
        pdbDumpCount+=1
        dumpOutput.writeEnergyAndCoords(fOut, step, eK, eP, coords,
                                        simTime, pdbDumpCount, verbose)

    # Close PDB file
    dumpOutput.closePDB(fOut)



if __name__=='__main__':
    import time
    wallclockTime0 = time.time()   # Start time so can find total time
    main()
    dTime = time.time()-wallclockTime0
    sys.stdout.write("Total run time = %.1f Sec\n" % dTime)


