#!/usr/local/bin/env python

"""
Test analytical dispersion correction.

DESCRIPTION

To test the analytical dispersion correction, we test to make sure that enlarging the
cutoff only minimally perturbs energy in an isotropic Lennard-Jones fluid.

TODO

* Have test fail if computed density not within NSIGMA of acceptable energy difference.

COPYRIGHT AND LICENSE

@author John D. Chodera <jchodera@gmail.com>

All code in this repository is released under the GNU General Public License.

This program is free software: you can redistribute it and/or modify it under
the terms of the GNU General Public License as published by the Free Software
Foundation, either version 3 of the License, or (at your option) any later
version.

This program is distributed in the hope that it will be useful, but WITHOUT ANY
WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A
PARTICULAR PURPOSE.  See the GNU General Public License for more details.
 
You should have received a copy of the GNU General Public License along with
this program.  If not, see <http://www.gnu.org/licenses/>.

"""

#=============================================================================================
# GLOBAL IMPORTS
#=============================================================================================

import os
import os.path
import sys
import math

import numpy

import simtk
import simtk.unit as units
import simtk.openmm as openmm
import simtk.pyopenmm.extras.testsystems as testsystems

#=============================================================================================
# Set parameters
#=============================================================================================

# Test pass/fail criteria
MAX_ALLOWED_POTENTIAL_DIFFERENCE = 0.6 * units.kilocalories_per_mole
NSIGMA_CUTOFF = 6.0 # maximum number of standard deviations away from true value before test fails

# Constants
kB = units.BOLTZMANN_CONSTANT_kB * units.AVOGADRO_CONSTANT_NA

# Select fastest platform to test.
nplatforms = openmm.Platform.getNumPlatforms()
platform_speeds = numpy.zeros([nplatforms], numpy.float64)
for platform_index in range(nplatforms):
    platform = openmm.Platform.getPlatform(platform_index)
    platform_speeds[platform_index] = platform.getSpeed()
platform_index = int(numpy.argmax(platform_speeds))
platform = openmm.Platform.getPlatform(platform_index)

# Select run parameters
timestep = 2.0 * units.femtosecond # timestep for integrtion
nsteps = 2500 # number of steps per data record
nequiliterations = 10 # number of equilibration iterations
niterations = 25 # number of iterations to collect data for

# Lennard-Jones fluid parameters (argon).
mass     = 39.9 * units.amu
sigma    = 3.4 * units.angstrom
epsilon  = 120.0 * units.kelvin * kB

# Set temperature, pressure, and collision rate for stochastic thermostats.
temperature = 0.9 / (kB / epsilon)
pressure = 4.0 / (sigma**3 / epsilon) / units.AVOGADRO_CONSTANT_NA
barostat_frequency = 25 # number of steps between MC volume adjustments
collision_rate = 5.0 / units.picosecond # collision rate for Langevin integrator
print "temperature = %.1f K, pressure = %.1f atm" % (temperature / units.kelvin, pressure / units.atmospheres)

# Determine whether dispersion correction should be used (should be True).
use_dispersion_correction = True

# Small and large cutoffs.
short_cutoff = 3.0 * sigma
long_cutoff = 6.0 * sigma

# Number of particles per dimension.
nx = 12
ny = 12
nz = 12

# Flag to set verbose debug output
verbose = False

#=============================================================================================
# SUBROUTINES
#=============================================================================================

def generateMaxwellBoltzmannVelocities(system, temperature):
   """Generate Maxwell-Boltzmann velocities.
   
   ARGUMENTS
   
   system (simtk.openmm.System) - the system for which velocities are to be assigned
   temperature (simtk.unit.Quantity of temperature) - the temperature at which velocities are to be assigned
   
   RETURNS
   
   velocities (simtk.unit.Quantity of numpy Nx3 array, units length/time) - particle velocities
   
   TODO

   This could be sped up by introducing vector operations.
   
   """
   
   # Get number of atoms
   natoms = system.getNumParticles()
   
   # Create storage for velocities.        
   velocities = units.Quantity(numpy.zeros([natoms, 3], numpy.float32), units.nanometer / units.picosecond) # velocities[i,k] is the kth component of the velocity of atom i
   
   # Compute thermal energy and inverse temperature from specified temperature.
   kB = units.BOLTZMANN_CONSTANT_kB * units.AVOGADRO_CONSTANT_NA
   kT = kB * temperature # thermal energy
   beta = 1.0 / kT # inverse temperature
   
   # Assign velocities from the Maxwell-Boltzmann distribution.
   for atom_index in range(natoms):
      mass = system.getParticleMass(atom_index) # atomic mass
      sigma = units.sqrt(kT / mass) # standard deviation of velocity distribution for each coordinate for this atom
      for k in range(3):
         velocities[atom_index,k] = sigma * numpy.random.normal()

   # Return velocities
   return velocities
   
