#!/usr/local/bin/env python

#=============================================================================================
# MODULE DOCSTRING
#=============================================================================================

"""
Parallel tempering driver.
Parallelization using mpi4py.

DESCRIPTION

This script directs parallel tempering simulations of AMBER protein systems in implicit solvent.

COPYRIGHT

@author John D. Chodera <jchodera@gmail.com>

This source file is released under the GNU General Public License.

This program is free software: you can redistribute it and/or modify it under
the terms of the GNU General Public License as published by the Free Software
Foundation, either version 3 of the License, or (at your option) any later
version.

This program is distributed in the hope that it will be useful, but WITHOUT ANY
WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A
PARTICULAR PURPOSE.  See the GNU General Public License for more details.
 
You should have received a copy of the GNU General Public License along with
this program.  If not, see <http://www.gnu.org/licenses/>.

"""

#=============================================================================================
# GLOBAL IMPORTS
#=============================================================================================

import os
import os.path
import sys
import numpy
import math
import copy
import time

import scipy.optimize # THIS MUST BE IMPORTED FIRST?!

import os
import os.path

import numpy
import numpy.linalg

import simtk.unit as units
import simtk.chem.openmm as openmm
import simtk.chem.openmm
import simtk.chem.openmm.extras.amber as amber

#import scipy.io.netcdf as netcdf # scipy pure Python netCDF interface - GIVES US TROUBLE FOR NOW
import netCDF4 as netcdf # netcdf4-python is used in place of scipy.io.netcdf for now
#import tables as hdf5 # HDF5 will be supported in the future

#=============================================================================================
# SOURCE CONTROL
#=============================================================================================

__version__ = "$Id: $"

#=============================================================================================
# MODULE CONSTANTS
#=============================================================================================

kB = units.BOLTZMANN_CONSTANT_kB * units.AVOGADRO_CONSTANT_NA # Boltzmann constant

#=============================================================================================
# Exceptions
#=============================================================================================

class NotImplementedException(Exception):
    """
    Exception denoting that the requested feature has not yet been implemented.

    """
    pass

class ParameterException(Exception):
    """
    Exception denoting that an incorrect argument has been specified.

    """
    pass
    
#=============================================================================================
# SIMULATION SNAPSHOT 
#=============================================================================================

class Snapshot(object):
    """
    Simulation snapshot.

    """
    def __init__(self, context=None, coordinates=None, velocities=None, box_vectors=None, potential_energy=None, kinetic_energy=None):
        """
        Create a simulation snapshot from either an OpenMM context or individually-specified components.

        OPTIONAL ARGUMENTS

        context (simtk.chem.openmm.Context) - if not None, the current state will be queried to populate simulation snapshot; otherwise, can specify individual components (default: None)
        coordinates (simtk.unit.Quantity wrapping Nx3 numpy array of dimension length) - atomic coordinates (default: None)
        velocities (simtk.unit.Quantity wrapping Nx3 numpy array of dimension length) - atomic velocities (default: None)
        box_vectors - periodic box vectors (default: None)
        potential_energy (simtk.unit.Quantity of units energy/mole) - potential energy at current timestep (default: None)
        kinetic_energy (simtk.unit.Quantity of units energy/mole) - kinetic energy at current timestep (default: None)
        
        """

        if context is not None:
            # Get current state from OpenMM Context object.
            state = context.getState(getPositions=True, getVelocities=True, getEnergy=True)
            
            # Populate current snapshot data.
            self.coordinates = state.getPositions(asNumpy=True)
            self.velocities = state.getVelocities(asNumpy=True)
            self.box_vectors = state.getPeriodicBoxVectors() # TODO: set asNumpy=True once bug in OpenMM is fixed
            self.potential_energy = state.getPotentialEnergy()
            self.kinetic_energy = state.getKineticEnergy()
        else:
            if coordinates is not None: self.coordinates = copy.deepcopy(coordinates)
            if velocities is not None: self.velocities = copy.deepcopy(velocities)
            if box_vectors is not None: self.box_vectors = copy.deepcopy(box_vectors)
            if potential_energy is not None: self.potential_energy = copy.deepcopy(potential_energy)
            if kinetic_energy is not None: self.kinetic_energy = copy.deepcopy(kinetic_energy)                       

        # Check for nans in coordinates, and raise an exception if something is wrong.
        if numpy.any(numpy.isnan(self.coordinates)):
            raise Exception("Some coordinates became 'nan'; simulation is unstable or buggy.")

        return

    @property
    def total_energy(self):
        return self.kinetic_energy + self.potential_energy

#=============================================================================================
# SIMULATION TRAJECTORY
#=============================================================================================

class Trajectory(list):
    """
    Simulation trajectory.

    """
    def __init__(self, trajectory=None):
        """
        Create a simulation trajectory object

        OPTIONAL ARGUMENTS

        trajectory (Trajectory) - if not None, make a deep copy of specified trajectory (default: None)
        
        """

        # Initialize list.
        list.__init__(self)

        if trajectory is not None:
            # Try to make a copy out of whatever container we were provided
            for snapshot in trajectory:
                snapshot_copy = copy.deepcopy(snapshot)                    
                self.append(snapshot_copy)

        return

    def reverse(self):
        """
        Reverse the trajectory.

        NOTE

        We cannot handle the velocities correctly when reversing the trajectory, so velocities will no longer be meaningful.
        Kinetic energies are correctly updated, however, and path actions should be accurate.

        """
        # Reverse the order of snapshots within the trajectory.
        list.reverse(self)

        # Determine number of snapshots.
        nsnapshots = self.__len__()
        
        # Recalculate kinetic energies for the *beginning* of each trajectory segment.
        # This makes use of the fact that the energy is (approximately) conserved over each trajectory segment, in between velocity randomizations.
        # Note that this may be a poor approximation in some cases.
        for t in range(nsnapshots-1):
            self[t].kinetic_energy = self[t+1].total_energy - self[t].potential_energy

        # No use reversing momenta, since we can't determine what appropriate reversed momenta should be.
        
        return

#=============================================================================================
# VELOCITY VERLET DYNAMICS
#=============================================================================================

