import tdsmParserMultis
import dicom
import numpy as np
import matplotlib.pyplot as plt
import SimpleITK as sitk
import warnings
import xml.etree.ElementTree as ET
from XMLparser import getAcceptance
import peakutils
import os
import math
import fileName
from scipy import stats

warnings.simplefilter('ignore', np.RankWarning)
warnings.simplefilter('ignore', RuntimeWarning)



ids=['OptoTapTest']

for SubjectID in ids:
    print SubjectID
    Directory = "../MULTIS_trials/" + SubjectID
    tdmsDirectory = Directory + "/Data/"
    tdmsFiles = sorted(os.listdir(tdmsDirectory))
    tdmsFilesAll = []

    for i in tdmsFiles:
        if i.endswith('.tdms'):
            tdmsFilesAll.append(i)


    for tdms_filename in tdmsFiles:
    
        if tdms_filename[0:3] == '002': # or tdms_filename[0:3] == '033':
            tdms_file = tdmsDirectory + tdms_filename
            print tdms_filename
            data = tdsmParserMultis.parseTDMSfile(tdms_file)
            Fx = np.array(data[u'State.6-DOF Load'][u'6-DOF Load Fx'])
            Fy = np.array(data[u'State.6-DOF Load'][u'6-DOF Load Fy'])
            Fz = np.array(data[u'State.6-DOF Load'][u'6-DOF Load Fz'])
            Mx = np.array(data[u'State.6-DOF Load'][u'6-DOF Load Mx'])
            My = np.array(data[u'State.6-DOF Load'][u'6-DOF Load My'])
            Mz = np.array(data[u'State.6-DOF Load'][u'6-DOF Load Mz'])

            time = data[u'Time'][u'Time']
            time = time - time[0]

            bone = 'US Probe'
            Xp = data[u'Sensor.' + bone][u'' + bone + '_smart_02.x'] / 1000
            Yp = data[u'Sensor.' + bone][u'' + bone + '_smart_02.y'] / 1000
            Zp = data[u'Sensor.' + bone][u'' + bone + '_smart_02.z'] / 1000
            rp = np.radians(data[u'Sensor.' + bone][u'' + bone + '_smart_02.r'])
            pp = np.radians(data[u'Sensor.' + bone][u'' + bone + '_smart_02.p'])
            wp = np.radians(data[u'Sensor.' + bone][u'' + bone + '_smart_02.w'])

            def createAverageFit(F, avgThres):
                avgeragelist = []
                for items in range(len(F)):
                    if F[items] != F[items - avgThres]:
                        num2avg = F[items:(items + avgThres)]
                        avgeragelist.append(np.average(num2avg))
                    else:
                        continue

                avgeragelist = [avgeragelist[0]] * (avgThres / 2) + avgeragelist[0:-(avgThres) / 2]
                return avgeragelist


            # Find indentation start time in TDMS timeline , denoted by tdms b/c used for tdms
            avg = createAverageFit(Fx, 300)
            Xp = Xp - Xp[0]
            Yp = Yp - Yp[0]
            Zp = Zp - Zp[0]



            vec=Fx
            ranges = [[1000,1150], [2814,2820],[4560,4575]]
            def findmovement(vec,ranges,thresh, time):
                vec = np.array(vec)
                i=0
                for x in time:
                    i+=1
                    if x > ranges[0] and x< ranges[1]:
                        diff = vec[i+1]-vec[i]
                        if diff >= thresh:
                            x=time[i]
                            return x

                            break

            optothresh = .0003
            o1 = findmovement(Xp, [950, 1250], optothresh,time)
            o2 = findmovement(Xp, [2700, 2820], optothresh,time)
            o3 = findmovement(Xp, [4400, 4575], optothresh,time)
            o = [o1, o2, o3]
            # exit()
            forcethresh = .2
            x1 = findmovement(Fx,[1000,1150],forcethresh, range(len(Fx)))
            x2 = findmovement(Fx,[2814,2820],forcethresh, range(len(Fx)))
            x3 = findmovement(Fx,[4560,4575],forcethresh, range(len(Fx)))
            x = [x1,x2,x3]
            print o
            print x

            diff = np.array(x) - np.array(o)
            print 'diff=', diff
            mean = np.average(diff)
            print 'mean=', mean
            std = np.std(diff)
            print 'std=', std
            fig, (ax1, ax2) = plt.subplots(nrows=2, ncols=1, dpi=100)

            ax1list = [Fx]
            ax2list = [Xp]
            ax1.set_title("Tap Test")
            ax2.set_ylabel("OptoMotion (m)")
            ax1.set_ylabel("Force (N)")
            ax2.set_xlabel("time (ms)")
            for plts in ax1list:
                ax1.plot(plts)
            for plts in ax2list:
                ax2.plot(time, plts)
                ax2.scatter(time, plts,marker='o')
            for oo in o:
                ax2.axvline(oo, linestyle='--', color='r')
            for xx in x:
                ax1.axvline(xx,linestyle = '--',color='r')
            plt.show()