#!/usr/local/bin/env python

"""
Measure drift for constant-energy (NVE) dynamics dependent on system size for Lennard-Jones fluid.

DESCRIPTION

TODO

COPYRIGHT AND LICENSE

@author John D. Chodera <jchodera@gmail.com>

All code in this repository is released under the GNU General Public License.

This program is free software: you can redistribute it and/or modify it under
the terms of the GNU General Public License as published by the Free Software
Foundation, either version 3 of the License, or (at your option) any later
version.

This program is distributed in the hope that it will be useful, but WITHOUT ANY
WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A
PARTICULAR PURPOSE.  See the GNU General Public License for more details.
 
You should have received a copy of the GNU General Public License along with
this program.  If not, see <http://www.gnu.org/licenses/>.

"""

#=============================================================================================
# GLOBAL IMPORTS
#=============================================================================================

import os
import os.path
import sys
import math

import numpy
import scipy.stats
import scipy.integrate

import simtk
import simtk.unit as units
import simtk.openmm as openmm
import simtk.pyopenmm.extras.testsystems as testsystems

#=============================================================================================
# Set parameters
#=============================================================================================

NSIGMA_CUTOFF = 6.0 # maximum number of standard deviations away from true value before test fails

# Make a list of platforms.
platform_name = 'OpenCL'
platform_name = 'Cuda'
platform = openmm.Platform.getPlatformByName(platform_name)

# Select run parameters
timestep = 1.0 * units.femtosecond # timestep for integration
nsteps = 1000 # number of steps per iteration
nequiliterations = 50 # number of equilibration iterations
niterations = 1000 # number of integration periods to run

# Set temperature and collision rate for equilibration with stochastic thermostat.
temperature = 298.0 * units.kelvin
collision_rate = 50.0 / units.picosecond 

# Test systems (from simtk.pyopenmm.extras.testysstems) to run.
nx_to_try = [10, 11, 12, 13, 14, 15] # number of LJ molecules on each box side to try

# Flag to set verbose debug output
verbose = True

# Minimize before equilibration.
minimize = True

# Write data to files.
write_data = True

#=============================================================================================
# Initialize statistics.
#=============================================================================================

data = dict()

# Compute thermal energy and inverse temperature from specified temperature.
kB = units.BOLTZMANN_CONSTANT_kB * units.AVOGADRO_CONSTANT_NA
kT = kB * temperature # thermal energy
beta = 1.0 / kT # inverse temperature

#=============================================================================================
# Test thermostats
#=============================================================================================

