from lxml import etree as et
#import subprocess32
#import sys
import numpy as np
#import time
import math
import matplotlib.pyplot as plt
import os

'''Script to generate some plots from the calibrations on stiffness'''

def lin_fit(self, x, y):
    '''function from erica's subject_force_displacement_plots.py, performs a linear fit forced through (0,0)'''

    plt.rcParams.update({'font.size': 12, 'lines.linewidth': 1})
    plt.figure(figsize=(9, 11))
    plt.plot(x, y, 'o', markersize=5)
    ax = plt.gca()
    [i.set_linewidth(1) for i in ax.spines.itervalues()]
    ax.tick_params(length=1)

    '''Fits a linear fit of the form mx to the data'''
    # A = np.vstack([x]).T

    m, _, _, _ = np.linalg.lstsq(x, y, rcond=None)

    # print >> open('/home/doherts/lustre/MULTIS/008_UL/Py_Script.out', 'w'), '\n',   m
    y_fit = m * x
    #plt.plot(x, y_fit)
    y_bar = np.average(y)
    SS_tot = np.sum((y - y_bar) ** 2)
    SS_res = np.sum((y - y_fit) ** 2)
    R_sqr = 1 - SS_res / SS_tot
    plt.ylabel('Force (N)')
    plt.xlabel('Displacement (mm)')
    textstr = r'$R^2=%.2f$' % (R_sqr,) + '\n' + r'$m=%.2f$' % (m[0],)
    props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
    ax.text(0.05, 0.95, textstr, transform=ax.transAxes, fontsize=15, verticalalignment='top', bbox=props)
    #plt.ylim([-0.5 ,8.5])
    #plt.xlim([-0.5 ,12.5])
    plt.xlabel('Displacement (mm)')
    plt.ylabel('Force (N)')
    #plt.show()
    #raw_input('Showing Plot...')
    #plt.subplots_adjust(bottom=.1)
    plt.tight_layout()
    #print >> open('/home/doherts/lustre/MULTIS/008_UL/Py_Script.out', 'w'), '\n',   self.split('.')[0] + '.png'
    plt.savefig(self.split('.')[0] + '_InverseFEA.png',dpi=300)
    plt.close()

    # plt.tight_layout()
    # plt.show()
    return float(m[0]), float(R_sqr)

def get_Disp_XMLData(RegistrationXML):
    '''read displacement data from registration xml files'''
    f = open(RegistrationXML, 'r')
    tree = et.parse(f)
    root2 = tree.getroot()



    limb = os.path.basename(RegistrationXML).split('_')
    limbDict = {'UL': 'UpperLeg', 'LL': 'LowerLeg', 'UA': 'UpperArm', 'LA': 'LowerArm'}

    data = root2.find(limbDict.get(limb[-3])).find("CentralAnterior").find("Indentation").findall("USPosition")

    Xdisp = []
    Ydisp = []
    Zdisp = []
    TotalDisp = np.ones((0, 1))
    for i in xrange(len(data)):
        Xdisp.append(float(data[i].find('x').get('value')))
        Ydisp.append(float(data[i].find('y').get('value')))
        Zdisp.append(float(data[i].find('z').get('value')))

    Xdisp[:] = [x - Xdisp[0] for x in Xdisp]
    Ydisp[:] = [y - Ydisp[0] for y in Ydisp]
    Zdisp[:] = [z - Zdisp[0] for z in Zdisp]
    # find displacement magnitude
    for i in xrange(len(Xdisp)):
        TotalDisp = np.append(TotalDisp,
                              [[math.sqrt((Xdisp[i] * 1000) ** 2 + (Ydisp[i] * 1000) ** 2 + (Zdisp[i] * 1000) ** 2)]],
                              axis=0)
    #print TotalDisp

    return TotalDisp

