#!/usr/local/bin/env python

#=============================================================================================
# MODULE DOCSTRING
#=============================================================================================

"""
Parallel tempering a Kob-Andersen system.

DESCRIPTION


REFERENCES

[1] Hedges LO, Jack RL, Garrahan JP, and Chandler D. Dynamic order-disorder in atomic models
of structural glass-formers. Science 323:1309, 2009.

[2] Minh DDL and Chodera JD. Optimal estimators and asymptotic variances for nonequilibrium
path-ensemble averages. JCP 131:134110, 2009.

COPYRIGHT

@author John D. Chodera <jchodera@gmail.com>

This source file is released under the GNU General Public License.

This program is free software: you can redistribute it and/or modify it under
the terms of the GNU General Public License as published by the Free Software
Foundation, either version 3 of the License, or (at your option) any later
version.

This program is distributed in the hope that it will be useful, but WITHOUT ANY
WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A
PARTICULAR PURPOSE.  See the GNU General Public License for more details.
 
You should have received a copy of the GNU General Public License along with
this program.  If not, see <http://www.gnu.org/licenses/>.

"""

#=============================================================================================
# GLOBAL IMPORTS
#=============================================================================================

import os
import os.path
import sys
import numpy
import math
import copy
import time
import gc

import simtk
import simtk.chem.openmm as openmm
import simtk.unit as units

import Scientific.IO.NetCDF as netcdf

#=============================================================================================
# SOURCE CONTROL
#=============================================================================================

__version__ = "$Id: KobAndersen.py 524 2010-01-05 07:47:29Z jchodera $"

#=============================================================================================
# GLOBAL CONSTANTS
#=============================================================================================

kB = units.BOLTZMANN_CONSTANT_kB * units.AVOGADRO_CONSTANT_NA

#=============================================================================================
# Kob-Andersen two-component mixture of Lennard-Jones particles.
#=============================================================================================

def KobAndersen(N=150, NA=None, A_fraction=0.8, principal_component_density=0.96, mm=None, mass=None, epsilon=None, sigma=None, softcore=False, alpha=0.5, lambda_=1.0):
    """
    Create a test system containing a Kob-Andersen two-component mixture.

    A soft-core Lennard-Jones potential is used if 'softcore' is set to True.

    OPTIONAL ARGUMENTS

    N (int) - total number of atoms (default: 150)
    A_fraction (float) - fraction of A component
    principal_component_density (float) - NA sigma^3 / V (default: 0.96)
    softcore (bool) - soft-core Lennard Jones (Eq. 4 of Shirts and Pande, JCP 122:134508, 2005) is used if True (default: False)
    lambda_ (float) - alchemical parameter, where 1.0 is fully interacting, 0.0 is non-interacting (default: 1.0)
    alpha (float) - soft-core parameter (default: 0.5)

    RETURNS

    system (System)
    coordinates (numpy array)
    epsilon (simtk.unit) - fundamental energy scale (change to argument?)
    sigma (simtk.unit) - fundamental length scale (change to argument?

    EXAMPLES

    Create a Kob-Andersen two-component mixture.

    >>> epsilon     = 119.8 * units.kelvin * units.BOLTZMANN_CONSTANT_kB * units.AVOGADRO_CONSTANT_NA # arbitrary reference energy    
    >>> [system, coordinates] = KobAndersen(epsilon=epsilon)

    Create softcore Kob-Andersen two-component mixture with alchemical perturbation.

    >>> [system, coordinates] = KobAndersen(epsilon=epsilon, softcore=True, lambda_=0.0)

    Test the energy

    >>> # Create a Context.
    >>> kB = units.BOLTZMANN_CONSTANT_kB
    >>> NA = units.AVOGADRO_CONSTANT_NA
    >>> temperature = 0.6 * epsilon / kB / NA
    >>> collision_rate = 90.0 / units.picosecond
    >>> timestep = 1.0 * units.femtosecond    
    >>> integrator = openmm.LangevinIntegrator(temperature, collision_rate, timestep)
    >>> platform = openmm.Platform.getPlatformByName("OpenCL")
    >>> context = openmm.Context(system, integrator, platform)
    >>> # Set positions
    >>> context.setPositions(coordinates)
    >>> # Evaluate the potential energy.
    >>> state = context.getState(getEnergy=True)
    >>> reduced_potential = (state.getPotentialEnergy() / epsilon)
    >>> print reduced_potential

    Integrate dynamics

    >>> nsteps = 1000 # number of steps to integrate
    >>> integrator.step(nsteps)
    >>> # Retrieve configuration to make sure no coordinates are nan
    >>> state = context.getState(getPositions=True)
    >>> coordinates = state.getPositions(asNumpy=True)
    >>> if numpy.any(numpy.isnan(coordinates / units.nanometers)): raise Exception('some coordinates are nan after integration: %s' % str(coordinates))

    """
    
    # Choose OpenMM package.
    if mm is None:
        mm = simtk.chem.openmm

    # Set unit system based on Rowley, Nicholson, and Parsonage argon parameters.
    if mass    is None:  mass        = 39.948 * units.amu # arbitrary reference mass        
    if epsilon is None:  epsilon     = 119.8 * units.kelvin * units.BOLTZMANN_CONSTANT_kB * units.AVOGADRO_CONSTANT_NA # arbitrary reference energy    
    if sigma   is None:  sigma       = 0.3405 * units.nanometers # arbitrary reference lengthscale

    # Define LJ mixture parameters.
    epsilon_AA  = 1.0 * epsilon
    epsilon_AB  = 1.5 * epsilon
    epsilon_BB  = 0.5 * epsilon

    sigma_AA    = 1.0 * sigma
    sigma_AB    = 0.8 * sigma
    sigma_BB    = 0.88 * sigma

    # Determine number of atoms of each component
    if (NA is not None):
        # User has specified number of A components
        NB = N - NA
    else:
        # Compute number of A components
        NA = int(math.floor(A_fraction * N))
        NB = N - NA
        
    # Create system
    system = mm.System()

    # Compute total system volume.
    volume = NA * sigma**3 / principal_component_density
    
    # Make system cubic in dimension.
    length = volume**(1./3.)
    # TODO: Can we change this to use tuples or 3x3 array?
    a = units.Quantity(numpy.array([1.0, 0.0, 0.0], numpy.float32), units.nanometer) * length/units.nanometer
    b = units.Quantity(numpy.array([0.0, 1.0, 0.0], numpy.float32), units.nanometer) * length/units.nanometer
    c = units.Quantity(numpy.array([0.0, 0.0, 1.0], numpy.float32), units.nanometer) * length/units.nanometer
    system.setPeriodicBoxVectors(a, b, c)

    # Add particles to system.
    for n in range(NA):
        system.addParticle(mass)
    for n in range(NB):
        system.addParticle(mass)
            
    # Create nonbonded force term implementing Kob-Andersen two-component Lennard-Jones interaction.
    energy_expression = ""
    if not softcore:
        # Standard Lennard-Jones    
        energy_expression += '4.0*epsilon*((sigma/r)^12 - (sigma/r)^6) * step(2.5*sigma - r);'
    else:
        # Soft-core Lennard-Jones from Eq. 4 of Shirts and Pande, JCP 122:134508, 2005.
        # U_LJ(r) = \lambda 4 \epsilon_{ij} ( [alpha (1-lambda) + (r/sigma_{ij})^6]^(-2) - [alpha (1-lambda) + (r/sigma_{ij}^6]^-1 )
        energy_expression += '4.0*epsilon*lambda*inv*(inv - 1.0);'
        energy_expression += 'inv = (alpha*(1.0-lambda) + (r/sigma)^6)^(-1);' 
        
    # Add mixing rules for two types.
    energy_expression += "epsilon = epsilon0*(1.0*AA + 1.5*AB + 0.5*BB);"
    energy_expression += "sigma = sigma0*(1.0*AA + 0.8*AB + 0.88*BB);"
    energy_expression += "AB = 1.0 - AA - BB;"
    energy_expression += "AA = A1*A2;"
    energy_expression += "BB = B1*B2;"

    force = mm.CustomNonbondedForce(energy_expression)

    # Add alchemical global parameters.
    if softcore:
        force.addGlobalParameter('alpha', alpha)
        force.addGlobalParameter('lambda', lambda_)

    # Set epsilon0 and sigma0 global parameters.
    force.addGlobalParameter('epsilon0', epsilon)
    force.addGlobalParameter('sigma0', sigma)

    # Add per-particle parameters to indicate whether each particle is type A or B.
    force.addPerParticleParameter('A')
    force.addPerParticleParameter('B')

    # Add A and B particle identifications.
    for n in range(NA):
        force.addParticle((1.0, 0.0))
    for n in range(NB):
        force.addParticle((0.0, 1.0))

    # Set periodic boundary conditions with cutoff.
    force.setNonbondedMethod(mm.CustomNonbondedForce.CutoffPeriodic)
    force.setCutoffDistance(2.5 * sigma_AA)
    
    # Add nonbonded force term to the system.
    system.addForce(force)

    # Create initial coordinates using a Sobol' subrandom sequence in three dimensions.
    coordinates = numpy.zeros([N,3], numpy.float32)
    from QuasiRandom import SobolSequence
    sobol = SobolSequence(3)
    for n in range(N):
        coordinates[n,:] = sobol()
    coordinates = units.Quantity(coordinates, units.nanometer) * (length / units.nanometer)
       
    # Return system and coordinates.
    return (system, coordinates)


