#!/usr/local/bin/env python

#=============================================================================================
# MODULE DOCSTRING
#=============================================================================================

"""
Analyze VVVR data for kinetic properties.

"""

#=============================================================================================
# GLOBAL IMPORTS
#=============================================================================================

import os, os.path
import sys
import math
import doctest
import numpy

import simtk.unit as units
import simtk.openmm as openmm

import netCDF4 as netcdf 

#=============================================================================================
# CONSTANTS
#=============================================================================================

nreduce = 1
ndt = 100
nregression = ndt/2
timeSample = 1. * units.picoseconds
natomsInWater = 3
#kT_in_amu_nm2_ps2 = 2.49

#=============================================================================================
# UTILITY SUBROUTINES
#=============================================================================================

def analyze_netcdf_file(netcdf_filename, label):
  # DEBUG: Only look at WaterBox
  if 'WaterBox' not in netcdf_filename:
    return

  # Open NetCDF file for reading.
  print "Opening %s for reading..." % netcdf_filename
  ncfile = netcdf.Dataset(netcdf_filename, 'r')

  # Compute kT.
  temperature = ncfile.variables['temperature'][0] * units.kelvin
  kT_in_amu_nm2_ps2 = units.AVOGADRO_CONSTANT_NA * units.BOLTZMANN_CONSTANT_kB * temperature / (units.amu * units.nanometers**2 / units.picoseconds**2)
  
  # Deserialize OpenMM System object.
  serialized_system = str(ncfile.variables['system'][0])
  system = openmm.XmlSerializer.deserialize(serialized_system)
  
  # Extract masses.
  masses = units.Quantity( numpy.array([ system.getParticleMass(index) / units.amu for index in range(system.getNumParticles()) ]), units.amu )
  # print masses

  # Get dimensions.
  [nsamples, natoms, ndim] = ncfile.variables['positions'][:,:,:].shape
  #nsamples /= nreduce # FOR DEBUGGING
  print "%d samples loaded" % nsamples
  # print "%d atoms in system" % natoms

  # Get timestep and gamma.
  timestep_unit = units.femtoseconds
  gamma_unit = 1.0 / units.picoseconds
  timestep = float(ncfile.variables['timestep'][0]) * timestep_unit
  gamma = float(ncfile.variables['gamma'][0]) * gamma_unit

  # Use one less sample in case last sample is incomplete.
  #nsamples -= 1
  #print "Using %d samples" % nsamples

  # Units
  displacement_unit = units.nanometers
  mass_unit = units.amu
  velocity_unit = units.nanometers/units.picoseconds
  diffCoef_unit = units.nanometers*units.nanometers/units.picoseconds
  sd_unit = displacement_unit**2
  sv_unit = velocity_unit**2

  # Check to make sure stability limit of integrator has not been reached.
  potential_energies = ncfile.variables['potential_energy'][0:nsamples] # energies in kT
  if numpy.any(numpy.isnan(potential_energies)):
    print "stability limit exceeded"
    return

  # Compute mean-squared displacement
  print "computing msd..."
  msd_dt = numpy.zeros([ndt], numpy.float64) # mean-squared displacement as function of time separation
  for dt in range(ndt):
    for t in range(dt,nsamples):
      msd_dt[dt] += ((ncfile.variables['positions'][t,:,:]-ncfile.variables['positions'][t-dt,:,:])**2).sum()
    msd_dt[dt] /= (nsamples-dt)*natoms
  # Check that MSD is finite.
  if numpy.any(numpy.isnan(msd_dt)):
    print "NaN MSDs detected."
    return
  # Write msd to msd file
  outfile = open('%s.msd' % label, 'w')
  for dt in range(ndt):
    outfile.write('%8.3f %.3e\n' % (dt, msd_dt[dt]))
  # Clean up.
  outfile.close()

  # Compute diffusion coefficient via Einstein formula involving time-derivative of mean-squared displacement at long times
  print "computing Einstein diffusion coefficient..."
  from numpy import arange,array,ones,linalg
  # A = array([ arange(ndt), ones(ndt) ])
  # w = linalg.lstsq(A.T,msd_dt[:])[0]
  A = array([ arange(nregression,2*nregression), ones(nregression) ])
  w = linalg.lstsq(A.T,msd_dt[nregression:])[0]
  diffCoefE = w[0] * sv_unit * timeSample / ndim / 2.
  msdIntercept = w[1] * sd_unit

  # Compute diffusion coefficient via Green-Kubo formula involving time-integral of velocity autocorrelation
  print "computing Green-Kubo diffusion coefficient..."
  vac_dt = numpy.zeros([ndt], numpy.float64) # velocity autocorrelation as function of time separation
  for dt in range(ndt):
    for t in range(dt, nsamples):
      vac_dt[dt] += (ncfile.variables['velocities'][t,:,:]*ncfile.variables['velocities'][t-dt,:,:]).sum()
    vac_dt[dt] /= (nsamples-dt)*natoms
  # Check that VAC is finite.
  if numpy.any(numpy.isnan(vac_dt)):
    print "NaN VACs detected."
    return
  # Write vac to msd file
  outfile = open('%s.vac' % label, 'w')
  for dt in range(ndt):
    outfile.write('%8.3f %.3e\n' % (dt, vac_dt[dt]))
  # Clean up.
  outfile.close()
  diffCoefGK = vac_dt.sum() / ndim * sv_unit * timeSample
  diffCoefExpected = kT_in_amu_nm2_ps2 / (masses[0] * gamma)
  # print diffCoefGK, diffCoefExpected, diffCoefE

  # Compute mean-squared velocity.
  print "computing msv..."
  msv_t = numpy.zeros([nsamples], numpy.float64) # mean-squared velocity trajectory
  # print ncfile.variables['velocities'][:,:,:].max(), ncfile.variables['velocities'][:,:,:].mean(), ncfile.variables['velocities'][:,:,:].min() 
  print ncfile.variables['velocities'][1:nsamples/50,:,:].max(), ncfile.variables['velocities'][1:nsamples/50,:,:].mean(), ncfile.variables['velocities'][1:nsamples/50,:,:].min() 
  print ncfile.variables['velocities'][9*nsamples/10:nsamples,:,:].max(), ncfile.variables['velocities'][9*nsamples/10:nsamples,:,:].mean(), ncfile.variables['velocities'][9*nsamples/10:nsamples,:,:].min() 
  if 'WaterBox' not in netcdf_filename:
    msv_t[t] = (ncfile.variables['velocities'][t,:,:]**2).sum() / natoms
  else:
    massWater = masses[0:2].sum()
    nwaters = natoms/natomsInWater
    for t in range(nsamples):
      massWeightedVelocities = ncfile.variables['velocities'][t,:,:]
      for k in range(3):
        massWeightedVelocities[:,k] *= masses[:]
      for i in range(nwaters):
        vCOM = massWeightedVelocities[3*i:3*i+2,:].sum(axis=1) / massWater
        msv_t[t] += (vCOM[:]**2).sum()
    msv_t[:] /= nwaters
  # Check that MSV is finite.
  if numpy.any(numpy.isnan(msv_t)):
    print "NaN MSVs detected."
    return

  # Subsample to generate uncorrelated subset.
  print "Subsampling data..."
  # print msv_t
  # print msv_t.max()
  import timeseries
  [t, g, Neff_max, indices] = timeseries.detectEquilibration(msv_t, nskip=100)
  msv_n = msv_t[indices]
  msv = msv_n.mean()
  dmsv = msv_n.std() / numpy.sqrt(Neff_max)
  msvExpected = 3 * kT_in_amu_nm2_ps2 / masses[0]
  # print msvExpected

  # Compute instantaneous kinetic temperature
  kT = units.AVOGADRO_CONSTANT_NA * units.BOLTZMANN_CONSTANT_kB * temperature
  ndof = 3*system.getNumParticles() - system.getNumConstraints() # number of degrees of freedom
  kinetic_energies = ncfile.variables['kinetic_energy'][0:nsamples] # reduced units
  import timeseries
  [t, g, Neff_max, indices] = timeseries.detectEquilibration(kinetic_energies, nskip=100)
  kinetic_temperatures = kinetic_energies * kT / ((ndof/2.) * units.AVOGADRO_CONSTANT_NA * units.BOLTZMANN_CONSTANT_kB)
  kinetic_temperatures_in_kelvin = numpy.array(kinetic_temperatures / units.kelvin, dtype=numpy.float64)
  mean_kinetic_temperature_in_kelvin = numpy.mean(kinetic_temperatures_in_kelvin)
  mean_kinetic_temperature_uncertainty = numpy.std(kinetic_temperatures_in_kelvin) / numpy.sqrt(Neff_max)
  print "Mean kinetic temperature = %8.1f K" % mean_kinetic_temperature_in_kelvin

  # Select output filename.
  outfile = open('%s.out' % label, 'w')
  # Write data to analysis file
  outfile.write('%.3e %.3e %.3e %.3e %.3e %.3e %.3e %.3e %.3e %.3e %.3e\n' % (timestep / timestep_unit, gamma / gamma_unit, msv, dmsv, msvExpected * mass_unit, diffCoefGK / diffCoef_unit, diffCoefE / diffCoef_unit, diffCoefExpected * (mass_unit * gamma_unit), msdIntercept / sd_unit, mean_kinetic_temperature_in_kelvin, mean_kinetic_temperature_uncertainty ))
  # Clean up.
  outfile.close()
  ncfile.close()

  return



