#!/usr/bin/python

# Process trajectory data to compute torsions and weights.

#=============================================================================================
# REQUIREMENTS
#
# This code requires the 'pynetcdf' package, containing the Scientific.IO.NetCDF package built for numpy.
#
# http://pypi.python.org/pypi/pynetcdf/
# http://sourceforge.net/project/showfiles.php?group_id=1315&package_id=185504
#
# This code also uses the 'MBAR' package, implementing the multistate Bennett acceptance ratio estimator, available here:
#
# http://www.simtk.org/home/pymbar
#=============================================================================================

#===================================================================================================
# IMPORTS
#===================================================================================================
import numpy
import math
import commands
import os
import os.path
import datetime # time and date

import numpy
import numpy.linalg 

import pymbar # for MBAR analysis
import timeseries # for timeseries analysis

import simtk.unit as units

#from pynetcdf import NetCDF # for writing of data objects for plotting in Matlab or Mathematica
import netCDF4 as netcdf # for writing of data objects for plotting in Matlab or Mathematica

#===================================================================================================
# CONSTANTS
#===================================================================================================

kB = units.BOLTZMANN_CONSTANT_kB * units.AVOGADRO_CONSTANT_NA # Boltzmann constant in energy/temperature units

#===================================================================================================
# PARAMETERS
#===================================================================================================

nequil = 50 # number of initial iterations to discard to equilibration
tau_unit = 0.1 # sampling time in ps
netcdf_input_filename = 'alanine-dipeptide-parallel-tempering.nc' # netcdf input filename
netcdf_output_filename = 'alanine-dipeptide-processed.nc' # netcdf output filename
#use_analytical_momentum = True # if True, will include analytical momentum contribution to partition function in energies
#netcdf_output_filename = 'output/alanine-dipeptide-transition-rate-analytical-momentum.nc' # netcdf output filename
ndof = (3*22-21) + 431*(3*3-3) - 3 # number of degrees of freedom

#===================================================================================================
# SUBROUTINES
#===================================================================================================

def compute_torsion(coordinates, i, j, k, l):
    """
    Compute torsion angle defined by four atoms.
    
    ARGUMENTS

    coordinates (simtk.unit.Quantity wrapping numpy natoms x 3) - atomic coordinates
    i, j, k, l - four atoms defining torsion angle

    NOTES

    Algorithm of Swope and Ferguson [1] is used.

    [1] Swope WC and Ferguson DM. Alternative expressions for energies and forces due to angle bending and torsional energy.
    J. Comput. Chem. 13:585, 1992.

    """
    # Swope and Ferguson, Eq. 26
    rij = (coordinates[i,:] - coordinates[j,:]) / units.angstroms
    rkj = (coordinates[k,:] - coordinates[j,:]) / units.angstroms
    rlk = (coordinates[l,:] - coordinates[k,:]) / units.angstroms
    rjk = (coordinates[j,:] - coordinates[k,:]) / units.angstroms # JDC: added for clarity

    # Swope and Ferguson, Eq. 27
    t = numpy.cross(rij, rkj)
    u = numpy.cross(rjk, rlk) # JDC: fixed because this didn't seem to match diagram in equation in paper

    # Swope and Ferguson, Eq. 28
    t_norm = numpy.sqrt(numpy.dot(t, t))
    u_norm = numpy.sqrt(numpy.dot(u, u))    
    cos_theta = numpy.dot(t, u) / (t_norm * u_norm)

    if (abs(cos_theta) > 1.0):
        cos_theta = 1.0 * numpy.sign(cos_theta)
    if math.isnan(cos_theta):
        print "cos_theta is NaN"
    if math.isnan(numpy.arccos(cos_theta)):
        print "arccos(cos_theta) is NaN"
        print "cos_theta = %f" % cos_theta
        print coordinates[i,:]
        print coordinates[j,:]
        print coordinates[k,:]
        print coordinates[l,:]
        print "n1"
        print n1
        print "n2"
        print n2       

    theta = numpy.arccos(cos_theta) * numpy.sign(numpy.dot(rkj, numpy.cross(t, u))) * units.radians

    return theta

                                                    
