import tdsmParserMultis
import dicom
import numpy as np
import matplotlib.pyplot as plt
import SimpleITK as sitk
import warnings
import xml.etree.ElementTree as ET
from XMLparser import getAcceptance
import peakutils
import os
import math
import fileName
from scipy import stats

warnings.simplefilter('ignore', np.RankWarning)
warnings.simplefilter('ignore', RuntimeWarning)
 
Directory, SubjectID, tdmsDirectory, dicomDirectory = fileName.setDirectory(0)

xmlDirectory = Directory + "/TimeSynchronization/"
dicomFilesAll = sorted(os.listdir(dicomDirectory))

if not os.path.exists(Directory + '/DataOverview'):
    os.makedirs(Directory + '/DataOverview')

def make3digits(input):
    if len(input) == 1:
        output = '00' + str(input)
    elif len(input) == 2:
        output = '0' + str(input)
    else:
        output = str(input)
    return output

timeSynchList = []
cnt = 0

# file_TDMS = open(Directory + '/TimeSynchronization.txt', 'r')
# lines = file_TDMS.readlines()
# timeSynchList = []
# for i in lines:
#     timeSynchList.append(i.split(' '))
# file_TDMS.close()

def find_file(name, path):
    for root, dirs, files in os.walk(path):
        if name in files:
            return os.path.join(root, name)

xmlFiles = sorted(os.listdir(xmlDirectory))
# xmlFiles = [xmlFiles[17]]
for files in xmlFiles:
    #Find the xml file with delta_t
    delta_t_file = find_file(files, xmlDirectory)
    doc1 = ET.parse(delta_t_file)
    root1 = doc1.getroot()
    loc1 = root1.find("Location")
    dT_str = loc1.find("dT").text
    dT = float(dT_str)
    timeSynchList.append([files[0:3],dT])

dicomFilesAll = sorted(os.listdir(dicomDirectory))
tdmsFiles = sorted(os.listdir(tdmsDirectory))
tdmsFilesAll = []
for i in tdmsFiles:
    if i.endswith('.tdms'):
        tdmsFilesAll.append(i)
filesAll = []
cnt=0
trialSynch = []
for dcm in range(len(dicomFilesAll)):
    for tdms in range(len(tdmsFilesAll)):
        for runs in range(len(timeSynchList)):
            if dicomFilesAll[dcm][0:3] == timeSynchList[runs][0]:
                if tdmsFilesAll[tdms][0:3] == timeSynchList[runs][0]:
                    filesAll.append([dicomFilesAll[dcm],tdmsFilesAll[tdms], timeSynchList[runs][0], float(timeSynchList[runs][1])])
                    trialSynch.append(timeSynchList[runs][0])
                    # print dicomFiles[cnt]
                    # print filesAll[cnt]
                    cnt +=1

# if cnt != len(timeSynchList):
#     print 'Missing files'
#     for trials in timeSynchList:
#         if trials[0] not in trialSynch:
#             print trials[0]
#     ans = raw_input("Continue?: y/n")
#     if ans == 'y':
#         donothing = 0
#     else:
#         exit()
trial_acceptance_list, Total_Accepted = getAcceptance(Directory)

accList = []
for files, acc in trial_acceptance_list:
    if acc == 1:
        accList.append(files[0:3])
cntt = 0
for d, t, trials, dt in filesAll:
    if trials not in accList:
        cntt += 1
        print trials
print cntt

# import collections
# print [item for item, count in collections.Counter(timeSynchList).items() if count > 1]

print cnt
print ". "*cnt

def inprogress():
    print ("."),
