#!/usr/local/bin/env python

#=============================================================================================
# MODULE DOCSTRING
#=============================================================================================

"""
Test restraints between two particles.

DESCRIPTION


REFERENCES


COPYRIGHT

@author John D. Chodera <jchodera@gmail.com>

This source file is released under the GNU General Public License.

This program is free software: you can redistribute it and/or modify it under
the terms of the GNU General Public License as published by the Free Software
Foundation, either version 3 of the License, or (at your option) any later
version.

This program is distributed in the hope that it will be useful, but WITHOUT ANY
WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A
PARTICULAR PURPOSE.  See the GNU General Public License for more details.
 
You should have received a copy of the GNU General Public License along with
this program.  If not, see <http://www.gnu.org/licenses/>.

"""

#=============================================================================================
# GLOBAL IMPORTS
#=============================================================================================

import os
import os.path
import sys
import math
import copy
import time

import numpy

import simtk
import simtk.chem.openmm as openmm
import simtk.unit as units

#import Scientific.IO.NetCDF as netcdf
import netCDF4 as netcdf

#=============================================================================================
# MAIN AND TESTS
#=============================================================================================

if __name__ == "__main__":

    # Get platform.
    platform = openmm.Platform.getPlatformByName("OpenCL")    

    # Constant
    kB = units.BOLTZMANN_CONSTANT_kB * units.AVOGADRO_CONSTANT_NA

    # Create a system of two noninteracting particles.
    N = 2 # number of particles
    mass        = 12.0 * units.amu # arbitrary reference mass        
    [system, coordinates] = KobAndersen(N=N, NA=NA, mass=mass, epsilon=epsilon, sigma=sigma)    

    mode = 'test'
    # mode = 'production'

    # Relevant times and timescales
    reduced_time = units.sqrt((mass * sigma**2) / (48.0 * epsilon)) # factor on which all reduced times are based
    timestep = 0.035 * reduced_time # velocity Verlet timestep from Ref. [1], Supplementary Section 1.1
    delta_t = (40.0 / 3.0) * reduced_time # number of timesteps per Delta t in Lester's paper, Ref. [1] Supplementary Section 1.1, specified exactly by Lester in private communication

    if mode == 'test':
        quenched_temperature = 0.6 * epsilon / kB # temperature corresponding to left column of Fig. 2 from [1]
        tau = 200.0 * reduced_time # private communication from Lester Hedges (tau = 60 * reduced time for Tred = 0.7, 200 * reduced_time for Tred = 0.6)
        t_obs = 1 * tau # number of intervals of tau, from Fig. 2 of [1] (red dashed line, right column,) # should be 33 * tau for Tred = 0.7, 20*tau for Tred = 0.6
    elif mode == 'production':
        quenched_temperature = 0.7 * epsilon / kB # temperature corresponding to left column of Fig. 2 from [1]
        tau = 60.0 * reduced_time # private communication from Lester Hedges (tau = 60 * reduced time for Tred = 0.7, 200 * reduced_time for Tred = 0.6)
        t_obs = 33 * tau # number of intervals of tau, from Fig. 2 of [1] (red dashed line, right column,) # should be 33 * tau for Tred = 0.7, 20*tau for Tred = 0.6        
    else:
        raise Exception("Unknown mode: %s" % mode)

    print "UNCORRECTED TIMES"
    print "timestep = %s" % str(timestep.in_units_of(units.femtosecond))
    print "delta_t = %s (%f steps)" % (str(delta_t.in_units_of(units.picosecond)), delta_t / timestep)
    print "tau = %s (%f steps)" % (str(tau.in_units_of(units.picosecond)), tau / timestep)
    print "t_obs = %s (%f steps)" % (str(t_obs.in_units_of(units.picosecond)), t_obs / timestep)

    # Determine integral number of steps per velocity randomization and number of trajectories
    nsteps_per_frame = int(round(delta_t / timestep)) # number of steps per trajectory frame and velocity randomization
    nframes = int(round(t_obs / delta_t)) # number of frames per trajectory
    print "number of steps per trajectory frame and velocity randomization = %d" % nsteps_per_frame
    print "number of frames per trajectory = %d" % nframes
    # Correct delta_t and t_obs to be integral
    delta_t = nsteps_per_frame * timestep 
    t_obs = nframes * delta_t
    print "CORRECTED TIMES"
    print "timestep = %s" % str(timestep.in_units_of(units.femtosecond))
    print "delta_t = %s (%f steps)" % (str(delta_t.in_units_of(units.picosecond)), delta_t / timestep)
    print "tau = %s (%f steps)" % (str(tau.in_units_of(units.picosecond)), tau / timestep)
    print "t_obs = %s (%f steps)" % (str(t_obs.in_units_of(units.picosecond)), t_obs / timestep)

    # Decide whether we are initializing or resuming the simulation.
    minimize = True
    equilibrate = True
    quench = True
    seed = True

    # Minimize
    if minimize:
        print "Minimizing with L-BFGS..."
        # Initialize a minimizer with default options.
        import optimize
        minimizer = optimize.LBFGSMinimizer(system, verbose=True, platform=platform)
        # Minimize the initial coordinates.
        coordinates = minimizer.minimize(coordinates)
        # Clean up to release the Context.
        del minimizer

    # Set temperature for simulations.
    elevated_temperature = 2.0 * epsilon / kB

    if equilibrate:
        # Equilibrate at high temperature.
        collision_rate = 1.0 / delta_t
        nsteps = int(math.floor(t_obs / timestep))
        print "Equilibrating at %s for %d steps..." % (str(elevated_temperature), nsteps)
        integrator = openmm.LangevinIntegrator(elevated_temperature, collision_rate, timestep)
        context = openmm.Context(system, integrator, platform)
        context.setPositions(coordinates)
        integrator.step(nsteps)
        state = context.getState(getPositions=True)
        coordinates = state.getPositions(asNumpy=True)

    if quench:
        # Quench to final temperature
        collision_rate = 1.0 / delta_t
        nsteps = int(math.floor(t_obs / timestep))
        print "Quenching to %s for %d steps..." % (str(quenched_temperature), nsteps)
        integrator = openmm.LangevinIntegrator(quenched_temperature, collision_rate, timestep)
        context = openmm.Context(system, integrator, platform)
        context.setPositions(coordinates)
        integrator.step(nsteps)
        state = context.getState(getPositions=True)
        coordinates = state.getPositions(asNumpy=True)           

    # Specify reduced temperatures for replica-exchange.
    temperatures = [0.6, 0.7, 0.8, 0.9, 1.0, 1.5, 2.0] # temperatures are in reduced units

    # Create a number of transition path sampling ensembles at different values of the field parameter s.
    s_reduced_unit = 1.0 / (sigma**2 * delta_t)
    svalues = [0.0, 0.0]    
    ensembles = [ TransitionPathSampling(system, N, NA, mass, epsilon, sigma, timestep, nsteps_per_frame, nframes, quenched_temperature, s * s_reduced_unit) for s in svalues ]

    trajectory = None
    if seed:
        # Generate an initial trajectory at zero field
        print "Generating seed trajectory for TPS..."
        trajectory = ensembles[0].generateTrajectory(coordinates, nframes)

    # Initialize replica-exchange TPS simulation.
    print "Initializing replica-exchange TPS..."
    ncfilename = 'szero.nc'
    #ncfilename = 'szerolarge.nc'    
    simulation = ReplicaExchangeTPS(system, ensembles, trajectory, ncfilename)

    # Run simulation
    print "Running replica-exchange TPS..."
    simulation.run()
        
    print "Done."    
    return
