import sys
import numpy as np
import matplotlib.pyplot as plt
# from scipy import signal processing function. Not all used.
from scipy.signal import butter, lfilter, freqz, filtfilt
from scipy.optimize import curve_fit

# these are just for finding file paths
import os
import ntpath
import xmltodict
from scipy import stats
#/home/klonowe/anaconda3/bin/python /home/klonowe/Documents/NewFolder/Python/tissuecheck_version3.py Documents/NewFolder/Data/TissueTestingInput.xml
        #run this from the terminal

with open ('/Users/klonowe/TissueTesting/app/TissueTesting/Data/TissueTestingInput.xml') as fd:
    Data_dictionary=xmltodict.parse(fd.read())
    Data_dictionary=Data_dictionary["TissueTest"]

    root_elements=Data_dictionary["Test1"] if type(Data_dictionary["Test1"])==list else [Data_dictionary["Test1"]]
    for element in root_elements:
        file1=(element["file"])
        init_length1 =float((element["InitialLength"]))
        threshold=int((element["Threshold"]))
        width1=float((element["Width"]))
        thickness=float(element["Thickness"])
        fs = element["Frequency"]

    root_elements = Data_dictionary["Test2"] if type(Data_dictionary["Test2"]) == list else [Data_dictionary["Test2"]]
    for element in root_elements:
        file2=(element["file"])
        init_length2 = float((element["InitialLength"]))
        threshold=int((element["Threshold"]))
        width2=float((element["Width"]))
        thickness = float(element["Thickness"]) #thickness of sample from optical measurement
        fs= float(element["Frequency"])

    root_elements = Data_dictionary["Test3"] if type(Data_dictionary["Test3"]) == list else [Data_dictionary["Test3"]]
    for element in root_elements:
        file3 = (element["file"])
        init_length3 = float((element["InitialLength"]))
        threshold = int((element["Threshold"]))
        width3 = float((element["Width"]))
        thickness = float(element["Thickness"])  # thickness of sample from optical measurement
        fs = float(element["Frequency"])


#dir = '/home/klonowe/Documents/NewFolder/Data/TissueTestingPractice/'
dir='/Users/klonowe/TissueTesting/app/TissueTesting/Data/'
thickness=2.22
threshold=200
# #
files = [file1,file2,file3]
fs=2500
fig2, (ax3, ax4) = plt.subplots(nrows=2, ncols=1, figsize=(18, 9), dpi=90)
fig3, (ax5, ax8) = plt.subplots(nrows=1, ncols=2, figsize=(18, 8), dpi=90) #load v disp and stress v strain at ramp before and after
masterdisp=[]
masterforce=[]
moduli=[]
stresses=[]
strains=[]

