# program 'tissue-data_analysis'
# python 3 version
# INPUT: read one file; tissue mechanical testing data file example: 'data.txt'
# - read 2 further numbers from command line: length for tensile (or thickness for compression), x-sec area
# - process data to generate stress-relaxation, preconditioning data, with filtering, displacement vs. time, force vs. time, other calcs
# - generate plots of above data. 
# OUTPUT: For some calculations above write two data files: combined force-disp stress-relax ./<sample name>_ramp_hold.txt ; Pre and post preconditioning load-unload vs. displacement ./<sample name>-_pc_ramps.txt

#python /Users/klonowe/TissueTesting/app/TissueTesting/Python/tissue-data_analysis_version3.py /Users/klonowe/TissueTesting/app/TissueTesting/Data/Full_SR_8_23.txt 23.1 5
# Developed by Snehal Chokhandre, adapted for python 3 by Ellen Klonowski
# input needed= data file name, thickness/length and area of cross section

import sys
import numpy as np
import matplotlib.pyplot as plt
#from scipy import signal processing function. Not all used.
from scipy.signal import butter, lfilter, freqz, filtfilt
#these are just for finding file paths
import os
import ntpath



def USAGE(argv):

    print ('USAGE: ') + argv[0] + ' <Insufficient arguments>'

    
def main(argv):

    if len(argv)!=4:
        USAGE(argv)
        sys.exit(1)

    input_filename = argv[1]
    infile = open(input_filename, 'r')
    

# --------------------
# pick all data for move relative and wait commands '
    

    ramps=[[] for x in range(15)]
    waits=[[] for x in range(14)]
    sins=[]
    i=-1 
    f=-1
    sn=-1
#    amp=[]
    for line in infile: # read each line          
         if line[1:5] == 'Move':
             i += 1
             for _ in range(6):
               line = next(infile)
               a=line[0]
               while a.isdigit():
#                   line = line.rstrip(',,,\n')
                   ramps[i].append(line)
                   line = next(infile)
                   a=line[0]  
         if line[1:5] == 'Wait':
             f += 1
             for _ in range(4):
                 line = next(infile)
                 b= line[0]             
                 while b.isdigit():
#                     line = line.rstrip(',,,\n')
                     waits[f].append(line)
                     line = next(infile)
                     b=line[0]     
         if line[1:4] == 'Sin':
             sn += 1
             for _ in range(8):
                 line = next(infile)
                 b= line[0]             
                 while b.isdigit():
#                     line = line.rstrip(',,,\n')
                     sins.append(line)
                     line = next(infile)
                     b=line[0] 
    infile.close()
   
   
# expected rate
    vel=[]
    infile = open(input_filename, 'r')    
    for line in infile:    
        if line[0:8] == 'Velocity':
             vel.append(line[16:22])
#             print vel  
             line=next(infile)
    infile.close()
    print ('----------')
    print ('Expected rate, mm/s =', vel[13])
# -------------------------------------
# separate time, disp and load from all move relative and wait command data
    
    r_time=[[] for x in range(15)]
    r_disp=[[] for x in range(15)]
    r_load=[[] for x in range(15)]    
    w_time=[[] for x in range(14)]
    w_disp=[[] for x in range(14)]
    w_load=[[] for x in range(14)]
    sin_disp=[]
    sin_load=[]
    sin_time=[]
    

    for j in range(0,15): 
     for line in ramps[j]:
         line = line.split()
         line = list(map(float, line))
         r_time[j].append(line[0])
         r_disp[j].append(abs(line[1]))
         r_load[j].append(abs(line[2]))
         infile.close()
         

    for p in range(0,14): 
     for line in waits[p]:
         line = line.split()
         line = list(map(float, line))
         w_time[p].append(line[0])
         w_disp[p].append(abs(line[1]))
         w_load[p].append(abs(line[2]))
         infile.close()
         
    for line in sins:
         line = line.split()
         line = list(map(float, line))
         sin_time.append(line[0])
         sin_disp.append(abs(line[1]))
         sin_load.append(abs(line[2]))
         infile.close()
         
    
