# This file is a part of the Lineshaper package
# Stephen Fried, January 2012

"""analysisLib contains functions for performing analyses on data from Lineshaper

Notes:
1.  The inputs come both from predict_chemical_shifty and calculate_lineshapes 
2.  Assumes the default file output names from those scripts
2.  Functions below generally are written in a logical order
"""

import numpy as np

from msmbuilder import Serializer

from scipy import linalg
from scipy import io


import matplotlib.pyplot as plt

import sys
import os
import pickle

##---------------------------------------------------------
##Some housekeeping functions that the other functions call
##---------------------------------------------------------

def SettingUpPaths(path_names):
    """If this is the first time an analysis function is being used, creates folder called
    Analysis.  All outputs from all functions will get put in this folder.  Checks if the
    output files for that function already exist.  If so, prevents that function from running."""
    #path_names is a list of strings containing all the filenames the function outputs

    if os.path.exists( 'Analysis' ):
        pass
    else:
        os.mkdir( 'Analysis' )
        
    for path in path_names:
        if os.path.exists( './Analysis/'+path ):
            raise Exception("Error. %s already exists. Exiting." % path)
            
def Voigt(x, y0, A, fracLorentzian, fwhm, peak):
    return y0 + A*(fracLorentzian*(2/np.pi)*(fwhm/(4*(x-peak)**2 + fwhm**2)) + (1-fracLorentzian)*(np.sqrt(4*np.log(2))/(np.sqrt(np.pi)*fwhm)) * np.exp(-(4*np.log(2)/fwhm**2)*(x-peak)**2))

def getNatLW(atomType):
    """Determines natural linewidths in units of ppm.  In general, this is an approximate catch-all to incorporate
    all relaxation effects other than chemical exchange, like CSA, DD, etc.  To a first
    approximation, these effects are empirically related to the protein's tau_c
    (rotational correlation time).  Below reflect average parameters for a very small protein,
    like WW (35 a.a.s), where tau_c ~ 2 ns."""
    
    # These are FWHMs in units of ppm at 600 MHz magnet.  Delta nu_FHWM (Hz) = R_2^0 / pi
    #Technically, lines are narrower at higher B0, but this is good enough to get the fitting started
    if atomType == 'HA' or atomType == 'HA2' or atomType == 'HA3':
        LW = 0.0118
    elif atomType == 'HN':
        LW = 0.0064
    elif atomType == 'N':
        LW = 0.0318
    elif atomType == 'CA' or atomType == 'CB':
        LW = 0.0254
    elif atomType == 'C':
        LW = 0.0233
    return LW
    
def getWeight(shiftType):
    """Returns the weighting factor as defined by Schumann FH, et al. (2007)."""
    resType = shiftType[1]
    atomType = shiftType[2]
    
    if atomType[0] == 'H':
        key = resType+'H'
    elif atomType == 'N':
        key = resType+'N'
    elif atomType == 'CA' or atomType == 'CB':
        key = resType+'CA'
    elif atomType == 'C':
        key = resType + 'C'
    
    weight = weightDict[key]
    return weight  
    
def GetCominedDeltaShift(shiftSet1, shiftSet2, shiftTypes):
    """shiftSet1 and shiftSet2 are complete sets of chemical shifts for the same protein.
    They differ either because they represent two different states of the protein
    two different ways of calculating chemical shifts, or something else.
    This function returns Delta delta_12^combined, an overall distance between the two states in the
    multi-dimensional chemical shift space.  It uses a weighted Euclidean distance metric according to
    Schumann FH, et al.  J. Biomol. NMR 39, 275-289 (2007)."""
    
    NumShifts = shiftSet1.shape[0]
    runningTotal = 0.0
    for s in range(NumShifts): #iterate over all the shifts
        runningTotal += ((shiftSet1[s] - shiftSet2[s]) * getWeight(shiftTypes[s]))**2
    
    combinedDeltaShift = np.sqrt(runningTotal/NumShifts)
    return combinedDeltaShift
    
