import sys
import numpy as np
import matplotlib.pyplot as plt
# from scipy import signal processing function. Not all used.
from scipy.signal import butter, lfilter, freqz, filtfilt
from scipy.optimize import curve_fit

# these are just for finding file paths
import os
import ntpath
import xmltodict
from scipy import stats
#/home/klonowe/anaconda3/bin/python /home/klonowe/Documents/NewFolder/Python/tissuecheck_version3.py Documents/NewFolder/Data/TissueTestingInput.xml
        #run this from the terminal


with open ('/Users/klonowe/TissueTesting/app/TissueTesting/TissueTestingInput.xml') as fd:
    Data_dictionary=xmltodict.parse(fd.read())
    Data_dictionary=Data_dictionary["TissueTest"]

    root_elements=Data_dictionary["Test1"] if type(Data_dictionary["Test1"])==list else [Data_dictionary["Test1"]]
    for element in root_elements:
        file1=(element["file"])
        reflength1_1 =float((element["InitialLength"]))
        reflength1_2 = float((element["InitialLength2"]))
        threshold=int((element["Threshold"]))
        width1=float((element["Width"]))
        thickness1=float(element["Thickness"])
        fs = element["Frequency"]

    root_elements = Data_dictionary["Test2"] if type(Data_dictionary["Test2"]) == list else [Data_dictionary["Test2"]]
    for element in root_elements:
        file2=(element["file"])
        reflength2_1 = float((element["InitialLength"]))
        reflength2_2 = float((element["InitialLength2"]))
        threshold=int((element["Threshold"]))
        width2=float((element["Width"]))
        thickness2 = float(element["Thickness"]) #thickness of sample from optical measurement
        fs= float(element["Frequency"])

    root_elements = Data_dictionary["Test3"] if type(Data_dictionary["Test3"]) == list else [Data_dictionary["Test3"]]
    for element in root_elements:
        file3 = (element["file"])
        reflength3_1 = float((element["InitialLength"]))
        reflength3_2 = float((element["InitialLength2"]))
        threshold = int((element["Threshold"]))
        width3 = float((element["Width"]))
        thickness3 = float(element["Thickness"])  # thickness of sample from optical measurement
        fs = float(element["Frequency"])


dir = '/Users/klonowe/TissueTesting/app/TissueTesting/Data/'
# thickness=2.41
# threshold=200
# #
files = [file1,file2,file3]
# fs=2500
fig2, (ax3, ax4) = plt.subplots(nrows=2, ncols=1, figsize=(18, 9), dpi=90)
fig3, (ax5, ax8) = plt.subplots(nrows=1, ncols=2, figsize=(18, 8), dpi=90) #load v disp and stress v strain at ramp before and after
masterdisp=[]
masterforce=[]
moduli=[]
stresses=[]
strains=[]

