# This file is a part of the Lineshaper package
# Stephen Fried, January 2012

import numpy as np

from msmbuilder import Serializer
from msmbuilder import Project
from msmbuilder.Scripts import SavePDBs, GetRandomConfs

import sys
import glob
import os
import re
import time
import subprocess
import pickle

from argparse import ArgumentParser


def getFileName(ConfNo, StateNo):
    """ Given the desired StateNo and ConfNo, produces a filename with those as indexes """
    s = './PDBs/State%s-%s.pdb' % (StateNo, ConfNo)
    return s


def getNumShiftsAndMeta(file): 
    """ Gets the number of shifts and the metadata for each shift """
    X = {}
    f = open(file, 'r')
    i = 0
    for line in f:
        if line.startswith(' '):           					 # ignores the intro
            infoInLine = line.split()[0:3]+line.split()[5:6] # gathers the residue no., residue name, atom type, and RC shift
            X[i] = infoInLine            				  	 # puts that info in a dict indexed by a key no
            i += 1                         					 # that matches index on matrix with shifts
    return i, X


def getStateAndConf(file):
    """ This function is given a filename in the format that is output by
    SPARTA (i.e., StateXX-YY_pred.tab). It then extracts out the StateNo
    (i.e., XX) and the ConfNo (i.e., YY-1). """
    whereIsState = file.find('State')
    whereIsPred = file.find('_pred')
    shortenedFileName = file[whereIsState+5:whereIsPred]
    StateNo = int(shortenedFileName.split('-')[0])
    ConfNo = int(shortenedFileName.split('-')[1])
    return StateNo, ConfNo


def getShift(line):
    shift = float(line.split()[4])
    secondaryShift = float(line.split()[4]) - float(line.split()[5]) 
    return shift, secondaryShift


def generate_PBS_script(commands, f_num, ppn):
    """ Generates a PBS script with the commands passed as arguments. Saves
    this script to disk as lsh_pbs_N.sh, N = f_num argument

    Also could decide which queue to submit to, but hardwired to default for now
    """

    # Set filename and queue
    f_name = 'lsh_pbs_%d.sh' % f_num
    queue  = 'default'

    # generate the string of commands to execute
    command_str = ""
    for command in commands:
        command_str += "%s &\n" % command
    command_str += "wait\n"

    # generate the PBS script
    current_dir = os.path.abspath(".")
    content = """#!/bin/bash

### THIS SCRIPT WAS GENERATED BY GetNMRData.py, A PART OF lineshaper

#PBS -N lineshaper_%d
#PBS -e %s/lineshaper_%d.err
#PBS -o %s/lineshaper_%d.out
#PBS -l nodes=1:ppn=%d
#PBS -l walltime=24:00:00
#PBS -V


PBS_O_WORKDIR='%s'
export PBS_O_WORKDIR
### ---------------------------------------
### BEGINNING OF EXECUTION
### ---------------------------------------

        echo The master node of this job is `hostname`
        echo The working directory is `echo $PBS_O_WORKDIR`
        echo This job runs on the following nodes:
        echo `cat $PBS_NODEFILE`


### end of information preamble
cd $PBS_O_WORKDIR

# execute commands
%s
""" % (f_num, current_dir, f_num, current_dir, f_num, ppn, current_dir, command_str)

    f = open(f_name, 'w')
    f.write(content)
    f.close()

    return queue


def check_finish(NumStates, NumConfs):
    """ Performs two checks to see if the PBS spawned jobs are done:
    (1) Counts the number of .tab files generated, and makes sure this matches up with
        what is expected
    (2) Makes sure that there are some jobs still in the PBS queue

    returns True if done, False else
    """

    # do some preliminary cleaning
    current_tabs = glob.glob("*.tab")
    if len(current_tabs) > 0:
        os.system("rm *_struct.tab")
        os.system("mv *_pred.tab tabs")
    
    total_tab_files = NumStates * NumConfs

    # check the queue for lineshaper jobs
    username = subprocess.check_output("whoami", shell=True).strip()
    qstat    = subprocess.check_output("qstat -u %s" % username, shell=True)
    jobs     = re.search('lineshaper', qstat)

    # decide what to do based on the jobs
    if jobs == None:
        print "Found no lineshaper jobs in queue for user: %s" % username
        tab_files = glob.glob("tabs/*.tab")
        num_tabs  = len(tab_files)
        if num_tabs != total_tab_files:
            print "WARNING: Found insufficient number of tab files: %d, expected %d" %(num_tabs, total_tab_files)
            print "Expect PBS job failure. Attempting to generate list of missing"
            print "files, which you can re-submit by running ''" # TJL finish up here
            print "\nExiting gracefully..."
            sys.exit(0)
        else:
            print "Found all tab files expected (%d)" % total_tab_files
            done_bool = True
    else:
        print "Found at least one lineshaper job in queue, sleeping for 5 min..."
        done_bool = False

    return done_bool


