#=============================================================================================
# model.py
#
# Use MODELLER to complete PDB files that are missing atoms and residues.
#
# Written 2007-05-20 by
# Gregory R. Bowman <gbowman@stanford.edu>
# Biophysics Program, Stanford University
# Pande group
#
# Modifications 2012-08-08 by
# John D. Chodera <jchodera@berkeley.edu>
# California Institute for Quantitative Biosciences (QB3)
# University of California, Berkeley
#=============================================================================================
# TODO:
# * Add the ability to fetch a PDB file from the RCSB for source files and/or templates.
#=============================================================================================
# CHANGELOG:
#=============================================================================================


#=============================================================================================
# GLOBAL IMPORTS
#=============================================================================================

import os
import os.path
import shutil
import tempfile
import commands

import modeller

#=============================================================================================
# SUBROUTINES
#=============================================================================================

one_letter_code = { 'ALA' : 'A',
                    'ASN' : 'N',
                    'ASP' : 'D',
                    'ARG' : 'R',
                    'CYS' : 'C',
                    'GLN' : 'Q',
                    'GLU' : 'E',               
                    'GLY' : 'G',
                    'HIS' : 'H',                
                    'ILE' : 'I',
                    'LEU' : 'L',
                    'NLE' : 'n',
                    'LYS' : 'K',
                    'MET' : 'M',
                    'PRO' : 'P',
                    'PHE' : 'F',
                    'SER' : 'S',
                    'THR' : 'T',
                    'TRP' : 'W',
                    'TYR' : 'Y',
                    'VAL' : 'V',
                    'UNK' : 'X' }

three_letter_code = dict()
for key in one_letter_code.keys():
    value = one_letter_code[key]
    three_letter_code[value] = key

def read_file(filename):
    infile = open(filename, 'r')
    lines = infile.readlines()
    infile.close()
    return lines

def write_file(filename, contents):
    outfile = open(filename, 'w')
    if type(contents) == list:
        for line in contents:
            outfile.write(line)
            if line[-1] != '\n':
                outfile.write('\n')
    else:
        outfile.write(contents)
    outfile.close()
    return

