#!/opt/local/bin/python2.5

#=============================================================================================
# Render a replica trajectory in PyMOL
#=============================================================================================

#=============================================================================================
# REQUIREMENTS
#
# This code requires the NetCDF module, available either as Scientific.IO.NetCDF or standalone through pynetcdf:
# http://pypi.python.org/pypi/pynetcdf/
# http://sourceforge.net/project/showfiles.php?group_id=1315&package_id=185504
#=============================================================================================

#=============================================================================================
# TODO
#=============================================================================================

#=============================================================================================
# CHAGELOG
#=============================================================================================

#=============================================================================================
# VERSION CONTROL INFORMATION
# * 2009-08-01 JDC
# Created file.
#=============================================================================================

#=============================================================================================
# IMPORTS
#=============================================================================================

import numpy
from numpy import *

#import Scientific.IO.NetCDF
import netCDF4 as netcdf # for writing of data objects for plotting in Matlab or Mathematica

import os
import os.path
from pymol import cmd
from pymol import util

#=============================================================================================
# PARAMETERS
#=============================================================================================

#=============================================================================================
# SUBROUTINES
#=============================================================================================

def readAtomsFromPDB(pdbfilename):
    """Read atom records from the PDB and return them in a list.

    present_sequence = getPresentSequence(pdbfilename, chain=' ')
    contents of protein.seqfile
    REQUIRED ARGUMENTS
      pdbfilename - the filename of the PDB file to import from

    OPTIONAL ARGUMENTS
      chain - the one-character chain ID of the chain to import (default ' ')

    RETURN VALUES
      atoms - a list of atom{} dictionaries

    The ATOM records are read, and the sequence for which there are atomic coordinates is stored.

    """

    # Read the PDB file into memory.
    pdbfile = open(pdbfilename, 'r')
    lines = pdbfile.readlines()
    pdbfile.close()


    # Read atoms.
    atoms = []
    for line in lines:
        if line[0:5] == "ATOM ":
            # Parse line into fields.
            atom = { }
            atom["serial"] = int(line[5:11])
            atom["name"] = line[12:16]
            atom["altLoc"] = line[16:17]
            atom["resName"] = line[17:21]
            atom["chainID"] = line[21:22]
            atom["resSeq"] = int(line[22:26])
            atom["iCode"] = line[26:27]
            atom["x"] = float(line[30:38])
            atom["y"] = float(line[38:46])
            atom["z"] = float(line[46:54])

            atom["occupancy"] = 1.0
            if (line[54:60].strip() != ''):
              atom["occupancy"] = float(line[54:60])
              
            atom["tempFactor"] = 0.0
            if (line[60:66].strip() != ''):
              atom["tempFactor"] = float(line[60:66])
            
            atom["segID"] = line[72:76]
            atom["element"] = line[76:78]
            atom["charge"] = line[78:80]
            
            atoms.append(atom)

    # Return list of atoms.
    return atoms

def write_netcdf_replica_trajectories(directory, prefix, title, ncfile):
    """Write out replica trajectories in AMBER NetCDF format.

    ARGUMENTS
       directory (string) - the directory to write files to
       prefix (string) - prefix for replica trajectory files
       title (string) - the title to give each NetCDF file
       ncfile (NetCDF) - NetCDF file object for input file       
    """
    # Get current dimensions.
    niterations = ncfile.variables['positions'].shape[0]
    nstates = ncfile.variables['positions'].shape[1]
    natoms = ncfile.variables['positions'].shape[2]

    # Write out each replica to a separate file.
    for replica in range(nstates):
        # Create a new replica file.
        output_filename = os.path.join(directory, '%s-%03d.nc' % (prefix, replica))
        ncoutfile = NetCDF.NetCDFFile(output_filename, 'w')
        initialize_netcdf(ncoutfile, title + " (replica %d)" % replica, natoms)
        for iteration in range(niterations):
            coordinates = array(ncfile.variables['positions'][iteration,replica,:,:])
            coordinates *= 10.0 # convert nm to angstroms
            write_netcdf_frame(ncoutfile, iteration, time = 1.0 * iteration, coordinates = coordinates)
        ncoutfile.close()

    return

def compute_torsion_trajectories(ncfile, filename):
    """Write out torsion trajectories for Val 111.

    ARGUMENTS
       ncfile (NetCDF) - NetCDF file object for input file
       filename (string) - name of file to be written
    """
    atoms = [1735, 1737, 1739, 1741] # N-CA-CB-CG1 of Val 111        

    # Get current dimensions.
    niterations = ncfile.variables['positions'].shape[0]
    nstates = ncfile.variables['positions'].shape[1]
    natoms = ncfile.variables['positions'].shape[2]

    # Compute torsion angle
    def compute_torsion(positions, atoms):
        # Compute vectors from cross products        
        vBA = positions[atoms[0],:] - positions[atoms[1],:]
        vBC = positions[atoms[2],:] - positions[atoms[1],:]
        vCB = positions[atoms[1],:] - positions[atoms[2],:]
        vCD = positions[atoms[3],:] - positions[atoms[2],:]
        v1 = cross(vBA,vBC)
        v2 = cross(vCB,vCD)
        cos_theta = dot(v1,v2) / sqrt(dot(v1,v1) * dot(v2,v2))
        theta = arccos(cos_theta) * 180.0 / math.pi
        return theta
                
    # Compute torsion angles for each replica
    contents = ""
    for iteration in range(niterations):
        for replica in range(nstates):
            # Compute torsion
            torsion = compute_torsion(array(ncfile.variables['positions'][iteration,replica,:,:]), atoms)
            # Write torsion
            contents += "%8.1f" % torsion
        contents += "\n"

    # Write contents.
    write_file(filename, contents)

    return

