#!/usr/local/bin/env python

#=============================================================================================
# Analyze electric field history to compute Stark effect.
#=============================================================================================

#=============================================================================================
# REQUIREMENTS
#
# The netcdf4-python module is now used to provide netCDF v4 support:
# http://code.google.com/p/netcdf4-python/
#
# This requires NetCDF with version 4 and multithreading support, as well as HDF5.
#=============================================================================================

#=============================================================================================
# TODO
#=============================================================================================

#=============================================================================================
# CHAGELOG
#=============================================================================================

#=============================================================================================
# VERSION CONTROL INFORMATION
#=============================================================================================

#=============================================================================================
# IMPORTS
#=============================================================================================

import os
import os.path
import sys
import math

import numpy

import netCDF4 as netcdf # netcdf4-python

from pymbar import MBAR # multistate Bennett acceptance ratio
import timeseries # for statistical inefficiency analysis

import simtk.unit as units

#=============================================================================================
# PARAMETERS
#=============================================================================================

kB = units.BOLTZMANN_CONSTANT_kB * units.AVOGADRO_CONSTANT_NA

#=============================================================================================
# REFERENCE DATA
#=============================================================================================

# TODO: Make sure mutants are matched irrespective of orderin which mutations are listed.
reference_stark_peaks = {
    'T338M A403S' : 2213.1, 
    'T338M A403T V323T' : 2213.1,
    'T338M A403T' : 2213.4,
    'T338L' : 2213.7,
    'T338M M314L' : 2213.8,
    'T338M V323T' : 2213.9,
    'T338M' : 2214.6,
    'V323C A403T' : 2215.7,
    'T338N' : 2215.8,
    'T338M A403C' : 2215.8,
    'T338I' : 2215.9,
    'T338F' : 2216,
    'V323S A403T' : 2216.9,
    'T338Q' : 2217,
    'A403T' : 2217.3,
    'M314L L325Y T338M' : 2217.8,
    'T338V' : 2217.9,
    'V323S' : 2218.4,
    'V323C' : 2219.2,
    'V323T' : 2219.8,
    'A403V' : 2220,
    'A403C' : 2221.5,
    'M341H' : 2222.5,
    'A403S' : 2223.2,
    'V281C' : 2223.3,
    'WT' : 2223.6,
    'M341C' : 2224.4,
    'L325Y' : 2225.4,
}

#=============================================================================================
# SUBROUTINES
#=============================================================================================
def write_file(filename, contents):
    """Write the specified contents to a file.

    ARGUMENTS
      filename (string) - the file to be written
      contents (string) - the contents of the file to be written

    """

    outfile = open(filename, 'w')

    if type(contents) == list:
        for line in contents:
            outfile.write(line)
    elif type(contents) == str:
        outfile.write(contents)
    else:
        raise "Type for 'contents' not supported: " + repr(type(contents))

    outfile.close()

    return

def read_file(filename):
    """Read contents of the specified file.

    ARGUMENTS
      filename (string) - the name of the file to be read

    RETURNS
      lines (list of strings) - the contents of the file, split by line
    """

    infile = open(filename, 'r')
    lines = infile.readlines()
    infile.close()

    return lines

def detect_equilibration(A_t, maxeval=100):
    """
    Automatically detect equilibrated region.

    ARGUMENTS

    A_t (numpy.array) - timeseries

    OPTIONAL ARGUMENTS 

    maxeval (int) - maximum number of origins to evaluate, or no limit if None

    RETURNS

    t (int) - start of equilibrated data
    g (float) - statistical inefficiency of equilibrated data
    Neff_max (float) - number of uncorrelated samples   
    
    """
    T = A_t.size

    # Special case if timeseries is constant.
    if A_t.std() == 0.0:
        return (0, 1, T)

    # Determine list of origins to evaluate.
    if (maxeval is None) or (T < maxeval):
        origins = range(0, T-1)
    else:
        origins = [ int(i * (T/maxeval)) for i in range(0, maxeval) ]
    norigins = len(origins)
    
    g_t = numpy.ones([norigins], numpy.float32)
    Neff_t = numpy.ones([norigins], numpy.float32)
    for (origin_index, origin) in enumerate(origins):
        g_t[origin_index] = timeseries.statisticalInefficiency(A_t[origin:T], fast=True)
        Neff_t[origin_index] = (T-origin+1) / g_t[origin_index]
    
    Neff_max = Neff_t.max()
    origins_index = Neff_t.argmax()
    t = origins[origins_index]
    g = g_t[origins_index]
    
    return (t, g, Neff_max)