def build_manual_alignment(template_pdb_filename, alignment_filename, retain_hetatm=False, retain_water=False, debug=False):
    """
    Build a manual alignment between the residues present in the template PDB
    file and the full sequence defined in the SEQRES records in the file.
    
    REQUIRED ARGUMENTS
      template_pdb_filename (string) - filename of the PDB template
      alignment_filename (string) - filename for the alignment file to be written in PIR format
    
    OPTIONAL ARGUMENTS
      retain_hetatm (boolean) - if True, HETATMs will be included
      retain_water (boolean) - if True, waters will be included
      debug (boolean) - print additional information if True (default False)
        
    NOTES
   
    This function uses the DBREF and SEQRES fields which are required by the RCSB in deposited
    files but may not be present in PDB files from other sources.

    """
    
    # Read the te,plate PDB file into memory.
    infile = open(template_pdb_filename, 'r')
    lines = infile.readlines()
    infile.close()

    # TODO: Throw an exception if required fields are not found.

    water_residue_names = ['HOH', 'WAT', 'TIP', 'TP3']

    # Get the complete sequence for all present chains from the SEQRES fields.
    if debug: print "Processing SEQRES fields..."
    target_sequences = dict() # target_sequences[chainid] is a list of amino acids (one-letter-code) for each chain
    for line in lines:
        if line[0:6] == 'SEQRES':
            # Parse line into fields.
            serNum = int(line[8:10])
            chainID = line[11:12]
            numRes = line[13:17]
            resNames = line[19:70].split()
            
            # Create a new entry if needed.
            if chainID not in target_sequences.keys():
                target_sequences[chainID] = ""
            # Append one-letter-code residue names to sequence.
            for resName in resNames:
                target_sequences[chainID] += one_letter_code[resName]
    if debug:
        print "From SEQRES read:"
        print target_sequences

    # Extract the first residue for each from DBREF field.
    first_residue_indices = dict() # first_residue_indices[chainid] is the first residue index if chain 'chainid'
    dbref = dict() # dbref[chainid] is the DBREF entry for chain 'chainid'
    for line in lines:
        if line[0:6] == "DBREF ":
            # Parse line into fields.
            field = dict()
            field["idCode"] = line[7:11]
            field["chainID"] = line[12:13]
            field["seqBegin"] = int(line[14:18])
            field["insertBegin"] = line[18:19]
            field["seqEnd"] = int(line[20:24])
            field["insertEnd"] = line[24:25]
            field["database"] = line[26:32]
            field["dbAccession"] = line[33:41]
            field["dbIdCode"] = line[42:54]
            field["dbseqBegin"] = line[55:60]
            field["idbnsBeg"] = line[60:61]

            dbref[field['chainID']] = field

            # Guess that first residue is given by seqBegin
            first_residue_indices[field['chainID']] = field['seqBegin']            

    # Process SEQADV records, if present.
    # COLUMNS       DATA TYPE       FIELD      DEFINITION
    # -----------------------------------------------------------------
    #  1 -  6       Record name     "SEQADV"
    #  8 - 11       IDcode          idCode    ID code of this entry.
    # 13 - 15       Residue name    resName   Name of the PDB residue in conflict.
    # 17            Character       chainID   PDB chain identifier.
    # 19 - 22       Integer         seqNum    PDB sequence number.
    # 23            AChar           iCode     PDB insertion code.
    # 25 - 28       LString         database  
    # 30 - 38       LString         dbIdCode  Sequence database accession number.
    # 40 - 42       Residue name    dbRes     Sequence database residue name.
    # 44 - 48       Integer         dbSeq     Sequence database sequence number.
    # 50 - 70       LString         conflict  Conflict comment.        
    
    for line in lines:
        if line[0:6] == 'SEQADV':
            if debug: print line,
            # Parse line into fields.
            field = dict()
            field['idCode'] = line[7:11]
            field['resName'] = line[12:15]
            field['chainID'] = line[16:17]
            field['seqNum'] = int(line[18:22])
            field['iCode'] = line[22:23]
            field['database'] = line[24:28]
            field['dbIdCode'] = line[29:38]
            field['dbRes'] = line[39:42]
            field['dbSeq'] = line[43:48]
            field['conflict'] = line[49:70]
            
            # If SEQADV has an earlier-numbered residue than DBREF, change the first or last residue number.
            # This is often due to the addition of N-terminal "CLONING ARTIFACT" residues.
            chainid = field['chainID']
            if (field['seqNum'] < first_residue_indices[chainid]):
                first_residue_indices[chainid] = field['seqNum']

    # Extract atoms from PDB file.
    atoms = list()
    for line in lines:
        recordtype = line[0:6] 
        if (recordtype == "ATOM  ") or (recordtype == "HETATM"):
            # Parse line into fields.
            atom = dict()
            atom["serial"] = int(line[6:11])
            atom["name"] = line[12:16]
            atom["altLoc"] = line[16:17]
            atom["resName"] = line[17:20]
            atom["chainID"] = line[21:22]
            atom["resSeq"] = int(line[22:26])
            atom["iCode"] = line[26:27]
            atom["recordtype"] = recordtype
            
            include = False
            if (recordtype == 'ATOM  '):
                include = True
            if (recordtype == 'HETATM') and retain_water and (atom['resName'] in water_residue_names):
                include = True
            if (recordtype == 'HETATM') and retain_hetatm and (atom['resName'] not in water_residue_names):                
                include = True
            if include:
                atoms.append(atom)
    
    # Build a list of residues present in the PDB file, including HETATM residues.
    residues = list()
    for atom in atoms:
        residue = dict()
        residue['resName'] = atom['resName']
        residue['chainID'] = atom['chainID']
        residue['resSeq'] = atom['resSeq']
        residue['recordtype'] = atom['recordtype']
        # Add residue to list only if it is different from previous residue.
        if (len(residues)==0) or (residue != residues[-1]):
            residues.append(residue)
    if debug: 
        print "Residues appearing in ATOM and HETATM records:"
        for residue in residues:
            print "%c %8d %3s %6s" % (residue['chainID'], residue['resSeq'], residue['resName'], residue['recordtype'])

    # Build a list of the order in which chains appear in the PDB file ATOM and HETATM records.
    chain_list = list()
    for residue in residues:
        if (len(chain_list)==0) or (residue['chainID'] != chain_list[-1]):
            chain_list.append(residue['chainID'])
    if debug: print chain_list

    # Compile sequences for SEQRES sequences.
    for chainid in chain_list:
        if chainid not in target_sequences.keys():
            target_sequences[chainid] = ""
            for residue in residues:
                if (residue['chainID'] == chainid):
                    if residue['resName'] in water_residue_names:
                        target_sequences[chainid] += "w"
                    else:
                        target_sequences[chainid] += "."
                    if chainid not in first_residue_indices.keys():
                        first_residue_indices[chainid] = int(residue['resSeq'])
    if debug:
        print "target_sequences = "
        print target_sequences

    # Build a list of sequences present in the PDB.
    template_sequences = dict()
    for chainid in chain_list:
        sequence = target_sequences[chainid]
        if debug: print chainid + " : " + sequence
        if sequence[0] in ['.', 'w']:
            # Copy BLK residue spans for HETATMs.
            template_sequences[chainid] = sequence
        else:            
            template_sequences[chainid] = ""
            # Build a set of residues present in this chain.
            residues_in_chain = set()
            for residue in residues:
                if (residue['chainID'] == chainid):
                    residues_in_chain.add(residue['resSeq'])
            # Add residues that exist.
            first_residue_index = int(first_residue_indices[chainid])
            for (position_index, residue_code) in enumerate(sequence):
                resSeq = position_index + first_residue_index
                if resSeq in residues_in_chain:
                    template_sequences[chainid] += residue_code
                else:
                    template_sequences[chainid] += "-"

    # Build the sequence string for the template.
    template_sequence = ""
    target_sequence = ""
    for chainid in chain_list:
        template_sequence += template_sequences[chainid] + "/"
        target_sequence += target_sequences[chainid] + "/"
    template_sequence = template_sequence[:-1] + "*" 
    target_sequence = target_sequence[:-1] + "*"

    # Compute last residue indices.
    last_residue_indices = dict()
    for chainid in chain_list:
        last_residue_indices[chainid] = first_residue_indices[chainid] + len(template_sequences[chainid]) - 1

    # Extract first and last residus present in template.
    first_template_residue = atoms[0]['resSeq']
    last_template_residue = atoms[-1]['resSeq']
    first_chainid = atoms[0]['chainID']
    last_chainid = atoms[-1]['chainID']
    first_target_residue = first_residue_indices[first_chainid]
    last_target_residue = last_residue_indices[last_chainid]

    # Build the manual alignment.
    alignment_file = open(alignment_filename, 'w')
    print >> alignment_file, ">P1;%s" % "template"
    print >> alignment_file, "%s:%s:%s:%s:%s:%s:%s:%s:%s:%s" % ( "structure", "template", first_template_residue, first_chainid, last_template_residue, last_chainid, " ", " ", " ", " " )
    print >> alignment_file, "%s" % template_sequence
    print >> alignment_file, ""    
    print >> alignment_file, ">P1;%s" % "target"
    print >> alignment_file, "%s:%s:%s:%s:%s:%s:%s:%s:%s:%s" % ( "sequence", "target", first_target_residue, first_chainid, last_target_residue, last_chainid, " ", " ", " ", " " )
    print >> alignment_file, "%s" % target_sequence
    alignment_file.close()

    # Write alignment file.
    if debug:
        lines = read_file(alignment_filename)
        for line in lines: print line,

    renumber_residues = list()
    for chainid in chain_list:
        renumber_residues.append(first_residue_indices[chainid])
    return [chain_list, renumber_residues]

