#!/usr/local/bin/env python

#=============================================================================================
# MODULE DOCSTRING
#=============================================================================================

"""
Simulation of WCA dimer in dense WCA solvent using GHMC with instantaeous MC moves.

DESCRIPTION

COPYRIGHT

@author John D. Chodera <jchodera@gmail.com>

All code in this repository is released under the GNU General Public License.

This program is free software: you can redistribute it and/or modify it under
the terms of the GNU General Public License as published by the Free Software
Foundation, either version 3 of the License, or (at your option) any later
version.

This program is distributed in the hope that it will be useful, but WITHOUT ANY
WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A
PARTICULAR PURPOSE.  See the GNU General Public License for more details.
 
You should have received a copy of the GNU General Public License along with
this program.  If not, see <http://www.gnu.org/licenses/>.

TODO

"""

#=============================================================================================
# GLOBAL IMPORTS
#=============================================================================================

import os
import os.path
import sys
import math
import copy
import time

import numpy

import simtk
import simtk.unit as units
import simtk.openmm as openmm
    
#import Scientific.IO.NetCDF as netcdf # for netcdf interface in Scientific
import netCDF4 as netcdf # for netcdf interface provided by netCDF4 in enthought python

import wcadimer # WCA dimer system definition
import sampling # sampling utility methods

#=============================================================================================
# SUBROUTINES
#=============================================================================================

def norm(n01):
    return n01.unit * numpy.sqrt(numpy.dot(n01/n01.unit, n01/n01.unit))

#=============================================================================================
# MAIN
#=============================================================================================

