#!/usr/local/bin/env python -d

#=============================================================================================
# MODULE DOCSTRING
#=============================================================================================

"""
Parallel tempering driver.

DESCRIPTION

This script directs parallel tempering simulations of AMBER protein systems in implicit solvent.

COPYRIGHT

@author John D. Chodera <jchodera@gmail.com>

This source file is released under the GNU General Public License.

This program is free software: you can redistribute it and/or modify it under
the terms of the GNU General Public License as published by the Free Software
Foundation, either version 3 of the License, or (at your option) any later
version.

This program is distributed in the hope that it will be useful, but WITHOUT ANY
WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A
PARTICULAR PURPOSE.  See the GNU General Public License for more details.
 
You should have received a copy of the GNU General Public License along with
this program.  If not, see <http://www.gnu.org/licenses/>.

"""

#=============================================================================================
# GLOBAL IMPORTS
#=============================================================================================

import os
import os.path
import sys
import pdb
import numpy

#import simtk.unit as units
import simtk.openmm as openmm
import simtk.pyopenmm.amber.amber_file_parser as amber
import simtk.unit as units
import netCDF4 as netcdf

#=============================================================================================
# SOURCE CONTROL
#=============================================================================================

__version__ = "$Id: $"

#=============================================================================================
# MAIN AND TESTS
#=============================================================================================

