#!/usr/local/bin/env python

#=============================================================================================
# MODULE DOCSTRING
#=============================================================================================

"""
Analyze alanine dipeptide 2D PMF via replica exchange.

DESCRIPTION



COPYRIGHT

@author John D. Chodera <jchodera@gmail.com>

This source file is released under the GNU General Public License.

This program is free software: you can redistribute it and/or modify it under
the terms of the GNU General Public License as published by the Free Software
Foundation, either version 3 of the License, or (at your option) any later
version.

This program is distributed in the hope that it will be useful, but WITHOUT ANY
WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A
PARTICULAR PURPOSE.  See the GNU General Public License for more details.
 
You should have received a copy of the GNU General Public License along with
this program.  If not, see <http://www.gnu.org/licenses/>.

"""

#=============================================================================================
# GLOBAL IMPORTS
#=============================================================================================

import scipy.optimize # THIS MUST BE IMPORTED FIRST?!

import os
import os.path

import numpy
import math
import time

import simtk.unit as units
import simtk.chem.openmm as openmm
import simtk.chem.openmm.extras.amber as amber
import simtk.chem.openmm.extras.optimize as optimize

import netCDF4 as netcdf # netcdf4-python is used in place of scipy.io.netcdf for now

import timeseries

#=============================================================================================
# SOURCE CONTROL
#=============================================================================================

__version__ = "$Id: $"

#=============================================================================================
# SUBROUTINES
#=============================================================================================

def compute_torsion(coordinates, i, j, k, l):
    """
    Compute torsion angle defined by four atoms.
    
    ARGUMENTS
    
    coordinates (simtk.unit.Quantity wrapping numpy natoms x 3) - atomic coordinates
    i, j, k, l - four atoms defining torsion angle
    
    NOTES
    
    Algorithm of Swope is used.    
    
    """
    rji = (coordinates[i,:] - coordinates[j,:]) / units.angstroms
    rjk = (coordinates[k,:] - coordinates[j,:]) / units.angstroms
    rkj = (coordinates[j,:] - coordinates[k,:]) / units.angstroms
    rkl = (coordinates[l,:] - coordinates[k,:]) / units.angstroms
    n1 = numpy.cross(rji, rjk); n1 = n1 / numpy.sqrt(numpy.dot(n1, n1))
    n2 = numpy.cross(rkl, rkj); n2 = n2 / numpy.sqrt(numpy.dot(n2, n2))
    cos_theta = numpy.dot(n1, n2)
    if (abs(cos_theta) > 1.0):
        cos_theta = 1.0 * numpy.sign(cos_theta)
    if math.isnan(cos_theta):
        print "cos_theta is NaN"
    if math.isnan(numpy.arccos(cos_theta)):
        print "arccos(cos_theta) is NaN"
        print "cos_theta = %f" % cos_theta
        print coordinates[i,:]
        print coordinates[j,:]
        print coordinates[k,:]
        print coordinates[l,:]
        print "n1"
        print n1
        print "n2"
        print n2
    theta = numpy.arccos(cos_theta) * units.radians
    
    if (numpy.dot(rjk, numpy.cross(n1, n2)) < 0.0):
        theta = - theta
    return theta

def show_mixing_statistics(ncfile, show_transition_matrix=False):
    """
    Print summary of mixing statistics.

    """

    print "Computing mixing statistics..."

    states = ncfile.variables['states'][:,:].copy()

    # Determine number of iterations and states.
    [niterations, nstates] = ncfile.variables['states'][:,:].shape
    
    # Compute statistics of transitions.
    Nij = numpy.zeros([nstates,nstates], numpy.float64)
    for iteration in range(niterations-1):
        for ireplica in range(nstates):
            istate = states[iteration,ireplica]
            jstate = states[iteration+1,ireplica]
            Nij[istate,jstate] += 0.5
            Nij[jstate,istate] += 0.5
    Tij = numpy.zeros([nstates,nstates], numpy.float64)
    for istate in range(nstates):
        Tij[istate,:] = Nij[istate,:] / Nij[istate,:].sum()

    if show_transition_matrix:
        # Print observed transition probabilities.
        PRINT_CUTOFF = 0.001 # Cutoff for displaying fraction of accepted swaps.
        print "Cumulative symmetrized state mixing transition matrix:"
        print "%6s" % "",
        for jstate in range(nstates):
            print "%6d" % jstate,
        print ""
        for istate in range(nstates):
            print "%-6d" % istate,
            for jstate in range(nstates):
                P = Tij[istate,jstate]
                if (P >= PRINT_CUTOFF):
                    print "%6.3f" % P,
                else:
                    print "%6s" % "",
            print ""

    # Estimate second eigenvalue and equilibration time.
    mu = numpy.linalg.eigvals(Tij)
    mu = -numpy.sort(-mu) # sort in descending order
    if (mu[1] >= 1):
        print "Perron eigenvalue is unity; Markov chain is decomposable."
    else:
        print "Perron eigenvalue is %9.5f; state equilibration timescale is ~ %.1f iterations" % (mu[1], 1.0 / (1.0 - mu[1]))

    return

