"""
Modeling of point mutations with MODELLER.

Based on mutate_model.py from the MODELLER wiki (written by Ben Webb?):

http://salilab.org/modeller/wiki/Mutate%20model

TODO

* Allow multiple simultaneous point mutations.
* Allow inclusion of ligand during modelling.

"""

#===============================================================================
# IMPORTS
#===============================================================================

import os
import os.path
import sys

from modeller import *
from modeller.optimizers import molecular_dynamics, conjugate_gradients
from modeller.automodel import autosched

#===============================================================================
# DATA
#===============================================================================

three_letter_code = {
    'A' : 'ALA',
    'C' : 'CYS',
    'D' : 'ASP',
    'E' : 'GLU',
    'F' : 'PHE',
    'G' : 'GLY',
    'H' : 'HIS',
    'I' : 'ILE',
    'K' : 'LYS',
    'L' : 'LEU',
    'M' : 'MET',
    'N' : 'ASN',
    'P' : 'PRO',
    'Q' : 'GLN',
    'R' : 'ARG',
    'S' : 'SER',
    'T' : 'THR',
    'V' : 'VAL',
    'W' : 'TRP',
    'Y' : 'TYR'
}   

one_letter_code = dict()
for one_letter in three_letter_code.keys():
    three_letter = three_letter_code[one_letter]
    one_letter_code[three_letter] = one_letter

#===============================================================================
# SUBROUTINES
#===============================================================================

def optimize(atmsel, sched):
    #conjugate gradient
    for step in sched:
        step.optimize(atmsel, max_iterations=200, min_atom_shift=0.001)
    #md
    refine(atmsel)
    cg = conjugate_gradients()
    cg.optimize(atmsel, max_iterations=200, min_atom_shift=0.001)


#molecular dynamics
def refine(atmsel):
    # at T=1000, max_atom_shift for 4fs is cca 0.15 A.
    md = molecular_dynamics(cap_atom_shift=0.39, md_time_step=4.0,
                            md_return='FINAL')
    init_vel = True
    for (its, equil, temps) in ((200, 20, (150.0, 250.0, 400.0, 700.0, 1000.0)),
                                (200, 600,
                                 (1000.0, 800.0, 600.0, 500.0, 400.0, 300.0))):
        for temp in temps:
            md.optimize(atmsel, init_velocities=init_vel, temperature=temp,
                         max_iterations=its, equilibrate=equil)
            init_vel = False


#use homologs and dihedral library for dihedral angle restraints
def make_restraints(mdl1, aln):
   rsr = mdl1.restraints
   rsr.clear()
   s = selection(mdl1)
   for typ in ('stereo', 'phi-psi_binormal'):
       rsr.make(s, restraint_type=typ, aln=aln, spline_on_site=True)
   for typ in ('omega', 'chi1', 'chi2', 'chi3', 'chi4'):
       rsr.make(s, restraint_type=typ+'_dihedral', spline_range=4.0,
                spline_dx=0.3, spline_min_points = 5, aln=aln,
                spline_on_site=True)