def runSparta(NumStates, NumConfs, ppn=1, nodes=0):
    """ Runs SPARTA+ on all the files in directory 'PBDs', which generates an
    X_struct.tab and X_pred.tab for each PDB that is fed into SPARTA+. Simply dumps these
    into the working directory, though this may be changed.

    Assumes that you have SPARTA+ in your PATH, otherwise throws exception.

    Supports various levels of parallelism. If nodes=0, then will run on the local machine.
    If nodes >= 1, will spawn PBS scheduler proccesses and qsub them. In either case, will
    run with 'ppn' number of threads.
    """

    print "\nPredicting chemical shifts with SPARTA+..."

    # check for SPARTA+
    ret = os.system("sparta+ >& /dev/null")
    if ret == 0:
        print "Located SPARTA+ binary"
    elif ret == 32512:
        raise Exception("Error. Cannot find SPARTA+. Please ensure that your installation is correct.")
    else:
        raise Exception("Unknown error (%d) attempting to run SPARTA+. Exiting." % ret)

    # Generate a list of commands that need to be run - one for each processor
    total_pdbs = NumStates * NumConfs
    commands   = []
    if nodes != 0:
        num_procs = nodes * ppn
    elif nodes == 0:
        num_procs = ppn
    print "Distributing SPARTA+ jobs over %d total procs on %d nodes..." % (num_procs, nodes)

    for proc in range(num_procs):
        pdb_indices = range(total_pdbs)[proc::num_procs]
        if len(pdb_indices) < 2: # if we only have one file we need to pass diff input to sparta+
            if len(pdb_indices) == 1:
                ConfNo  = pdb_indices[0] / NumStates
                StateNo = pdb_indices[0] % NumStates
                fn_pre  = getFileName(ConfNo, StateNo)[:-4]
                pdb_str = "sparta+ -in PDBs/%s.pdb -out %s_pred.tab" % ((fn_pre,)*2)
            elif len(pdb_indices) == 0:
                pass
        else:
            pdb_str = ''
            for i in pdb_indices:
                ConfNo  = i / NumStates
                StateNo = i % NumStates
                pdb_str += getFileName(ConfNo, StateNo) + ' '
            commands.append("sparta+ -in %s" % pdb_str )
 
    # Split work up amongst many processors
    if nodes == 0:        # running locally
        null = open('/dev/null','w')
        pipes = []
        for n,command in enumerate(commands):
            #print "Executing: %s" % command
            p = subprocess.Popen( command, bufsize=-1, shell=True, stdout=null, stderr=null )
            print "Shelling child process %d / PID : %d" % (n, p.pid)
            pipes.append(p)

        for pipe in pipes:
            pipe.wait()
        null.close()

    elif nodes >= 0:      # running on PBS

        # generate and submit a bunch of PBS scripts
        for n in range(nodes):
            queue = generate_PBS_script( commands[n::nodes], n, ppn )   # creates 'lsh_pbs_N.sh'
            qsub_str = "qsub -mae -q %s %s" % (queue, 'lsh_pbs_%d.sh' % n)
            print "Submitting job # %d [ %s ]" % (n, qsub_str)
            os.system(qsub_str)
            
        # Every 5 min, check if things have finished
        print "\nSleeping until work finishes..."
        not_done = True
        while not_done:
            if check_finish(NumStates, NumConfs): # returns True if done
                not_done = False
            else:
                time.sleep(300) # wait 300 seconds

        # if all done, clean up
        print "Removing PBS scripts, std err and std out files"
        os.system("rm lsh_pbs_*.sh")
        os.system("rm lineshaper_*.out")
        os.system("rm lineshaper_*.err")

    else:                 # someone dumb tried to call a negative num of nodes
        raise Exception("Error: %d not a valid input for kwarg 'nodes'. Exiting." % nodes)

    print "Moving all predictions to tabs, removing structure tab files"
    if len( glob.glob("*.tab") ) > 0:
        os.system("rm *_struct.tab")
        os.system("mv *_pred.tab tabs")

    return


