# Usage: Run this script from the directory that it is in.
# python tissuerepeatabilityT_plots.py <path to data xml>.

import pickle
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
from scipy import stats
import sys
plt.rcParams["font.family"] = "Times New Roman"
def get_cmap(n, name='tab20'):
    return plt.cm.get_cmap(name,n)



xmlname = sys.argv[-1]

#xmlname='/Users/klonowe/TissueTesting/app/TissueTesting/Data/TMULTIS013/TMULTIS013_UL_S_L2D_001/TMULTIS013_UL_S_L2D_001.xml'

with open(xmlname[0:-3]+'pickle','rb') as handle:
    MachData=pickle.load(handle)

print()
##############
# Make plots #
##############

fig2, (ax3, ax4) = plt.subplots(nrows=2, ncols=1, figsize=(18, 9), dpi=90)
fig3, (ax8) = plt.subplots(nrows=1, ncols=1, figsize=(6, 4), dpi=90) #load v disp and stress v strain at ramp before and after
fig5, (ax9,ax10) = plt.subplots(nrows=2, ncols=1, figsize=(7, 7), dpi=90)
ramp1mod=[]
ramp10mod=[]
maxloadchangeList=[]
filename_list = list(MachData.keys())
file_RI = filename_list[0] # Assign what test to use for reindexing
cmap = get_cmap(2*len(filename_list))
print(len(filename_list))

