#!/usr/local/bin/env python

#=============================================================================================
# MODULE DOCSTRING
#=============================================================================================

"""
Estimate steady-state nonequilibrium free energy of a TIP3P water box as a function of number of water molecules and timestep.

Use mpi4py to parallelize water box sizes.

"""

#=============================================================================================
# GLOBAL IMPORTS
#=============================================================================================

import sys
import math
import doctest
import time
import numpy

import simtk.unit as units
import simtk.openmm as openmm
import simtk.pyopenmm.extras.testsystems as testsystems

import netCDF4 as netcdf 

#=============================================================================================
# CONSTANTS
#=============================================================================================

kB = units.BOLTZMANN_CONSTANT_kB * units.AVOGADRO_CONSTANT_NA

#=============================================================================================
# INTEGRATORS
#=============================================================================================

def GHMCIntegrator(timestep, temperature=298.0*units.kelvin, gamma=50.0/units.picoseconds):
    """
    Create a generalized hybrid Monte Carlo (GHMC) integrator.
    
    ARGUMENTS

    timestep (numpy.unit.Quantity compatible with femtoseconds) - the integration timestep
    temperature (numpy.unit.Quantity compatible with kelvin) - the temperature
    gamma (numpy.unit.Quantity compatible with 1/picoseconds) - the collision rate

    RETURNS

    integrator (simtk.openmm.CustomIntegrator) - a GHMC integrator

    NOTES
    
    This integrator is equivalent to a Langevin integrator in the velocity Verlet discretization with a
    Metrpolization step to ensure sampling from the appropriate distribution.

    Additional global variables 'ntrials' and  'naccept' keep track of how many trials have been attempted and
    accepted, respectively.

    TODO

    Move initialization of 'sigma' to setting the per-particle variables.

    """

    kT = kB * temperature
        
    integrator = openmm.CustomIntegrator(timestep)

    integrator.addGlobalVariable("kT", kB*temperature) # thermal energy
    integrator.addGlobalVariable("b", numpy.exp(-gamma*timestep)) # velocity mixing parameter
    integrator.addPerDofVariable("sigma", 0) 
    integrator.addGlobalVariable("ke", 0) # kinetic energy
    integrator.addPerDofVariable("vold", 0) # old velocities
    integrator.addPerDofVariable("xold", 0) # old positions
    integrator.addGlobalVariable("Eold", 0) # old energy
    integrator.addGlobalVariable("Enew", 0) # new energy
    integrator.addGlobalVariable("accept", 0) # accept or reject
    integrator.addGlobalVariable("naccept", 0) # number accepted
    integrator.addGlobalVariable("ntrials", 0) # number of Metropolization trials
    integrator.addPerDofVariable("x1", 0) # position before application of constraints
    
    #
    # Pre-computation.
    # This only needs to be done once, but it needs to be done for each degree of freedom.
    # Could move this to initialization?
    #
    integrator.addComputePerDof("sigma", "sqrt(kT/m)")

    # 
    # Velocity perturbation.
    #
    integrator.addComputePerDof("v", "sqrt(b)*v + sqrt(1-b)*sigma*gaussian")
    integrator.addConstrainVelocities();
    
    #
    # Metropolized symplectic step.
    #
    integrator.addUpdateContextState();
    integrator.addComputeSum("ke", "0.5*m*v*v")
    integrator.addComputeGlobal("Eold", "ke + energy")
    integrator.addComputePerDof("xold", "x")
    integrator.addComputePerDof("vold", "v")
    integrator.addComputePerDof("v", "v + 0.5*dt*f/m")
    integrator.addComputePerDof("x", "x + v*dt")
    integrator.addComputePerDof("x1", "x")
    integrator.addConstrainPositions();
    integrator.addComputePerDof("v", "v + 0.5*dt*f/m + (x-x1)/dt")
    integrator.addConstrainVelocities();
    integrator.addComputeSum("ke", "0.5*m*v*v")
    integrator.addComputeGlobal("Enew", "ke + energy")
    integrator.addComputeGlobal("accept", "step(exp(-(Enew-Eold)/kT) - uniform)")
    integrator.addComputePerDof("x", "x*accept + xold*(1-accept)")
    integrator.addComputePerDof("v", "v*accept - vold*(1-accept)")

    #
    # Velocity randomization
    #
    integrator.addComputePerDof("v", "sqrt(b)*v + sqrt(1-b)*sigma*gaussian")
    integrator.addConstrainVelocities();

    #
    # Accumulate statistics.
    #
    integrator.addComputeGlobal("naccept", "naccept + accept")
    integrator.addComputeGlobal("ntrials", "ntrials + 1")   

    return integrator