def main(ass_fn, proj_fn, NumConfs, ppn, nodes):
    """ 
    --- LineShaper Step (1): Predict Chemical Shifts

    Calculates the predicted SPARTA+ chemical shifts for a set of states from an
    MSM, for use in lineshape prediction. This script samples a number of conformations
    from each state in an MSM, predicts the chemical shifts for protons in each, and
    writes these values to disk. 

    Takes a number of arguments:
    ass_fn:   The path to an Assignments.h5 file  (str)
    proj_fn:  Path to a ProjectInfo.h5 file       (str)
    NumConfs: The number of conformations to sample from each state in the MSM  (int)
    ppn:      The number of processors to use per node                          (int)
    nodes:    The number of PBS-scheduled nodes to run on, if == 0, run locally (int)

    Writes:
    shiftArrayAved.h5: 			 serialized array, averaged chemical shifts for each microstate and atom
    secondaryShiftArrayAved.h5:  same as above, but filters out RC shift for each atom
    shiftArray.h5:     			 serialized array, all chemical shifts for each conformation sampled
    secondaryShiftArray.h5:      same as above, but filters out RC shift for each atom
    shiftTypes.pkl:    			 dictionary containting residue no., residue name, and atom type
    """

    # pull information from the MSM 
    assignments = Serializer.LoadData(ass_fn)
    P1          = Project.Project.LoadFromHDF(proj_fn)

    # set up file structure, test if output is taken
    path_names = [ "shiftArrayAved.h5", "shiftArray.h5", "shiftTypes.pkl", "PDBs", "tabs" ]
    for path in path_names:
        if os.path.exists( path ):
            raise Exception("Error. %s already exists. Exiting." % path)
    print "Creating directories: 'tabs' and 'PDBs'"
    os.mkdir("tabs")
    os.mkdir("PDBs")
    
    # produces the file AllConfs.lh5.pdf which contains a concatenated pdb file with all the info
    GetRandomConfs.run(P1, assignments, NumConfs, 'AllConfs.lh5')
    
    # slice out the pdb's
    print "\nConverting Random Confs LH5 to PDBs..."
    f = open('AllConfs.lh5.pdb', 'r')

    pdbOut = open('./PDBs/State0-0.pdb', 'w')
    StateNo = 0
    ConfNo  = 0
    for line in f:
        if line.startswith('ENDMDL'):
            if ConfNo == NumConfs - 1:
                ConfNo = 0
                StateNo += 1
                pdbOut.close()
                #print "Generated file: ",getFileName(ConfNo,StateNo)
                pdbOut = open(getFileName(ConfNo, StateNo), 'w')
            else:
                ConfNo += 1
                pdbOut.close()
                #print "Generated file: ",getFileName(ConfNo,StateNo)
                pdbOut = open(getFileName(ConfNo, StateNo), 'w')
        else:
            pdbOut.write(line)
    os.system( 'rm ' + getFileName(ConfNo,StateNo) ) # Gets rid of the trailing empty pdb file
    os.system( 'rm AllConfs.lh5 AllConfs.lh5.pdb' )
    print "Successfully converted %d files" % ( NumConfs * StateNo, )

    NumStates = StateNo # based on the last file modified
    assert NumStates == np.max( assignments.flatten() ) + 1
    
    # Pipe saved PDBs into SPARTA+
    runSparta(NumStates, NumConfs, ppn=ppn, nodes=nodes)
    
    # Find out the number of shifts and fill metadata into dict
    NumShifts, shiftTypes = getNumShiftsAndMeta('./tabs/State0-0_pred.tab') 
    
    # Make the matrix that contains the shifts 
    shiftArray = np.zeros((int(NumStates), int(NumConfs), int(NumShifts)), dtype=np.float64)
    secondaryShiftArray = np.zeros((int(NumStates), int(NumConfs), int(NumShifts)), dtype=np.float64)
    
    # Fill the shiftArray matrix, containing all the chemcial shift data
    for file in glob.glob('./tabs/*_pred.tab'):		# go through each file in the
                                                        # directory with kind of file name
        StateNo, ConfNo = getStateAndConf(file)
        f = open(file, 'r')
        shiftNo = 0		# keeping track of the line number trace which nucleus
        for line in f:
            if line.startswith(' '):
                shiftArray[StateNo,ConfNo,shiftNo], secondaryShiftArray[StateNo,ConfNo,shiftNo] = getShift(line)
                shiftNo += 1
   
    print "\n --- OUTPUT ---"
 
    # Export the shiftArray
    Serializer.SaveData("shiftArray.h5",shiftArray) # the 3rd rank matrix, NStates by NShifts by NConfs
    Serializer.SaveData("secondaryShiftArray.h5",secondaryShiftArray)
    print "Wrote: (secondary)shiftArray.h5,     chemical shifts for all %d conformations pulled in each microstate" % NumConfs
    
    
    # Next, perform average over the Conf dimension.
    shiftArray = shiftArray.sum(axis=1) / shiftArray.shape[1] # a 2nd rank matrix, NState by NShifts
    secondaryShiftArray = secondaryShiftArray.sum(axis=1) / secondaryShiftArray.shape[1]
    Serializer.SaveData("shiftArrayAved.h5", shiftArray)
    Serializer.SaveData("secondaryShiftArrayAved.h5", secondaryShiftArray)
    print "Wrote: (secondary)shiftArrayAved.h5, averaged chemical shifts for %d microstates and %d atoms" % (NumStates, NumShifts)
    
    
    # Export the dictionary with metaData
    shiftTypesOut = open("shiftTypes.pkl", 'wb')
    pickle.dump(shiftTypes, shiftTypesOut)
    shiftTypesOut.close()
    print "Wrote: shiftTypes.pkl,    dictionary containting residue no., residue name, and atom type"

    return


