#!/usr/local/bin/env python

"""
Test of Andersen thermostat.

"""

#=============================================================================================
# GLOBAL IMPORTS
#=============================================================================================

import os
import os.path
import numpy
import math

import simtk
import simtk.chem.openmm as openmm
import simtk.unit as units

#=============================================================================================
# Set parameters
#=============================================================================================

# Select platform
platform = openmm.Platform.getPlatformByName('Reference')

# Select run parameters
timestep = 1.0 * units.femtosecond # timestep for integrtion
nsteps = 1000 # number of steps per iteration
nequiliterations = 1000 # number of equilibration iterations
niterations = 10000 # number of integration periods to run

# Set temperature and collision rate for stochastic thermostats.
temperature = 298.0 * units.kelvin
collision_rate = 5.0 / units.picosecond 

# Choose system to test (from testsystems.py)
testsystem = 'HarmonicOscillator'
#testsystem = 'Diatom'
#testsystem = 'ConstraintCoupledHarmonicOscillator'
#testsystem = 'LennardJonesFluid'
#testsystem = 'WaterBox'
#testsystem = 'AlanineDipeptideImplicit'

# List of thermostats to test
thermostats = ['Maxwell-Boltzmann', 'Brownian', 'Langevin', 'Andersen', 'Andersen-massive']
#thermostats = ['Brownian', 'Langevin', 'Andersen', 'Andersen-massive']
#thermostats = ['Langevin', 'Andersen', 'Andersen-massive']


# Flag to set verbose debug output
verbose = False

#=============================================================================================
# SUBROUTINES
#=============================================================================================

def generateMaxwellBoltzmannVelocities(system, temperature):
   """Generate Maxwell-Boltzmann velocities.
   
   ARGUMENTS
   
   system (simtk.chem.openmm.System) - the system for which velocities are to be assigned
   temperature (simtk.unit.Quantity of temperature) - the temperature at which velocities are to be assigned
   
   RETURNS
   
   velocities (simtk.unit.Quantity of numpy Nx3 array, units length/time) - particle velocities
   
   TODO

   This could be sped up by introducing vector operations.
   
   """
   
   # Get number of atoms
   natoms = system.getNumParticles()
   
   # Create storage for velocities.        
   velocities = units.Quantity(numpy.zeros([natoms, 3], numpy.float32), units.nanometer / units.picosecond) # velocities[i,k] is the kth component of the velocity of atom i
   
   # Compute thermal energy and inverse temperature from specified temperature.
   kB = units.BOLTZMANN_CONSTANT_kB * units.AVOGADRO_CONSTANT_NA
   kT = kB * temperature # thermal energy
   beta = 1.0 / kT # inverse temperature
   
   # Assign velocities from the Maxwell-Boltzmann distribution.
   for atom_index in range(natoms):
      mass = system.getParticleMass(atom_index) # atomic mass
      sigma = units.sqrt(kT / mass) # standard deviation of velocity distribution for each coordinate for this atom
      for k in range(3):
         velocities[atom_index,k] = sigma * numpy.random.normal()

   # Return velocities
   return velocities