def VVVRIntegrator(timestep, temperature=298.0*units.kelvin, gamma=50.0/units.picoseconds):
    """
    Create a velocity verlet with velocity randomization (VVVR) integrator.
    
    ARGUMENTS

    timestep (numpy.unit.Quantity compatible with femtoseconds) - the integration timestep
    temperature (numpy.unit.Quantity compatible with kelvin) - the temperature
    gamma (numpy.unit.Quantity compatible with 1/picoseconds) - the collision rate

    RETURNS

    integrator (simtk.openmm.CustomIntegrator) - a VVVR integrator

    NOTES
    
    This integrator is equivalent to a Langevin integrator in the velocity Verlet discretization with a
    timestep correction to ensure that the field-free diffusion constant is timestep invariant.

    The global 'pseudowork' keeps track of the pseudowork accumulated during integration, and can be
    used to correct the sampled statistics or in a Metropolization scheme.
    
    TODO

    Move initialization of 'sigma' to setting the per-particle variables.
    We can ditch pseudowork and instead use total energy difference - heat.
    
    """

    kT = kB * temperature
    
    integrator = openmm.CustomIntegrator(timestep)
    
    integrator.addGlobalVariable("kT", kT) # thermal energy
    integrator.addGlobalVariable("b", numpy.exp(-gamma*timestep)) # velocity mixing parameter
    integrator.addPerDofVariable("sigma", 0) 
    integrator.addGlobalVariable("ke_old", 0) # kinetic energy
    integrator.addGlobalVariable("ke_new", 0) # kinetic energy
    integrator.addGlobalVariable("ke", 0) # kinetic energy
    integrator.addGlobalVariable("Eold", 0) # old energy
    integrator.addGlobalVariable("Enew", 0) # new energy
    integrator.addGlobalVariable("accept", 0) # accept or reject
    integrator.addGlobalVariable("naccept", 0) # number accepted
    integrator.addGlobalVariable("ntrials", 0) # number of Metropolization trials
    integrator.addPerDofVariable("x1", 0) # position before application of constraints

    integrator.addGlobalVariable("pseudowork", 0) # accumulated pseudowork
    integrator.addGlobalVariable("heat", 0) # accumulated heat
    
    #
    # Pre-computation.
    # This only needs to be done once, but it needs to be done for each degree of freedom.
    # Could move this to initialization?
    #
    integrator.addComputePerDof("sigma", "sqrt(kT/m)")

    # 
    # Velocity perturbation.
    #
    integrator.addComputeSum("ke_old", "0.5*m*v*v")    
    integrator.addComputePerDof("v", "sqrt(b)*v + sqrt(1-b)*sigma*gaussian")
    integrator.addConstrainVelocities();
    integrator.addComputeSum("ke_new", "0.5*m*v*v")
    integrator.addComputeGlobal("heat", "heat + (ke_new - ke_old)")    
    
    #
    # Metropolized symplectic step.
    #
    integrator.addUpdateContextState(); 
    integrator.addComputeSum("ke", "0.5*m*v*v")
    integrator.addComputeGlobal("Eold", "ke + energy")
    integrator.addComputePerDof("v", "v + 0.5*dt*f/m")
    integrator.addComputePerDof("x", "x + v*dt")
    integrator.addComputePerDof("x1", "x")
    integrator.addConstrainPositions();
    integrator.addComputePerDof("v", "v + 0.5*dt*f/m + (x-x1)/dt")
    integrator.addConstrainVelocities();
    integrator.addComputeSum("ke", "0.5*m*v*v")
    integrator.addComputeGlobal("Enew", "ke + energy")

    #
    # Accumulate statistics.
    #
    integrator.addComputeGlobal("pseudowork", "pseudowork + (Enew-Eold)") # accumulate pseudowork
    integrator.addComputeGlobal("naccept", "naccept + 1")
    integrator.addComputeGlobal("ntrials", "ntrials + 1")   

    #
    # Velocity randomization
    #
    integrator.addComputeSum("ke_old", "0.5*m*v*v")
    integrator.addComputePerDof("v", "sqrt(b)*v + sqrt(1-b)*sigma*gaussian")
    integrator.addConstrainVelocities();
    integrator.addComputeSum("ke_new", "0.5*m*v*v")
    integrator.addComputeGlobal("heat", "heat + (ke_new - ke_old)")    

    return integrator
    