#-------------------------------------------
# low pass butterworth filter, 3rd order, 100 hz cutoff freq

    def butter_lowpass(cutoff, fs, order=3):
         nyq = 0.5 * fs
         normal_cutoff = cutoff / nyq
         b, a = butter(order, normal_cutoff, btype='low', analog=False)
         return b, a

    def butter_lowpass_filter(data, cutoff, fs, order):
         b, a = butter_lowpass(cutoff, fs, order=order)
#         y = lfilter(b, a, data)
         y = filtfilt(b, a, data,method="gust")
         return y
         
    # Filter
    order = 3
    fs = 2.5*1000       # sample rate, Hz
    cutoff = 20 # desired cutoff frequency of the filter, Hz
    
    # filter coefficients to check its frequency response
    b, a = butter_lowpass(cutoff, fs, order)
    
    
    filtered_r_load=[[] for x in range(15)]
    filtered_r_disp=[[] for x in range(15)]
    filtered_w_load=[[] for x in range(14)]
#    filtered_w_disp=[[] for x in range(14)]
    filtered_sin_load =[]

    
   # Apply the filter.
    for r in range(0,15):
          filtered_r_load[r] = butter_lowpass_filter(r_load[r], cutoff, fs, order)
          filtered_r_disp[r] = butter_lowpass_filter(r_disp[r], cutoff, fs, order)

    for r in range(0,14):
          filtered_w_load[r] = butter_lowpass_filter(w_load[r], cutoff, fs, order)
#          filtered_w_disp[r] = butter_lowpass_filter(w_disp[r], cutoff, fs, order)
    filtered_sin_load = butter_lowpass_filter(sin_load, cutoff, fs, order)
    sin_disp= [((x-r_disp[2][0])-0.3) for x in sin_disp] # accommodate 300 micron buffer
#-------------------------------------------
# find peak/instantaneous and equilibrium load ,disp 
    
    inst_disp=[]
    inst_load=[] 
    inst_time_stamp=[]
    ramp_low_limit=[]
    ramp_up_limit=[]
    
    for r in range(11,14): # range depends on the number of move relative commands
           y=0
           while r_disp[r][y+1]-r_disp[r][y]==0:
                 y+= 1
#          
           ramp_low_limit.append(y)  # will be used to find ramp data for rate and duration calculation
           
           peak=max(r_disp[r])
           kx = [iv for iv, jg in enumerate(r_disp[r]) if jg == peak] # find all indices at which peak value is present in list
           yj=kx[0] # pick 1st index at which max disp appears 
           ramp_up_limit.append(yj)   # will be used to find ramp data for rate and duration calculation
           inst_time_stamp.append(yj) 
           inst_disp.append(r_disp[r][yj])
           inst_load.append(filtered_r_load[r][yj])

 
    inst_disp=[x-r_disp[9][0] for x in inst_disp]   # normalize (5g contact load)
    # inst_disp=[x-contact_disp[0] for x in inst_disp]
    
    eqbm_disp=[]
    eqbm_load=[] 
    eqbm_time_stamp=[]
    for t in range(12,15):  # range depends on the number of wait commands
         g=0
         while r_disp[t][g+1]-r_disp[t][g]==0:
               g+= 1
         eqbm_time_stamp.append(g)
         eqbm_disp.append(r_disp[t][g])
         eqbm_load.append(filtered_r_load[t][g])
         
         
    eqbm_disp=[x-r_disp[9][0] for x in eqbm_disp] # r_disp[2][0] same as the start of first ramp
    # eqbm_disp=[x-contact_disp[0] for x in eqbm_disp]
    
    print ('inst_disp=',inst_disp)
    print ('inst_load=',inst_load)
    print ('eqbm_disp=',eqbm_disp)
    print ('eqbm_load=',eqbm_load)


 