def computeHarmonicOscillatorExpectations(K, mass, temperature):
   """
   Compute mean and variance of potential and kinetic energies for harmonic oscillator.

   Numerical quadrature is used.

   ARGUMENTS

   K - spring constant
   mass - mass of particle
   temperature - temperature

   RETURNS

   values

   """

   values = dict()

   # Compute thermal energy and inverse temperature from specified temperature.
   kB = units.BOLTZMANN_CONSTANT_kB * units.AVOGADRO_CONSTANT_NA
   kT = kB * temperature # thermal energy
   beta = 1.0 / kT # inverse temperature
   
   # Compute standard deviation along one dimension.
   sigma = 1.0 / units.sqrt(beta * K) 

   # Define limits of integration along r.
   r_min = 0.0 * units.nanometers # initial value for integration
   r_max = 10.0 * sigma      # maximum radius to integrate to

   # Compute mean and std dev of potential energy.
   import scipy.integrate
   V = lambda r : (K/2.0) * (r*units.nanometers)**2 / units.kilojoules_per_mole # potential in kJ/mol, where r in nm
   q = lambda r : 4.0 * math.pi * r**2 * math.exp(-beta * (K/2.0) * (r*units.nanometers)**2) # q(r), where r in nm
   (IqV2, dIqV2) = scipy.integrate.quad(lambda r : q(r) * V(r)**2, r_min / units.nanometers, r_max / units.nanometers)
   (IqV, dIqV)   = scipy.integrate.quad(lambda r : q(r) * V(r), r_min / units.nanometers, r_max / units.nanometers)
   (Iq, dIq)     = scipy.integrate.quad(lambda r : q(r), r_min / units.nanometers, r_max / units.nanometers)
   values['potential'] = dict()
   values['potential']['mean'] = (IqV / Iq) * units.kilojoules_per_mole
   values['potential']['stddev'] = (IqV2 / Iq) * units.kilojoules_per_mole   
   
   # Compute mean and std dev of kinetic energy.
   values['kinetic'] = dict()
   values['kinetic']['mean'] = (3./2.) * kT
   values['kinetic']['stddev'] = math.sqrt(3./2.) * kT

   return values
   
#=============================================================================================
# Initialize statistics.
#=============================================================================================

data = dict()

#=============================================================================================
# Test thermostats
#=============================================================================================

for thermostat in thermostats:
    # Create the test system.integrator
    import testsystems
    constructor = getattr(testsystems, testsystem)
    [system, coordinates] = constructor()

    # Compute expectations numerically, if we can.
    if testsystem == 'HarmonicOscillator':
       K = 1.0 * units.kilojoules_per_mole / units.nanometer**2
       mass = 39.948 * units.amu
       numerical_expectations = computeHarmonicOscillatorExpectations(K, mass, temperature)

    # Minimize coordinates.
    if verbose: print "Minimizing energy..."
    import optimize
    minimizer = optimize.LBFGSMinimizer(system, verbose=verbose, platform=platform)
    coordinates = minimizer.minimize(coordinates, constrain=True)
    
    # Create integrator and context.
    if thermostat == 'Brownian':
        integrator = openmm.BrownianIntegrator(temperature, collision_rate, timestep)
    elif thermostat == 'Langevin':
        integrator = openmm.LangevinIntegrator(temperature, collision_rate, timestep)
    elif thermostat == 'Andersen':
        # add Andersen thermostat
        force = openmm.AndersenThermostat(temperature, collision_rate)
        system.addForce(force)
        integrator = openmm.VerletIntegrator(timestep)
    elif thermostat in ['Andersen-massive', 'Maxwell-Boltzmann']:
        # Andersen with periodic massive collisions
        integrator = openmm.VerletIntegrator(timestep)        
    else:
        raise Exception("Unknown thermostat '%s'" % thermostat)

    # Create context.
    context = openmm.Context(system, integrator, platform)

    # Set initial positions.
    context.setPositions(coordinates)

    # Generate initial velocities from Maxwell-Boltzmann distribution.
    velocities = generateMaxwellBoltzmannVelocities(system, temperature)
    context.setVelocities(velocities)
            
    # Initialize statistics.
    data[thermostat] = dict()
    data[thermostat]['time'] = numpy.zeros([niterations], numpy.float64)
    data[thermostat]['potential'] = numpy.zeros([niterations], numpy.float64)
    data[thermostat]['kinetic'] = numpy.zeros([niterations], numpy.float64)

    # Equilibrate.
    if verbose: print "Equilibrating..."
    integrator.step(nequiliterations * nsteps)

    # Production
    if verbose: print "Production..."
    for iteration in range(niterations):
        if thermostat in ['Andersen-massive', 'Maxwell-Boltzmann']:
            # Reassign velocities from Maxwell-Boltzmann distribution.
            velocities = generateMaxwellBoltzmannVelocities(system, temperature)
            context.setVelocities(velocities)
        
        # Propagate dynamics.
        if thermostat not in ['Maxwell-Boltzmann']:
           integrator.step(nsteps)
        
        # Compute energies.
        state = context.getState(getEnergy=True)
        kinetic = state.getKineticEnergy()
        potential = state.getPotentialEnergy()
        if verbose:
            if thermostat != 'Brownian':
                print "%6d %9.3f %16.3f %16.3f" % (iteration, state.getTime() / units.picoseconds, kinetic / units.kilocalories_per_mole, potential / units.kilocalories_per_mole)
            else:
                print "%6d %9.3f %16s %16.3f" % (iteration, state.getTime() / units.picoseconds, 'N/A', potential / units.kilocalories_per_mole)                
        
        # Store energies.
        data[thermostat]['time'][iteration] = state.getTime() / units.picoseconds        
        data[thermostat]['potential'][iteration] = potential / units.kilocalories_per_mole
        data[thermostat]['kinetic'][iteration] = kinetic / units.kilocalories_per_mole    

    # Clean up.
    del system, integrator, context   