count = 0
completeList = []
for dicom_filename, tdms_filename, runNum, dT in filesAll:

    if tdms_filename[-8:-7] == 'I':
        indentation = True
        anatomical = False
        # print "indentation"
    elif tdms_filename[-8:-7] == 'A':
        indentation = False
        anatomical = True
        # print "anatomical"
    # dT = 1883.0
    All = True
    if All: #tdms_filename[0:3] == '036':
        # trial_acceptance_list, Total_Accepted = getAcceptance(Directory)
        dicom_files = dicomDirectory + dicom_filename
        tdms_files = tdmsDirectory + tdms_filename
        dataset = dicom.read_file(dicom_files, stop_before_pixels=True)
        frameTimeVector = np.array(dataset.FrameTimeVector)
        totalFrames = len(frameTimeVector)

        dicomDuration = sum(frameTimeVector)

        # print totalFrames
        # print dicomDuration

        data = tdsmParserMultis.parseTDMSfile(tdms_files)
        Fx = np.array(data[u'State.6-DOF Load'][u'6-DOF Load Fx'])
        Fy = np.array(data[u'State.6-DOF Load'][u'6-DOF Load Fy'])
        Fz = np.array(data[u'State.6-DOF Load'][u'6-DOF Load Fz'])
        Fxx = np.array(data[u'Sensor.Load Cell'][u'Load Cell_Fx'])  # Fx and Fz reversed
        Fyy = np.array(data[u'Sensor.Load Cell'][u'Load Cell_Fy'])
        Fzz = np.array(data[u'Sensor.Load Cell'][u'Load Cell_Fz'])
        Mx = np.array(data[u'State.6-DOF Load'][u'6-DOF Load Mx'])
        My = np.array(data[u'State.6-DOF Load'][u'6-DOF Load My'])
        Mz = np.array(data[u'State.6-DOF Load'][u'6-DOF Load Mz'])
        pulse = np.array(data[u'Sensor.Run Number Pulse Train'][u'Run Number Pulse Train'])
        pulse = pulse[:]

        bone = ['Femur', 'Tibia', 'Humerus', 'Radius']
        if '_UL_' in tdms_filename:
            bone = bone[0]
        elif '_LL_' in tdms_filename:
            bone = bone[1]
        elif '_UA_' in tdms_filename:
            bone = bone[2]
        elif '_LA_' in tdms_filename:
            bone = bone[3]


        # X = data[u'State.Probe-' + bone + ' Position'][u'Probe-' + bone + ' Position x'] / 1000
        # Y = data[u'State.Probe-' + bone + ' Position'][u'Probe-' + bone + ' Position y'] / 1000
        # Z = data[u'State.Probe-' + bone + ' Position'][u'Probe-' + bone + ' Position z'] / 1000
        # r = np.radians(data[u'State.Probe-' + bone + ' Position'][u'Probe-' + bone + ' Position roll'])
        # p = np.radians(data[u'State.Probe-' + bone + ' Position'][u'Probe-' + bone + ' Position pitch'])
        # w = np.radians(data[u'State.Probe-' + bone + ' Position'][u'Probe-' + bone + ' Position yaw'])

        time = data[u'Time'][u'Time']
        F_mag = []
        for f in range(len(Fx)):
            F_mag.append(math.sqrt(Fx[f] ** 2 + Fy[f] ** 2 + Fz[f] ** 2))

        Peaks = peakutils.indexes(pulse, thres=0.5 * max(pulse), min_dist=100)

        def createAverageFit(F, avgThres):
            avgeragelist = []
            for items in range(len(F)):
                if F[items] != F[items-avgThres]:
                    num2avg = F[items:(items+avgThres)]
                    avgeragelist.append(np.average(num2avg))
                else:
                    continue

            avgeragelist = [avgeragelist[0]]*(avgThres/2) + avgeragelist[0:-(avgThres)/2]
            return avgeragelist

        # Find indentation start time in TDMS timeline , denoted by tdms b/c used for tdms
        avg = createAverageFit(F_mag, 300)
        index_range = 200

        time_max_tdms = avg.index(max(avg))

        r_sq_old = 0
        r_sq_diff = 1

        i = 500
        # print(max(force_list[0:230])-min(force_list[0:230]))
        while r_sq_diff > .1:
            if i > time_max_tdms:
                time_start_tdms = 230
                break
            else:
                x = np.arange(time_max_tdms - i, time_max_tdms, 1)
                y = np.array(avg[time_max_tdms - i:time_max_tdms])
                r_sq_diff = -r_sq_old + (y[-1]-y[0])
                # print(r_sq_diff)
                r_sq_old = (y[-1]-y[0])
                time_start_tdms = x[100]
                i += 100

        if anatomical:
            time_start_tdms=230

        force_list = list(F_mag)

        time_preIndent_tdms = Peaks[0]  #230 ms is location first pulse.
        time_postIndent_tdms = Peaks[-1]
        time_max_tdms = force_list.index(max(force_list))   # note max and min reversed as Fx is negative force

        time_min_tdms = force_list.index(min(force_list))
        if time_min_tdms < time_preIndent_tdms:
            time_min_tdms = force_list.index(min(force_list[230:]))

        time_middle_tdms = time_start_tdms + (time_max_tdms - time_start_tdms) / 2

        def findFrame(initial_time, dT, frameTimeVector):
            adjusted_time = dT + initial_time
            for frames in range(len(frameTimeVector)):
                frames += 1
                frame_time = sum(frameTimeVector[0:frames])
                frame_frame = frames - 1 # in case if loop not entered
                readjusted_time_tdms = sum(frameTimeVector[0:frames-1]) - dT # in case if loop not entered
                if adjusted_time <= frame_time:
                    timeDiff_up = frame_time - adjusted_time
                    timeDiff_low = adjusted_time - sum(frameTimeVector[0:frames-1])
                    if timeDiff_up < timeDiff_low:
                        frame_frame = frames
                        readjusted_time_tdms = frame_time - dT
                    else:
                        frame_frame = frames-1
                        readjusted_time_tdms = sum(frameTimeVector[0:frames-1]) - dT
                    break

            return frame_frame, int(readjusted_time_tdms)

        pre_frame, pre_frame_time_tdms = findFrame(time_preIndent_tdms, dT, frameTimeVector)
        post_frame, post_frame_time_tdms = findFrame(time_postIndent_tdms, dT, frameTimeVector)
        start_frame, start_frame_time_tdms = findFrame(time_start_tdms, dT, frameTimeVector)
        middle_frame, middle_frame_time_tdms = findFrame(time_middle_tdms, dT, frameTimeVector)
        max_frame, max_frame_time_tdms = findFrame(time_max_tdms, dT, frameTimeVector)
        min_frame, min_frame_time_tdms = findFrame(time_min_tdms, dT, frameTimeVector)

        try:

            force_start_tdms = force_list[start_frame_time_tdms]
            force_preIndent_tdms = force_list[pre_frame_time_tdms]
            force_postIndent_tdms = force_list[post_frame_time_tdms]
            force_middle_tdms = force_list[middle_frame_time_tdms]
            force_max_tdms = force_list[max_frame_time_tdms]
            force_min_tdms = force_list[min_frame_time_tdms]
        except IndexError:
            print ""
            print "Index Error"
            print dicom_filename
            print tdms_filename
            print min_frame_time_tdms

        peak_frame = []
        for i in Peaks:
            peak_frame.append(findFrame(i, dT, frameTimeVector))

        reader = sitk.ImageFileReader()
        reader.SetFileName(dicom_files)

        # print frame
        img = reader.Execute()
        img1 = img[:, :, pre_frame]

        if anatomical:
            img2 = img[:, :, min_frame]
            img3 = img[:, :, post_frame]
        if indentation:
            img2 = img[:, :, start_frame]
            img3 = img[:, :, middle_frame]
            img4 = img[:, :, max_frame]
            image4 = sitk.GetArrayFromImage(img4)

        image1 = sitk.GetArrayFromImage(img1)
        image2 = sitk.GetArrayFromImage(img2)
        image3 = sitk.GetArrayFromImage(img3)

        fig, (ax11, ax12, ax13, ax14, ax2, ax3, axPulse, axOptoPos, axOptoAng) = plt.subplots(nrows=9, ncols=1, figsize=(18, 18), dpi=200)
        ax2 = plt.subplot2grid((21, 4), (6, 0), rowspan=5, colspan=4)
        ax3 = plt.subplot2grid((21, 4), (10, 4), rowspan=4, colspan=4)
        axOptoPos = plt.subplot2grid((21, 4), (14, 0), rowspan=3, colspan=4)
        axOptoAng = plt.subplot2grid((21, 4), (17, 0), rowspan=3, colspan=4)
        axPulse = plt.subplot2grid((21, 4), (20, 0), rowspan=1, colspan=4)


        if indentation:
            ax11 = plt.subplot2grid((21, 4), (0, 0), rowspan=5, colspan=1)
            ax12 = plt.subplot2grid((21, 4), (0, 1), rowspan=5, colspan=1)
            ax13 = plt.subplot2grid((21, 4), (0, 2), rowspan=5, colspan=1)
            ax14 = plt.subplot2grid((21, 4), (0, 3), rowspan=5, colspan=1)

            ax14.imshow(image4[:, :, 0], interpolation=None, cmap='gray')

            ax2.scatter(pre_frame_time_tdms, force_preIndent_tdms, s=25, color='black', marker='o')
            ax2.scatter(max_frame_time_tdms, force_max_tdms, s=25, color='black', marker='o')
            ax2.scatter(middle_frame_time_tdms, force_middle_tdms, s=25, color='black', marker='o')
            ax2.scatter(start_frame_time_tdms, force_start_tdms, s=25, color='black', marker='o')

        if anatomical:
            ax11 = plt.subplot2grid((21, 8), (0, 0), rowspan=5, colspan=2)
            ax12 = plt.subplot2grid((21, 8), (0, 3), rowspan=5, colspan=2)
            ax13 = plt.subplot2grid((21, 8), (0, 6), rowspan=5, colspan=2)
            ax2.scatter(pre_frame_time_tdms, force_preIndent_tdms, s=25, color='black', marker='o')
            ax2.scatter(min_frame_time_tdms, force_min_tdms, s=25, color='black', marker='o')
            ax2.scatter(post_frame_time_tdms, force_postIndent_tdms, s=25, color='black', marker='o')

        ax11.imshow(image1[:,:,0], interpolation=None, cmap='gray')
        ax12.imshow(image2[:,:,0], interpolation=None, cmap='gray')
        ax13.imshow(image3[:,:,0], interpolation=None, cmap='gray')

        # ax2.plot(avg, color='red', lw=3, alpha=0.8)
        ax2.plot(Fx, label="$Fx-6DOF$", color='blue')
        ax2.plot(Fy, label="$Fy-6DOF$", color='red')
        ax2.plot(Fz, label="$Fz-6DOF$", color='green')
        ax2.plot(Fxx, label="$Fx$", color='blue', linestyle='--', alpha=.5)
        ax2.plot(Fyy, label="$Fy$", color='red', linestyle='--', alpha=.5)
        ax2.plot(Fzz, label="$Fz$", color='green', linestyle='--', alpha=.5)
        ax3.plot(Mx, label="$Mx-6DOF$", color='blue')
        ax3.plot(My, label="$My-6DOF$", color='red')
        ax3.plot(Mz, label="$Mz-6DOF$", color='green')
        axPulse.plot(pulse, color='black', label='$Pulse$')

        # print len(x)
        # print len(y)
        # print len(z)
        # print len(time)
        X = X-X[0]
        Y = Y-Y[0]
        Z = Z-Z[0]
        axOptoPos.plot(time, X * 1000, label="$x$")
        axOptoPos.plot(time, Y * 1000, label="$y$")
        axOptoPos.plot(time, Z * 1000, label="$z$")
        axOptoPos.set_ylabel('Tool Position')




        r = r-r[0]
        p = p-p[0]
        w = w-w[0]
        # axOptoAng.plot(time, np.degrees(r), label="$roll$")
        # axOptoAng.plot(time, np.degrees(p), label="$pitch$")
        # axOptoAng.plot(time, np.degrees(w), label="$yaw$")

        axOptoAng.plot(time, r, label="$x$")
        axOptoAng.plot(time, p, label="$y$")
        axOptoAng.plot(time, w, label="$z$")

        axOptoAng.set_ylabel('Tool Orientation')
        # axOptoAng.legend(['roll (deg)', 'pitch (deg)', 'yaw (deg)'], loc='center left', bbox_to_anchor=(1, 0.5))

        leg5 = axOptoPos.legend(loc='upper right', prop={'size': 10}, borderpad=0.2, handlelength=3)
        leg6 = axOptoAng.legend(loc='upper right', prop={'size': 10}, borderpad=0.2, handlelength=3)

        if indentation:
            if force_max_tdms < 22:
                ax2.set_ylim(-10, 22)
            else:
                ax2.set_ylim(-10, round(force_max_tdms)+5)
        if anatomical:
            ax2.set_ylim(-10, 22)

        ax3.set_ylim(-.14, .14)
        ax2.set_xlim(0, 8000)
        ax3.set_xlim(0, 8000)
        axPulse.set_xlim(0, 8000)
        axPulse.set_ylim(-.1, .55)

        ax11.set_xticklabels([])
        ax11.set_yticklabels([])
        ax12.set_xticklabels([])
        ax12.set_yticklabels([])
        ax13.set_xticklabels([])
        ax13.set_yticklabels([])
        ax14.set_xticklabels([])
        ax14.set_yticklabels([])


        ax2.tick_params(axis='both',
                        which='both',
                        bottom='off',
                        top='off',
                        gridOn='true',
                        labelbottom='off')
        ax3.tick_params(axis='both',
                        which='both',
                        bottom='off',
                        top='off',
                        gridOn='true',
                        labelbottom='off')

        ax11.grid(color='red', ls='solid', lw=.5, alpha=0.5)
        ax12.grid(color='red', ls='solid', lw=.5, alpha=0.5)
        ax13.grid(color='red', ls='solid', lw=.5, alpha=0.5)
        ax14.grid(color='red', ls='solid', lw=.5, alpha=0.5)

        axPulse.tick_params(axis='y',
                            which='both',
                            bottom='off',
                            top='off',
                            left='off',
                            right='off',
                            labelleft='off')

        leg1 = ax2.legend(loc='upper right', prop={'size': 10}, borderpad=0.2, handlelength=3)
        leg2 = ax3.legend(loc='upper right', prop={'size': 10}, borderpad=0.2, handlelength=3)
        leg = axPulse.legend(loc='upper right', prop={'size': 10}, borderpad=0.2, handlelength=3)
        leg1.get_frame().set_alpha(0.8)
        leg.get_frame().set_alpha(0.8)

        if indentation:
            ax11.set_title("Pre-Indentation: Frame " + str(pre_frame) + "/" + str(totalFrames))
            ax12.set_title("Start-Indentation: Frame " + str(start_frame) + "/" + str(totalFrames))
            ax13.set_title("Middle-Indentation: Frame " + str(middle_frame) + "/" + str(totalFrames))
            ax14.set_title("Max-Indentation: Frame " + str(max_frame) + "/" + str(totalFrames))

        if anatomical:
            ax11.set_title("First Pulse: Frame " + str(pre_frame) + "/" + str(totalFrames))
            ax12.set_title("At Minimum Force: Frame " + str(min_frame) + "/" + str(totalFrames))
            ax13.set_title("Last Pulse: Frame " + str(post_frame) + "/" + str(totalFrames))
        ax2.set_title("Load Cell Data ")
        # ax3.set_title("Moments (Nm)")
        ax2.set_ylabel('Force (N)')
        ax3.set_ylabel('Moment (Nm)')
        axPulse.set_xlabel("time(ms)")
        # fig.tight_layout(pad=4, w_pad=.5, h_pad=.5)
        # plt.draw()


        plt.savefig(Directory + '/DataOverview/' + tdms_filename[0:-5] + '_analysis' + '.png')
        completeList.append(tdms_filename[0:3])
        count += 1
        # plt.show()
        inprogress()

        # plt.show()
        # exit()



pngList = sorted(os.listdir(Directory))

def findMissingFiles(trial_acceptance_list, pngList):
    # make list of accepted trials (3 digits)
    accList = []
    for files, acc in trial_acceptance_list:
        if acc == 1:
            accList.append(files[0:3])

    print "{} accepted trials".format(len(accList))

    # make list of png's made (3 digits)
    png = []
    for pngs in pngList:
        png.append(pngs[0:3])

    # find missing png's
    print "Missing: "
    for files in accList:
        if files not in png:
            print "     ", files

# findMissingFiles(trial_acceptance_list, pngList)

print ""
print "Complete"