if __name__ == "__main__":
    # PARAMETERS
    netcdf_filename = 'data/mc-solvent.nc'
    
    # WCA fluid parameters (argon).
    mass     = wcadimer.mass
    sigma    = wcadimer.sigma
    epsilon  = wcadimer.epsilon
    r_WCA    = wcadimer.r_WCA
    r0       = wcadimer.r0
    h        = wcadimer.h
    w        = wcadimer.w
    
    # Compute characteristic timescale.
    tau = wcadimer.tau
    print "tau = %.3f ps" % (tau / units.picoseconds)

    # Compute timestep.
    equilibrate_timestep = 2 * wcadimer.stable_timestep
    switching_timestep = 5 * wcadimer.stable_timestep
    print "equilibrate timestep = %.1f fs" % (equilibrate_timestep / units.femtoseconds)

    # Set temperature, pressure, and collision rate for stochastic thermostats.
    temperature = wcadimer.temperature
    print "temperature = %.1f K" % (temperature / units.kelvin)
    kT = wcadimer.kT
    beta = 1.0 / kT # inverse temperature    
    collision_rate = 1.0 / tau # collision rate for Langevin integrator

    niterations = 10000 # number of work samples to collect

    # Create system.     
    [system, coordinates] = wcadimer.WCADimer()

    # Form vectors of masses and sqrt(kT/m) for force propagation and velocity randomization.
    print "Creating masses array..."
    nparticles = system.getNumParticles()
    masses = units.Quantity(numpy.zeros([nparticles,3], numpy.float64), units.amu)
    for particle_index in range(nparticles):
        masses[particle_index,:] = system.getParticleMass(particle_index)
    sqrt_kT_over_m = units.Quantity(numpy.zeros([nparticles,3], numpy.float64), units.nanometers / units.picosecond)
    for particle_index in range(nparticles):
        sqrt_kT_over_m[particle_index,:] = units.sqrt(kT / masses[particle_index,0]) # standard deviation of velocity distribution for each coordinate for this atom

    # List all available platforms
    print "Available platforms:"
    for platform_index in range(openmm.Platform.getNumPlatforms()):
        platform = openmm.Platform.getPlatform(platform_index)
        print "%5d %s" % (platform_index, platform.getName())
    print ""

    # Select platform.
    platform = openmm.Platform.getPlatformByName("OpenCL")
    deviceid = 4
    platform.setPropertyDefaultValue('OpenCLDeviceIndex', '%d' % deviceid)
    platform.setPropertyDefaultValue('CudaDeviceIndex', '%d' % deviceid)         

    # Initialize netcdf file.
    if not os.path.exists(netcdf_filename):
        # Open NetCDF file for writing
        ncfile = netcdf.Dataset(netcdf_filename, 'w') # for netCDF4
        
        # Create dimensions.
        ncfile.createDimension('iteration', 0) # unlimited number of iterations
        ncfile.createDimension('nparticles', nparticles) # number of particles
        ncfile.createDimension('ndim', 3) # number of dimensions    

        # Create variables.
        ncfile.createVariable('distance', 'd', ('iteration',))
        ncfile.createVariable('work', 'd', ('iteration',))
        ncfile.createVariable('log_Paccept', 'd', ('iteration',))
        ncfile.createVariable('fraction_accepted', 'd', ('iteration',))        
        ncfile.createVariable('initial_distance', 'd', ('iteration',))
        ncfile.createVariable('final_distance', 'd', ('iteration',))
        ncfile.createVariable('accept', 'i', ('iteration',))                
        ncfile.createVariable('positions', 'd', ('iteration','nparticles','ndim'))        
        
        # Force sync to disk to avoid data loss.
        ncfile.sync()

        # Minimize.
        print "Minimizing energy..."
        coordinates = sampling.minimize(platform, system, coordinates)        
    
        # Equilibrate with GHMC.
        print "Equilibrating..."
        [coordinates, velocities, fraction_accepted] = sampling.equilibrate_ghmc(system, equilibrate_timestep, collision_rate, temperature, masses, sqrt_kT_over_m, coordinates, platform)
        print "%.3f %% accepted" % (fraction_accepted * 100.0)
        
        ncfile.sync()        
        iteration = 0
    else:
        # Open NetCDF file for reading.
        ncfile = netcdf.Dataset(netcdf_filename, 'a') # for netCDF4

        # Read iteration and coordinates.
        iteration = ncfile.variables['distance'][:].size
        iteration -= 1 # BACK UP ONE MORE JUST IN CASE
        coordinates = units.Quantity(ncfile.variables['positions'][iteration-1,:,:], units.angstroms)

    # Continue
    while (iteration < niterations):
        print "iteration %5d / %5d" % (iteration, niterations)
        initial_time = time.time()
        
        # Generate a new configuration.
        initial_distance = norm(coordinates[1,:] - coordinates[0,:])
        [coordinates, velocities, fraction_accepted] = sampling.equilibrate_ghmc(system, equilibrate_timestep, collision_rate, temperature, masses, sqrt_kT_over_m, coordinates, platform)
        print "%.3f %% accepted" % (fraction_accepted * 100.0)
        final_distance = norm(coordinates[1,:] - coordinates[0,:])            
        print "Dynamics %.1f A -> %.1f A (barrier at %.1f A)" % (initial_distance / units.angstroms, final_distance / units.angstroms, (r0+w)/units.angstroms)
        ncfile.variables['fraction_accepted'][iteration] = fraction_accepted
        
        # Create a context.
        print "Creating new context..."
        integrator = openmm.VerletIntegrator(switching_timestep)            
        context = openmm.Context(system, integrator, platform)
        
        # Choose neq move.
        initial_distance = norm(coordinates[1,:] - coordinates[0,:])
        print "initial distance = %s" % str(initial_distance)
        if (initial_distance < 1.5*r0):
            final_distance = initial_distance + r0
        elif (initial_distance > 1.5*r0 and initial_distance < 3.0*r0):
            final_distance = initial_distance - r0
        else:
            final_distance = initial_distance
        print "Proposing %.1f A -> %.1f A (barrier at %.1f A)" % (initial_distance / units.angstroms, final_distance / units.angstroms, (r0+w)/units.angstroms)
            
        # Record attempt
        ncfile.variables['initial_distance'][iteration] = initial_distance / units.angstroms
        ncfile.variables['final_distance'][iteration] = final_distance / units.angstroms            

        # Accumulate work.
        work = 0.0 * units.kilocalories_per_mole
            
        # Generate initial coordinates.
        proposed_coordinates = copy.deepcopy(coordinates)

        # Compute initial potential energy. 
        #print "distance = %s" % str(norm(proposed_coordinates[1,:] - proposed_coordinates[0,:]))           
        potential_energy = sampling.compute_potential(context, proposed_coordinates)
        initial_energy = potential_energy            
        #print "initial energy = %f kT" % (initial_energy / kT)

        # Make instantaneous proposal.
        last_potential_energy = potential_energy
        distance = final_distance
        bond_midpoint = (proposed_coordinates[0,:] + proposed_coordinates[1,:]) / 2.0
        n01 = (proposed_coordinates[1,:] - proposed_coordinates[0,:]); n01 /= norm(n01);
        for k in range(3):
            proposed_coordinates[0,k] = bond_midpoint[k] - float(n01[k]) * distance/2
            proposed_coordinates[1,k] = bond_midpoint[k] + float(n01[k]) * distance/2
        #print "distance = %s" % str(norm(proposed_coordinates[1,:] - proposed_coordinates[0,:]))
        potential_energy = sampling.compute_potential(context, proposed_coordinates)
        work = potential_energy - last_potential_energy

        print "work = %8.1f kT" % (work / kT)

        # Record data.
        ncfile.variables['work'][iteration] = work / kT
        ncfile.sync()

        # Accept or reject.
        log_Paccept = -work/kT + numpy.log((final_distance/initial_distance)**2)
        ncfile.variables['log_Paccept'][iteration] = log_Paccept
        if (log_Paccept >= 0.0) or (numpy.random.rand() < numpy.exp(log_Paccept)):
            print "Accepted."
            coordinates = proposed_coordinates
            ncfile.variables['accept'][iteration] = 1
        else:
            print "Rejected."
            ncfile.variables['accept'][iteration] = 0

        # Record results.
        distance = norm(coordinates[1,:] - coordinates[0,:])                
        ncfile.variables['distance'][iteration] = distance / units.angstroms
        ncfile.variables['positions'][iteration,:,:] = coordinates[:,:] / units.angstroms
        ncfile.sync()

        # Debug.
        final_time = time.time()
        elapsed_time = final_time - initial_time
        print "%12.3f s elapsed" % elapsed_time

        # Increment iteration counter.
        iteration += 1

    # Close netcdf file.
    ncfile.close()

