#!/usr/bin/env python

"""
Example script illustrating Stark effect calculation using StarkShiftReporter with the OpenMM app.

TODO

* Add ability to resume simulations.

"""

#=============================================================================================
# GLOBAL IMPORTS
#=============================================================================================

import os
import os.path
import sys
import math
import time

# Import OpenMM, units package, and application layer.
import simtk.unit as units
import simtk.openmm as openmm
import simtk.openmm.app as app

#=============================================================================================
# SUBROUTINES
#=============================================================================================

def readPdbAtoms(pdbfilename):
    """
    Extract the ATOM and HETATM information from a PDB file.
    
    ARGUMENTS    
      pdbfilename (string) - the filename of the PDB file to be read
      
    RETURNS
      atoms (list of dict) - atoms[index] is a dict of information about atom 'index'

    """
        
    # Read the PDB file into memory.
    pdbfile = open(pdbfilename, 'r')
    lines = pdbfile.readlines()
    pdbfile.close()

    # Extract the sequence for which there are defined atomic coordinates.
    atoms = list()
    for line in lines:
        recordtype = line[0:6] 
        if recordtype in ['ATOM  ', 'HETATM']:            
            # Parse line into fields.
            atom = dict()
            atom["serial"] = int(line[6:11])
            atom["name"] = line[12:16]
            atom["altLoc"] = line[16:17]
            atom["resName"] = line[17:20]
            atom["chainID"] = line[21:22]
            atom["resSeq"] = int(line[22:26])
            atom["iCode"] = line[26:27]

            atom["recordtype"] = recordtype

            atoms.append(atom)
            
    return atoms

def findAtoms(atoms, resName=None, name=None, chainID=None, resSeq=None):
    """
    Find the list of atom indices that match the given query criteria.
    
    OPTIONAL ARGUMENTS
      resName (string) - residue name
      name (string) - atom name
      chainID (string) - chain ID
      resSeq (int) - residue sequence number

    RETURNS
      atom_indices (list) - list of atom indices that match query criteria

    """
    atom_indices = list() # matches
    for (atom_index, atom) in enumerate(atoms):
        if resName and (resName != atom["resName"]):
            continue
        if name and (name.strip() != atom["name"].strip()):
            continue
        if chainID and (chainID != atom["chainID"]):
            continue        
        if resSeq and (str(resSeq) != atom["resSeq"]):
            continue
        atom_indices.append(atom_index)

    return atom_indices                

