#!/usr/local/bin/env python

#=============================================================================================
# Analyze electric field history to compute Stark effect.
#=============================================================================================

#=============================================================================================
# REQUIREMENTS
#
# The netcdf4-python module is now used to provide netCDF v4 support:
# http://code.google.com/p/netcdf4-python/
#
# This requires NetCDF with version 4 and multithreading support, as well as HDF5.
#=============================================================================================

#=============================================================================================
# TODO
#=============================================================================================

#=============================================================================================
# CHAGELOG
#=============================================================================================

#=============================================================================================
# VERSION CONTROL INFORMATION
#=============================================================================================

#=============================================================================================
# IMPORTS
#=============================================================================================

import os
import os.path
import sys
import math

import numpy

import netCDF4 as netcdf # netcdf4-python

from pymbar import MBAR # multistate Bennett acceptance ratio
import timeseries # for statistical inefficiency analysis

import simtk.unit as units

#=============================================================================================
# PARAMETERS
#=============================================================================================

kB = units.BOLTZMANN_CONSTANT_kB * units.AVOGADRO_CONSTANT_NA

#=============================================================================================
# SUBROUTINES
#=============================================================================================
def write_file(filename, contents):
    """Write the specified contents to a file.

    ARGUMENTS
      filename (string) - the file to be written
      contents (string) - the contents of the file to be written

    """

    outfile = open(filename, 'w')

    if type(contents) == list:
        for line in contents:
            outfile.write(line)
    elif type(contents) == str:
        outfile.write(contents)
    else:
        raise "Type for 'contents' not supported: " + repr(type(contents))

    outfile.close()

    return

def read_file(filename):
    """Read contents of the specified file.

    ARGUMENTS
      filename (string) - the name of the file to be read

    RETURNS
      lines (list of strings) - the contents of the file, split by line
    """

    infile = open(filename, 'r')
    lines = infile.readlines()
    infile.close()

    return lines

def detect_equilibration(A_t, maxeval=100):
    """
    Automatically detect equilibrated region.

    ARGUMENTS

    A_t (numpy.array) - timeseries

    OPTIONAL ARGUMENTS 

    maxeval (int) - maximum number of origins to evaluate, or no limit if None

    RETURNS

    t (int) - start of equilibrated data
    g (float) - statistical inefficiency of equilibrated data
    Neff_max (float) - number of uncorrelated samples   
    
    """
    T = A_t.size

    # Special case if timeseries is constant.
    if A_t.std() == 0.0:
        return (0, 1, T)

    # Determine list of origins to evaluate.
    if (maxeval is None) or (T < maxeval):
        origins = range(0, T-1)
    else:
        origins = [ int(i * (T/maxeval)) for i in range(0, maxeval) ]
    norigins = len(origins)
    
    g_t = numpy.ones([norigins], numpy.float32)
    Neff_t = numpy.ones([norigins], numpy.float32)
    for (origin_index, origin) in enumerate(origins):
        g_t[origin_index] = timeseries.statisticalInefficiency(A_t[origin:T], fast=True)
        Neff_t[origin_index] = (T-origin+1) / g_t[origin_index]
    
    Neff_max = Neff_t.max()
    origins_index = Neff_t.argmax()
    t = origins[origins_index]
    g = g_t[origins_index]
    
    return (t, g, Neff_max)

def vector_length(vector):
    """
    Compute length of a given vector.

    ARGUMENTS

    vector (simtk.unit.Quantity of numpy) - the vector whose length is to be computed

    RETURNS

    length (simtk.unit.Quantity) - the length of the vector

    """
    return numpy.sqrt(numpy.sum((vector/vector.unit)**2)) * vector.unit

#=============================================================================================
# MAIN
#=============================================================================================

# TODO: Convert this to use optparse to take command-line arguments.

# Name of Stark shift output NetCDF file to analyze.
filename = 'stark.nc'
shift_unit = units.dimensionless / units.centimeter # TODO: Have these units automatically read from NetCDF file.

# Open NetCDF file for reading.
print "Opening NetCDF trajectory file '%(filename)s' for reading..." % vars()
ncfile = netcdf.Dataset(filename, 'r')

# DEBUG
print "dimensions:"
for dimension_name in ncfile.dimensions.keys():
    print "%16s %8d" % (dimension_name, len(ncfile.dimensions[dimension_name]))
    
# Extract computed Stark shift.
print "Extracting Stark shift trajectory..."
shift_t = ncfile.variables['shift'][:]
print shift_t

# Detect equilibrated region.
print "Detecting equilibrated region..."
[t, g, Neff_max] = detect_equilibration(shift_t)
print "Automated equilibration detection: t = %.1f, g = %.1f, Neff_max = %.1f" % (t, g, Neff_max)

# Discard non-equilibrated data.
shift_t = shift_t[t:]
T = len(shift_t)

# TODO: Write statistics.
Eshift = shift_t.mean() # mean
dEshift = shift_t.std() / numpy.sqrt(T/g) # standard error
std_shift = shift_t.std()
print "Average Stark shift: %.3f +- %.3f (1/cm)" % (Eshift, dEshift)
print "stddev Stark shift: %.3f" % (std_shift)

# Write out equilibrated data.
outfile = open('shift.out', 'w')
for t in range(T):
    outfile.write('%24.8f\n' % shift_t[t])
outfile.close()

# Clean up.
ncfile.close()