for f_itr, file in enumerate(MachData.keys()):

    fig1, (ax1, ax2) = plt.subplots(nrows=2, ncols=1, figsize=(18, 9), dpi=90)
    fig4, (ax7) = plt.subplots(nrows=1, ncols=1, figsize=(18, 9), dpi=90)  # load v time during preconditioning
    # ax1.set_title(file)

    # print(file)
    ax3.set_title(file)
    ax4.set_title(file)
    # ax5.set_title(file)
    # ax6.set_title(file)
    # ax7.set_title(file)

    fig1.tight_layout(pad=4, w_pad=.5, h_pad=2)
    fig2.tight_layout(pad=4, w_pad=.5, h_pad=2)

    # fig1
    ax1.set_xlabel('time,s')
    ax1.set_ylabel('load, grams')
    ax2.set_xlabel('time,s')
    ax2.set_ylabel('disp,mm')

    # fig2
    ax3.set_xlabel('time,s')
    ax3.set_ylabel('load,grams')
    ax4.set_xlabel('time,s')
    ax4.set_ylabel('disp,mm')

    # fig3
    # ax5.set_xlabel('disp,mm')
    # ax5.set_ylabel('load,grams')

    # fig4
    ax7.set_xlabel('time,s')
    ax7.set_ylabel('load,grams')
    ax7.plot(MachData[file]['5']['Data']['Time, s'] , MachData[file]['5']['Data']['Filtered Load'] ,label=file)
    ax7.set_title(file)
    # ax5.plot(MachData[file]['2']['Data']['Filtered Displacement1'], MachData[file]['2']['Data']['Filtered Load'], c = cmap(2*f_itr), label=file + " Ramp 1")
    # ax5.plot(MachData[file]['3']['Data']['Filtered Displacement1'], MachData[file]['3']['Data']['Filtered Load'], c = cmap(2*f_itr), ls='--' ,label=file +" Ramp down 1")
    #
    # ax5.plot(MachData[file]['10']['Data']['Filtered Displacement2'], MachData[file]['10']['Data']['Filtered Load']
    #          ,c = cmap(2*f_itr+1), label=file +" Ramp 2")
    # ax5.plot(MachData[file]['11']['Data']['Filtered Displacement2'], MachData[file]['11']['Data']['Filtered Load']
    #          ,c = cmap(2*f_itr+1), ls='--' ,label=file + " Ramp down 2 down ")
    # ax5.set_title('Force vs. Displacement')
    # ax5.legend(loc=2)

    # figure1
    for i in MachData[file].keys(): #plots all the loads on one single plot
        i = int(i)
        if i >= 2 and i < 4:
            ax3.plot(MachData[file][str(i)]['Data']['Total Time'], MachData[file][str(i)]['Data']['Filtered Load'])
        if i > 5 and i <= 7:
            ax3.plot(MachData[file][str(i)]['Data']['Total Time'], MachData[file][str(i)]['Data']['Filtered Load'])
        if i >= 2:
            ax3.plot(MachData[file][str(i)]['Data']['Total Time'], MachData[file][str(i)]['Data']['Filtered Load'])
            ax4.plot(MachData[file][str(i)]['Data']['Total Time'], MachData[file][str(i)]['Data']['Filtered Position'])

    #find the modulus based on slope from 10% strain to the max stress on curve (linear region)
    # #slope is the slope of the regression line
    # #intercept is the intercept of the regression line
    # #r value returns the correlation coefficient, measures + or - correlation between x any y
    # #Two-sided p-value for a hypothesis test whose null hypothesis is that the slope is zero
    # #Standard error of the estimated gradient
    maxLoadindex = MachData[file]['10']['Data']['Stress2'].idxmax()
    slope, intercept, r_value, p_value, std_err = stats.linregress(
        MachData[file]['10']['Data']['Strain2'][1900:maxLoadindex],
        MachData[file]['10']['Data']['Stress2'][1900:maxLoadindex])
    MachData[file]['10']['Calculated'] = {}
    MachData[file]['10']['Calculated']['Modulus'] = slope
    MachData[file]['10']['Calculated']['rvalue'] = r_value
    print(maxLoadindex)
    print("first ramp modulus after pc")
    print(file)
    print(str(round(slope, 4)) + " MPa")
    # print((slope))
    #
    # slopes.append(slope)
    # print(slopes)


    maxLoadindex2 = MachData[file]['46']['Data']['Stress3'].idxmax()
    slope2, intercept2, r_value2, p_value2, std_err2 = stats.linregress(
        MachData[file]['46']['Data']['Strain3'][1900:maxLoadindex2],
        MachData[file]['46']['Data']['Stress3'][1900:maxLoadindex2])
    #print(MachData[file]['10']['Data']['Stress2'])
    print(maxLoadindex2)
    print("last ramp modulus after pc")
    #print(MachData[file]['46']['Data']['Strain3'])
    #print(MachData[file]['46']['Data']['Stress3'])

    print(str(round(slope2,4))+ " MPa")
    MachData[file]['46']['Calculated'] = {}
    MachData[file]['46']['Calculated']['Modulus']=slope2
    MachData[file]['46']['Calculated']['rvalue']=r_value2

    # exit()
    # for i in MachData[file].keys():
    #
        ax1.plot(MachData[file][i]['Data']['Total Time'], MachData[file][i]['Data']['Filtered Load'])
    #     ax2.plot(MachData[file][i]['Data']['Total Time'], MachData[file][i]['Data']['Filtered Position'])
    #     ax1.set_title(file)

    # MachData[file]['10']['Data_RI'] = MachData[file]['10']['Data'][:maxLoadindex]
    # MachData[file]['10']['Data_RI'].index = MachData[file]['10']['Data_RI']['Strain2']
    # MachData[file]['10']['Data_RI'] = MachData[file]['10']['Data'].reindex(MachData[file_RI]['10']['Data_RI']['Strain2'])

    #ax8.plot(MachData[file]['2']['Data']['Strain1'] , MachData[file]['2']['Data']['Stress1'], c = cmap(f_itr), label=file+" Ramp 1")
    ax8.plot(MachData[file]['10']['Data']['Strain2'] , MachData[file]['10']['Data']['Stress2'], c = cmap(f_itr),label=file+" Ramp 1")
    ax8.plot(MachData[file]['46']['Data']['Strain3'], MachData[file]['46']['Data']['Stress3'], c=cmap(f_itr),linestyle= 'dashdot',
             label=file + " Ramp 10")

    #ax8.scatter(MachData[file]['10']['Data']['Strain2'][1900] , MachData[file]['10']['Data']['Stress2'][1900])
    # print(MachData[file]['10']['Data']['Strain2'][1900])
    ax8.set_title('Stress vs Strain')
    ax8.set_xlabel('strain')
    ax8.set_ylabel('stress MPa')
    ax8.legend(loc=2)
    ramp1mod.append(round(slope,2))
    ramp10mod.append(round(slope2,2))
    maxloadchange=np.subtract(MachData[file]['10']['Data']['Filtered Load'][maxLoadindex],MachData[file]['46']['Data']['Filtered Load'][maxLoadindex2])
    maxloadchangeList.append(round(maxloadchange,2)) # change in load from first of ten ramps to last

fig3.savefig(os.path.join(os.path.dirname(xmlname), os.path.split(xmlname)[1][:-4] + '_StressStrain.png'))
filelist=list(MachData.keys())
output=[filelist, ramp1mod, ramp10mod, maxloadchangeList]
print(MachData.keys())


#MAxLoad index is the 4th tests max load for RMSE