def compute_torsions(trajectory, i, j, k, l):
    """
    Compute torsion angle defined by four atoms.
    
    ARGUMENTS

    coordinates (simtk.unit.Quantity wrapping numpy natoms x 3) - atomic coordinates
    i, j, k, l - four atoms defining torsion angle

    NOTES

    Algorithm of Swope and Ferguson [1] is used.

    [1] Swope WC and Ferguson DM. Alternative expressions for energies and forces due to angle bending and torsional energy.
    J. Comput. Chem. 13:585, 1992.

    """

    T = trajectory.shape[0]
    torsions = units.Quantity(numpy.zeros([T], numpy.float32), units.radians)

    try:
        from scipy import weave
        from scipy.weave import converters

        weave_context = dict()
        weave_context['T'] = T
        weave_context['i'] = i
        weave_context['j'] = j
        weave_context['k'] = k
        weave_context['l'] = l        
        weave_context['coordinates'] = trajectory / units.angstroms
        weave_context['torsions'] = numpy.zeros([T], numpy.float32)

        code = """
        for(int frame = 0; frame < T; frame++) {
           // Swope and Ferguson, Eq. 26
           double rij[3];
           double rkj[3];
           double rlk[3];
           double rjk[3];
           for(int d = 0; d < 3; d++) {
              rij[d] = (COORDINATES3(frame,i,d) - COORDINATES3(frame,j,d));
              rkj[d] = (COORDINATES3(frame,k,d) - COORDINATES3(frame,j,d));
              rlk[d] = (COORDINATES3(frame,l,d) - COORDINATES3(frame,k,d));
              rjk[d] = (COORDINATES3(frame,j,d) - COORDINATES3(frame,k,d)); // JDC: added for clarity
              }
                      
           // Swope and Ferguson, Eq. 27
           double t[3]; // t = cross(rij, rkj)
           t[0] = rij[1]*rkj[2] - rij[2]*rkj[1];
           t[1] = rij[0]*rkj[2] - rij[2]*rkj[0];
           t[2] = rij[0]*rkj[1] - rij[1]*rkj[0];
           double u[3]; // u = cross(rjk, rlk) // JDC: fixed because this didn't seem to match diagram in equation in paper
           u[0] = rjk[1]*rlk[2] - rjk[2]*rlk[1];
           u[1] = rjk[0]*rlk[2] - rjk[2]*rlk[0];
           u[2] = rjk[0]*rlk[1] - rjk[1]*rlk[0];

           // Swope and Ferguson, Eq. 28
           double t_norm = 0.0; // t_norm = norm(t)
           double u_norm = 0.0; // u_norm = norm(u)
           double cos_theta = 0.0; // cos_theta = dot(t,u) / (norm(t) * norm(u))
           for (int d = 0; d < 3; d++) {
              t_norm += t[d]*t[d];
              u_norm += u[d]*u[d];
              cos_theta += t[d]*u[d];
              }
           t_norm = sqrt(t_norm);   
           u_norm = sqrt(u_norm);
           cos_theta = cos_theta / (t_norm * u_norm);

           if (cos_theta > 1.0)
              cos_theta = 1.0;
           if (cos_theta < -1.0)
              cos_theta = -1.0;

           torsions[frame] = acos(cos_theta);
 
           // Determine sign.
           double t_cross_u[3];
           t_cross_u[0] = t[1]*u[2] - t[2]*u[1];
           t_cross_u[1] = t[0]*u[2] - t[2]*u[0];
           t_cross_u[2] = t[0]*u[1] - t[1]*u[0];        
           double rkj_dot_t_cross_u = 0.0;
           for (int d = 0; d < 3; d++)
              rkj_dot_t_cross_u += rkj[d] * t_cross_u[d];
           if (rkj_dot_t_cross_u < 0.0)
              torsions[frame] = - torsions[frame];
           }

        """

        # Execute inline C code with weave.
        weave.inline(code, weave_context.keys(), local_dict=weave_context, headers=['<math.h>', '<stdlib.h>'], verbose=0)        

        # Store results.
        for t in range(T):
            torsions[t] = weave_context['torsions'][t] * units.radians
        
    except Exception as exception:

        print exception

        T = trajectory.shape[0]
        torsions = units.Quantity(numpy.zeros([T], numpy.float32), units.radians)
        for t in range(T):
            torsions[t] = compute_torsion(trajectory[t,:,:], i, j, k, l)

    return torsions