filecount=0
for file in files:

    input_filename = dir+file
    infile = open(input_filename, 'r')

    if file == files[0]:    #first test specs
        init_length = init_length1                           # length               #initial length of the specimen
        width = 5.0
    elif file==files[1]:
        init_length= init_length2
        width=5.0
    else:                                                   #second test specs
        init_length =init_length3                                  #initial length of the second specimen, add elif if more than two sample comparison
        width=5.0

    area = float(thickness*width)  #area

    # --------------------
    # pick all data for move relative and wait commands'
    rampcount = 0
    ramplist = []

    for line in infile:  # read each line
        if line[1:5] == 'Move':
            rampcount += 1
            ramplist.append(line)
            print(line)
        if line[1:4] == 'Sin':
            rampcount += 1
            ramplist.append(line)
            print(line)
    infile.close()

    infile = open(input_filename, 'r')
    ramps = [[] for x in range(rampcount)]
    sins = []
    i = -1
    f = -1
    sn = -1
    for line in infile:  # read each line
        if line[1:5] == 'Move':
            i += 1
            for _ in range(6):
                line = next(infile)
                a = line[0]
                while a.isdigit():
                    #                   line = line.rstrip(',,,\n')
                    ramps[i].append(line)
                    line = next(infile)
                    a = line[0]

        if line[1:4] == 'Sin':
            sn += 1
            for _ in range(8):
                line = next(infile)
                b = line[0]
                while b.isdigit():
                    #                     line = line.rstrip(',,,\n')
                    sins.append(line)
                    line = next(infile)
                    b = line[0]
    ramps = ramps[0:-1]

    rampcount = len(ramps)
    infile.close()

    vel = []
    infile = open(input_filename, 'r')
    for line in infile:
        if line[0:8] == 'Velocity':
            vel.append(line[16:22])
            # print vel
            line = next(infile)
    infile.close()
    print(str("----------"))
    print("Expected rate, mm/s =" + str(vel[-2]))

    # separate time, disp and load from all move relative and wait command data

    r_time = [[] for x in range(rampcount)]
    r_disp = [[] for x in range(rampcount)]
    r_load = [[] for x in range(rampcount)]

    sin_disp = []
    sin_load = []
    sin_time = []
    loadindex=-1
    for j in range(0, rampcount):
        for line in ramps[j]:
            line = line.split()
            line = list(map(float, line))

            r_time[j].append(line[0])
            #r_disp[j].append(abs(line[1]))
            r_disp[j].append(abs(line[1]))
            r_load[j].append(abs(line[loadindex]))
            #line = next(infile)                        #Do i need this?

    for line in sins:
        line = line.split()
        line = list(map(float, line))
        sin_time.append(line[0])
        sin_disp.append(abs(line[1]))
        sin_load.append(abs(line[loadindex]))
        #line = next(infile)              #Do i need this?

    # -------------------------------------------
    # low pass butterworth filter, 3rd order, 100 hz cutoff freq

    def butter_lowpass(cutoff, fs, order=3):
        nyq = 0.5 * fs
        normal_cutoff = cutoff / nyq
        b, a = butter(order, normal_cutoff, btype='low', analog=False)
        return b, a


    def butter_lowpass_filter(data, cutoff, fs, order):
        b, a = butter_lowpass(cutoff, fs, order=order)
        #         y = lfilter(b, a, data)
        y = filtfilt(b, a, data, method="gust")
        return y


    # Filter
    order = 3
    cutoff = 20  # desired cutoff frequency of the filter, Hz

    # filter coefficients to check its frequency response
    b, a = butter_lowpass(cutoff, fs, order)

    filtered_r_load = [[] for x in range(rampcount)]
    filtered_r_disp = [[] for x in range(rampcount)]
    filtered_sin_load = []
    #print(rampcount)

    # Apply the filter.
    for r in range(0, rampcount):
        filtered_r_load[r] = butter_lowpass_filter(r_load[r], cutoff, fs, order)
        filtered_r_disp[r] = butter_lowpass_filter(r_disp[r], cutoff, fs, order)

    filtered_sin_load = butter_lowpass_filter(sin_load, cutoff, fs, order)
    sin_disp_orig = sin_disp
    sin_disp = [((x - r_disp[2][0]) - 0.3) for x in sin_disp]  # accommodate 300 micron buffer

    def concatenateTime(r_time):
        time = [[] for x in range(len(r_time))]
        for i in range(len(r_time)):
            # print i
            if i==0:
                time[0]=np.array(r_time[i])
            if i>0:
                time[i]= np.array(r_time[i])+np.array(time[i-1][-1])
        return time

    time = concatenateTime(r_time)

    fig1, (ax1, ax2) = plt.subplots(nrows=2, ncols=1, figsize=(18, 9), dpi=90)
    fig4, (ax7)= plt.subplots(nrows=1, ncols=1, figsize=(18, 9), dpi=90)  # load v time during preconditioning
    #ax1.set_title(file)

    ax3.set_title(file)
    ax4.set_title(file)
    ax5.set_title(file)
    # ax6.set_title(file)
    #ax7.set_title(file)

    fig1.tight_layout(pad=4, w_pad=.5, h_pad=2)
    fig2.tight_layout(pad=4, w_pad=.5, h_pad=2)

    #fig1
    ax1.set_xlabel('time,s')
    ax1.set_ylabel('load, grams')
    ax2.set_xlabel('time,s')
    ax2.set_ylabel('disp,mm')

    #fig2
    ax3.set_xlabel('time,s')
    ax3.set_ylabel('load,grams')
    ax4.set_xlabel('time,s')
    ax4.set_ylabel('disp,mm')

    #fig3
    ax5.set_xlabel('disp,mm')
    ax5.set_ylabel('load,grams')

    #fig4
    ax7.set_xlabel('time,s')
    ax7.set_ylabel('load,grams')
    ax7.plot(sin_time, filtered_sin_load,label=file)
    ax7.set_title(file)

    disp_adjustment = []
    cnt = 0
    redo=[]

    for i in filtered_r_load:
        #print(filtered_r_load)
        # i = np.array(i)
        # print max(i)
        result = np.where(i<threshold)[0]
        maxindex = max(result)                  #max index of filtered load where load is less than threshold
        minindex = min(result)                  #min index of filtered load where load is less than threshold

        #print(i), print (i[maxindex]), print(i[minindex])
        if i[maxindex] < i[minindex]:
            disp_adjustment.append([minindex, i[minindex]])  #min index, and value at that index
        else:
            disp_adjustment.append([maxindex, i[maxindex]])  #disp_adjustment has all lists of ramps with the [max index, threshold] except the second list is move absolute [zero,?]

    norm_f_disp = [[] for x in range(rampcount)]
    # for r in range(0, rampcount):
    #     norm_f_disp[r] = [x - float(filtered_r_disp[2][0]+.3) for x in filtered_r_disp[r]]
    # norm_f_disp = [[] for x in range(rampcount)]
    #rmsload=[]
    for r in range(0, rampcount):
        startindex = disp_adjustment[r][0]                  #startindex is the first index value of the adjusted displacement indicies
        #print ("startindex" + str(startindex)), print((np.int(-(.1)*len(norm_f_disp[r])))), print(len(norm_f_disp[r]))
        norm_f_disp[r] = [x - float(filtered_r_disp[r][0]) for x in filtered_r_disp[r]]         #each movemoment list starts with zero
        #if r > 1:
            #rmsload[r] = filtered_r_load[r][startindex:(np.int(-(.1)*len(norm_f_disp[r])))]
    ###masterdisp.append()

    norm_r_disp = [[] for x in range(rampcount)] #for left figure force v disp      #
    for r in range(0, rampcount):
        norm_r_disp[r] = [x - float(r_disp[2][0]+.3) for x in r_disp[r]]    #finds the change in length from initial length

    PRINT=r_disp[2][0]

    a = 2
    b = 3
    strain = []
    stress = []
    cnt = 0
    for i in range(a, b):  # this is the first ramp on the material
        strain.append([c / float(init_length) for c in norm_r_disp[i]])  # norm f disp is used for strain norm r disp is used for stress with 300 micron buffer
        stress.append([(c * .0098) / float(area) for c in filtered_r_load[i]])
        # stress = [c * 0.0098 for c in stress[cnt]]  # gf to N
        cnt += 1

    aa = 6
    bb = 7
    cnt = 0
    for i in range(aa, bb):  # second ramp load after preconditioning
        strain.append([c / float(init_length) for c in norm_f_disp[i]])
        stress.append([(c * .0098) / float(area) for c in filtered_r_load[i]])
        # stress = [c * 0.0098 for c in stress[cnt]]  # gf to N
        cnt += 1

    stresses.insert(1, list(stress))
    strains.insert(1, list(strain))

    print("Max Load Ramp 1 (g)" + str(file))
    print(max(filtered_r_load[2]))
    print("Max Load Ramp 2 (g)")
    print(max(filtered_r_load[aa]))


    ax5.plot(norm_r_disp[2], filtered_r_load[2], label=file)
    ax5.plot(norm_r_disp[3], filtered_r_load[3],ls='--')
    ax5.plot(norm_r_disp[aa], filtered_r_load[aa],label=file)
    ax5.plot(norm_r_disp[bb], filtered_r_load[bb],ls='--')
    ax5.set_title('Force vs. Displacement')
    ax5.legend(loc=2)

    # for i in range(len(stress)):
    #     if file==files[0]:
    #         ax6.plot(strain[i], stress[i], c='r', label='Test 1 Ramp 1')
    #         ax6.plot(strain[i], stress[i], c='m', label='Test 1 Ramp 2')
    #
    #     elif file == files[1]:
    #         ax6.plot(strain[i], stress[i], c='b', label='Test 2 Ramp 1')
    #         ax6.plot(strain[i], stress[i], c='c',ls='--', label='Test 2 Ramp 2')
    #         ax6.legend(loc=2)
    #     else:
    #         ax6.plot(strain[i], stress[i], c='g', label='Test 3 Ramp 1')
    #         ax6.plot(strain[i], stress[i], c='y',ls='--', label='Test 3 Ramp 2')
    #         ax6.legend(loc=2)
    #         ax6.set_xlabel('strain')
    #         ax6.set_ylabel('stress MPa')
    #         # fig4

    for i in range(len(time)):          #figure1
        if i>=2 and i<4:
            ax3.plot(time[i], filtered_r_load[i])
        if i>5 and i<=7:
            ax3.plot(time[i], filtered_r_load[i])
        if i>=2:
            ax3.plot(time[i], filtered_r_load[i])
            ax4.plot(time[i], filtered_r_disp[i])

    A = 5
    totaltime = r_time
    totaltime.insert(A,list(sin_time))
    totaltime = concatenateTime(totaltime)

    allLoads = filtered_r_load
    allLoads.insert(A,list(sin_load))

    allDisp = filtered_r_disp
    allDisp.insert(A,list(sin_disp_orig))

    # print len(totaltime)
    # exit()

    for i in range(len(totaltime)):
        ax1.plot(totaltime[i], allLoads[i])
    for i in range(len(totaltime)):
        ax2.plot(totaltime[i], allDisp[i])
        ax1.set_title(file)

    filecount+=1
    # plt.figure(3)
    # plt.plot(sin_disp, filtered_sin_load)
    # plt.xlabel('disp,mm')
    # plt.ylabel('load,gf')
    #
    # plt.figure(4)
    # plt.plot(sin_time, filtered_sin_load)
    # plt.xlabel('time,s')
    # plt.ylabel('load,gf')

    filtered_r_load
    filtered_r_disp