#=============================================================================================
# UTILITY SUBROUTINES
#=============================================================================================

def create_waterbox(box_edge=2.3*units.nanometers, cutoff=0.9*units.nanometers, constraints=True):
   """
   Create a water box test system.

   OPTIONAL ARGUMENTS

   box_edge (simtk.unit.Quantity with units compatible with nanometers) - edge length for cubic box [should be greater than 2*cutoff] (default: 2.3 nm)
   cutoff  (simtk.unit.Quantity with units compatible with nanometers) - nonbonded cutoff (default: 0.9 * units.nanometers)
   constraints (bool) - if True, will use constraints

   RETURNS

   system (simtk.openmm.System) - the water box system
   positions (simtk.unit.Quantity of nparticles x 3 with units compatible with nanometers) - the particle positions

   """
   import simtk.openmm.app as app

   # Load forcefield for solvent model.
   ff =  app.ForceField('tip3p.xml')

   # Create empty topology and coordinates.
   top = app.Topology()
   pos = units.Quantity((), units.angstroms)

   # Create new Modeller instance.
   m = app.Modeller(top, pos)

   # Add solvent to specified box dimensions.
   boxSize = units.Quantity(numpy.ones([3]) * box_edge/box_edge.unit, box_edge.unit)
   m.addSolvent(ff, boxSize=boxSize)
   
   # Get new topology and coordinates.
   newtop = m.getTopology()
   newpos = m.getPositions()
   
   # Convert positions to numpy.
   positions = units.Quantity(numpy.array(newpos / newpos.unit), newpos.unit)
   
   # Create OpenMM System.
   nonbondedMethod = app.CutoffPeriodic
   if constraints:
       system = ff.createSystem(newtop, nonbondedMethod=nonbondedMethod, nonbondedCutoff=cutoff, constraints=app.HBonds, rigidWater=True, removeCMMotion=False)
   else:
       system = ff.createSystem(newtop, nonbondedMethod=nonbondedMethod, nonbondedCutoff=cutoff, constraints=False, rigidWater=False, removeCMMotion=False)

   return [system, positions]

def generateMaxwellBoltzmannVelocities(system, temperature):
   """Generate Maxwell-Boltzmann velocities.
   
   ARGUMENTS
   
   system (simtk.openmm.System) - the system for which velocities are to be assigned
   temperature (simtk.unit.Quantity of temperature) - the temperature at which velocities are to be assigned
   
   RETURNS
   
   velocities (simtk.unit.Quantity of numpy Nx3 array, units length/time) - particle velocities
   
   TODO

   This could be sped up by introducing vector operations.
   
   """
   
   # Get number of atoms
   natoms = system.getNumParticles()
   
   # Create storage for velocities.        
   velocities = units.Quantity(numpy.zeros([natoms, 3], numpy.float32), units.nanometer / units.picosecond) # velocities[i,k] is the kth component of the velocity of atom i
   
   # Compute thermal energy and inverse temperature from specified temperature.
   kB = units.BOLTZMANN_CONSTANT_kB * units.AVOGADRO_CONSTANT_NA
   kT = kB * temperature # thermal energy
   beta = 1.0 / kT # inverse temperature
   
   # Assign velocities from the Maxwell-Boltzmann distribution.
   for atom_index in range(natoms):
      mass = system.getParticleMass(atom_index) # atomic mass
      sigma = units.sqrt(kT / mass) # standard deviation of velocity distribution for each coordinate for this atom
      for k in range(3):
         velocities[atom_index,k] = sigma * numpy.random.normal()

   # Return velocities
   return velocities

