#!/usr/local/bin/env python

#=============================================================================================
# MODULE DOCSTRING
#=============================================================================================

"""
Estimate steady-state nonequilibrium free energy of a TIP3P water box as a function of number of water molecules and timestep.

"""

#=============================================================================================
# GLOBAL IMPORTS
#=============================================================================================

import sys
import math
import doctest
import numpy

import simtk.unit as units
# import simtk.openmm as openmm
# import simtk.pyopenmm.extras.testsystems as testsystems

#=============================================================================================
# CONSTANTS
#=============================================================================================
    
#=============================================================================================
# UTILITY SUBROUTINES
#=============================================================================================

def statisticalInefficiency(A_n, B_n=None, fast=False, mintime=3):
  """
  Compute the (cross) statistical inefficiency of (two) timeseries.

  REQUIRED ARGUMENTS  
    A_n (numpy array) - A_n[n] is nth value of timeseries A.  Length is deduced from vector.

  OPTIONAL ARGUMENTS
    B_n (numpy array) - B_n[n] is nth value of timeseries B.  Length is deduced from vector.
       If supplied, the cross-correlation of timeseries A and B will be estimated instead of the
       autocorrelation of timeseries A.  
    fast (boolean) - if True, will use faster (but less accurate) method to estimate correlation
       time, described in Ref. [1] (default: False)
    mintime (int) - minimum amount of correlation function to compute (default: 3)
       The algorithm terminates after computing the correlation time out to mintime when the
       correlation function furst goes negative.  Note that this time may need to be increased
       if there is a strong initial negative peak in the correlation function.

  RETURNS
    g is the estimated statistical inefficiency (equal to 1 + 2 tau, where tau is the correlation time).
       We enforce g >= 1.0.

  NOTES 
    The same timeseries can be used for both A_n and B_n to get the autocorrelation statistical inefficiency.
    The fast method described in Ref [1] is used to compute g.

  REFERENCES  
    [1] J. D. Chodera, W. C. Swope, J. W. Pitera, C. Seok, and K. A. Dill. Use of the weighted
    histogram analysis method for the analysis of simulated and parallel tempering simulations.
    JCTC 3(1):26-41, 2007.

  EXAMPLES

  Compute statistical inefficiency of timeseries data with known correlation time.  

  >>> import timeseries
  >>> A_n = timeseries.generateCorrelatedTimeseries(N=100000, tau=5.0)
  >>> g = statisticalInefficiency(A_n, fast=True)
  
  """

  # Create numpy copies of input arguments.
  A_n = numpy.array(A_n)
  if B_n is not None:  
    B_n = numpy.array(B_n)
  else:
    B_n = numpy.array(A_n) 
  
  # Get the length of the timeseries.
  N = A_n.size

  # Be sure A_n and B_n have the same dimensions.
  if(A_n.shape != B_n.shape):
    raise ParameterError('A_n and B_n must have same dimensions.')

  # Initialize statistical inefficiency estimate with uncorrelated value.
  g = 1.0
    
  # Compute mean of each timeseries.
  mu_A = A_n.mean()
  mu_B = B_n.mean()

  # Make temporary copies of fluctuation from mean.
  dA_n = A_n.astype(numpy.float64) - mu_A
  dB_n = B_n.astype(numpy.float64) - mu_B

  # Compute estimator of covariance of (A,B) using estimator that will ensure C(0) = 1.
  sigma2_AB = (dA_n * dB_n).mean() # standard estimator to ensure C(0) = 1

  # Trap the case where this covariance is zero, and we cannot proceed.
  if(sigma2_AB == 0):
    raise ParameterException('Sample covariance sigma_AB^2 = 0 -- cannot compute statistical inefficiency')

  # Accumulate the integrated correlation time by computing the normalized correlation time at
  # increasing values of t.  Stop accumulating if the correlation function goes negative, since
  # this is unlikely to occur unless the correlation function has decayed to the point where it
  # is dominated by noise and indistinguishable from zero.
  t = 1
  increment = 1
  while (t < N-1):

    # compute normalized fluctuation correlation function at time t
    C = sum( dA_n[0:(N-t)]*dB_n[t:N] + dB_n[0:(N-t)]*dA_n[t:N] ) / (2.0 * float(N-t) * sigma2_AB)
    
    # Terminate if the correlation function has crossed zero and we've computed the correlation
    # function at least out to 'mintime'.
    if (C <= 0.0) and (t > mintime):
      break
    
    # Accumulate contribution to the statistical inefficiency.
    g += 2.0 * C * (1.0 - float(t)/float(N)) * float(increment)

    # Increment t and the amount by which we increment t.
    t += increment

    # Increase the interval if "fast mode" is on.
    if fast: increment += 1

  # g must be at least unity
  if (g < 1.0): g = 1.0
   
  # Return the computed statistical inefficiency.
  return g

#=============================================================================================
# PARAMETERS
#=============================================================================================

