#!/usr/local/bin/env python

#=============================================================================================
# MODULE DOCSTRING
#=============================================================================================

"""
Parallel tempering driver.

DESCRIPTION

This script directs parallel tempering simulations of AMBER protein systems in implicit solvent.

COPYRIGHT

@author John D. Chodera <jchodera@gmail.com>

This source file is released under the GNU General Public License.

This program is free software: you can redistribute it and/or modify it under
the terms of the GNU General Public License as published by the Free Software
Foundation, either version 3 of the License, or (at your option) any later
version.

This program is distributed in the hope that it will be useful, but WITHOUT ANY
WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A
PARTICULAR PURPOSE.  See the GNU General Public License for more details.
 
You should have received a copy of the GNU General Public License along with
this program.  If not, see <http://www.gnu.org/licenses/>.

"""

#=============================================================================================
# GLOBAL IMPORTS
#=============================================================================================

import scipy.optimize # THIS MUST BE IMPORTED FIRST?!

import os
import os.path
import sys

import numpy

import simtk.unit as units
import simtk.chem.openmm as openmm
import simtk.chem.openmm.extras.amber as amber
import simtk.chem.openmm.extras.optimize as optimize

#=============================================================================================
# SOURCE CONTROL
#=============================================================================================

__version__ = "$Id: $"

#=============================================================================================
# MAIN AND TESTS
#=============================================================================================

class ParallelTemperingDriver(object):
    def __init__(self):
        pass

    def run(self):
        # PARAMETERS
        
        # Beta-lactalbumin in implicit solvent test system.
        #prmtop_filename = os.path.join('systems', 'beta-lactalbumin', 'la_allcys.prmtop') # input Amber prmtop file
        #crd_filename    = os.path.join('systems', 'beta-lactalbumin', 'la_allcys.crd') # input Amber coordinate file

        # Alanine dipeptide in implicit solvent test system.
        prmtop_filename = os.path.join('systems', 'alanine-dipeptide-gbsa', 'alanine-dipeptide.prmtop') # input prmtop file
        crd_filename    = os.path.join('systems', 'alanine-dipeptide-gbsa', 'alanine-dipeptide.crd') # input coordinate file    

        store_filename  = 'parallel-tempering-output.nc' # output netCDF filename

        Tmin = 270.0 * units.kelvin # minimum temperature
        Tmax = 450.0 * units.kelvin # maximum temperature
        ntemps = 20 # number of replicas

        timestep = 2.0 * units.femtoseconds # timestep for simulation
        nsteps = 2500 # number of timesteps per iteration (exchange attempt)
        #nsteps = 500 # number of timesteps per iteration (exchange attempt)
        niterations = 2000 # number of iterations
        nequiliterations = 0 # number of equilibration iterations at Tmin with timestep/2 timestep, for nsteps*2

        verbose = True # verbose output
        minimize = True # minimize
        equilibrate = False # equilibrate

        # Select platform: one of 'Reference' (CPU-only), 'Cuda' (NVIDIA Cuda), or 'OpenCL' (for OS X 10.6 with OpenCL OpenMM compiled)
        platform = openmm.Platform.getPlatformByName("Cuda")    

        # Create system.
        if verbose: print "Reading AMBER prmtop..."
        system = amber.readAmberSystem(prmtop_filename, shake="h-bonds", gbmodel="OBC GBSA", nonbondedCutoff=None)
        if verbose: print "Reading AMBER coordinates..."
        coordinates = amber.readAmberCoordinates(crd_filename)
        if verbose: print "prmtop and coordinate files read.\n"

        # Determine whether we will resume or create new simulation.
        resume = False
        if os.path.exists(store_filename):
            resume = True
            if verbose: print "Store filename '%s' found, resuming existing run..." % store_filename

        # Minimize system (if not resuming).
        if minimize and not resume:
            if verbose: print "Minimizing..."
            minimizer = optimize.LBFGSMinimizer(system, verbose=verbose, platform=platform)
            nminiterations = 1
            for iteration in range(nminiterations):
                print "iteration %d / %d" % (iteration, nminiterations)
                coordinates = minimizer.minimize(coordinates, constrain=True)
            del minimizer    

        # Equilibrate (if not resuming).
        if equilibrate and not resume:
            if verbose: print "Equilibrating at %.1f K for %.3f ps with %.1f fs timestep..." % (Tmin / units.kelvin, nequiliterations * nsteps * timestep / units.picoseconds, (timestep/2.0)/units.femtoseconds)
            collision_rate = 5.0 / units.picosecond
            integrator = openmm.LangevinIntegrator(Tmin, collision_rate, timestep / 2.0)    
            context = openmm.Context(system, integrator, platform)
            context.setPositions(coordinates)
            for iteration in range(nequiliterations):
                integrator.step(nsteps * 2)
                state = context.getState(getEnergy=True)
                if verbose: print "iteration %8d %12.3f ns %16.3f kcal/mol" % (iteration, state.getTime() / units.nanosecond, state.getPotentialEnergy() / units.kilocalories_per_mole)            

        # Set up device to bind to.
        print "Selecting MPI communicator and selecting a GPU device..."
        from mpi4py import MPI # MPI wrapper
        hostname = os.uname()[1]
        ngpus = 2 # number of GPUs per system
        comm = MPI.COMM_WORLD # MPI communicator
        deviceid = comm.rank % ngpus # select a unique GPU for this node assuming block allocation (not round-robin)
        platform.setPropertyDefaultValue('CudaDeviceIndex', '%d' % deviceid) # select Cuda device index
        platform.setPropertyDefaultValue('OpenCLDeviceIndex', '%d' % deviceid) # select OpenCL device index
        print "node '%s' deviceid %d / %d, MPI rank %d / %d" % (hostname, deviceid, ngpus, comm.rank, comm.size)
        # Make sure random number generators have unique seeds.
        seed = numpy.random.randint(sys.maxint - comm.size) + comm.rank
        numpy.random.seed(seed)

        # Initialize parallel tempering simulation.
        #import simtk.chem.openmm.extras.repex as repex
        import repexmpi as repex
        if verbose: print "Initializing parallel tempering simulation..."
        simulation = repex.ParallelTempering(system, coordinates, store_filename, Tmin=Tmin, Tmax=Tmax, ntemps=ntemps, comm=comm)
        simulation.verbose = True # write debug output
        simulation.platform = platform # use Cuda platform
        simulation.number_of_equilibration_iterations = nequiliterations
        simulation.number_of_iterations = niterations # number of iterations (exchange attempts)
        simulation.timestep = timestep # timestep
        simulation.nsteps_per_iteration = nsteps # number of timesteps per iteration
        #simulation.replica_mixing_scheme = 'swap-neighbors' # 'swap-neighbors' or 'swap-all'
        simulation.replica_mixing_scheme = 'swap-all' # 'swap-neighbors' or 'swap-all'    
        simulation.minimize = False

        # Run or resume simulation.
        if verbose: print "Running..."
        simulation.run()
    
if __name__ == "__main__":
    print "Initializing MPI..."
    from mpi4py import MPI # MPI wrapper
    print "Initialized on node %d / %d" % (MPI.COMM_WORLD.rank, MPI.COMM_WORLD.size)
    driver = ParallelTemperingDriver()
    driver.run()