def statisticalInefficiency(A_n, B_n=None, fast=False, mintime=3):
  """
  Compute the (cross) statistical inefficiency of (two) timeseries.

  REQUIRED ARGUMENTS  
    A_n (numpy array) - A_n[n] is nth value of timeseries A.  Length is deduced from vector.

  OPTIONAL ARGUMENTS
    B_n (numpy array) - B_n[n] is nth value of timeseries B.  Length is deduced from vector.
       If supplied, the cross-correlation of timeseries A and B will be estimated instead of the
       autocorrelation of timeseries A.  
    fast (boolean) - if True, will use faster (but less accurate) method to estimate correlation
       time, described in Ref. [1] (default: False)
    mintime (int) - minimum amount of correlation function to compute (default: 3)
       The algorithm terminates after computing the correlation time out to mintime when the
       correlation function furst goes negative.  Note that this time may need to be increased
       if there is a strong initial negative peak in the correlation function.

  RETURNS
    g is the estimated statistical inefficiency (equal to 1 + 2 tau, where tau is the correlation time).
       We enforce g >= 1.0.

  NOTES 
    The same timeseries can be used for both A_n and B_n to get the autocorrelation statistical inefficiency.
    The fast method described in Ref [1] is used to compute g.

  REFERENCES  
    [1] J. D. Chodera, W. C. Swope, J. W. Pitera, C. Seok, and K. A. Dill. Use of the weighted
    histogram analysis method for the analysis of simulated and parallel tempering simulations.
    JCTC 3(1):26-41, 2007.

  EXAMPLES

  Compute statistical inefficiency of timeseries data with known correlation time.  

  >>> import timeseries
  >>> A_n = timeseries.generateCorrelatedTimeseries(N=100000, tau=5.0)
  >>> g = statisticalInefficiency(A_n, fast=True)
  
  """

  # Create numpy copies of input arguments.
  A_n = numpy.array(A_n)
  if B_n is not None:  
    B_n = numpy.array(B_n)
  else:
    B_n = numpy.array(A_n) 
  
  # Get the length of the timeseries.
  N = A_n.size

  # Be sure A_n and B_n have the same dimensions.
  if(A_n.shape != B_n.shape):
    raise ParameterError('A_n and B_n must have same dimensions.')

  # Initialize statistical inefficiency estimate with uncorrelated value.
  g = 1.0
    
  # Compute mean of each timeseries.
  mu_A = A_n.mean()
  mu_B = B_n.mean()

  # Make temporary copies of fluctuation from mean.
  dA_n = A_n.astype(numpy.float64) - mu_A
  dB_n = B_n.astype(numpy.float64) - mu_B

  # Compute estimator of covariance of (A,B) using estimator that will ensure C(0) = 1.
  sigma2_AB = (dA_n * dB_n).mean() # standard estimator to ensure C(0) = 1

  # Trap the case where this covariance is zero, and we cannot proceed.
  if(sigma2_AB == 0):
    raise ParameterException('Sample covariance sigma_AB^2 = 0 -- cannot compute statistical inefficiency')

  # Accumulate the integrated correlation time by computing the normalized correlation time at
  # increasing values of t.  Stop accumulating if the correlation function goes negative, since
  # this is unlikely to occur unless the correlation function has decayed to the point where it
  # is dominated by noise and indistinguishable from zero.
  t = 1
  increment = 1
  while (t < N-1):

    # compute normalized fluctuation correlation function at time t
    C = sum( dA_n[0:(N-t)]*dB_n[t:N] + dB_n[0:(N-t)]*dA_n[t:N] ) / (2.0 * float(N-t) * sigma2_AB)
    
    # Terminate if the correlation function has crossed zero and we've computed the correlation
    # function at least out to 'mintime'.
    if (C <= 0.0) and (t > mintime):
      break
    
    # Accumulate contribution to the statistical inefficiency.
    g += 2.0 * C * (1.0 - float(t)/float(N)) * float(increment)

    # Increment t and the amount by which we increment t.
    t += increment

    # Increase the interval if "fast mode" is on.
    if fast: increment += 1

  # g must be at least unity
  if (g < 1.0): g = 1.0
   
  # Return the computed statistical inefficiency.
  return g