def vector_length(vector):
    """
    Compute length of a given vector.

    ARGUMENTS

    vector (simtk.unit.Quantity of numpy) - the vector whose length is to be computed

    RETURNS

    length (simtk.unit.Quantity) - the length of the vector

    """
    return numpy.sqrt(numpy.sum((vector/vector.unit)**2)) * vector.unit

def recompute_stark_shift(ncfile):
    """
    Recompute Stark shift from group dipole and group efield
    
    """

    dipole_unit = units.elementary_charge * units.nanometer
    efield_unit = units.mega*units.volt/units.centimeter
    shift_unit = units.dimensionless / units.centimeter

    shift_t = numpy.array(ncfile.variables['shift'][:])
    niterations = shift_t.shape[0]

    for iteration in range(niterations):
        group_dipole = ncfile.variables['dipole'][iteration,:]
        group_dipole = units.Quantity(group_dipole, dipole_unit)
    
        # Compute unit vector aligned with group dipole.
        group_dipole_norm = units.sqrt(units.sum(group_dipole**2))
        group_dipole_unit_vector = group_dipole / group_dipole_norm

        group_efield = ncfile.variables['efield'][iteration,:]
        group_efield = units.Quantity(group_efield, efield_unit)

        alpha = 0.87 * (units.centimeters**-1) / (units.mega*units.volts/units.centimeter) # linear Stark tuning rate of bosutinib

        # Compute the effective linear Stark shift by dotting the effective electric field impinging on the group with the dipole unit vector.
        stark_shift = alpha * units.sum(group_efield * group_dipole_unit_vector)
        stark_shift = stark_shift.in_units_of(shift_unit)
        
        shift_t[iteration] = stark_shift / shift_unit

    return shift_t

def compute_efield_from_shift(shift_t):
    dipole_unit = units.elementary_charge * units.nanometer
    efield_unit = units.mega*units.volt/units.centimeter
    shift_unit = units.dimensionless / units.centimeter

    alpha = 0.87 * (units.centimeters**-1) / (units.mega*units.volts/units.centimeter) # linear Stark tuning rate of bosutinib

    #efield_t = (shift_t * shift_unit / alpha) / efield_unit # for some reason, this sometimes breaks
    efield_t = shift_t * (shift_unit / alpha / efield_unit)

    return efield_t

def generate_chart(basedir, mutant, ncfile, nskip=10):
    #shift_t = recompute_stark_shift(ncfile)
    shift_t = numpy.array(ncfile.variables['shift'][:])
    efield_t = compute_efield_from_shift(shift_t)
    interval = 0.1 * units.picoseconds
    T = len(shift_t)
    data = "['time', 'Stark shift'],"
    for t in range(0, T, nskip):
        data += "\n[%.3f, %.3f]," % (t*interval / units.picoseconds, efield_t[t])
        
    xmin = 0.0
    xmax = 1000.0 

    ymin = efield_t.min()
    ymax = efield_t.max()

    block = """\
    <script type="text/javascript" src="https://www.google.com/jsapi"></script>
    <script type="text/javascript">
      google.load("visualization", "1", {packages:["corechart"]});
      google.setOnLoadCallback(drawChart);
      function drawChart() {
         var data = google.visualization.arrayToDataTable([         
            %(data)s
         ]);
            
         var options = {
            enableInteractivity: false,
            hAxis: {title: 'simulation time (ps)', minValue: %(xmin)f, maxValue: %(xmax)f},
            vAxis: {title: 'projected efield (MV/cm)', minValue: %(ymin)f, maxValue: %(ymax)f},
            pointSize: 1,
            legend: 'none'
         };

         var chart = new google.visualization.ScatterChart(document.getElementById("%(basedir)s-%(mutant)s"));
         chart.draw(data, options);
      }
    </script>
    """ % vars()
    return block