def statisticalInefficiency(A_n, B_n=None, fast=False, mintime=3):
  """
  Compute the (cross) statistical inefficiency of (two) timeseries.

  REQUIRED ARGUMENTS  
    A_n (numpy array) - A_n[n] is nth value of timeseries A.  Length is deduced from vector.

  OPTIONAL ARGUMENTS
    B_n (numpy array) - B_n[n] is nth value of timeseries B.  Length is deduced from vector.
       If supplied, the cross-correlation of timeseries A and B will be estimated instead of the
       autocorrelation of timeseries A.  
    fast (boolean) - if True, will use faster (but less accurate) method to estimate correlation
       time, described in Ref. [1] (default: False)
    mintime (int) - minimum amount of correlation function to compute (default: 3)
       The algorithm terminates after computing the correlation time out to mintime when the
       correlation function furst goes negative.  Note that this time may need to be increased
       if there is a strong initial negative peak in the correlation function.

  RETURNS
    g is the estimated statistical inefficiency (equal to 1 + 2 tau, where tau is the correlation time).
       We enforce g >= 1.0.

  NOTES 
    The same timeseries can be used for both A_n and B_n to get the autocorrelation statistical inefficiency.
    The fast method described in Ref [1] is used to compute g.

  REFERENCES  
    [1] J. D. Chodera, W. C. Swope, J. W. Pitera, C. Seok, and K. A. Dill. Use of the weighted
    histogram analysis method for the analysis of simulated and parallel tempering simulations.
    JCTC 3(1):26-41, 2007.

  EXAMPLES

  Compute statistical inefficiency of timeseries data with known correlation time.  

  >>> import timeseries
  >>> A_n = timeseries.generateCorrelatedTimeseries(N=100000, tau=5.0)
  >>> g = statisticalInefficiency(A_n, fast=True)
  
  """

  # Create numpy copies of input arguments.
  A_n = numpy.array(A_n)
  if B_n is not None:  
    B_n = numpy.array(B_n)
  else:
    B_n = numpy.array(A_n) 
  
  # Get the length of the timeseries.
  N = A_n.size

  # Be sure A_n and B_n have the same dimensions.
  if(A_n.shape != B_n.shape):
    raise ParameterError('A_n and B_n must have same dimensions.')

  # Initialize statistical inefficiency estimate with uncorrelated value.
  g = 1.0
    
  # Compute mean of each timeseries.
  mu_A = A_n.mean()
  mu_B = B_n.mean()

  # Make temporary copies of fluctuation from mean.
  dA_n = A_n.astype(numpy.float64) - mu_A
  dB_n = B_n.astype(numpy.float64) - mu_B

  # Compute estimator of covariance of (A,B) using estimator that will ensure C(0) = 1.
  sigma2_AB = (dA_n * dB_n).mean() # standard estimator to ensure C(0) = 1

  # Trap the case where this covariance is zero, and we cannot proceed.
  if(sigma2_AB == 0):
    raise ParameterException('Sample covariance sigma_AB^2 = 0 -- cannot compute statistical inefficiency')

  # Accumulate the integrated correlation time by computing the normalized correlation time at
  # increasing values of t.  Stop accumulating if the correlation function goes negative, since
  # this is unlikely to occur unless the correlation function has decayed to the point where it
  # is dominated by noise and indistinguishable from zero.
  t = 1
  increment = 1
  while (t < N-1):

    # compute normalized fluctuation correlation function at time t
    C = sum( dA_n[0:(N-t)]*dB_n[t:N] + dB_n[0:(N-t)]*dA_n[t:N] ) / (2.0 * float(N-t) * sigma2_AB)
    
    # Terminate if the correlation function has crossed zero and we've computed the correlation
    # function at least out to 'mintime'.
    if (C <= 0.0) and (t > mintime):
      break
    
    # Accumulate contribution to the statistical inefficiency.
    g += 2.0 * C * (1.0 - float(t)/float(N)) * float(increment)

    # Increment t and the amount by which we increment t.
    t += increment

    # Increase the interval if "fast mode" is on.
    if fast: increment += 1

  # g must be at least unity
  if (g < 1.0): g = 1.0
   
  # Return the computed statistical inefficiency.
  return g

def compute_volume(box_vectors):
   """
   Return the volume of the current configuration.
   
   RETURNS
   
   volume (simtk.unit.Quantity) - the volume of the system (in units of length^3), or None if no box coordinates are defined
   
   """

   # Compute volume of parallelepiped.
   [a,b,c] = box_vectors
   A = numpy.array([a/a.unit, b/a.unit, c/a.unit])
   volume = numpy.linalg.det(A) * a.unit**3
   return volume