def initialize_netcdf(ncfile, system, timesteps_to_try, nsteps_to_try):
    """
    Initialize NetCDF file for storage.
    
    """    
    
    # Create dimensions.
    ncfile.createDimension('samples', 0) # unlimited number of samples
    ncfile.createDimension('atoms', system.getNumParticles()) # number of particles in system
    ncfile.createDimension('dimensions', 3) # number of spatial dimensions
    ncfile.createDimension('timesteps', len(timesteps_to_try)) # number of timesteps to try
    ncfile.createDimension('steps', len(nsteps_to_try)) # number of timesteps to try
    ncfile.createDimension('single', 1)

    # Create variables.
    ncvar_positions = ncfile.createVariable('positions', 'f', ('samples', 'atoms', 'dimensions'))
    ncvar_velocities = ncfile.createVariable('velocities', 'f', ('samples', 'atoms', 'dimensions'))
    ncvar_box_vectors = ncfile.createVariable('box_vectors', 'f', ('samples','dimensions','dimensions'))        
    ncvar_volumes  = ncfile.createVariable('volumes', 'f', ('samples',))
    ncvar_timesteps = ncfile.createVariable('timesteps', 'f', ('timesteps',))
    ncvar_nsteps = ncfile.createVariable('nsteps', 'f', ('steps',))
    ncvar_pseudowork = ncfile.createVariable('pseudowork', 'f', ('samples', 'timesteps', 'steps'))
    ncvar_heat = ncfile.createVariable('heat', 'f', ('samples', 'timesteps', 'steps'))
    ncvar_initial_energy = ncfile.createVariable('initial_energy', 'f', ('samples'))
    ncvar_final_energy = ncfile.createVariable('final_energy', 'f', ('samples', 'timesteps', 'steps'))

    # Serialize OpenMM System object.
    ncvar_serialized_state = ncfile.createVariable('system', str, ('single',), zlib=True)
    ncvar_serialized_state[0] = system.__getstate__()
    
    # Define units for variables.
    setattr(ncvar_positions, 'units', 'nm')
    setattr(ncvar_velocities, 'units', 'nm/ps')
    setattr(ncvar_box_vectors, 'units', 'nm')
    setattr(ncvar_volumes, 'units', 'nm**3')
    setattr(ncvar_timesteps, 'units', 'fs')
    setattr(ncvar_pseudowork, 'units', 'kT')
    setattr(ncvar_heat, 'units', 'kT')
    setattr(ncvar_initial_energy, 'units', 'kT')
    setattr(ncvar_final_energy, 'units', 'kT')

    for (timestep_index, timestep) in enumerate(timesteps_to_try):
        ncfile.variables['timesteps'][timestep_index] = timestep / units.femtoseconds    

    for (nsteps_index, nsteps) in enumerate(nsteps_to_try):
        ncfile.variables['nsteps'][nsteps_index] = nsteps

    # Define long (human-readable) names for variables.
    setattr(ncvar_positions, "long_name", "positions[sample,particle,dimenson] is position of coordinate 'dimension' of particle 'particle' for sample 'sample'.")
    setattr(ncvar_velocities, "long_name", "velocities[sample,particle,dimension] is velocity of coordinate 'dimension' of particle 'particle' for sample 'sample.")
    setattr(ncvar_box_vectors, "long_name", "box_vectors[sample,i,j] is dimension j of box vector i for sample 'sample'.")
    setattr(ncvar_volumes, "long_name", "volume[sample] is the box volume for sample 'sample'.")
    setattr(ncvar_timesteps, "long_name", "timesteps[timestep_index] is the VVVR timestep used for dataset 'timestep_index'.")
    setattr(ncvar_nsteps, "long_name", "nsteps[nsteps_index] is the number of VVVR steps simulated for dataset 'nsteps_index'.")
    setattr(ncvar_pseudowork, "long_name", "pseudowork[sample,timestep_index,nsteps_index] is the accumulated pseudowork for sample 'sample' after 'nsteps[nsteps_index]' steps of VVVR using timestep 'timesteps[timestep_index]'.")
    setattr(ncvar_heat, "long_name", "heat[sample,timestep_index,nsteps_index] is the accumulated heat for sample 'sample' after 'nsteps[nsteps_index]' steps of VVVR using timestep 'timesteps[timestep_index]'.")    
    setattr(ncvar_initial_energy, "long_name", "initial_energy[sample] is the initial total energy for sample 'sample' before VVVR simulation.")
    setattr(ncvar_final_energy, "long_name", "final_energy[sample,timestep_index,nsteps_index] is the final total energy for sample 'sample' after 'nsteps[nsteps_index]' steps of VVVR using timestep 'timesteps[timestep_index]'.")

    
    # Create timestamp variable.
    ncvar_timestamp = ncfile.createVariable('timestamp', str, ('samples',))

    # Force sync to disk to avoid data loss.
    ncfile.sync()
        
    return

#=============================================================================================
# PARAMETERS
#=============================================================================================

boxsizes_to_try = [5.0 * units.angstroms, 10.0 * units.angstroms, 15.0 * units.angstroms] # waterbox sizes to try 