#=============================================================================================
# Exceptions
#=============================================================================================

class NotImplementedException(Exception):
    """
    Exception denoting that the requested feature has not yet been implemented.

    """
    pass

class ParameterException(Exception):
    """
    Exception denoting that a parameter has been incorrectly furnished.

    """
    pass


#=============================================================================================
# UTILITY FUNCTIONS
#=============================================================================================

def write_pdb(filename, trajectory, atoms):
    """Write out replica trajectories as multi-model PDB files.

    ARGUMENTS
       filename (string) - name of PDB file to be written
       trajectory (Trajectory)
       atoms (list of dict) - parsed PDB file ATOM entries from read_pdb() - WILL BE CHANGED
    """

    # Create file.
    outfile = open(filename, 'w')

    nframes = len(trajectory)

    # Write trajectory as models
    for frame_index in range(nframes):
        outfile.write("MODEL     %4d\n" % (frame_index+1))
        coordinates = trajectory[frame_index].coordinates
        # Write ATOM records.
        for (index, atom) in enumerate(atoms):
            atom["x"] = "%8.3f" % (coordinates[index,0] / units.angstroms)
            atom["y"] = "%8.3f" % (coordinates[index,1] / units.angstroms)
            atom["z"] = "%8.3f" % (coordinates[index,2] / units.angstroms)
            outfile.write('ATOM  %(serial)5s %(atom)4s%(altLoc)c%(resName)3s %(chainID)c%(Seqno)5s   %(x)8s%(y)8s%(z)8s\n' % atom)

        outfile.write("ENDMDL\n")
        
    # Close file.
    outfile.close()

    return

def construct_atom_list(N, NA):
    """Write out replica trajectories as multi-model PDB files.

    ARGUMENTS
       atoms (list of dict) - parsed PDB file ATOM entries from read_pdb() - WILL BE CHANGED
       filename (string) - name of PDB file to be written
       title (string) - the title to give each PDB file
       ncfile (NetCDF) - NetCDF file object for input file       

    """

    NB = N - NA
    atoms = list()

    index = 1
    for n in range(NA):
        atom = dict()
        atom['serial'] = index
        atom['atom'] = ' Ar '
        atom['altLoc'] = ' '
        atom['resName'] = 'Ar '
        atom['chainID'] = ' '
        atom['Seqno'] = '%5d' % index        
        index += 1
        atoms.append(atom)        
    for n in range(NB):
        atom = dict()
        atom['serial'] = index
        atom['atom'] = ' He '
        atom['altLoc'] = ' '
        atom['resName'] = 'He '
        atom['chainID'] = ' '
        atom['Seqno'] = '%5d' % index        
        index += 1
        atoms.append(atom)
        
    return atoms

def testKobAndersen():
    """
    Test the energy of the Kob-Andersen system.

    """
    # Set unit system based on Rowley, Nicholson, and Parsonage argon parameters.
    mass        = 39.948 * units.amu # arbitrary reference mass        
    epsilon     = 119.8 * units.kelvin * units.BOLTZMANN_CONSTANT_kB * units.AVOGADRO_CONSTANT_NA # arbitrary reference energy    
    sigma       = 0.3405 * units.nanometers # arbitrary reference lengthscale
    print "mass = %s\nepsilon = %s\nsigma = %s\n" % (str(mass), str(epsilon), str(sigma))

    # Set mixture properties.
    N = 150
    A_fraction = 0.8
    NA = int(math.floor(A_fraction * N))
    print "N = %d, NA = %d" % (N, NA)
    
    # Compute total system volume.
    principal_component_density = 0.96
    volume = NA * sigma**3 / principal_component_density
    print "volume = %s" % str(volume)
    
    # Make system cubic in dimension.
    length = volume**(1./3.)
    print "length = %s" % str(length)

    # Set the temperature.
    kB = units.BOLTZMANN_CONSTANT_kB * units.AVOGADRO_CONSTANT_NA
    temperature = 0.7 * epsilon / kB
    print "temperature = %s" % str(temperature)    

    # Choose a platform.
    platform = openmm.Platform.getPlatformByName("OpenCL")

    # Create a system.
    [system, coordinates] = KobAndersen(mass=mass, epsilon=epsilon, sigma=sigma, N=N, A_fraction=A_fraction, principal_component_density=principal_component_density)

    # Optimize the coordinates.
    import optimize
    print "Minimizing..."
    minimizer = optimize.LBFGSMinimizer(system, verbose=True, platform=platform)
    coordinates = minimizer.minimize(coordinates)
    del minimizer    

    # Compute the potential energy via OpenMM.
    # Create a Context.
    collision_rate = 90.0 / units.picosecond
    timestep = 1.0 * units.femtosecond    
    integrator = openmm.LangevinIntegrator(temperature, collision_rate, timestep)
    context = openmm.Context(system, integrator, platform)
    # Set positions
    context.setPositions(coordinates)
    # Evaluate the potential energy.
    state = context.getState(getEnergy=True, getForces=True)
    openmm_reduced_potential = (state.getPotentialEnergy() / epsilon)
    print "OpenMM reduced potential: %s" % str(openmm_reduced_potential)
    openmm_force = state.getForces(asNumpy=True)

    # Compute the potential energy directly.

    # Set LJ mixture parameters.
    epsilon_AA  = 1.0 * epsilon
    epsilon_AB  = 1.5 * epsilon
    epsilon_BB  = 0.5 * epsilon

    sigma_AA    = 1.0 * sigma
    sigma_AB    = 0.8 * sigma
    sigma_BB    = 0.88 * sigma

    epsilon_ij = units.Quantity(numpy.zeros([N,N]), units.kilojoules_per_mole)
    epsilon_ij[0:NA,0:NA] = epsilon_AA
    epsilon_ij[NA:N,0:NA] = epsilon_AB
    epsilon_ij[0:NA,NA:N] = epsilon_AB
    epsilon_ij[NA:N,NA:N] = epsilon_BB

    sigma_ij = units.Quantity(numpy.zeros([N,N]), units.nanometers)
    sigma_ij[0:NA,0:NA] = sigma_AA
    sigma_ij[NA:N,0:NA] = sigma_AB
    sigma_ij[0:NA,NA:N] = sigma_AB
    sigma_ij[NA:N,NA:N] = sigma_BB
    
    cutoff_ij = 2.5 * sigma_ij
    
    U = 0.0 * units.kilojoules_per_mole
    gradient = units.Quantity(numpy.zeros([N,3]), units.kilojoules_per_mole / units.nanometer)
    for i in range(N):
        for j in range(i+1,N):
            # Compute minimum-image distance between particles i and j.
            delta = coordinates[j,:] - coordinates[i,:]
            for k in range(3):
                while (delta[k] >= length / 2.0): delta[k] -= length
                while (delta[k] < -length / 2.0): delta[k] += length
            r = units.sqrt(delta[0]**2 + delta[1]**2 + delta[2]**2)
            nij = delta[:] / r
            # Accumulate contribution to potential energy if within cutoff.
            if (r < cutoff_ij[i,j]):
                U += 4.0 * epsilon_ij[i,j] * ( (sigma_ij[i,j]/r)**12 - (sigma_ij[i,j]/r)**6 )
                # dU/dx[j,k] = dU/dr dr/dx[j,k] = dU/dr d(( (x[j,0] - x[i,0])**2 + x[j,1] - x[i,1])**2 + x[j,2] - x[i,2])**2 )**(1/2) )/dx[j,k]
                #            = dU/dr (1/2) ( (x[j,0] - x[i,0])**2 + x[j,1] - x[i,1])**2 + x[j,2] - x[i,2])**2 )**(-1/2) * 2*(x[j,k] - x[i,k]) * (+1)
                #            = dU/dr nij[k]
                dU_dr = 4.0 * epsilon_ij[i,j] * ( - 12*(sigma_ij[i,j]/r)**12  + 6*(sigma_ij[i,j]/r)**6 ) / r
                for k in range(3):
                    gradient[i,k] += - float(nij[k]) * dU_dr 
                    gradient[j,k] += + float(nij[k]) * dU_dr 
    force = - gradient

    direct_reduced_potential = (U / epsilon)
    print "direct reduced potential: %s" % str(direct_reduced_potential)

    # Compare forces
    for i in range(N):
        print "%5d : " % i,
        for k in range(3):
            print "%12.3f" % (openmm_force[i,k] / (units.kilojoules_per_mole / units.nanometer)),
        print " : ",
        for k in range(3):
            print "%12.3f" % (force[i,k] / (units.kilojoules_per_mole / units.nanometer)),            
        print ""
        
    return