def compute_mass(system):
   """
   Returns the total mass of the system in amu.

   RETURNS

   mass (simtk.unit.Quantity) - the mass of the system (in units of amu)

   """

   mass = 0.0 * units.amu
   for i in range(system.getNumParticles()):
      mass += system.getParticleMass(i)
   return mass

#=============================================================================================
# Test thermostats
#=============================================================================================

# Create the test system to simulate (with short cutoff).
if verbose: print "Constructing system..."
[system, coordinates] = testsystems.LennardJonesFluid(nx=nx, ny=ny, nz=nz, mass=mass, sigma=sigma, epsilon=epsilon, cutoff=short_cutoff, dispersion_correction=use_dispersion_correction)

# Determine number of degrees of freedom.
kB = units.BOLTZMANN_CONSTANT_kB * units.AVOGADRO_CONSTANT_NA
ndof = 3*system.getNumParticles() - system.getNumConstraints()
nparticles = system.getNumParticles()

# Compute total mass.
system_mass = compute_mass(system).in_units_of(units.gram / units.mole) / units.AVOGADRO_CONSTANT_NA # total system mass in g

# Add Monte Carlo barostat.
barostat = openmm.MonteCarloBarostat(pressure, temperature, barostat_frequency)
system.addForce(barostat)
    
# Create integrator.
integrator = openmm.LangevinIntegrator(temperature, collision_rate, timestep)        

# Create context.
context = openmm.Context(system, integrator, platform)

# Set initial positions.
context.setPositions(coordinates)

# Assign velocities.
velocities = generateMaxwellBoltzmannVelocities(system, temperature)
context.setVelocities(velocities)

# Initialize statistics.
data = dict()
data['time'] = units.Quantity(numpy.zeros([niterations], numpy.float64), units.picoseconds)
data['potential'] = units.Quantity(numpy.zeros([niterations], numpy.float64), units.kilocalories_per_mole)
data['kinetic'] = units.Quantity(numpy.zeros([niterations], numpy.float64), units.kilocalories_per_mole)
data['volume'] = units.Quantity(numpy.zeros([niterations], numpy.float64), units.angstroms**3)
data['density'] = units.Quantity(numpy.zeros([niterations], numpy.float64), units.gram / units.centimeters**3)
data['kinetic_temperature'] = units.Quantity(numpy.zeros([niterations], numpy.float64), units.kelvin)
data['delta_potential'] = units.Quantity(numpy.zeros([niterations], numpy.float64), units.kilocalories_per_mole)
data['reduced_density'] = numpy.zeros([niterations], numpy.float64)

# Equilibrate.
if verbose: print "Equilibrating..."
for iteration in range(nequiliterations):
   # Integrate.
   integrator.step(nsteps)
   
   # Compute properties.
   if verbose:
      state = context.getState(getEnergy=True)
      kinetic = state.getKineticEnergy()
      potential = state.getPotentialEnergy()
      box_vectors = state.getPeriodicBoxVectors()
      volume = compute_volume(box_vectors)
      density = (system_mass / volume).in_units_of(units.gram / units.centimeter**3)
      kinetic_temperature = 2.0 * kinetic / kB / ndof # (1/2) ndof * kB * T = KE
      print "%6d %9.3f %16.3f %16.3f %16.1f %10.3f" % (iteration, state.getTime() / units.picoseconds, kinetic_temperature / units.kelvin, potential / units.kilocalories_per_mole, volume / units.angstroms**3, density / (units.gram / units.centimeter**3))

# Create a copy of the system with long cutoff, and set up integrator/context on Reference system to evaluate energy.
[system_long_cutoff, coordinates] = testsystems.LennardJonesFluid(nx=nx, ny=ny, nz=nz, mass=mass, sigma=sigma, epsilon=epsilon, cutoff=long_cutoff, dispersion_correction=use_dispersion_correction)
integrator_long_cutoff = openmm.VerletIntegrator(timestep)
reference_platform = openmm.Platform.getPlatformByName('Reference')
context_long_cutoff = openmm.Context(system_long_cutoff, integrator_long_cutoff, reference_platform)