cutoff = 0.9 * units.nanometers
min_box_edge = 2.35 * cutoff
boxedges_to_try = [ min_box_edge*(n**(1.0/3.0)) for n in range(1,7) ]

#nparticles_to_try = [ 2**n for n in range(8, 14) ]
#timesteps_to_try = units.Quantity([0.25, 0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0, 2.5, 3.0, 3.5, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], units.femtoseconds)
timesteps_to_try = units.Quantity([0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0], units.femtoseconds)
nsteps_to_try = [ 2**n for n in range(13) ]

#print "numbers of particles to try: %s" % str(nparticles_to_try)
#print "timesteps to try: %s" % str(timesteps_to_try)
#print "number of steps to try: %s" % str(nsteps_to_try)

temperature = 298.0 * units.kelvin
pressure = 1.0 * units.atmosphere # pressure for equilibration
gamma = 91.0 / units.picosecond # collision rate
ghmc_nsteps = 10000 # number of steps to generate new uncorrelated sample with GHMC
ghmc_timestep = 0.50 * units.femtoseconds
nsamples = 2000 # number of samples to generate
nequil = 10 # number of NPT equilibration iterations

constraints = True # if True, will constrain waters
#constraints = False # if True, will constrain waters
verbose = True

kT = kB * temperature # thermal energy
beta = 1.0 / kT # inverse temperature

#=============================================================================================
# Initialize MPI.
#=============================================================================================

try:
    from mpi4py import MPI # MPI wrapper
    rank = MPI.COMM_WORLD.rank
    print "Node %d / %d" % (MPI.COMM_WORLD.rank, MPI.COMM_WORLD.size)
except:
    print "mpi4py not available---using serial execution."
    rank = 0
    
platform_name = 'OpenCL'
platform = openmm.Platform.getPlatformByName(platform_name)
deviceid = rank
platform.setPropertyDefaultValue('OpenCLDeviceIndex', '%d' % deviceid) # select OpenCL device index
platform.setPropertyDefaultValue('CudaDeviceIndex', '%d' % deviceid) # select Cuda device index

#=============================================================================================
# Create system to simulate.
#=============================================================================================

box_edge = boxedges_to_try[rank] # use MPI rank to select box size
[system, positions] = create_waterbox(box_edge=box_edge, cutoff=cutoff, constraints=constraints)
velocities = generateMaxwellBoltzmannVelocities(system, temperature)
ndof = 3*system.getNumParticles() - system.getNumConstraints()
nparticles = system.getNumParticles()
nwaters = nparticles / 3

print "Node %d: Box has edge length %.3f nm and %d particles (%d waters)" % (rank, box_edge / units.nanometers, nparticles, nwaters)

if constraints:
    store_filename = 'data/TIP3P-%d-constrained.nc' % (nwaters)
else:
    store_filename = 'data/TIP3P-%d-flexible.nc' % (nwaters)

# Turn off output from all but head node.
if rank != 0: verbose = False

#=============================================================================================
# Open NetCDF file for writing.
#=============================================================================================

import os.path
resume = False
if os.path.exists(store_filename):
    print "Node %d: Attempting to resume from file '%s'..." % (rank, store_filename)
    ncfile = netcdf.Dataset(store_filename, 'a')
    resume = True
else:
    print "Node %d: Opening '%s' for writing..." % (rank, store_filename)
    ncfile = netcdf.Dataset(store_filename, 'w', version='NETCDF4')
    initialize_netcdf(ncfile, system, timesteps_to_try, nsteps_to_try)    

#=============================================================================================
# Sample accumulated work as a function of timestep and nsteps for switching from GHMC to VVVR
#=============================================================================================

# Equilibrate with Monte Carlo barostat.
import copy
system_with_barostat = copy.deepcopy(system)
barostat = openmm.MonteCarloBarostat(pressure, temperature)
system_with_barostat.addForce(barostat)
ghmc_integrator = GHMCIntegrator(ghmc_timestep, temperature=temperature, gamma=gamma)
ghmc_global_variables = { ghmc_integrator.getGlobalVariableName(index) : index for index in range(ghmc_integrator.getNumGlobalVariables()) }
ghmc_context = openmm.Context(system_with_barostat, ghmc_integrator, platform)

# Set positions and velocities.
ghmc_context.setPositions(positions)
ghmc_context.setVelocities(velocities) 

