#!/usr/local/bin/env python

#=============================================================================================
# MODULE DOCSTRING
#=============================================================================================

"""
Simulation of WCA dimer in dense WCA solvent using GHMC with NCMC moves.

DESCRIPTION

COPYRIGHT

@author John D. Chodera <jchodera@gmail.com>

All code in this repository is released under the GNU General Public License.

This program is free software: you can redistribute it and/or modify it under
the terms of the GNU General Public License as published by the Free Software
Foundation, either version 3 of the License, or (at your option) any later
version.

This program is distributed in the hope that it will be useful, but WITHOUT ANY
WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A
PARTICULAR PURPOSE.  See the GNU General Public License for more details.
 
You should have received a copy of the GNU General Public License along with
this program.  If not, see <http://www.gnu.org/licenses/>.

TODO

"""

#=============================================================================================
# GLOBAL IMPORTS
#=============================================================================================

import os
import os.path
import sys
import math
import copy
import time

import numpy

import simtk
import simtk.unit as units
import simtk.openmm as openmm
    
#import Scientific.IO.NetCDF as netcdf # for netcdf interface in Scientific
import netCDF4 as netcdf # for netcdf interface provided by netCDF4 in enthought python
#import scipy.io.netcdf as netcdf # for netcdf interface in scipy

import wcadimer # import WCA dimer system
import sampling # import sampling algorithms

#=============================================================================================
# SUBROUTINES
#=============================================================================================

def norm(n01):
    return n01.unit * numpy.sqrt(numpy.dot(n01/n01.unit, n01/n01.unit))

#=============================================================================================
# MAIN AND TESTS
#=============================================================================================