def model_missing_atoms(pdbFilename, outputFilename, debug=False, allHydrogen=False, templates=None, retainOriginalNumbering=True, retainHETATM=True, retainWater=True):
    """
    Model missing atoms/residues in a specified PDB file using MODELLER.
    
    REQUIRED ARGUMENTS
    pdbFilename - the filename of the PDB file to model missing atoms and residues for
    outputFilename - the filename for the desired final model
    
    OPTIONAL ARGUMENTS
    debug - flag to print extra debug output and leave temporary directory (default False)
    templates (list of string) - additional templates to use in modeling missing atoms and residues (default None)
    retainOriginalNumbering (boolean) - if True, will retain original numbering; otherwise will number from 1 (default True)
    retainHETATM (boolean) - if True, will include HETATMs (default False)
    
    NOTES
    
    The specified chain from pdbFilename is processed through MODELLER to build missing
    atoms and residues specified in the SEQRES entry of the PDB file but not present in
    the PDB file.
    
    This procedure is loosely based on the protocol appearing at
    
    http://salilab.org/modeller/wiki/Missing_residues
    
    The complete sequence is read from the SEQRES fields, and the DBREF field used to
    determine the span of residues described in the SEQRES fields.  A heavy-atom topology
    as constructed in MODELLER for the complete sequence, coordinates present in the PDB file
    transferred, and the remaining heavy-atom coordinates built from ideal geometry.
    Finally, a single standard simulated-annealing-based modeling step is performed using
    the standard automodel protocol but allowing only the atoms and residues that were undefined in
    the PDB file to move.
    
    """
    
    # Ensure specified PDB file exists.
    if not os.path.exists(pdbFilename): raise ParameterException("Specified PDB file %s not found." % pdbFilename)        
        
    # Append full path to output filename.
    pdbFilename = os.path.abspath(pdbFilename)
    outputFilename = os.path.abspath(outputFilename)

    # Store original directory name.
    original_directory = os.getcwd()
                
    # Create a temporary directory for running MODELLER.
    tmpdir = tempfile.mkdtemp()
    os.chdir(tmpdir)
    if debug: print "tmpdir = %s" % tmpdir

    # Copy template PDB.
    template_pdb_filename = 'template.pdb'
    print "Copying %s to %s" % (pdbFilename, template_pdb_filename)
    shutil.copy(pdbFilename, template_pdb_filename)
    print commands.getoutput('head -n 20 %(template_pdb_filename)s' % vars())    
    
    # Copy any additional specified templates to temporary directory.
    # TODO: Use MODELLER to read/write PDB file?
    # TODO: What if template_pdb_filename contains a pathname?
    for template_name in templates:
        # Create a local copy of the template, stripping out HETATM lines.
        original_template_pdb_filename = os.path.join(original_directory, template_name + '.pdb')
        local_template_pdb_filename = template_name + '.pdb'
        outfile = open(local_template_pdb_filename, 'w')
        for line in open(original_template_pdb_filename, 'r'):
            if not line.startswith('HETATM'):
                outfile.write(line)
        outfile.close()

    # Build a manual alignment between template and target.
    alignment_filename = "transfer.ali"
    [segment_ids, renumber_residues] = build_manual_alignment(template_pdb_filename, alignment_filename, debug=False, retain_hetatm=retainHETATM, retain_water=retainWater)

    # Call MODELLER to generate topology, transfer coordinates, and build from internal coordinates.
    import modeller
    import modeller.automodel
    
    # Create a new environemnt.
    env = modeller.environ()

    # Set up MODELLER paths.
    env.io.atom_files_directory = [ tmpdir ]

    if retainHETATM:
        modeller.io_data.hetatm = True

    if retainWater:
        modeller.io_data.water = True
    
    # Specify the topology and parameters to use.
    if allHydrogen:
        raise Exception("Not implemented.")
    else:
        env.libs.topology.read(file='$(LIB)/top_heav.lib')
        env.libs.parameters.read(file='$(LIB)/par.lib')
        
    # Read in alignment.
    aln = modeller.alignment(env)
    aln.append(file=alignment_filename, align_codes='all')
    
    # Create a model.
    model = modeller.model(env)
        
    # Generate the topology from the target sequence.
    model.generate_topology(aln['target'])
        
    # Transfer defined coordinates from template.
    model.transfer_xyz(aln)
        
    # Determine which atoms are undefined because they are missing in the template, and create a selection from them.
    missing_atom_indices = []
    for atom_index in range(len(model.atoms)):
        atom = model.atoms[atom_index]
        if atom.x == -999:
            missing_atom_indices.append(atom_index)
    
    if retainOriginalNumbering:
        # Rename segments.    
        model.rename_segments(segment_ids=segment_ids)
        model.rename_segments(segment_ids=segment_ids, renumber_residues=renumber_residues)
            
    # Write model coordinates to a PDB file.
    transferred_pdb_filename = 'transferred.pdb'
    model.write(file=transferred_pdb_filename)

    # Build the remaining undefined atomic coordinates from ideal internal coordinates stored in residue topology files.
    model.build(initialize_xyz=False, build_method='INTERNAL_COORDINATES')
    
    if retainOriginalNumbering:
        # Rename segments.    
        model.rename_segments(segment_ids=segment_ids)
        model.rename_segments(segment_ids=segment_ids, renumber_residues=renumber_residues)

    # Write model coordinates to a PDB file.
    built_pdb_filename = 'built.pdb'
    model.write(file=built_pdb_filename)

    # Add additional templates to the alignment, if specified.
    #modeller.io_data.hetatm = False
    #modeller.io_data.water = False
    knowns = ['template']
    aln = modeller.alignment(env)
    aln.append(file=alignment_filename, align_codes=knowns)
    if templates:
        # Add additional templates.
        # TODO: Because we're modeling multiple chains, do we have to try aligning templates to each chain?
        for template_name in templates:
            print template_name
            # Create a local copy of the template, stripping out HETATM lines.
            # TODO: Use MODELLER to read/write PDB file?
            # TODO: What if template_pdb_filename contains a pathname?
            template_pdb_filename = template_name + '.pdb'
            # Load template model.
            template_model = modeller.model(env, file=template_pdb_filename)
            aln.append_model(template_model, align_codes=template_name)
            knowns.append(template_name)
        
        for (weights, write_fit, whole) in (((1., 0., 0., 0., 1., 0.), False, True),
                                            ((1., 0.5, 1., 1., 1., 0.), False, True),
                                            ((1., 1., 1., 1., 1., 0.), True, False)):
            aln.salign(rms_cutoff=3.5, normalize_pp_scores=False,
                       rr_file='$(LIB)/as1.sim.mat', overhang=30,
                       gap_penalties_1d=(-9000, -50),
                       gap_penalties_3d=(0, 3), gap_gap_score=0, gap_residue_score=0,
                       alignment_type='tree', # If 'progresive', the tree is not
                                # computed and all structues will be
                                # aligned sequentially to the first
                       feature_weights=weights, # For a multiple sequence alignment only
                                # the first feature needs to be non-zero
                       improve_alignment=True, fit=True, write_fit=write_fit,
                       write_whole_pdb=whole, output='ALIGNMENT QUALITY')

        aln.salign(rms_cutoff=1.0, normalize_pp_scores=False,
                   rr_file='$(LIB)/as1.sim.mat', overhang=30,
                   gap_penalties_1d=(-450, -50), gap_penalties_3d=(0, 3),
                   gap_gap_score=0, gap_residue_score=0,
                   alignment_type='progressive', feature_weights=[0]*6,
                   improve_alignment=False, fit=False, write_fit=True,
                   write_whole_pdb=False, output='QUALITY')

    # Write template alignment.
    alignment_filename = 'templates.ali'
    aln.write(file=alignment_filename, alignment_format='PIR')
    if debug:
        print "templates.ali:"
        print commands.getoutput('cat %(alignment_filename)s' % vars())

    # Align target to templates.
    aln = modeller.alignment(env)
    aln.append(file='transfer.ali', align_codes='all')
    aln_block = len(aln)
    aln.append(file='templates.ali', align_codes='all')

    # Structure sensitive variable gap penalty sequence-sequence alignment:
    aln.salign(output='', max_gap_length=20,
               gap_function=True,   # to use structure-dependent gap penalty
               alignment_type='PAIRWISE', align_block=aln_block,
               feature_weights=(1., 0., 0., 0., 0., 0.), overhang=0,
               gap_penalties_1d=(-450, 0),
               gap_penalties_2d=(0.35, 1.2, 0.9, 1.2, 0.6, 8.6, 1.2, 0., 0.),
               similarity_flag=True)

    alignment_filename = 'aligned.ali'
    aln.write(file=alignment_filename, alignment_format='PIR')
    if debug:
        print "aligned.ali:"
        print commands.getoutput('cat %(alignment_filename)s' % vars())

    # Override the 'select_atoms' routine in the 'automodel' class to select only the atoms with undefined atomic coordinates in template PDB.
    if (allHydrogen):
        class mymodel(modeller.automodel.allhmodel):
            def select_atoms(self):
                missing_atoms = modeller.selection()
                for atom_index in missing_atom_indices:
                    missing_atoms.add(self.atoms[atom_index])
                return missing_atoms
            def special_patches(self, aln):
                if retainOriginalNumbering:
                    self.rename_segments(segment_ids=segment_ids, renumber_residues=renumber_residues)
                return
    else:
        class mymodel(modeller.automodel.automodel):
            def select_atoms(self):
                missing_atoms = modeller.selection()
                for atom_index in missing_atom_indices:
                    missing_atoms.add(self.atoms[atom_index])
                return missing_atoms
            def special_patches(self, aln):
                if retainOriginalNumbering:
                    self.rename_segments(segment_ids=segment_ids, renumber_residues=renumber_residues)
                return

    # Ensure selected atoms feel all nonbonded interactions.
    env.edat.nonbonded_sel_atoms = 1
        
    # Set up automodel.
    #a = mymodel(env, alnfile=alignment_filename, knowns=knowns, sequence='target') # This works
    a = mymodel(env, inifile=built_pdb_filename, alnfile=alignment_filename, knowns=knowns, sequence='target') # This doesn't work
    
    # Set parameters for automodel.
    # Build only one model.
    # TODO: Have more models built by default (perhaps 50?)
    a.starting_model = 1
    a.ending_model = 1
    
    # Generate model(s).
    a.make()
    
    # TODO: Rescore models and select the best one.
    # For now, we only use the first model.
    final_model_summary = a.outputs[0]
    
    # Copy resulting model to desired output PDB filename.
    shutil.copy(final_model_summary['name'], outputFilename)

    # Restore working directory.
    os.chdir(original_directory)
        
    # Clean up temporary directory.
    if (not debug):
        for filename in os.listdir(tmpdir):
            os.remove(os.path.join(tmpdir,filename))
        os.rmdir(tmpdir)

    return

#=============================================================================================
# MAIN
#=============================================================================================

if __name__ == '__main__':
    import doctest
    doctest.testmod()
    
    # Attempt modeling.

    source_pdb_filename = 'src_tbosutinib_c4+d2_refine_waters_tls_38-seqres.pdb'
    output_pdb_filename = 'src_tbosutinib_c4+d2_refine_waters_tls_38.modeller.pdb'

    #source_pdb_filename = 'a403t_bosut_refine_nl5_30-seqres.pdb'
    #output_pdb_filename = 'a403t_bosut_refine_nl5_30.modeller.pdb'

    templates = ['3LCK']
    templates = []

    model_missing_atoms(source_pdb_filename, output_pdb_filename, templates=templates, debug=True)
