#!/usr/local/bin/env python

#=============================================================================================
# MODULE DOCSTRING
#=============================================================================================

"""
Analyze WCA dimer in dense WCA solvent GHMC simulation.

DESCRIPTION

COPYRIGHT

@author John D. Chodera <jchodera@gmail.com>

All code in this repository is released under the GNU General Public License.

This program is free software: you can redistribute it and/or modify it under
the terms of the GNU General Public License as published by the Free Software
Foundation, either version 3 of the License, or (at your option) any later
version.

This program is distributed in the hope that it will be useful, but WITHOUT ANY
WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A
PARTICULAR PURPOSE.  See the GNU General Public License for more details.
 
You should have received a copy of the GNU General Public License along with
this program.  If not, see <http://www.gnu.org/licenses/>.

TODO

"""

#=============================================================================================
# GLOBAL IMPORTS
#=============================================================================================

import os
import os.path
import sys
import math
import copy
import time

import numpy

import simtk.unit as units
    
#import Scientific.IO.NetCDF as netcdf # for netcdf interface in Scientific
import netCDF4 as netcdf # for netcdf interface provided by netCDF4 in enthought python

#=============================================================================================
# SUBROUTINES
#=============================================================================================

def logsum(a_n):
    """
    Compute log(sum(exp(a_n))) in a numerically stable manner.

    """
    
    a_n = a_n.copy()
    max_arg = numpy.max(a_n)
    return numpy.log(numpy.sum(numpy.exp(a_n - max_arg))) + max_arg

def eliminate_nans(u_n):
    """
    Squeeze out entries that are nan.

    """
    return u_n[~numpy.isnan(u_n)]

def log_mean_metropolis(u_n):
    """
    Compute log of mean Metropolis acceptance criteria.

    """

    u_n = eliminate_nans(u_n)
    
    log_acceptance = - u_n
    log_acceptance[log_acceptance > 0.0] = 0.0
    log_mean_acceptance = logsum(log_acceptance) - numpy.log(float(u_n.size))
    return log_mean_acceptance

def mean_metropolis(u_n):
    """
    Compute log of mean Metropolis acceptance criteria.

    """

    # TODO: Include Jacobian.
    # TODO: Include statistical inefficiency in estimate of error.

    u_n = eliminate_nans(u_n)
    
    log_acceptance = - u_n
    log_acceptance[log_acceptance > 0.0] = 0.0
    acceptance = numpy.exp(log_acceptance)
    mean_acceptance = acceptance.mean()
    N = acceptance.size
    mean_acceptance_error = acceptance.std() / numpy.sqrt(N)
    return [mean_acceptance, mean_acceptance_error]

def compute_statistics(nsteps_to_try, log_Paccept, debug=False):
    # Get dimensions.
    [K, niterations] = log_Paccept.shape
    if debug: print "%d iterations read" % niterations

    # Log mean acceptance
    mean_acceptance = numpy.zeros([K], numpy.float64)
    mean_acceptance_error = numpy.zeros([K], numpy.float64)
    log_mean_acceptance = numpy.zeros([K], numpy.float64)
    log_mean_acceptance_error = numpy.zeros([K], numpy.float64)    
    for k in range(K):
        x_n = numpy.squeeze(log_Paccept[k,0:niterations])
        x_n = numpy.minimum(x_n, 0.0)        
        max_arg = x_n.max()
        log_mean_acceptance[k] = numpy.log(numpy.mean(numpy.exp(x_n - max_arg))) + max_arg

        x_n = numpy.exp(numpy.squeeze(log_Paccept[k,0:niterations]))
        x_n = numpy.minimum(x_n, 1.0)
        mean_acceptance[k] = numpy.mean(x_n)
        mean_acceptance_error[k] = numpy.std(x_n) / numpy.sqrt(niterations)
        log_mean_acceptance_error[k] = (1.0/mean_acceptance[k]) * mean_acceptance_error[k]
    log_instantaneous_acceptance = log_mean_acceptance[0]

    if debug:
        print "LOG MEAN ACCEPTANCE"
        print "%8s" % "",
        for k in range(K):
            print "%8d" % (nsteps_to_try[k]),
        print ""
        print "%8s" % "",
        for k in range(K):
            print "%8.1f" % (log_mean_acceptance[k]),
        print ""
        print "LOG MEAN ACCEPTANCE ERROR"
        for k in range(K):
            print "%8.3f" % (log_mean_acceptance_error[k]),
        print ""
        
    # Log relative efficiency.
    log_mean_efficiency = numpy.zeros([K], numpy.float64)
    for k in range(K):
        nsteps = nsteps_to_try[k]
        log_mean_efficiency[k] = log_mean_acceptance[k] - numpy.log(nsteps+1) - log_instantaneous_acceptance

    if debug:
        print "LOG RELATIVE EFFICIENCY"
        print "%8s" % "",
        for k in range(K):
            print "%8d" % (nsteps_to_try[k]),
        print ""
        print "%8s" % "",
        for k in range(K):
            print "%8.1f" % (log_mean_efficiency[k]),
        print ""

    # Log10 mean acceptance.
    if debug:
        print "LOG10 MEAN ACCEPTANCE"
        #print "%8s %12.3f +- %12.3f" % ("MC", instantaneous_acceptance, instantaneous_acceptance_error)
        for k in range(K):
            print "%8d %12.3f +- %12.3f" % (nsteps_to_try[k], log_mean_acceptance[k] / math.log(10.0), log_mean_acceptance_error[k] / math.log(10.0))
        print ""
        print ""

        print "LOG10 RELATIVE_EFFICIENCY"
        #print "%8s %12.3f +- %12.3f" % ("MC", instantaneous_acceptance, instantaneous_acceptance_error)
        for k in range(K):
            print "%8d %12.3f" % (nsteps_to_try[k], log_mean_efficiency[k] / math.log(10.0))
        print ""
        print ""

    # Store statistics.
    statistics = dict()

    statistics['log_mean_acceptance'] = log_mean_acceptance
    statistics['log_mean_efficiency'] = log_mean_efficiency
    statistics['mean_acceptance'] = mean_acceptance

    return statistics