def simulate_stark_shift(prmtop_filename, inpcrd_filename, pdb_filename, alpha, atom_indices, output_directory, gpuid=None, target_simulation_length=10.0 * units.picoseconds, verbose=False):
    """
    OPTIONAL ARGUMENTS
      target_simulation_length (simtk.unit.Quantity with units compatible with simtk.unit.picoseconds) - target simulation length
      verbose (bool) - if True, will write state data to stdout
    
    """
    
    # Set the CUDA and OpenCL device IDs if instructed.
    if gpuid:
        openmm.Platform.getPlatformByName('Cuda').setPropertyDefaultValue('CudaDeviceIndex', '%d' % gpuid)
        openmm.Platform.getPlatformByName('OpenCL').setPropertyDefaultValue('OpenCLDeviceIndex', '%d' % gpuid)

    # Make output directory if it doesn't exist.
    if not os.path.exists(output_directory):
        os.makedirs(output_directory)

    # Set fixed simulation parameters.
    # TODO: Make these user-tunable when this is transformed into a class.
    pressure = 1.0 * units.atmospheres
    temperature = 298.0 * units.kelvin
    collision_rate = 9.1 / units.picosecond
    timestep = 2.0 * units.femtoseconds
    barostat_frequency = 50 # number of steps between Monte Carlo volume moves
    minimization_steps = 20 # number of minimization steps
    nsteps_trajectory_pdb = 5000 # number of steps in between writing to PDB (10 ps)
    nsteps_final_pdb = 500 # number of steps in between overwriting current structure to PDB (1 ps)
    nsteps_netcdf = 500 # number of steps in between writing to NetCDF (1 ps)
    nsteps_efield = 50 # number of steps between reporting E field (0.1 ps)
    nsteps_report = 50 # number of steps between reporting energy and temperature (0.1 ps)
    targetStep = int(target_simulation_length / timestep) # total number of steps to run 

    # Load the AMBER system.
    print "Creating AMBER system..."
    inpcrd = app.AmberInpcrdFile(inpcrd_filename)
    prmtop = app.AmberPrmtopFile(prmtop_filename)  
    #system = prmtop.createSystem(nonbondedMethod=app.CutoffPeriodic, constraints=app.HBonds)
    system = prmtop.createSystem(nonbondedMethod=app.PME, constraints=app.HBonds) # Doesn't work on rickhouse?
    positions = inpcrd.getPositions()
            
    # Add a barostat to the system.
    system.addForce(openmm.MonteCarloBarostat(pressure, temperature, barostat_frequency))

    # Create Langevin integrator.
    print "Creating integrator..."
    integrator = openmm.LangevinIntegrator(temperature, collision_rate, timestep)

    # Create OpenMM app Simulation object.
    print "Creating Simulation object..."
    simulation = app.Simulation(prmtop.topology, system, integrator)

    # If a simulation exists, try to resume.
    filename = os.path.join(output_directory, 'trajectory.nc')
    resume = False
    if os.path.exists(filename):
        resume = True

    if resume:
        try:
            print "Attempting to resume from NetCDF trajectory..."
            from netcdfreporter import NetCDFReporter
            [positions, box_vectors, last_step] = NetCDFReporter.getLastFrame(filename)
            simulation.context.setPositions(positions)
            print box_vectors
            simulation.context.setPeriodicBoxVectors(box_vectors[0,:], box_vectors[1,:], box_vectors[2,:])
            simulation.currentStep = last_step
            resume = True
            # Return if we've already reached target simulation time.
            if targetStep <= simulation.currentStep: return            
        except Exception as e:
            raise e
            print "Could not resume from file."
            resume = False

    if not resume:    
        # Set positions in Context.
        print "Setting positions..."
        simulation.context.setPositions(positions)    
        
        # Write initial positions.
        print "Writing initial positions to PDB file..."
        filename = os.path.join(output_directory, 'initial.pdb')
        app.PDBFile.writeFile(prmtop.topology, positions, open(filename, 'w'))

        # Compute initial energy.
        state = simulation.context.getState(getEnergy=True)
        potential = state.getPotentialEnergy()
        print "Initial energy is %.3f kcal/mol" % (potential / units.kilocalories_per_mole)

        # Minimize energy.
        print "Minimizing energy..."
        simulation.minimizeEnergy(maxIterations=minimization_steps)
        
        # Compute minimized energy.
        state = simulation.context.getState(getEnergy=True)
        potential = state.getPotentialEnergy()
        print "Minimized is %.3f kcal/mol" % (potential / units.kilocalories_per_mole)
        
        # Write PDB file.
        print "Writing minimized positions to PDB file..."
        positions = simulation.context.getState(getPositions=True).getPositions()
        filename = os.path.join(output_directory, 'minimized.pdb')
        app.PDBFile.writeFile(prmtop.topology, positions, open(filename, 'w'))

    if verbose:
        # TODO: Have this written to a file instead.
        # Add state data reporter to print energy and temperature to the console.
        from statedatareporter import StateDataReporter
        simulation.reporters.append(StateDataReporter(sys.stdout, nsteps_report, step=True, potentialEnergy=True, temperature=True, volume=True)) 

    # Add PDB reporter to write trajectory frames.
    from pdbreporter import PDBReporter
    filename = os.path.join(output_directory, 'trajectory.pdb')
    simulation.reporters.append(PDBReporter(filename, nsteps_trajectory_pdb, append=resume))

    # Add PDB reporter to write only final frame.
    from pdbreporter import PDBReporter
    filename = os.path.join(output_directory, 'final.pdb')
    simulation.reporters.append(PDBReporter(filename, nsteps_final_pdb, finalFrameOnly=True)) 

    # Add NetCDF reporter.
    from netcdfreporter import NetCDFReporter
    filename = os.path.join(output_directory, 'trajectory.nc')
    simulation.reporters.append( NetCDFReporter(system, filename, nsteps_netcdf, writePositions=True, writeVelocities=False, writeEnergies=True, append=resume) )
    
    # Add a Stark shift reporter.
    from starkshiftreporter import StarkShiftReporter
    filename = os.path.join(output_directory, 'stark.nc')
    simulation.reporters.append( StarkShiftReporter(system, atom_indices, alpha, interval=nsteps_efield, filename=filename, format='netcdf', append=resume) )

    # Run simulation.
    print "Running simulation..."
    simulation.step(targetStep - simulation.currentStep)

    # Completed!
    print "Done."

    # TODO: Analyze?

    return

#===============================================================================
# MAIN
#===============================================================================

if __name__ == '__main__':
    # Test on some data.

    # System-specific information.
    # bosutinib in solvent
    prmtop_filename = 'bosutinib.prmtop'
    inpcrd_filename = 'bosutinib.crd'
    pdb_filename = 'bosutinib.pdb'

    #basedir = "src_tbosutinib_c4+d2_refine_waters_tls_38/WT/"
    #prmtop_filename = os.path.join(basedir, 'leap.complex.prmtop')
    #inpcrd_filename = os.path.join(basedir, 'leap.complex.inpcrd')
    #pdb_filename = os.path.join(basedir, 'leap.complex.pdb')
                                   
    alpha = 0.87 * (units.centimeters**-1) / (units.mega*units.volts/units.centimeter) # linear Stark tuning rate of bosutinib
    
    # Find atom indices by simple match criteria.
    atoms = readPdbAtoms(pdb_filename)
    atom_indices = findAtoms(atoms, resName='DB8', name='C3') + findAtoms(atoms, resName='DB8', name='N1')
    print atom_indices

    output_directory = 'stark-test'    
    simulate_stark_shift(prmtop_filename, inpcrd_filename, pdb_filename, alpha, atom_indices, output_directory)