#=============================================================================================
# Compute statistical inefficiencies to determine effective number of uncorrelated samples.
#=============================================================================================

import timeseries
for thermostat in thermostats:
    data[thermostat]['g_potential'] = timeseries.statisticalInefficiency(data[thermostat]['potential'])
    if thermostat == 'Brownian': continue
    data[thermostat]['g_kinetic'] = timeseries.statisticalInefficiency(data[thermostat]['kinetic'])

#=============================================================================================
# Compute drift
#=============================================================================================

import scipy.stats
for thermostat in thermostats:
    time = data[thermostat]['time'] / 1000.0
    potential = data[thermostat]['potential']
    (slope, intercept, r, tt, stderr)=scipy.stats.linregress(time, potential)
    data[thermostat]['drift'] = slope
    data[thermostat]['ddrift'] = stderr

#=============================================================================================
# Summary statistics
#=============================================================================================

print "Summary statistics for system '%s' (%.3f ns equil, %.3f ns production) for '%s' platform" % (testsystem, nequiliterations * nsteps * timestep / units.nanoseconds, niterations * nsteps * timestep / units.nanoseconds, platform.getName())

print ""

print "average kinetic energies:"
for thermostat in thermostats:
   if thermostat == 'Brownian': continue # Brownian doesn't use kinetic energy
   d = data[thermostat]
   print "%32s  %12.3f +- %12.3f  kcal/mol  (g = %8.1f ps)" % (thermostat, d['kinetic'].mean(), d['kinetic'].std() / numpy.sqrt(niterations / d['g_kinetic']), d['g_kinetic'] * nsteps * timestep / units.picoseconds)
# Show numerical results if we are able to compute them.
if testsystem == 'HarmonicOscillator':
   print "%32s  %12.3f    %12s  kcal/mol" % ("numerical", numerical_expectations['kinetic']['mean'] / units.kilocalories_per_mole, "")
print ""

# TODO: Check whether these values differ from expected or consensus values by more than, say, six sigma.

print "average potential energies:"
for thermostat in thermostats:
   if thermostat not in ['Maxwell-Boltzmann']:
      d = data[thermostat]    
      print "%32s  %12.3f +- %12.3f  kcal/mol  (g = %8.1f ps)" % (thermostat, d['potential'].mean(), d['potential'].std() / numpy.sqrt(niterations / d['g_potential']), d['g_potential'] * nsteps * timestep / units.picoseconds)
# Show numerical results if we are able to compute them.
if testsystem == 'HarmonicOscillator':
   print "%32s  %12.3f    %12s  kcal/mol" % ("numerical", numerical_expectations['potential']['mean'] / units.kilocalories_per_mole, "")
print ""

# TODO: Check whether these values differ from expected or consensus values by more than, say, six sigma.

print "drift in potential energies:"
for thermostat in thermostats:
   if thermostat not in ['Maxwell-Boltzmann']:
      d = data[thermostat]    
      print "%32s  %12.3f +- %12.3f  kcal/mol/ns" % (thermostat, d['drift'], d['ddrift'])
print ""


