#!/usr/local/bin/env python

"""
Test thermostats built into OpenMM for concordance on a simple system.

DESCRIPTION

TODO

COPYRIGHT AND LICENSE

@author John D. Chodera <jchodera@gmail.com>

All code in this repository is released under the GNU General Public License.

This program is free software: you can redistribute it and/or modify it under
the terms of the GNU General Public License as published by the Free Software
Foundation, either version 3 of the License, or (at your option) any later
version.

This program is distributed in the hope that it will be useful, but WITHOUT ANY
WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A
PARTICULAR PURPOSE.  See the GNU General Public License for more details.
 
You should have received a copy of the GNU General Public License along with
this program.  If not, see <http://www.gnu.org/licenses/>.

"""

#=============================================================================================
# GLOBAL IMPORTS
#=============================================================================================

import os
import os.path
import sys
import math

import numpy
import scipy.stats
import scipy.integrate

import simtk
import simtk.unit as units
import simtk.openmm as openmm
import simtk.pyopenmm.extras.testsystems as testsystems

#=============================================================================================
# UTILITY SUBROUTINES
#=============================================================================================

def generateMaxwellBoltzmannVelocities(system, temperature):
   """Generate Maxwell-Boltzmann velocities.
   
   ARGUMENTS
   
   system (simtk.openmm.System) - the system for which velocities are to be assigned
   temperature (simtk.unit.Quantity of temperature) - the temperature at which velocities are to be assigned
   
   RETURNS
   
   velocities (simtk.unit.Quantity of numpy Nx3 array, units length/time) - particle velocities
   
   TODO

   This could be sped up by introducing vector operations.
   
   """
   
   # Get number of atoms
   natoms = system.getNumParticles()
   
   # Create storage for velocities.        
   velocities = units.Quantity(numpy.zeros([natoms, 3], numpy.float32), units.nanometer / units.picosecond) # velocities[i,k] is the kth component of the velocity of atom i
   
   # Compute thermal energy and inverse temperature from specified temperature.
   kB = units.BOLTZMANN_CONSTANT_kB * units.AVOGADRO_CONSTANT_NA
   kT = kB * temperature # thermal energy
   beta = 1.0 / kT # inverse temperature
   
   # Assign velocities from the Maxwell-Boltzmann distribution.
   for atom_index in range(natoms):
      mass = system.getParticleMass(atom_index) # atomic mass
      sigma = units.sqrt(kT / mass) # standard deviation of velocity distribution for each coordinate for this atom
      for k in range(3):
         velocities[atom_index,k] = sigma * numpy.random.normal()

   # Return velocities
   return velocities

def computeHarmonicOscillatorExpectations(K, mass, temperature):
   """
   Compute mean and variance of potential and kinetic energies for harmonic oscillator.

   Numerical quadrature is used.

   ARGUMENTS

   K - spring constant
   mass - mass of particle
   temperature - temperature

   RETURNS

   values

   """

   values = dict()

   # Compute thermal energy and inverse temperature from specified temperature.
   kB = units.BOLTZMANN_CONSTANT_kB * units.AVOGADRO_CONSTANT_NA
   kT = kB * temperature # thermal energy
   beta = 1.0 / kT # inverse temperature
   
   # Compute standard deviation along one dimension.
   sigma = 1.0 / units.sqrt(beta * K) 

   # Define limits of integration along r.
   r_min = 0.0 * units.nanometers # initial value for integration
   r_max = 10.0 * sigma      # maximum radius to integrate to

   # Compute mean and std dev of potential energy.
   V = lambda r : (K/2.0) * (r*units.nanometers)**2 / units.kilojoules_per_mole # potential in kJ/mol, where r in nm
   q = lambda r : 4.0 * math.pi * r**2 * math.exp(-beta * (K/2.0) * (r*units.nanometers)**2) # q(r), where r in nm
   (IqV2, dIqV2) = scipy.integrate.quad(lambda r : q(r) * V(r)**2, r_min / units.nanometers, r_max / units.nanometers)
   (IqV, dIqV)   = scipy.integrate.quad(lambda r : q(r) * V(r), r_min / units.nanometers, r_max / units.nanometers)
   (Iq, dIq)     = scipy.integrate.quad(lambda r : q(r), r_min / units.nanometers, r_max / units.nanometers)
   values['potential'] = dict()
   values['potential']['mean'] = (IqV / Iq) * units.kilojoules_per_mole
   values['potential']['stddev'] = (IqV2 / Iq) * units.kilojoules_per_mole   
   
   # Compute mean and std dev of kinetic energy.
   values['kinetic'] = dict()
   values['kinetic']['mean'] = (3./2.) * kT
   values['kinetic']['stddev'] = math.sqrt(3./2.) * kT

   return values
   