def mutate(original_model_filename, mutations, chain='', verbose=False, seed=-49837, outfile=None):
    """
    Make a (multi)point mutation to a specified model, relaxing the local neighborhood of the mutation.

    ARGUMENTS
    
    original_model_filename (string) - filename of initial model; must be free of missing atoms or chain breaks
    mutation (string) - mutation to build as one-letter code, residue number, one-letter code; multiple mutations may be separated by spaces
       examples: '', 'M314L', 'M314L L325Y', 'M314L L325Y T338M'

    OPTIONAL ARGUMENTS

    chain (string) - chain identifier for source chain to mutate, or blank if no chain (default: '')    
    verbose (boolean) - if True, will give verbose output (default: False)
    seed (int) - random seed; change if different models are desired (default: -49837)

    NOTES

    If no mutation is specified (mutations = ''), the initial model will be copied to outfile if specfied; otherwise, no file will be generated.

    TODO

    * Build missing atoms for 'modelname' if possible?
    * Use OS temporary file creation routines for safety.
    * Allow user-specified topology and parameter files.

    """

    if verbose: log.verbose()

    # Get filepath of original model.
    import os.path
    original_model_fullpath = os.path.abspath(original_model_filename)
    if outfile is None:
        outfile = os.path.join(original_model_filename + ' ' + mutations, '.pdb')
    mutated_model_fullpath = os.path.abspath(outfile)
            
    # Change to a temporary working directory.
    import os, tempfile
    original_directory = os.getcwd()
    temporary_directory = tempfile.mkdtemp()
    os.chdir(temporary_directory)
    if verbose: print "Using temporary working directory: %s" % temporary_directory

    # Copy original model into temporary director.
    import shutil
    original_modelname = 'original.pdb'
    shutil.copyfile(original_model_fullpath, original_modelname)

    # Set a different value for rand_seed to get a different final model
    env = environ(rand_seed=seed)

    # Read any heteroatoms in source PDB file.
    env.io.hetatm = True

    # Select potential to use.
    # TODO: Add second-stage refinement with Lennard-Jones later?
    env.edat.dynamic_sphere=True # soft sphere
    env.edat.dynamic_lennard=False # Lennard-Jones
    env.edat.contact_shell = 4.0
    env.edat.update_dynamic = 0.39

    # Read customized topology file with phosphoserines (or standard one)
    env.libs.topology.read(file='$(LIB)/top_heav.lib')
    
    # Read customized CHARMM parameter library with phosphoserines (or standard one)
    env.libs.parameters.read(file='$(LIB)/par.lib')

    # Read the original PDB file and copy its sequence to the alignment array:
    mdl1 = model(env, file=original_modelname)
    ali = alignment(env)
    ali.append_model(mdl1, atom_files=original_modelname, align_codes=original_modelname)

    # Process mutations, making mutations to sequence and building a list of all mutated residues.
    mutated_residue_indices = list()
    for mutation in mutations.split():
        # Parse mutation into original residue one-letter code, residue index, and mutated residue one-letter code
        import re
        match = re.match('(\D)(\d+)(\D)', mutation)
        original_residue_name = three_letter_code[match.group(1)] # original residue three-letter code
        residue_index = match.group(2) # residue index to mutate
        mutated_residue_name = three_letter_code[match.group(3)] # new residue three-letter code

        # Check the original residue name corresponds to our source model.
        current_residue_name = mdl1.chains[chain].residues[residue_index].name
        if current_residue_name != original_residue_name:
            raise Exception("Original residue type is '%s', was supposed to be '%s' for mutation '%s'." % (current_residue_name, original_residue_name, mutation))

        # Select residue to be mutated based on index
        residue_selection = selection(mdl1.chains[chain].residues[residue_index])

        # Mutate the residue.
        residue_selection.mutate(residue_type=mutated_residue_name)
        
        # Add mutated residue to list of mutated residues.
        mutated_residue_indices.append(residue_index)

    # If no mutations are to be made, copy the initial model over.
    if len(mutated_residue_indices) == 0:
        if outfile is not None:
            # Copy to output path.
            import shutil
            shutil.copy(original_modelname, mutated_model_fullpath)
        return
        
    # Append the mutated sequence to the alignment.
    ali.append_model(mdl1, align_codes=original_modelname)

    # Generate molecular topology for mutant.
    mdl1.clear_topology()
    mdl1.generate_topology(ali[-1])

    # Transfer all the coordinates you can from the template native structure
    # to the mutant (this works even if the order of atoms in the native PDB
    # file is not standard):
    #here we are generating the model by reading the template coordinates
    mdl1.transfer_xyz(ali)

    # Build the remaining unknown coordinates from internal coordinate tables.
    mdl1.build(initialize_xyz=False, build_method='INTERNAL_COORDINATES')

    # Note that model2 is the same file as model1.
    mdl2 = model(env, file=original_modelname)

    # Transfer residue numbering.
    #ali.append_model(mdl2, atom_files=original_modelname, align_codes=original_modelname)
    #transfers from "model 2" to "model 1"
    mdl1.res_num_from(mdl2,ali)

    #It is usually necessary to write the mutated sequence out and read it in
    #before proceeding, because not all sequence related information about MODEL
    #is changed by this command (e.g., internal coordinates, charges, and atom
    #types and radii are not updated).
    import tempfile
    temporary_file = tempfile.NamedTemporaryFile(delete=False)
    temporary_filename = temporary_file.name
    mdl1.write(file=temporary_filename)
    mdl1.read(file=temporary_filename)
    temporary_file.close()
    os.unlink(temporary_filename)

    #set up restraints before computing energy
    #we do this a second time because the model has been written out and read in,
    #clearing the previously set restraints
    make_restraints(mdl1, ali)

    #a non-bonded pair has to have at least as many selected atoms
    mdl1.env.edat.nonbonded_sel_atoms=1

    # Create annealing schedule.
    sched = autosched.loop.make_for_model(mdl1)

    #only optimize the selected residue (in first pass, just atoms in selected
    #residue, in second pass, include nonbonded neighboring atoms)
    #set up the mutate residue selection segment    
    refinement_selection = selection()
    #refinement_selection = selection(mdl1).hot_atoms(pick_hot_cutoff=4.0).by_residue() # pick residues that need to be cleaned up
    for residue_index in mutated_residue_indices:        
        residue_selection = selection(mdl1.chains[chain].residues[residue_index])
        refinement_selection.add(residue_selection)
    mdl1.restraints.unpick_all()
    mdl1.restraints.pick(refinement_selection)

    # Compute energy.
    refinement_selection.energy()
    
    # Perturb coordinates.
    #refinement_selection.randomize_xyz(deviation=4.0)
    
    # Minimize energy.
    mdl1.env.edat.nonbonded_sel_atoms=2
    optimize(refinement_selection, sched)

    #feels environment (energy computed on pairs that have at least one member
    #in the selected)
    mdl1.env.edat.nonbonded_sel_atoms=1
    optimize(refinement_selection, sched)

    # Compute final energy.
    refinement_selection.energy()
    
    # DEBUG
    mdl1.write(file='stage1.pdb')

    #
    # Second stage of refinement
    #
    env.edat.dynamic_sphere=False # soft sphere
    env.edat.dynamic_lennard=True # Lennard-Jones

    refinement_radius = 6.0
    refinement_selection = selection()
    for residue_index in mutated_residue_indices:        
        residue_selection = selection(mdl1.chains[chain].residues[residue_index])
        refinement_selection.add(residue_selection)
        #refinement_selection.add(residue_selection.select_sphere(refinement_radius))
    mdl1.restraints.unpick_all()
    mdl1.restraints.pick(refinement_selection)

    #refinement_selection.randomize_xyz(deviation=4.0)

    optimize(refinement_selection, sched)

    refinement_selection.energy()

    # DEBUG
    mdl1.write(file='stage2.pdb')

    # Write file.
    mdl1.write(file='mutated.pdb')
    import shutil
    shutil.copyfile('mutated.pdb', mutated_model_fullpath)

    # Change back to original directory.
    os.chdir(original_directory)

    # Clean up temporary directory.
#    for filename in os.listdir(temporary_directory):
#        os.unlink(os.path.join(temporary_directory, filename))
#    os.remove(temporary_directory)    

    return

#===============================================================================
# MAIN
#===============================================================================

# TODO: Fix this command-line driver to use optparse.

if __name__ == '__main__':
    # Test on some data.
    modelname = 'basemodel.pdb'
    chain = 'B'
    mutations = 'M314L L325Y T338M' # kinases that bind bosutinib tightly but do not look like Src have these three mutations
    mutate(modelname, mutations, chain=chain, verbose=True, seed=-49837, outfile='final.pdb')