def read_pdb(filename, natoms=None):
    """
    Read the contents of a PDB file.

    ARGUMENTS

    filename (string) - name of the file to be read

    RETURNS

    atoms (list of dict) - atoms[index] is a dict of fields for the ATOM residue

    """
    
    # Read the PDB file into memory.
    pdbfile = open(filename, 'r')

    # Extract the ATOM entries.
    # Format described here: http://bmerc-www.bu.edu/needle-doc/latest/atom-format.html
    atoms = list()
    for line in pdbfile:
        if line[0:6] == "ATOM  ":
            # Parse line into fields.
            atom = dict()
            atom["serial"] = line[6:11]
            atom["atom"] = line[12:16]
            atom["altLoc"] = line[16:17]
            atom["resName"] = line[17:20]
            atom["chainID"] = line[21:22]
            atom["Seqno"] = line[22:26]
            atom["iCode"] = line[26:27]
            atom["x"] = line[30:38]
            atom["y"] = line[38:46]
            atom["z"] = line[46:54]
            atom["occupancy"] = line[54:60]
            atom["tempFactor"] = line[60:66]
            atoms.append(atom)
        if (natoms is not None) and (len(atoms)==natoms):
            break
            
    # Close PDB file.
    pdbfile.close()

    # Return dictionary of present residues.
    return atoms

def write_file(filename, contents):
   """Write the specified contents to a file.
   
   ARGUMENTS
     filename (string) - the file to be written
     contents (string) - the contents of the file to be written
     
   """

   outfile = open(filename, 'w')
   
   if type(contents) == list:
      for line in contents:
         outfile.write(line)
   elif type(contents) == str:
      outfile.write(contents)
   else:
      raise "Type for 'contents' not supported: " + repr(type(contents))
      
   outfile.close()
   
   return

def read_file(filename):
   """Read contents of the specified file.
      
   ARGUMENTS
     filename (string) - the name of the file to be read
     
   RETURNS
     lines (list of strings) - the contents of the file, split by line

   """

   infile = open(filename, 'r')
   lines = infile.readlines()
   infile.close()
   
   return lines

def logSum(log_terms):
   """Compute the log of a sum of terms whose logarithms are provided.

   REQUIRED ARGUMENTS  
      log_terms is the array (possibly multidimensional) containing the logs of the terms to be summed.

   RETURN VALUES
      log_sum is the log of the sum of the terms.

   """

   # compute the maximum argument
   max_log_term = log_terms.max()

   # compute the reduced terms
   terms = numpy.exp(log_terms - max_log_term)

   # compute the log sum
   log_sum = log( terms.sum() ) + max_log_term

   # return the log sum
   return log_sum

def kronecker(i,j):
   """Kronecker delta.
   """

   if (i == j):
      return 1
   return 0

def delta(a,b,i,j):
   """
   """

   if (((a == i) and (b == j)) or ((a == j) and (b == i))):
      return 1
   return 0

def show_mixing_statistics(ncfile, show_transition_matrix=False):
    """
    Compute mixing statistics among thermodynamic states.

    OPTIONAL ARGUMENTS

    show_transition_matrix (boolean) - if True, the transition matrix will be printed

    RETURN VALUES

    Tij (numpy array of dimension [nstates,nstates]) - Tij[i,j] is the fraction of time a pair of replicas at states i and j were swapped during an iteration    

    """

    print "Computing mixing statistics..."

    states = ncfile.variables['states'][:,:].copy()

    # Determine number of iterations and states.
    [niterations, nstates] = ncfile.variables['states'][:,:].shape
    
    # Compute statistics of transitions.
    Nij = numpy.zeros([nstates,nstates], numpy.float64)
    for iteration in range(niterations-1):
        for ireplica in range(nstates):
            istate = states[iteration,ireplica]
            jstate = states[iteration+1,ireplica]
            Nij[istate,jstate] += 0.5
            Nij[jstate,istate] += 0.5
    Tij = numpy.zeros([nstates,nstates], numpy.float64)
    for istate in range(nstates):
        Tij[istate,:] = Nij[istate,:] / Nij[istate,:].sum()

    if show_transition_matrix:
        # Print observed transition probabilities.
        PRINT_CUTOFF = 0.001 # Cutoff for displaying fraction of accepted swaps.
        print "Cumulative symmetrized state mixing transition matrix:"
        print "%6s" % "",
        for jstate in range(nstates):
            print "%6d" % jstate,
        print ""
        for istate in range(nstates):
            print "%-6d" % istate,
            for jstate in range(nstates):
                P = Tij[istate,jstate]
                if (P >= PRINT_CUTOFF):
                    print "%6.3f" % P,
                else:
                    print "%6s" % "",
            print ""

    # Estimate second eigenvalue and equilibration time.
    mu = numpy.linalg.eigvals(Tij)
    mu = -numpy.sort(-mu) # sort in descending order
    if (mu[1] >= 1):
        print "Perron eigenvalue is unity; Markov chain is decomposable."
    else:
        print "Perron eigenvalue is %9.5f; state equilibration timescale is ~ %.1f iterations" % (mu[1], 1.0 / (1.0 - mu[1]))

    return Tij