ax8.plot(strains[0][0], stresses[0][0], c='b', label=file)
ax8.plot(strains[1][0], stresses[1][0], c='r', label=file)
ax8.plot(strains[2][0], stresses[2][0], c='g', label=file)
ax8.plot(strains[0][1], stresses[0][1], c='b', ls='--')
ax8.plot(strains[1][1], stresses[1][1], c='r', ls='--')
ax8.plot(strains[2][1], stresses[2][1], c='g', ls='--')
ax8.set_xlabel('strain')
ax8.set_ylabel('stress MPa')
ax8.legend(loc=2)

# strainarr=np.array(strains[0][1])
# index = np.where((strainarr > .10) & (strainarr > 0))
# newstrains=np.array(strains[index])
# adjust=strains[index]
#index = np.where(np.logical_and(latfemoral_surface[:, 1] > 0, latfemoral_surface[:, 1] < 40.0))

# # #Ramp1 Stress RSMEs
Stress12= np.sqrt(np.average(np.square(np.subtract(stresses[1][0],stresses[0][0]))))
Stress23= np.sqrt(np.average(np.square(np.subtract(stresses[2][0],stresses[1][0]))))
Stress31= np.sqrt(np.average(np.square(np.subtract(stresses[2][0],stresses[0][0]))))
print(str(Stress12) + " RSME Test 1 to 2 ramp 1 (MPa)")        #R is a measure of fit, RSME is a measure of absolute fit
print(str(Stress23) + " RSME Test 2 to 3 ramp 1 (MPa)")
print(str(Stress31) +" RSME Test 3 to 1 ramp 1 (MPa)")

