#!/usr/local/bin/env python

"""
Measure drift for constant-energy (NVE) dynamics.

DESCRIPTION

TODO

COPYRIGHT AND LICENSE

@author John D. Chodera <jchodera@gmail.com>

All code in this repository is released under the GNU General Public License.

This program is free software: you can redistribute it and/or modify it under
the terms of the GNU General Public License as published by the Free Software
Foundation, either version 3 of the License, or (at your option) any later
version.

This program is distributed in the hope that it will be useful, but WITHOUT ANY
WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A
PARTICULAR PURPOSE.  See the GNU General Public License for more details.
 
You should have received a copy of the GNU General Public License along with
this program.  If not, see <http://www.gnu.org/licenses/>.

"""

#=============================================================================================
# GLOBAL IMPORTS
#=============================================================================================

import os
import os.path
import sys
import math

import numpy
import scipy.stats
import scipy.integrate

import simtk
import simtk.unit as units
import simtk.openmm as openmm
import simtk.pyopenmm.extras.testsystems 

#=============================================================================================
# Set parameters
#=============================================================================================

NSIGMA_CUTOFF = 6.0 # maximum number of standard deviations away from true value before test fails

# Make a list of platforms.
platforms = [ openmm.Platform.getPlatform(platform_index).getName() for platform_index in range(openmm.Platform.getNumPlatforms()) ]
platforms = ['OpenCL'] # DEBUG

# Select run parameters
timestep = 1.0 * units.femtosecond # timestep for integration
nsteps = 1000 # number of steps per iteration
nequiliterations = 20 # number of equilibration iterations
niterations = 1000 # number of integration periods to run

# Set temperature and collision rate for equilibration with stochastic thermostat.
temperature = 298.0 * units.kelvin
collision_rate = 50.0 / units.picosecond 

# Test systems (from simtk.pyopenmm.extras.testysstems) to run.
#testsystems = ['AlanineDipeptideImplicit', 'WaterBox']
#testsystems = ['LennardJonesCutoff', 'LennardJonesSwitched']
testsystems = ['WaterBoxCutoff', 'WaterBoxSwitched']

# Flag to set verbose debug output
verbose = True

# Minimize before equilibration.
minimize = True

#=============================================================================================
# Initialize statistics.
#=============================================================================================

data = dict()

# Compute thermal energy and inverse temperature from specified temperature.
kB = units.BOLTZMANN_CONSTANT_kB * units.AVOGADRO_CONSTANT_NA
kT = kB * temperature # thermal energy
beta = 1.0 / kT # inverse temperature

#=============================================================================================
# Test thermostats
#=============================================================================================