#=============================================================================================
# SIMULATION SNAPSHOT 
#=============================================================================================

class Snapshot(object):
    """
    Simulation snapshot.

    """
    def __init__(self, context=None, coordinates=None, velocities=None, box_vectors=None, potential_energy=None, kinetic_energy=None):
        """
        Initialize a simulation snapshot.

        """

        if context is not None:
            # Get current state from OpenMM Context object.
            # TODO: Also get box vectors when these are available through interface extensions.
            state = context.getState(getPositions=True, getVelocities=True, getEnergy=True)
            
            # Populate data structures
            self.coordinates = state.getPositions(asNumpy=True)
            self.velocities = state.getVelocities(asNumpy=True)
            self.box_vectors = None
            self.potential_energy = state.getPotentialEnergy()
            self.kinetic_energy = state.getKineticEnergy()
        else:
            self.coordinates = copy.deepcopy(coordinates)
            self.velocities = copy.deepcopy(velocities)
            self.box_vectors = copy.deepcopy(box_vectors)
            self.potential_energy = copy.deepcopy(potential_energy)
            self.kinetic_energy = copy.deepcopy(kinetic_energy)                       

        # Check for nans in coordinates.
        if numpy.any(numpy.isnan(self.coordinates)):
            raise Exception("Some coordinates became 'nan'; simulation is unstable or buggy.")

        return

    @property
    def total_energy(self):
        return self.kinetic_energy + self.potential_energy

#=============================================================================================
# SIMULATION TRAJECTORY
#=============================================================================================

class Trajectory(list):
    """
    Simulation trajectory.

    """
    def __init__(self, trajectory=None):
        """
        Initialize a simulation trajectory.
        
        """

        # Initialize trajectory as a list.
        self.trajectory = list()

        if trajectory is not None:
            try:
                # Try to make a copy out of whatever container we were provided
                # TODO: Check that each snapshot is a Snapshot object (or supports the same interface)?
                self.trajectory = [ copy.deepcopy(snapshot) for snapshot in trajectory ]
            except:
                # We were provided with a single snapshot.
                self.trajectory = [ trajectory ]

        return

    def reverse(self):
        """
        Reverse the trajectory.

        """
        # Reverse the trajectory
        self.trajectory.reverse()

        # Recalculate kinetic energies for the *beginning* of each trajectory segment.
        # This makes use of the fact that the energy is (approximately) conserved over each trajectory segment, in between velocity randomizations.
        nsnapshots = len(self.trajectory)
        for t in range(nsnapshots-1):
            self.trajectory[t].kinetic_energy = self.trajectory[t+1].total_energy - self.trajectory[t].potential_energy

        return
    
#=============================================================================================
# DYNAMICS FOR KOB-ANDERSEN
#=============================================================================================

class TrajectorySampler(object):
    """
    Constant-temperature trajectory sampling using Verlet dynamics with Andersen thermostat.

    """

    def __init__(self, system, N, NA, mass, epsilon, sigma, timestep, nsteps_per_frame, nframes, temperature):
        """
        Initialize a constant-temperature trajectory sampler.

        ARGUMENTS

        system - the system
        N - the number of particles
        NA - the number of particles of type A
        mass - mass scale
        epsilon - per-particle energy scale
        sigma - length scale
        timestep (units.Quantity) - timestep to use
        nsteps_per_frame - number of steps per frame
        nframes - number of frames per trajectory
        temperature - reduced temperature of this sampler        
        
        """

        # Store local copy of System.
        self.system = system

        self.mass = mass
        self.epsilon = epsilon
        self.sigma = sigma

        self.timestep = timestep
        self.nsteps_per_frame = nsteps_per_frame
        self.delta_t = timestep * nsteps_per_frame

        self.temperature = temperature

        # Compute thermal energy and inverse temperature from specified temperature.
        self.kT = kB * self.temperature # thermal energy
        self.beta = 1.0 / self.kT # inverse temperature

        # Store number of A particles.
        self.N = N
        self.NA = NA

        # Create a Verlet integrator.
        self.integrator = openmm.VerletIntegrator(self.timestep)

        # Create a Context for integration.
        self.platform = openmm.Platform.getPlatformByName("OpenCL")
        self.context = openmm.Context(self.system, self.integrator, self.platform)

        # Store reduced units
        self.t_obs = nframes * self.delta_t
        self.s_reduced_unit = 1.0 /  (self.sigma**2 * self.delta_t)
        self.K_reduced_unit = (self.N * self.t_obs * self.sigma**2)

        return

    def assignMaxwellBoltzmannVelocities(self, remove_com_velocity=False):
        """Generate Maxwell-Boltzmann velocities.

        @param system the system for which velocities are to be assigned
        @type simtk.chem.openmm.System or System

        @param temperature the temperature at which velocities are to be assigned
        @type Quantity with units of temperature

        @return velocities drawn from the Maxwell-Boltzmann distribution at the appropriate temperature
        @returntype (natoms x 3) numpy array wrapped in Quantity with units of velocity

        TODO

        This could be sped up by introducing vector operations.

        """

        # Get number of atoms
        natoms = self.system.getNumParticles()

        # Create storage for velocities.        
        velocities = units.Quantity(numpy.zeros([natoms, 3], numpy.float32), units.nanometer / units.picosecond) # velocities[i,k] is the kth component of the velocity of atom i

        # Assign velocities from the Maxwell-Boltzmann distribution.
        for atom_index in range(natoms):
            mass = self.system.getParticleMass(atom_index) # atomic mass
            sigma = units.sqrt(self.kT / mass) # standard deviation of velocity distribution for each coordinate for this atom
            for k in range(3):
                velocities[atom_index,k] = sigma * numpy.random.normal()

        if remove_com_velocity:
            # Remove center of mass velocity
            com_velocity = units.Quantity((velocities / (units.nanometers / units.picoseconds)).mean(0), units.nanometers/units.picoseconds)
            for atom_index in range(natoms):
                velocities[atom_index,:] -= com_velocity

        # Return velocities
        return velocities

    def generateTrajectory(self, x0, nframes):
        """
        Generate a trajectory consisting of ntau segments of tau_steps.

        ARGUMENTS
        
        x0 (coordinate set) - initial coordinates
        nframes (int) - number of trajectory segments to generate

        RETURNS

        trajectory (list) - generated trajectory of initial conditions, including initial coordinate set
        
        """

        # Set initial positions
        self.context.setPositions(x0)

        # Store initial state for each trajectory segment in trajectory.
        trajectory = Trajectory()

        # Generate trajectory segments.
        for frame_index in range(nframes):
            #print "tau segment %d / %d" % (frame_index, nframes)

            # Assign velocities from Maxwell-Boltzmann distribution
            velocities = self.assignMaxwellBoltzmannVelocities(remove_com_velocity=True)            
            self.context.setVelocities(velocities)

            # Store initial snapshot of trajectory segment.
            snapshot = Snapshot(context=self.context)
            trajectory.append(snapshot)

            # Propagate dynamics.
            self.integrator.step(self.nsteps_per_frame)

        # Store final snapshot of trajectory.
        snapshot = Snapshot(self.context)
        trajectory.append(snapshot)

        return trajectory

    def logEquilibriumTrajectoryProbability(self, trajectory):
        """
        Compute the log equilibrium probability (up to an unknown additive constant) of an unbiased trajectory evolved according to Verlet dynamics with Andersen thermostatting.

        ARGUMENTS

        trajectory (Trajectory) - the trajectory

        RETURNS

        log_q (float) - the log equilibrium probability of the trajectory

        """

        nsnapshots = len(trajectory)
        log_q = - self.beta * trajectory[0].total_energy
        for snapshot_index in range(1, nsnapshots-1):
            log_q += - self.beta * trajectory[snapshot_index].kinetic_energy

        return log_q

    def computeActivity(self, trajectory):
        """
        Compute the activity of a given trajectory, defined in Ref. [1] as

        K[x(t)] = delta_t \sum_{t=0}^{t_obs} \sum_{j=1}^N [r_j(t+delta_t) - r_j(t)]^2

        RETURNS

        K (simtk.unit) - activity K[x(t)] for the specified trajectory

        """

        # Determine number of frames in trajectory.
        nframes = len(trajectory)

        # Compute activity of component A.
        K = 0.0 * self.delta_t * units.nanometers**2
        for frame_index in range(nframes-1):
            # Compute displacement of all atoms.
            delta_r = trajectory[frame_index+1].coordinates - trajectory[frame_index].coordinates
            # Compute contribution to activity K.
            K += self.delta_t * ((delta_r[0:self.NA,:] / units.nanometers)**2).sum() * (units.nanometers**2)

        return K 

    def sampleTrajectory(self, trajectory):
        """
        Generate a new trajectory sample using last frame of provided trajectory.

        ARGUMENTS

        trajectory - a previous sample from the state

        RETURN VALUES

        trajectory - new sampled trajectory (correlated with previous trajectory sample)

        """

        # TODO: Run a bit in between to decorrelate trajectories?
        
        # Generate a new trajectory.
        new_trajectory = self.generateTrajectory(trajectory[-1].coordinates, nframes)

        return new_trajectory