def get_ForceTime_XMLData(manThickXML):
    '''read force data from manthick xml files'''

    f = open(manThickXML, 'r')
    tree = et.parse(f)
    root1 = tree.getroot()
    data = root1.find('Subject').find('Source').findall('Frame')

    Fx = []
    Fy = []
    Fz = []
    Time = np.ones((0, 1))
    Force = np.ones((0, 1))

    for i in xrange(len(data)):
        Fx.append(float(data[i].find('Forces').find('Fx').text))
        Fy.append(float(data[i].find('Forces').find('Fy').text))
        Fz.append(float(data[i].find('Forces').find('Fz').text))
        Force = np.append(Force, [[math.sqrt(Fx[i] ** 2 + Fy[i] ** 2 + Fz[i] ** 2)]], axis=0)
        Time = np.append(Time, [[float(data[i].find('Time').get('value'))]], axis=0)
    #print 'EXP Max Force: ', max(Force)
    return Force, Time

def get_ForceDisp_LogData(logFile, MaxDispMagnitude):
    '''Return numpy array of force and displacement based on febio log file'''
    log = open(logFile, 'r')
    lines = log.readlines()
    data = np.ones((0, 2))
    TimeArray = np.ones((0, 1))
    ForceArray = np.ones((0, 1))
    for i, line in enumerate(lines):
        if line[:-1] == 'Data = Fx;Fy;Fz':
            time = float(lines[i - 1].split(' ')[-1])
            forces = lines[i + 1]  # forces will end up being the last entry
            forces = [float(f) for f in forces.split()]
            force = (forces[1] ** 2 + forces[2] ** 2 + forces[3] ** 2) ** 0.5
            if time >= 1:
                data = np.append(data, [[time, force]], axis=0)
                TimeArray = np.append(TimeArray, [[time]], axis=0)
                ForceArray = np.append(ForceArray, [[force]], axis=0)
            # print >> open('/home/doherts/lustre/MULTIS/008_UL/Py_Script.out', 'w'), '\n',   time, '\t\t', force

    #print len(ForceArray), len(TimeArray)
    DispArray = np.abs((TimeArray - 1) * MaxDispMagnitude)
    # Move data over so duplicate 0 force timepoints arent used which may influence line fit
    # print >> open('/home/doherts/lustre/MULTIS/008_UL/Py_Script.out', 'w'), '\n',   len(ForceArray)
    # print >> open('/home/doherts/lustre/MULTIS/008_UL/Py_Script.out', 'w'), '\n',   np.count_nonzero(ForceArray)

    NumZeros = len(ForceArray) - np.count_nonzero(ForceArray)
    #print >> open(loggingFile, 'a+'), '\n',   NumZeros, type(NumZeros)


    ForceArray = np.delete(ForceArray, range(0, NumZeros - 1), 0)
    TimeArray = np.delete(TimeArray, range(0, NumZeros - 1), 0)
    DispArray = np.delete(DispArray, range(0, NumZeros - 1), 0)

    TimeArray = TimeArray - TimeArray[0]
    DispArray = DispArray - DispArray[0]

    #print TimeArray
    #print DispArray
    #print max(ForceArray)
    return ForceArray, DispArray, TimeArray

def ClipFebioForceDisp(FebioDisp, FebioForce,ExpDisp, ExpForce):
    #print 'Lower Bound of Force Clip: ', min(ExpForce)
    FebioForceClip = []
    FebioDispClip = []
    for x in FebioForce:
        if x >= min(ExpForce):
            FebioForceClip.append(x)


    FebioForceClip = np.array(FebioForceClip)
    #print 'Clipped Febio Force: ', FebioForceClip
    #print FebioDisp
    #print ExpDisp
    #print np.shape(FebioForceClip)

    TempFebioDispClip = np.reshape(np.delete(FebioDisp, range(0, len(FebioDisp) - len(FebioForceClip))), [-1, 1])
    TempFebioDispClip = TempFebioDispClip - TempFebioDispClip[0]
    #print np.shape(TempFebioDispClip)
    for x in TempFebioDispClip:
        if x<=max(ExpDisp):
            FebioDispClip.append(x)
        else:
            FebioDispClip.append(x)
            break
    FebioDispClip = np.array(FebioDispClip)
    FebioForceClip = np.reshape(np.delete(FebioForceClip,range(len(FebioDispClip),len(FebioForceClip))), [-1, 1])
    #print np.shape(FebioForceClip)


    return FebioDispClip, FebioForceClip