def show_acceptance_statistics(ncfile):
    """
    Print summary of exchange acceptance statistics.

    ARGUMENTS
       ncfile (NetCDF file handle) - the parallel tempering datafile to be analyzed

    RETURNS
       fraction_accepted (numpy array of [nstates-1]) - fraction_accepted[separation] is fraction of swaps attempted between state i and i+separation that were accepted

    """

    print "Computing acceptance statistics..."

    nstates = ncfile.variables['proposed'][:,:,:].shape[1]

    # Aggregated proposed and accepted by how far from the diagonal we are.
    fraction_accepted = numpy.ones([nstates-1], numpy.float64)
    for separation in range(1,nstates-1):
        nproposed = 0
        naccepted = 0            
        for i in range(nstates):
            j = i+separation
            if (j < nstates):                
                nproposed += ncfile.variables['proposed'][:,i,j].sum()
                naccepted += ncfile.variables['accepted'][:,i,j].sum()
        fraction_accepted[separation] = float(naccepted) / float(nproposed)
        print "%5d : %10d %10d : %8.5f" % (separation, nproposed, naccepted, fraction_accepted[separation])

    return fraction_accepted

def write_pdb(filename, trajectory, atoms):
    """Write out replica trajectories as multi-model PDB files.

    ARGUMENTS
       filename (string) - name of PDB file to be written
       trajectory
       atoms (list of dict) - parsed PDB file ATOM entries from read_pdb() - WILL BE CHANGED
    """

    # Create file.
    outfile = open(filename, 'w')

    nframes = trajectory.shape[0]

    # Write trajectory as models
    for frame_index in range(nframes):
        outfile.write("MODEL     %4d\n" % (frame_index+1))

        # Write ATOM records.
        for (index, atom) in enumerate(atoms):
            atom["x"] = "%8.3f" % trajectory[frame_index,index,0]
            atom["y"] = "%8.3f" % trajectory[frame_index,index,1]
            atom["z"] = "%8.3f" % trajectory[frame_index,index,2]
            outfile.write('ATOM  %(serial)5s %(atom)4s%(altLoc)c%(resName)3s %(chainID)c%(Seqno)5s   %(x)8s%(y)8s%(z)8s\n' % atom)

        outfile.write("ENDMDL\n")
        
    # Close file.
    outfile.close()

    return

def unitsum(x):
    return units.Quantity((x / x.unit).sum(), x.unit)

#===================================================================================================
# MAIN
#===================================================================================================

#===================================================================================================
# Open parallel tempering NetCDF file for reading.
#===================================================================================================

# Open the NetCDF trajectory file.
print "Opening parallel tempering NetCDF file for reading..."
#repex_ncfile = NetCDF.NetCDFFile(netcdf_output_filename, 'r')
repex_ncfile = netcdf.Dataset(netcdf_input_filename, 'r')
    
#===================================================================================================
# Read temperatures
#===================================================================================================

print "Reading temperatures and other dimensions from parallel tempering dataset..."

# Determine dimensions.
print "Reading statistics..."
[N, K, T, natoms, ndim] = repex_ncfile.variables['trajectories'].shape
print "%d trajectories" % N
print "%d replicas" % K
print "%d snapshots/trajectory" % T
print "%d atoms in trajectories" % natoms
print "%d dimensions/atom" % ndim
print ""

