import sys
import numpy as np
import matplotlib.pyplot as plt
# from scipy import signal processing function. Not all used.
from scipy.signal import butter, lfilter, freqz, filtfilt
import pandas as pd
# these are just for finding file paths
import os
import ntpath

# 0 = script.py
# argv1 = 'C:\Users\User\Documents\Multis\CMULTIS002-1_UL_AC_S_0_T\oksdata.txt'
# argv1 = 'C:\Users\User\Documents\Multis\CurrentData\data.txt'


# dir = '/Users/schimmt/Multis/app/TissueTesting/Data/'


# input_filename = 'C:\Users\User\Documents\Multis\RubberSuccess\test.txt'
files = ['data.txt', 'data1.txt']

threshold = 50.0  # threshold                                             #######  adjust accordingly to sample type

# files = ['\data.txt','\data1.txt']
fig2, (ax3, ax4) = plt.subplots(nrows=2, ncols=1, figsize=(18, 9), dpi=90)
fig3, (ax5, ax6) = plt.subplots(nrows=1, ncols=2, figsize=(18, 8), dpi=90) #load v disp and stress v strain at ramp before and after


filecount=0
for file in files:

    input_filename = dir+file
    infile = open(input_filename, 'r')

    if file == files[0]:    #first test specs
        init_length = 25.4935  # length
        width = 4
    elif file== files[1]:                   #second test specs
        init_length =  27.4065
    else:                   #third test specs
        init_length =  27.767

    area = float(np.pi*width**2)  #area

    # --------------------
    # pick all data for move relative and wait commands '
    rampcount = 0
    ramplist = []

    for line in infile:  # read each line
        if line[1:5] == 'Move':
            rampcount += 1
            ramplist.append(line)
            print(line)
        if line[1:4] == 'Sin':
            rampcount += 1
            ramplist.append(line)
            print(line)
    infile.close()


    infile = open(input_filename, 'r')
    ramps = [[] for x in range(rampcount)]
    sins = []
    i = -1
    f = -1
    sn = -1
    for line in infile:  # read each line
        if line[1:5] == 'Move':
            i += 1
            for _ in xrange(6):
                line = infile.next()
                a = line[0]
                while a.isdigit():
                    #                   line = line.rstrip(',,,\n')
                    ramps[i].append(line)
                    line = infile.next()
                    a = line[0]

        if line[1:4] == 'Sin':
            sn += 1
            for _ in xrange(8):
                line = infile.next()
                b = line[0]
                while b.isdigit():
                    #                     line = line.rstrip(',,,\n')
                    sins.append(line)
                    line = infile.next()
                    b = line[0]
    ramps = ramps[0:-1]

    rampcount = len(ramps)
    infile.close()

    vel = []
    infile = open(input_filename, 'r')
    for line in infile:
        if line[0:8] == 'Velocity':
            vel.append(line[16:22])
            # print vel
            line = infile.next()
    infile.close()
    print('----------')
    print('Expected rate, mm/s =', vel[-2])

    # separate time, disp and load from all move relative and wait command data

    r_time = [[] for x in range(rampcount)]
    r_disp = [[] for x in range(rampcount)]
    r_load = [[] for x in range(rampcount)]

    sin_disp = []
    sin_load = []
    sin_time = []
    loadindex=-1
    for j in range(0, rampcount):
        for line in ramps[j]:
            line = line.split()
            line = map(float, line)
            r_time[j].append(line[0])
            r_disp[j].append(abs(line[1]))
            r_load[j].append(abs(line[loadindex]))
            line = infile.next


    for line in sins:
        line = line.split()
        line = map(float, line)
        sin_time.append(line[0])
        sin_disp.append(abs(line[1]))
        sin_load.append(abs(line[loadindex]))
        line = infile.next


    # -------------------------------------------
    # low pass butterworth filter, 3rd order, 100 hz cutoff freq

    def butter_lowpass(cutoff, fs, order=3):
        nyq = 0.5 * fs
        normal_cutoff = cutoff / nyq
        b, a = butter(order, normal_cutoff, btype='low', analog=False)
        return b, a


    def butter_lowpass_filter(data, cutoff, fs, order):
        b, a = butter_lowpass(cutoff, fs, order=order)
        #         y = lfilter(b, a, data)
        y = filtfilt(b, a, data, method="gust")
        return y


    # Filter
    order = 3
    fs = 2.5 * 1000  # sample rate, Hz
    cutoff = 20  # desired cutoff frequency of the filter, Hz

    # filter coefficients to check its frequency response
    b, a = butter_lowpass(cutoff, fs, order)

    filtered_r_load = [[] for x in range(rampcount)]
    filtered_r_disp = [[] for x in range(rampcount)]
    #    filtered_w_disp=[[] for x in range(waitcount)]
    filtered_sin_load = []
    print(rampcount)

    # Apply the filter.
    for r in range(0, rampcount):
        print(r)
        filtered_r_load[r] = butter_lowpass_filter(r_load[r], cutoff, fs, order)
        filtered_r_disp[r] = butter_lowpass_filter(r_disp[r], cutoff, fs, order)

    filtered_sin_load = butter_lowpass_filter(sin_load, cutoff, fs, order)
    sin_disp_orig = sin_disp
    sin_disp = [((x - r_disp[2][0]) - 0.3) for x in sin_disp]  # accommodate 300 micron buffer

    def concotenateTime(r_time):
        time = [[] for x in range(len(r_time))]
        for i in range(len(r_time)):
            # print i
            if i==0:
                time[0]=np.array(r_time[i])
            if i>0:
                time[i]= np.array(r_time[i])+np.array(time[i-1][-1])
        return time

    time = concotenateTime(r_time)

    fig1, (ax1, ax2) = plt.subplots(nrows=2, ncols=1, figsize=(18, 9), dpi=90)
    fig4, (ax7) = plt.subplots(nrows=1, ncols=1, figsize=(18, 9), dpi=90)  # load v time during preconditioning
    # fig2, (ax3, ax4) = plt.subplots(nrows=2, ncols=1, figsize=(18, 9), dpi=90)
    ax1.set_title(file)
    ax3.set_title(file)
    ax7.set_title(file)

    fig1.tight_layout(pad=4, w_pad=.5, h_pad=2)
    fig2.tight_layout(pad=4, w_pad=.5, h_pad=2)

    #fig1
    ax1.set_xlabel('time,s')
    ax1.set_ylabel('load,gf')
    ax2.set_xlabel('time,s')
    ax2.set_ylabel('disp,mm')

    #fig2
    ax3.set_xlabel('time,s')
    ax3.set_ylabel('load,gf')
    ax4.set_xlabel('time,s')
    ax4.set_ylabel('disp,mm')

    #fig3
    ax5.set_xlabel('disp,mm')
    ax5.set_ylabel('load,fg')
    ax6.set_xlabel('strain')
    ax6.set_ylabel('stress')

    #fig4
    ax7.set_xlabel('time,s')
    ax7.set_ylabel('load,gf')

    ax7.plot(sin_time, filtered_sin_load)
    disp_adjustment = []
    cnt = 0
    redo=[]

    for i in filtered_r_load:
        # i = np.array(i)
        # print max(i)
        result = np.where(i<threshold)[0] #set threshold in beginning of script
        maxindex = max(result)
        minindex = min(result)
        # print i[maxindex], i[minindex]
        if i[maxindex] < i[minindex]:
            disp_adjustment.append([minindex, i[minindex]])
        else:
            disp_adjustment.append([maxindex, i[maxindex]])


    norm_f_disp = [[] for x in range(rampcount)]
    # for r in range(0, rampcount):
    #     norm_f_disp[r] = [x - float(filtered_r_disp[2][0]+.3) for x in filtered_r_disp[r]]
    # norm_f_disp = [[] for x in range(rampcount)]
    rmsload=[]

    for r in range(0, rampcount):
        startindex = disp_adjustment[r][0]
        print(startindex, (np.int(-(.1)*len(norm_f_disp[r]))), len(norm_f_disp[r]))
        norm_f_disp[r] = [x - float(filtered_r_disp[r][startindex]) for x in filtered_r_disp[r]]
        # if r > 1: not finished
        #     rmsload[r] = filtered_r_load[r][startindex:(np.int(-(.1)*len(norm_f_disp[r])))]

    # masterdisp.append()x

    a=2
    b=3
    strain = []
    stress = []
    cnt=0
    for i in range(a,b): ## fix
        strain.append([c / float(init_length) for c in norm_f_disp[i]])
        stress.append([(c*.0098) / float(area) for c in filtered_r_load[i]])
        # stress = [c * 0.0098 for c in stress[cnt]]  # gf to N
        cnt+=1

    aa=6
    bb=7
    cnt=0
    for i in range(aa,bb):
        strain.append([c / float(init_length) for c in norm_f_disp[i]])
        stress.append([(c*.0098) / float(area) for c in filtered_r_load[i]])
        # stress = [c * 0.0098 for c in stress[cnt]]  # gf to N
        cnt+=1

    norm_r_disp = [[] for x in range(rampcount)] #for left figure force v disp
    for r in range(0, rampcount):
        norm_r_disp[r] = [x - float(r_disp[2][0]+0.0) for x in r_disp[r]]

    if file==files[0]:
        ax5.plot(norm_r_disp[2], filtered_r_load[2],c='r',label=file)
        ax5.plot(norm_r_disp[3], filtered_r_load[3],c='orange',label=file)
        ax5.plot(norm_r_disp[aa], filtered_r_load[aa],c='g',label=file)
        ax5.plot(norm_r_disp[bb], filtered_r_load[bb],c='b',label=file)
    else:
        ax5.plot(norm_r_disp[2], filtered_r_load[2], c='r',ls='--',label=file)
        ax5.plot(norm_r_disp[3], filtered_r_load[3], c='orange',ls='--',label=file)
        ax5.plot(norm_r_disp[aa], filtered_r_load[aa], c='green',ls='--',label=file)
        ax5.plot(norm_r_disp[bb], filtered_r_load[bb], c='b',ls='--',label=file)

    for i in range(len(stress)):
        if i==1: #added the if statement to remove ramps before preconditioning
            if file==files[0]:
                ax6.plot(strain[i],stress[i],label=file)
            else:
                ax6.plot(strain[i],stress[i],ls='--',label=file)

    # time1=time[2:]
    for i in range(len(time)):
        if i>=2 and i<4:
            ax3.plot(time[i], filtered_r_load[i],lw=4)
        if i>5 and i<=7:
            ax3.plot(time[i], filtered_r_load[i],lw=4)
        if i>=2:
            ax3.plot(time[i], filtered_r_load[i])
            ax4.plot(time[i], filtered_r_disp[i])


    A = 5
    totaltime = r_time
    totaltime.insert(A,list(sin_time))
    totaltime = concotenateTime(totaltime)

    allLoads = filtered_r_load
    allLoads.insert(A,list(sin_load))

    allDisp = filtered_r_disp
    allDisp.insert(A,list(sin_disp_orig))

    # print len(totaltime)
    # exit()

    leg1=ax6.legend(loc=2)
    leg2=ax5.legend(loc=2)
    for i in range(len(totaltime)):
        ax1.plot(totaltime[i], allLoads[i])
    for i in range(len(totaltime)):
        ax2.plot(totaltime[i], allDisp[i])
    filecount+=1

    # plt.figure(3)
    # plt.plot(sin_disp, filtered_sin_load)
    # plt.xlabel('disp,mm')
    # plt.ylabel('load,gf')
    #
    # plt.figure(4)
    # plt.plot(sin_time, filtered_sin_load)
    # plt.xlabel('time,s')
    # plt.ylabel('load,gf')

    filtered_r_load
    filtered_r_disp

plt.show()