filecount=0
for file in files:

    input_filename = dir+file
    infile = open(input_filename, 'r')

    if file == files[0]:    #first test specs
        init_length = reflength1_1                                       #initial length of the specimen
        init_length2 = reflength1_2
        thickness=thickness1
        width = width1
    elif file == files[1]:
        init_length = reflength2_1
        init_length2 = reflength2_2
        thickness = thickness2
        width = width2
    else:                                                   #second test specs
        init_length = reflength3_1                                #initial length of the second specimen, add elif if more than two sample comparison
        init_length2 = reflength3_2
        thickness = thickness3
        width = width3

    area = float(thickness*width)  #area

    # --------------------
    # pick all data for move relative and wait commands'
    rampcount = 0
    ramplist = []       #name of each ramp

    for line in infile:  # read each line to create ramplist
        if line[1:5] == 'Move':
            rampcount += 1
            ramplist.append(line)
            print(line)
        if line[1:4] == 'Sin':
            rampcount += 1
            ramplist.append(line)
            print(line)

    infile.close()

    infile = open(input_filename, 'r')
    ramps = [[] for x in range(rampcount)]
    sins = []   #[time,z, x, y, gf]
    i = -1
    f = -1
    sn = -1
    for line in infile:  # read each line
        if line[1:5] == 'Move':
            i += 1
            for _ in range(6):
                line = next(infile)
                a = line[0]
                while a.isdigit():
                    #                   line = line.rstrip(',,,\n')
                    ramps[i].append(line)
                    line = next(infile)
                    a = line[0]

        if line[1:4] == 'Sin':
            sn += 1
            for _ in range(8):
                line = next(infile)
                b = line[0]
                while b.isdigit():
                    #                     line = line.rstrip(',,,\n')
                    sins.append(line)
                    line = next(infile)
                    b = line[0]
    ramps = ramps[0:-1]

    rampcount = len(ramps)
    infile.close()

    vel = []            #velocity for each ramp (9 items long)
    infile = open(input_filename, 'r')
    for line in infile:
        if line[0:8] == 'Velocity':
            vel.append(line[16:22])
            # print vel
            line = next(infile)
    infile.close()
    print(str("----------"))
    print("Expected rate, mm/s =" + str(vel[-2]))

    # separate time, disp and load from all move relative and wait command data

    r_time = [[] for x in range(rampcount)]
    r_disp = [[] for x in range(rampcount)]
    r_load = [[] for x in range(rampcount)]

    sin_disp = []
    sin_load = []
    sin_time = []
    loadindex=-1
    for j in range(0, rampcount):               #These blocks further seperate ramps and sin arrays into single arrays
        for line in ramps[j]:
            line = line.split()
            line = list(map(float, line))

            r_time[j].append(line[0])
            #r_disp[j].append(abs(line[1]))
            r_disp[j].append(abs(line[1]))
            r_load[j].append(abs(line[loadindex]))
            #line = next(infile)                        #Do i need this?

    for line in sins:
        line = line.split()
        line = list(map(float, line))
        sin_time.append(line[0])
        sin_disp.append(abs(line[1]))
        sin_load.append(abs(line[loadindex]))
        #line = next(infile)              #Do i need this?

    # -------------------------------------------
    # low pass butterworth filter, 3rd order, 100 hz cutoff freq

    def butter_lowpass(cutoff, fs, order=3):        #filter out data greater than a certain frequency
        nyq = 0.5 * fs
        normal_cutoff = cutoff / nyq
        b, a = butter(order, normal_cutoff, btype='low', analog=False)
        return b, a


    def butter_lowpass_filter(data, cutoff, fs, order):
        b, a = butter_lowpass(cutoff, fs, order=order)
        #         y = lfilter(b, a, data)
        y = filtfilt(b, a, data, method="gust")
        return y


    # Filter
    order = 3
    cutoff = 20  # desired cutoff frequency of the filter, Hz

    # filter coefficients to check its frequency response
    b, a = butter_lowpass(cutoff, fs, order)

    filtered_r_load = [[] for x in range(rampcount)]
    filtered_r_disp = [[] for x in range(rampcount)]
    filtered_sin_load = []
    #print(rampcount)

    # Apply the filter.
    for r in range(0, rampcount):
        filtered_r_load[r] = butter_lowpass_filter(r_load[r], cutoff, fs, order)
        filtered_r_disp[r] = butter_lowpass_filter(r_disp[r], cutoff, fs, order)

    filtered_sin_load = butter_lowpass_filter(sin_load, cutoff, fs, order)
    sin_disp_orig = sin_disp
    sin_disp = [((x - r_disp[2][0]) - np.multiply(init_length,.02)) for x in sin_disp]  # accommodate 300 micron buffer #### **will have to change for new sequence

    def concatenateTime(r_time):
        time = [[] for x in range(len(r_time))]
        for i in range(len(r_time)):
            # print i
            if i==0:
                time[0]=np.array(r_time[i])
            if i>0:
                time[i]= np.array(r_time[i])+np.array(time[i-1][-1])
        return time

    time = concatenateTime(r_time)

    fig1, (ax1, ax2) = plt.subplots(nrows=2, ncols=1, figsize=(18, 9), dpi=90)
    fig4, (ax7)= plt.subplots(nrows=1, ncols=1, figsize=(18, 9), dpi=90)  # load v time during preconditioning
    #ax1.set_title(file)

    ax3.set_title(file)
    ax4.set_title(file)
    ax5.set_title(file)
    # ax6.set_title(file)
    #ax7.set_title(file)

    fig1.tight_layout(pad=4, w_pad=.5, h_pad=2)
    fig2.tight_layout(pad=4, w_pad=.5, h_pad=2)

    #fig1
    ax1.set_xlabel('time,s')
    ax1.set_ylabel('load, grams')
    ax2.set_xlabel('time,s')
    ax2.set_ylabel('disp,mm')

    #fig2
    ax3.set_xlabel('time,s')
    ax3.set_ylabel('load,grams')
    ax4.set_xlabel('time,s')
    ax4.set_ylabel('disp,mm')

    #fig3
    ax5.set_xlabel('disp,mm')
    ax5.set_ylabel('load,grams')

    #fig4
    ax7.set_xlabel('time,s')
    ax7.set_ylabel('load,grams')
    ax7.plot(sin_time, filtered_sin_load,label=file)
    ax7.set_title(file)

    disp_adjustment = []
    cnt = 0
    redo=[]

    for i in filtered_r_load:
        #print(filtered_r_load)
        # i = np.array(i)
        # print max(i)
        result = np.where(i<threshold)[0]
        maxindex = max(result)                  #max index of filtered load where load is less than threshold
        minindex = min(result)                  #min index of filtered load where load is less than threshold

        #print(i), print (i[maxindex]), print(i[minindex])
        if i[maxindex] < i[minindex]:
            disp_adjustment.append([minindex, i[minindex]])  #min index, and value at that index
        else:
            disp_adjustment.append([maxindex, i[maxindex]])  #disp_adjustment has all lists of ramps with the [max index, threshold] except the second list is move absolute [zero,?]

    norm_f_disp = [[] for x in range(rampcount)]
    # for r in range(0, rampcount):
    #     norm_f_disp[r] = [x - float(filtered_r_disp[2][0]+.3) for x in filtered_r_disp[r]]
    # norm_f_disp = [[] for x in range(rampcount)]
    #rmsload=[]
    for r in range(0, rampcount):
        startindex = disp_adjustment[r][0]                  #startindex is the first index value of the adjusted displacement indicies
        #print ("startindex" + str(startindex)), print((np.int(-(.1)*len(norm_f_disp[r])))), print(len(norm_f_disp[r]))
        norm_f_disp[r] = [x - float(filtered_r_disp[r][0]) for x in filtered_r_disp[r]]         #each movemoment list starts with zero
        #if r > 1:
            #rmsload[r] = filtered_r_load[r][startindex:(np.int(-(.1)*len(norm_f_disp[r])))]
    ###masterdisp.append()

    norm_r_disp = [[] for x in range(rampcount)] #for left figure force v disp
    for r in range(0, rampcount):
        norm_r_disp[r] = [x - (init_length) for x in r_disp[r]]    # z displacement

    norm_r_disp2 = [[] for x in range(rampcount)]  # for left figure force v disp
    for r in range(0, rampcount):
        norm_r_disp2[r] = [x - (init_length2) for x in r_disp[r]]  # z displacement

    a = 2
    b = 3
    strain = []
    stress = []   #nominal stress
    cnt = 0
    for i in range(a, b):  # this is the first ramp on the material
        strain.append([c / float(init_length) for c in norm_r_disp[i]])  # norm f disp is used for strain norm r disp is used for stress with 300 micron buffer
        stress.append([(c * .0098) / float(area) for c in filtered_r_load[i]])
        # stress = [c * 0.0098 for c in stress[cnt]]  # gf to N
        cnt += 1

    aa = 9
    bb = 10
    cnt = 0
    for i in range(aa, bb):  # second ramp load after preconditioning
        strain.append([c / float(init_length2) for c in norm_r_disp2[i]])
        stress.append([(c * .0098) / float(area) for c in filtered_r_load[i]])
        # stress = [c * 0.0098 for c in stress[cnt]]  # gf to N
        cnt += 1

    stresses.insert(1, list(stress))
    strains.insert(1, list(strain))

    # print("Max Load Ramp 1 (g)" + str(file))
    # print(max(filtered_r_load[2]))
    # print("Max Load Ramp 2 (g)")
    # print(max(filtered_r_load[aa]))


    ax5.plot(norm_r_disp[2], filtered_r_load[2], label=file+ " Ramp 1")
    ax5.plot(norm_r_disp[3], filtered_r_load[3],ls='--',label=file+" Ramp down 1")
    ax5.plot(norm_r_disp2[aa], filtered_r_load[aa],label=file+" Ramp 2")
    ax5.plot(norm_r_disp2[bb], filtered_r_load[bb],ls='--' ,label=file+" Ramp down 2 down ")
    ax5.set_title('Force vs. Displacement')
    ax5.legend(loc=2)

    # for i in range(len(stress)):
    #     if file==files[0]:
    #         ax6.plot(strain[i], stress[i], c='r', label='Test 1 Ramp 1')
    #         ax6.plot(strain[i], stress[i], c='m', label='Test 1 Ramp 2')
    #
    #     elif file == files[1]:
    #         ax6.plot(strain[i], stress[i], c='b', label='Test 2 Ramp 1')
    #         ax6.plot(strain[i], stress[i], c='c',ls='--', label='Test 2 Ramp 2')
    #         ax6.legend(loc=2)
    #     else:
    #         ax6.plot(strain[i], stress[i], c='g', label='Test 3 Ramp 1')
    #         ax6.plot(strain[i], stress[i], c='y',ls='--', label='Test 3 Ramp 2')
    #         ax6.legend(loc=2)
    #         ax6.set_xlabel('strain')
    #         ax6.set_ylabel('stress MPa')
    #         # fig4

    for i in range(len(time)):          #figure1
        if i>=2 and i<4:
            ax3.plot(time[i], filtered_r_load[i])
        if i>5 and i<=7:
            ax3.plot(time[i], filtered_r_load[i])
        if i>=2:
            ax3.plot(time[i], filtered_r_load[i])
            ax4.plot(time[i], filtered_r_disp[i])

    A = 5
    totaltime = r_time
    totaltime.insert(A,list(sin_time))
    totaltime = concatenateTime(totaltime)

    allLoads = filtered_r_load
    allLoads.insert(A,list(sin_load))

    allDisp = filtered_r_disp
    allDisp.insert(A,list(sin_disp_orig))

    # print len(totaltime)
    # exit()

    for i in range(len(totaltime)):
        ax1.plot(totaltime[i], allLoads[i])
    for i in range(len(totaltime)):
        ax2.plot(totaltime[i], allDisp[i])
        ax1.set_title(file)

    filecount+=1
    # plt.figure(3)
    # plt.plot(sin_disp, filtered_sin_load)
    # plt.xlabel('disp,mm')
    # plt.ylabel('load,gf')
    #
    # plt.figure(4)
    # plt.plot(sin_time, filtered_sin_load)
    # plt.xlabel('time,s')
    # plt.ylabel('load,gf')

    filtered_r_load
    filtered_r_disp


