# This file is a part of the Lineshaper package
# Stephen Fried, January 2012

# BE SURE TO IMPORT KRYLOV FIRST!!
try:
    import Krylov
    krylov_enabled = True
except:
    print "Warning: Could not load 'Krylov' library"
    krylov_enabled = False

import numpy as np

from msmbuilder import Serializer
from scipy import linalg
from scipy import fftpack
from scipy import io
from scipy import sparse

import sys
import os
import pickle

from argparse import ArgumentParser

try:
    from deap import dtm
    deap_enabled = True
except:
    deap_enabled = False


def getTStep(atomType):
    """Determines an appropriate time-step given the atom type (either H, C, or N)."""

    if atomType == 'HA' or atomType == 'HN' or atomType == 'HA2' or atomType == 'HA3':
        TS = .0001
    elif atomType == 'C' or atomType == 'CA' or atomType == 'CB':
        TS = .0004
    elif atomType == 'N':
        TS = .001
    return TS

    
def getScale(atomType,B0):
    """Determines the reference Larmor frequency given the atom type (eiter H, C, or N) 
    and the static magnetic field, as given by the proton's Larmor frequency."""
    
    if atomType == 'HA' or atomType == 'HN' or atomType == 'HA2' or atomType == 'HA3':
        scale = B0
    elif atomType == 'C' or atomType == 'CA' or atomType == 'CB':
        scale = B0*10.705/42.576
    elif atomType == 'N':
        scale = B0*4.316/42.576
    else:
        raise Exception("Could not understand atom type: %s" % atomType)
    return scale


def getR2Not(atomType):
    """Determines R_2^0, which is the intrinsic transverse relaxation rate in the absence
    of chemical exchange.  In general, this is an approximate catch-all to incorporate
    all relaxation effects other than chemical exchange, like CSA, DD, etc.  To a first
    approximation, these effects are empirically related to the protein's tau_c
    (rotational correlation time).  Below reflect average parameters for a very small protein,
    like WW (35 a.a.s), where tau_c ~ 2 ns."""
    
    # These are R_2^0 in units of s^-1.  Delta nu_FHWM = R_2^0 / pi
    if atomType == 'HA' or atomType == 'HA2' or atomType == 'HA3':
        R2Not = 22.0
    elif atomType == 'HN':
        R2Not = 12.0
    elif atomType == 'N':
        R2Not = 6.0
    elif atomType == 'CA' or atomType == 'CB':
        R2Not = 12.0
    elif atomType == 'C':
        R2Not = 11.0
    return R2Not


def Bloch_McConnell(X, points, timeArray, P, method="krylov", k=0):
    """ Solves the Bloch-McConnel equation, essentially generating an artificial
    FID. Does this by taking matrix exponents over time.

    Since this process can be quite expensive, there are two ways to perform this
    calculation. First, one can do a brute-force Pade approximation to the matrix
    exponentials. To perform this, pass 'k=-1'.

    Alternatively, one can do a low-rank approximation. Here, we decompose X into
    'k' eigenvectors/values, and take the (trivial) matrix exponential of many
    diagonal matrices. We then re-combine the solution via unitary tranformation.
    To perform the eigendecomposition with all eigenvectors, pass 'k=0'.

    Calculation:
    Signal = Expm[ X * t ] * Populations, for many t

    Arguments:
    X      - the matrix  X = K + i*Omega - R2not
    points - the time points at which to do the calculation

    Returns:
    signalatt - the FID signal at all the time-points
    """

    signalArray = np.zeros((points),dtype='complex') 
    NumStates   = X.shape[0]

    # brute force method
    if method == "brute":
        if sparse.issparse(X): # if sparse cast to dense
            X = X.toarray()

        for j in range(points):
            signalatt = sum(np.dot(linalg.expm(X*timeArray[j]),P))
            signalArray[j] = signalatt

    # eigendecomposition, all eigenvectors
    # this is working --TJL 2.23.12
    elif method == "eigendecomposition":

        if sparse.issparse(X): # if sparse cast to dense
            X = X.toarray()

        l, U = linalg.eig(X)
        Uinv = linalg.inv(U)

        for j in range(points):
            L = np.eye(NumStates) * np.exp( l * timeArray[j] )
            signalArray[j] = np.sum( np.dot( np.dot( np.dot(Uinv, L), U), P ))

    # low-rank approximation
    elif method == "low-rank":

        if not sparse.issparse(X):
            raise Exception("Error. You probably shouldn't be using low-rank approximations on dense mats!")

        l, U = sparse.eigs(X, k=k, which='SM')
        Uinv = linalg.inv(U)

        for j in range(points):
            L = np.eye(NumStates) * np.exp( l * timeArray[j] )
            signalArray[j] = np.sum( np.dot( np.dot( np.dot(Uinv, L), U), P ))

    elif method == "krylov":

        if not sparse.issparse(X):
            print "Warning: passed X is not sparse in Bloch_McConnell()"

        assert krylov_enabled == True
        signalArray = np.sum(Krylov.krylov_expm(X, P, list(timeArray), debug=0), axis=1)

    else:
        print "Method not understood: %s" % method
        print "Must be one of: brute, eigendecomposition, low-rank, krylov"
        raise Exception("Cannot understand method")

    return signalArray


