#!/usr/local/bin/env python

#=============================================================================================
# MODULE DOCSTRING
#=============================================================================================

"""
Test how energy conservation of VerletIntegrator depends on system, platform, timestep, and
number of timesteps.

DESCRIPTION

TODO

* Add failure conditions.

COPYRIGHT AND LICENSE

@author John D. Chodera <jchodera@gmail.com>

All code in this repository is released under the GNU General Public License.

This program is free software: you can redistribute it and/or modify it under
the terms of the GNU General Public License as published by the Free Software
Foundation, either version 3 of the License, or (at your option) any later
version.

This program is distributed in the hope that it will be useful, but WITHOUT ANY
WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A
PARTICULAR PURPOSE.  See the GNU General Public License for more details.
 
You should have received a copy of the GNU General Public License along with
this program.  If not, see <http://www.gnu.org/licenses/>.

"""

#=============================================================================================
# GLOBAL IMPORTS
#=============================================================================================

import sys
import math
import doctest
import numpy
import os.path

import simtk.unit as units
import simtk.openmm as openmm
import simtk.pyopenmm.extras.testsystems as testsystems

#=============================================================================================
# TEST PARAMETERS
#=============================================================================================

NSIGMA_CUTOFF = 6.0 # maximum number of standard deviations away from true value before test fails

systems = ['WaterBox'] # systems to test
timesteps = [1.0 * units.femtosecond, 0.5 * units.femtosecond, 0.1 * units.femtosecond, 0.01 * units.femtosecond] # timesteps to test
nsteps = [1, 10, 100, 1000] # number of steps to test
nsamples = 50

temperature = 300.0 * units.kelvin # temperature to perform tests
kB = units.BOLTZMANN_CONSTANT_kB * units.AVOGADRO_CONSTANT_NA
kT = kB * temperature # thermal energy
beta = 1.0 / kT # inverse temperature

constraint_tolerance = 1.0e-8

# Equilibration parameters.
collision_rate = 20.0 / units.picosecond
pressure = 1.0 * units.atmospheres

verbose = False

#=============================================================================================
# SUBROUTINES
#=============================================================================================

def generateMaxwellBoltzmannVelocities(system, temperature):
   """Generate Maxwell-Boltzmann velocities.
   
   ARGUMENTS
   
   system (simtk.openmm.System) - the system for which velocities are to be assigned
   temperature (simtk.unit.Quantity of temperature) - the temperature at which velocities are to be assigned
   
   RETURNS
   
   velocities (simtk.unit.Quantity of numpy Nx3 array, units length/time) - particle velocities
   
   TODO

   This could be sped up by introducing vector operations.
   
   """
   
   # Get number of atoms
   natoms = system.getNumParticles()
   
   # Create storage for velocities.        
   velocities = units.Quantity(numpy.zeros([natoms, 3], numpy.float32), units.nanometer / units.picosecond) # velocities[i,k] is the kth component of the velocity of atom i
   
   # Compute thermal energy and inverse temperature from specified temperature.
   kB = units.BOLTZMANN_CONSTANT_kB * units.AVOGADRO_CONSTANT_NA
   kT = kB * temperature # thermal energy
   beta = 1.0 / kT # inverse temperature
   
   # Assign velocities from the Maxwell-Boltzmann distribution.
   for atom_index in range(natoms):
      mass = system.getParticleMass(atom_index) # atomic mass
      sigma = units.sqrt(kT / mass) # standard deviation of velocity distribution for each coordinate for this atom
      for k in range(3):
         velocities[atom_index,k] = sigma * numpy.random.normal()

   # Return velocities
   return velocities

def metropolis(E):
   return min(1.0, math.exp(-E))
metropolis = numpy.vectorize(metropolis)

#=============================================================================================
# MAIN
#=============================================================================================

data = dict()

test_pass = True 

## Build a list of all platform names to test.
#nplatforms = openmm.Platform.getNumPlatforms()
#platform_names = list()
#for platform_index in range(nplatforms):
#    platform = openmm.Platform.getPlatform(platform_index)
#    platform_names.append(platform.getName())

# Select fastest platform as only platform to test
nplatforms = openmm.Platform.getNumPlatforms()
platform_speeds = numpy.zeros([nplatforms], numpy.float64)
for platform_index in range(nplatforms):
    platform = openmm.Platform.getPlatform(platform_index)
    platform_speeds[platform_index] = platform.getSpeed()
platform_index = int(numpy.argmax(platform_speeds))
platform = openmm.Platform.getPlatform(platform_index)
platform_names = [platform.getName()]

