#!/opt/local/bin/python2.5

#=============================================================================================
# Render a replica trajectory in PyMOL
#=============================================================================================

#=============================================================================================
# REQUIREMENTS
#
# This code requires the NetCDF module, available either as Scientific.IO.NetCDF or standalone through pynetcdf:
# http://pypi.python.org/pypi/pynetcdf/
# http://sourceforge.net/project/showfiles.php?group_id=1315&package_id=185504
#=============================================================================================

#=============================================================================================
# TODO
#=============================================================================================

#=============================================================================================
# CHAGELOG
#=============================================================================================

#=============================================================================================
# VERSION CONTROL INFORMATION
# * 2009-08-01 JDC
# Created file.
#=============================================================================================

#=============================================================================================
# IMPORTS
#=============================================================================================

import numpy
from numpy import *
#import Scientific.IO.NetCDF # not installed
#import scipy.io.netcdf as netcdf
import netCDF4 as netcdf
import os
import os.path
import pymol
from pymol import cmd
from pymol import util

#=============================================================================================
# PARAMETERS
#=============================================================================================

#=============================================================================================
# SUBROUTINES
#=============================================================================================

def readAtomsFromPDB(pdbfilename):
    """Read atom records from the PDB and return them in a list.

    present_sequence = getPresentSequence(pdbfilename, chain=' ')
    contents of protein.seqfile
    REQUIRED ARGUMENTS
      pdbfilename - the filename of the PDB file to import from

    OPTIONAL ARGUMENTS
      chain - the one-character chain ID of the chain to import (default ' ')

    RETURN VALUES
      atoms - a list of atom{} dictionaries

    The ATOM records are read, and the sequence for which there are atomic coordinates is stored.

    """

    # Read the PDB file into memory.
    pdbfile = open(pdbfilename, 'r')
    lines = pdbfile.readlines()
    pdbfile.close()


    # Read atoms.
    atoms = []
    for line in lines:
        if line[0:5] == "ATOM ":
            # Parse line into fields.
            atom = { }
            atom["serial"] = int(line[5:11])
            atom["name"] = line[12:16]
            atom["altLoc"] = line[16:17]
            atom["resName"] = line[17:21]
            atom["chainID"] = line[21:22]
            atom["resSeq"] = int(line[22:26])
            atom["iCode"] = line[26:27]
            atom["x"] = float(line[30:38])
            atom["y"] = float(line[38:46])
            atom["z"] = float(line[46:54])

            atom["occupancy"] = 1.0
            if (line[54:60].strip() != ''):
              atom["occupancy"] = float(line[54:60])
              
            atom["tempFactor"] = 0.0
            if (line[60:66].strip() != ''):
              atom["tempFactor"] = float(line[60:66])
            
            atom["segID"] = line[72:76]
            atom["element"] = line[76:78]
            atom["charge"] = line[78:80]
            
            atoms.append(atom)

    # Return list of atoms.
    return atoms

def write_netcdf_replica_trajectories(directory, prefix, title, ncfile):
    """Write out replica trajectories in AMBER NetCDF format.

    ARGUMENTS
       directory (string) - the directory to write files to
       prefix (string) - prefix for replica trajectory files
       title (string) - the title to give each NetCDF file
       ncfile (NetCDF) - NetCDF file object for input file       
    """
    # Get current dimensions.
    niterations = ncfile.variables['positions'].shape[0]
    nstates = ncfile.variables['positions'].shape[1]
    natoms = ncfile.variables['positions'].shape[2]

    # Write out each replica to a separate file.
    for replica in range(nstates):
        # Create a new replica file.
        output_filename = os.path.join(directory, '%s-%03d.nc' % (prefix, replica))
        #ncoutfile = NetCDF.NetCDFFile(output_filename, 'w')
        ncoutfile = netcdf.Dataset(output_filename, 'w')        
        initialize_netcdf(ncoutfile, title + " (replica %d)" % replica, natoms)
        for iteration in range(niterations):
            coordinates = array(ncfile.variables['positions'][iteration,replica,:,:])
            coordinates *= 10.0 # convert nm to angstroms
            write_netcdf_frame(ncoutfile, iteration, time = 1.0 * iteration, coordinates = coordinates)
        ncoutfile.close()

    return

def compute_torsion_trajectories(ncfile, filename):
    """Write out torsion trajectories for Val 111.

    ARGUMENTS
       ncfile (NetCDF) - NetCDF file object for input file
       filename (string) - name of file to be written
    """
    atoms = [1735, 1737, 1739, 1741] # N-CA-CB-CG1 of Val 111        

    # Get current dimensions.
    niterations = ncfile.variables['positions'].shape[0]
    nstates = ncfile.variables['positions'].shape[1]
    natoms = ncfile.variables['positions'].shape[2]

    # Compute torsion angle
    def compute_torsion(positions, atoms):
        # Compute vectors from cross products        
        vBA = positions[atoms[0],:] - positions[atoms[1],:]
        vBC = positions[atoms[2],:] - positions[atoms[1],:]
        vCB = positions[atoms[1],:] - positions[atoms[2],:]
        vCD = positions[atoms[3],:] - positions[atoms[2],:]
        v1 = cross(vBA,vBC)
        v2 = cross(vCB,vCD)
        cos_theta = dot(v1,v2) / sqrt(dot(v1,v1) * dot(v2,v2))
        theta = arccos(cos_theta) * 180.0 / math.pi
        return theta
                
    # Compute torsion angles for each replica
    contents = ""
    for iteration in range(niterations):
        for replica in range(nstates):
            # Compute torsion
            torsion = compute_torsion(array(ncfile.variables['positions'][iteration,replica,:,:]), atoms)
            # Write torsion
            contents += "%8.1f" % torsion
        contents += "\n"

    # Write contents.
    write_file(filename, contents)

    return