def apodizer(signalArray):
    """Apodizes the spectrum using a particular window function of interest.
    Signal processing pre-Fourier transform."""
    
    #Currently, there is no apodization being used, 
    #So just returns the current time-domain spectrum
    return signalArray 

  
def phaser(signalF):
    """Phases the spectrum.  Tries to make the real part of the spectrum 100% absorptive.
    Samples iterively through a series of phases, trying to maximize the peak height.
    Then returns the spectrum multiplied by that phase."""
    
    phases = np.arange(-np.pi, np.pi, .01)
    maxes = np.zeros((len(phases)), dtype='float64')
    for k in range(len(phases)):
        signalFphased = signalF*np.exp(1j*phases[k])
        highestPoint = (signalFphased.real).max()
        maxes[k] = highestPoint
    bestPhase = maxes.argmax()
    signalFphased = (signalF*np.exp(1j*phases[bestPhase])).real
    
    return signalFphased


def simulate_spec_wrapper( wrapper_args):
    """ Argument passing wrapper for the DTM """
    (i, secondaryShiftArrayAved, shiftTypes, K, P, B0, points) = wrapper_args
    return simulate_spec(i, secondaryShiftArrayAved, shiftTypes, K, P, B0, points=points)


def simulate_spec(i, secondaryShiftArrayAved, shiftTypes, K, P, B0, points=4000):
    """ Simulates the FID signal for a single atomic nucleus.

    Arguments:
    i                       - atom index
    secondaryShiftArrayAved -
    shiftTypes              -
    K                       - MSM rate matrix
    P                       - MSM equilibrium populations
    B0                      - magnetic field (MHz?)
    tstep                   - sampling rate, in seconds
    points                  - number of time domain points

    Returns:
    freqArray               -
    freqArray + RCshift     -
    signalFp                -
    """

    atomShifts = secondaryShiftArrayAved[:,i]   #Gets a vector of shifts in all microstates
    atomType   = shiftTypes[i][2]               #Finds out whether the atom is HA, HN, N, C, CA, CB
    RCshift    = float(shiftTypes[i][3])        #Gets that atom's random coil shift
        
    # Initialize variables for the signal processing
    tstep     = getTStep(atomType)                        # atom specific time step
    Ws        = 2*np.pi/tstep	                          # sampling frequency in rad/s
    timeArray = np.arange(tstep,tstep*(points+1),tstep,dtype='float64')   # number of time points
    NumStates = secondaryShiftArrayAved.shape[0]          # number of states in the MSM

    # Construct Omega -- the offset frequency matrix in units of s^-1.
    Omega = sparse.dia_matrix( ( (atomShifts * getScale(atomType, B0)), 0), shape=K.shape, dtype='complex')
    #Omega = np.eye(NumStates) * atomShifts * getScale(atomType,B0) 
   
    #Construct R2not -- the transverse relaxation rate in the absence of exchage in s^-1
    R2not = sparse.dia_matrix( (np.ones(K.shape[0]) * getR2Not(atomType), 0), shape=K.shape, dtype='complex' )
   
    #let the matrix X = K + i*Omega - R2not
    X = K + 2.0*np.pi*Omega*1j - R2not
    assert sparse.isspmatrix(X)

    signalArray = Bloch_McConnell(X, points, timeArray, P, method='krylov')
    signalArray = apodizer(signalArray)
    signalF     = (fftpack.fftshift(fftpack.fft(signalArray)))*tstep # take the fourier transform and shift
    signalFp    = abs(signalF)                                    # extract out the absorptive component
   
    # build the frequency axis from the Nyquist frequency Ws/2; then convert to ppm
    freqArray = Ws*np.arange(-points/2,points/2,dtype='float64')/points
    freqArray = freqArray/(2*np.pi*getScale(atomType,B0))

    print "Calculated lineshape for %s in %s%s" % (shiftTypes[i][2], shiftTypes[i][1], shiftTypes[i][0])
    return freqArray, freqArray + RCshift, signalFp