for system_name in systems:
   # Create system.
   print "Constructing system '%s'..." % system_name
   constructor = getattr(testsystems, system_name)
   if system_name == 'WaterBox':
      filename = os.path.join(os.getenv('PYOPENMM_SOURCE_DIR'), 'test', 'additional-tests', 'systems', 'waterbox', 'tip3p900.pdb')       
      [system, coordinates] = constructor(filename=filename, cutoff = 9.0 * units.angstroms, nonbonded_method=openmm.NonbondedForce.PME, constrain=False, flexible=True)
      #[system, coordinates] = constructor(constrain=True, flexible=False, cutoff=9.0*units.angstroms, charges=True)
      #[system, coordinates] = constructor(constrain=False, flexible=True, cutoff=9.0*units.angstroms, charges=True)      
   elif system_name == 'LennardJonesFluid':
      sigma = 3.4 * units.angstrom # argon
      cutoff = 2.5 * sigma
      #[system, coordinates] = constructor(nx=10, ny=10, nz=10, sigma=sigma, cutoff=cutoff)
      [system, coordinates] = constructor(nx=10, ny=10, nz=10, sigma=sigma, cutoff=cutoff, switch=2.0*sigma)      
   elif system_name == 'Diatom':
      [system, coordinates] = constructor(constraint=False)
   else:
      [system, coordinates] = constructor()      
      
   # Add barostat (which will only be used during equilibration.
   barostat_frequency = 25 # number of steps between MC volume adjustments
   barostat = openmm.MonteCarloBarostat(pressure, temperature, barostat_frequency)
   system.addForce(barostat)    

   for platform_name in platform_names:
      # Select platform
      if verbose: print platform_name
      platform = openmm.Platform.getPlatformByName(platform_name)

      for sample in range(nsamples):
         if verbose: print sample

         # Equilibrate with barostat.
         barostat.setFrequency(barostat_frequency)         
         timestep = 1.0 * units.femtosecond
         nequilsteps = 500
         integrator = openmm.LangevinIntegrator(temperature, collision_rate, timestep)
         context = openmm.Context(system, integrator, platform)
         if constraint_tolerance is not None:
            integrator.setConstraintTolerance(constraint_tolerance)
         context.setPositions(coordinates)
         velocities = generateMaxwellBoltzmannVelocities(system, temperature)    
         integrator.step(nequilsteps)
         state = context.getState(getPositions=True, getVelocities=True)    
         coordinates = state.getPositions(asNumpy=True)
         velocities = state.getVelocities(asNumpy=True)
         box_vectors = state.getPeriodicBoxVectors()
         system.setDefaultPeriodicBoxVectors(*box_vectors)
         del state, context, integrator
         
         # Disable barostat.
         barostat.setFrequency(32766)
         
         if verbose:
            for i in range(len(nsteps)):
               print "%6d steps" % nsteps[i],
            print ""

         for timestep in timesteps:
            if verbose: print str(timestep) + " timestep"
            
            # Create integrator and context.
            integrator = openmm.VerletIntegrator(timestep)
            context = openmm.Context(system, integrator, platform)
            if constraint_tolerance is not None:
               integrator.setConstraintTolerance(constraint_tolerance)
            
            # Set positions and velocities
            context.setPositions(coordinates)
            context.setVelocities(velocities)
            
            # Take a step to engage constraints.
            # TODO: Replace this with context.applyConstraints(tol)
            integrator.step(2)
            
            # Compute initial energy.
            state = context.getState(getEnergy=True)
            initial_energy = state.getPotentialEnergy() + state.getKineticEnergy()
            
            for i in range(len(nsteps)):                   
               if i == 0:
                  integrator.step(nsteps[i])
               else:
                  integrator.step(nsteps[i] - nsteps[i-1])

               state = context.getState(getEnergy=True)
               final_energy = state.getPotentialEnergy() + state.getKineticEnergy()
               
               # Compute energy differences.
               delta_energy = (final_energy - initial_energy) / kT
               
               key = (system_name, platform_name, sample, timestep, nsteps[i])
               data[key] = delta_energy

               if verbose: print "%12.5f" % (delta_energy),

            if verbose: print ""
            
            # Clean up
            del state, context, integrator
            
      # Print statistics.
      print "summary statistics for system '%s' on platform '%s' from %d sampled trajectories" % (system_name, platform_name, nsamples)
      if constraint_tolerance is not None:
         print "constraint tolerance is %e" % constraint_tolerance
      else:
         print "constraint tolerance left at default value"
      for timestep in timesteps:
         print timestep
         for i in range(len(nsteps)):
            # Extract samples.
            dE = numpy.zeros([nsamples], numpy.float64)
            for sample in range(nsamples):
               key = (system_name, platform_name, sample, timestep, nsteps[i])           
               dE[sample] = data[key]
            mean = dE.mean()
            dmean = dE.std() / math.sqrt(nsamples) # TODO: Use statistical inefficiency
            stddev = dE.std()
            dstddev = stddev * math.sqrt(1.0 / 2.0 / (nsamples-1))
            accept = numpy.mean(metropolis(dE))
            print "%10d steps : mean %12.5f +- %12.5f kT  std %12.5f +- %12.5f kT  accept %10.6f (%12.8e)" % (nsteps[i], mean, dmean, stddev, dstddev, accept, accept)

#=============================================================================================
# Report pass or fail in exit code
#=============================================================================================

if test_pass:
   sys.exit(0)
else:
   sys.exit(1)

   