#=============================================================================================
# PARALLEL TEMPERING TRAJECTORY SAMPLER
#=============================================================================================

class ParallelTemperingTrajectorySampler(object):
    """
    Parallel tempering trajectory sampler.

    """

    def __init__(self, system, ensembles, trajectories, ncfilename):
        """
        Initialize replica-exchange transition path sampling.

        ARGUMENTS
          system (simtk.chem.openmm.State) - system object
          ensemble (list of TransitionPathSampling objects)
          trajectories - trajectory or list of trajectories to initialize TPS with
          ncfilename - name of NetCDF file to create or resume from

        TODO
          * Implement a more general method where we have generateTrajectory(x0) report P(x0,x1,...,xT) and divide by P(x0) as appropriate.

        """
        # Store state information.
        self.system = system
        self.natoms = system.getNumParticles()

        # Store TransitionPathSampling objects
        self.ensembles = ensembles
        self.nstates = len(ensembles)

        # Distribute trajectories (and segment energies).
        self.trajectories = list()
        try:
            # Test if trajectories is an indexable set of Trajectory objects.
            shape = trajectories[0][0].coordinates.shape # size of first snapshot of first trajectory
            self.trajectories = [ copy.deepcopy(trajectories[index % len(trajectories)]) for index in range(self.nstates) ]
        except:
            try:
                # Test if trajectories is a single Trajectory object.
                shape = trajectories[0].coordinates.shape
                self.trajectories = [ copy.deepcopy(trajectories) for index in range(self.nstates) ]
            except:
                raise Exception("Don't know what to do with 'trajectories' object; doesn't seem to be a Trajectory object or list of Trajectories.")

        # Store number of frames per trajectory.
        self.nframes = len(self.trajectories[0])

        # Set default options.
        self.number_of_iterations = 50
        self.title = 'Replica-exchange simulation'
        self.store_filename = ncfilename
        self.verbose = True
        
        # Flag as uninitialized.
        self._initialized = False

        return

    def _initialize(self):
        """
        Initialize the simulation, and bind to a storage file.

        """
        if self._initialized:
            print "Replica-exchange simulation has already been initialized."
            raise Error

        print "Initializing..."

        # Allocate storage.
        self.replica_states     = numpy.zeros([self.nstates], numpy.int32) # replica_states[i] is the state that replica i is currently at
        for state_index in range(self.nstates):
            self.replica_states[state_index] = state_index
        
        self.log_P_kl           = numpy.zeros([self.nstates,self.nstates], numpy.float32) # log_P_kl[k,l] is log bias probability of replica k in state l
        self.swap_Pij_accepted  = numpy.zeros([self.nstates, self.nstates], numpy.float32) # swap_Pij_accepted[i,j] is fraction of swap attempts between states i and j that were accepted during last mixing phase
        self.log_q_k            = numpy.zeros([self.nstates], numpy.float64) # log_q_k[k] is the log equilibrium probability of trajectory k at zero field s, up to an unknown additive constant
        self.activities         = [ None for replica_index in range(self.nstates) ] # activities[i] is activity of replica i

        # Check if netcdf file extists.
        if os.path.exists(self.store_filename) and (os.path.getsize(self.store_filename) > 0):
            # Resume from NetCDF file.
            self._resume_from_netcdf()
        else:
            # Initialize current iteration counter.
            self.iteration = 0
            
            # Compute energies of all alchemical replicas
            self._compute_activities()
            
            # Initialize NetCDF file.
            self._initialize_netcdf()

            # Store initial state.
            self._write_iteration_netcdf()
  
        # Signal that the class has been initialized.
        self._initialized = True

        return

    def _finalize(self):
        """
        Do anything necessary to clean up.

        """

        self.ncfile.close()

        return

    def run(self):
        """
        Run the replica-exchange simulation.

        Any parameter changes that were made between object creation and calling this method become locked in at this point,
        and the object will create and bind to the store file.

        """

        # Make sure we've initialized everything and bound to a storage file before we begin execution.
        if not self._initialized:
            self._initialize()

        # Main loop
        while (self.iteration < self.number_of_iterations):
            start_time = time.time()
            if self.verbose: print "\nIteration %d / %d" % (self.iteration+1, self.number_of_iterations)

            # Attempt replica swaps to sample from equilibrium permuation of states associated with replicas.
            self._mix_replicas()

            # Propagate replicas.
            self._propagate_replicas()

            # Compute energies of all replicas at all states.
            self._compute_activities()

            # Write to storage file.
            self._write_iteration_netcdf()
            
            # Increment iteration counter.
            self.iteration += 1
            end_time = time.time()
            if self.verbose: print "Iteration took %.3f s" % (end_time - start_time)

        # Clean up and close storage files.
        self._finalize()

        return

    def _propagate_replicas(self):
        """
        Propagate all replicas.

        TODO

        * Parallel implementation

        """

        # Propagate all replicas.
        for replica_index in range(self.nstates):
            state_index = self.replica_states[replica_index]
            trajectory = self.trajectories[replica_index]
            trajectory = self.ensembles[state_index].sampleTrajectory(trajectory)
            self.trajectories[replica_index] = trajectory

        return

    def _compute_activities(self):
        """
        Compute activities of all replicas.
        
        """

        print "Computing activities..."

        print "Computing trajectory probabilities..."

        for replica_index in range(self.nstates):
            trajectory = self.trajectories[replica_index]
            self.log_q_k[replica_index] = self.ensembles[replica_index].logEquilibriumTrajectoryProbability(trajectory)

        # Compute activities for all replicas.
        for replica_index in range(self.nstates):
            trajectory = self.trajectories[replica_index]
            self.activities[replica_index] = self.ensembles[replica_index].computeActivity(trajectory)
            #print "replica %5d, x0 = %42s, K = %32s" % (replica_index, str(trajectory[0].coordinates[0,:]), str(self.activities[replica_index] / self.ensembles[replica_index].K_reduced_unit))

        # Compute log biasing probabilities for all replicas in all states.
        for replica_index in range(self.nstates):
            K = self.activities[replica_index]
            for state_index in range(self.nstates):
                s = self.ensembles[state_index].s
                self.log_P_kl[replica_index,state_index] = self.log_q_k[replica_index] - s * K

        if self.verbose:
            print "states = "
            for replica_index in range(self.nstates):
                print "%6d" % self.replica_states[replica_index],
            print ""
            print "activities = "
            for replica_index in range(self.nstates):
                print "%6.3f" % (self.activities[replica_index] / self.ensembles[replica_index].K_reduced_unit),
            print ""
                
        return

    def _mix_replicas(self):
        """
        Attempt exchanges between all replicas to enhance mixing.

        NOTES

        This function might be slow in pure Python, so it may be necessary to re-code this in a compiled language.
        We certainly don't want this function to take a substantial fraction of the iteration time.
        
        """

        start_time = time.time()

        print "self.nstates = %d" % self.nstates

        # Determine number of swaps to attempt to ensure thorough mixing.
        # TODO: Replace this with analytical result computed to guarantee sufficient mixing.
        nswap_attempts = self.nstates**5 # number of swaps to attempt
        
        # Allocate storage to keep track of mixing.
        Nij_proposed = numpy.zeros([self.nstates,self.nstates], numpy.float32) # Nij_proposed[i][j] is the number of swaps proposed between states i and j, prior of 1
        Nij_accepted = numpy.zeros([self.nstates,self.nstates], numpy.float32) # Nij_proposed[i][j] is the number of swaps proposed between states i and j, prior of 1

        # Show log P
        if self.verbose:
            print "log_P[replica,state] ="
            print "%6s" % "",
            for jstate in range(self.nstates):
                print "%6d" % jstate,
            print ""
            for ireplica in range(self.nstates):
                print "%-6d" % ireplica,
                for jstate in range(self.nstates):
                    log_P = self.log_P_kl[ireplica,jstate]
                    print "%8.3f" % log_P,
                print ""

        # Attempt swaps to mix replicas.
        nswaps_accepted = 0
        for swap_attempt in range(nswap_attempts):
            # Choose replicas to attempt to swap.
            i = numpy.random.randint(self.nstates) # Choose replica i uniformly from set of replicas.
            j = numpy.random.randint(self.nstates) # Choose replica j uniformly from set of replicas.

            # Determine which states these resplicas correspond to.
            istate = self.replica_states[i] # state in replica slot i
            jstate = self.replica_states[j] # state in replica slot j

            # Compute log probability of swap.
            log_P_accept = (self.log_P_kl[i,jstate] + self.log_P_kl[j,istate]) - (self.log_P_kl[i,istate] + self.log_P_kl[j,jstate])

            #print "replica (%3d,%3d) states (%3d,%3d) energies (%8.1f,%8.1f) %8.1f -> (%8.1f,%8.1f) %8.1f : log_P_accept %8.1f" % (i,j,istate,jstate,self.u_kl[i,istate],self.u_kl[j,jstate],self.u_kl[i,istate]+self.u_kl[j,jstate],self.u_kl[i,jstate],self.u_kl[j,istate],self.u_kl[i,jstate]+self.u_kl[j,istate],log_P_accept)
            
            # Record that this move has been proposed.
            Nij_proposed[istate,jstate] += 0.5
            Nij_proposed[jstate,istate] += 0.5

            # Accept or reject.
            if (log_P_accept >= 0.0 or (numpy.random.rand() < math.exp(log_P_accept))):
                # Swap states in replica slots i and j.
                (self.replica_states[i], self.replica_states[j]) = (self.replica_states[j], self.replica_states[i])
                # Accumulate statistics
                Nij_accepted[istate,jstate] += 0.5
                Nij_accepted[jstate,istate] += 0.5
                nswaps_accepted += 1

        # Report statistics of acceptance.
        swap_fraction_accepted = float(nswaps_accepted) / float(nswap_attempts);
  
        # Estimate transition probabilities between all states.
        swap_Pij_accepted = numpy.zeros([self.nstates,self.nstates], numpy.float32)
        for istate in range(self.nstates):
            for jstate in range(self.nstates):
                if (Nij_proposed[istate,jstate] > 0.0):
                    swap_Pij_accepted[istate,jstate] = Nij_accepted[istate,jstate] / Nij_proposed[istate,jstate]
                else:
                    swap_Pij_accepted[istate,jstate] = 0.0
        self.swap_Pij_accepted = swap_Pij_accepted

        end_time = time.time()

        # Report on mixing.
        # TODO: Add this behind a verbose flag.
        if self.verbose:
            PRINT_CUTOFF = 0.001 # Cutoff for displaying fraction of accepted swaps.
            print "Fraction of accepted swaps between states:"
            print "%6s" % "",
            for jstate in range(self.nstates):
                print "%6d" % jstate,
            print ""
            for istate in range(self.nstates):
                print "%-6d" % istate,
                for jstate in range(self.nstates):
                    P = self.swap_Pij_accepted[istate,jstate]
                    if (P >= PRINT_CUTOFF):
                        print "%6.3f" % P,
                    else:
                        print "%6s" % "",
                print ""
            print "Mixing of replicas took %.3f s" % (end_time - start_time)

        return


    def _initialize_netcdf(self):
        """
        Initialize NetCDF file for storage.
        
        """    

        # Open NetCDF file for writing
        ncfile = netcdf.NetCDFFile(self.store_filename, 'w')

        # Create dimensions.
        ncfile.createDimension('iteration', 0) # unlimited number of iterations
        ncfile.createDimension('replica', self.nstates) # number of replicas
        ncfile.createDimension('frame', self.nframes) # number of frames per trajectory
        ncfile.createDimension('atom', self.natoms) # number of atoms in system
        ncfile.createDimension('spatial', 3) # number of spatial dimensions

        # Set global attributes.
        setattr(ncfile, 'tile', self.title)
        setattr(ncfile, 'application', 'KobAndersen')
        setattr(ncfile, 'program', 'KobAndersen.py')
        setattr(ncfile, 'programVersion', __version__)
        setattr(ncfile, 'Conventions', 'ReplicaExchangeTPS')
        setattr(ncfile, 'ConventionVersion', '0.1')

        ensemble = self.ensembles[0]
        setattr(ncfile, 'sKfactor', ensemble.N * ensemble.t_obs / ensemble.delta_t)
        
        # Create variables.
        ncvar_fields      = ncfile.createVariable('fields', 'f', ('replica',))

        ncvar_trajectory_coordinates = ncfile.createVariable('trajectory_coordinates', 'f', ('replica','frame','atom','spatial'))
        ncvar_trajectory_velocities  = ncfile.createVariable('trajectory_velocities',  'f', ('replica','frame','atom','spatial'))
        ncvar_trajectory_potential   = ncfile.createVariable('trajectory_potential',   'f', ('replica','frame'))
        ncvar_trajectory_kinetic     = ncfile.createVariable('trajectory_kinetic',     'f', ('replica','frame'))

        ncvar_states      = ncfile.createVariable('states', 'i', ('iteration','replica'))
        ncvar_activities  = ncfile.createVariable('activities', 'f', ('iteration','replica'))
        ncvar_log_probabilities = ncfile.createVariable('log_probabilities', 'f', ('iteration','replica','replica'))
        ncvar_mixing      = ncfile.createVariable('mixing', 'f', ('iteration','replica','replica'))
        
        # Define units for variables.
        setattr(ncvar_trajectory_coordinates, 'units', 'nm')
        setattr(ncvar_trajectory_velocities,  'units', 'nm/ps')
        setattr(ncvar_trajectory_potential,   'units', 'kJ/mol')
        setattr(ncvar_trajectory_kinetic,     'units', 'kJ/mol')
        
        setattr(ncvar_states,    'units', 'none')
        setattr(ncvar_mixing,    'units', 'none')
        # TODO: fields and activities

        # Set display formatting attributes.
        #setattr(ncvar_trajectories, 'C_format', r'%9.4f')
        #setattr(ncvar_activities, 'C_format', r'%5.3f')
        
        # Define long (human-readable) names for variables.
        setattr(ncvar_states,    "long_name", "states[iteration][replica] is the state index (0..nstates-1) of replica 'replica' of iteration 'iteration'.")
        setattr(ncvar_mixing,    "long_name", "mixing[iteration][i][j] is the fraction of proposed transitions between states i and j that were accepted during mixing using the coordinates from iteration 'iteration-1'.")
        # TODO

        # Store fields.
        for state_index in range(self.nstates):
            s = self.ensembles[state_index].s
            s_reduced_unit = self.ensembles[state_index].s_reduced_unit
            ncfile.variables['fields'][state_index] = s / s_reduced_unit

        # Force sync to disk to avoid data loss.
        ncfile.sync()

        # Store netcdf file handle.
        self.ncfile = ncfile
        
        return
    
    def _write_iteration_netcdf(self):
        """
        Write positions, states, and energies of current iteration to NetCDF file.
        
        """

        # DEBUG
        ensemble = self.ensembles[0]
        setattr(self.ncfile, 'sKfactor', ensemble.N * ensemble.t_obs / ensemble.delta_t)        

        # Store trajectories.
        for replica_index in range(self.nstates):
            trajectory = self.trajectories[replica_index]
            for frame_index in range(self.nframes):                
                self.ncfile.variables['trajectory_coordinates'][replica_index,frame_index,:,:] = (trajectory[frame_index].coordinates / units.nanometers).astype(numpy.float32)
                self.ncfile.variables['trajectory_velocities'][replica_index,frame_index,:,:] = (trajectory[frame_index].velocities / (units.nanometers / units.picoseconds)).astype(numpy.float32)
                self.ncfile.variables['trajectory_potential'][replica_index,frame_index] = trajectory[frame_index].potential_energy / units.kilojoules_per_mole                                
                self.ncfile.variables['trajectory_kinetic'][replica_index,frame_index] = trajectory[frame_index].kinetic_energy / units.kilojoules_per_mole

        # Store state information.
        self.ncfile.variables['states'][self.iteration,:] = self.replica_states[:]

        # Store activities.
        for replica_index in range(self.nstates):
            state_index = self.replica_states[replica_index]
            K = self.activities[replica_index]
            K_reduced_unit = self.ensembles[state_index].K_reduced_unit
            self.ncfile.variables['activities'][self.iteration,replica_index] = K / K_reduced_unit

        # Store log probabilities.
        self.ncfile.variables['log_probabilities'][self.iteration,:,:] = self.log_P_kl

        # Store mixing statistics.
        self.ncfile.variables['mixing'][self.iteration,:,:] = self.swap_Pij_accepted[:,:]

        # Force sync to disk to avoid data loss.
        self.ncfile.sync()

        return

    def _resume_from_netcdf(self):
        """
        Resume execution by reading current positions and energies from a NetCDF file.
        
        """

        # Open NetCDF file for reading
        ncfile = netcdf.NetCDFFile(self.store_filename, 'r')
        
        # TODO: Perform sanity check on file before resuming

        # Get current dimensions.
        self.iteration = ncfile.variables['activities'].shape[0] - 1
        self.nstates = ncfile.variables['activities'].shape[1]
        self.nframes = ncfile.variables['trajectory_coordinates'].shape[1]
        self.natoms = ncfile.variables['trajectory_coordinates'].shape[2]

        print "iteration = %d, nstates = %d, natoms = %d" % (self.iteration, self.nstates, self.natoms)

        # Restore trajectories.
        self.trajectories = list()
        for replica_index in range(self.nstates):
            trajectory = Trajectory()
            for frame_index in range(self.nframes):                
                x = ncfile.variables['trajectory_coordinates'][replica_index,frame_index,:,:].astype(numpy.float32).copy()
                coordinates = units.Quantity(x, units.nanometers)                
                v = ncfile.variables['trajectory_velocities'][replica_index,frame_index,:,:].astype(numpy.float32).copy()
                velocities = units.Quantity(v, units.nanometers / units.picoseconds)                
                V = ncfile.variables['trajectory_potential'][replica_index,frame_index]
                potential_energy = units.Quantity(V, units.kilojoules_per_mole)
                T = ncfile.variables['trajectory_kinetic'][replica_index,frame_index]
                kinetic_energy = units.Quantity(T, units.kilojoules_per_mole)
                snapshot = Snapshot(coordinates=coordinates, velocities=velocities, kinetic_energy=kinetic_energy, potential_energy=potential_energy)
                trajectory.append(snapshot)
            self.trajectories.append(trajectory)

        # Restore state information.
        self.replica_states = ncfile.variables['states'][self.iteration,:].copy()

        # Restore log probabilities.
        self.log_P_kl = ncfile.variables['log_probabilities'][self.iteration,:,:] 

        # Restore activities
        for replica_index in range(self.nstates):
            state_index = self.replica_states[replica_index]
            K_reduced_unit = self.ensembles[state_index].K_reduced_unit
            K = ncfile.variables['activities'][self.iteration,replica_index]
            self.activities[replica_index] = K * K_reduced_unit

        # Close NetCDF file.
        ncfile.close()        
        
        # Reopen NetCDF file for appending, and maintain handle.
        self.ncfile = netcdf.NetCDFFile(self.store_filename, 'a')

        # DEBUG: Set number of iterations to be a bit more than we've done.
        self.number_of_iterations = self.iteration + 50
        
        return

