#!/usr/local/bin/env python

#=============================================================================================
# MODULE DOCSTRING
#=============================================================================================

"""
Parallel tempering driver.

DESCRIPTION

This script directs parallel tempering simulations of AMBER protein systems in implicit solvent.

COPYRIGHT

@author John D. Chodera <jchodera@gmail.com>

This source file is released under the GNU General Public License.

This program is free software: you can redistribute it and/or modify it under
the terms of the GNU General Public License as published by the Free Software
Foundation, either version 3 of the License, or (at your option) any later
version.

This program is distributed in the hope that it will be useful, but WITHOUT ANY
WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A
PARTICULAR PURPOSE.  See the GNU General Public License for more details.
 
You should have received a copy of the GNU General Public License along with
this program.  If not, see <http://www.gnu.org/licenses/>.

"""

#=============================================================================================
# GLOBAL IMPORTS
#=============================================================================================

import pdb
import os
import os.path

import numpy

import simtk.openmm as openmm
import simtk.pyopenmm.amber.amber_file_parser as amber
import simtk.unit as units
import netCDF4 as netcdf

#=============================================================================================
# SOURCE CONTROL
#=============================================================================================

__version__ = "$Id: $"

#=============================================================================================
# MAIN AND TESTS
#=============================================================================================
    
if __name__ == "__main__":

    # Determine prmtop and crd filenames in test directory.
    # TODO: Make these command-line arguments?
    #prmtop_filename = os.path.join(os.getenv('HOME'), 'lactalbumin', 'inputs','la_allcys.prmtop') # input prmtop file
    #crd_filename    = os.path.join(os.getenv('HOME'), 'lactalbumin', 'inputs','la_allcys.crd') # input coordinate file
    prmtop_filename = os.path.join(os.getenv('HOME'), 'testMPI', 'inputs','blg.prmtop') # input prmtop file
    crd_filename    = os.path.join(os.getenv('HOME'), 'testMPI', 'inputs','blg.crd') # input coordinate file
    #prmtop_filename = os.path.join(os.getenv('HOME'), 'lactalbumin', 'inputs','blg.prmtop') # input prmtop file
    #crd_filename    = os.path.join(os.getenv('HOME'), 'lactalbumin', 'inputs','blg.crd') # input coordinate file

    # Uncomment this for smaller test system (alanine dipeptide in implicit solvent)
    #prmtop_filename = os.path.join(os.getenv('YANK_INSTALL_DIR'), 'test', 'systems', 'alanine-dipeptide-gbsa', 'alanine-dipeptide.prmtop') # input prmtop file
    #crd_filename    = os.path.join(os.getenv('YANK_INSTALL_DIR'), 'test', 'systems', 'alanine-dipeptide-gbsa', 'alanine-dipeptide.crd') # input coordinate file    

    #store_filename  = 'blg.nc' # output netCDF filename
    store_filename  = 'compare_step.nc' # output netCDF filename

    Tmin = 270.0 * units.kelvin # minimum temperature
    Tmax = 400.0 * units.kelvin # maximum temperature
    ntemps = 16 # number of replicas

    #Tmin = 270.0 * units.kelvin # minimum temperature
    #Tmax = 450.0 * units.kelvin # maximum temperature
    #ntemps = 20 # number of replicas

    timestep = 2.0 * units.femtoseconds # timestep for simulation
    nsteps = 2500 # number of timesteps per iteration (exchange attempt)
    niterations = 10 # number of iterations
    nequiliterations = 10 # number of equilibration iterations at Tmin with timestep/2 timestep, for nsteps*2
    minimize_tolerance = 1.0 * units.kilojoules / units.nanometers**2
    minimize_maximum_evaluations = 10000 # max number of minimization evaluations
    verbose = True # verbose output
    minimize = True # minimize
    equilibrate = True # equilibrate

    # Select platform: one of 'Reference' (CPU-only), 'Cuda' (NVIDIA Cuda), or 'OpenCL' (for OS X 10.6 with OpenCL OpenMM compiled)
    platform = openmm.Platform.getPlatformByName("Cuda")    

    # Create system.
    if verbose: print "Reading AMBER prmtop..."
    system = amber.readAmberSystem(prmtop_filename, shake="h-bonds", gbmodel="OBC GBSA", nonbondedCutoff=None)
    if verbose: print "Reading AMBER coordinates..."
    coordinates = amber.readAmberCoordinates(crd_filename)
    if verbose: print "prmtop and coordinate files read.\n"

    # Determine whether we will resume or create new simulation.
    resume = False
    if os.path.exists(store_filename):
        resume = True
        if verbose: print "Store filename '%s' found, resuming existing run..." % store_filename

    if not resume:    
        collision_rate = 5.0 / units.picosecond
        integrator = openmm.LangevinIntegrator(Tmin, collision_rate, timestep / 2.0)    
        context = openmm.Context(system, integrator, platform)
        context.setPositions(coordinates)

        # Minimize system (if not resuming).
        if minimize:
            if verbose: print "Minimizing..."
            openmm.LocalEnergyMinimizer.minimize(context, minimize_tolerance, minimize_maximum_evaluations)  
            openmm_state = context.getState(getPositions=True)
            coordinates = openmm_state.getPositions(asNumpy=True)    
        # clean up
        
        # Equilibrate (if not resuming).
        if equilibrate:
            if verbose: print "Equilibrating at %.1f K for %.3f ps with %.1f fs timestep..." % (Tmin / units.kelvin, nequiliterations * nsteps * timestep / units.picoseconds, (timestep/2.0)/units.femtoseconds)

            context.setPositions(coordinates)
            for iteration in range(nequiliterations):
                integrator.step(nsteps * 2)
                state = context.getState(getEnergy=True)
                if verbose: print "iteration %8d %12.3f ns %16.3f kcal/mol" % (iteration, state.getTime() / units.nanosecond, state.getPotentialEnergy() / units.kilocalories_per_mole)            

            openmm_state = context.getState(getPositions=True)
            coordinates = openmm_state.getPositions(asNumpy=True)    
        del context
        del integrator

    # Initialize parallel tempering simulation.
    #import simtk.pyopenmm.extras.repex as repex
    import repex as repex
    if verbose: print "Initializing parallel tempering simulation..."
    simulation = repex.ParallelTempering(system, coordinates, store_filename, Tmin=Tmin, Tmax=Tmax, ntemps=ntemps)
    simulation.verbose = True # write debug output
    simulation.platform = platform # use Cuda platform
    simulation.number_of_equilibration_iterations = nequiliterations
    simulation.number_of_iterations = niterations # number of iterations (exchange attempts)
    simulation.timestep = timestep # timestep
    simulation.nsteps_per_iteration = nsteps # number of timesteps per iteration
    simulation.minimize = False
    
    # Run or resume simulation.
    if verbose: print "Running..."
    simulation.run()