# ----------------------------------------   
# strain rate and ramp duration calculation
    
    actual_ramp=[[] for x in range(3)]
    actual_time=[[] for x in range(3)]
    for r in range(11,14):
        actual_ramp[r-11].append(r_disp[r][ramp_low_limit[r-11]])
        actual_time[r-11].append(r_time[r][ramp_low_limit[r-11]])
        xy= ramp_low_limit[r-11]+1
        while xy != ramp_up_limit[r-11]:
              actual_ramp[r-11].append(r_disp[r][xy])
              actual_time[r-11].append(r_time[r][xy])
              xy+=1

   
    applied_rates=[]
    for r in range(0,3):
        (applied_rate,applied_inter)=np.polyfit(actual_time[r],actual_ramp[r],1)
        applied_rates.append(applied_rate)

    
    print ('Applied rates at each ramp, mm/s =', applied_rates)
    
    time_taken=[]
    for r in range(0,3):
        end=r_time[r+2][ramp_up_limit[r]]
        start=r_time[r+2][ramp_low_limit[r]]
        time_taken.append(end-start)
            
    print ('Time taken for each ramp, s =', time_taken)
# -------------------------------
# calculate eqbm and inst stress strain
         
    init= argv[2]
    inst_stress=[]
    inst_strain=[]
    eqbm_stress=[]
    eqbm_strain=[]

#    diff_inst_strain=[x-r_disp[2][0] for x in inst_disp]
    inst_strain = [c/ float(init) for c in inst_disp]
    inst_stress = [c/float(argv[3]) for c in inst_load] 
    inst_stress= [c*0.0098 for c in inst_stress]      # gf to N
    inst_strain.insert(0,0)  # change to strain at zero load -disp point 
    inst_stress.insert(0,0)  # change to stress at zero load-disp point
    
#    diff_eqbm_strain=[x-r_disp[2][0] for x in eqbm_disp]
    eqbm_strain = [c/ float(init) for c in eqbm_disp]
    eqbm_stress = [c/float(argv[3]) for c in eqbm_load] 
    eqbm_stress= [c*0.0098 for c in eqbm_stress]      # gf to N
    eqbm_strain.insert(0,0)
    eqbm_stress.insert(0,0)
    
   
    print ('inst_strain=',inst_strain)
    print ('inst_stress=',inst_stress)
    print ('eqbm_strain=',eqbm_strain)
    print ('eqbm_stress=',eqbm_stress)

# ------------------------
# calculate moduli 


#    
    (inst_mod_fit_all,inst_inter)=np.polyfit(inst_strain,inst_stress,1)
    (eqbm_mod_fit_all,eqbm_inter)=np.polyfit(eqbm_strain,eqbm_stress,1)
#    

    inst_mod_2= (inst_stress[2]-inst_stress[0])/(inst_strain[2]-inst_strain[0])
    eqbm_mod_2= (eqbm_stress[2]-eqbm_stress[0])/(eqbm_strain[2]-eqbm_strain[0])
    inst_mod_1= (inst_stress[1]-inst_stress[0])/(inst_strain[1]-inst_strain[0])
    eqbm_mod_1= (eqbm_stress[1]-eqbm_stress[0])/(eqbm_strain[1]-eqbm_strain[0])
    inst_mod_3= (inst_stress[3]-inst_stress[0])/(inst_strain[3]-inst_strain[0])
    eqbm_mod_3= (eqbm_stress[3]-eqbm_stress[0])/(eqbm_strain[3]-eqbm_strain[0])
    average_inst_mod= (inst_mod_1+inst_mod_2+inst_mod_3)/3
    average_eqbm_mod= (eqbm_mod_1+eqbm_mod_2+eqbm_mod_3)/3
    inst_mod_3_highstrain= (inst_stress[3]-inst_stress[2])/(inst_strain[3]-inst_strain[2])
    eqbm_mod_3_highstrain= (eqbm_stress[3]-eqbm_stress[2])/(eqbm_strain[3]-eqbm_strain[2])
    
    print ('average instantaneous modulus, MPa =', average_inst_mod)
    print ('average equilibrium modulus, MPa =', average_eqbm_mod)
    print ('low strain instantaneous modulus,MPa=', inst_mod_1)
    print ('low strain equilibrium modulus,MPa=', eqbm_mod_1)
    print ('high strain instantaneous modulus,MPa=', inst_mod_3_highstrain)
    print ('hign strain equilibrium modulus,MPa=', eqbm_mod_3_highstrain)
    

 
