# Usage: Run this script from the directory that it is in.
# python tissuerepeatabilityT_plots.py <path to data xml>.

import pickle
import matplotlib.pyplot as plt
import numpy as np
import os
from scipy import stats
import sys

def get_cmap(n, name='tab20'):
    return plt.cm.get_cmap(name,n)



#xmlname = sys.argv[-1]

xmlname='/Users/klonowe/TissueTesting/app/TissueTesting/Data/TMULTIS013/TMULTIS013_UL_S_L2D_001/TMULTIS013_UL_S_L2D_001.xml'

with open(xmlname[0:-3]+'pickle','rb') as handle:
    MachData=pickle.load(handle)

print()
##############
# Make plots #
##############

fig2, (ax3, ax4) = plt.subplots(nrows=2, ncols=1, figsize=(18, 9), dpi=90)
fig3, (ax5, ax8) = plt.subplots(nrows=1, ncols=2, figsize=(18, 8), dpi=90) #load v disp and stress v strain at ramp before and after

filename_list = list(MachData.keys())
file_RI = filename_list[0] # Assign what test to use for reindexing
cmap = get_cmap(2*len(filename_list))
print(len(filename_list))

for f_itr, file in enumerate(MachData.keys()):

    fig1, (ax1, ax2) = plt.subplots(nrows=2, ncols=1, figsize=(18, 9), dpi=90)
    fig4, (ax7) = plt.subplots(nrows=1, ncols=1, figsize=(18, 9), dpi=90)  # load v time during preconditioning
    # ax1.set_title(file)

    # print(file)
    ax3.set_title(file)
    ax4.set_title(file)
    ax5.set_title(file)
    # ax6.set_title(file)
    # ax7.set_title(file)

    fig1.tight_layout(pad=4, w_pad=.5, h_pad=2)
    fig2.tight_layout(pad=4, w_pad=.5, h_pad=2)

    # fig1
    ax1.set_xlabel('time,s')
    ax1.set_ylabel('load, grams')
    ax2.set_xlabel('time,s')
    ax2.set_ylabel('disp,mm')

    # fig2
    ax3.set_xlabel('time,s')
    ax3.set_ylabel('load,grams')
    ax4.set_xlabel('time,s')
    ax4.set_ylabel('disp,mm')

    # fig3
    ax5.set_xlabel('disp,mm')
    ax5.set_ylabel('load,grams')

    # fig4
    ax7.set_xlabel('time,s')
    ax7.set_ylabel('load,grams')
    ax7.plot(MachData[file]['5']['Data']['Time, s'] , MachData[file]['5']['Data']['Filtered Load'] ,label=file)
    ax7.set_title(file)
    ax5.plot(MachData[file]['2']['Data']['Filtered Displacement1'], MachData[file]['2']['Data']['Filtered Load'], c = cmap(2*f_itr), label=file + " Ramp 1")
    ax5.plot(MachData[file]['3']['Data']['Filtered Displacement1'], MachData[file]['3']['Data']['Filtered Load'], c = cmap(2*f_itr), ls='--' ,label=file +" Ramp down 1")

    ax5.plot(MachData[file]['10']['Data']['Filtered Displacement2'], MachData[file]['10']['Data']['Filtered Load']
             ,c = cmap(2*f_itr+1), label=file +" Ramp 2")
    ax5.plot(MachData[file]['11']['Data']['Filtered Displacement2'], MachData[file]['11']['Data']['Filtered Load']
             ,c = cmap(2*f_itr+1), ls='--' ,label=file + " Ramp down 2 down ")
    ax5.set_title('Force vs. Displacement')
    ax5.legend(loc=2)

    # figure1
    for i in MachData[file].keys(): #plots all the loads on one single plot
        i = int(i)
        if i >= 2 and i < 4:
            ax3.plot(MachData[file][str(i)]['Data']['Total Time'], MachData[file][str(i)]['Data']['Filtered Load'])
        if i > 5 and i <= 7:
            ax3.plot(MachData[file][str(i)]['Data']['Total Time'], MachData[file][str(i)]['Data']['Filtered Load'])
        if i >= 2:
            ax3.plot(MachData[file][str(i)]['Data']['Total Time'], MachData[file][str(i)]['Data']['Filtered Load'])
            ax4.plot(MachData[file][str(i)]['Data']['Total Time'], MachData[file][str(i)]['Data']['Filtered Position'])

    #find the modulus based on slope from 10% strain to the max stress on curve (linear region)
    # #slope is the slope of the regression line
    # #intercept is the intercept of the regression line
    # #r value returns the correlation coefficient, measures + or - correlation between x any y
    # #Two-sided p-value for a hypothesis test whose null hypothesis is that the slope is zero
    # #Standard error of the estimated gradient
    maxLoadindex = MachData[file]['10']['Data']['Stress2'].idxmax()
    slope, intercept, r_value, p_value, std_err = stats.linregress(
        MachData[file]['10']['Data']['Strain2'][1900:maxLoadindex],
        MachData[file]['10']['Data']['Stress2'][1900:maxLoadindex])
    print(maxLoadindex)

    maxLoadindex2 = MachData[file]['46']['Data']['Stress3'].idxmax()
    slope2, intercept2, r_value2, p_value2, std_err2 = stats.linregress(
        MachData[file]['46']['Data']['Strain3'][1900:maxLoadindex2],
        MachData[file]['46']['Data']['Stress3'][1900:maxLoadindex2])
    #print(MachData[file]['10']['Data']['Stress2'])
    print("last ramp slope")
    print(slope2)
    #print(MachData[file]['46']['Data']['Strain3'])
    #print(MachData[file]['46']['Data']['Stress3'])

    print(file)

    print(str(round(slope,4))+ " MPa")
    MachData[file]['10']['Calculated'] = {}
    MachData[file]['10']['Calculated']['Modulus']=slope
    MachData[file]['10']['Calculated']['rvalue']=r_value

    # exit()
    for i in MachData[file].keys():

        ax1.plot(MachData[file][i]['Data']['Total Time'], MachData[file][i]['Data']['Filtered Load'])
        ax2.plot(MachData[file][i]['Data']['Total Time'], MachData[file][i]['Data']['Filtered Position'])
        ax1.set_title(file)

    # MachData[file]['10']['Data_RI'] = MachData[file]['10']['Data'][:maxLoadindex]
    # MachData[file]['10']['Data_RI'].index = MachData[file]['10']['Data_RI']['Strain2']
    # MachData[file]['10']['Data_RI'] = MachData[file]['10']['Data'].reindex(MachData[file_RI]['10']['Data_RI']['Strain2'])

    ax8.plot(MachData[file]['2']['Data']['Strain1'] , MachData[file]['2']['Data']['Stress1'], c = cmap(f_itr), label=file+" Ramp 1")
    ax8.plot(MachData[file]['10']['Data']['Strain2'] , MachData[file]['10']['Data']['Stress2'], c = cmap(f_itr), linestyle= 'dashed',label=file+" Ramp 2")
    ax8.plot(MachData[file]['46']['Data']['Strain3'], MachData[file]['46']['Data']['Stress3'], c=cmap(f_itr),linestyle= 'dashdot',
             label=file + " Ramp 10")

    ax8.scatter(MachData[file]['10']['Data']['Strain2'][1900] , MachData[file]['10']['Data']['Stress2'][1900])
    # print(MachData[file]['10']['Data']['Strain2'][1900])
    ax8.set_title('Stress vs Strain')
    ax8.set_xlabel('strain')
    ax8.set_ylabel('stress MPa')
    ax8.legend(loc=2)

