#!/usr/local/bin/env python

#=============================================================================================
# MODULE DOCSTRING
#=============================================================================================

"""
Estimate steady-state nonequilibrium free energy of a TIP3P water box as a function of number of water molecules and timestep.

This script analyzes the NetCDF file produced by vvvr_waterbox_steadystate.py.

"""

#=============================================================================================
# GLOBAL IMPORTS
#=============================================================================================

import sys
import math
import doctest
import numpy

import simtk.unit as units
# import simtk.openmm as openmm
# import simtk.pyopenmm.extras.testsystems as testsystems

import netCDF4 as netcdf 

#=============================================================================================
# CONSTANTS
#=============================================================================================
    
#=============================================================================================
# UTILITY SUBROUTINES
#=============================================================================================

def statisticalInefficiency(A_n, B_n=None, fast=False, mintime=3):
  """
  Compute the (cross) statistical inefficiency of (two) timeseries.

  REQUIRED ARGUMENTS  
    A_n (numpy array) - A_n[n] is nth value of timeseries A.  Length is deduced from vector.

  OPTIONAL ARGUMENTS
    B_n (numpy array) - B_n[n] is nth value of timeseries B.  Length is deduced from vector.
       If supplied, the cross-correlation of timeseries A and B will be estimated instead of the
       autocorrelation of timeseries A.  
    fast (boolean) - if True, will use faster (but less accurate) method to estimate correlation
       time, described in Ref. [1] (default: False)
    mintime (int) - minimum amount of correlation function to compute (default: 3)
       The algorithm terminates after computing the correlation time out to mintime when the
       correlation function furst goes negative.  Note that this time may need to be increased
       if there is a strong initial negative peak in the correlation function.

  RETURNS
    g is the estimated statistical inefficiency (equal to 1 + 2 tau, where tau is the correlation time).
       We enforce g >= 1.0.

  NOTES 
    The same timeseries can be used for both A_n and B_n to get the autocorrelation statistical inefficiency.
    The fast method described in Ref [1] is used to compute g.

  REFERENCES  
    [1] J. D. Chodera, W. C. Swope, J. W. Pitera, C. Seok, and K. A. Dill. Use of the weighted
    histogram analysis method for the analysis of simulated and parallel tempering simulations.
    JCTC 3(1):26-41, 2007.

  EXAMPLES

  Compute statistical inefficiency of timeseries data with known correlation time.  

  >>> import timeseries
  >>> A_n = timeseries.generateCorrelatedTimeseries(N=100000, tau=5.0)
  >>> g = statisticalInefficiency(A_n, fast=True)
  
  """

  # Create numpy copies of input arguments.
  A_n = numpy.array(A_n)
  if B_n is not None:  
    B_n = numpy.array(B_n)
  else:
    B_n = numpy.array(A_n) 
  
  # Get the length of the timeseries.
  N = A_n.size

  # Be sure A_n and B_n have the same dimensions.
  if(A_n.shape != B_n.shape):
    raise ParameterError('A_n and B_n must have same dimensions.')

  # Initialize statistical inefficiency estimate with uncorrelated value.
  g = 1.0
    
  # Compute mean of each timeseries.
  mu_A = A_n.mean()
  mu_B = B_n.mean()

  # Make temporary copies of fluctuation from mean.
  dA_n = A_n.astype(numpy.float64) - mu_A
  dB_n = B_n.astype(numpy.float64) - mu_B

  # Compute estimator of covariance of (A,B) using estimator that will ensure C(0) = 1.
  sigma2_AB = (dA_n * dB_n).mean() # standard estimator to ensure C(0) = 1

  # Trap the case where this covariance is zero, and we cannot proceed.
  if(sigma2_AB == 0):
    raise ParameterException('Sample covariance sigma_AB^2 = 0 -- cannot compute statistical inefficiency')

  # Accumulate the integrated correlation time by computing the normalized correlation time at
  # increasing values of t.  Stop accumulating if the correlation function goes negative, since
  # this is unlikely to occur unless the correlation function has decayed to the point where it
  # is dominated by noise and indistinguishable from zero.
  t = 1
  increment = 1
  while (t < N-1):

    # compute normalized fluctuation correlation function at time t
    C = sum( dA_n[0:(N-t)]*dB_n[t:N] + dB_n[0:(N-t)]*dA_n[t:N] ) / (2.0 * float(N-t) * sigma2_AB)
    
    # Terminate if the correlation function has crossed zero and we've computed the correlation
    # function at least out to 'mintime'.
    if (C <= 0.0) and (t > mintime):
      break
    
    # Accumulate contribution to the statistical inefficiency.
    g += 2.0 * C * (1.0 - float(t)/float(N)) * float(increment)

    # Increment t and the amount by which we increment t.
    t += increment

    # Increase the interval if "fast mode" is on.
    if fast: increment += 1

  # g must be at least unity
  if (g < 1.0): g = 1.0
   
  # Return the computed statistical inefficiency.
  return g