#slope 1 and 2 are compared in the loop to represent the moduli of each test of a single sample to obtain all moduli comparisons.
percent_diff_list=[]
for slope1 in range(len(filelist)-1): #-1 so we dont include the last file's index (last file name accounted for in second loop)
    for slope2 in range(slope1+1,len(filelist)): #creates all combinations of ramp 2 to compare them
         #The index below accounts for only the values used in the slope of the linear reg line
         diff = np.subtract(MachData[filelist[slope1]]['10']['Data']['Stress2'][1900:maxLoadindex],
                                                       MachData[filelist[slope2]]['10']['Data']['Stress2'][1900:maxLoadindex])

         RMSE=np.sqrt(np.average(np.square(np.subtract(MachData[filelist[slope1]]['10']['Data']['Stress2'][1900:maxLoadindex],
                                                       MachData[filelist[slope2]]['10']['Data']['Stress2'][1900:maxLoadindex]))))
         print("RMSE  " + filelist[slope1], filelist[slope2], str(RMSE) +" MPa")

         percent_diff = round(abs(MachData[filelist[slope1]]['10']['Calculated']['Modulus']-MachData[filelist[slope2]]['10']['Calculated']['Modulus'])/\
                        ((MachData[filelist[slope1]]['10']['Calculated']['Modulus']+MachData[filelist[slope2]]['10']['Calculated']['Modulus'])/2)*100,2)
         percent_diff = [filelist[slope1], filelist[slope2], percent_diff]
         percent_diff_list.append(percent_diff)
         print("Percent Diff: ", str(percent_diff), " %")


percent_diff_10list=[]
for slope3 in range(len(filelist)-1): #-1 so we dont include the last file's index (last file name accounted for in second loop)
    for slope4 in range(slope3+1,len(filelist)): #creates all combinations of ramp 2 to compare them
         #The indexbelow accounts for only the values used in the slope of the linear reg line
         diff = np.subtract(MachData[filelist[slope3]]['46']['Data']['Stress3'][1900:maxLoadindex],
                                                       MachData[filelist[slope4]]['46']['Data']['Stress3'][1900:maxLoadindex])

         RMSE=np.sqrt(np.average(np.square(np.subtract(MachData[filelist[slope3]]['46']['Data']['Stress3'][1900:maxLoadindex],
                                                          MachData[filelist[slope4]]['46']['Data']['Stress3'][1900:maxLoadindex]))))
         print("RMSE  " + filelist[slope3], filelist[slope4], str(RMSE) +" MPa")

         percent_diff_10 = round(abs(MachData[filelist[slope3]]['46']['Calculated']['Modulus']-MachData[filelist[slope4]]['46']['Calculated']['Modulus'])/\
                       ((MachData[filelist[slope3]]['46']['Calculated']['Modulus']+MachData[filelist[slope4]]['46']['Calculated']['Modulus'])/2)*100,2)
         percent_diff_list_ramp10files = [filelist[slope3], filelist[slope4], percent_diff_10]
         percent_diff_10list.append(percent_diff_list_ramp10files) #percent difference comparing last of 10 ramps from all tests

         print("percent difference among last ramp slopes among tests")
         print("Percent Diff: ", str(percent_diff_10), " %")

# ax9.set_title(xmlname[0:-4])
ax9.set_title('Longitudinal Frozen')
# ax10.set_title('Load Difference')
# ax10.set_xlabel('Test Day')
ax9.set_xlabel('Test Day')
ax9.set_ylabel('Load (g)')
# ax10.set_title('Load Differential (Ramp 1-10)')

ax9.set_ylabel('Modulus (MPa)')
ax9.plot([1,2,3,4],ramp1mod,label='Ramp 1')
ax9.plot([1,2,3,4],ramp10mod,label='Ramp 10')
ax10.plot([1,2,3,4],maxloadchangeList, label="ramp one modulus")
ax9.legend(loc=0)

frametitle=pd.DataFrame(['Percent Difference All Files Ramp 1 '])
frametitle2=pd.DataFrame(['Percent Difference All Files Ramp 10 '])
df = pd.DataFrame(percent_diff_list)
df2 = pd.DataFrame(percent_diff_10list)
frames=[frametitle,df,frametitle2,df2]
result = pd.concat(frames)
result.to_csv(os.path.join(os.path.dirname(xmlname), os.path.split(xmlname)[1][:-4] + '_PercentChange.csv'), index=False, header=False)
fig5.savefig(os.path.join(os.path.dirname(xmlname), os.path.split(xmlname)[1][:-4] + '_Compare.png'))
plt.show()

#


# arg where, strain -.1

# # Ramp2 Stress RMSEs  #revise this!! calculate RMSE from same range of points as slope from lin regress line
# Stress12_ramp2= np.sqrt(np.average(np.square(np.subtract(stresses[1][1],stresses[0][1]))))
# Stress23_ramp2= np.sqrt(np.average(np.square(np.subtract(stresses[2][1],stresses[1][1]))))
# Stress31_ramp2= np.sqrt(np.average(np.square(np.subtract(stresses[2][1],stresses[0][1]))))