def statisticalInefficiency(A_n, B_n=None, fast=False, mintime=3):
  """
  Compute the (cross) statistical inefficiency of (two) timeseries.

  REQUIRED ARGUMENTS  
    A_n (numpy array) - A_n[n] is nth value of timeseries A.  Length is deduced from vector.

  OPTIONAL ARGUMENTS
    B_n (numpy array) - B_n[n] is nth value of timeseries B.  Length is deduced from vector.
       If supplied, the cross-correlation of timeseries A and B will be estimated instead of the
       autocorrelation of timeseries A.  
    fast (boolean) - if True, will use faster (but less accurate) method to estimate correlation
       time, described in Ref. [1] (default: False)
    mintime (int) - minimum amount of correlation function to compute (default: 3)
       The algorithm terminates after computing the correlation time out to mintime when the
       correlation function furst goes negative.  Note that this time may need to be increased
       if there is a strong initial negative peak in the correlation function.

  RETURNS
    g is the estimated statistical inefficiency (equal to 1 + 2 tau, where tau is the correlation time).
       We enforce g >= 1.0.

  NOTES 
    The same timeseries can be used for both A_n and B_n to get the autocorrelation statistical inefficiency.
    The fast method described in Ref [1] is used to compute g.

  REFERENCES  
    [1] J. D. Chodera, W. C. Swope, J. W. Pitera, C. Seok, and K. A. Dill. Use of the weighted
    histogram analysis method for the analysis of simulated and parallel tempering simulations.
    JCTC 3(1):26-41, 2007.

  EXAMPLES

  Compute statistical inefficiency of timeseries data with known correlation time.  

  >>> import timeseries
  >>> A_n = timeseries.generateCorrelatedTimeseries(N=100000, tau=5.0)
  >>> g = statisticalInefficiency(A_n, fast=True)
  
  """

  # Create numpy copies of input arguments.
  A_n = numpy.array(A_n)
  if B_n is not None:  
    B_n = numpy.array(B_n)
  else:
    B_n = numpy.array(A_n) 
  
  # Get the length of the timeseries.
  N = A_n.size

  # Be sure A_n and B_n have the same dimensions.
  if(A_n.shape != B_n.shape):
    raise ParameterError('A_n and B_n must have same dimensions.')

  # Initialize statistical inefficiency estimate with uncorrelated value.
  g = 1.0
    
  # Compute mean of each timeseries.
  mu_A = A_n.mean()
  mu_B = B_n.mean()

  # Make temporary copies of fluctuation from mean.
  dA_n = A_n.astype(numpy.float64) - mu_A
  dB_n = B_n.astype(numpy.float64) - mu_B

  # Compute estimator of covariance of (A,B) using estimator that will ensure C(0) = 1.
  sigma2_AB = (dA_n * dB_n).mean() # standard estimator to ensure C(0) = 1

  # Trap the case where this covariance is zero, and we cannot proceed.
  if(sigma2_AB == 0):
    raise ParameterException('Sample covariance sigma_AB^2 = 0 -- cannot compute statistical inefficiency')

  # Accumulate the integrated correlation time by computing the normalized correlation time at
  # increasing values of t.  Stop accumulating if the correlation function goes negative, since
  # this is unlikely to occur unless the correlation function has decayed to the point where it
  # is dominated by noise and indistinguishable from zero.
  t = 1
  increment = 1
  while (t < N-1):

    # compute normalized fluctuation correlation function at time t
    C = sum( dA_n[0:(N-t)]*dB_n[t:N] + dB_n[0:(N-t)]*dA_n[t:N] ) / (2.0 * float(N-t) * sigma2_AB)
    
    # Terminate if the correlation function has crossed zero and we've computed the correlation
    # function at least out to 'mintime'.
    if (C <= 0.0) and (t > mintime):
      break
    
    # Accumulate contribution to the statistical inefficiency.
    g += 2.0 * C * (1.0 - float(t)/float(N)) * float(increment)

    # Increment t and the amount by which we increment t.
    t += increment

    # Increase the interval if "fast mode" is on.
    if fast: increment += 1

  # g must be at least unity
  if (g < 1.0): g = 1.0
   
  # Return the computed statistical inefficiency.
  return g