# strainarr=np.array(strains[0][1])
# index = np.where((strainarr > .10) & (strainarr > 0))
# newstrains=np.array(strains[index])
# adjust=strains[index]
#index = np.where(np.logical_and(latfemoral_surface[:, 1] > 0, latfemoral_surface[:, 1] < 40.0))

# # #Ramp1 Stress RSMEs
# Stress12= np.sqrt(np.average(np.square(np.subtract(stresses[1][0],stresses[0][0]))))
# Stress23= np.sqrt(np.average(np.square(np.subtract(stresses[2][0],stresses[1][0]))))
# Stress31= np.sqrt(np.average(np.square(np.subtract(stresses[2][0],stresses[0][0]))))
# print(str(Stress12) + " RSME Test 1 to 2 ramp 1 (MPa)")        #R is a measure of fit, RSME is a measure of absolute fit
# print(str(Stress23) + " RSME Test 2 to 3 ramp 1 (MPa)")
# print(str(Stress31) +" RSME Test 3 to 1 ramp 1 (MPa)")

# Ramp2 Stress RSMEs  #revise this!! calculate RSME from same range of points as slope from lin regress line
Stress12_ramp2= np.sqrt(np.average(np.square(np.subtract(stresses[1][1],stresses[0][1]))))
Stress23_ramp2= np.sqrt(np.average(np.square(np.subtract(stresses[2][1],stresses[1][1]))))
Stress31_ramp2= np.sqrt(np.average(np.square(np.subtract(stresses[2][1],stresses[0][1]))))
print(str(Stress12_ramp2) + " RSME 1 Test to 2 Ramp 2 (MPa)")
print(str(Stress23_ramp2) + " RSME 2 Test to 3 Ramp 2 (MPa)")
print(str(Stress31_ramp2) + " RSME 3 Test to 1 Ramp 2 (MPa)")
##
MOD1max=stresses[0][1].index(max(stresses[0][1]))
MOD2max=stresses[1][1].index(max(stresses[1][1]))
MOD3max=stresses[2][1].index(max(stresses[2][1]))
# fig5,(ax9) = plt.subplots(nrows=1, ncols=1, figsize=(18, 9), dpi=90)
# #slope is the slope of the regression line
# #intercept is the intercept of the regression line
# #r value returns the correlation coefficient
# #Two-sided p-value for a hypothesis test whose null hypothesis is that the slope is zero
# #Standard error of the estimated gradient
slope, intercept, r_value, p_value, std_err=stats.linregress(strains[0][1][1725:int(MOD1max)],stresses[0][1][1725:int(MOD1max)])
slope1, intercept1, r_value1, p_value1, std_err1=stats.linregress(strains[1][1][1725:int(MOD2max)],stresses[1][1][1725:int(MOD2max)])
slope2, intercept2, r_value2, p_value2, std_err2=stats.linregress(strains[2][1][1725:int(MOD3max)],stresses[2][1][1725:int(MOD3max)])
print("Young's Moduli")
print(str(slope) + "(MPa)")

print(str(slope1) + "(MPa)")
print(str(slope2)+ "(MPa)")

print(file1)
print(file2)
print(file3)
print(max(stresses[0][1]))
print(max(stresses[1][1]))
print(max(stresses[2][1]))

# print(str(r_value) + "r squared 1")
# print(str(r_value1) + "r squared 2")
# print(str(r_value2) + "r squared 3")

ax8.plot(strains[0][0], stresses[0][0], c='m', label="Test 1 Ramp1 ")
ax8.plot(strains[1][0], stresses[1][0], c='c', label="Test 2 Ramp1 ")
ax8.plot(strains[2][0], stresses[2][0], c='y', label="Test 3 Ramp1 ")
ax8.plot(strains[0][1], stresses[0][1], c='m',linestyle='dashed')
ax8.plot(strains[1][1], stresses[1][1], c='c',linestyle='dashed')
ax8.plot(strains[2][1], stresses[2][1], c='y',linestyle='dashed')
ax8.set_xlabel('strain')
ax8.set_ylabel('stress MPa')
ax8.legend(loc=2)



# # print(str(Mod1),str(Mod2),str(Mod3))
plt.show()