if __name__ == "__main__":
    # PARAMETERS
    netcdf_filename = 'data/ncmc-solvent.nc'

    verbose = False
    
    # WCA fluid parameters (argon).
    mass     = wcadimer.mass
    sigma    = wcadimer.sigma
    epsilon  = wcadimer.epsilon
    r_WCA    = wcadimer.r_WCA
    r0       = wcadimer.r0
    h        = wcadimer.h
    w        = wcadimer.w
    
    # Compute characteristic timescale.
    tau = wcadimer.tau
    print "tau = %.3f ps" % (tau / units.picoseconds)

    # Compute timestep.
    equilibrate_timestep = 2 * wcadimer.stable_timestep
    timestep = 5 * wcadimer.stable_timestep
    print "equilibrate timestep = %.1f fs, switch timestep = %.1f fs" % (equilibrate_timestep / units.femtoseconds, timestep / units.femtoseconds)

    # Set temperature, pressure, and collision rate for stochastic thermostats.
    temperature = wcadimer.temperature
    print "temperature = %.1f K" % (temperature / units.kelvin)
    kT = wcadimer.kT
    beta = 1.0 / kT # inverse temperature    
    collision_rate = 1.0 / tau # collision rate for Langevin integrator

    niterations = 10000 # number of work samples to collect

    # Number of steps for switching.
    nsteps = 2048

    # Create system.     
    [system, coordinates] = wcadimer.WCADimer()

    # Form vectors of masses and sqrt(kT/m) for force propagation and velocity randomization.
    print "Creating masses array..."
    nparticles = system.getNumParticles()
    masses = units.Quantity(numpy.zeros([nparticles,3], numpy.float64), units.amu)
    for particle_index in range(nparticles):
        masses[particle_index,:] = system.getParticleMass(particle_index)
    sqrt_kT_over_m = units.Quantity(numpy.zeros([nparticles,3], numpy.float64), units.nanometers / units.picosecond)
    for particle_index in range(nparticles):
        sqrt_kT_over_m[particle_index,:] = units.sqrt(kT / masses[particle_index,0]) # standard deviation of velocity distribution for each coordinate for this atom

    # List all available platforms
    print "Available platforms:"
    for platform_index in range(openmm.Platform.getNumPlatforms()):
        platform = openmm.Platform.getPlatform(platform_index)
        print "%5d %s" % (platform_index, platform.getName())
    print ""

    # Select platform.
    platform = openmm.Platform.getPlatformByName("OpenCL")
    deviceid = 5
    platform.setPropertyDefaultValue('OpenCLDeviceIndex', '%d' % deviceid)
    platform.setPropertyDefaultValue('CudaDeviceIndex', '%d' % deviceid)         

    # Initialize netcdf file.
    if not os.path.exists(netcdf_filename):
        # Open NetCDF file for writing
        ncfile = netcdf.Dataset(netcdf_filename, 'w') # for netCDF4
        #ncfile = netcdf.netcdf_file(netcdf_filename, 'w') # for scipy
        
        # Create dimensions.
        ncfile.createDimension('iteration', 0) # unlimited number of iterations
        ncfile.createDimension('nparticles', nparticles) # number of particles
        ncfile.createDimension('ndim', 3) # number of dimensions    
        ncfile.createDimension('nsteps', nsteps) # number of steps        

        # Create variables.
        ncfile.createVariable('fraction_accepted', 'd', ('iteration',))
        ncfile.createVariable('distance', 'd', ('iteration',))
        ncfile.createVariable('work', 'd', ('iteration',))
        ncfile.createVariable('heat', 'd', ('iteration',))
        ncfile.createVariable('lechner_work', 'd', ('iteration'))
        ncfile.createVariable('work_history', 'd', ('iteration','nsteps'))
        ncfile.createVariable('heat_history', 'd', ('iteration','nsteps'))
        ncfile.createVariable('lechner_work_history', 'd', ('iteration','nsteps'))                                
        ncfile.createVariable('initial_distance', 'd', ('iteration',))
        ncfile.createVariable('final_distance', 'd', ('iteration',))
        ncfile.createVariable('accept', 'i', ('iteration',))                
        ncfile.createVariable('positions', 'd', ('iteration','nparticles','ndim'))        
                
        # Force sync to disk to avoid data loss.
        ncfile.sync()

        # Minimize.
        print "Minimizing energy..."
        coordinates = sampling.minimize(platform, system, coordinates)
    
        # Equilibrate.
        print "Equilibrating..."
        [coordinates, velocities, fraction_accepted] = sampling.equilibrate_ghmc(system, equilibrate_timestep, collision_rate, temperature, masses, sqrt_kT_over_m, coordinates, platform)
        print "%.3f %% accepted." % (fraction_accepted*100.)

        ncfile.sync()        
        iteration = 0
    else:
        # Open NetCDF file for reading.
        ncfile = netcdf.Dataset(netcdf_filename, 'a') # for netCDF4
        #ncfile = netcdf.netcdf_file(netcdf_filename, 'a') # for scipy

        # Read iteration and coordinates.
        iteration = ncfile.variables['distance'][:].size
        iteration -= 1 # Back up an iteration in case restarting due to NaN
        coordinates = units.Quantity(ncfile.variables['positions'][iteration-1,:,:], units.angstroms)

    # Continue
    while (iteration < niterations):
        print "\niteration %5d / %5d" % (iteration, niterations)
        initial_time = time.time()
        
        # Generate a new configuration.
        initial_distance = norm(coordinates[1,:] - coordinates[0,:])
        [coordinates, velocities, fraction_accepted] = sampling.equilibrate_ghmc(system, equilibrate_timestep, collision_rate, temperature, masses, sqrt_kT_over_m, coordinates, platform) 
        print "%.3f %% accepted." % (fraction_accepted*100.)
        final_distance = norm(coordinates[1,:] - coordinates[0,:])            
        print "Dynamics %.1f A -> %.1f A (barrier at %.1f A)" % (initial_distance / units.angstroms, final_distance / units.angstroms, (r0+w)/units.angstroms)
        ncfile.variables['fraction_accepted'][iteration] = fraction_accepted
        
        # Create a context.
        print "Creating new context..."
        integrator = openmm.VerletIntegrator(timestep)            
        context = openmm.Context(system, integrator, platform)

        # Compute initial total energy.
        print "Randomizing velocities..."
        velocities = sqrt_kT_over_m * numpy.random.standard_normal(size=sqrt_kT_over_m.shape)
        total_energy = sampling.compute_energy(context, coordinates, velocities)
        
        # Choose neq move.
        initial_distance = norm(coordinates[1,:] - coordinates[0,:])
        print "initial distance = %s" % str(initial_distance)
        if (initial_distance < 1.5*r0):
            final_distance = initial_distance + r0
        elif (initial_distance > 1.5*r0 and initial_distance < 3.0*r0):
            final_distance = initial_distance - r0
        else:
            final_distance = initial_distance
        print "Proposing %.1f A -> %.1f A (barrier at %.1f A)" % (initial_distance / units.angstroms, final_distance / units.angstroms, (r0+w)/units.angstroms)
            
        # Accumulate work and heat.
        work = 0.0 * units.kilocalories_per_mole
        heat = 0.0 * units.kilocalories_per_mole
        
        # Generate initial coordinates and velocities.
        proposed_coordinates = copy.deepcopy(coordinates)
        proposed_velocities = copy.deepcopy(velocities)
        
        # Allocate storage for work histories.
        work_history = numpy.zeros([nsteps], numpy.float32)
        heat_history = numpy.zeros([nsteps], numpy.float32)
        lechner_work_history = numpy.zeros([nsteps], numpy.float32)
        
        # Compute initial total energy.        
        context.setPositions(coordinates)
        context.setVelocities(velocities)        
        state = context.getState(getForces=False,getEnergy=True)
        total_energy = state.getKineticEnergy() + state.getPotentialEnergy()
        initial_energy = total_energy
        
        # Integrate nonequilibrium switching trajectory.
        for step in range(nsteps):

            #
            # Apply perturbation kernel to dimer, accumulating work contribution.
            #
            
            last_total_energy = total_energy

            distance = (initial_distance + (final_distance - initial_distance)*float(step+1)/float(nsteps))
            bond_midpoint = (proposed_coordinates[0,:] + proposed_coordinates[1,:]) / 2.0
            n01 = (proposed_coordinates[1,:] - proposed_coordinates[0,:]); n01 /= norm(n01);
            for k in range(3):
                proposed_coordinates[0,k] = bond_midpoint[k] - float(n01[k]) * distance/2
                proposed_coordinates[1,k] = bond_midpoint[k] + float(n01[k]) * distance/2

            context.setPositions(proposed_coordinates)
            state = context.getState(getForces=True,getEnergy=True)
            kinetic_energy = state.getKineticEnergy()
            potential_energy = state.getPotentialEnergy()           
            total_energy = kinetic_energy + potential_energy

            work += total_energy - last_total_energy
                
            #
            # Apply propagation kernel (Velocity Verlet) to bath, accumulating heat contribution.
            # 
            
            last_total_energy = kinetic_energy = potential_energy

            # Half-kick velocities.
            forces = state.getForces(asNumpy=True)
            proposed_velocities[2:,:] += 0.5 * forces[2:,:]/masses[2:,:] * timestep 

            # Full-kick positions.
            proposed_coordinates[2:,:] += proposed_velocities[2:,:] * timestep 

            # Update force and potential energy.
            context.setPositions(proposed_coordinates)
            state = context.getState(getForces=True, getEnergy=True)
            forces = state.getForces(asNumpy=True)            
            potential_energy = state.getPotentialEnergy()

            # Half-kick velocities and update kinetic energy.
            proposed_velocities[2:,:] += 0.5 * forces[2:,:]/masses[2:,:] * timestep
            kinetic_energy = 0.5 * (masses*proposed_velocities**2).in_units_of(potential_energy.unit).sum() * potential_energy.unit

            # Accumulate heat contribution.
            total_energy = kinetic_energy + potential_energy
            heat += total_energy - last_total_energy
            
            # Store
            work_history[step] = work / kT
            heat_history[step] = heat / kT
            lechner_work_history[step] = (total_energy - initial_energy) / kT
            
            if (verbose): print "step %5d / %5d : energy = %8.1f kT, accumulated work = %8.1f kT, accumulated heat = %8.1f kT" % (step+1, nsteps, total_energy / kT, work / kT, heat / kT)
            
        # Compute Lechner work.
        final_energy = total_energy
        lechner_work = final_energy - initial_energy
        print "%5d steps : lechner_work = %8.1f kT, work+heat = %8.1f kT, work = %8.1f kT, heat = %8.1f kT" % (nsteps, lechner_work / kT, (work+heat) / kT, work/kT, heat/kT)
        
        # Record data.
        ncfile.variables['initial_distance'][iteration] = initial_distance / units.angstroms
        ncfile.variables['final_distance'][iteration] = final_distance / units.angstroms            
        ncfile.variables['work'][iteration] = work / kT
        ncfile.variables['heat'][iteration] = heat / kT
        ncfile.variables['lechner_work'][iteration] = lechner_work / kT
        ncfile.variables['work_history'][iteration,0:nsteps] = work_history
        ncfile.variables['heat_history'][iteration,0:nsteps] = heat_history
        ncfile.variables['lechner_work_history'][iteration,0:nsteps] = lechner_work_history            
        
        # Accept or reject.
        log_Paccept = -lechner_work/kT + numpy.log((final_distance/initial_distance)**2)
        if (log_Paccept >= 0.0) or (numpy.random.rand() < numpy.exp(log_Paccept)):
            print "Accepted."
            coordinates = proposed_coordinates
            ncfile.variables['accept'][iteration] = 1
        else:
            print "Rejected."
            ncfile.variables['accept'][iteration] = 0

        final_time = time.time()
        elapsed_time = final_time - initial_time
        print "%12.3f s elapsed" % elapsed_time

        # Record results.
        distance = norm(coordinates[1,:] - coordinates[0,:])                
        ncfile.variables['distance'][iteration] = distance / units.angstroms
        ncfile.variables['positions'][iteration,:,:] = coordinates[:,:] / units.angstroms
        ncfile.sync()

        # Debug.
        final_time = time.time()
        elapsed_time = final_time - initial_time
        print "%12.3f s elapsed" % elapsed_time

        # Increment iteration counter.
        iteration += 1

    # Close netcdf file.
    ncfile.close()