#=============================================================================================
# MAIN
#=============================================================================================

# DEBUG: ANALYSIS PATH IS HARD-CODED FOR NOW
#source_directory = 'indole'
source_directory = 'p-xylene'

reference_pdbfile = os.path.join(source_directory, 'complex.pdb')
phase = 'complex'
replica = 0 # replica index to render

# Load PDB file.
cmd.rewind()
cmd.delete('all')
cmd.reset()
cmd.load(reference_pdbfile, 'complex')
cmd.select('receptor', 'not resn TMP')
#cmd.select('ligand', 'resn TMP')
cmd.select('ligand', 'resn TMP and not hydrogen')
cmd.deselect()
cmd.hide('all')
#cmd.show('cartoon', 'receptor')
#cmd.show('sticks', 'ligand')
cmd.show('lines', 'all')
util.cbay('ligand')
cmd.color('green', 'receptor')

# speed up builds
cmd.set('defer_builds_mode', 3)
cmd.set('cache_frames', 0)

model = cmd.get_model('complex')
for atom in model.atom:
    print "%8d %4s %3s %5d %8.3f %8.3f %8.3f" % (atom.index, atom.name, atom.resn, int(atom.resi), atom.coord[0], atom.coord[1], atom.coord[2])

# Read atoms from PDB
pdbatoms = readAtomsFromPDB(reference_pdbfile)

# Build mappings.
pdb_indices = dict()
for (index, atom) in enumerate(pdbatoms):
    key = (int(atom['resSeq']), atom['name'].strip())
    value = index
    pdb_indices[key] = value
    
model_indices = dict()
for (index, atom) in enumerate(model.atom):
    key = (int(atom.resi), atom.name)
    value = index
    model_indices[key] = value

model_mapping = list()
for (pdb_index, atom) in enumerate(pdbatoms):
    key = (int(atom['resSeq']), atom['name'].strip())
    model_index = model_indices[key]
    model_mapping.append(model_index)
    
# Construct full path to NetCDF file.
fullpath = os.path.join(source_directory, phase + '.nc')

# Open NetCDF file for reading.
print "Opening NetCDF trajectory file '%(fullpath)s' for reading..." % vars()
#ncfile = Scientific.IO.NetCDF.NetCDFFile(fullpath, 'r')
ncfile = netcdf.Dataset(fullpath, 'r')

# DEBUG
print "dimensions:"
print ncfile.dimensions
    
# Read dimensions.
nstates = ncfile.dimensions['replica']
natoms = ncfile.dimensions['atom']

# Get variables
print "variables:"
print ncfile.variables.keys()
niterations = ncfile.variables['states'].shape[0]
print "Read %(niterations)d iterations, %(nstates)d states" % vars()

# DEBUG
#niterations = 20

# Load frames
for iteration in range(niterations):
    # Set coordinates
    for pdb_index in range(natoms):
        model_index = model_mapping[pdb_index]
        for k in range(3):
            model.atom[model_index].coord[k] = ncfile.variables['positions'][iteration, replica, pdb_index, k] * 10.0 # convert to angstroms
    cmd.load_model(model, 'complex', state=iteration+1)
    #cmd.load_model(model, 'complex')    

cmd.hide('all')
#cmd.show('lines', 'all')
cmd.show('cartoon', 'receptor')
cmd.show('sticks', 'ligand')

cmd.mset("1 -%d" % cmd.count_states())

# Align all states
cmd.intra_fit('all')

# Zoom viewport
cmd.zoom('complex')
cmd.orient('complex')
cmd.orient('ligand')

# Render movie
frame_prefix = 'frames/frame'
cmd.set('ray_trace_frames', 1)
for iteration in range(niterations):
#    # Set coordinates
#    for pdb_index in range(natoms):
#        model_index = model_mapping[pdb_index]
#        for k in range(3):
#            model.atom[model_index].coord[k] = ncfile.variables['positions'][iteration, replica, pdb_index, k] * 10.0 # convert to angstroms
#    cmd.load_model(model, 'complex', state=iteration+1)
    cmd.set('stick_transparency', float(ncfile.variables['states'][iteration, replica]) / float(nstates-1))
    cmd.mpng(frame_prefix, iteration+1, iteration+1)
    #cmd.load_model(model, 'complex')    

cmd.set('ray_trace_frames', 0)

# Close file
ncfile.close()