#=============================================================================================
# PARAMETERS
#=============================================================================================

# Sets of parameters to regress over.
timestep_correction_flags_to_try = [True, False]
systems_to_try = ['IdealGas', 'LennardJonesFluid', 'WaterBox'] 
timesteps_to_try = units.Quantity([0.5, 1.0, 2.0, 4.0, 8.0, 16.0], units.femtoseconds) # MD timesteps
gammas_to_try = units.Quantity([0.01, 0.1, 1.0, 10.0, 100.0, 1000.0], units.picoseconds**-1) # collision rates

#=============================================================================================
# Initialize MPI.
#=============================================================================================

try:
    from mpi4py import MPI # MPI wrapper
    rank = MPI.COMM_WORLD.rank
    size = MPI.COMM_WORLD.size
    print "Node %d / %d" % (MPI.COMM_WORLD.rank, MPI.COMM_WORLD.size)
except:
    print "mpi4py not available---using serial execution."
    rank = 0
    size = 1
    
#=============================================================================================
# Analyze data.
#=============================================================================================

noptionsets = len(timestep_correction_flags_to_try) * len(systems_to_try) * len(timesteps_to_try) * len(gammas_to_try)
for index in range(rank, noptionsets, size):

    #=============================================================================================
    # Select problem to analyze.
    #=============================================================================================

    def select_options(options_list, index):
        selected_options = list()
        for option in options_list:
            noptions = len(option)
            selected_options.append(option[index % noptions])
            index = int(index/noptions)
        return selected_options

    try:
        [use_timestep_correction, system_name, timestep, gamma] = select_options([timestep_correction_flags_to_try, systems_to_try, timesteps_to_try, gammas_to_try], index)
    except:
        continue

    # Create filename to store data in.
    store_filename = 'data/vvvr-%s-%s-%s-%s.nc' % (str(use_timestep_correction), system_name, str(timestep/units.femtoseconds), str(gamma*units.picoseconds))

    if os.path.exists(store_filename):
      # Analyze.
      label = 'vvvr-%s-%s-%s-%s' % (str(use_timestep_correction), system_name, str(timestep/units.femtoseconds), str(gamma*units.picoseconds))
      analyze_netcdf_file(store_filename, 'analysis/' + label)

    