##----------------------------------------------------
##These functions run basic controls over the data-set
##----------------------------------------------------

def GetChemicalShiftSpreads():
    """Finds the chemical shift stdDev for each nucleus for each microstate.
    And finds the chemical shift stdDev between microstates for each nucleus.
    
    Writes:
    microstateSpread.txt	A matrix showing CS spread for each nucleus for each microstate
    atomSpread.txt			A vector showing CS spread b/w microstates for each nucleus
    ChemicalShiftSpreads.pdf	An image visualizing this information
    """
    
    #Names of output files that are created. 
    path_names = ['microstateSpread.txt', 'atomSpread.txt', 'ChemicalShiftSpreads.pdf']
    #Check if they already exist.
    SettingUpPaths(path_names)
    
    #Load up the shift arrays, assuming the conventional filename
    SA = Serializer.Serializer.LoadFromHDF('shiftArray.h5')['Data']
    SAA = Serializer.Serializer.LoadFromHDF('shiftArrayAved.h5')['Data'] 

    microstateSpread = SA.std(axis=1)  #a 2nd rank matrix, NState by NShifts
    atomSpread = np.array([SAA.std(axis=0) , SAA.std(axis=0)]) #a vector, NShifts
    
    #Saves results
    print "Saving results..."
    np.savetxt('./Analysis/microstateSpread.txt', microstateSpread, delimiter=' ', newline='\n')
    np.savetxt('./Analysis/atomSpread.txt', atomSpread, delimiter=' ', newline='\n')
    
    #Visualize results
    print "Plotting results..."
    fig, ax = plt.subplots( nrows = 2)
    plt.subplots_adjust(bottom=0.3, top=0.7, wspace=0.3)
    
    ax[0].matshow(microstateSpread)
    #ax[0].colorbar()
    ax[0].set_xlabel('atom index')		#TO TJL: STILL WANT TO FIGURE OUT HOW TO LABEL ATOM-TYPES ACROSS THIS AXIS
    ax[0].set_ylabel('microstate index')
    
    ax[1].matshow(atomSpread)
    #ax[1].colorbar()
    ax[1].set_xlabel('atom index')
    
    plt.savefig('./Analysis/ChemicalShiftSpreads.pdf')
    print "Saved ChemicalShiftSpreads.pdf"
    
    return 0


def CompareRMSDs():
    """Finds the overall chemical shift RMSD betwen microstates.
    Also finds the atomic coordinate RMSD between generators.
    These measures of chemical shift similarity and geometric similarity can also
    be compared to the rate matrix, to see if they correlate to kinetic similarity.
    
    Writes:
    chemshiftRMSD.txt	A matrix showing CS RMSD between microstates
    geometricRMSD.txt	A matrix showing atomic coordinate RMSD between microstates 
    """
    
    #Names of output files that are created
    path_names = ['chemshiftRMSD.h5', 'geometricRMSD.h5', 'ShiftVsGeometricRMSD.pdf']
    #Check if they already exist.
    SettingUpPaths(path_names)
    
    #Load up the shift array, assuming the conventional filename
    SAA = Serializer.Serializer.LoadFromHDF('shiftArrayAved.h5')['Data'] 
    
    #Lod up the shift types
    pkl = open('shiftTypes.pkl', 'rb')
    ST = pickle.load(pkl) 
    pkl.close()
    
    #Create a microstate-microstate chemical shift RMSD matrix
    NumStates = SAA.shape[0]
    NumShifts = SAA.shape[1]
    chemshiftRMSD = np.zeros((NumStates,NumStates))
    for i in range(NumStates):
        for j in range(NumStates):
            if j > i:
                chemshiftRMSD[i,j] = GetCominedDeltaShift(SAA[i,:], SAA[j,:], ST) 
            
    #Create a microstate-microstate atomic coordinate RMSD matrix
    geometricRMSD = np.zeros((NumStates,NumStates))
    #TALK TO ROBERT AND TJL ABOUT THIS
    
    #Saves results
    print "Saving results..."
    Serializer.SaveData('./Analysis/chemshiftRMSD.h5', chemshiftRMSD)
    Serializer.SaveData('./Analysis/geometricRMSD.h5', geometricRMSD)

    
    #Visualize results
    print "Plotting results..."
    
    plt.scatter(chemshiftRMSD.flatten(), geometricRMSD.flatten())
    #STILL NEED TO BEAUTIFY THIS FIGURE UP
    plt.savefig('./Analysis/ShiftVsGeometricRMSD.pdf')
    print "Saved ChemicalShiftSpreads.pdf"
    
    return
    
    