#=============================================================================================
# PARAMETERS
#=============================================================================================

netcdf_filename = 'TIP3P-216-flexible.nc' # NetCDF filename to analyze
netcdf_filename = 'TIP3P-205-constrained.nc' # NetCDF filename to analyze

#=============================================================================================
# Analyze data.
#=============================================================================================

# Open NetCDF file for reading.
ncfile = netcdf.Dataset(netcdf_filename, 'r')

# Get dimensions.
timesteps_to_try = ncfile.variables['timesteps'][:]
nsamples = ncfile.variables['pseudowork'].shape[0]
nsteps_size = ncfile.variables['pseudowork'].shape[2]
nsteps_to_try = ncfile.variables['nsteps'][:]
print "%d samples loaded" % nsamples

# Use one less sample.
nsamples -= 1
print "Using %d samples" % nsamples

import timeseries

for (timestep_index, timestep) in enumerate(timesteps_to_try):
    print "timestep: %.3f fs" % timestep
    
    # Only examine data up to stability limit.
    final_energies = ncfile.variables['final_energy'][0:nsamples,timestep_index,:]
    if numpy.any(numpy.isnan(final_energies)) or numpy.any(numpy.isinf(final_energies)):
      print " stability limit exceeded"
      continue

    outfile = open('timestep-%04.2f.out' % timestep, 'w')

    for (nsteps_index, nsteps) in enumerate(nsteps_to_try[:-1]):
        # Check to make sure all energies are finite.
        final_energies = ncfile.variables['final_energy'][0:nsamples,timestep_index,nsteps_index+1]
        if numpy.any(numpy.isnan(final_energies)) or numpy.any(numpy.isinf(final_energies)):
          print "  nsteps: %12d | stability limit reached" % nsteps
          continue
      
        pseudowork = ncfile.variables['pseudowork'][0:nsamples,timestep_index,nsteps_index] # pseudowork in kT
        print pseudowork.mean()

        # Compute average pseudowork over first 'nsteps' from equilibrium initial sample.
        average_pseudowork = pseudowork.mean()
        g = timeseries.statisticalInefficiency(pseudowork) # statistical inefficiency
        N = len(pseudowork) # number of samples
        Neff = N / g # effective number of samples
        average_pseudowork_stderr = pseudowork.std() / numpy.sqrt(Neff)
        print "  nsteps: %12d | average pseudowork from equilibrium: %12.3f +/- %12.3f kT | g = %5.1f samples" % (nsteps, average_pseudowork, average_pseudowork_stderr, g)
        
        # Compute steady-state power over the next 'nsteps' steps.
        further_pseudowork = (ncfile.variables['pseudowork'][0:nsamples,timestep_index,nsteps_index+1] - pseudowork) # pseudowork per step over next 'nsteps' steps
        steadystate_pseudowork = further_pseudowork.mean() 
        steadystate_pseudowork_stderr = further_pseudowork.std() / numpy.sqrt(Neff)
        print "          %12s | steady-state power integrated over next 'nsteps' steps: %12.3f +/- %12.3f kT (integrated over %12d steps)" % ('', steadystate_pseudowork, steadystate_pseudowork_stderr, nsteps)

        # Compute estimate of nonequilibrium free energy.
        neq_free_energy = (average_pseudowork - steadystate_pseudowork) / 2.0
        covariance = numpy.cov(pseudowork, further_pseudowork)
        #neq_free_energy_stderr_squared = (average_pseudowork_stderr**2 + steadystate_pseudowork_stderr**2 - 2*covariance[0,1]) / 4.0
        neq_free_energy_stderr_squared = (average_pseudowork_stderr**2 + steadystate_pseudowork_stderr**2) / 4.0 # DEBUG: Neglects correlation
        neq_free_energy_stderr = numpy.sqrt(neq_free_energy_stderr_squared) if (neq_free_energy_stderr_squared > 0.0) else 0.0    
        print "          %12s | estimate of nonequilibrium free energy:                 %12.3f +/- %12.3f" % ('', neq_free_energy, neq_free_energy_stderr)

        # Compute estimate of total energy difference (enthalpy).
        initial_energies = ncfile.variables['initial_energy'][0:nsamples]
        final_energies = ncfile.variables['final_energy'][0:nsamples,timestep_index,nsteps_index]
        covariance = numpy.cov(initial_energies, final_energies)
        enthalpy_change = final_energies.mean() - initial_energies.mean()
        enthalpy_change_stderr = numpy.sqrt(final_energies.std()**2 + initial_energies.std()**2 - 2*covariance[0,1]) / numpy.sqrt(Neff)
        print "          %12s | estimate of nonequilibrium enthalpy change:             %12.3f +/- %12.3f" % ('', enthalpy_change, enthalpy_change_stderr)

        # Compute estimate of entropy difference.
        entropy_change = enthalpy_change - neq_free_energy
        entropy_change_stderr = numpy.sqrt(enthalpy_change_stderr**2 + neq_free_energy_stderr**2) # DEBUG: Neglects correlation.
        print "          %12s | estimate of nonequilibrium entropy change:              %12.3f +/- %12.3f" % ('', entropy_change, entropy_change_stderr)

        # Write estimates to a file.
        outfile.write('%12d %6.1f %24.8f %24.8f %24.8f %24.8f %24.8f %24.8f %24.8f %24.8f %24.8f %24.8f %24.8f %24.8f\n' % (nsteps, g, average_pseudowork, average_pseudowork_stderr, steadystate_pseudowork, steadystate_pseudowork_stderr, steadystate_pseudowork / nsteps, steadystate_pseudowork_stderr / nsteps, neq_free_energy, neq_free_energy_stderr, enthalpy_change, enthalpy_change_stderr, entropy_change, entropy_change_stderr))
        
    outfile.close()

    # Write just average cumulative pseudowork.
    outfile = open('cumulative-%04.2f.out' % timestep, 'w')    
    for (nsteps_index, nsteps) in enumerate(nsteps_to_try):
        pseudowork = ncfile.variables['pseudowork'][0:nsamples,timestep_index,nsteps_index]
        
        # Check to make sure all energies are finite.
        final_energies = ncfile.variables['final_energy'][0:nsamples,timestep_index,nsteps_index]
        if numpy.any(numpy.isnan(final_energies)) or numpy.any(numpy.isinf(final_energies)):
          continue

        # Compute average pseudowork over first 'nsteps' from equilibrium initial sample.
        average_pseudowork = pseudowork.mean()
        g = timeseries.statisticalInefficiency(pseudowork) # statistical inefficiency
        N = len(pseudowork) # number of samples
        Neff = N / g # effective number of samples
        average_pseudowork_stderr = pseudowork.std() / numpy.sqrt(Neff)

        # Write estimates to a file.
        outfile.write('%12d %24.8f %24.8f\n' % (nsteps, average_pseudowork, average_pseudowork_stderr))
                      
    outfile.close()

ncfile.close()