# Ramp2 Stress RSMEs
Stress12_ramp2= np.sqrt(np.average(np.square(np.subtract(stresses[1][1],stresses[0][1]))))
Stress23_ramp2= np.sqrt(np.average(np.square(np.subtract(stresses[2][1],stresses[1][1]))))
Stress31_ramp2= np.sqrt(np.average(np.square(np.subtract(stresses[2][1],stresses[0][1]))))
print(str(Stress12_ramp2) + " RSME 1 Test to 2 Ramp 2 (MPa)")
print(str(Stress23_ramp2) + " RSME 2 Test to 3 Ramp 2 (MPa)")
print(str(Stress31_ramp2) + " RSME 3 Test to 1 Ramp 2 (MPa)")
#
# #Returns a vector of polynomial coefficients in descending order, should be only 2 values
# #Second value in the vector ..y intercept
# #The index represents the top third of the curve, should be .30
# poly1=np.polyfit(strains[0][1], stresses[0][1],1)
# poly2=np.polyfit(strains[1][1], stresses[1][1],1)
# poly3=np.polyfit(strains[2][1], stresses[2][1],1)
# print(str(poly1),str(poly2),str(poly3))
#
MOD1max=stresses[0][1].index(max(stresses[0][1]))
MOD2max=stresses[1][1].index(max(stresses[1][1]))
MOD3max=stresses[2][1].index(max(stresses[2][1]))
# fig5,(ax9) = plt.subplots(nrows=1, ncols=1, figsize=(18, 9), dpi=90)
# #slope is the slope of the regression line
# #intercept is the intercept of the regression line
# #r value returns the correlation coefficient
# #Two-sided p-value for a hypothesis test whose null hypothesis is that the slope is zero
# #Standard error of the estimated gradient
slope, intercept, r_value, p_value, std_err=stats.linregress(stresses[0][1][1725:int(MOD1max)], strains[0][1][1725:int(MOD1max)])
slope1, intercept1, r_value1, p_value1, std_err1=stats.linregress(stresses[1][1][1725:int(MOD1max)], strains[1][1][1725:int(MOD1max)])
slope2, intercept2, r_value2, p_value2, std_err2=stats.linregress(stresses[2][1][1725:int(MOD1max)], strains[2][1][1725:int(MOD1max)])