# Collect production data.
if verbose: print "Production..."
for iteration in range(niterations):
   # Propagate dynamics.
   integrator.step(nsteps)
   
   # Compute properties.
   state = context.getState(getEnergy=True, getPositions=True)
   kinetic = state.getKineticEnergy()
   potential = state.getPotentialEnergy()
   box_vectors = state.getPeriodicBoxVectors()
   volume = compute_volume(box_vectors)
   density = (system_mass / volume).in_units_of(units.gram / units.centimeter**3)
   reduced_density = nparticles * sigma**3 / volume # reduced (dimensionless) density
   kinetic_temperature = 2.0 * kinetic / kB / ndof
   coordinates = state.getPositions(asNumpy=True)

   # Compute properties at long cutoff.   
   context_long_cutoff.setPositions(coordinates)
   context_long_cutoff.setPeriodicBoxVectors(*box_vectors)
   state_long_cutoff = context_long_cutoff.getState(getEnergy=True)
   potential_long_cutoff = state_long_cutoff.getPotentialEnergy()
   delta_potential = potential_long_cutoff - potential

   if verbose:
      print "%6d %9.3f %16.3f %16.3f %16.1f %10.3f : %10.6e" % (iteration, state.getTime() / units.picoseconds, kinetic_temperature / units.kelvin, potential / units.kilocalories_per_mole, volume / units.angstroms**3, density / (units.gram / units.centimeter**3), delta_potential / units.kilocalories_per_mole)
      
   # Store properties.
   data['time'][iteration] = state.getTime() 
   data['potential'][iteration] = potential 
   data['kinetic'][iteration] = kinetic
   data['volume'][iteration] = volume
   data['density'][iteration] = density
   data['kinetic_temperature'][iteration] = kinetic_temperature
   data['delta_potential'][iteration] = delta_potential
   data['reduced_density'][iteration] = reduced_density
   
#=============================================================================================
# Compute statistical inefficiencies to determine effective number of uncorrelated samples.
#=============================================================================================

data['g_potential'] = statisticalInefficiency(data['potential'] / units.kilocalories_per_mole)
data['g_kinetic'] = statisticalInefficiency(data['kinetic'] / units.kilocalories_per_mole)
data['g_volume'] = statisticalInefficiency(data['volume'] / units.angstroms**3)
data['g_density'] = statisticalInefficiency(data['density'] / (units.gram / units.centimeter**3))
data['g_kinetic_temperature'] = statisticalInefficiency(data['kinetic_temperature'] / units.kelvin)
data['g_delta_potential'] = statisticalInefficiency(data['delta_potential'] / units.kilocalories_per_mole)

#=============================================================================================
# Compute statistics.
#=============================================================================================

statistics = dict()

# Difference between short and long cutoffs.
statistics['delta_potential']  = (data['delta_potential'] / units.kilocalories_per_mole).mean() * units.kilocalories_per_mole
statistics['ddelta_potential'] = (data['delta_potential'] / units.kilocalories_per_mole).std() / numpy.sqrt(niterations / data['g_delta_potential']) * units.kilocalories_per_mole
statistics['g_delta_potential'] = data['g_delta_potential'] * nsteps * timestep

# Reduced density
statistics['density']  = data['reduced_density'].mean() 
statistics['ddensity'] = data['reduced_density'].std() / numpy.sqrt(niterations / data['g_density']) 
statistics['g_density'] = data['g_density'] * nsteps * timestep

#=============================================================================================
# Print summary statistics
#=============================================================================================

test_pass = True # this gets set to False if test fails

print "Summary statistics (%.1f ns equil, %.1f ns production)" % (nequiliterations * nsteps * timestep / units.nanoseconds, niterations * nsteps * timestep / units.nanoseconds)

print ""

# Potential energy difference in enlarging cutoff
print "average potential energy difference in enlarging cutoff from %.3f sigma to %.3f sigma:" % (short_cutoff / sigma, long_cutoff / sigma)
print "%12.3f +- %10.3f  kcal/mol  (g = %8.1f ps)" % (statistics['delta_potential'] / units.kilocalories_per_mole, statistics['ddelta_potential'] / units.kilocalories_per_mole, statistics['g_delta_potential'] / units.picoseconds)

# Average density.
print ""
print "reduced density (N sigma**3 / V):"
print "%12.5f +- %10.5f            (g = %8.1f ps)" % (statistics['density'], statistics['ddensity'], statistics['g_density'] / units.picoseconds)

#=============================================================================================
# Check pass or fail conditions.
#=============================================================================================

if abs(statistics['delta_potential']) > (MAX_ALLOWED_POTENTIAL_DIFFERENCE + NSIGMA_CUTOFF * statistics['ddelta_potential']):
   # Test has failed.
   print ""
   print "***** Test exceeds allowed potential difference %s by more than %.1f sigma -- test failed" % (str(MAX_ALLOWED_POTENTIAL_DIFFERENCE), NSIGMA_CUTOFF)
   test_pass = False
   

#=============================================================================================
# Report pass or fail in exit code
#=============================================================================================

if test_pass:
   sys.exit(0)
else:
   sys.exit(1)

   