def compute_relaxation_time(bin_it, nbins):
    """
    Compute relaxation time from empirical transition matrix of binned coordinate trajectories.

    """

    [nstates, niterations] = bin_it.shape
    
    # Compute statistics of transitions.
    Nij = numpy.zeros([nbins,nbins], numpy.float64)
    for ireplica in range(nstates):
        for iteration in range(niterations-1):        
            ibin = bin_it[ireplica, iteration]
            jbin = bin_it[ireplica, iteration+1]
            Nij[ibin,jbin] += 0.5
            Nij[jbin,ibin] += 0.5
    Tij = numpy.zeros([nbins,nbins], numpy.float64)
    for ibin in range(nbins):
        Tij[ibin,:] = Nij[ibin,:] / Nij[ibin,:].sum()

    mu = numpy.linalg.eigvals(Tij)
    mu = -numpy.sort(-mu) # sort in descending order
    tau = 1.0 / (1.0 - mu[1])
    
    return tau

#=============================================================================================
# MAIN AND TESTS
#=============================================================================================
    
if __name__ == "__main__":

    store_filename  = '2d-pmf-neighbor-swap.nc' # output netCDF filename
    store_filename  = '2d-pmf-all-swap.nc' # output netCDF filename    
    
    temperature = 300.0 * units.kelvin # temperature


    # Open NetCDF file.
    ncfile = netcdf.Dataset(store_filename, 'r', version=2)

    # Get dimensions.
    [niterations, nstates, natoms, ndim] = ncfile.variables['positions'][:,:,:,:].shape    
    print "%d iterations, %d states, %d atoms" % (niterations, nstates, natoms)
    
    # Print summary statistics about mixing in state space.
    show_mixing_statistics(ncfile)

    # Compute statistical inefficiency of state index.
    states = ncfile.variables['states'][:,:].copy()
    A_kn = [ states[:,k].copy() for k in range(nstates) ]
    g_states = timeseries.statisticalInefficiencyMultiple(A_kn)
    print "g_states = %.1f iterations" % g_states
    del states, A_kn

    # Compute statistical inefficiency for reduced potential
    energies = ncfile.variables['energies'][:,:,:].copy()
    states = ncfile.variables['states'][:,:].copy()    
    u_n = numpy.zeros([niterations], numpy.float64)
    for iteration in range(niterations):
        u_n[iteration] = 0.0
        for replica in range(nstates):
            state = states[iteration,replica]
            u_n[iteration] += energies[iteration,replica,state]
    del energies, states
    g_u = timeseries.statisticalInefficiency(u_n)
    print "g_u = %8.1f iterations" % g_u
        
    # Compute x and y umbrellas.    
    print "Computing torsions..."
    positions = ncfile.variables['positions'][:,:,:,:]
    coordinates = units.Quantity(numpy.zeros([natoms,ndim], numpy.float32), units.angstroms)
    phi_it = units.Quantity(numpy.zeros([nstates,niterations], numpy.float32), units.radians)
    psi_it = units.Quantity(numpy.zeros([nstates,niterations], numpy.float32), units.radians)
    for iteration in range(niterations):
        for replica in range(nstates):
            coordinates[:,:] = units.Quantity(positions[iteration,replica,:,:].copy(), units.angstroms)
            phi_it[replica,iteration] = compute_torsion(coordinates, 4, 6, 8, 14) 
            psi_it[replica,iteration] = compute_torsion(coordinates, 6, 8, 14, 16)

    # Compute statistical inefficiencies of various functions of the timeseries data.
    print "Computing statistical infficiencies of cos(phi), sin(phi), cos(psi), sin(psi)..."
    cosphi_kn = [ numpy.cos(phi_it[replica,:] / units.radians).copy() for replica in range(1,nstates) ]
    sinphi_kn = [ numpy.sin(phi_it[replica,:] / units.radians).copy() for replica in range(1,nstates) ]
    cospsi_kn = [ numpy.cos(psi_it[replica,:] / units.radians).copy() for replica in range(1,nstates) ]
    sinpsi_kn = [ numpy.sin(psi_it[replica,:] / units.radians).copy() for replica in range(1,nstates) ]
    g_cosphi = timeseries.statisticalInefficiencyMultiple(cosphi_kn)
    g_sinphi = timeseries.statisticalInefficiencyMultiple(sinphi_kn)
    g_cospsi = timeseries.statisticalInefficiencyMultiple(cospsi_kn)
    g_sinpsi = timeseries.statisticalInefficiencyMultiple(sinpsi_kn)
    del cosphi_kn, sinphi_kn, cospsi_kn, sinpsi_kn
    print "g_cosphi = %8.1f iterations" % g_cosphi
    print "g_sinphi = %8.1f iterations" % g_sinphi
    print "g_cospsi = %8.1f iterations" % g_cospsi
    print "g_sinpsi = %8.1f iterations" % g_sinpsi

    # Compute relaxation times in each torsion.
    print "Relaxation times for transitions among phi or psi bins alone:"
    nbins = 50 # number of bins per torsion
    delta = 360.0 / (nbins - 0.01)
    phibin_it = ((phi_it / units.degrees + 180.0) / delta).astype(numpy.int16)
    tau_phi = compute_relaxation_time(phibin_it, nbins)
    psibin_it = ((psi_it / units.degrees + 180.0) / delta).astype(numpy.int16)
    tau_psi = compute_relaxation_time(psibin_it, nbins)
    print "tau_phi = %8.1f iteration" % tau_phi
    print "tau_psi = %8.1f iteration" % tau_psi