# Read temperatures.
temperature_k = repex_ncfile.variables['temperatures'][:].copy()
temperature_k = units.Quantity(temperature_k, units.kelvin)
beta_k = 1.0 / (kB * temperature_k)

#===================================================================================================
# Write trajectories.
#===================================================================================================

write_trajectories = False
if write_trajectories:
    print "Writing trajectories..."
    pdbfilename = 'setup/alanine-dipeptide.pdb'
    natoms = 22
    atoms = read_pdb(pdbfilename, natoms=natoms) # read PDB
    iteration = N-1
    for replica in range(K):
        filename = 'replica-%05d.pdb' % replica
        trajectory = repex_ncfile.variables['trajectories'][iteration,replica,:,:,:].copy() * 10.0
        write_pdb(filename, trajectory, atoms)

#===================================================================================================
# Show acceptance probabilities, broken down by separation in number of temperatures.
#===================================================================================================

fraction_accepted = show_acceptance_statistics(repex_ncfile)
outfile = open('fraction-accepted.txt', 'w')
for i in range(K-1):
    outfile.write("%8.5f" % fraction_accepted[i])
    outfile.write("\n")
outfile.close()

#===================================================================================================
# Show mixing statistics.
#===================================================================================================

Tij = show_mixing_statistics(repex_ncfile, show_transition_matrix=True)
outfile = open('swap-statistics.txt', 'w')
for i in range(K):
    for j in range(K):
        outfile.write("%24e" % Tij[i,j])
    outfile.write("\n")
outfile.close()

#===================================================================================================
# Estimate statistical inefficiency and determine subset of effectively uncorrelated samples.
#===================================================================================================

# Compute negative log-probability of product space of all replicas.
print "Computing log-probability history..."
u_n = numpy.zeros([N], numpy.float64)
for iteration in range(N):
   u_n[iteration] = 0.0
   for replica in range(K):
      state = repex_ncfile.variables['states'][iteration,replica]
      u_n[iteration] += repex_ncfile.variables['energies'][iteration,replica,state]

# Compute statistical inefficiency.
print "Estimating statistical inefficiency after discarding first %d iterations to equilibration" % nequil
g_u = timeseries.statisticalInefficiency(u_n[nequil:])
print "g_u = %8.1f iterations" % g_u

# Determine indices of effectively uncorrelated trajectories.
indices = timeseries.subsampleCorrelatedData(u_n[nequil:], g=g_u)
indices = numpy.array(indices) + nequil

# DEBUG: Use all samples.
#indices = numpy.arange(N)

# Reduce number of samples.
N = indices.size
print "There are %d uncorrelated samples (after discarding initial %d and subsampling by %.1f)" % (N, nequil, g_u)
    
#===================================================================================================
# Initialize the NetCDF file for output of computed data objects.
#===================================================================================================

# Open the NetCDF trajectory file.
print "Opening analysis NetCDF file for writing..."
#output_ncfile = NetCDF.NetCDFFile(netcdf_output_filename, 'w')
output_ncfile = netcdf.Dataset(netcdf_output_filename, 'w', format='NETCDF3_CLASSIC')
    
# Set global attributes.
setattr(output_ncfile, 'title', "Analysis data produced at %s" % datetime.datetime.now().ctime())
setattr(output_ncfile, 'application', 'analyze-correlation-function.py')

# Store dimensions in netcdf.
output_ncfile.createDimension('K', K)             # number of temperatures
output_ncfile.createDimension('N', N)             # number of trajectories per temperature
output_ncfile.createDimension('T', T)             # number of snapshots per trajectory
  
variable = output_ncfile.createVariable('temperature_k', 'd', ('K',))
setattr(variable, 'units', 'Kelvin')
setattr(variable, 'description', 'temperature_k[k] is the temperature of temperature index k')
output_ncfile.variables['temperature_k'][:] = temperature_k / units.kelvin

variable = output_ncfile.createVariable('beta_k', 'f', ('K',))
setattr(variable, 'units', '1/(kcal/mol)')
setattr(variable, 'description', 'beta_k[k] is the inverse temperature of temperature index k')
output_ncfile.variables['beta_k'][:] = beta_k / (1.0 / units.kilocalories_per_mole)