# -------------------------------
# combine all stress relaxation ramps and waits (disp and loads) and convert to 
# stress-strain  
    

    filtered_load_stress_relax=[]
#    filtered_disp_stress_relax=[]
    disp_stress_relax=[]
    load_stress_relax=[]
    
    
    z=0
    for h in range(11,14):
          filtered_r_load[h]=filtered_r_load[h].tolist()  # convert array to list 
          filtered_load_stress_relax=filtered_load_stress_relax+filtered_r_load[h]
          disp_stress_relax=disp_stress_relax+r_disp[h]
          load_stress_relax=load_stress_relax+r_load[h]
          z+=1
          k=z+4
          for z in range(z,k):
#                filtered_w_load[z+1]=filtered_w_load[z+1].tolist()
                filtered_load_stress_relax=filtered_load_stress_relax+w_load[z+1]
                disp_stress_relax=disp_stress_relax+w_disp[z+1]  
                load_stress_relax=load_stress_relax+w_load[z+1]

    # note: include part of last unloading ramp where disp does not change

    time_stress_relax=[]
    z=0
    u=0

    for h in range(11,14):  
           r_time[h] = [x+u for x in r_time[h]]
           time_stress_relax=time_stress_relax+r_time[h]
           z+=1
           k=z+4
           for z in range(z,k):
               u=time_stress_relax[len(time_stress_relax)-1]
               w_time[z+1] = [x+u for x in w_time[z+1]]
               time_stress_relax=time_stress_relax+w_time[z+1]
           u=time_stress_relax[len(time_stress_relax)-1]

    allstress_stress_relax = [(c*0.0098)/float(argv[3]) for c in load_stress_relax]

    allstrain_stress_relax = [(c-r_disp[9][0])/float(argv[2]) for c in disp_stress_relax]

# ------------------------------
# max peak force in filtered data (comparing this value to the force peaks corresponding to disp peaks will give an indication of any phase shifts after filtering)

    mx_filtered_load=[]

    for i in range (11,14):    
        mx= max(filtered_r_load[i])
        mx_filtered_load.append(mx)
        
        
    print ('Actual peak/instantaneous filtered loads, gf = ', mx_filtered_load)
    
    print ('Filtered peak/instantaneous load in sync with displacement peaks,gf =',inst_load)
    
    print ('----------')
    
#-------------------------------
# combined all load values (unfiltered) first and filtered for plotting. 

    filtered_load_stress_relax_combined = butter_lowpass_filter(load_stress_relax, cutoff, fs, order)
    
# combined all stress values (unfiltered) first and filtered for plotting. 
    
    filtered_stress_stress_relax_combined = butter_lowpass_filter(allstress_stress_relax, cutoff, fs, order)

   