def compute_confidence_interval(x_n, interval):
    # Copy
    x_n_copy = numpy.array(x_n)

    # Sort
    x_n_copy.sort()

    # Determine length.
    N = x_n.size

    # Compute indices.
    lower_index = int(numpy.round(float(N-1) * (0.5 - interval/2.0)))
    upper_index = int(numpy.round(float(N-1) * (0.5 + interval/2.0)))

    # Return interval.
    return (x_n_copy[lower_index], x_n_copy[upper_index])           

#=============================================================================================
# MAIN AND TESTS
#=============================================================================================

if __name__ == "__main__":

    filename = 'data/ncmc-statistics.nc'
    nbootstrap = 1000
    
    # Open NetCDF file for reading
    # ncfile = netcdf.NetCDFFile(filename, 'r') # for Scientific.IO.NetCDF
    ncfile = netcdf.Dataset(filename, 'r') # for netCDF4

    # Load data.
    heat = ncfile.variables['heat'][:,:].copy()
    work = ncfile.variables['work'][:,:].copy()
    lechner_work = ncfile.variables['lechner_work'][:,:].copy()
    log_Paccept = ncfile.variables['log_Paccept'][:,:].copy()
    nsteps_to_try = ncfile.variables['nsteps_to_try'][:].copy()

    # Close NetCDF.
    ncfile.close()
    
    # Correct nan work values to LOG_ZERO.
    LOG_ZERO = -1000
    log_Paccept[numpy.isnan(log_Paccept)] = LOG_ZERO    

    # Get dimensions.
    [K, niterations] = heat.shape
    #niterations -= 1 # subtract one incomplete iteration
    print "%d iterations read" % niterations

    # Compute maximum-likelihood statistics.
    statistics = compute_statistics(nsteps_to_try, log_Paccept)    
    log_mean_acceptance = statistics['log_mean_acceptance']
    log_mean_efficiency = statistics['log_mean_efficiency']
    mean_acceptance = statistics['mean_acceptance']

    # Bootstrap iterations.
    log_mean_acceptance_trials = numpy.zeros([nbootstrap, K], numpy.float64)
    log_mean_efficiency_trials = numpy.zeros([nbootstrap, K], numpy.float64)
    mean_acceptance_trials = numpy.zeros([nbootstrap, K], numpy.float64)    
    for trial in range(nbootstrap):
        print "Bootstrap iteration %d / %d" % (trial, nbootstrap)
        
        # Resample data.
        indices = numpy.random.randint(niterations, size=[niterations])
        
        # Compute statistics of bootstrap sample.
        statistics = compute_statistics(nsteps_to_try, log_Paccept[:,indices])

        # Store data.
        log_mean_acceptance_trials[trial,:] = statistics['log_mean_acceptance']
        log_mean_efficiency_trials[trial,:] = statistics['log_mean_efficiency']        
        mean_acceptance_trials[trial,:] = statistics['mean_acceptance']        
        
    # Compute 95% confidence interval.
    interval = 0.95
    log_mean_acceptance_lower = numpy.zeros([K], numpy.float64)
    log_mean_acceptance_upper = numpy.zeros([K], numpy.float64)    
    log_mean_efficiency_lower = numpy.zeros([K], numpy.float64)
    log_mean_efficiency_upper = numpy.zeros([K], numpy.float64)    
    mean_acceptance_lower = numpy.zeros([K], numpy.float64)
    mean_acceptance_upper = numpy.zeros([K], numpy.float64)
    for k in range(K):
        [lower, upper] = compute_confidence_interval(log_mean_acceptance_trials[:,k], interval)
        log_mean_acceptance_lower[k] = lower
        log_mean_acceptance_upper[k] = upper

        [lower, upper] = compute_confidence_interval(log_mean_efficiency_trials[:,k], interval)
        log_mean_efficiency_lower[k] = lower
        log_mean_efficiency_upper[k] = upper

        [lower, upper] = compute_confidence_interval(mean_acceptance_trials[:,k], interval)
        mean_acceptance_lower[k] = lower
        mean_acceptance_upper[k] = upper

    # Generate Matlab-formatted output
    print "%% %d iterations" % niterations
    print "log_instantaneous_acceptance = %6.2f;" % log_mean_acceptance[0]
    print "nsteps_list = [",
    for k in range(1,K):
        print "%6d" % (nsteps_to_try[k]),
    print "];"

    print "log_mean_acceptance = [",
    for k in range(1,K):
        print "%6.2f" % (log_mean_acceptance[k]),
    print "];"
    print "log_mean_acceptance_lower = [",
    for k in range(1,K):
        print "%6.2f" % (log_mean_acceptance_lower[k]),
    print "];"        
    print "log_mean_acceptance_upper = [",
    for k in range(1,K):
        print "%6.2f" % (log_mean_acceptance_upper[k]),
    print "];"

    print "log_relative_efficiency = [",
    for k in range(1,K):
        print "%6.2f" % (log_mean_efficiency[k]),
    print "];"
    print "log_relative_efficiency_lower = [",
    for k in range(1,K):
        print "%6.2f" % (log_mean_efficiency_lower[k]),
    print "];"        
    print "log_relative_efficiency_upper = [",
    for k in range(1,K):
        print "%6.2f" % (log_mean_efficiency_upper[k]),
    print "];"