#===================================================================================================
# Read path Hamiltonians for uncorrelated trajectories.
#===================================================================================================

print "Computing path Hamiltonians..."
H_kn = units.Quantity(numpy.zeros([K,N], numpy.float64), units.kilocalories_per_mole)
for n in range(N):
   # Get index into original iteration.
   iteration = indices[n]

   # Compute path Hamiltonians.
   for replica in range(K):
      state = repex_ncfile.variables['states'][iteration,replica]
      u = float(repex_ncfile.variables['energies'][iteration,replica,state])
      H_kn[state,n] = u / beta_k[state]

variable = output_ncfile.createVariable('H_kn', 'd', ('K','N'))
setattr(variable, 'units', 'kcal/mol')
setattr(variable, 'description', 'H_kn[k,n] is the path Hamiltonian of trajectory n from state k')
output_ncfile.variables['H_kn'][:,:] = H_kn[:,:] / units.kilocalories_per_mole

#===================================================================================================
# Compute reduced potentials in all states for MBAR.
#===================================================================================================

print "Computing reduced potentials..."
u_kln = numpy.zeros([K,K,N], numpy.float64)
for n in range(N):
   for k in range(K):
      u_kln[k,:,n] = beta_k[:] * H_kn[k,n]

#===================================================================================================
# Initialize MBAR.
#===================================================================================================

print "Initiaizing MBAR..."
N_k = N * numpy.ones([K], numpy.int32) # N_k[k] is the number of uncorrelated samples from themodynamic state k
mbar = pymbar.MBAR(u_kln, N_k, verbose=True, method='Newton-Raphson', initialize='BAR', relative_tolerance = 1.0e-10)

#===================================================================================================
# Compute weights at each temperature.
#===================================================================================================

# Choose temperatures for reweighting to be the simulation temperatures plus their midpoints.
L = 2*K-1 # number of temperatures for reweighting
reweighted_temperature_l = units.Quantity(numpy.zeros([L], numpy.float64), units.kelvin)
for k in range(K):
   reweighted_temperature_l[2*k] = temperature_k[k] # simulation temperatures
for k in range(K-1):   
   reweighted_temperature_l[2*k+1] = (temperature_k[k] + temperature_k[k+1]) / 2.0 # midpoint temperatures
print "Temperatures for reweighting:"
print reweighted_temperature_l

print "Computing trajectory weights for reweighting temperatures..."
log_w_lkn = numpy.zeros([L,K,N], numpy.float64) # w_lkn[l,k,n] is the normalized weight of snapshot n from simulation k at reweighted temperature l
w_lkn = numpy.zeros([L,K,N], numpy.float64) # w_lkn[l,k,n] is the normalized weight of snapshot n from simulation k at reweighted temperature l

# alternate: first compute just denominators    
all_log_denom = mbar._computeUnnormalizedLogWeights(numpy.zeros([mbar.K,mbar.N_max],dtype=numpy.float64))
for l in range(L):
   temperature = reweighted_temperature_l[l]
   beta = 1.0 / (kB * temperature)
   u_kn = beta * H_kn
   log_w_kn = -u_kn+all_log_denom
   w_kn = numpy.exp(log_w_kn - log_w_kn.max())
   w_kn = w_kn / w_kn.sum()
   w_lkn[l,:,:] = w_kn
   log_w_lkn[l,:,:] = log_w_kn   

# Store weights.
output_ncfile.createDimension('L', L)             # number of temperatures for reweighting

variable = output_ncfile.createVariable('reweighted_temperature_l', 'd', ('L',))
setattr(variable, 'units', 'Kelvin')
setattr(variable, 'description', 'reweighted_temperature_l[k] is the temperature of reweighted temperature index k')
output_ncfile.variables['reweighted_temperature_l'][:] = reweighted_temperature_l / units.kelvin

variable = output_ncfile.createVariable('log_w_lkn', 'd', ('L','K','N'))
setattr(variable, 'units', 'dimensionless')
setattr(variable, 'description', 'log_w_lkn[l,k,n] is the unnormalized log weight of trajectory n from temperature k at reweighted temperature l')
output_ncfile.variables['log_w_lkn'][:,:,:] = log_w_lkn

