#!/usr/local/bin/env python -d

#=============================================================================================
# MODULE DOCSTRING
#=============================================================================================

"""
Run halide scan.

DESCRIPTION

This script sets up a 'halide scan' in implicit solvent.

COPYRIGHT

@author John D. Chodera <jchodera@gmail.com>

This source file is released under the GNU General Public License.

This program is free software: you can redistribute it and/or modify it under
the terms of the GNU General Public License as published by the Free Software
Foundation, either version 3 of the License, or (at your option) any later
version.

This program is distributed in the hope that it will be useful, but WITHOUT ANY
WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A
PARTICULAR PURPOSE.  See the GNU General Public License for more details.
 
You should have received a copy of the GNU General Public License along with
this program.  If not, see <http://www.gnu.org/licenses/>.

"""

#=============================================================================================
# GLOBAL IMPORTS
#=============================================================================================

import os
import os.path
import sys
import pdb
import numpy

import simtk.openmm as openmm
import simtk.pyopenmm.amber.amber_file_parser as amber
import simtk.unit as units
import netCDF4 as netcdf

#=============================================================================================
# SOURCE CONTROL
#=============================================================================================

__version__ = "$Id: $"

#=============================================================================================
# MAIN AND TESTS
#=============================================================================================

class HalideScanDriver(object):
    def __init__(self):
        pass

    def run(self):
        # Simulation options.
        prefix = 'example/complex' # prefix for files to load
        temperature = 300.0 * units.kelvin # temperature
        timestep = 2.0 * units.femtoseconds # timestep for simulation
        nsteps = 500 # number of timesteps per iteration (exchange attempt)
        niterations = 10000 # number of iterations to complete
        nequiliterations = 1 # number of equilibration iterations at Tmin with timestep/2 timestep, for nsteps*2
        verbose = True # verbose output (set to False for less output)
        minimize = True # minimize
        phase = 'ligand'
        store_filename = phase + '.nc'

        # Make a list of prmtop/crd pairs to load.
        import commands
        prmtop_filenames = commands.getoutput('ls %s*-%s.prmtop' % (prefix,phase)).split()
        crd_filenames = commands.getoutput('ls %s*-%s.crd' % (prefix,phase)).split()        
        
        # Create thermodynamic states.
        coordinates = list()
        from thermodynamics import ThermodynamicState        
        states = list() # thermodynamic states
        for [prmtop_filename, crd_filename] in zip(prmtop_filenames, crd_filenames):
            print prmtop_filename + ' ' + crd_filename

            # Create system.
            if verbose: print "Reading AMBER prmtop..."
            #system = amber.readAmberSystem(prmtop_filename, shake="h-bonds", nonbondedMethod='CutoffPeriodic', nonbondedCutoff=9.0*units.angstroms)
            #system = amber.readAmberSystem(prmtop_filename, shake="h-bonds", nonbondedMethod='PME')
            system = amber.readAmberSystem(prmtop_filename, shake="h-bonds", nonbondedMethod='NoCutoff', gbmodel='OBC')
            if verbose: print "System has %d atoms." % system.getNumParticles()
            if verbose: print "Reading AMBER coordinates..."
            #[coordinates, box_vectors] = amber.readAmberCoordinates(crd_filename, read_box=True)
            #system.setDefaultPeriodicBoxVectors(*box_vectors)
            positions = amber.readAmberCoordinates(crd_filename, read_box=False)
            if verbose: print "prmtop and coordinate files read.\n"
            coordinates.append(positions)
            # Create thermodynamic state
            state = ThermodynamicState(system=system, temperature=temperature) 
            states.append(state)
        
        # Select platform: one of 'Reference' (CPU-only), 'Cuda' (NVIDIA Cuda), or 'OpenCL' (for OS X 10.6 with OpenCL OpenMM compiled)
        #platform = openmm.Platform.getPlatformByName("Cuda") # Cuda seems faster on Linux + Telsa M2070
        platform = openmm.Platform.getPlatformByName("OpenCL") # OpenCL is pretty good on Mac OS X

        try:
            # Set up device to bind to.
            print "Selecting MPI communicator and selecting a GPU device..."
            from mpi4py import MPI # MPI wrapper
            hostname = os.uname()[1]
            ngpus = 2 # number of GPUs per system
            comm = MPI.COMM_WORLD # MPI communicator
            deviceid = comm.rank % ngpus # select a unique GPU for this node assuming block allocation (not round-robin)
            platform.setPropertyDefaultValue('CudaDeviceIndex', '%d' % deviceid) # select Cuda device index
            platform.setPropertyDefaultValue('OpenCLDeviceIndex', '%d' % deviceid) # select OpenCL device index
            print "node '%s' deviceid %d / %d, MPI rank %d / %d" % (hostname, deviceid, ngpus, comm.rank, comm.size)
            # Make sure random number generators have unique seeds.
            seed = numpy.random.randint(sys.maxint - comm.size) + comm.rank
            numpy.random.seed(seed)
        except:
            comm = None
            print "WARNING: Could not initialize MPI.  Running serially..."            

        # Initialize parallel tempering simulation.
        import repex
        if verbose: print "Initializing parallel tempering simulation..."
        simulation = repex.ReplicaExchange(states, coordinates, store_filename, mpicomm=comm)
        simulation.verbose = True # write debug output
        simulation.platform = platform # specify platform
        simulation.number_of_equilibration_iterations = nequiliterations
        simulation.number_of_iterations = niterations # number of iterations (exchange attempts)
        simulation.timestep = timestep # timestep
        simulation.nsteps_per_iteration = nsteps # number of timesteps per iteration
        #simulation.replica_mixing_scheme = 'swap-neighbors' # traditional neighbor-swap exchange
        simulation.replica_mixing_scheme = 'swap-all' # better mixing scheme for exchange step
        simulation.minimize = minimize

        # Run or resume simulation.
        if verbose: print "Running..."
        simulation.run()
    
if __name__ == "__main__":
    driver = HalideScanDriver()
    driver.run()
