#!/usr/local/bin/env python

#=============================================================================================
# MODULE DOCSTRING
#=============================================================================================

"""
Parallel tempering driver.

DESCRIPTION

This script directs parallel tempering simulations of AMBER protein systems in implicit solvent.

COPYRIGHT

@author John D. Chodera <jchodera@gmail.com>

This source file is released under the GNU General Public License.

This program is free software: you can redistribute it and/or modify it under
the terms of the GNU General Public License as published by the Free Software
Foundation, either version 3 of the License, or (at your option) any later
version.

This program is distributed in the hope that it will be useful, but WITHOUT ANY
WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A
PARTICULAR PURPOSE.  See the GNU General Public License for more details.
 
You should have received a copy of the GNU General Public License along with
this program.  If not, see <http://www.gnu.org/licenses/>.

"""

#=============================================================================================
# GLOBAL IMPORTS
#=============================================================================================

import scipy.optimize # THIS MUST BE IMPORTED FIRST?!

import os
import os.path

import numpy

import simtk.unit as units
import simtk.chem.openmm as openmm
import simtk.chem.openmm.extras.amber as amber
import simtk.chem.openmm.extras.optimize as optimize

#=============================================================================================
# SOURCE CONTROL
#=============================================================================================

__version__ = "$Id: $"

#=============================================================================================
# MAIN AND TESTS
#=============================================================================================
    
if __name__ == "__main__":

    # Determine prmtop and crd filenames in test directory.
    # TODO: Make these command-line arguments?
    prmtop_filename = os.path.join(os.getenv('YANK_INSTALL_DIR'), 'test', 'systems', 'lactalbumin', 'la_allcys.prmtop') # input prmtop file
    crd_filename    = os.path.join(os.getenv('YANK_INSTALL_DIR'), 'test', 'systems', 'lactalbumin', 'la_allcys.crd') # input coordinate file

    # Uncomment this for smaller test system (alanine dipeptide in implicit solvent)
    #prmtop_filename = os.path.join(os.getenv('YANK_INSTALL_DIR'), 'test', 'systems', 'alanine-dipeptide-gbsa', 'alanine-dipeptide.prmtop') # input prmtop file
    #crd_filename    = os.path.join(os.getenv('YANK_INSTALL_DIR'), 'test', 'systems', 'alanine-dipeptide-gbsa', 'alanine-dipeptide.crd') # input coordinate file    

    store_filename  = 'parallel-tempering-neighbor-swaps-lactalbumin.nc' # output netCDF filename

    Tmin = 270.0 * units.kelvin # minimum temperature
    Tmax = 450.0 * units.kelvin # maximum temperature
    ntemps = 20 # number of replicas

    timestep = 2.0 * units.femtoseconds # timestep for simulation
    nsteps = 2500 # number of timesteps per iteration (exchange attempt)
    #nsteps = 500 # number of timesteps per iteration (exchange attempt)
    niterations = 2000 # number of iterations
    nequiliterations = 0 # number of equilibration iterations at Tmin with timestep/2 timestep, for nsteps*2
     
    verbose = True # verbose output
    minimize = True # minimize
    equilibrate = False # equilibrate

    # Select platform: one of 'Reference' (CPU-only), 'Cuda' (NVIDIA Cuda), or 'OpenCL' (for OS X 10.6 with OpenCL OpenMM compiled)
    platform = openmm.Platform.getPlatformByName("Cuda")    

    # Create system.
    if verbose: print "Reading AMBER prmtop..."
    system = amber.readAmberSystem(prmtop_filename, shake="h-bonds", gbmodel="OBC GBSA", nonbondedCutoff=None)
    if verbose: print "Reading AMBER coordinates..."
    coordinates = amber.readAmberCoordinates(crd_filename)
    if verbose: print "prmtop and coordinate files read.\n"

    # Determine whether we will resume or create new simulation.
    resume = False
    if os.path.exists(store_filename):
        resume = True
        if verbose: print "Store filename '%s' found, resuming existing run..." % store_filename

    # Minimize system (if not resuming).
    if minimize and not resume:
        if verbose: print "Minimizing..."
        minimizer = optimize.LBFGSMinimizer(system, verbose=verbose, platform=platform)
        nminiterations = 1
        for iteration in range(nminiterations):
            print "iteration %d / %d" % (iteration, nminiterations)
            coordinates = minimizer.minimize(coordinates, constrain=True)
        del minimizer    

    # Equilibrate (if not resuming).
    if equilibrate and not resume:
        if verbose: print "Equilibrating at %.1f K for %.3f ps with %.1f fs timestep..." % (Tmin / units.kelvin, nequiliterations * nsteps * timestep / units.picoseconds, (timestep/2.0)/units.femtoseconds)
        collision_rate = 5.0 / units.picosecond
        integrator = openmm.LangevinIntegrator(Tmin, collision_rate, timestep / 2.0)    
        context = openmm.Context(system, integrator, platform)
        context.setPositions(coordinates)
        for iteration in range(nequiliterations):
            integrator.step(nsteps * 2)
            state = context.getState(getEnergy=True)
            if verbose: print "iteration %8d %12.3f ns %16.3f kcal/mol" % (iteration, state.getTime() / units.nanosecond, state.getPotentialEnergy() / units.kilocalories_per_mole)            

    # Initialize parallel tempering simulation.
    #import simtk.chem.openmm.extras.repex as repex
    import repex # DEBUG
    if verbose: print "Initializing parallel tempering simulation..."
    simulation = repex.ParallelTempering(system, coordinates, store_filename, Tmin=Tmin, Tmax=Tmax, ntemps=ntemps)
    simulation.verbose = True # write debug output
    simulation.platform = platform # use Cuda platform
    simulation.number_of_equilibration_iterations = nequiliterations
    simulation.number_of_iterations = niterations # number of iterations (exchange attempts)
    simulation.timestep = timestep # timestep
    simulation.nsteps_per_iteration = nsteps # number of timesteps per iteration
    simulation.replica_mixing_scheme = 'swap-neighbors' # 'swap-neighbors' or 'swap-all'
    #simulation.replica_mixing_scheme = 'swap-all' # 'swap-neighbors' or 'swap-all'    
    simulation.minimize = False
    
    # Run or resume simulation.
    if verbose: print "Running..."
    simulation.run()