#=============================================================================================
# MAIN AND TESTS
#=============================================================================================

def run():
    """
    Run a parallel tempering simulation of the Kob-Andersen system.

    """
    # Get OpenCL platform.
    platform = openmm.Platform.getPlatformByName("OpenCL")    

    # Constant
    kB = units.BOLTZMANN_CONSTANT_kB * units.AVOGADRO_CONSTANT_NA

    # Create a Kob-Andersen two-component mixture.
    N = 150 # number of particles
    #N = 800 # number of particles    
    A_fraction = 0.8 # fraction of A component
    NA = int(math.floor(A_fraction * N)) # number of A component
    mass        = 39.948 * units.amu # arbitrary reference mass        
    epsilon     = 119.8 * units.kelvin * units.BOLTZMANN_CONSTANT_kB * units.AVOGADRO_CONSTANT_NA # arbitrary reference energy    
    sigma       = 0.3405 * units.nanometers # arbitrary reference lengthscale    
    [system, coordinates] = KobAndersen(N=N, NA=NA, mass=mass, epsilon=epsilon, sigma=sigma)    

    mode = 'test'
    # mode = 'production'

    # Relevant times and timescales
    reduced_time = units.sqrt((mass * sigma**2) / (48.0 * epsilon)) # factor on which all reduced times are based
    timestep = 0.035 * reduced_time # velocity Verlet timestep from Ref. [1], Supplementary Section 1.1
    delta_t = (40.0 / 3.0) * reduced_time # number of timesteps per Delta t in Lester's paper, Ref. [1] Supplementary Section 1.1, specified exactly by Lester in private communication

    if mode == 'test':
        quenched_temperature = 0.6 * epsilon / kB # temperature corresponding to left column of Fig. 2 from [1]
        tau = 200.0 * reduced_time # private communication from Lester Hedges (tau = 60 * reduced time for Tred = 0.7, 200 * reduced_time for Tred = 0.6)
        t_obs = 1 * tau # number of intervals of tau, from Fig. 2 of [1] (red dashed line, right column,) # should be 33 * tau for Tred = 0.7, 20*tau for Tred = 0.6
    elif mode == 'production':
        quenched_temperature = 0.7 * epsilon / kB # temperature corresponding to left column of Fig. 2 from [1]
        tau = 60.0 * reduced_time # private communication from Lester Hedges (tau = 60 * reduced time for Tred = 0.7, 200 * reduced_time for Tred = 0.6)
        t_obs = 33 * tau # number of intervals of tau, from Fig. 2 of [1] (red dashed line, right column,) # should be 33 * tau for Tred = 0.7, 20*tau for Tred = 0.6        
    else:
        raise Exception("Unknown mode: %s" % mode)

    print "UNCORRECTED TIMES"
    print "timestep = %s" % str(timestep.in_units_of(units.femtosecond))
    print "delta_t = %s (%f steps)" % (str(delta_t.in_units_of(units.picosecond)), delta_t / timestep)
    print "tau = %s (%f steps)" % (str(tau.in_units_of(units.picosecond)), tau / timestep)
    print "t_obs = %s (%f steps)" % (str(t_obs.in_units_of(units.picosecond)), t_obs / timestep)

    # Determine integral number of steps per velocity randomization and number of trajectories
    nsteps_per_frame = int(round(delta_t / timestep)) # number of steps per trajectory frame and velocity randomization
    nframes = int(round(t_obs / delta_t)) # number of frames per trajectory
    print "number of steps per trajectory frame and velocity randomization = %d" % nsteps_per_frame
    print "number of frames per trajectory = %d" % nframes
    # Correct delta_t and t_obs to be integral
    delta_t = nsteps_per_frame * timestep 
    t_obs = nframes * delta_t
    print "CORRECTED TIMES"
    print "timestep = %s" % str(timestep.in_units_of(units.femtosecond))
    print "delta_t = %s (%f steps)" % (str(delta_t.in_units_of(units.picosecond)), delta_t / timestep)
    print "tau = %s (%f steps)" % (str(tau.in_units_of(units.picosecond)), tau / timestep)
    print "t_obs = %s (%f steps)" % (str(t_obs.in_units_of(units.picosecond)), t_obs / timestep)

    # Decide whether we are initializing or resuming the simulation.
    minimize = True
    equilibrate = True
    quench = True
    seed = True

    # Minimize
    if minimize:
        print "Minimizing with L-BFGS..."
        # Initialize a minimizer with default options.
        import optimize
        minimizer = optimize.LBFGSMinimizer(system, verbose=True, platform=platform)
        # Minimize the initial coordinates.
        coordinates = minimizer.minimize(coordinates)
        # Clean up to release the Context.
        del minimizer

    # Set temperature for simulations.
    elevated_temperature = 2.0 * epsilon / kB

    if equilibrate:
        # Equilibrate at high temperature.
        collision_rate = 1.0 / delta_t
        nsteps = int(math.floor(t_obs / timestep))
        print "Equilibrating at %s for %d steps..." % (str(elevated_temperature), nsteps)
        integrator = openmm.LangevinIntegrator(elevated_temperature, collision_rate, timestep)
        context = openmm.Context(system, integrator, platform)
        context.setPositions(coordinates)
        integrator.step(nsteps)
        state = context.getState(getPositions=True)
        coordinates = state.getPositions(asNumpy=True)

    if quench:
        # Quench to final temperature
        collision_rate = 1.0 / delta_t
        nsteps = int(math.floor(t_obs / timestep))
        print "Quenching to %s for %d steps..." % (str(quenched_temperature), nsteps)
        integrator = openmm.LangevinIntegrator(quenched_temperature, collision_rate, timestep)
        context = openmm.Context(system, integrator, platform)
        context.setPositions(coordinates)
        integrator.step(nsteps)
        state = context.getState(getPositions=True)
        coordinates = state.getPositions(asNumpy=True)           

    # Specify reduced temperatures for replica-exchange.
    temperatures = [0.6, 0.7, 0.8, 0.9, 1.0, 1.5, 2.0] # temperatures are in reduced units

    # Create a number of transition path sampling ensembles at different values of the field parameter s.
    s_reduced_unit = 1.0 / (sigma**2 * delta_t)
    svalues = [0.0, 0.0]    
    ensembles = [ TransitionPathSampling(system, N, NA, mass, epsilon, sigma, timestep, nsteps_per_frame, nframes, quenched_temperature, s * s_reduced_unit) for s in svalues ]

    trajectory = None
    if seed:
        # Generate an initial trajectory at zero field
        print "Generating seed trajectory for TPS..."
        trajectory = ensembles[0].generateTrajectory(coordinates, nframes)

    # Initialize replica-exchange TPS simulation.
    print "Initializing replica-exchange TPS..."
    ncfilename = 'szero.nc'
    #ncfilename = 'szerolarge.nc'    
    simulation = ReplicaExchangeTPS(system, ensembles, trajectory, ncfilename)

    # Run simulation
    print "Running replica-exchange TPS..."
    simulation.run()
        
    print "Done."    
    return