first_sample = 0 # sample index to start from
if resume:
    # Get positions and velocities from NetCDF file.
    sample = ncfile.variables['positions'].shape[0] - 1 # back up one sample, in case it was incomplete
    if (sample > 0):
        print "Resuming from NetCDF file at sample %d" % sample
        positions = units.Quantity(ncfile.variables['positions'][sample,:,:], units.nanometers)
        velocities = units.Quantity(ncfile.variables['velocities'][sample,:,:], units.nanometers / units.picoseconds)
        box_vectors = units.Quantity(ncfile.variables['box_vectors'][sample,:,:], units.nanometers)
        nequil = 0 # don't equilibrate
        first_sample = sample # start at 'sample'

        # Update positions and velocities.
        ghmc_context.setPositions(positions)
        ghmc_context.setVelocities(velocities) 
        ghmc_context.setPeriodicBoxVectors(box_vectors[0,:], box_vectors[1,:], box_vectors[2,:])

# Compute initial volume.
state = ghmc_context.getState()
box_vectors = state.getPeriodicBoxVectors(asNumpy=True)
volume = box_vectors[0,0] * box_vectors[1,1] * box_vectors[2,2]
if verbose: print "initial volume %8.3f nm^3" % (volume / units.nanometers**3)

# Equilibrate system with NPT.
volume_history = numpy.zeros([nequil], numpy.float64)
for iteration in range(nequil):
    ghmc_integrator.setGlobalVariable(ghmc_global_variables['naccept'], 0)
    ghmc_integrator.setGlobalVariable(ghmc_global_variables['ntrials'], 0)

    # Generate new sample from equilibrium distribution with GHMC.
    ghmc_integrator.step(ghmc_nsteps)
    
    # Compute volume.
    state = ghmc_context.getState(getEnergy=True)
    box_vectors = state.getPeriodicBoxVectors(asNumpy=True)
    potential = state.getPotentialEnergy()
    volume = box_vectors[0,0] * box_vectors[1,1] * box_vectors[2,2]
    volume_history[iteration] = volume / units.nanometers**3
    max_radius = box_vectors[0,0] / 2.0 # half the box width

    naccept = ghmc_integrator.getGlobalVariable(ghmc_global_variables['naccept'])
    ntrials = ghmc_integrator.getGlobalVariable(ghmc_global_variables['ntrials'])
    fraction_accepted = float(naccept) / float(ntrials)
    if verbose: print "GHMC equil %5d / %5d | accepted %6d / %6d (%6.3f %%) | volume %8.3f nm^3 | max radius %8.3f nm | potential %8.3f kcal/mol" % (iteration, nequil, naccept, ntrials, fraction_accepted*100.0, volume / units.nanometers**3, max_radius / units.nanometers, potential / units.kilocalories_per_mole)

# Make a list of global variables.
ghmc_global_variables = { ghmc_integrator.getGlobalVariableName(index) : index for index in range(ghmc_integrator.getNumGlobalVariables()) }

# Allocate storage for data.
data = dict() # data[(timestep/units.femtoseconds, nsteps)][sample] is pseudowork for sample 'sample' after 'nsteps' of dynamics for timestep 'timestep'
for timestep in timesteps_to_try:
    for nsteps in nsteps_to_try:
        data[(timestep/units.femtoseconds, nsteps)] = numpy.zeros([nsamples], numpy.float64)