for platform in platforms:
    # Initialize storage.
    data[platform] = dict()
    
    for testsystem in testsystems:
        if verbose: print "Platform %s system %s" % (platform, testsystem)
        
        # Initialize storage.
        data[platform][testsystem] = dict()
        data[platform][testsystem]['time'] = units.Quantity(numpy.zeros([niterations], numpy.float64), units.picoseconds)
        data[platform][testsystem]['potential'] = units.Quantity(numpy.zeros([niterations], numpy.float64), units.kilocalories_per_mole)
        data[platform][testsystem]['kinetic'] = units.Quantity(numpy.zeros([niterations], numpy.float64), units.kilocalories_per_mole)
        data[platform][testsystem]['total'] = units.Quantity(numpy.zeros([niterations], numpy.float64), units.kilocalories_per_mole)
        
        # Create the test system.
        if testsystem == 'LennardJonesCutoff':
            [system, coordinates] = simtk.pyopenmm.extras.testsystems.LennardJonesFluid(nx=10, ny=10, nz=10, switch=None, cutoff=9.0*units.angstroms, dispersion_correction=False)
        elif testsystem == 'LennardJonesSwitched':
            [system, coordinates] = simtk.pyopenmm.extras.testsystems.LennardJonesFluid(nx=10, ny=10, nz=10, switch=7.0*units.angstroms, cutoff=9.0*units.angstroms, dispersion_correction=False)
        elif testsystem == 'WaterBoxCutoff':
            [system, coordinates] = simtk.pyopenmm.extras.testsystems.WaterBox(switch=None, cutoff=9.0*units.angstroms, nonbonded_method=openmm.NonbondedForce.CutoffPeriodic, constrain=False, flexible=True)
        elif testsystem == 'WaterBoxSwitched':
            [system, coordinates] = simtk.pyopenmm.extras.testsystems.WaterBox(switch=7.0*units.angstroms, cutoff=9.0*units.angstroms, nonbonded_method=openmm.NonbondedForce.CutoffPeriodic, constrain=False, flexible=True)
        else:
            constructor = getattr(simtk.pyopenmm.extras.testsystems, testsystem)
            [system, coordinates] = constructor()

        # Compute number of degrees of freedom.
        data[platform][testsystem]['ndof'] = 3 * system.getNumParticles() - system.getNumConstraints()

        # Equilibrate.
        # TODO: Equilibrate with NPT if periodic.
        if verbose: print "Equilibrating..."
        integrator = openmm.LangevinIntegrator(temperature, collision_rate, timestep)
        context = openmm.Context(system, integrator, openmm.Platform.getPlatformByName(platform))
        context.setPositions(coordinates)
        if minimize: openmm.LocalEnergyMinimizer.minimize(context)    
        for iteration in range(nequiliterations):
            integrator.step(nsteps)
            state = context.getState(getEnergy=True)
            kinetic = state.getKineticEnergy()
            potential = state.getPotentialEnergy()
            total = kinetic + potential
            if verbose: print "NVT equilibration iteration %5d / %5d | potential %8.3f kcal/mol | kinetic %8.3f kcal/mol | total %8.3f kcal/mol" % (iteration, nequiliterations, potential / units.kilocalories_per_mole, kinetic / units.kilocalories_per_mole, total / units.kilocalories_per_mole)
        state = context.getState(getPositions=True, getVelocities=True)
        positions = state.getPositions(asNumpy=True)
        velocities = state.getVelocities(asNumpy=True)
        del context, integrator

        # Production constant-energy (NVE) run.
        if verbose: print "Running production simulation..."
        integrator = openmm.VerletIntegrator(timestep)
        context = openmm.Context(system, integrator, openmm.Platform.getPlatformByName(platform))
        context.setPositions(coordinates)
        context.setVelocities(velocities)
        for iteration in range(niterations):
            integrator.step(nsteps)
            state = context.getState(getEnergy=True)
            kinetic = state.getKineticEnergy()
            potential = state.getPotentialEnergy()
            total = kinetic + potential
            if verbose: print "NVE production iteration %5d / %5d | potential %8.3f kcal/mol | kinetic %8.3f kcal/mol | total %8.3f kcal/mol" % (iteration, niterations, potential / units.kilocalories_per_mole, kinetic / units.kilocalories_per_mole, total / units.kilocalories_per_mole)
            
            # Store energies.
            data[platform][testsystem]['time'][iteration] = state.getTime() 
            data[platform][testsystem]['potential'][iteration] = potential 
            data[platform][testsystem]['kinetic'][iteration] = kinetic
            data[platform][testsystem]['total'][iteration] = total

        del context, integrator
    
#=============================================================================================
# Compute drift.
#=============================================================================================

# TODO: Statistical inefficiency is currently disregarded.

for platform in platforms:
    for testsystem in testsystems:
        d = data[platform][testsystem]
   
        time = d['time'] / units.nanoseconds
        total_energy = d['total'] / kT / d['ndof'] # energy / kT / ndof
        (slope, intercept, r, tt, stderr) = scipy.stats.linregress(time, total_energy)
        d['drift'] = slope
        d['ddrift'] = stderr 
        d['drift-nsigma'] = abs(slope / stderr)
        
#=============================================================================================
# Print summary statistics
#=============================================================================================

test_pass = True # this gets set to False if test fails

#print "Summary statistics for system '%s' platform '%s' (%.3f ns equil, %.3f ns production)" % (testsystem, platform.getName(), nequiliterations * nsteps * timestep / units.nanoseconds, niterations * nsteps * timestep / units.nanoseconds)

print ""

print "drift in total energy from %d samples over %.3f ps of simulation with %.3f fs timestep:" % (niterations, (niterations * nsteps * timestep) / units.picoseconds, timestep / units.femtoseconds)
for testsystem in testsystems:
    for platform in platforms:
        d = data[platform][testsystem]
        
        print "%32s  %32s  %12.5e +- %12.5e  kT/dof/ns  %5.1f sigma" % (platform, testsystem, d['drift'], d['ddrift'], d['drift-nsigma']),
        if (d['drift-nsigma'] > NSIGMA_CUTOFF):
            print ' ***',
            test_pass = False
        print ''
    print ''            
print ''

#=============================================================================================
# Report pass or fail in exit code
#=============================================================================================

if test_pass:
   sys.exit(0)
else:
   sys.exit(1)

   