##-------------------------------------------------------------------------------
##These functions look at lineshapes and chemical shifts of the simulated spectra
##-------------------------------------------------------------------------------

def GetSpectralParams():
    """For each atom, determines the chemical shift (peak frequency), full-width at half-max, 
    and the fraction Lorentzian by fitting to a Lorentzian.
    
    Writes
    lineshapeParams.txt		Matrix giving the fit parameters for each atom
    lineshapeParams.h5
    """
    
    #Names of output files that are created
    path_names = ['lineshapeParams.txt', 'lineshapeParams.h5']
    #Check if they already exist.
    SettingUpPaths(path_names)
    
    #Load up the spectra and the shiftTypes dictionary
    spectra = Serializer.Serializer.LoadFromHDF('spectra.h5')['Data'] 
    pkl = open('shiftTypes.pkl', 'rb')
    ST = pickle.load(pkl) 
    pkl.close()
    
    NumAtoms = spectra.shape[0]
    
    #Create matrix to be filled with fitting parameters
    lineshapeParams = np.zeros((NumAtoms,4))
    
    for i in range(NumAtoms):
        #An atom index which matches the metadata goes in the first column
        lineshapeParams[i,0] = int(i)
        
        #Find the peak position
        peakPoint = spectra[i,2,:].argmax()
        height = spectra[i,2,:].max()
        peakPosition = spectra[i,1,peakPoint]
        
        #Find the FWHM
        halfHeight = height/2
        leftInterval = range( (peakPoint - 100), peakPoint)
        closest = leftInterval[0]
        for j in leftInterval:
            closestDiff = abs( halfHeight - closest )
            currentDiff = abs( halfHeight - spectra[i,2,j] )
            if currentDiff < closestDiff:
                closestIndexL = j
        
        rightInterval = range( peakPoint, (peakPoint + 100) )
        closest = rightInterval[0]
        for j in rightInterval:
            closestDiff = abs( halfHeight - closest )
            currentDiff = abs( halfHeight - spectra[i,2,j] )
            if currentDiff < closestDiff:
                closestIndexR = j
        
        fwhm = spectra[i,1,closestIndexR] - spectra[i,1,closestIndexL]
             
        print "Got the parameters for ",ST[i][2],"in",ST[i][1],ST[i][0]
        
        #The fit parameters are stored up:
        lineshapeParams[i,1:] = [height, fwhm, peakPosition]
    
    #Saves results
    print "Saving results..."
    np.savetxt('./Analysis/lineshapeParams.txt', lineshapeParams, delimiter=' ', newline='\n')
    Serializer.SaveData('./Analysis/lineshapeParams.h5', lineshapeParams)
    print "Completed."
    
    return
    
    