variable = output_ncfile.createVariable('w_lkn', 'd', ('L','K','N'))
setattr(variable, 'units', 'dimensionless')
setattr(variable, 'description', 'w_lkn[l,k,n] is the normalized weight of trajectory n from temperature k at reweighted temperature l')
output_ncfile.variables['w_lkn'][:,:,:] = w_lkn

#===================================================================================================
# Compute generalized heat capacity as a function of temperature.
#===================================================================================================

print "Computing generalized heat capacity as a function of temperature..."
heat_capacity_units = units.kilocalories_per_mole / units.kelvin
Cv_l = units.Quantity(numpy.zeros([L], numpy.float64), heat_capacity_units)
for l in range(L):
   temperature = reweighted_temperature_l[l]
   beta = 1.0 / (kB * temperature)
   EH = numpy.sum(w_lkn[l,:,:] * (H_kn/H_kn.unit)) * H_kn.unit
   EH2 = numpy.sum(w_lkn[l,:,:] * (H_kn/H_kn.unit)**2) * H_kn.unit**2
   varH = numpy.sum(w_lkn[l,:,:] * ((H_kn - EH)/H_kn.unit)**2) * H_kn.unit**2
   Cv_l[l] = kB * beta**2 * varH

variable = output_ncfile.createVariable('Cv_l', 'd', ('L',))
setattr(variable, 'units', 'kcal/mol/kelvin')
setattr(variable, 'description', 'Cv_l[l] is the generalized heat capacity of reweighted temperature l')
output_ncfile.variables['Cv_l'][:] = Cv_l[:] / heat_capacity_units

outfile = open('generalized-heat-capacity.txt', 'w')
for l in range(L):
    outfile.write('%8.1f K : %16.3f kcal/mol/K' % (reweighted_temperature_l[l] / units.kelvin, Cv_l[l] / heat_capacity_units))
outfile.close()

#===================================================================================================
# Compute peptide torsions.
#===================================================================================================

class Torsion(object):
   def __init__(self, name, atom_indices):
      """
      ARGUMENTS
         name (string) - name of torsion
         atom_indices (list) - atom indices that specify the torsion
      """
      self.name = name
      self.atom_indices = atom_indices
      self.torsion_knt = units.Quantity(numpy.zeros([K,N,T], numpy.float32), units.degrees)
      return

print "Computing torsions..."

# Make a dictionary of torsions to compute.
torsions = dict()
torsions['phi'] = Torsion('phi', [4, 6, 8, 14])
torsions['psi'] = Torsion('psi', [6, 8, 14, 16])
torsions['omega1'] = Torsion('omega1', [1, 4, 6, 8])
torsions['omega2'] = Torsion('omega2', [8, 14, 16, 18])

for n in range(N):
   print " sample %d / %d" % (n, N)
   
   # Get index to original iteration.
   iteration = indices[n]

   # Extract trajectories for all replicas for this interation.
   trajectories = units.Quantity(numpy.zeros([K,T,natoms,ndim], numpy.float32), units.nanometers)   
   trajectories[:,:,:,:] = units.Quantity(repex_ncfile.variables['trajectories'][iteration,:,:,:,:], units.nanometers)

   # Process all trajectories with all torsions.
   for replica in range(K):
      trajectory = trajectories[replica,:,:,:] # trajectory for replica
      state = repex_ncfile.variables['states'][iteration,replica] # state for replica
      for (torsion_name, torsion) in torsions.iteritems():
          torsion.torsion_knt[state,n,:] = compute_torsions(trajectory, *(torsion.atom_indices)) 

# Write data to NetCDF file.
for (torsion_name, torsion) in torsions.iteritems():
   variable_name = torsion.name + "_knt"
   variable = output_ncfile.createVariable(variable_name, 'f', ('K','N','T'))
   setattr(variable, 'units', 'degrees')
   setattr(variable, 'description', '%s[k,n,t] is the %s torsion of snapshot t of uncorrelated trajectory n of state k' % (variable_name, torsion_name))
   output_ncfile.variables[variable_name][:,:,:] = torsion.torsion_knt / units.degrees

#===================================================================================================
# Close NetCDF files.
#===================================================================================================

repex_ncfile.close()
output_ncfile.close()
print "Done."

