#!/usr/local/bin/env python

#=============================================================================================
# MODULE DOCSTRING
#=============================================================================================

"""
Parallel tempering driver.
Parallelization using mpi4py.

DESCRIPTION

This script directs parallel tempering simulations of AMBER protein systems in implicit solvent.

COPYRIGHT

@author John D. Chodera <jchodera@gmail.com>

This source file is released under the GNU General Public License.

This program is free software: you can redistribute it and/or modify it under
the terms of the GNU General Public License as published by the Free Software
Foundation, either version 3 of the License, or (at your option) any later
version.

This program is distributed in the hope that it will be useful, but WITHOUT ANY
WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A
PARTICULAR PURPOSE.  See the GNU General Public License for more details.
 
You should have received a copy of the GNU General Public License along with
this program.  If not, see <http://www.gnu.org/licenses/>.

"""

#=============================================================================================
# GLOBAL IMPORTS
#=============================================================================================

import os
import os.path
import sys
import numpy
import math
import copy
import time

import scipy.optimize # THIS MUST BE IMPORTED FIRST?!

import os
import os.path

import numpy
import numpy.linalg

import simtk.unit as units
import simtk.chem.openmm as openmm
import simtk.chem.openmm.extras.amber as amber

#=============================================================================================
# SOURCE CONTROL
#=============================================================================================

__version__ = "$Id: $"

#=============================================================================================
# MODULE CONSTANTS
#=============================================================================================

kB = units.BOLTZMANN_CONSTANT_kB * units.AVOGADRO_CONSTANT_NA # Boltzmann constant

#=============================================================================================
# MAIN AND TESTS
#=============================================================================================
    
if __name__ == "__main__":

    temperature = 300.0 * units.kelvin # minimum temperature

    timestep = 2.0 * units.femtoseconds # timestep for simulation
    nsteps_per_iteration = 5000 # number of timesteps between report
    nequiliterations = 100 # number of equilibration iterations at Tmin
     
    verbose = True # verbose output
    minimize = True # minimize
    equilibrate = True # equilibrate

    platform = openmm.Platform.getPlatformByName('Cuda') # fast GPU platform
    #platform = openmm.Platform.getPlatformByName('OpenCL') # fast GPU platform
    
    equilibrated_crd_filename = 'alanine-dipeptide.equilibrated.crd'
    
    # Create system.
    nonbondedCutoff = 9.0 * units.angstroms
    prmtop_filename = 'alanine-dipeptide.prmtop'
    crd_filename = 'alanine-dipeptide.crd'
    if verbose: print "Creating system from AMBER prmtop..."
    system = amber.readAmberSystem(prmtop_filename, mm=openmm, nonbondedCutoff=nonbondedCutoff, nonbondedMethod='PME', shake='h-bonds')
    if verbose: print "Reading coordinates..."
    [coordinates, box_vectors] = amber.readAmberCoordinates(crd_filename, read_box=True)
    system.setDefaultPeriodicBoxVectors(box_vectors[0], box_vectors[1], box_vectors[2])
    
    # Minimize system (if not resuming).
    if (minimize):
        if verbose: print "Minimizing..."

        # Create a Context with arbitrary parameters.
        integrator = openmm.VerletIntegrator(timestep)
        context = openmm.Context(system, integrator, platform)

        # Set coordinates.
        context.setPositions(coordinates)

        # Compute initial energy.
        state = context.getState(getEnergy=True)
        initial_potential = state.getPotentialEnergy()
        if verbose: print "initial potential : %s" % (str(initial_potential))
        
        # Minimize.
        openmm.LocalEnergyMinimizer.minimize(context)    

        # Compute final energy.
        state = context.getState(getEnergy=True)
        final_potential = state.getPotentialEnergy()
        if verbose: print "final potential   : %s" % (str(final_potential))

        # Get minimized coordinates.
        state = context.getState(getPositions=True)
        coordinates = state.getPositions(asNumpy=True)

        # Clean up.
        del state, context, integrator
        
    # Equilibrate (if not resuming).
    if (equilibrate):
        if verbose: print "Equilibrating at %.1f K for %.3f ps with %.1f fs timestep..." % (temperature / units.kelvin, nequiliterations * nsteps_per_iteration * timestep / units.picoseconds, timestep/units.femtoseconds)
        # Add Monte Carlo barostat.
        pressure = 1.0 * units.atmospheres
        barostat_frequency = 25 # steps between barostat updates
        barostat = openmm.MonteCarloBarostat(pressure, temperature, barostat_frequency)
        system.addForce(barostat)
        # Create integrator and context.
        collision_rate = 5.0 / units.picosecond
        integrator = openmm.LangevinIntegrator(temperature, collision_rate, timestep)    
        context = openmm.Context(system, integrator, platform)
        context.setPositions(coordinates)
        for iteration in range(nequiliterations):
            # Integrate equations of motion.
            integrator.step(nsteps_per_iteration)
            # Get potential energy and box dimensions.
            state = context.getState(getEnergy=True, getPositions=True)
            [a,b,c] = state.getPeriodicBoxVectors()
            volume = a[0]*b[1]*c[2]
            # Compute instantaneous temperature.
            # NkT/2 = KE
            ndof = 3*system.getNumParticles() - system.getNumConstraints()
            kinetic_temperature = state.getKineticEnergy() / ndof * 2 / kB
            # Update box dimensions.
            system.setDefaultPeriodicBoxVectors(a,b,c)
            # Get updated positions.
            coordinates = state.getPositions(asNumpy=True)            
            # Report on progress.
            if verbose: print "iteration %8d | %12.3f ns | potential: %16.3f kcal/mol | temperature: %8.1f K | volume: %16.1f A^3" % (iteration, state.getTime() / units.nanosecond, state.getPotentialEnergy() / units.kilocalories_per_mole, kinetic_temperature / units.kelvin, volume / units.angstroms**3)

        # Clean up.
        del state, context, integrator        
        
    # Write final configuration.
    if verbose: print "Reading coordinates..."    
    box_vectors = system.getDefaultPeriodicBoxVectors()
    amber.writeAmberCoordinates(equilibrated_crd_filename, system, coordinates, box_vectors)