for nx in nx_to_try:    
    if verbose: print "nx = %d" % nx
    
    # Initialize storage.
    data[nx] = dict()
    data[nx]['time'] = units.Quantity(numpy.zeros([niterations], numpy.float64), units.picoseconds)
    data[nx]['potential'] = units.Quantity(numpy.zeros([niterations], numpy.float64), units.kilocalories_per_mole)
    data[nx]['kinetic'] = units.Quantity(numpy.zeros([niterations], numpy.float64), units.kilocalories_per_mole)
    data[nx]['total'] = units.Quantity(numpy.zeros([niterations], numpy.float64), units.kilocalories_per_mole)
    
    # Create the test system.
    [system, coordinates] = testsystems.LennardJonesFluid(nx=nx, ny=nx, nz=nx, switch=None, cutoff=9.0*units.angstroms, dispersion_correction=False)

    if write_data:
        nparticles = system.getNumParticles()
        outfile = open('lennard-jones-energies-%d.out' % nparticles, 'w')        

    # Compute number of degrees of freedom.
    data[nx]['ndof'] = 3 * system.getNumParticles() - system.getNumConstraints()

    # Equilibrate.
    # TODO: Equilibrate with NPT if periodic.
    if verbose: print "Equilibrating..."
    integrator = openmm.LangevinIntegrator(temperature, collision_rate, timestep)
    context = openmm.Context(system, integrator, platform)
    context.setPositions(coordinates)
    if minimize: openmm.LocalEnergyMinimizer.minimize(context)    
    for iteration in range(nequiliterations):
        integrator.step(nsteps)
        state = context.getState(getEnergy=True)
        kinetic = state.getKineticEnergy()
        potential = state.getPotentialEnergy()
        total = kinetic + potential
        if verbose: print "NVT equilibration iteration %5d / %5d | potential %8.3f kcal/mol | kinetic %8.3f kcal/mol | total %8.3f kcal/mol" % (iteration, nequiliterations, potential / units.kilocalories_per_mole, kinetic / units.kilocalories_per_mole, total / units.kilocalories_per_mole)
    state = context.getState(getPositions=True, getVelocities=True)
    positions = state.getPositions(asNumpy=True)
    velocities = state.getVelocities(asNumpy=True)
    del context, integrator

    # Production constant-energy (NVE) run.
    if verbose: print "Running production simulation..."
    integrator = openmm.VerletIntegrator(timestep)
    context = openmm.Context(system, integrator, platform)
    context.setPositions(coordinates)
    context.setVelocities(velocities)
    for iteration in range(niterations):
        integrator.step(nsteps)
        state = context.getState(getEnergy=True)
        kinetic = state.getKineticEnergy()
        potential = state.getPotentialEnergy()
        total = kinetic + potential
        if verbose: print "NVE production iteration %5d / %5d | potential %8.3f kcal/mol | kinetic %8.3f kcal/mol | total %8.3f kcal/mol" % (iteration, niterations, potential / units.kilocalories_per_mole, kinetic / units.kilocalories_per_mole, total / units.kilocalories_per_mole)
        
        # Store energies.
        data[nx]['time'][iteration] = state.getTime() 
        data[nx]['potential'][iteration] = potential 
        data[nx]['kinetic'][iteration] = kinetic
        data[nx]['total'][iteration] = total

        if write_data:
            outfile.write('%16.8f %16.8f %16.8f\n' % (potential/kT, kinetic/kT, total/kT))
            outfile.flush()
        
    del context, integrator
    if write_data: outfile.close()

#=============================================================================================
# Compute drift.
#=============================================================================================

# TODO: Statistical inefficiency is currently disregarded.

for nx in nx_to_try:
    d = data[nx]
    
    time = d['time'] / units.nanoseconds
    total_energy = d['total'] / kT / d['ndof'] # energy / kT / ndof
    (slope, intercept, r, tt, stderr) = scipy.stats.linregress(time, total_energy)
    d['drift'] = slope
    d['ddrift'] = stderr 
    d['drift-nsigma'] = abs(slope / stderr)
        
#=============================================================================================
# Print summary statistics
#=============================================================================================

test_pass = True # this gets set to False if test fails

#print "Summary statistics for system '%s' platform '%s' (%.3f ns equil, %.3f ns production)" % (testsystem, platform.getName(), nequiliterations * nsteps * timestep / units.nanoseconds, niterations * nsteps * timestep / units.nanoseconds)

print ""

print "drift in total energy from %d samples over %.3f ps of simulation with %.3f fs timestep:" % (niterations, (niterations * nsteps * timestep) / units.picoseconds, timestep / units.femtoseconds)
for nx in nx_to_try:
    d = data[nx]
        
    print "%8d dof | drift %12.5e +- %12.5e  kT/dof/ns  %5.1f sigma" % (d['ndof'], d['drift'], d['ddrift'], d['drift-nsigma']),
    if (d['drift-nsigma'] > NSIGMA_CUTOFF):
        print ' ***',
        test_pass = False
    print ''
print ''

for nx in nx_to_try:
    d = data[nx]
        
    print "%8d %12.5e %12.5e %5.1f" % (d['ndof'], d['drift'], d['ddrift'], d['drift-nsigma']),
    print ''
print ''


#=============================================================================================
# Report pass or fail in exit code
#=============================================================================================

if test_pass:
   sys.exit(0)
else:
   sys.exit(1)

   