print(str(slope) + "(MPa)")
print(str(slope1) + "(MPa)")
print(str(slope2) + "(MPa)")
#
# ax9.plot(strains[0][1][1725:int(MOD1max)], stresses[0][1][1725:int(MOD1max)], c='r',label='first test second ramp')
# ax9.plot(strains[1][1][1725:int(MOD2max)], stresses[1][1][1725:int(MOD2max)], c='b', label='second test second ramp')
# ax9.plot(strains[2][1][1725:int(MOD3max)], stresses[2][1][1725:int(MOD3max)], c='g', label='third test second ramp')
# #see stats.linregress documentation. the plots below plot the linear regression
# ax9.plot(strains[0][1][1725:int(MOD1max)], intercept + np.multiply(slope,strains[0][1][1725:int(MOD1max)]), 'r',ls='--', label='fitted line test 1')
# ax9.plot(strains[1][1][1725:int(MOD2max)], intercept1 + np.multiply(slope1,strains[1][1][1725:int(MOD2max)]), 'b',ls='--', label='fitted line test 2')
# ax9.plot(strains[2][1][1725:int(MOD3max)], intercept2 + np.multiply(slope2,strains[2][1][1725:int(MOD3max)]), 'g',ls='--', label='fitted line test 3')
# ax9.set_xlabel('strain')
# ax9.set_ylabel('stress MPa')
# ax9.legend(loc=2)

#returns index of max stress value of ramp 2

#add the modulus
# #modulus using the values bewtween the 2000th value and the max value
# Mod1=np.average(np.divide(stresses[0][1][1725:int(MOD1max)], strains[0][1][1725:int(MOD1max)])) #to find the top third index of the ramp 2 stresses
# Mod2=np.average(np.divide(stresses[1][1][1725:int(MOD2max)], strains[1][1][1725:int(MOD2max)])) # adjust the indicies
# Mod3=np.average(np.divide(stresses[2][1][1725:int(MOD3max)], strains[2][1][1725:int(MOD3max)]))
#
# print("Modulus 1 " + str(Mod1) + " MPa")
# print("Modulus 2 " + str(Mod2) + " MPa")
# print("Modulus 3 " + str(Mod3) + " MPa")
#
# print("Max Stress Test 1 Ramp 2")
# print(max(stresses[0][1]))
# print("Max Stress Test 2 Ramp 2")
# print(max(stresses[1][1]))
# print("Max Stress Test 3 Ramp 2")
# print(max(stresses[2][1]))
#

#
# def fit(x,a,b):
#     return a*x+b
# #will return a numpy array containing two arrays: the first will contain values for a and b that best fit your data, and the second will be the covariance of the optimal fit parameters.
# parameters=curve_fit(fit,strains[0][1][1750:int(MOD1max)], stresses[0][1][1750:int(MOD1max)])
# [a,b]=parameters[0]
# #[1750:MOD1max]
# print(a,b)

# print(files)

# plt.errorbar(strains[0][1][2000:int(MOD1max)],stresses[0][1][2000:int(MOD1max)])
# ax9.scatter(strains[0][1][index:int(MOD1max)], stresses[0][1][index:int(MOD1max)]) #Scatter plot to show the linear region up to the max stress value of second ramp. may need to adjust to add more points
# ax9.scatter(strains[1][1][1725:int(MOD2max)], stresses[1][1][1725:int(MOD2max)])
# ax9.scatter(strains[2][1][1725:int(MOD3max)], stresses[2][1][1725:int(MOD3max)])
# ax9.set_xlabel('strain')
# ax9.set_ylabel('stress MPa')
# ax9.legend(loc=2)
# print(parameters)
# print(str(poly1),str(poly2),str(poly3))

# # print(str(Mod1),str(Mod2),str(Mod3))
plt.show()