#=============================================================================================
# Set parameters
#=============================================================================================

NSIGMA_CUTOFF = 6.0 # maximum number of standard deviations away from true value before test fails

# Select fastest platform for tests.
nplatforms = openmm.Platform.getNumPlatforms()
platform_speeds = numpy.zeros([nplatforms], numpy.float64)
for platform_index in range(nplatforms):
    platform = openmm.Platform.getPlatform(platform_index)
    platform_speeds[platform_index] = platform.getSpeed()
platform_index = int(numpy.argmax(platform_speeds))
platform = openmm.Platform.getPlatform(platform_index)

# Override platform.
#platform = openmm.Platform.getPlatformByName('Reference')
platform = openmm.Platform.getPlatformByName('CUDA')

# Select run parameters
timestep = 2.0 * units.femtosecond # timestep for integration
nsteps = 250 # number of steps per iteration
nequiliterations = 1000 # number of equilibration iterations
niterations = 5000 # number of integration periods to run

# Set temperature and collision rate for stochastic thermostats.
temperature = 298.0 * units.kelvin
collision_rate = 1.0 / units.picosecond 

# Choose system to test (from testsystems.py).
# TODO: Allow this test to run on multiple systems.
#testsystem = 'HarmonicOscillator'
#testsystem = 'Diatom'
#testsystem = 'ConstraintCoupledHarmonicOscillator'
#testsystem = 'LennardJonesFluid'
#testsystem = 'WaterBox'
testsystem = 'AlanineDipeptideImplicit'

# List of thermostats to test.
thermostats = ['Maxwell-Boltzmann', 'Brownian', 'Langevin', 'Andersen', 'Andersen-massive']
# Maxwell-Boltzmann: generation of velocities from Maxwell-Boltzmann distribution (no dynamics)
# Brownian: BrownianIntegrator
# Langevin: LangevinIntegrator
# Andersen: AndersenThermostat added to system, VerletIntegrator
# Andersen-massive: Mawell-Boltzmann velocities assigned every iteration, VerletIntegrator used in between

# Flag to set verbose debug output
verbose = True

#=============================================================================================
# Initialize statistics.
#=============================================================================================

data = dict()

#=============================================================================================
# Test thermostats
#=============================================================================================