def reduce_results(spectra, map_results):
    """ Reduces FID simulations distributed via a DTM map """

    print "\nReducing results...\n"

    for i,result in enumerate(map_results):
        spectra[i,0,:] = result[0]
        spectra[i,1,:] = result[1]
        spectra[i,2,:] = result[2]

    return spectra

    
def calculate_lineshapes( secondaryShiftArrayAved_file, shiftTypes_file,
    B0, ratematrix_fn, pops_fn, nodes=0 ):
    """
    --- LineShaper Step (2): Calculates lineshapes for the nuclear resonances

    Solves the Bloch-McConnell equations for free precession of transverse
    magnetization in the rotating frame.

    Prepares a synthetic FID (frequency induction decay) for each nucleus. Each
    nuclear lineshape is calculated independently, which assumes scalar coupling is
    ignored.

    This script requires the shiftArrayAved.h5 and shiftTypes.pkl files produced
    from predict_chemical_shifts.py. It also requires a rate matrix and a population
    vector from the MSM (normally Populations.dat, found in the Data folder).   

    Writes:
    spectra.h5:		matrix containing spectra for all atoms.
    fids.h5			matrix containing FIDs for all atoms (to be processed on other NMR software)
    """

    if os.path.exists("spectra.h5"):
        print "Error! File: spectra.h5 already exists!"
        sys.exit(1)

    #Load up the secondaryShiftArrayAved into memory.
    secondaryShiftArrayAved = Serializer.Serializer.LoadFromHDF(secondaryShiftArrayAved_file)['Data'] 
    NumStates = secondaryShiftArrayAved.shape[0]
    NumShifts = secondaryShiftArrayAved.shape[1]
    
    #Load up the shiftTypes metadata into memory
    pkl = open(shiftTypes_file, 'rb')
    shiftTypes = pickle.load(pkl) 
    pkl.close() 
    
    #Load rate matrix and populations
    K = io.mmread(ratematrix_fn)
    P = np.loadtxt(pops_fn)
    
    #Make arrays that will store spectra
    points = 4000
    spectra = np.zeros((NumShifts,3,points), dtype='float64')

    # if running locally
    if nodes == 0:
        for i in range(NumShifts): # For each atom
            spectra[i,0,:], spectra[i,1,:], spectra[i,2,:] = \
                simulate_spec(i, secondaryShiftArrayAved, shiftTypes, K, P,
                             B0, points = points)
            
    # Running on MPI
    else:

        # Generate list of jobs, via their arguments (packaged in dicts)
        wrapper_args = []
        for i in range(NumShifts):
            wrapper_args.append( (i, secondaryShiftArrayAved, shiftTypes, K, P, B0, points) )

        map_results = dtm.map( simulate_spec_wrapper, wrapper_args )
        spectra = reduce_results(spectra, map_results)
        
    #Rescale the intensities to be between [0,1]
    # SDF 20120301: peak heights might be important, so I'm putting this on hold.  
    #for i in range(NumShifts):
    #    maxIntensity = spectra[i,2,:].max()
    #    spectra[i,2,:] = spectra[i,2,:]/maxIntensity
    
    #Write the data to files
    Serializer.SaveData("spectra.h5", spectra)
    print "Saved all data to spectra.h5 \n"
    print "Lineshaper terminated succesfully."

    return