#=============================================================================================
# MAIN
#=============================================================================================

# DEBUG: ANALYSIS PATH IS HARD-CODED FOR NOW
#source_directory = 'indole'
#source_directory = 'p-xylene'
#source_directory = '../test-systems/chk1/amber/amber-gbsa/molec000001'
ligand = 'MOL' # name of ligand residue in PDB file

# DEBUG
source_directory = '../test-systems/T4-lysozyme-L99A/amber-gbsa/amber-gbsa/1-methylpyrrole'
ligand = 'TMP' # name of ligand residue in PDB file

reference_pdbfile = os.path.join(source_directory, 'complex.pdb')
phase = 'complex'
replica = 0 # replica index to render

# Launch pymol.
pymol.finish_launching()

# Load PDB file.
cmd.rewind()
cmd.delete('all')
cmd.reset()
cmd.load(reference_pdbfile, 'complex')
cmd.select('receptor', 'not resn %s' % ligand)
#cmd.select('ligand', 'resn TMP')
cmd.select('ligand', 'resn %s and not hydrogen' % ligand)
cmd.select('pocket', '(receptor and not hydrogen) within 10.0 of ligand')
cmd.deselect()
cmd.hide('all')
#cmd.show('cartoon', 'receptor')
#cmd.show('sticks', 'ligand')
cmd.show('lines', 'all')
util.cbay('ligand')
cmd.color('green', 'receptor')
cmd.show('lines', 'pocket')
util.cbag('pocket')

# speed up builds
cmd.set('defer_builds_mode', 3)
cmd.set('cache_frames', 0)

model = cmd.get_model('complex')
for atom in model.atom:
    print "%8d %4s %3s %5d %8.3f %8.3f %8.3f" % (atom.index, atom.name, atom.resn, int(atom.resi), atom.coord[0], atom.coord[1], atom.coord[2])

# Read atoms from PDB
pdbatoms = readAtomsFromPDB(reference_pdbfile)

# Build mappings.
pdb_indices = dict()
for (index, atom) in enumerate(pdbatoms):
    key = (int(atom['resSeq']), atom['name'].strip())
    value = index
    pdb_indices[key] = value
    
model_indices = dict()
for (index, atom) in enumerate(model.atom):
    key = (int(atom.resi), atom.name)
    value = index
    model_indices[key] = value

model_mapping = list()
for (pdb_index, atom) in enumerate(pdbatoms):
    key = (int(atom['resSeq']), atom['name'].strip())
    model_index = model_indices[key]
    model_mapping.append(model_index)
    
# Construct full path to NetCDF file.
fullpath = os.path.join(source_directory, phase + '.nc')

# Open NetCDF file for reading.
print "Opening NetCDF trajectory file '%(fullpath)s' for reading..." % vars()
#ncfile = Scientific.IO.NetCDF.NetCDFFile(fullpath, 'r') 
ncfile = netcdf.Dataset(fullpath, 'r')

# Get variables
print "variables:"
print ncfile.variables.keys()

[niterations, nstates, natoms, nspatial] = ncfile.variables['positions'].shape
print "Read %(niterations)d iterations, %(nstates)d states, %(natoms)d atoms" % vars()

# DEBUG
#niterations = 10

# Load frames
print "Loading all frames..."
for iteration in range(niterations):
    print " frame %d / %d" % (iteration, niterations)
    # Set coordinates
    for pdb_index in range(natoms):
        model_index = model_mapping[pdb_index]
        for k in range(3):
            model.atom[model_index].coord[k] = ncfile.variables['positions'][iteration, replica, pdb_index, k] * 10.0 # convert to angstroms
    cmd.load_model(model, 'complex', state=iteration+1)
    #cmd.load_model(model, 'complex')    
print "Done."

cmd.hide('all')
#cmd.show('lines', 'all')
cmd.show('cartoon', 'receptor')
cmd.show('sticks', 'ligand')
cmd.show('lines', 'pocket')

cmd.mset("1 -%d" % cmd.count_states())

# Align all states
cmd.intra_fit('all')

# Zoom viewport
cmd.zoom('complex')
cmd.orient('complex')

cmd.orient('ligand')
cmd.zoom('ligand', 10.0)

# OVERRIDE
#cmd.set_view((\
#    -0.665100634,   -0.085599251,    0.741833150,\
#    -0.517284989,    0.769275129,   -0.375011891,\
#    -0.538573802,   -0.633158743,   -0.555924177,\
#    0.000000000,    0.000000000,  -80.684349060,\
#    -3.434448242,    9.531448364,  -18.679140091,\
#    63.612155914,   97.756546021,  -20.000000000 ))

# Render movie
print "Rendering movie..."
frame_prefix = 'frames/frame'
cmd.set('ray_trace_frames', 1)
for iteration in range(niterations):
    print " frame %d / %d" % (iteration, niterations)
#    # Set coordinates
#    for pdb_index in range(natoms):
#        model_index = model_mapping[pdb_index]
#        for k in range(3):
#            model.atom[model_index].coord[k] = ncfile.variables['positions'][iteration, replica, pdb_index, k] * 10.0 # convert to angstroms
#    cmd.load_model(model, 'complex', state=iteration+1)
    cmd.set('stick_transparency', float(ncfile.variables['states'][iteration, replica]) / float(nstates-1))
    cmd.mpng(frame_prefix, iteration+1, iteration+1)
    #cmd.load_model(model, 'complex')    

cmd.set('ray_trace_frames', 0)

# Close file
ncfile.close()