if verbose: print "Generating uncorrelated samples with GHMC..."
for sample in range(first_sample, nsamples):
    ghmc_integrator.setGlobalVariable(ghmc_global_variables['naccept'], 0)
    ghmc_integrator.setGlobalVariable(ghmc_global_variables['ntrials'], 0)

    # Generate new sample from equilibrium distribution with GHMC.
    ghmc_integrator.step(ghmc_nsteps)
    
    naccept = ghmc_integrator.getGlobalVariable(ghmc_global_variables['naccept'])
    ntrials = ghmc_integrator.getGlobalVariable(ghmc_global_variables['ntrials'])
    fraction_accepted = float(naccept) / float(ntrials)
    if verbose: print "GHMC sample %5d / %5d | accepted %6d / %6d (%6.3f %%)" % (sample, nsamples, naccept, ntrials, fraction_accepted*100.0)

    # Extract coordinates and box vectors.
    state = ghmc_context.getState(getPositions=True, getVelocities=True)
    positions = state.getPositions(asNumpy=True)
    velocities = state.getVelocities(asNumpy=True)
    box_vectors = state.getPeriodicBoxVectors(asNumpy=True)
    volume = box_vectors[0,0] * box_vectors[1,1] * box_vectors[2,2]

    state = ghmc_context.getState(getEnergy=True)
    potential = state.getPotentialEnergy()
    kinetic = state.getKineticEnergy()
    total_energy = potential+kinetic

    ncfile.variables['positions'][sample,:,:] = positions[:,:] / units.nanometers
    ncfile.variables['velocities'][sample,:,:] = velocities[:,:] / (units.nanometers/units.picoseconds)
    ncfile.variables['box_vectors'][sample,:,:] = box_vectors[:,:] / units.nanometers
    ncfile.variables['volumes'][sample] = volume / (units.nanometers**3)
    ncfile.variables['initial_energy'][sample] = total_energy / kT

    initial_time = time.time()

    # Switch to VVVR using different timesteps.
    for (timestep_index, timestep) in enumerate(timesteps_to_try):
        initial_timestep_time = time.time()

        # Initialize VVVR integrator and context.
        integrator = VVVRIntegrator(timestep, temperature=temperature, gamma=gamma)
        context = openmm.Context(system, integrator, platform)
        context.setPositions(positions)
        context.setVelocities(velocities) 
        context.setPeriodicBoxVectors(box_vectors[0,:], box_vectors[1,:], box_vectors[2,:])

        # Make a list of global variables.
        global_variables = { integrator.getGlobalVariableName(index) : index for index in range(integrator.getNumGlobalVariables()) }

        for (nsteps_index, nsteps) in enumerate(nsteps_to_try):
            # Compute number of additional steps to simulate.
            additional_steps = nsteps
            if (nsteps_index > 0): additional_steps = nsteps_to_try[nsteps_index] - nsteps_to_try[nsteps_index-1]
            
            # Simulate additional steps.
            integrator.step(additional_steps)

            # Compute final energy.
            state = context.getState(getEnergy=True)
            potential = state.getPotentialEnergy()
            kinetic = state.getKineticEnergy()
            total_energy = (potential+kinetic) / kT # in kT
            if numpy.isnan(total_energy): total_energy = numpy.inf            
            ncfile.variables['final_energy'][sample,timestep_index,nsteps_index] = total_energy            

            # Get delta energy.
            delta_energy = total_energy - ncfile.variables['initial_energy'][sample]

            # Store heat statistics.
            heat = integrator.getGlobalVariable(global_variables['heat']) * units.kilojoules_per_mole / kT
            ncfile.variables['heat'][sample,timestep_index,nsteps_index] = heat
            
            # Store pseudowork statistics.
            pseudowork = integrator.getGlobalVariable(global_variables['pseudowork']) * units.kilojoules_per_mole / kT
            if numpy.isnan(pseudowork): pseudowork = numpy.inf            
            data[(timestep/units.femtoseconds, nsteps)][sample] = pseudowork            
            ncfile.variables['pseudowork'][sample,timestep_index,nsteps_index] = pseudowork            

            #if verbose: print "timestep = %8.3f fs | nsteps = %8d | pseudowork = %.3f kT | heat = %.3f kT | DeltaE = %.3f | DeltaE-heat = %.3f | error = %.3f" % (timestep / units.femtoseconds, nsteps, pseudowork, heat, delta_energy, delta_energy - heat, (delta_energy-heat)-pseudowork)
            if verbose: print "timestep = %8.3f fs | nsteps = %8d | pseudowork = %.3f kT | heat = %.3f kT | DeltaE = %.3f" % (timestep / units.femtoseconds, nsteps, pseudowork, heat, delta_energy)

        # Clean up.
        del context, integrator

        final_timestep_time = time.time()
        elapsed_time = final_timestep_time - initial_timestep_time
        if verbose: print "  %12.3f s" % elapsed_time        
    
    # Store data.
    import cPickle as pickle
    outfile = open('data.pkl', 'wb')
    pickle.dump(sample, outfile)
    pickle.dump(data, outfile)
    outfile.close()

    # Write data to NetCDF file.
    ncfile.variables['timestamp'][sample] = time.ctime()
    ncfile.sync()

    final_time = time.time()
    elapsed_time = final_time - initial_time
    if verbose: print "  %12.3f s elapsed this iteration" % elapsed_time

# Close NetCDF file.
ncfile.close()