def generate_stark_comparison(html_outfile, computed_stark_efields, reference_stark_peaks):
    data = ""
    for record in computed_stark_efields:
        try:
            Eefield = record['Eefield']
            mutant = record['mutant']
            stark_peak = reference_stark_peaks[mutant]
            data += "[%.3f, %.3f, '%s'],\n" % (Eefield, stark_peak, mutant)
        except Exception as e:
            pass

    block = """\
    <script type="text/javascript" src="https://www.google.com/jsapi"></script>
    <script type="text/javascript">
      google.load("visualization", "1", {packages:["corechart"]});
      google.setOnLoadCallback(drawChart);
      function drawChart() {
         var data = new google.visualization.DataTable();
         data.addColumn('number', 'Efield');
         data.addColumn('number', 'stark peak');
         data.addColumn({type:'string', role:'tooltip'});
         data.addRows([
            %(data)s
         ]);
            
         var options = {
            enableInteractivity: true,
            hAxis: {title: 'computed Efield (MV/cm)'},
            vAxis: {title: 'Stark peak (cm^-1)'},
            pointSize: 10,
            legend: 'none'
         };

         var chart = new google.visualization.ScatterChart(document.getElementById("stark-comparison"));
         chart.draw(data, options);
      }
    </script>
    """ % vars()
    html_outfile.write(block)

    block = """
   <div id="stark-comparison" style="width: 800px; height: 800px;"></div>
   """ % vars()    
    html_outfile.write(block)
    html_outfile.flush()

    return

def generate_stark_comparison_image(html_outfile, computed_stark_efields, reference_stark_peaks):

    import matplotlib
    matplotlib.use('Agg')
    import matplotlib.pyplot as pyplot

    x = list()
    dx = list()
    y = list()
    dy = list()
    labels = list()
    for record in computed_stark_efields:
        try:
            Eefield = record['Eefield']
            dEefield = record['dEefield']
            mutant = record['mutant']
            stark_peak = reference_stark_peaks[mutant]
            y.append(Eefield)
            dy.append(dEefield)
            x.append(stark_peak)
            dx.append(0.1)
            labels.append(mutant)
        except Exception as e:
            pass

    x = numpy.array(x)
    dx = numpy.array(dx)
    y = numpy.array(y)
    dy = numpy.array(dy)
    
    pyplot.clf()
    out = pyplot.errorbar(x, y, fmt='k.', xerr=2*dx, yerr=2*dy)
    pyplot.xlabel('measured Stark shift (cm$^{-1}$)')
    pyplot.ylabel('computed E field (MV/cm)')

    for i in range(len(labels)):
        pyplot.text(x[i], y[i], labels[i])

    pyplot.savefig('stark-comparison.png', dpi=300)


    block = """\
    <center>
    <img src="stark-comparison.png" alt="Stark comparison plot" width="800px"/> 
    </center>
    """ % vars()
    html_outfile.write(block)
    html_outfile.flush()
    return

#=============================================================================================
# MAIN
#=============================================================================================

# TODO: Convert this to use optparse to take command-line arguments.
print "Starting..."

# Base directory.
# TODO: Autodetect basedirs.
basedirs = ['src_tbosutinib_c4+d2_refine_waters_tls_38', 'a403t_bosut_refine_nl5_30']

# Generate list of subdirectories.
#import commands
#output = commands.getoutput('ls -1tr %s' % basedirs[0])
#mutants = output.split('\n')

import stark_mutations
mutants = stark_mutations.mutation_list
print mutants

# Generate output text file.
text_output_filename = 'summary-new.txt'
outfile = open(text_output_filename, 'w')
outfile.write('%-12s %-48s %11s %8s %8s %8s %8s %8s\n' % ('mutant', 'source structure', 'efield', 'std err', 'std dev', 't0', 'g', 'Neff'))

# Generate HTML page.
html_output_filename = 'index-new.html'
html_outfile = open(html_output_filename, 'w')
header = """\
<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN"
"http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd">
<style type="text/css">
<!--
@import url("style.css");
-->
</style>
<html xmlns="http://www.w3.org/1999/xhtml" xml:lang="en" lang="en">
<head>
<meta http-equiv="Content-Type" content="text/html; charset=utf-8" />
<title>Stark shift computation summary</title>
</head>
<body>
<center>
<table id="hor-minimalist-a" border="0" cellspacing="0" cellpadding="0" width="95%">
<tr>
  <th>mutant</th>
  <th>timeseries</th>
  <th>projected efield (MV/cm)</th>
  <th>std dev (MV/cm)</th>
  <th>t0</th>
  <th>g</th>
  <th>Neff</th>
</tr>
"""
html_outfile.write(header)