for thermostat in thermostats:
    # Create the test system.integrator
    constructor = getattr(testsystems, testsystem)
    [system, coordinates] = constructor()

    # Compute number of degrees of freedom.
    ndof = 3*system.getNumParticles() - system.getNumConstraints()
    
    # Compute analytically expected average kinetic energy.
    kB = units.BOLTZMANN_CONSTANT_kB * units.AVOGADRO_CONSTANT_NA
    kT = kB * temperature # thermal energy
    EKE_analytical = ndof * 0.5 * kT

    # Compute expectations numerically, if we can.
    numerical_expectations = None
    if testsystem == 'HarmonicOscillator':
       K = 100.0 * units.kilojoules_per_mole / units.angstrom**2
       mass = 39.948 * units.amu
       numerical_expectations = computeHarmonicOscillatorExpectations(K, mass, temperature)
       period = 2*math.pi*units.sqrt(mass/K)
       #print "Period = %f ps" % (period / units.picoseconds)

    # Create integrator and context.
    if thermostat == 'Brownian':
        integrator = openmm.BrownianIntegrator(temperature, collision_rate, timestep)
    elif thermostat == 'Langevin':
        integrator = openmm.LangevinIntegrator(temperature, collision_rate, timestep)
    elif thermostat == 'Andersen':
        # add Andersen thermostat
        force = openmm.AndersenThermostat(temperature, collision_rate)
        system.addForce(force)
        integrator = openmm.VerletIntegrator(timestep)
    elif thermostat in ['Andersen-massive', 'Maxwell-Boltzmann']:
        # Andersen with periodic massive collisions
        integrator = openmm.VerletIntegrator(timestep)        
    else:
        raise Exception("Unknown thermostat '%s'" % thermostat)

    # Create context.
    context = openmm.Context(system, integrator, platform)

    # Set initial positions.
    context.setPositions(coordinates)

    # Minimize energy.
    if verbose: print "Minimizing energy..."
    openmm.LocalEnergyMinimizer.minimize(context)
    
    # Generate initial velocities from Maxwell-Boltzmann distribution.
    velocities = generateMaxwellBoltzmannVelocities(system, temperature)
    context.setVelocities(velocities)
            
    # Initialize statistics.
    data[thermostat] = dict()
    data[thermostat]['time'] = units.Quantity(numpy.zeros([niterations], numpy.float64), units.picoseconds)
    data[thermostat]['potential'] = units.Quantity(numpy.zeros([niterations], numpy.float64), units.kilocalories_per_mole)
    data[thermostat]['kinetic'] = units.Quantity(numpy.zeros([niterations], numpy.float64), units.kilocalories_per_mole)

    # Equilibrate.
    if verbose: print "Equilibrating..."
    integrator.step(nequiliterations * nsteps)

    # Production
    if verbose: print "Production..."
    for iteration in range(niterations):
        if thermostat in ['Andersen-massive', 'Maxwell-Boltzmann']:
            # Reassign velocities from Maxwell-Boltzmann distribution.
            velocities = generateMaxwellBoltzmannVelocities(system, temperature)
            context.setVelocities(velocities)
        
        # Propagate dynamics.
        if thermostat not in ['Maxwell-Boltzmann']:
           integrator.step(nsteps)
        else:
           integrator.step(1)
        
        # Compute energies.
        state = context.getState(getEnergy=True)
        kinetic = state.getKineticEnergy()
        potential = state.getPotentialEnergy()
        if verbose:
            if thermostat != 'Brownian':
                print "%6d %9.3f %16.3f %16.3f" % (iteration, state.getTime() / units.picoseconds, kinetic / units.kilocalories_per_mole, potential / units.kilocalories_per_mole)
            else:
                print "%6d %9.3f %16s %16.3f" % (iteration, state.getTime() / units.picoseconds, 'N/A', potential / units.kilocalories_per_mole)                
        
        # Store energies.
        data[thermostat]['time'][iteration] = state.getTime() 
        data[thermostat]['potential'][iteration] = potential 
        data[thermostat]['kinetic'][iteration] = kinetic 

    # Clean up.
    del system, integrator, context   

#=============================================================================================
# Compute statistical inefficiencies to determine effective number of uncorrelated samples.
#=============================================================================================

for thermostat in thermostats:
    data[thermostat]['g_potential'] = statisticalInefficiency(data[thermostat]['potential'] / units.kilocalories_per_mole)
    if thermostat == 'Brownian': continue
    data[thermostat]['g_kinetic'] = statisticalInefficiency(data[thermostat]['kinetic'] / units.kilocalories_per_mole)

#=============================================================================================
# Compute expectations and uncertainties for kinetic and potential energies.
#=============================================================================================