def parse(print_title=True):
    if print_title: print """

                               ~~~ LineShaper ~~~
          Markov State Model Approach to Protein NMR Lineshape Analysis
                    by Stephen Fried <sdfried@stanford.edu>
             
    --- LineShaper Step (2): Calculates lineshapes for the nuclear resonances

    Solves the Bloch-McConnell equations for free precession of transverse
    magnetization in the rotating frame.

    Prepares a synthetic FID (frequency induction decay) for each nucleus. Each
    nuclear lineshape is calculated independently, which assumes scalar coupling is
    ignored.

    This script requires the shiftArrayAved.h5 and shiftTypes.pkl files produced
    from predict_chemical_shifts.py. It also requires a rate matrix and a population
    vector from the MSM (normally Populations.dat, found in the Data folder).   

    Writes:
    spectra.h5:		matrix containing spectra for all atoms. \n"""


    parser = ArgumentParser()
    parser.add_argument('magnetic_field', type=int,
                        help='Static magnetic field given as proton Larmor frequency')
    parser.add_argument('-s', dest='secondary_shift_array_aved', type=str,
                        help='''Averaged secondary chemical shifts for each
                        microstate and atom. Default: secondaryShiftArrayAved.h5''',
                        default='secondaryShiftArrayAved.h5')
    parser.add_argument('-t', dest='shift_types', type=str,
                        help='Dictionary containting residue no., residue name, and atom type',
                        default='shiftTypes.pkl')
    parser.add_argument('-K', dest='rate_matrix', type=str, help='Rate Matrix for the MSM',
                        default='Data/K.mtx')
    parser.add_argument('-P', dest='populations', type=str, help='Static populations of the MSM',
                        default='Data/Populations.dat')
    parser.add_argument('-D', dest='nodes', type=int,
                        help='''Number of nodes to parallelize with. Submit this script as an
                        MPI job. Pass 0 to disable this and run locally. Default: 0 (local).
                        Note: You must have the DEAP package installed to use this feature.''',
                        default=0)

    args = parser.parse_args()

    
    return( args.secondary_shift_array_aved, args.shift_types,
            args.magnetic_field, args.rate_matrix, args.populations, args.nodes )

def main():
    """ Wrap stuff in a run function for external use - probably useless but complete """

    (secondary_shift_array_aved, shift_types,
    magnetic_field, rate_matrix, populations, nodes) = parse(print_title=False)

    calculate_lineshapes( secondary_shift_array_aved, shift_types,
                          magnetic_field, rate_matrix, populations, nodes=nodes)

    return

if __name__ == '__main__':

    # if help is requested, don't boot MPI
    if len(sys.argv) < 2:
        parse()
        sys.exit(0)
    elif sys.argv[1] == '-h' or sys.argv[1] == '--help':
        parse()
        sys.exit(0)
    else:
        (secondary_shift_array_aved, shift_types,
         magnetic_field, rate_matrix, populations, nodes) = parse()

        if nodes > 0:
            print "\nMPI job requested..."
            if deap_enabled:
                print "DEAP enabled. Booting the MPI interface...\n"
                dtm.start(main) # start with MPI
            else:
                print "Error. You need the DEAP & mpi4py modules to run in parallel mode!"
                sys.exit(1)

        else: main() # start w/o MPI