def MakePlots(ManThickXML, RegistrationXML,FebioFile,InputFile,LastRun):

    feb_tree = et.parse(InputFile)
    febio_root = feb_tree.getroot()
    BCs = febio_root.find('Boundary').find('rigid_body').findall('prescribed')
    xFebDisp = float(BCs[0].text)
    yFebDisp = float(BCs[1].text)
    zFebDisp = float(BCs[2].text)
    MaxFebioDisplacement = np.sqrt(xFebDisp ** 2 + yFebDisp ** 2 + zFebDisp ** 2)

    ExpForce, ExpTime = get_ForceTime_XMLData(ManThickXML)
    ExpDisp = get_Disp_XMLData(RegistrationXML)
    #print len(ExpDisp),len(ExpForce)
    if os.path.isfile(FebioFile.replace('run1', 'run' + str(LastRun))):
        RunOneFile = FebioFile
        RunOneForce, RunOneDisp, RunOneTime = get_ForceDisp_LogData(RunOneFile, MaxFebioDisplacement)
        FebioFile = FebioFile.replace('run1', 'run' +str(LastRun))
        MultipleRuns = True
    else:
        MultipleRuns = False
    FebioForce, FebioDisp, FebioTime = get_ForceDisp_LogData(FebioFile, MaxFebioDisplacement)
    #print len(FebioForce)
    #print min(ExpForce)

    RunOneDispClip, RunOneForceClip = ClipFebioForceDisp(RunOneDisp,RunOneForce,ExpDisp,ExpForce)
    FebioDispClip, FebioForceClip =  ClipFebioForceDisp(FebioDisp,FebioForce,ExpDisp,ExpForce)
    #print ExpForce
    #print ExpForce
    #print FebioForceClip
    #print FebioForce.shape, FebioDisp.shape, FebioTime.shape
    #print ExpForce.shape, ExpTime.shape, ExpDisp.shape

    plt.figure(1)
    ax1 = plt.subplot(221)
    ax2 = plt.subplot(222)
    ax3 = plt.subplot(223)
    ax4 = plt.subplot(224)

    Febslope, FebRTwo = lin_fit(ManThickXML, FebioDispClip, FebioForceClip)
    Expslope, ExpRTwo = lin_fit(RegistrationXML, ExpDisp, ExpForce)

    ax1.plot(FebioDispClip, FebioForceClip, 'r-', FebioDispClip, Febslope*FebioDispClip)
    ax1.set_title('FeBio Force vs Disp with Linear Fit')
    textstr = r'$R^2=%.4f$' % (FebRTwo,) + '\n' + r'$m=%.4f$' % (Febslope,)
    props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
    ax1.text(0.05, 0.95, textstr, transform=ax1.transAxes, fontsize=15, verticalalignment='top', bbox=props)
    ax2.plot(ExpDisp, ExpForce, 'r-', ExpDisp, Expslope*ExpDisp)
    ax2.set_title('Exp Force vs Disp with Linear Fit')
    textstr = r'$R^2=%.4f$' % (ExpRTwo,) + '\n' + r'$m=%.4f$' % (Expslope,)
    props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
    ax2.text(0.05, 0.95, textstr, transform=ax2.transAxes, fontsize=15, verticalalignment='top', bbox=props)
    ax3.plot(ExpDisp, ExpForce, 'k.', FebioDispClip,FebioForceClip, 'rx')
    ax3.set_title('Calibrated Febio Run and Experimental Data')
    #ax4.plot(ExpDisp, (ExpForce-ExpDisp*Expslope), 'rx', FebioDisp, (FebioForce-FebioDisp*Febslope), 'ko')#ExpTime / 1000, ExpForce, 'rx')
    ax4.set_title('All Data')

    ax1.set_xlim([min(FebioDisp), max(FebioDisp) + 1])
    ax1.set_ylim([min(FebioForce)-1, max(FebioForce) + 1])
    ax2.set_xlim([min(ExpDisp), max(ExpDisp) + 1])
    ax2.set_ylim([min(ExpForce)-1, max(ExpForce) + 1])

    ax3.set_xlim([min(FebioDisp), max(FebioDisp)+1])
    ax3.set_ylim([min(FebioForce)-1, max(max(ExpForce), max(FebioForce))+1])
    ax3.legend(['Experimental', 'Febio Converged Run'])

    if MultipleRuns is False:
        ax4.plot(ExpDisp, ExpForce, 'k.', FebioDisp, FebioForce, 'rx')
        ax4.legend(['Experimental', 'Febio Converged Run'])
    else:
        ax4.plot(ExpDisp, ExpForce, 'k.', RunOneDispClip,RunOneForceClip, 'rx', FebioDispClip, FebioForceClip, 'b+')
        ax4.legend(['Experimental', 'Febio Initial Guess', 'Febio Calibrated Run'])




    plt.show()

    #plt.figure(2)
    ax5 = plt.figure()
    plt.plot(ExpDisp, ExpForce, 'k.', RunOneDispClip, RunOneForceClip, 'r-', FebioDispClip, FebioForceClip, 'b--', markersize=20,linewidth=5)

    plt.xlim([0,15])
    plt.ylim([0,25])
    #plt.legend(['Experimental', 'Febio Initial Guess', 'Febio Converged Run'], fontsize=24)
    #plt.ylabel('Force (N)', fontsize=32)
    #plt.xlabel('Displacement (mm)', fontsize=32)
    # plt.figure(2)
    # ax1 = plt.subplot(121)
    # ax2 = plt.subplot(122)
    # SqrtFebSlope, SqrtFebR2 = lin_fit(RegistrationXML, FebioDisp,np.sqrt(FebioForce))
    # ax1.plot(FebioDisp,np.sqrt(FebioForce), 'rx', FebioDisp, FebioDisp*SqrtFebSlope)
    # ax1.set_title('Sqrt Transformed FeBio Force vs Disp')
    # textstr = r'$R^2=%.5f$' % (SqrtFebR2,) + '\n' + r'$m=%.5f$' % (SqrtFebSlope,)
    # props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
    # ax1.text(0.05, 0.95, textstr, transform=ax1.transAxes, fontsize=15, verticalalignment='top', bbox=props)
    #
    # SqrtExpSlope, SqrtExpR2 = lin_fit(RegistrationXML, ExpDisp, np.sqrt(ExpForce))
    # ax2.plot(ExpDisp, np.sqrt(ExpForce), 'rx', ExpDisp, ExpDisp * SqrtExpSlope)
    # ax2.set_title('Sqrt Transformed FeBio Force vs Disp')
    # textstr = r'$R^2=%.5f$' % (SqrtExpR2,) + '\n' + r'$m=%.5f$' % (SqrtExpSlope,)
    # props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
    # ax2.text(0.05, 0.95, textstr, transform=ax2.transAxes, fontsize=15, verticalalignment='top', bbox=props)

    plt.show()

