#!/usr/local/bin/env python

#=============================================================================================
# MODULE DOCSTRING
#=============================================================================================

"""
Context caching test.

"""

#=============================================================================================
# GLOBAL IMPORTS
#=============================================================================================

import os
import os.path
import sys
import copy
import time

import numpy

import simtk.unit as units
import simtk.openmm as openmm

#=============================================================================================
# MAIN
#=============================================================================================

# Parameters.
temperature = 300.0 * units.kelvin
timestep = 2.0 * units.femtoseconds
collision_rate = 9.1 / units.picoseconds
niterations = 10 # number of iterations
nsteps = 500 # number of timesteps per iteration.
noverhead = 3 # number of times to test context creation/deletion
ncopies = 500 # number of copies to cache

prmtop_filename = 'complex.prmtop'
inpcrd_filename = 'complex.crd'

# Read OpenMM System from AMBER prmtop / inpcrd.
print "Reading AMBER prmtop and inpcrd files..."
import simtk.openmm.app as app
nonbondedMethod = app.NoCutoff
implicitSolvent = app.OBC2
constraints = app.HBonds
removeCMMotion = False

initial_time = time.time()
system = app.AmberPrmtopFile(prmtop_filename).createSystem(nonbondedMethod=nonbondedMethod, implicitSolvent=implicitSolvent, constraints=constraints, removeCMMotion=removeCMMotion)
final_time = time.time()
elapsed_time = final_time - initial_time
print "   %.3f s elapsed reading prmtop file and creating System..." % elapsed_time

initial_time = time.time()
positions = app.AmberInpcrdFile(inpcrd_filename).getPositions(asNumpy=True)   
final_time = time.time()
elapsed_time = final_time - initial_time
print "   %.3f s elapsed reading inpcrd file and creating coordinates..." % elapsed_time

velocities = positions * 0.0

# Test context creation/deletion overhead.
for overhead_test in range(noverhead):
    print "Testing Context creation + deletion overhead try %d / %d..." % (overhead_test, noverhead)
    initial_time = time.time()
    integrator = openmm.LangevinIntegrator(temperature, collision_rate, timestep)
    context = openmm.Context(system, integrator)
    context.setPositions(positions)
    context.setVelocities(velocities)
    del context, integrator
    final_time = time.time()
    elapsed_time = final_time - initial_time
    print "   %.3f s elapsed to create integrator and context, push coordinates and velocities, and clean up." % elapsed_time

# Create and cache Context objects.
integrators = list()
contexts = list()
for copy_index in range(ncopies):
    print "Creating and caching Context %d / %d..." % (copy_index, ncopies)

    try:
        initial_time = time.time()
        
        integrator = openmm.LangevinIntegrator(temperature, collision_rate, timestep)
        context = openmm.Context(system, integrator)
        
        integrators.append(integrator)
        contexts.append(context)
        
        final_time = time.time()
        elapsed_time = final_time - initial_time
        print "   %.3f s elapsed." % elapsed_time

    except Exception as e:
        print "Caught exception: " + str(e)
        print "Recovering from exception---giving up on caching."
        break
        
print "Completed caching test."