fig3.savefig(os.path.join(os.path.dirname(xmlname), os.path.split(xmlname)[1][:-4] + '_StressStrain.png'))
#
filelist=list(MachData.keys())
print(MachData.keys())



#MAxLoad index is the 4th tests max load for RMSE

#slope 1 and 2 are compared in the loop to represent the moduli of each test of a single sample to obtain all moduli comparisons.
for slope1 in range(len(filelist)-1): #-1 so we dont include the last file's index (last file name accounted for in second loop)
    for slope2 in range(slope1+1,len(filelist)): #creates all combinations of ramp 2 to compare them
         #The index below accounts for only the values used in the slope of the linear reg line
         diff = np.subtract(MachData[filelist[slope1]]['10']['Data']['Stress2'][1900:maxLoadindex],
                                                       MachData[filelist[slope2]]['10']['Data']['Stress2'][1900:maxLoadindex])

         RMSE=np.sqrt(np.average(np.square(np.subtract(MachData[filelist[slope1]]['10']['Data']['Stress2'][1900:maxLoadindex],
                                                       MachData[filelist[slope2]]['10']['Data']['Stress2'][1900:maxLoadindex]))))
         print("RMSE  " + filelist[slope1], filelist[slope2], str(RMSE) +" MPa")

         percent_diff = abs(MachData[filelist[slope1]]['10']['Calculated']['Modulus']-MachData[filelist[slope2]]['10']['Calculated']['Modulus'])/\
                       ((MachData[filelist[slope1]]['10']['Calculated']['Modulus']+MachData[filelist[slope2]]['10']['Calculated']['Modulus'])/2)*100

         print("Percent Diff: ", str(percent_diff), " %")



plt.show()

#


# arg where, strain -.1

# # Ramp2 Stress RMSEs  #revise this!! calculate RMSE from same range of points as slope from lin regress line
# Stress12_ramp2= np.sqrt(np.average(np.square(np.subtract(stresses[1][1],stresses[0][1]))))
# Stress23_ramp2= np.sqrt(np.average(np.square(np.subtract(stresses[2][1],stresses[1][1]))))
# Stress31_ramp2= np.sqrt(np.average(np.square(np.subtract(stresses[2][1],stresses[0][1]))))


