#!/usr/local/bin/env python -d

#=============================================================================================
# MODULE DOCSTRING
#=============================================================================================

"""
Parallel tempering driver.

DESCRIPTION

This script directs parallel tempering simulations of AMBER protein systems in implicit solvent.

COPYRIGHT

@author John D. Chodera <jchodera@gmail.com>

This source file is released under the GNU General Public License.

This program is free software: you can redistribute it and/or modify it under
the terms of the GNU General Public License as published by the Free Software
Foundation, either version 3 of the License, or (at your option) any later
version.

This program is distributed in the hope that it will be useful, but WITHOUT ANY
WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A
PARTICULAR PURPOSE.  See the GNU General Public License for more details.
 
You should have received a copy of the GNU General Public License along with
this program.  If not, see <http://www.gnu.org/licenses/>.

"""

#=============================================================================================
# GLOBAL IMPORTS
#=============================================================================================

import os
import os.path
import sys
import pdb
import numpy

import simtk.openmm as openmm
import simtk.pyopenmm.amber.amber_file_parser as amber
import simtk.unit as units
import netCDF4 as netcdf

#=============================================================================================
# SOURCE CONTROL
#=============================================================================================

__version__ = "$Id: $"

#=============================================================================================
# MAIN AND TESTS
#=============================================================================================

class ParallelTemperingDriver(object):
    def __init__(self):
        pass

    def run(self):        
        # Determine prmtop and crd filenames in test directory.
        prmtop_filename = 'hp35.prmtop' # input AMBER prmtop file
        crd_filename    = 'hp35.inpcrd' # input AMBER coordinate file
        store_filename = 'repex.nc' # output NetCDF filename
        
        # Parallel tempering options.
        Tmin = 300.0 * units.kelvin # minimum temperature
        Tmax = 400.0 * units.kelvin # maximum temperature
        ntemps = 24 # number of replicas

        timestep = 2.0 * units.femtoseconds # timestep for simulation
        nsteps = 500 # number of timesteps per iteration (exchange attempt)
        niterations = 10000 # number of iterations to complete
        nequiliterations = 1 # number of equilibration iterations at Tmin with timestep/2 timestep, for nsteps*2
        verbose = True # verbose output (set to False for less output)
        minimize = False # minimize

        # Specify which CPUs should be attached to specific GPUs for maximum performance.
        cpu_platform_name = 'Reference'
        gpu_platform_name = 'OpenCL'
        cpuid_gpuid_mapping = { 0:0, 1:1, 8:2, 9:3, 10:4, 11:5 } # cpuid:gpuid for NCSA Forge 

        # Initialize MPI, if available.    
        try:
            # Initialize MPI. 
            # Set up device to bind to.
            from mpi4py import MPI # MPI wrapper
            hostname = os.uname()[1]

            # Turn off output from non-root nodes:
            if not (MPI.COMM_WORLD.rank==0):
                verbose = False

            # Make sure random number generators have unique seeds.
            seed = numpy.random.randint(sys.maxint - MPI.COMM_WORLD.size) + MPI.COMM_WORLD.rank
            numpy.random.seed(seed)

            # Choose appropriate platform for each device.
            cpuid = MPI.COMM_WORLD.rank # use default rank as CPUID (TODO: Improve this)
            #print "node '%s' MPI_WORLD rank %d/%d" % (hostname, MPI.COMM_WORLD.rank, MPI.COMM_WORLD.size)
            if cpuid in cpuid_gpuid_mapping.keys():
                platform = openmm.Platform.getPlatformByName(gpu_platform_name)
                deviceid = cpuid_gpuid_mapping[cpuid]
                platform.setPropertyDefaultValue('OpenCLDeviceIndex', '%d' % deviceid) # select OpenCL device index
                platform.setPropertyDefaultValue('CudaDeviceIndex', '%d' % deviceid) # select Cuda device index
                print "node '%s' MPI_WORLD rank %d/%d cpuid %d platform %s deviceid %d" % (hostname, MPI.COMM_WORLD.rank, MPI.COMM_WORLD.size, cpuid, gpu_platform_name, deviceid)
            else:
                platform = openmm.Platform.getPlatformByName(cpu_platform_name)

            # Set up CPU and GPU communicators.
            gpu_process_list = filter(lambda x : x < MPI.COMM_WORLD.size, cpuid_gpuid_mapping.keys())
            if cpuid in gpu_process_list:
                color = 0 # GPU
            else:
                color = 1 # CPU    
            comm = MPI.COMM_WORLD.Split(color=color)

            # DEBUG
            #print "node '%s' MPI_WORLD rank %d/%d gpu_comm rank %d/%d cpu_comm rank %d/%d" % (hostname, MPI.COMM_WORLD.rank, MPI.COMM_WORLD.size, gpu_comm.rank, gpu_comm.size, cpu_comm.rank, cpu_comm.size)
            print "node '%s' MPI_WORLD rank %d/%d comm rank %d/%d : no GPU, so will not participate" % (hostname, MPI.COMM_WORLD.rank, MPI.COMM_WORLD.size, comm.rank, comm.size)

        except Exception as e:
            print e
            print "WARNING: Could not initialize MPI; falling back to serial execution."
            platform = openmm.Platform.getPlatformByName(gpu_platform_name)
            comm = None
            verbose = True

        # Create system.
        if verbose: print "Reading AMBER prmtop..."
        system = amber.readAmberSystem(prmtop_filename, shake="h-bonds", nonbondedMethod='CutoffPeriodic', nonbondedCutoff=9.0*units.angstroms)
        #system = amber.readAmberSystem(prmtop_filename, shake="h-bonds", nonbondedMethod='PME')
        if verbose: print "System has %d atoms." % system.getNumParticles()
        if verbose: print "Reading AMBER coordinates..."
        [coordinates, box_vectors] = amber.readAmberCoordinates(crd_filename, read_box=True)
        system.setDefaultPeriodicBoxVectors(*box_vectors)
        if verbose: print "prmtop and coordinate files read.\n"

        # Initialize parallel tempering simulation.
        import repex
        if verbose: print "Initializing parallel tempering simulation..."
        simulation = repex.ParallelTempering(system, coordinates, store_filename, Tmin=Tmin, Tmax=Tmax, ntemps=ntemps, mpicomm=comm)
        simulation.verbose = True # write debug output
        simulation.platform = platform # specify platform
        simulation.number_of_equilibration_iterations = nequiliterations
        simulation.number_of_iterations = niterations # number of iterations (exchange attempts)
        simulation.timestep = timestep # timestep
        simulation.nsteps_per_iteration = nsteps # number of timesteps per iteration
        #simulation.replica_mixing_scheme = 'swap-neighbors' # traditional neighbor-swap exchange
        simulation.replica_mixing_scheme = 'swap-all' # better mixing scheme for exchange step
        simulation.minimize = minimize

        # Run or resume simulation.
        if verbose: print "Running..."
        if comm:
            # Only GPU nodes run simulation.
            if cpuid in gpu_process_list:   
                simulation.run()
            # Wait for all nodes to finish
            MPI.COMM_WORLD.Barrier()            
        else:
            simulation.run() # run the simulation serially
    
if __name__ == "__main__":
    driver = ParallelTemperingDriver()
    driver.run()