# ------------------------------
# PLOT DATA
    

    plt.figure(1)
    ax1 = plt.subplot2grid((2,1), (0,0))
    ax2 = plt.subplot2grid((2,1), (1,0))
    
    
    
  
    disp_stress_relax_norm=[x-r_disp[9][0] for x in disp_stress_relax]  
    
    ax1.plot(time_stress_relax,disp_stress_relax_norm,'r')
    ax1.set_ylabel('disp,mm')
    ax1.set_xlabel('time,s')
    
    ax2.plot(time_stress_relax,load_stress_relax,'-ro')
    ax2.plot(time_stress_relax,filtered_load_stress_relax,'-go')
    ax2.set_ylabel('load,gf')
    ax2.set_xlabel('time,s')

    
    plt.figure(2)
    ax1 = plt.subplot2grid((2,1), (0,0))
    ax2 = plt.subplot2grid((2,1), (1,0))
    
    ax1.plot(disp_stress_relax_norm, load_stress_relax)
    ax1. plot(inst_disp,inst_load,'go')
    ax1. plot(eqbm_disp,eqbm_load,'ro')
    ax1.set_xlabel('disp,mm')
    ax1.set_ylabel('load,gf')
    

    for r in range(11,14):

              ax1.plot(disp_stress_relax_norm, filtered_load_stress_relax, 'y')


         
    
    ax2.plot(allstrain_stress_relax,filtered_stress_stress_relax_combined)
    ax2.set_xlabel('strain')
    ax2.set_ylabel('stress,Mpa')
    
    
#    
    plt.figure(3)
    plt.plot(sin_disp,filtered_sin_load)
    plt.xlabel('disp,mm')
    plt.ylabel('load,gf')
    
    plt.figure(4)
    plt.plot(sin_time,filtered_sin_load)
    plt.xlabel('time,s')
    plt.ylabel('load,gf')

    plt.figure(5)    
    for r in range(11,14):

              plt.plot(disp_stress_relax_norm, filtered_load_stress_relax, 'y')
    plt.xlabel('disp,mm')
    plt.ylabel('load,gf')
    
    norm_r_disp=[[] for x in range(14)]
    for r in range(0,14):
          norm_r_disp[r] = [x-float(r_disp[0][0]) for x in r_disp[r]]
    
    plt.figure(6)
    plt.plot(norm_r_disp[2],filtered_r_load[2],'ro')
    plt.plot(norm_r_disp[2],r_load[2],'b-o')
    plt.plot(norm_r_disp[3],filtered_r_load[3])
    plt.plot(norm_r_disp[6],filtered_r_load[6])
    plt.plot(norm_r_disp[7],filtered_r_load[7])
    plt.xlabel('disp,mm')
    plt.ylabel('load,gf')
#    plt.show()

# -------------------------------    
# save combined ramp and hold to text file, text file name= folder name+ramp_hold
#

#    
#    
    
    file_path = os.path.abspath(input_filename)  
    folder_name=os.path.basename(os.path.normpath(file_path[:-8])) # grab folder name 
    
    ramp_hold_file=folder_name + '_ramp_hold'
    fd_data = np.array([time_stress_relax,disp_stress_relax_norm, filtered_load_stress_relax,allstrain_stress_relax,filtered_stress_stress_relax_combined])
    fd_data = fd_data.T
    datafile_id = open(ramp_hold_file+'.txt', 'w+')
    np.savetxt(datafile_id, fd_data)
    datafile_id.close()
    
    pc_ramps_file=folder_name + '_pc_ramps'
    pc_ramps_fd_data = np.array([norm_r_disp[2],filtered_r_load[2],norm_r_disp[3],filtered_r_load[3],norm_r_disp[6],filtered_r_load[6],norm_r_disp[7],filtered_r_load[7]])
    pc_ramps_fd_data = pc_ramps_fd_data.T
    pc_ramps_datafile_id = open(pc_ramps_file+'.txt', 'w+')
    np.savetxt(pc_ramps_datafile_id, pc_ramps_fd_data)
    pc_ramps_datafile_id.close()
    
#    ramp1_file=folder_name + 'ramp1'
#    ramp1 = np.array(filtered_r_load[11])
#    ramp1 = ramp1.T
#    ramp1_id = open(ramp1_file+'.txt', 'w+')
#    np.savetxt(ramp1_id, ramp1)
#    ramp1_id.close()
#    

# display/plot figures

    plt.show()

if __name__=="__main__":
    main(sys.argv)
#