def driver():
    #import doctest
    #doctest.testmod()

    # Get OpenCL platform.
    platform = openmm.Platform.getPlatformByName("OpenCL")    

    # Constant
    kB = units.BOLTZMANN_CONSTANT_kB * units.AVOGADRO_CONSTANT_NA

    # Create a Kob-Andersen two-component mixture.
    N = 150 # number of particles
    A_fraction = 0.8 # fraction of A component
    NA = int(math.floor(A_fraction * N)) # number of A component
    mass        = 39.948 * units.amu # arbitrary reference mass        
    epsilon     = 119.8 * units.kelvin * units.BOLTZMANN_CONSTANT_kB * units.AVOGADRO_CONSTANT_NA # arbitrary reference energy    
    sigma       = 0.3405 * units.nanometers # arbitrary reference lengthscale    
    [system, coordinates] = KobAndersen(N=N, NA=NA, mass=mass, epsilon=epsilon, sigma=sigma)    

    # mode = 'test'
    mode = 'production'

    # Relevant times and timescales
    reduced_time = units.sqrt((mass * sigma**2) / (48.0 * epsilon)) # factor on which all reduced times are based
    timestep = 0.035 * reduced_time # velocity Verlet timestep from Ref. [1], Supplementary Section 1.1
    delta_t = (40.0 / 3.0) * reduced_time # number of timesteps per Delta t in Lester's paper, Ref. [1] Supplementary Section 1.1, specified exactly by Lester in private communication

    if mode == 'test':
        quenched_temperature = 0.6 * epsilon / kB # temperature corresponding to left column of Fig. 2 from [1]
        tau = 200.0 * reduced_time # private communication from Lester Hedges (tau = 60 * reduced time for Tred = 0.7, 200 * reduced_time for Tred = 0.6)
        t_obs = 1 * tau # number of intervals of tau, from Fig. 2 of [1] (red dashed line, right column,) # should be 33 * tau for Tred = 0.7, 20*tau for Tred = 0.6
    elif mode == 'production':
        quenched_temperature = 0.7 * epsilon / kB # temperature corresponding to left column of Fig. 2 from [1]
        tau = 60.0 * reduced_time # private communication from Lester Hedges (tau = 60 * reduced time for Tred = 0.7, 200 * reduced_time for Tred = 0.6)
        t_obs = 33 * tau # number of intervals of tau, from Fig. 2 of [1] (red dashed line, right column,) # should be 33 * tau for Tred = 0.7, 20*tau for Tred = 0.6        
    else:
        raise Exception("Unknown mode: %s" % mode)

    print "UNCORRECTED TIMES"
    print "timestep = %s" % str(timestep.in_units_of(units.femtosecond))
    print "delta_t = %s (%f steps)" % (str(delta_t.in_units_of(units.picosecond)), delta_t / timestep)
    print "tau = %s (%f steps)" % (str(tau.in_units_of(units.picosecond)), tau / timestep)
    print "t_obs = %s (%f steps)" % (str(t_obs.in_units_of(units.picosecond)), t_obs / timestep)

    # Determine integral number of steps per velocity randomization and number of trajectories
    nsteps_per_frame = int(round(delta_t / timestep)) # number of steps per trajectory frame and velocity randomization
    nframes = int(round(t_obs / delta_t)) # number of frames per trajectory
    print "number of steps per trajectory frame and velocity randomization = %d" % nsteps_per_frame
    print "number of frames per trajectory = %d" % nframes
    # Correct delta_t and t_obs to be integral
    delta_t = nsteps_per_frame * timestep 
    t_obs = nframes * delta_t
    print "CORRECTED TIMES"
    print "timestep = %s" % str(timestep.in_units_of(units.femtosecond))
    print "delta_t = %s (%f steps)" % (str(delta_t.in_units_of(units.picosecond)), delta_t / timestep)
    print "tau = %s (%f steps)" % (str(tau.in_units_of(units.picosecond)), tau / timestep)
    print "t_obs = %s (%f steps)" % (str(t_obs.in_units_of(units.picosecond)), t_obs / timestep)

    # Replica-exchange filename
    ncfilename = 'repex.nc'

    # Decide whether we are initializing or resuming the simulation.
    minimize = True
    equilibrate = True
    quench = True
    seed = True
    if os.path.exists(ncfilename):
        # No need to do these things if we are resuming
        minimize = True
        equilibrate = False
        quench = False
        seed = True
                
    # Minimize
    if minimize:
        print "Minimizing with L-BFGS..."
        # Initialize a minimizer with default options.
        import optimize
        minimizer = optimize.LBFGSMinimizer(system, verbose=True, platform=platform)
        # Minimize the initial coordinates.
        coordinates = minimizer.minimize(coordinates)
        # Clean up to release the Context.
        del minimizer

    # Set temperature for simulations.
    elevated_temperature = 2.0 * epsilon / kB

    if equilibrate:
        # Equilibrate at high temperature.
        collision_rate = 1.0 / delta_t
        nsteps = int(math.floor(t_obs / timestep))
        print "Equilibrating at %s for %d steps..." % (str(elevated_temperature), nsteps)
        integrator = openmm.LangevinIntegrator(elevated_temperature, collision_rate, timestep)
        context = openmm.Context(system, integrator, platform)
        context.setPositions(coordinates)
        integrator.step(nsteps)
        state = context.getState(getPositions=True)
        coordinates = state.getPositions(asNumpy=True)

    if quench:
        # Quench to final temperature
        collision_rate = 1.0 / delta_t
        nsteps = int(math.floor(t_obs / timestep))
        print "Quenching to %s for %d steps..." % (str(quenched_temperature), nsteps)
        integrator = openmm.LangevinIntegrator(quenched_temperature, collision_rate, timestep)
        context = openmm.Context(system, integrator, platform)
        context.setPositions(coordinates)
        integrator.step(nsteps)
        state = context.getState(getPositions=True)
        coordinates = state.getPositions(asNumpy=True)           

    # Create a number of transition path sampling ensembles at different values of the field parameter s.
    s_reduced_unit = 1.0 / (sigma**2 * delta_t)
    #svalues = [0.0, 0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07] # repex-uniform.nc
    #svalues = [0.0, 0.002, 0.01, 0.07] # repex-four.nc
    #svalues = [0.0, 0.002, 0.01, 0.02, 0.04, 0.06, -0.01, -0.02, -0.04, -0.08] # 
    #svalues = [0.0, 0.002, 0.01, 0.02, 0.04, 0.06, 0.08] #
    #svalues = [0.0, 0.002, 0.01, 0.02, 0.04, 0.06, -0.01, -0.02, -0.04, -0.08] #    
    svalues = [0.00, 0.01, 0.02, 0.03, 0.04, 0.06]
    svalues = [0.00, 0.01, 0.02, 0.03, 0.04, 0.06]
    ensembles = [ TransitionPathSampling(system, N, NA, mass, epsilon, sigma, timestep, nsteps_per_frame, nframes, quenched_temperature, s * s_reduced_unit) for s in svalues ]

    trajectory = None
    if seed:
        # Generate an initial trajectory at zero field
        print "Generating seed trajectory for TPS..."
        trajectory = ensembles[0].generateTrajectory(coordinates, nframes)

    # Initialize replica-exchange TPS simulation.
    print "Initializing replica-exchange TPS..."
    simulation = ReplicaExchangeTPS(system, ensembles, trajectory, ncfilename)

    # Run simulation
    print "Running replica-exchange TPS..."
    simulation.run()
        
    print "Done."