if __name__ == '__main__':
    #Change the path as needed
    path = '/home/doherts/Documents/MULTIS/Calibrations/LinearFitCalibration/006UA_Unclean/'
    #last .xplt run file, not .feb file
    LastRunVal = 4

    for fname in os.listdir(path):  # change directory as needed
        #print fname
        if os.path.isfile(path + fname):  # make sure it's a file, not a directory entry
            if 'manThick' in fname and '.xml' in fname:
                ManThickXML = path + fname
            elif 'US_CT.xml' in fname:
                RegistrationXML = path + fname
            elif 'run1.log' in fname:
                FebioLog = path +fname
            elif '.feb' in fname and '_run' in fname:
                #print 'InputFile: ', fname
                InputFile = path + fname
            else:
                pass
    #ManThickXML = '/home/doherts/Documents/MULTIS/Calibrations/StiffnessCalibration/008UL/003_CMULTIS008-1_UL_AC_I-1_manThick201708241020.xml'
    #RegistrationXML = '/home/doherts/Documents/MULTIS/Calibrations/StiffnessCalibration/008UL/CMULTIS008-1_UL_US_CT.xml'
    #FebioFile = '/home/doherts/Documents/MULTIS/Calibrations/StiffnessCalibration/008UL/008UL_Quad_run1.log'
    #print ManThickXML
    #print RegistrationXML
    #print FebioFile


    MakePlots(ManThickXML,RegistrationXML,FebioLog,InputFile, LastRunVal)