def CompareShifts():
    """Compare the lineshape shifts (from above) to static chemical shifts to microstate 
    population-averaged (but not dynamically-averaged) chemical shifts.  Compare to experiment.
    Visualize deviations from experiment for the three calculation schemes.
    """
    #Load up the chemical shifts from the lineshapes
    lineshapeShifts = Serializer.Serializer.LoadFromHDF('lineshapeParams.h5')['Data'][:,3]
    
    #Load up the observed chemical shifts
    observedShifts = np.loadtxt("FILE_WITH_OBSERVED_SHIFTS")
    
    #Load up the shiftTypes dictionary
    pkl = open('shiftTypes.pkl', 'rb')
    ST = pickle.load(pkl) 
    pkl.close()
    
    #Calculate the population-averaged shifts
    SAA = Serializer.Serializer.LoadFromHDF('shiftArrayAved.h5')['Data']
    P = np.loadtxt('Data/Populations.dat')
    populationShifts = np.dot(P,SAA)
    
    #Overall difference between lineshape-based shifts and static population-averaged shifts
    lineshapeVsPops = GetCominedDeltaShift(lineshapeShifts, populationShifts, ST) 
    print "Combined delta shift between lineshapes versus population-averaged is",lineshapeVsPops
    	##TO TJL: MIGHT ALSO BE INTERESTING TO GIVE A FIGURE THAT COMPARES ATOM-BY-ATOM ??
     
    #Overall error in population-averaged shifts
    popsVsObsvd = GetCominedDeltaShift(populationShifts, observedShifts, ST) 
    print "Overall error of the calculated population-averaged shifts is",popsVsObsvd
    
    #Overall error in lineshape-derived shifts
    lineshapeVsObsvd = GetCominedDeltaShift(lineshapeShifts, observedShifts, ST) 
    print "Overall error of the calculated lineshape-derived shifts is",lineshapeVsObsvd
    
    return
    
    
def VisualizeLinewidths():
    """Sort the linewidths by atomType, and prepare an atom-by-atom line graph plus 
    make histograms."""
    
    #Load up the linewidths
    linewidths = Serializer.Serializer.LoadFromHDF('lineshapeParams.h5')['Data'][:,4]
    
    #Load up the shiftTypes dictionary
    pkl = open('shiftTypes.pkl', 'rb')
    ST = pickle.load(pkl) 
    pkl.close()
    
    #Open up dictionaries that will store data for each atomType
    linewidthHA = {}
    linewidthHN = {}
    linewidthN = {}
    linewidthC = {}
    linewidthCA = {}
    linewidthCB = {}
    
    #Sort the linewidths by atom type
    for i in range(linewidths.shape[0]):
        atomType = ST[i][2][0:2]	#Will either be HA, HN, N, C, CA, CB
        residueNo = int(ST[i][0]) + 1	#Find out the residue number and start indexing at 1
        if atomType == 'HA':
            linewidthHA[residueNo] = linewidths[i]
        elif atomType == 'HN':
            linewidthHN[residueNo] = linewidths[i]
        elif atomType == 'N':
            linewidthN[residueNo] = linewidths[i]
        elif atomType == 'C':
            linewidthC[residueNo] = linewidths[i]
        elif atomType == 'CA':
            linewidthCA[residueNo] = linewidths[i]
        elif atomType == 'CB':
            linewidthCB[residueNo] = linewidths[i]
        
    #Draw the histogram figure
    print "Histogramming results..."
    fig, ax = plt.subplots( nrows = 6)
    plt.subplots_adjust(bottom=0.3, top=0.7, wspace=0.3)
    ax[0].hist(linewidthHA.values())
    ax[1].hist(linewidthHN.values())
    ax[2].hist(linewidthN.values())
    ax[3].hist(linewidthC.values())
    ax[4].hist(linewidthCA.values())
    ax[5].hist(linewidthCB.values())
    plt.savefig('./Analysis/LinewidthHistograms.pdf')
    print "Saved LinewidthHistograms.pdf"
    
    #Draw the atom-by-atom linegraph
    print "Plotting results..."
    plt.plot(linewidthHA.keys(),linewidthHA.values(),'r^',linewidthHN.keys(),linewidthHN.values(),'ro',linewidthN.keys(),linewidthN.values(),'go',linewidthC.keys(),linewidthC.values(),'bo',linewidthCA.keys(),linewidthCA.values(),'b^',linewidthCB.keys(),linewidthCB.values(),'b')
    plt.savefig('./Analysis/LinewidthAtombyAtom.pdf')
    
    
    print "The average linewidth for HAs is",sum(linewidthHA.values())/len(linewidthHA.values())
    print "The average linewidth for HNs is",sum(linewidthHN.values())/len(linewidthHN.values())
    print "The average linewidth for Ns is",sum(linewidthN.values())/len(linewidthN.values())
    print "The average linewidth for CAs is",sum(linewidthCA.values())/len(linewidthCA.values())
    print "The average linewidth for CBs is",sum(linewidthCB.values())/len(linewidthCB.values())
    print "The average linewidth for Cs is",sum(linewidthC.values())/len(linewidthC.values())
        
    return    
    
    

