import simtk.openmm as openmm
import simtk.unit as units
import numpy
import time

#=============================================================================================
# SUBROUTINES
#=============================================================================================

def create_waterbox(box_edge=2.3*units.nanometers, cutoff=0.9*units.nanometers):
   """
   Create a water box test system.

   OPTIONAL ARGUMENTS

   box_edge (simtk.unit.Quantity with units compatible with nanometers) - edge length for cubic box [should be greater than 2*cutoff] (default: 2.3 nm)
   cutoff  (simtk.unit.Quantity with units compatible with nanometers) - nonbonded cutoff (default: 0.9 * units.nanometers)

   RETURNS

   system (simtk.openmm.System) - the water box system
   positions (simtk.unit.Quantity of nparticles x 3 with units compatible with nanometers) - the particle positions

   """
   import simtk.openmm.app as app

   # Load forcefield for solvent model.
   ff =  app.ForceField('tip3p.xml')

   # Create empty topology and coordinates.
   top = app.Topology()
   pos = units.Quantity((), units.angstroms)

   # Create new Modeller instance.
   m = app.Modeller(top, pos)

   # Add solvent to specified box dimensions.
   boxSize = units.Quantity(numpy.ones([3]) * box_edge/box_edge.unit, box_edge.unit)
   m.addSolvent(ff, boxSize=boxSize)
   
   # Get new topology and coordinates.
   newtop = m.getTopology()
   newpos = m.getPositions()
   
   # Convert positions to numpy.
   positions = units.Quantity(numpy.array(newpos / newpos.unit), newpos.unit)
   
   # Create OpenMM System.
   nonbondedMethod = app.CutoffPeriodic
   constraints = app.HBonds
   system = ff.createSystem(newtop, nonbondedMethod=nonbondedMethod, nonbondedCutoff=cutoff, constraints=constraints, rigidWater=True, removeCMMotion=False)

   return [system, positions]

#=============================================================================================
# MAIN
#=============================================================================================

cutoff = 0.9 * units.nanometers
min_box_edge = 2.3 * cutoff
box_edges = [ min_box_edge*(n**(1.0/3.0)) for n in range(1,7) ]

for box_edge in box_edges:
   initial_time = time.time()
   
   [system, positions] = create_waterbox(box_edge=box_edge, cutoff=cutoff)
   
   final_time = time.time()
   setup_time = final_time - initial_time

   # DEBUG
   print "box edge: %8.3f nm | number of particles: %12d | setup time %8.3f s" % (box_edge / units.nanometers, system.getNumParticles(), setup_time)

   # Test with MD.
   temperature = 298.0 * units.kelvin
   collision_rate = 91.0 / units.picosecond
   timestep = 2.0 * units.femtosecond
   nsteps = 1000
   integrator = openmm.LangevinIntegrator(temperature, collision_rate, timestep)
   context = openmm.Context(system, integrator)
   context.setPositions(positions)

   state = context.getState(getEnergy=True)
   potential_energy = state.getPotentialEnergy()
   print "   potential energy: %8.3f kcal/mol" % (potential_energy / units.kilocalories_per_mole)

   initial_time = time.time()
   integrator.step(nsteps)
   final_time = time.time()
   elapsed_time = final_time - initial_time

   state = context.getState(getEnergy=True)
   potential_energy = state.getPotentialEnergy()
   print "   potential energy: %8.3f kcal/mol" % (potential_energy / units.kilocalories_per_mole)
   print "   elapsed dynamics time: %8.3f s" % elapsed_time
   