computed_stark_efields = list()

# Process all mutants.
for mutant in mutants:
    print "========================================================================================"
    print mutant
    print "========================================================================================"

    first = True
    
    for basedir in basedirs:

        directory = os.path.join(basedir, mutant)
        filename = os.path.join(directory, 'stark.nc')    
        shift_unit = units.dimensionless / units.centimeter # TODO: Have these units automatically read from NetCDF file.

        if not os.path.exists(filename):
            continue

        try:
            # Open NetCDF file for reading.
            ncfile = netcdf.Dataset(filename, 'r')
        
            # Extract computed Stark shift.
            shift_t = ncfile.variables['shift'][:]

            # Eliminate mask if present.
            if hasattr(shift_t, 'mask'):
                shift_t = shift_t.data[numpy.where(shift_t.mask==False)]

            # Detect equilibrated region.
            [t0, g, Neff] = detect_equilibration(shift_t)
            print "Automated equilibration detection: t = %.1f, g = %.1f, Neff_max = %.1f" % (t0, g, Neff)

            # DEBUG
            #t0 = 0
            #Neff = len(shift_t) / g

            # Discard non-equilibrated data.
            shift_t = shift_t[t0:]
            T = len(shift_t)

            # Compute efield.
            efield_t = compute_efield_from_shift(shift_t)
        
            # Write statistics.
            Eshift = shift_t.mean() # mean
            dEshift = shift_t.std() / numpy.sqrt(T/g) # standard error
            std_shift = shift_t.std()

            Eefield = efield_t.mean() # mean
            dEefield= efield_t.std() / numpy.sqrt(T/g) # standard error
            std_efield = efield_t.std()

            print "Average Stark shift: %.3f +- %.3f (1/cm)" % (Eshift, dEshift)
            print "stddev Stark shift: %.3f" % (std_shift)

            print "Average projected efield: %.3f +- %.3f (1/cm)" % (Eefield, dEefield)
            print "stddev projected efield: %.3f" % (std_efield)

            computed_stark_efields.append({'mutant' : mutant, 'Eefield' : Eefield, 'dEefield' : dEefield})

            # Write to text summary file.
            outfile.write('%-12s %-48s %11.1f %8.1f %8.1f %8.1f %8.1f %8.1f\n' % (mutant, basedir, Eefield, dEefield, std_efield, t0, g, Neff))

            # Write to HTML file.
            initial_model_filename = os.path.join(directory, 'leap.complex.pdb')
            pdb_final_filename = os.path.join(directory, 'final.pdb')
            pdb_trajectory_filename = os.path.join(directory, 'trajectory.pdb')
            html_outfile.write(generate_chart(basedir, mutant, ncfile))
            block = """\
 <tr>
   <td>
     <b>%(basedir)s</b>
     <br /><a href="%(initial_model_filename)s" target="_blank">initial model</a>
     <br /><a href="%(pdb_final_filename)s">final snapshot</a>
     <br /><a href="%(pdb_trajectory_filename)s">trajectory</a>
   </td>
   <td><div id="%(basedir)s-%(mutant)s" style="width: 500px; height: 150px;"></div></td>
   <td>%(Eefield).3f +- %(dEefield).3f</td>
   <td>%(std_efield).3f</td>
   <td>%(t0).1f</td>
   <td>%(g).1f</td>
   <td>%(Neff).1f</td>      
 </tr>
 """ % vars()    
            if first:
                # Add mutant header.
                block = "<tr><td><h3>%(mutant)s</h3></td></tr>\n" % vars() + block
                first = False
            html_outfile.write(block)
            html_outfile.flush()

            # Clean up.
            ncfile.close()
        except Exception as e:
            print e
            continue

# Clean up.
outfile.close()

# Generate scatter plot comparison.
generate_stark_comparison_image(html_outfile, computed_stark_efields, reference_stark_peaks)

# Finish HTML file.
import time
current_time = time.ctime()
footer = """\
</table>
<h5 id="hor-minimalist-a">Download text-only <a href="summary.txt">summary</a> of this data</h5>
<h5 id="hor-minimalist-a">Last updated %(current_time)s</h5>
</center>
</body>
""" % vars()
html_outfile.write(footer)
html_outfile.close()

# Copy files to be active
import commands
commands.getoutput('cp index-new.html index.html')
commands.getoutput('cp summary-new.txt summary.txt')