boxsizes_to_try = [5.0 * units.angstroms, 10.0 * units.angstroms, 15.0 * units.angstroms] # waterbox sizes to try 
timesteps_to_try = [0.25 * units.femtoseconds, 0.5 * units.femtoseconds, 0.75 * units.femtoseconds, 1.0 * units.femtoseconds, 1.25 * units.femtoseconds, 1.5 * units.femtoseconds, 1.75 * units.femtoseconds, 2.0 * units.femtoseconds] # timesteps to try
nsteps_to_try = [ 2**n for n in range(13) ]

print "timesteps to try: %s" % str(timesteps_to_try)
print "number of steps to try: %s" % str(nsteps_to_try)

temperature = 298.0 * units.kelvin
pressure = 1.0 * units.atmosphere # pressure for equilibration
gamma = 50.0 / units.picosecond # collision rate
ghmc_nsteps = 100 # number of steps to generate new uncorrelated sample with GHMC
ghmc_timestep = 1.0 * units.femtoseconds
nsamples = 100 # number of samples to generate
nequil = 10 # number of NPT equilibration iterations

#=============================================================================================
# Analyze data.
#=============================================================================================

# Load data.
import cPickle as pickle
infile = open('data.pkl', 'rb')
nsamples = pickle.load(infile)
data = pickle.load(infile)
infile.close()
print "%d samples loaded" % nsamples

import timeseries

for timestep in timesteps_to_try:
    print "timestep: %s" % (str(timestep))
    
    outfile = open('timestep-%04.2f.out' % (timestep/units.femtoseconds), 'w')
    
    for nsteps in nsteps_to_try[:-1]:
        key = (timestep / units.femtosecond, nsteps)
        pseudowork = data[key][0:nsamples]
        
        # Compute average pseudowork over first 'nsteps' from equilibrium initial sample.
        average_pseudowork = pseudowork.mean()
        g = timeseries.statisticalInefficiency(pseudowork) # statistical inefficiency
        N = len(pseudowork) # number of samples
        Neff = N / g # effective number of samples
        average_pseudowork_stderr = pseudowork.std() / numpy.sqrt(Neff)
        print "  nsteps: %12d | average pseudowork from equilibrium: %12.3f +/- %12.3f kT | g = %5.1f samples" % (nsteps, average_pseudowork, average_pseudowork_stderr, g)
        
        # Compute steady-state power over the next 'nsteps' steps.
        key = (timestep / units.femtosecond, 2*nsteps)
        further_pseudowork = (data[key][0:nsamples] - pseudowork) # pseudowork per step over next 'nsteps' steps
        steadystate_pseudowork = further_pseudowork.mean() 
        steadystate_pseudowork_stderr = further_pseudowork.std() / numpy.sqrt(Neff)
        print "          %12s | steady-state power integrated over next 'nsteps' steps: %12.3f +/- %12.3f kT (integrated over %12d steps)" % ('', steadystate_pseudowork, steadystate_pseudowork_stderr, nsteps)

        # Compute estimate of nonequilibrium free energy.
        neq_free_energy = (average_pseudowork - steadystate_pseudowork) / 2.0
        covariance = numpy.cov(pseudowork, further_pseudowork)
        #neq_free_energy_stderr_squared = (average_pseudowork_stderr**2 + steadystate_pseudowork_stderr**2 - 2*covariance[0,1]) / 4.0
        neq_free_energy_stderr_squared = (average_pseudowork_stderr**2 + steadystate_pseudowork_stderr**2) / 4.0 # DEBUG
        neq_free_energy_stderr = numpy.sqrt(neq_free_energy_stderr_squared) if (neq_free_energy_stderr_squared > 0.0) else 0.0    
        print "          %12s | estimate of nonequilibrium free energy:                 %12.3f +/- %12.3f" % ('', neq_free_energy, neq_free_energy_stderr)

        # Write estimates to a file.
        outfile.write('%12d %6.1f %24.8f %24.8f %24.8f %24.8f %24.8f %24.8f %24.8f %24.8f\n' % (nsteps, g, average_pseudowork, average_pseudowork_stderr, steadystate_pseudowork, steadystate_pseudowork_stderr, steadystate_pseudowork / nsteps, steadystate_pseudowork_stderr / nsteps, neq_free_energy, neq_free_energy_stderr))
        
    outfile.close()

    # Write just average cumulative pseudowork.
    outfile = open('cumulative-%04.2f.out' % (timestep/units.femtoseconds), 'w')    
    for nsteps in nsteps_to_try:
        key = (timestep / units.femtosecond, nsteps)
        pseudowork = data[key][0:nsamples]
        
        # Compute average pseudowork over first 'nsteps' from equilibrium initial sample.
        average_pseudowork = pseudowork.mean()
        g = timeseries.statisticalInefficiency(pseudowork) # statistical inefficiency
        N = len(pseudowork) # number of samples
        Neff = N / g # effective number of samples
        average_pseudowork_stderr = pseudowork.std() / numpy.sqrt(Neff)

        # Write estimates to a file.
        outfile.write('%12d %24.8f %24.8f\n' % (nsteps, average_pseudowork, average_pseudowork_stderr))
                      
    outfile.close()