class ParallelTemperingDriver(object):
    def __init__(self):
        pass

    def run(self):        
        # Determine prmtop and crd filenames in test directory.
        # TODO: Make these command-line arguments?
        #prmtop_filename = os.path.join(os.getenv('YANK_INSTALL_DIR'), 'test', 'systems', 'lactalbumin', 'la_allcys.prmtop') # input prmtop file
        #crd_filename    = os.path.join(os.getenv('YANK_INSTALL_DIR'), 'test', 'systems', 'lactalbumin', 'la_allcys.crd') # input coordinate file
        prmtop_filename = os.path.join(os.getenv('HOME'), 'testjohn','fs21.prmtop') # input prmtop file
        crd_filename    = os.path.join(os.getenv('HOME'), 'testjohn','fs21.crd') # input coordinate file
    
        # Uncomment this for smaller test system (alanine dipeptide in implicit solvent)
        #prmtop_filename = os.path.join(os.getenv('YANK_INSTALL_DIR'), 'test', 'systems', 'alanine-dipeptide-gbsa', 'alanine-dipeptide.prmtop') # input prmtop file
        #crd_filename    = os.path.join(os.getenv('YANK_INSTALL_DIR'), 'test', 'systems', 'alanine-dipeptide-gbsa', 'alanine-dipeptide.crd') # input coordinate file    

        store_filename  = 'fs21' + sys.argv[1] + '.nc' # output netCDF filename

        Tmin = 270.0 * units.kelvin # minimum temperature
        Tmax = 650.0 * units.kelvin # maximum temperature
        ntemps = 20 # number of replicas

        timestep = 2.0 * units.femtoseconds # timestep for simulation
        #nsteps = 1000 # number of timesteps per iteration (exchange attempt)
        #nsteps = 5000 # number of timesteps per iteration (exchange attempt)
        #nsteps = 10000 # number of timesteps per iteration (exchange attempt)
        nsteps = 2500 # number of timesteps per iteration (exchange attempt)
        #nsteps = 500 # number of timesteps per iteration (exchange attempt)
        niterations = 10000 # number of iterations
        nequiliterations = 50 # number of equilibration iterations at Tmin with timestep/2 timestep, for nsteps*2
        minimize_tolerance = 1.0 * units.kilojoules / units.nanometers**2
        minimize_maximum_evaluations = 10000 # max number of minimization evaluations
        verbose = True # verbose output
        minimize = True # minimize
        equilibrate = True # equilibrate

        # Select platform: one of 'Reference' (CPU-only), 'Cuda' (NVIDIA Cuda), or 'OpenCL' (for OS X 10.6 with OpenCL OpenMM compiled)
        platform = openmm.Platform.getPlatformByName("Cuda")    

        # Create system.
        if verbose: print "Reading AMBER prmtop..."
        system = amber.readAmberSystem(prmtop_filename, shake="h-bonds", gbmodel="OBC GBSA", nonbondedCutoff=None)
        if verbose: print "Reading AMBER coordinates..."
        coordinates = amber.readAmberCoordinates(crd_filename)
        if verbose: print "prmtop and coordinate files read.\n"

        # Determine whether we will resume or create new simulation.
        resume = False
        if os.path.exists(store_filename):
            resume = True
            if verbose: print "Store filename '%s' found, resuming existing run..." % store_filename

        if not resume:    
            collision_rate = 5.0 / units.picosecond
            integrator = openmm.LangevinIntegrator(Tmin, collision_rate, timestep / 2.0)    
            context = openmm.Context(system, integrator, platform)
            context.setPositions(coordinates)

            # Minimize system (if not resuming).
            if minimize:
                if verbose: print "Minimizing..."
                openmm.LocalEnergyMinimizer.minimize(context, minimize_tolerance, minimize_maximum_evaluations)  
                openmm_state = context.getState(getPositions=True)
                coordinates = openmm_state.getPositions(asNumpy=True)    
            # clean up
        
            # Equilibrate (if not resuming).
            if equilibrate:
                if verbose: print "Equilibrating at %.1f K for %.3f ps with %.1f fs timestep..." % (Tmin / units.kelvin, nequiliterations * nsteps * timestep / units.picoseconds, (timestep/2.0)/units.femtoseconds)

                context.setPositions(coordinates)
                for iteration in range(nequiliterations):
                    integrator.step(nsteps * 2)
                    state = context.getState(getEnergy=True)
                    if verbose: print "iteration %8d %12.3f ns %16.3f kcal/mol" % (iteration, state.getTime() / units.nanosecond, state.getPotentialEnergy() / units.kilocalories_per_mole)            

                openmm_state = context.getState(getPositions=True)
                coordinates = openmm_state.getPositions(asNumpy=True)    
            del context
            del integrator

        # Set up device to bind to.
        print "Selecting MPI communicator and selecting a GPU device..."
        from mpi4py import MPI # MPI wrapper
        hostname = os.uname()[1]
        ngpus = 2 # number of GPUs per system
        comm = MPI.COMM_WORLD # MPI communicator
        deviceid = comm.rank % ngpus # select a unique GPU for this node assuming block allocation (not round-robin)
        platform.setPropertyDefaultValue('CudaDeviceIndex', '%d' % deviceid) # select Cuda device index
        platform.setPropertyDefaultValue('OpenCLDeviceIndex', '%d' % deviceid) # select OpenCL device index
        print "node '%s' deviceid %d / %d, MPI rank %d / %d" % (hostname, deviceid, ngpus, comm.rank, comm.size)
        # Make sure random number generators have unique seeds.
        seed = numpy.random.randint(sys.maxint - comm.size) + comm.rank
        numpy.random.seed(seed)

        # Initialize parallel tempering simulation.
        import repexmpi as repex
        if verbose: print "Initializing parallel tempering simulation..."
        simulation = repex.ParallelTempering(system, coordinates, store_filename, Tmin=Tmin, Tmax=Tmax, ntemps=ntemps, comm=comm)
        simulation.verbose = True # write debug output
        simulation.platform = platform # use Cuda platform
        simulation.number_of_equilibration_iterations = nequiliterations
        simulation.number_of_iterations = niterations # number of iterations (exchange attempts)
        simulation.timestep = timestep # timestep
        simulation.nsteps_per_iteration = nsteps # number of timesteps per iteration
        #simulation.replica_mixing_scheme = 'swap-neighbors' # 'swap-neighbors' or 'swap-all'
        simulation.replica_mixing_scheme = 'swap-all' # 'swap-neighbors' or 'swap-all'    
        simulation.minimize = False

        # Run or resume simulation.
        if verbose: print "Running..."
        simulation.run()
    
if __name__ == "__main__":
    print "Initializing MPI..."
    from mpi4py import MPI # MPI wrapper
    print "Initialized on node %d / %d" % (MPI.COMM_WORLD.rank, MPI.COMM_WORLD.size)
    driver = ParallelTemperingDriver()
    driver.run()