class VelocityVerletDynamics(object):
    """
    Hamiltonian trajectories with canonical initial momenta, propagated using hybrid Velocity Verlet / Verlet dynamics.

    """

    def __init__(self, system, timestep, nsteps_per_snapshot, nsnapshots, platform=None, verbose=False):
        """

        """

        self.system = system
        self.timestep = timestep
        self.nsteps_per_snapshot = nsteps_per_snapshot
        self.nsnapshots = nsnapshots

        # Create a Verlet integrator.
        self.integrator = openmm.VerletIntegrator(self.timestep)

        # Create a Context for integration.
        if platform:
            self.context = openmm.Context(self.system, self.integrator, platform)
        else:
            self.context = openmm.Context(self.system, self.integrator)            

        # Form vectors of masses and sqrt(kT/m) for force propagation and velocity randomization.
        nparticles = system.getNumParticles()
        self.mass = units.Quantity(numpy.zeros([nparticles,3], numpy.float64), units.amu)
        for particle_index in range(nparticles):
            self.mass[particle_index,:] = self.system.getParticleMass(particle_index)

        self.maximum_rattle_iterations = 1000 # maximum number of RATTLE iterations to allow
        self.relative_rattle_tolerance = 1.0e-16 # relative RATTLE tolerance
        self.verbose = verbose

        # Create dictionary for wave code to use.
        weave_context = dict()
        weave_context['max_iterations'] = self.maximum_rattle_iterations
        weave_context['relative_rattle_tolerance'] = self.relative_rattle_tolerance
        # Build masses array.
        nparticles = self.system.getNumParticles()
        weave_context['nparticles'] = nparticles
        weave_context['masses'] = numpy.array([self.system.getParticleMass(index) / units.amu for index in range(nparticles)])
        # Build constraints array.
        nconstraints = self.system.getNumConstraints()
        weave_context['nconstraints'] = nconstraints
        constraints_atoms_i = [ self.system.getConstraintParameters(index)[0] for index in range(nconstraints) ]
        constraints_atoms_j = [ self.system.getConstraintParameters(index)[1] for index in range(nconstraints) ]            
        constraints_distances = [ self.system.getConstraintParameters(index)[2] / units.nanometers for index in range(nconstraints) ] 
        weave_context['constraints_atoms'] = numpy.array([constraints_atoms_i, constraints_atoms_j])
        weave_context['constraints_distances'] = numpy.array(constraints_distances)

        self.weave_context = weave_context

        return

    def assignMaxwellBoltzmannVelocities(self, temperature, remove_com_velocity=False):
        """
        Generate Maxwell-Boltzmann velocities at the current simulation temperature.

        OPTIONAL ARGUMENTS

        remove_com_velocity (boolean) - if True, the center-of-mass velocity will be removed after the velocities are randomized (default: False)

        TODO

        This could be sped up by introducing vector operations.

        """

        # Get number of atoms
        nparticles = self.system.getNumParticles()

        # Assign velocities from the Maxwell-Boltzmann distribution.
        sqrt_kT_over_m = units.Quantity(numpy.zeros([nparticles,3], numpy.float64), (units.nanometers / units.picoseconds))
        for particle_index in range(nparticles):
            sqrt_kT_over_m[particle_index,:] = units.sqrt(kB * temperature / self.mass[particle_index,0]) # standard deviation of velocity distribution for each coordinate for this atom
        velocities = sqrt_kT_over_m * numpy.random.standard_normal(size=(nparticles,3))

        if remove_com_velocity:
            # Remove center of mass velocity
            velocity_units = sqrt_kT_over_m.unit 
            com_velocity = units.Quantity(numpy.reshape((velocities / velocity_units).mean(0), (1,3)), velocity_units)
            velocities -= units.Quantity(numpy.repeat(com_velocity / velocity_units, nparticles, axis=0),velocity_units)

        # Return velocities
        return velocities

    def rattleVelocities(self, positions, velocities):
        """
        Apply RATTLE algorithm to velocities to remove components along constraints.

        """

        try:
            from scipy import weave
            from scipy.weave import converters
                
            weave_context = self.weave_context
            weave_context['positions'] = numpy.array(positions / units.nanometers)
            weave_context['velocities'] = numpy.array(velocities / (units.nanometers/units.picosecond))
            
            code = """
            // RATTLE algorithm
            for(int iteration = 0; iteration < max_iterations; iteration++) {
              int nconstraints_updated = 0;
              for(int index = 0; index < nconstraints; index++) {
                // Retrieve constraints.
                int i = constraints_atoms(0,index);
                int j = constraints_atoms(1,index);
                double dij = constraints_distances(index);
                // Compute difference in distances and velocities.
                double rij[3];
                double vij[3];
                for(int k = 0; k < 3; k++) {
                   rij[k] = (positions(i,k) - positions(j,k));
                   vij[k] = (velocities(i,k) - velocities(j,k));
                   }
                // Check if constraint is satisfied to within tolerance.
                double rvdot = rij[0]*vij[0] + rij[1]*vij[1] + rij[2]*vij[2];
                double rnorm = rij[0]*rij[0] + rij[1]*rij[1] + rij[2]*rij[2];
                double vnorm = vij[0]*vij[0] + vij[1]*vij[1] + vij[2]*vij[2];
                if (rvdot <= sqrt(rnorm*vnorm) * relative_rattle_tolerance)
                    continue;
                // Correct velocities if not.
                double mi = masses(i);
                double mj = masses(j);
                double kval = rvdot / (dij*dij) / (1./mi + 1./mj);
                for(int k = 0; k < 3; k++) {
                   velocities(i,k) -= kval*rij[k]/mi;
                   velocities(j,k) += kval*rij[k]/mj;
                   }
                nconstraints_updated++;
                }
                if (nconstraints_updated == 0) {
                   //printf("Converged to relative tolerance of %12.8e in %d iterations\\n", relative_rattle_tolerance, iteration);
                   break;                               
                   }                
            }

            """

            # Execute inline C code with weave.
            old_velocities = velocities            
            weave.inline(code, weave_context.keys(), local_dict=weave_context, headers=['<math.h>', '<stdlib.h>'], type_converters=converters.blitz, verbose=2)
        
            # Store results.
            #velocities = units.Quantity(weave_context['velocities'], units.nanometers / units.picosecond)
            velocities = units.Quantity(weave_context['velocities'], units.nanometers / units.picosecond)
            diff = (velocities - old_velocities) / (units.nanometers / units.picosecond)
            #print "before:\n%s\nafter:\n%s\ndifference:\n%s\n" % (str(old_velocities), str(velocities), str(numpy.sqrt(numpy.vdot(diff,diff))))

        except Exception as exception:

            print exception

            for iteration in range(self.maximum_rattle_iterations):
                for index in range(system.getNumConstraints()):
                    # Get constraints.
                    [i, j, dij] = system.getConstraintParameters(index)
                    # Compute difference in distances and velocities.
                    rij = (positions[i,:] - positions[j,:])
                    vij = (velocities[i,:] - velocities[j,:])
                    # Check if constraint is satisfied to within tolerance.
                    rvdot = rij[0]*vij[0] + rij[1]*vij[1] + rij[2]*vij[2]
                    rnorm = rij[0]*rij[0] + rij[1]*rij[1] + rij[2]*rij[2]
                    vnorm = vij[0]*vij[0] + vij[1]*vij[1] + vij[2]*vij[2]
                    if (rvdot <= units.sqrt(rnorm*vnorm) * self.relative_rattle_tolerance):
                        continue
                    # Correct velocities if not.
                    mi = system.getParticleMass(i)
                    mj = system.getParticleMass(j)                                
                    k = rvdot / dij**2 / (1./mi + 1./mj)
                    velocities[i,:] -= k*rij/mi
                    velocities[j,:] += k*rij/mj
            
        return velocities

    def rattleVelocitiesOld(self, positions, velocities):
        """
        Apply RATTLE algorithm to velocities to remove components along constraints.

        """

        for iteration in range(self.maximum_rattle_iterations):
            print "RATTLE iteration %d" % iteration
            max_relative_violation = 0.0
            for index in range(system.getNumConstraints()):
                # Get constraints.
                [i, j, dij] = system.getConstraintParameters(index)
                # Compute difference in distances and velocities.
                rij = (positions[i,:] - positions[j,:])
                vij = (velocities[i,:] - velocities[j,:])
                # Check if constraint is satisfied to within tolerance.
                rvdot = rij[0]*vij[0] + rij[1]*vij[1] + rij[2]*vij[2]
                rnorm = rij[0]*rij[0] + rij[1]*rij[1] + rij[2]*rij[2]
                vnorm = vij[0]*vij[0] + vij[1]*vij[1] + vij[2]*vij[2]
                relative_violation = abs(rvdot / units.sqrt(rnorm*vnorm))
                if (relative_violation > self.relative_rattle_tolerance):
                    mi = system.getParticleMass(i)
                    mj = system.getParticleMass(j)                                
                    k = rvdot / dij**2 / (1./mi + 1./mj)
                    velocities[i,:] -= k*rij/mi
                    velocities[j,:] += k*rij/mj
                    if relative_violation > max_relative_violation:
                        max_relative_violation = relative_violation
            print "max relative violation: %e" % max_relative_violation
            
        return velocities

    def generateTrajectory(self, temperature, positions):
        """
        Generate a velocity Verlet trajectory consisting of ntau segments of tau_steps in between storage of Snapshots and randomization of velocities.
        
        ARGUMENTS
        
        positions (coordinate set) - initial coordinates; velocities will be assigned from Maxwell-Boltzmann distribution

        RETURNS

        trajectory (list of Snapshot) - generated trajectory of initial conditions, including initial coordinate set

        NOTES

        This routine generates a velocity Verlet trajectory for systems without constraints by wrapping the OpenMM 'VerletIntegrator' in two half-kicks of the velocity.

        TODO

        RATTLE constraints for velocity.
        
        """

        # Set initial positions
        self.context.setPositions(positions)

        # SHAKE coordinates.
        self.context.applyConstraints(self.relative_rattle_tolerance)
        state = self.context.getState(getPositions=True)
        positions = state.getPositions(asNumpy=True)

        # Store initial state for each trajectory segment in trajectory.
        trajectory = Trajectory()

        # Assign velocities from Maxwell-Boltzmann distribution
        velocities = self.assignMaxwellBoltzmannVelocities(temperature, remove_com_velocity=True)
        self.rattleVelocities(positions, velocities)
        self.context.setVelocities(velocities)            

        # Generate trajectory segments.
        for snapshot_index in range(self.nsnapshots):
            # Store initial snapshot of trajectory segment.
            snapshot = Snapshot(context=self.context)
            trajectory.append(snapshot)
            
            # Propagate dynamics by velocity Verlet.
            # We only have leapfrog integrator available, so we wrap it in two half-kicks.
            # Back-kick by half a timestep to get ready for leapfrog integration.
            state = self.context.getState(getForces=True, getVelocities=True, getPositions=True)
            force = state.getForces(asNumpy=True)
            positions = state.getPositions(asNumpy=True)
            velocities = state.getVelocities(asNumpy=True)
            velocities -= 0.5 * force/self.mass * self.timestep
            self.context.setVelocities(velocities)
            # Step using leapfrog.
            self.integrator.step(self.nsteps_per_snapshot)
            # Forward-kick by half a timestep to bring velocities into sync with positions.
            state = self.context.getState(getForces=True, getVelocities=True)
            force = state.getForces(asNumpy=True)
            velocities = state.getVelocities(asNumpy=True)
            velocities += 0.5 * force/self.mass * self.timestep
            self.rattleVelocities(positions, velocities)            
            self.context.setVelocities(velocities)

        # Store final snapshot of trajectory.
        snapshot = Snapshot(self.context)
        trajectory.append(snapshot)

        # Store trajectory's path Hamiltonian.
        trajectory.path_hamiltonian = trajectory[0].total_energy

        return trajectory