def VisualizeHeights():
    """Sort the heights by atomType, and prepare a atom-by-atom line graph plus 
    make histograms."""
    
    #Load up the fractions Lorentzian
    heights = Serializer.Serializer.LoadFromHDF('lineshapeParams.h5')['Data'][:,1]
    
    #Load up the shiftTypes dictionary
    pkl = open('shiftTypes.pkl', 'rb')
    ST = pickle.load(pkl) 
    pkl.close()
    
    #Open up dictionaries that will store data for each atomType
    heightHA = {}
    heightHN = {}
    heightN = {}
    heightC = {}
    heightCA = {}
    heightCB = {}
    
    #Sort the linewidths by atom type
    for i in range(linewidths.shape[0]):
        atomType = ST[i][2][0:2]	#Will either be HA, HN, N, C, CA, CB
        residueNo = int(ST[i][0]) + 1	#Find out the residue number and start indexing at 1
        if atomType == 'HA':
            heightHA[residueNo] = heights[i]
        elif atomType == 'HN':
            heightHN[residueNo] = heights[i]
        elif atomType == 'N':
            heightN[residueNo] = heights[i]
        elif atomType == 'C':
            heightC[residueNo] = heights[i]
        elif atomType == 'CA':
            heightCA[residueNo] = heights[i]
        elif atomType == 'CB':
            heightCB[residueNo] = heights[i]
        
    #Draw the histogram figure
    print "Histogramming results..."
    fig, ax = plt.subplots( nrows = 6)
    plt.subplots_adjust(bottom=0.3, top=0.7, wspace=0.3)
    ax[0].hist(heightHA.values())
    ax[1].hist(heightHN.values())
    ax[2].hist(heightN.values())
    ax[3].hist(heightC.values())
    ax[4].hist(heightCA.values())
    ax[5].hist(heightCB.values())
    plt.savefig('./Analysis/HeightHistograms.pdf')
    print "Saved HeightHistograms.pdf"
    
    #Draw the atom-by-atom linegraph
    print "Plotting results..."
    plt.plot(heightHA.keys(),heightHA.values(),'r^',heightHN.keys(),heightHN.values(),'ro',heightN.keys(),heightN.values(),'go',heightC.keys(),heightC.values(),'bo',heightCA.keys(),heightCA.values(),'b^',heightCB.keys(),heightCB.values(),'b')
    plt.savefig('./Analysis/HeightAtombyAtom.pdf')
    
    
    print "The average height for HAs is",sum(heightHA.values())/len(heightHA.values())
    print "The average height for HNs is",sum(heightHN.values())/len(heightHN.values())
    print "The average height for Ns is",sum(heightN.values())/len(heightN.values())
    print "The average height for CAs is",sum(heightCA.values())/len(heightCA.values())
    print "The average height for CBs is",sum(heightCB.values())/len(heightCB.values())
    print "The average height for Cs is",sum(heightC.values())/len(heightC.values())
        
    return
    

def CompareLinewidths():
    """Compare the linewidths (from above) to experiment.
    Visualize deviations from experiment atom-by-atom."""
    

##-----------------------------------------
##These functions examine dynamical effects
##-----------------------------------------

def CompareDynamicRanges():
    """Compares three measures of dynamic range for each atom:
    Atomic coordinate RMSD, shift spread, and linewidth."""
    

#def RandomizeRateMatrix():


##---------------------------------------------
##Testing Kubo's stochastic theory of lineshape
##---------------------------------------------
        
        
    
    
    
    
    
    
    
    
    
    
    