if __name__ == '__main__':
    print """

                                  ~~~ LineShaper ~~~
                Markov State Model Approach to Protein NMR Lineshape Analysis
                         by Stephen Fried <sdfried@stanford.edu>
             

    --- LineShaper Step (1): Predict Chemical Shifts

    Calculates the predicted SPARTA+ chemical shifts for a set of states from an
    MSM, for use in lineshape prediction. This script samples a number of conformations
    from each state in an MSM, predicts the chemical shifts for protons in each, and
    writes these values to disk. 

    Writes:
    shiftArrayAved.h5: 			 serialized array, averaged chemical shifts for each microstate and atom
    secondaryShiftArrayAved.h5:  same as above, but filters out RC shift for each atom
    shiftArray.h5:     			 serialized array, all chemical shifts for each conformation sampled
    secondaryShiftArray.h5:      same as above, but filters out RC shift for each atom
    shiftTypes.pkl:    			 dictionary containting residue no., residue name, and atom type\n"""

    parser = ArgumentParser()
    parser.add_argument('-a', dest='assignments_file', type=str, help='Assignments file from MSM. Default: Data/Assignments.Fixed.lh5',
                        default='Data/Assignments.Fixed.h5')
    parser.add_argument('-p', dest='project_file', type=str, help='ProjectInfo file from MSM. Default: ProjectInfo.h5',
                        default='ProjectInfo.h5')
    parser.add_argument('-n', dest='num_confs', type=int, help='Number of conformations to sample from each state. Default: 5.',
                        default=5)
    parser.add_argument('-P', dest='processors', type=int, help='Number of processors to run on each node. Default: 1.',
                        default=1)
    parser.add_argument('-D', dest='nodes', type=int, help='''Number of nodes to run on - i.e. PBS jobs to schedule. Assumes that you are
running on a PBS enabled cluster. Recommend ensuring PDBs commands are appropriate for your system before use. 
Pass 0 to turn off this feature and run locally. Default: 0 (off).''', default=0)

    args = parser.parse_args()

    
    main(args.assignments_file, args.project_file, args.num_confs, args.processors, args.nodes)