def test_stability():
    """
    Test stability / energy conservation.

    """

    # Get OpenCL platform.
    platform = openmm.Platform.getPlatformByName("OpenCL")    

    # Constant
    kB = units.BOLTZMANN_CONSTANT_kB * units.AVOGADRO_CONSTANT_NA

    # Create a Kob-Andersen two-component mixture.
    N = 150 # number of particles
    A_fraction = 0.8 # fraction of A component
    NA = int(math.floor(A_fraction * N)) # number of A component
    mass        = 39.948 * units.amu # arbitrary reference mass        
    epsilon     = 119.8 * units.kelvin * units.BOLTZMANN_CONSTANT_kB * units.AVOGADRO_CONSTANT_NA # arbitrary reference energy    
    sigma       = 0.3405 * units.nanometers # arbitrary reference lengthscale    
    [system, coordinates] = KobAndersen(N=N, NA=NA, mass=mass, epsilon=epsilon, sigma=sigma)    

    # mode = 'test'
    mode = 'production'

    # Relevant times and timescales
    reduced_time = units.sqrt((mass * sigma**2) / (48.0 * epsilon)) # factor on which all reduced times are based
    timestep = 0.035 * reduced_time # velocity Verlet timestep from Ref. [1], Supplementary Section 1.1
    delta_t = (40.0 / 3.0) * reduced_time # number of timesteps per Delta t in Lester's paper, Ref. [1] Supplementary Section 1.1, specified exactly by Lester in private communication

    if mode == 'test':
        quenched_temperature = 0.6 * epsilon / kB # temperature corresponding to left column of Fig. 2 from [1]
        tau = 200.0 * reduced_time # private communication from Lester Hedges (tau = 60 * reduced time for Tred = 0.7, 200 * reduced_time for Tred = 0.6)
        t_obs = 1 * tau # number of intervals of tau, from Fig. 2 of [1] (red dashed line, right column,) # should be 33 * tau for Tred = 0.7, 20*tau for Tred = 0.6
    elif mode == 'production':
        quenched_temperature = 0.7 * epsilon / kB # temperature corresponding to left column of Fig. 2 from [1]
        tau = 60.0 * reduced_time # private communication from Lester Hedges (tau = 60 * reduced time for Tred = 0.7, 200 * reduced_time for Tred = 0.6)
        t_obs = 33 * tau # number of intervals of tau, from Fig. 2 of [1] (red dashed line, right column,) # should be 33 * tau for Tred = 0.7, 20*tau for Tred = 0.6        
    else:
        raise Exception("Unknown mode: %s" % mode)

    print "UNCORRECTED TIMES"
    print "timestep = %s" % str(timestep.in_units_of(units.femtosecond))
    print "delta_t = %s (%f steps)" % (str(delta_t.in_units_of(units.picosecond)), delta_t / timestep)
    print "tau = %s (%f steps)" % (str(tau.in_units_of(units.picosecond)), tau / timestep)
    print "t_obs = %s (%f steps)" % (str(t_obs.in_units_of(units.picosecond)), t_obs / timestep)

    # Determine integral number of steps per velocity randomization and number of trajectories
    nsteps_per_frame = int(round(delta_t / timestep)) # number of steps per trajectory frame and velocity randomization
    nframes = int(round(t_obs / delta_t)) # number of frames per trajectory
    print "number of steps per trajectory frame and velocity randomization = %d" % nsteps_per_frame
    print "number of frames per trajectory = %d" % nframes
    # Correct delta_t and t_obs to be integral
    delta_t = nsteps_per_frame * timestep 
    t_obs = nframes * delta_t
    print "CORRECTED TIMES"
    print "timestep = %s" % str(timestep.in_units_of(units.femtosecond))
    print "delta_t = %s (%f steps)" % (str(delta_t.in_units_of(units.picosecond)), delta_t / timestep)
    print "tau = %s (%f steps)" % (str(tau.in_units_of(units.picosecond)), tau / timestep)
    print "t_obs = %s (%f steps)" % (str(t_obs.in_units_of(units.picosecond)), t_obs / timestep)

    # Decide whether we are initializing or resuming the simulation.
    minimize = True
    equilibrate = True
    quench = True

    # Minimize
    if minimize:
        print "Minimizing with L-BFGS..."
        # Initialize a minimizer with default options.
        import optimize
        minimizer = optimize.LBFGSMinimizer(system, verbose=True, platform=platform)
        # Minimize the initial coordinates.
        coordinates = minimizer.minimize(coordinates)
        # Clean up to release the Context.
        del minimizer

    # Set temperature for simulations.
    elevated_temperature = 2.0 * epsilon / kB

    if equilibrate:
        # Equilibrate at high temperature.
        collision_rate = 1.0 / delta_t
        nsteps = int(math.floor(t_obs / timestep))
        print "Equilibrating at %s for %d steps..." % (str(elevated_temperature), nsteps)
        integrator = openmm.LangevinIntegrator(elevated_temperature, collision_rate, timestep)
        context = openmm.Context(system, integrator, platform)
        context.setPositions(coordinates)
        integrator.step(nsteps)
        state = context.getState(getPositions=True)
        coordinates = state.getPositions(asNumpy=True)

    if quench:
        # Quench to final temperature
        collision_rate = 1.0 / delta_t
        nsteps = int(math.floor(t_obs / timestep))
        print "Quenching to %s for %d steps..." % (str(quenched_temperature), nsteps)
        integrator = openmm.LangevinIntegrator(quenched_temperature, collision_rate, timestep)
        context = openmm.Context(system, integrator, platform)
        context.setPositions(coordinates)
        integrator.step(nsteps)
        state = context.getState(getPositions=True, getVelocities=True)
        coordinates = state.getPositions(asNumpy=True)
        velocities = state.getVelocities(asNumpy=True)                   

    # Create a Verlet integrator.
    integrator = openmm.VerletIntegrator(timestep)

    # Create a Context for integration.
    platform = openmm.Platform.getPlatformByName("OpenCL")
    context = openmm.Context(system, integrator, platform)
    
    # Set initial positions
    context.setPositions(coordinates)
    context.setVelocities(velocities)

    # Store initial state for each trajectory segment in trajectory.
    trajectory = Trajectory()

    # Generate trajectory segments.
    reassign_velocities = True
    remove_com_velocity = True
    for frame_index in range(nframes):
        #print "tau segment %d / %d" % (frame_index, nframes)

        if reassign_velocities:
            # Assign velocities from the Maxwell-Boltzmann distribution.
            kT = kB * quenched_temperature # thermal energy
            for atom_index in range(N):
                mass = system.getParticleMass(atom_index) # atomic mass
                sigma = units.sqrt(kT / mass) # standard deviation of velocity distribution for each coordinate for this atom
                for k in range(3):
                    velocities[atom_index,k] = sigma * numpy.random.normal()
            if remove_com_velocity:
                # Remove center of mass velocity
                com_velocity = units.Quantity((velocities / (units.nanometers / units.picoseconds)).mean(0), units.nanometers/units.picoseconds)
                for atom_index in range(N):
                    velocities[atom_index,:] -= com_velocity
                    
            context.setVelocities(velocities)
    
        # Store initial snapshot of trajectory segment.
        snapshot = Snapshot(context=context)
        trajectory.append(snapshot)
        # DEBUG
        print "frame %5d : Etot = %12.3f kcal/mol" % (frame_index, (snapshot.potential_energy + snapshot.kinetic_energy) / units.kilocalories_per_mole)

        # Propagate dynamics.
        integrator.step(nsteps_per_frame)
        
    # Store final snapshot of trajectory.
    snapshot = Snapshot(context=context)
    trajectory.append(snapshot)    

    atoms = construct_atom_list(N, NA)
    filename = "trajectory.pdb"
    write_pdb(filename, trajectory, atoms)
        
    print "Done."
    
if __name__ == "__main__":
    #test_stability()
    #testKobAndersen()
    szero()
    #driver()
    
