#!/usr/bin/env python

"""
Example script illustrating Stark effect calculation using StarkShiftReporter with the OpenMM app.

"""

#=============================================================================================
# GLOBAL IMPORTS
#=============================================================================================

import sys
import math
import doctest
import time
import numpy

# Import OpenMM, units package, and application layer.
import simtk.unit as units
import simtk.openmm as openmm
import simtk.openmm.app as app

#=============================================================================================
# SUBROUTINES
#=============================================================================================

def readPdbAtoms(pdbfilename):
    """
    Extract the ATOM and HETATM information from a PDB file.
    
    ARGUMENTS    
      pdbfilename (string) - the filename of the PDB file to be read
      
    RETURNS
      atoms (list of dict) - atoms[index] is a dict of information about atom 'index'

    """
        
    # Read the PDB file into memory.
    pdbfile = open(pdbfilename, 'r')
    lines = pdbfile.readlines()
    pdbfile.close()

    # Extract the sequence for which there are defined atomic coordinates.
    atoms = list()
    for line in lines:
        recordtype = line[0:6] 
        if recordtype in ['ATOM  ', 'HETATM']:            
            # Parse line into fields.
            atom = dict()
            atom["serial"] = int(line[6:11])
            atom["name"] = line[12:16]
            atom["altLoc"] = line[16:17]
            atom["resName"] = line[17:20]
            atom["chainID"] = line[21:22]
            atom["resSeq"] = int(line[22:26])
            atom["iCode"] = line[26:27]

            atom["recordtype"] = recordtype

            atoms.append(atom)
            
    return atoms

def findAtoms(atoms, resName=None, name=None, chainID=None, resSeq=None):
    """
    Find the list of atom indices that match the given query criteria.
    
    OPTIONAL ARGUMENTS
      resName (string) - residue name
      name (string) - atom name
      chainID (string) - chain ID
      resSeq (int) - residue sequence number

    RETURNS
      atom_indices (list) - list of atom indices that match query criteria

    """
    atom_indices = list() # matches
    for (atom_index, atom) in enumerate(atoms):
        if resName and (resName != atom["resName"]):
            continue
        if name and (name.strip() != atom["name"].strip()):
            continue
        if chainID and (chainID != atom["chainID"]):
            continue        
        if resSeq and (str(resSeq) != atom["resSeq"]):
            continue
        atom_indices.append(atom_index)

    return atom_indices                

#=============================================================================================
# MAIN
#=============================================================================================

# System-specific information.
# bosutinib in solvent
prmtop_filename = 'bosutinib.prmtop'
inpcrd_filename = 'bosutinib.crd'
pdb_filename = 'bosutinib.pdb'
alpha = 0.87 * (units.centimeters**-1) / (units.mega*units.volts/units.centimeter) # linear Stark tuning rate of bosutinib

#atom_indices = [3-1, 19-1] # atoms defining Stark vibrational group; pymol selections: ['resn DB8 and name C3', 'resn DB8 and name N1']
#                           # be sure to use OpenMM System numbering (starts at 0) rather than PDB numbering (starts at 1)
# Find atom indices by simple match criteria.
atoms = readPdbAtoms(pdb_filename)
atom_indices = findAtoms(atoms, resName='DB8', name='C3') + findAtoms(atoms, resName='DB8', name='N1')
print atom_indices

# Simulation parameters.
pressure = 1.0 * units.atmospheres
temperature = 298.0 * units.kelvin
collision_rate = 9.1 / units.picosecond
timestep = 2.0 * units.femtoseconds
nsteps_per_picosecond = 500
barostat_frequency = 50 # number of steps between Monte Carlo volume moves
minimization_steps = 20 # number of minimization steps
nsteps_pdb = 5000 # number of steps in between writing to PDB (10 ps)
nsteps_netcdf = 5000 # number of steps in between writing to NetCDF (10 ps)
nsteps_efield = 50 # number of steps between reporting E field (0.1 ps)
nsteps_report = 50 # number of steps between reporting energy and temperature (0.1 ps)
nsteps = 1000 * nsteps_per_picosecond # total number of steps to run (1 ns)

# Load the AMBER system.                                                                                                                       
print "Creating AMBER system..."                                                                                                               
inpcrd = app.AmberInpcrdFile(inpcrd_filename)                                                                                                  
prmtop = app.AmberPrmtopFile(prmtop_filename)                                                                                                  
#system = prmtop.createSystem(nonbondedMethod=app.CutoffPeriodic, constraints=app.HBonds)
system = prmtop.createSystem(nonbondedMethod=app.PME, constraints=app.HBonds)
positions = inpcrd.getPositions()

# Add a barostat to the system.
system.addForce(openmm.MonteCarloBarostat(pressure, temperature, barostat_frequency))

# Create Langevin integrator.
print "Creating integrator..."
integrator = openmm.LangevinIntegrator(temperature, collision_rate, timestep)

# Create OpenMM app Simulation object.
print "Creating Simulation object..."
simulation = app.Simulation(prmtop.topology, system, integrator)

# Set positions in Context.
print "Setting positions..."
simulation.context.setPositions(positions)

# Write initial positions.
print "Writing initial positions to PDB file..."
app.PDBFile.writeFile(prmtop.topology, positions, open('initial.pdb', 'w'))

# Minimize energy.
print "Minimizing energy..."
simulation.minimizeEnergy(maxIterations=minimization_steps)

# Write PDB file.
print "Writing minimized positions to PDB file..."
positions = simulation.context.getState(getPositions=True).getPositions()
app.PDBFile.writeFile(prmtop.topology, positions, open('minimized.pdb', 'w'))

# Add PDB reporter to write a frames every so often.
simulation.reporters.append(app.PDBReporter('trajectory.pdb', nsteps_pdb)) 

# Add NetCDF reporter.
from netcdfreporter import NetCDFReporter
simulation.reporters.append( NetCDFReporter(system, 'trajectory.nc', nsteps_netcdf, writePositions=True, writeVelocities=False, writeEnergies=True) )

# Add state data reporter to print energy and temperature.
simulation.reporters.append(app.StateDataReporter(sys.stdout, nsteps_report, step=True, potentialEnergy=True, temperature=True)) 

# Add a Stark shift reporter.
from starkshiftreporter import StarkShiftReporter
simulation.reporters.append( StarkShiftReporter(system, atom_indices, alpha, interval=10, filename='stark.nc', format='netcdf') )

# Run simulation.
print "Running simulation..."
simulation.step(nsteps)

# Completed!
print "Done."