#=============================================================================================
# Modified parallel tempering simulation
#=============================================================================================

class ParallelTempering(object):
    """
    Modified parallel tempering simulation.
    
    """    

    def __init__(self, system, coordinates, store_filename, Tmin=None, Tmax=None, ntemps=None, temperatures=None, mm=None, mpi=None):
        """
        Initialize a parallel tempering simulation object.

        ARGUMENTS
        
        system (simtk.chem.openmm.System) - the system to simulate
        coordinates (simtk.unit.Quantity of numpy natoms x 3 array of units length, or list thereof) - coordinate set(s) for one or more replicas, assigned in a round-robin fashion
        store_filename (string) -  name of NetCDF file to bind to for simulation output and checkpointing

        OPTIONAL ARGUMENTS

        Tmin, Tmax, ntemps - min and max temperatures, and number of temperatures for exponentially-spaced temperature selection (default: None)
        temperatures (list of simtk.unit.Quantity with units of temperature) - if specified, this list of temperatures will be used instead of (Tmin, Tmax, ntemps) (default: None)
        protocol (dict) - Optional protocol to use for specifying simulation protocol as a dict.  Provided keywords will be matched to object variables to replace defaults. (default: None)

        NOTES

        Either (Tmin, Tmax, ntempts) must all be specified or the list of 'temperatures' must be specified.

        """
        # Create thermodynamic states from temperatures.
        if temperatures is not None:
            pass
        elif (Tmin is not None) and (Tmax is not None) and (ntemps is not None):
            temperatures = [ Tmin + (Tmax - Tmin) * (math.exp(float(i) / float(ntemps-1)) - 1.0) / (math.e - 1.0) for i in range(ntemps) ]
        else:
            raise ValueError("Either 'temperatures' or 'Tmin', 'Tmax', and 'ntemps' must be provided.")

        # Make sure 'temperatures' is a units.Quantity wrapping a numpy array.
        self.temperatures = units.Quantity(numpy.zeros([len(temperatures)], numpy.float64), units.kelvin)
        for index in range(len(temperatures)):
            self.temperatures[index] = temperatures[index]

        # Select default OpenMM implementation if not specified.
        self.mm = mm
        if mm is None: self.mm = simtk.chem.openmm

        # Determine number of replicas from the number of specified thermodynamic states.
        self.nreplicas = len(self.temperatures)

        # Store system.
        self.system = system

        # Record store file filename
        self.store_filename = store_filename

        # Distribute coordinate information to replicas in a round-robin fashion.
        # We have to explicitly check to see if z is a list or a set here because it turns out that numpy 2D arrays are iterable as well.
        if type(coordinates) in [type(list()), type(set())]:
            self.provided_coordinates = [ copy.deepcopy(coordinate_set) for coordinate_set in coordinates ] 
            #self.provided_coordinates = [ coordinate_set for coordinate_set in coordinates ]
        else:
            self.provided_coordinates = [ copy.deepcopy(coordinates) ]            
            #self.provided_coordinates = [ coordinates ]
        
        # Set default options.
        # These can be changed externally until object is initialized.
        self.constraint_tolerance = 1.0e-6 
        self.timestep = 2.0 * units.femtosecond
        self.nsteps_per_snapshot = 500
        self.nsnapshots = 20 # number of snapshots per trajectory
        self.number_of_iterations = 10
        self.platform = None
        self.mpi = mpi # MPI communicator
        self.replica_mixing_scheme = 'swap-all'
        self.nsaveatoms = system.getNumParticles() # write all atoms
        self.title = 'Parallel tempering'
        self.ncfile = None

        # To allow for parameters to be modified after object creation, class is not initialized until a call to self._initialize().
        self._initialized = False

        # Set verbosity.
        self.verbose = False
        self.show_energies = True
        self.show_mixing_statistics = True
            
        return

    def run(self):
        """
        Run the replica-exchange simulation.

        Any parameter changes (via object attributes) that were made between object creation and calling this method become locked in
        at this point, and the object will create and bind to the store file.  If the store file already exists, the run will be resumed
        if possible; otherwise, an exception will be raised.

        """

        # Make sure we've initialized everything and bound to a storage file before we begin execution.
        if not self._initialized:
            print "rank %d / %d : _initialize" % (self.mpi.rank, self.mpi.size)        
            self._initialize()
            print "rank %d / %d : _initialize complete" % (self.mpi.rank, self.mpi.size)                    

        # Main loop
        print "rank %d / %d : main loop" % (self.mpi.rank, self.mpi.size)        
        while (self.iteration < self.number_of_iterations):
            if self.verbose: print "\nIteration %d / %d" % (self.iteration+1, self.number_of_iterations)

            # Attempt replica swaps to sample from equilibrium permuation of states associated with replicas.
            self._mix_replicas()

            # Propagate replicas.
            self._propagate_replicas()

            # Compute energies of all replicas at all states.
            self._compute_energies()

            # Show energies.
            if self.verbose and self.show_energies:
                self._show_energies()

            # Write to storage file.
            self._write_iteration_netcdf()
            
            # Increment iteration counter.
            self.iteration += 1

            # Show mixing statistics.
            if self.verbose:
                self._show_mixing_statistics()

        # Clean up and close storage files.
        self._finalize()

        return

    def _initialize(self):
        """
        Initialize the simulation, and bind to a storage file.

        """
        if self._initialized:
            print "Simulation has already been initialized."
            raise Error

        # Turn off verbosity if not master node.
        if self.mpi is not None:
            if self.mpi.rank != 0: self.verbose = False

        # Select OpenMM Platform.
        if self.platform is None:
            self.platform = simtk.chem.openmm.Platform.getPlatformByName("Reference")                                    

        self.dynamics = VelocityVerletDynamics(self.system, self.timestep, self.nsteps_per_snapshot, self.nsnapshots, platform=self.platform, verbose=self.verbose)

        # Determine number of states.
        self.nstates = len(self.temperatures)

        # Determine number of atoms in systems.
        self.natoms = self.system.getNumParticles()
  
        # Allocate storage.
        self.replica_states     = numpy.zeros([self.nstates], numpy.int32) # replica_states[i] is the state that replica i is currently at
        self.u_kl               = numpy.zeros([self.nstates, self.nstates], numpy.float32) # path hamiltonians
        self.swap_Pij_accepted  = numpy.zeros([self.nstates, self.nstates], numpy.float32) 
        self.Nij_proposed       = numpy.zeros([self.nstates,self.nstates], numpy.int64) # Nij_proposed[i][j] is the number of swaps proposed between states i and j, prior of 1
        self.Nij_accepted       = numpy.zeros([self.nstates,self.nstates], numpy.int64) # Nij_proposed[i][j] is the number of swaps proposed between states i and j, prior of 1

        # Distribute coordinate information to replicas in a round-robin fashion.
        self.replica_coordinates = [ copy.deepcopy(self.provided_coordinates[replica_index % len(self.provided_coordinates)]) for replica_index in range(self.nstates) ]
        #self.replica_coordinates = [ self.provided_coordinates[replica_index % len(self.provided_coordinates)] for replica_index in range(self.nstates) ]

        # Dummy trajectories.
        self.replica_trajectories = [ Trajectory() for index in range(self.nstates) ]
        
        # Assign initial replica states.
        for replica_index in range(self.nstates):
            self.replica_states[replica_index] = replica_index

        # Check if netcdf file exists.
        resume = os.path.exists(self.store_filename) and (os.path.getsize(self.store_filename) > 0)
        if (self.mpi): resume = self.mpi.bcast(resume)
        
        if resume:
            # Resume from NetCDF file.
            self._resume_from_netcdf()

            # Propagate replicas.
            self._propagate_replicas()

            # Show energies.
            if self.verbose and self.show_energies:
                self._show_energies()            
        else:
            # Initialize current iteration counter.
            self.iteration = 0

            # Propagate replicas.
            self._propagate_replicas()
            
            # Show energies.
            if self.verbose and self.show_energies:
                self._show_energies()

            # Initialize NetCDF file.
            self._initialize_netcdf()

            # Store initial state.
            self._write_iteration_netcdf()
  
        # Signal that the class has been initialized.
        self._initialized = True

        return

    def _finalize(self):
        """
        Do anything necessary to clean up.

        """

        if (self.mpi is None) or (self.mpi.rank == 0):
            self.ncfile.close()

        return

    def _propagate_replicas(self):
        """
        Propagate all replicas.

        TODO

        * Parallel implementation

        """

        #print "_propagate_replicas: rank %d / %d" % (self.mpi.rank, self.mpi.size)
        
        start_time = time.time()

        if self.verbose: print "Propagating all replicas for %.3f ps..." % (self.nsteps_per_snapshot * self.nsnapshots * self.timestep / units.picoseconds)
        
        # Propagate all replicas.
        if self.mpi is None:
            # Serial version
            for replica_index in range(self.nstates):            
                # Retrieve state.
                state_index = self.replica_states[replica_index] # index of thermodynamic state that current replica is assigned to
                temperature = self.temperatures[state_index]
                # Integrate.
                coordinates = self.replica_coordinates[replica_index]
                trajectory = self.dynamics.generateTrajectory(temperature, coordinates)
                # Store trajectory and coordinates.
                self.replica_trajectories[replica_index] = trajectory
                self.replica_coordinates[replica_index] = trajectory[-1].coordinates
        else:
            # Parallel version.
            #print "Parallel _propagate_replicas: rank %d / %d" % (self.mpi.rank, self.mpi.size)

            # Run just this node's share of replicas.
            if self.mpi.rank == 0: print "Running trajectories...."
            initial_time = time.time()
            for replica_index in range(self.mpi.rank, self.nstates, self.mpi.size):
                #print "node %d / %d : running replica %d / %d" % (self.mpi.rank, self.mpi.size, replica_index, self.nstates)
                # Retrieve state.
                state_index = self.replica_states[replica_index] # index of thermodynamic state that current replica is assigned to
                temperature = self.temperatures[state_index]
                # Integrate.
                coordinates = self.replica_coordinates[replica_index]
                trajectory = self.dynamics.generateTrajectory(temperature, coordinates)
                # Store trajectory and coordinates.
                self.replica_trajectories[replica_index] = trajectory
            final_time = time.time()
            self.mpi.barrier()
            if self.mpi.rank == 0: print "Running trajectories: elapsed time %.3f s" % (final_time - initial_time)
            
            # Send trajectories back to root node.
            if self.mpi.rank == 0: print "Synchronizing trajectories..."
            initial_time = time.time()
            gather = self.mpi.allgather(self.replica_trajectories[self.mpi.rank:self.nstates:self.mpi.size])
            for replica_index in range(self.nstates):
                source = replica_index % self.mpi.size # node with trajectory data
                index = replica_index // self.mpi.size # index within trajectory batch
                trajectory = gather[source][index] 
                self.replica_trajectories[replica_index] = trajectory
                self.replica_coordinates[replica_index] = trajectory[-1].coordinates
            final_time = time.time()
            if self.mpi.rank == 0: print "Synchronizing trajectories: elapsed time %.3f s" % (final_time - initial_time)

        end_time = time.time()
        elapsed_time = end_time - start_time
        time_per_replica = elapsed_time / (float(self.nstates) / float(self.mpi.size))
        seconds_per_day = 24*60*60
        ns_per_day = self.timestep * self.nsteps_per_snapshot * self.nsnapshots / time_per_replica * seconds_per_day / units.nanoseconds
        if self.verbose: print "Time to propagate all replicas %.3f s (%.3f s per replica, %.3f ns/day/replica, %.3f ns/day aggregate).\n" % (elapsed_time, time_per_replica, ns_per_day, ns_per_day * self.nstates)

        return


    def _compute_energies(self):
        """
        Compute reduced potentials of all replicas at all states (temperatures).

        NOTES

        Because only the temperatures differ among replicas, we replace the generic O(N^2) replica-exchange implementation with an O(N) implementation.

        TODO

        * Parallel implementation of energy calculation.
        
        """

        # Compute reduced potentials for all configurations in all states.
        for k in range(self.nstates):
            for l in range(self.nstates):
                # Compute reduced potential
                beta = 1.0 / (kB * self.temperatures[l])
                self.u_kl[k,l] = beta * self.replica_trajectories[k].path_hamiltonian

        return

    def _mix_all_replicas(self):
        """
        Attempt exchanges between all replicas to enhance mixing.

        TODO

        * Adjust nswap_attempts based on how many we can afford to do and not have mixing take a substantial fraction of iteration time.
        
        """

        # Determine number of swaps to attempt to ensure thorough mixing.
        # TODO: Replace this with analytical result computed to guarantee sufficient mixing.
        nswap_attempts = self.nstates**5 # number of swaps to attempt (ideal, but too slow!)
        nswap_attempts = self.nstates**3 # best compromise for pure Python?
        
        if self.verbose: print "Will attempt to swap all pairs of replicas, using a total of %d attempts." % nswap_attempts

        # Attempt swaps to mix replicas.
        for swap_attempt in range(nswap_attempts):
            # Choose replicas to attempt to swap.
            i = numpy.random.randint(self.nstates) # Choose replica i uniformly from set of replicas.
            j = numpy.random.randint(self.nstates) # Choose replica j uniformly from set of replicas.

            # Determine which states these resplicas correspond to.
            istate = self.replica_states[i] # state in replica slot i
            jstate = self.replica_states[j] # state in replica slot j

            # Reject swap attempt if any energies are nan.
            if (numpy.isnan(self.u_kl[i,jstate]) or numpy.isnan(self.u_kl[j,istate]) or numpy.isnan(self.u_kl[i,istate]) or numpy.isnan(self.u_kl[j,jstate])):
                continue

            # Compute log probability of swap.
            log_P_accept = - (self.u_kl[i,jstate] + self.u_kl[j,istate]) + (self.u_kl[i,istate] + self.u_kl[j,jstate])

            #print "replica (%3d,%3d) states (%3d,%3d) energies (%8.1f,%8.1f) %8.1f -> (%8.1f,%8.1f) %8.1f : log_P_accept %8.1f" % (i,j,istate,jstate,self.u_kl[i,istate],self.u_kl[j,jstate],self.u_kl[i,istate]+self.u_kl[j,jstate],self.u_kl[i,jstate],self.u_kl[j,istate],self.u_kl[i,jstate]+self.u_kl[j,istate],log_P_accept)

            # Record that this move has been proposed.
            self.Nij_proposed[istate,jstate] += 1
            self.Nij_proposed[jstate,istate] += 1

            # Accept or reject.
            if (log_P_accept >= 0.0 or (numpy.random.rand() < math.exp(log_P_accept))):
                # Swap states in replica slots i and j.
                (self.replica_states[i], self.replica_states[j]) = (self.replica_states[j], self.replica_states[i])
                # Accumulate statistics
                self.Nij_accepted[istate,jstate] += 1
                self.Nij_accepted[jstate,istate] += 1

        return

    def _mix_all_replicas_weave(self):
        """
        Attempt exchanges between all replicas to enhance mixing.
        Acceleration by 'weave' from scipy is used to speed up mixing by ~ 400x.
        
        """

        # Determine number of swaps to attempt to ensure thorough mixing.
        # TODO: Replace this with analytical result computed to guarantee sufficient mixing.
        nswap_attempts = self.nstates**3 # number of swaps to attempt 
        # Handled in C code below.
        
        if self.verbose: print "Will attempt to swap all pairs of replicas using weave-accelerated code, using a total of %d attempts." % nswap_attempts

        from scipy import weave

        # TODO: Replace drand48 with numpy random generator.
        code = """
        // Determine number of swap attempts.
        // TODO: Replace this with analytical result computed to guarantee sufficient mixing.        
        long nswap_attempts = nstates*nstates*nstates; 

        // Attempt swaps.
        for(long swap_attempt = 0; swap_attempt < nswap_attempts; swap_attempt++) {
            // Choose replicas to attempt to swap.
            int i = (long)(drand48() * nstates); 
            int j = (long)(drand48() * nstates);

            // Determine which states these resplicas correspond to.            
            int istate = REPLICA_STATES1(i); // state in replica slot i
            int jstate = REPLICA_STATES1(j); // state in replica slot j

            // Reject swap attempt if any energies are nan.
            if ((std::isnan(U_KL2(i,jstate)) || std::isnan(U_KL2(j,istate)) || std::isnan(U_KL2(i,istate)) || std::isnan(U_KL2(j,jstate))))
               continue;

            // Compute log probability of swap.
            double log_P_accept = - (U_KL2(i,jstate) + U_KL2(j,istate)) + (U_KL2(i,istate) + U_KL2(j,jstate));

            // Record that this move has been proposed.
            NIJ_PROPOSED2(istate,jstate) += 1;
            NIJ_PROPOSED2(jstate,istate) += 1;

            // Accept or reject.
            if (log_P_accept >= 0.0 || (drand48() < exp(log_P_accept))) {
                // Swap states in replica slots i and j.
                int tmp = REPLICA_STATES1(i);
                REPLICA_STATES1(i) = REPLICA_STATES1(j);
                REPLICA_STATES1(j) = tmp;
                // Accumulate statistics
                NIJ_ACCEPTED2(istate,jstate) += 1;
                NIJ_ACCEPTED2(jstate,istate) += 1;
            }

        }
        """

        # Stage input temporarily.
        nstates = self.nstates
        replica_states = self.replica_states
        u_kl = self.u_kl
        Nij_proposed = self.Nij_proposed
        Nij_accepted = self.Nij_accepted

        # Execute inline C code with weave.
        info = weave.inline(code, ['nstates', 'replica_states', 'u_kl', 'Nij_proposed', 'Nij_accepted'], headers=['<math.h>', '<stdlib.h>'], verbose=2)

        # Store results.
        self.replica_states = replica_states
        self.Nij_proposed = Nij_proposed
        self.Nij_accepted = Nij_accepted

        return

    def _mix_neighboring_replicas(self):
        """
        Attempt exchanges between neighboring replicas only.

        """

        if self.verbose: print "Will attempt to swap only neighboring replicas."

        # Attempt swaps of pairs of replicas using traditional scheme (e.g. [0,1], [2,3], ...)
        offset = numpy.random.randint(2) # offset is 0 or 1
        for istate in range(offset, self.nstates-1, 2):
            jstate = istate + 1 # second state to attempt to swap with i

            # Determine which replicas these states correspond to.
            i = None
            j = None
            for index in range(self.nstates):
                if self.replica_states[index] == istate: i = index
                if self.replica_states[index] == jstate: j = index                

            # Reject swap attempt if any energies are nan.
            if (numpy.isnan(self.u_kl[i,jstate]) or numpy.isnan(self.u_kl[j,istate]) or numpy.isnan(self.u_kl[i,istate]) or numpy.isnan(self.u_kl[j,jstate])):
                continue

            # Compute log probability of swap.
            log_P_accept = - (self.u_kl[i,jstate] + self.u_kl[j,istate]) + (self.u_kl[i,istate] + self.u_kl[j,jstate])

            #print "replica (%3d,%3d) states (%3d,%3d) energies (%8.1f,%8.1f) %8.1f -> (%8.1f,%8.1f) %8.1f : log_P_accept %8.1f" % (i,j,istate,jstate,self.u_kl[i,istate],self.u_kl[j,jstate],self.u_kl[i,istate]+self.u_kl[j,jstate],self.u_kl[i,jstate],self.u_kl[j,istate],self.u_kl[i,jstate]+self.u_kl[j,istate],log_P_accept)

            # Record that this move has been proposed.
            self.Nij_proposed[istate,jstate] += 1
            self.Nij_proposed[jstate,istate] += 1

            # Accept or reject.
            if (log_P_accept >= 0.0 or (numpy.random.rand() < math.exp(log_P_accept))):
                # Swap states in replica slots i and j.
                (self.replica_states[i], self.replica_states[j]) = (self.replica_states[j], self.replica_states[i])
                # Accumulate statistics
                self.Nij_accepted[istate,jstate] += 1
                self.Nij_accepted[jstate,istate] += 1

        return

    def _mix_replicas(self):
        """
        Attempt to swap replicas according to user-specified scheme.
        
        """

        if self.verbose: print "Mixing replicas..."

        if (self.mpi is not None) and (self.mpi.rank != 0):
            #print "rank %d / %d : scattering replica states (expecting %d)." % (self.mpi.rank, self.mpi.size, self.replica_states.size)
            self.replica_states = self.mpi.bcast(self.replica_states)
            return

        # Reset storage to keep track of swap attempts this iteration.
        self.Nij_proposed[:,:] = 0
        self.Nij_accepted[:,:] = 0

        # Perform swap attempts according to requested scheme.
        start_time = time.time()                    
        if self.replica_mixing_scheme == 'swap-neighbors':
            self._mix_neighboring_replicas()        
        elif self.replica_mixing_scheme == 'swap-all':
            # Try to use weave-accelerated mixing code if possible, otherwise fall back to Python-accelerated code.            
            try:
                self._mix_all_replicas_weave()            
            except:
                self._mix_all_replicas()
        else:
            raise ParameterException("Replica mixing scheme '%s' unknown.  Choose valid 'replica_mixing_scheme' parameter." % self.replica_mixing_scheme)
        end_time = time.time()

        # Determine fraction of swaps accepted this iteration.        
        nswaps_attempted = self.Nij_proposed.sum()
        nswaps_accepted = self.Nij_accepted.sum()
        swap_fraction_accepted = float(nswaps_accepted) / float(nswaps_attempted);
        if self.verbose: print "Accepted %d / %d attempted swaps (%.1f %%)" % (nswaps_accepted, nswaps_attempted, swap_fraction_accepted * 100.0)

        # Estimate cumulative transition probabilities between all states.
        print "rank %d / %d : accumulating mixing statistics..." % (self.mpi.rank, self.mpi.size)
        Nij_accepted = self.ncfile.variables['accepted'][:,:,:].sum(0) + self.Nij_accepted
        Nij_proposed = self.ncfile.variables['proposed'][:,:,:].sum(0) + self.Nij_proposed
        swap_Pij_accepted = numpy.zeros([self.nstates,self.nstates], numpy.float64)
        for istate in range(self.nstates):
            Ni = Nij_proposed[istate,:].sum()
            if (Ni == 0):
                swap_Pij_accepted[istate,istate] = 1.0
            else:
                swap_Pij_accepted[istate,istate] = 1.0 - float(Nij_accepted[istate,:].sum() - Nij_accepted[istate,istate]) / float(Ni)
                for jstate in range(self.nstates):
                    if istate != jstate:
                        swap_Pij_accepted[istate,jstate] = float(Nij_accepted[istate,jstate]) / float(Ni)

        # Report on mixing.
        if self.verbose:
            print "Mixing of replicas took %.3f s" % (end_time - start_time)

        if (self.mpi is not None):
            #print "rank %d / %d : scattering replica states (expecting %d)." % (self.mpi.rank, self.mpi.size, self.replica_states.size)            
            self.replica_states = self.mpi.bcast(self.replica_states)

        return

    def _show_mixing_statistics(self):
        """
        Print summary of mixing statistics.

        """
        
        # Don't print anything until we've accumulated some statistics.
        if self.iteration < 2:
            return
        
        # Compute statistics of transitions.
        Nij = numpy.zeros([self.nstates,self.nstates], numpy.float64)
        for iteration in range(self.iteration - 1):
            for ireplica in range(self.nstates):
                istate = self.ncfile.variables['states'][iteration,ireplica]
                jstate = self.ncfile.variables['states'][iteration+1,ireplica]
                Nij[istate,jstate] += 0.5
                Nij[jstate,istate] += 0.5
        Tij = numpy.zeros([self.nstates,self.nstates], numpy.float64)
        for istate in range(self.nstates):
            Tij[istate,:] = Nij[istate,:] / Nij[istate,:].sum()

        if self.show_mixing_statistics:
            # Print observed transition probabilities.
            PRINT_CUTOFF = 0.001 # Cutoff for displaying fraction of accepted swaps.
            print "Cumulative symmetrized state mixing transition matrix:"
            print "%6s" % "",
            for jstate in range(self.nstates):
                print "%6d" % jstate,
            print ""
            for istate in range(self.nstates):
                print "%-6d" % istate,
                for jstate in range(self.nstates):
                    P = Tij[istate,jstate]
                    if (P >= PRINT_CUTOFF):
                        print "%6.3f" % P,
                    else:
                        print "%6s" % "",
                print ""

        # Estimate second eigenvalue and equilibration time.
        mu = numpy.linalg.eigvals(Tij)
        mu = -numpy.sort(-mu) # sort in descending order
        if (mu[1] >= 1):
            print "Perron eigenvalue is unity; Markov chain is decomposable."
        else:
            print "Perron eigenvalue is %9.5f; state equilibration timescale is ~ %.1f iterations" % (mu[1], 1.0 / (1.0 - mu[1]))

        return

    def _initialize_netcdf(self):
        """
        Initialize NetCDF file for storage.
        
        """    

        if (self.mpi) and (self.mpi.rank != 0):
            return

        # Open NetCDF 4 file for writing.
        #ncfile = netcdf.NetCDFFile(self.store_filename, 'w', version=2)
        ncfile = netcdf.Dataset(self.store_filename, 'w', version=2)        

        # Create dimensions.
        ncfile.createDimension('iteration', 0) # unlimited number of iterations
        ncfile.createDimension('replica', self.nreplicas) # number of replicas
        ncfile.createDimension('atom', self.natoms) # number of atoms in system
        ncfile.createDimension('snapshot', self.nsnapshots+1)
        ncfile.createDimension('saveatom', self.nsaveatoms) # number of atoms in system to save
        ncfile.createDimension('spatial', 3) # number of spatial dimensions

        # Set global attributes.
        setattr(ncfile, 'tile', self.title)
        setattr(ncfile, 'application', 'YANK')
        setattr(ncfile, 'program', 'yank.py')
        setattr(ncfile, 'programVersion', __version__)
        setattr(ncfile, 'Conventions', 'YANK')
        setattr(ncfile, 'ConventionVersion', '0.1')
        
        # Create variables.
        ncvar_temperatures = ncfile.createVariable('temperatures', 'f', ('replica',))
        ncvar_positions = ncfile.createVariable('positions', 'f', ('iteration','replica','atom','spatial'))
        ncvar_trajectories = ncfile.createVariable('trajectories', 'f', ('iteration','replica','snapshot','saveatom','spatial'))        
        ncvar_path_hamiltonians = ncfile.createVariable('path_hamiltonians', 'f', ('iteration','replica'))
        ncvar_states    = ncfile.createVariable('states', 'i', ('iteration','replica'))
        ncvar_energies  = ncfile.createVariable('energies', 'f', ('iteration','replica','replica'))
        ncvar_proposed  = ncfile.createVariable('proposed', 'l', ('iteration','replica','replica'))
        ncvar_accepted  = ncfile.createVariable('accepted', 'l', ('iteration','replica','replica'))                
        
        # Define units for variables.
        setattr(ncvar_temperatures, 'units', 'kelvin')
        setattr(ncvar_positions, 'units', 'nanometers')
        setattr(ncvar_trajectories, 'units', 'nanometers')
        setattr(ncvar_path_hamiltonians, 'units', 'kilojoules_per_mole')

        # Define long (human-readable) names for variables.
        setattr(ncvar_positions, "long_name", "positions[iteration][replica][atom][spatial] is position of coordinate 'spatial' of atom 'atom' from replica 'replica' for iteration 'iteration'.")
        setattr(ncvar_states,    "long_name", "states[iteration][replica] is the state index (0..nstates-1) of replica 'replica' of iteration 'iteration'.")
        setattr(ncvar_energies,  "long_name", "energies[iteration][replica][state] is the reduced (unitless) energy of replica 'replica' from iteration 'iteration' evaluated at state 'state'.")
        setattr(ncvar_proposed,  "long_name", "proposed[iteration][i][j] is the number of proposed transitions between states i and j from iteration 'iteration-1'.")
        setattr(ncvar_accepted,  "long_name", "accepted[iteration][i][j] is the number of proposed transitions between states i and j from iteration 'iteration-1'.")

        # TODO: Write temperatures and inverse temperatures.
        ncfile.variables['temperatures'][:] = self.temperatures / units.kelvin

        # Force sync to disk to avoid data loss.
        ncfile.sync()

        # Store netcdf file handle.
        self.ncfile = ncfile
        
        return
    
    def _write_iteration_netcdf(self):
        """
        Write positions, states, and energies of current iteration to NetCDF file.
        
        """

        if (self.mpi is not None) and (self.mpi.rank != 0):
            return

        # Store replica positions.
        for replica_index in range(self.nstates):
            coordinates = self.replica_coordinates[replica_index]
            x = coordinates / units.nanometers
            self.ncfile.variables['positions'][self.iteration,replica_index,:,:] = x[:,:]

        # Store trajectories.
        for replica_index in range(self.nstates):
            trajectory = self.replica_trajectories[replica_index]
            for (snapshot_index,snapshot) in enumerate(trajectory):
                coordinates = snapshot.coordinates[0:self.nsaveatoms,:]
                x = coordinates / units.nanometers
                self.ncfile.variables['trajectories'][self.iteration,replica_index,snapshot_index,:,:] = x[:,:]
            
        # TODO: Store box vectors        

        # Store state information.
        self.ncfile.variables['states'][self.iteration,:] = self.replica_states[:]

        # Store energies.
        self.ncfile.variables['energies'][self.iteration,:,:] = self.u_kl[:,:]

        # Store path Hamiltonians.
        for replica_index in range(self.nstates):
            self.ncfile.variables['path_hamiltonians'][self.iteration,replica_index] = self.replica_trajectories[replica_index].path_hamiltonian / units.kilojoules_per_mole

        # Store mixing statistics.
        self.ncfile.variables['proposed'][self.iteration,:,:] = self.Nij_proposed[:,:]
        self.ncfile.variables['accepted'][self.iteration,:,:] = self.Nij_accepted[:,:]        

        # Force sync to disk to avoid data loss.
        self.ncfile.sync()

        return

    def _resume_from_netcdf(self):
        """
        Resume execution by reading current positions and energies from a NetCDF file.
        
        """

        # Open NetCDF file for reading
        #ncfile = netcdf.NetCDFFile(self.store_filename, 'r') # Scientific.IO.NetCDF
        ncfile = netcdf.Dataset(self.store_filename, 'r') # netCDF4
        
        # TODO: Perform sanity check on file before resuming

        # Get current dimensions.
        self.iteration = ncfile.variables['energies'].shape[0] - 1
        self.nstates = ncfile.variables['energies'].shape[1]
        self.natoms = ncfile.variables['energies'].shape[2]

        # Restore positions.
        self.replica_coordinates = list()
        for replica_index in range(self.nstates):
            x = ncfile.variables['positions'][self.iteration,replica_index,:,:].astype(numpy.float64).copy()
            coordinates = units.Quantity(x, units.nanometers)
            self.replica_coordinates.append(coordinates)

        # Restore state information.
        self.replica_states = ncfile.variables['states'][self.iteration,:].copy()

        # Restore energies.
        self.u_kl = ncfile.variables['energies'][self.iteration,:,:].copy()

        # Close NetCDF file.
        ncfile.close()        

        # We will work on the next iteration.
        self.iteration += 1
        
        # Reopen NetCDF file for appending, and maintain handle.
        #self.ncfile = netcdf.NetCDFFile(self.store_filename, 'a')
        if (self.mpi is None) or (self.mpi.rank == 0):
            self.ncfile = netcdf.Dataset(self.store_filename, 'a') # for netCDF4
        else:
            self.ncfile = None
        
        return

    def _show_energies(self):
        """
        Show energies (in units of kT) for all replicas at all states.

        """

        # print header
        print "%-24s %16s" % ("reduced potential (kT)", "current state"),
        for state_index in range(self.nstates):
            print " state %3d" % state_index,
        print ""

        # print energies in kT
        for replica_index in range(self.nstates):
            print "replica %-16d %16d" % (replica_index, self.replica_states[replica_index]),
            for state_index in range(self.nstates):
                print "%10.1f" % (self.u_kl[replica_index,state_index]),
            print ""

        return