for thermostat in thermostats:
   d = data[thermostat]
   
   # Kinetic energy.
   if thermostat not in ['Brownian']: 
      d['KE']  = (d['kinetic'] / units.kilocalories_per_mole).mean() * units.kilocalories_per_mole
      d['dKE'] = (d['kinetic'] / units.kilocalories_per_mole).std() / numpy.sqrt(niterations / d['g_kinetic']) * units.kilocalories_per_mole
      d['gKE'] = d['g_kinetic'] * nsteps * timestep / units.picoseconds
      #if numerical_expectations is not None:
      #   d['KE-nsigma'] = abs(d['KE'] - numerical_expectations['kinetic']['mean']) / d['dKE']
      d['KE-nsigma'] = abs(d['KE'] - EKE_analytical) / d['dKE']

   # Potential energy.
   if thermostat not in ['Maxwell-Boltzmann']:
      d['PE']  = (d['potential'] / units.kilocalories_per_mole).mean() * units.kilocalories_per_mole
      d['dPE'] = (d['potential'] / units.kilocalories_per_mole).std() / numpy.sqrt(niterations / d['g_potential']) * units.kilocalories_per_mole
      d['gPE'] = d['g_potential'] * nsteps * timestep / units.picoseconds
      if numerical_expectations is not None:      
         d['PE-nsigma'] = abs(d['PE'] - numerical_expectations['potential']['mean']) / d['dPE']

   # Drift
   if thermostat not in ['Maxwell-Boltzmann']:
      time = d['time'] / units.nanoseconds
      potential = d['potential'] / units.kilocalories_per_mole
      kinetic = d['kinetic'] / units.kilocalories_per_mole      
      (slope, intercept, r, tt, stderr) = scipy.stats.linregress(time, potential)
      stderr *= numpy.sqrt(data[thermostat]['g_potential'])
      d['drift'] = slope * units.kilocalories_per_mole / units.nanoseconds
      d['ddrift'] = stderr * units.kilocalories_per_mole / units.nanoseconds 
      d['drift-nsigma'] = abs(slope / stderr)

#=============================================================================================
# Print summary statistics
#=============================================================================================

test_pass = True # this gets set to False if test fails

print "Summary statistics for system '%s' platform '%s' (%.3f ns equil, %.3f ns production)" % (testsystem, platform.getName(), nequiliterations * nsteps * timestep / units.nanoseconds, niterations * nsteps * timestep / units.nanoseconds)

print ""

# Kinetic energies
print "average kinetic energies:"
print "%32s  %12.3f    %12s  kcal/mol" % ("analytical (from ndof)", EKE_analytical / units.kilocalories_per_mole, "") # analytical
for thermostat in thermostats:
   if thermostat == 'Brownian': continue # Brownian doesn't use kinetic energy
   d = data[thermostat]
   print "%32s  %12.3f +- %12.3f  kcal/mol  (g = %8.1f ps)" % (thermostat, d['KE'] / units.kilocalories_per_mole, d['dKE'] / units.kilocalories_per_mole, d['gKE']),
   print '  %5.1f sigma' % d['KE-nsigma'],
   if (d['KE-nsigma'] > NSIGMA_CUTOFF):
      print ' ***',
      test_pass = False
   print ''
      
# Show numerical results if we are able to compute them.
if testsystem == 'HarmonicOscillator':
   print "%32s  %12.3f    %12s  kcal/mol" % ("numerical", numerical_expectations['kinetic']['mean'] / units.kilocalories_per_mole, "")
print ""

# Potential energies
print "average potential energies:"
for thermostat in thermostats:
   if thermostat not in ['Maxwell-Boltzmann']:
      d = data[thermostat]    
      print "%32s  %12.3f +- %12.3f  kcal/mol  (g = %8.1f ps)" % (thermostat, d['PE'] / units.kilocalories_per_mole, d['dPE'] / units.kilocalories_per_mole, d['gPE']),
      if numerical_expectations is not None:
         print '  %5.1f sigma' % d['PE-nsigma'],
         if (d['PE-nsigma'] > NSIGMA_CUTOFF):
            print ' ***',
            test_pass = False
      print ''

# Show numerical results if we are able to compute them.
if testsystem == 'HarmonicOscillator':
   print "%32s  %12.3f    %12s  kcal/mol" % ("numerical", numerical_expectations['potential']['mean'] / units.kilocalories_per_mole, "")
print ""

# TODO: Check whether these values differ from expected or consensus values by more than, say, six sigma.

print "drift in potential energies:"
for thermostat in thermostats:
   if thermostat not in ['Maxwell-Boltzmann', 'Brownian']:
      d = data[thermostat]    
      print "%32s  %12.5f +- %12.5f  kcal/mol/ns  %5.1f sigma" % (thermostat, d['drift'] / (units.kilocalories_per_mole / units.nanosecond), d['ddrift'] / (units.kilocalories_per_mole / units.nanosecond), d['drift-nsigma']),
      if (d['drift-nsigma'] > NSIGMA_CUTOFF):
         print ' ***',
         test_pass = False
      print ''
            
print ""

#=============================================================================================
# Report pass or fail in exit code
#=============================================================================================

if test_pass:
   sys.exit(0)
else:
   sys.exit(1)

   
