#!/bin/env python

"""
Run benchmarks on AMBER systems.
"""

import os, sys, getopt, time

import simtk.chem.openmm as mm 

import simtk.unit as unit
import simtk.chem.amber.systemManager as systemManager
import simtk.chem.amber.crdLoader as crdLoader


def main():
    systems = ['T4-lysozyme-L99A', 'FKBP'] # systems to benchmark

    # Parameters for benchmarking
    timestep = unit.Quantity(2.0, unit.femtosecond)    
    nsteps = 500
    collision_frequency = 90.0 / unit.picosecond
    useGBSA_OBC = True

    # Loop over systems
    for system in systems:
        crdFilename = '%s.crd' % system
        prmtopFilename = '%s.prmtop' % system
        numStepsTotal = nsteps
        numStepsPerReport = nsteps
        temperature = 300.0 * unit.kelvin

        pdbFilename = None
        rstFilename = None
        cMMotionRemoverFrequency = None
        randomSeed = None

        #Parse prmtop file, and make OpenMM System
        sManager=systemManager.SystemManager(mm,
                                             prmtopFilename,
                                             shakeBondsWithH=True,
                                             andersenTemperature=None,
                                             nonbondedCutoff=None,
                                             cMMotionRemoverFrequency=None,
                                             randomSeed=None,
                                             useGBSA_OBC=True,
                                             fDumpInfo=None)
        #                                         fDumpInfo=sys.stdout)
        system=sManager.getSystem()

        #Select simple integrator
        integrator = mm.LangevinIntegrator(temperature, collision_frequency, timestep)

        #Make a context for this system
        context=mm.Context(system, integrator)
        platform=context.getPlatform()
        sys.stdout.write('Using %s platform\n' % platform.getName())

        #Set starting coords
        crd=crdLoader.CrdLoader(crdFilename)
        context.setPositions(crd.getCoords())

        start_time = time.time()
        integrator.step(nsteps)
        end_time = time.time()
        delta = end_time - start_time
        print "%s : %.3f s/ps : %.3f ns/day" % (system, delta, 0.001/delta*24.0*60.0*60.0)

        del context, integrator, crd, system, sManager

if __name__=='__main__':
    main()