#=============================================================================================
# MAIN AND TESTS
#=============================================================================================
    
if __name__ == "__main__":

    store_filename  = 'alanine-dipeptide-parallel-tempering.nc' # output netCDF filename

    Tmin = 300.0 * units.kelvin # minimum temperature
    Tmax = 600.0 * units.kelvin # maximum temperature
    ntemps = 40 # number of replicas

    timestep = 2.0 * units.femtoseconds # timestep for simulation
    nsteps_per_snapshot = 250 # number of timesteps between recording positions
    nsnapshots = 20 # 10 ps
    niterations = 10000 # number of iterations
    nequiliterations = 100 # number of equilibration iterations at Tmin
     
    verbose = True # verbose output

    print "importing mpi4py.MPI..."
    from mpi4py import MPI
    print "Started node %d / %d" % (MPI.COMM_WORLD.rank, MPI.COMM_WORLD.size)

    # Set up device to bind to.
    print "MPI..."
    from mpi4py import MPI # MPI wrapper    
    hostname = os.uname()[1]
    ngpus = 2
    comm = MPI.COMM_WORLD # MPI communicator
    deviceid = comm.rank % ngpus
    #platform = openmm.Platform.getPlatformByName('Cuda') # fast GPU platform
    platform = openmm.Platform.getPlatformByName('OpenCL') # fast GPU platform
    platform.setPropertyDefaultValue('OpenCLDeviceIndex', '%d' % deviceid)
    platform.setPropertyDefaultValue('CudaDeviceIndex', '%d' % deviceid) 
    print "node '%s' deviceid %d / %d, MPI rank %d / %d" % (hostname, deviceid, ngpus, comm.rank, comm.size)
    # Make sure random number generators have different seeds.
    seed = numpy.random.randint(sys.maxint - comm.size) + comm.rank
    numpy.random.seed(seed)
    print "node %d / %d : seed %d" % (comm.rank, comm.size, seed)
    
    # Create system.
    nonbondedCutoff = 9.0 * units.angstroms
    prmtop_filename = os.path.join('setup', 'alanine-dipeptide.prmtop')
    crd_filename = os.path.join('setup', 'alanine-dipeptide.equilibrated.crd')
    import simtk.chem.openmm.extras.amber as amber
    if verbose: print "Creating system from AMBER prmtop..."
    system = amber.readAmberSystem(prmtop_filename, mm=openmm, nonbondedCutoff=nonbondedCutoff, nonbondedMethod='PME', shake='h-bonds')
    if verbose: print "Reading coordinates..."
    [coordinates, box_vectors] = amber.readAmberCoordinates(crd_filename, read_box=True)
    system.setDefaultPeriodicBoxVectors(box_vectors[0], box_vectors[1], box_vectors[2])
    
    # Initialize parallel tempering simulation.
    if verbose and comm.rank==0: print "Initializing parallel tempering simulation..."
    simulation = ParallelTempering(system, coordinates, store_filename, Tmin=Tmin, Tmax=Tmax, ntemps=ntemps, mpi=comm)
    simulation.verbose = True # write debug output
    simulation.platform = platform # use specified platform
    simulation.number_of_equilibration_iterations = nequiliterations
    simulation.number_of_iterations = niterations # number of iterations (exchange attempts)
    simulation.timestep = timestep # timestep
    simulation.nsteps_per_snapshot = nsteps_per_snapshot # number of timesteps per iteration
    simulation.nsnapshots = nsnapshots
    simulation.nsaveatoms = 22

    # Run or resume simulation.
    if verbose: print "Running..."
    simulation.run()

