# !/usr/bin/env python

"""

Version 2.0.2

Description:
This script was written to manually extract tissue thicknesses for the MULTIS in vivo and in vitro experimentation.
Anatomical and indentation trials are analyzed differently.
    Anatomical - All frames between the first and last pulse, arranges from minimum to maximum normal force
    Indentation - Frames from start of indentation to peak normal force of indentation

Getting started:
The user may either hard code the path to the subjectID folder or choose to open a file dialog to browse for the
folder.
The accepted trials are then loaded into a listbox.
Select the trial you would like to analyze by 'double-clicking' the name.
The first frame will appear for analysis.
Move the four red dots to the appropriate locations
    Superficial skin
    Skin/Fat boundary
    Fat/Muscle boundary
    Muscle/Bone boundary
To save results, press 'Next Image' (Forces, moments, and thicknesses for that frame are saved to an xml
file) and the next frame will appear for analysis.

When you are finished with analysis for that image, you can close the program one of two ways
    'Done' button in the lower right corner will pretty print the newly created xml file and open the list of images
    for that subject, so that you can double-click on the next image you want to analyze
    The 'X' in the upper right corner will close all windows (without pretty printing the xml file), a warning box
    will appear to make sure this is what you want to do.

    Original Author:
        Erica Morrill
        Department of Biomedical Engineering
        Lerner Research Institute
        Cleveland Clinic
        Cleveland, OH
        morrile2@ccf.org

"""

import matplotlib

matplotlib.use("TkAgg")
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg, NavigationToolbar2Tk
import matplotlib.pyplot as plt
from matplotlib import patches
import os
import dicom
from SimpleITK import ImageFileReader, GetArrayFromImage
import tkFileDialog
from Tkinter import *
import Tkinter as tk
import tkMessageBox
import tdsmParserMultis
import numpy as np
import peakutils
import xml.etree.ElementTree as ET
from lxml import etree
import XMLparser
import math
import Plot_thicknessForce
import getpass
import time
from scipy import signal
import ConfigParser

def get_transformation_matrix(q1, q2, q3, q4, q5, q6):
    ''' Transform from optotrak global coordinates to optotrak position sensor coordinates'''

    T = np.zeros((4, 4))

    T[0, 0] = math.cos(q6) * math.cos(q5)
    T[1, 0] = math.sin(q6) * math.cos(q5)
    T[2, 0] = -math.sin(q5)

    T[0, 1] = math.cos(q6) * math.sin(q5) * math.sin(q4) - math.sin(q6) * math.cos(q4)
    T[1, 1] = math.sin(q6) * math.sin(q5) * math.sin(q4) + math.cos(q6) * math.cos(q4)
    T[2, 1] = math.cos(q5) * math.sin(q4)

    T[0, 2] = math.cos(q6) * math.sin(q5) * math.cos(q4) + math.sin(q6) * math.sin(q4)
    T[1, 2] = math.sin(q6) * math.sin(q5) * math.cos(q4) - math.cos(q6) * math.sin(q4)
    T[2, 2] = math.cos(q5) * math.cos(q4)

    T[0, 3] = q1
    T[1, 3] = q2
    T[2, 3] = q3
    T[3, 3] = 1

    return T

def convertTOarray(data):

    data = data.replace("", '')
    data = data.split(" ")
    corrected_list = data[2:len(data)]
    corrected_list[-1] = corrected_list[-1][0:-1]
    data_float = map(float, corrected_list)

    return np.asarray(data_float)

def get_distance(A, B):
    return math.sqrt((A[0]-B[0])**2+(A[1]-B[1])**2+(A[2]-B[2])**2)

def get_xyzrpw(T):
    x = T[0,3]
    y = T[1,3]
    z = T[2,3]
    w = math.atan2(T[1,0], T[0,0])
    p = math.atan2(-T[2,0], math.sqrt(T[2,1]**2+T[2,2]**2))
    r = math.atan2(T[2,1], T[2,2])
    return x, y, z, r, p, w

def fit_hypersphere(data, method="Hyper"):
    """
            FitHypersphere.py

            fit_hypersphere(collection of tuples or lists of real numbers)
            will return a hypersphere of the same dimension as the tuples:
                    (radius, (center))

            using the Hyper (hyperaccurate) algorithm of
            Ali Al-Sharadqah and Nikolai Chernov
            Error analysis for circle fitting algorithms
            Electronic Journal of Statistics
            Vol. 3 (2009) 886-911
            DOI: 10.1214/09-EJS419

            generalized to n dimensions

            Mon Apr 23 04:08:05 PDT 2012 Kevin Karplus

            Note: this version using SVD works with Hyper, Pratt, and Taubin methods.
            If you are not familiar with them, Hyper is probably your best choice.


            Creative Commons Attribution-ShareAlike 3.0 Unported License.
            http://creativecommons.org/licenses/by-sa/3.0/
    """

    """returns a hypersphere of the same dimension as the
        collection of input tuples
                (radius, (center))

       Methods available for fitting are "algebraic" fitting methods
        Hyper   Al-Sharadqah and Chernov's Hyperfit algorithm
        Pratt   Vaughn Pratt's algorithm
        Taubin  G. Taubin's algorithm

       The following methods, though very similar, are not implemented yet,
          because the contraint matrix N would be singular,
          and so the N_inv computation is not doable.

        Kasa    Kasa's algorithm
    """
    num_points = len(data)
    #    print >>stderr, "DEBUG: num_points=", num_points

    if num_points == 0:
        return (0, None)
    if num_points == 1:
        return (0, data[0])
    dimen = len(data[0])  # dimensionality of hypersphere
    #    print >>stderr, "DEBUG: dimen=", dimen

    if num_points < dimen + 1:
        raise ValueError( \
            "Error: fit_hypersphere needs at least {} points to fit {}-dimensional sphere, but only given {}".format(
                dimen + 1, dimen, num_points))

    # central dimen columns of matrix  (data - centroid)
    central = np.matrix(data, dtype=float)  # copy the data
    centroid = np.mean(central, axis=0)
    for row in central:
        row -= centroid
    # print >>stderr, "DEBUG: central=", repr(central)

    # squared magnitude for each centered point, as a column vector
    square_mag = [sum(a * a for a in row.flat) for row in central]
    square_mag = np.matrix(square_mag).transpose()
    #    print >>stderr, "DEBUG: square_mag=", square_mag

    if method == "Taubin":
        # matrix of normalized squared magnitudes, data
        mean_square = square_mag.mean()
        data_Z = np.bmat([[(square_mag - mean_square) / (2 * math.sqrt(mean_square)), central]])
        #    print >> stderr, "DEBUG: data_Z=",data_Z
        u, s, v = np.linalg.svd(data_Z, full_matrices=False)
        param_vect = v[-1, :]
        params = [x for x in np.asarray(param_vect)[0]]  # convert from (dimen+1) x 1 matrix to list
        params[0] /= 2 * math.sqrt(mean_square)
        params.append(-mean_square * params[0])
        params = np.array(params)

    else:
        # matrix of squared magnitudes, data, 1s
        data_Z = np.bmat([[square_mag, central, np.ones((num_points, 1))]])
        #    print >> stderr, "DEBUG: data_Z=",data_Z

        # SVD of data_Z
        # Note: numpy's linalg.svd returns data_Z = u * s * v
        #         not u*s*v.H as the Release 1.4.1 documentation claims.
        #         Newer documentation is correct.
        u, s, v = np.linalg.svd(data_Z, full_matrices=False)
        #    print >>stderr, "DEBUG: u=",repr(u)
        #    print >>stderr, "DEBUG: s=",repr(s)
        #    print >>stderr, "DEBUG: v=",repr(v)
        #    print >>stderr, "DEBUG: v.I=",repr(v.I)

        if s[-1] / s[0] < 1e-12:
            # singular case
            # param_vect as (dimen+2) x 1 matrix
            param_vect = v[-1, :]
            # Note: I get last ROW of v, while Chernov claims last COLUMN,
            # because of difference in definition of SVD for MATLAB and numpy

            #        print >> stderr, "DEBUG: singular, param_vect=", repr(param_vect)
            #        print >> stderr, "DEBUG: data_Z*V=", repr(data_Z*v)
            #        print >> stderr, "DEBUG: data_Z*VI=", repr(data_Z*v.I)
            #        print >> stderr, "DEBUG: data_Z*A=", repr(data_Z*v[:,-1])
        else:
            Y = v.H * np.diag(s) * v
            Y_inv = v.H * np.diag([1. / x for x in s]) * v
            #        print >>stderr, "DEBUG: Y=",repr(Y)
            #        print >>stderr, "DEBUG: Y.I=",repr(Y.I), "\nY_inv=",repr(Y_inv)
            # Ninv is the inverse of the constraint matrix, after centroid has been removed
            Ninv = np.asmatrix(np.identity(dimen + 2, dtype=float))
            if method == "Hyper":
                Ninv[0, 0] = 0
                Ninv[0, -1] = 0.5
                Ninv[-1, 0] = 0.5
                Ninv[-1, -1] = -2 * square_mag.mean()
            elif method == "Pratt":
                Ninv[0, 0] = 0
                Ninv[0, -1] = -0.5
                Ninv[-1, 0] = -0.5
                Ninv[-1, -1] = 0
            else:
                raise ValueError("Error: unknown method: {} should be 'Hyper', 'Pratt', or 'Taubin'")
                #        print >> stderr, "DEBUG: Ninv=", repr(Ninv)

            # get the eigenvector for the smallest positive eigenvalue
            matrix_for_eigen = Y * Ninv * Y
            #   print >> stderr, "DEBUG: {} matrix_for_eigen=\n{}".format(method, repr(matrix_for_eigen))
            eigen_vals, eigen_vects = np.linalg.eigh(matrix_for_eigen)
            #   print >> stderr, "DEBUG: eigen_vals=", repr(eigen_vals)
            #   print >> stderr, "DEBUG: eigen_vects=", repr(eigen_vects)

            positives = [x for x in eigen_vals if x > 0]
            if len(positives) + 1 != len(eigen_vals):
                # raise ValueError("Error: for method {} exactly one eigenvalue should be negative: {}".format(method,eigen_vals))
                print>> stderr, "Warning: for method {} exactly one eigenvalue should be negative: {}".format(method,
                                                                                                              eigen_vals)
            smallest_positive = min(positives)
            #    print >> stderr, "DEBUG: smallest_positive=", smallest_positive
            # chosen eigenvector as 1 x (dimen+2) matrix
            A_colvect = eigen_vects[:, list(eigen_vals).index(smallest_positive)]
            #        print >> stderr, "DEBUG: A_colvect=", repr(A_colvect)
            # now have to multiply by Y inverse
            param_vect = (Y_inv * A_colvect).transpose()
            #        print >> stderr, "DEBUG: nonsingular, param_vect=", repr(param_vect)
            params = np.asarray(param_vect)[0]  # convert from (dimen+2) x 1 matrix to array of (dimen+2)


            #    print >> stderr, "DEBUG: params=", repr(params)
    radius = 0.5 * math.sqrt(sum(a * a for a in params[1:-1]) - 4 * params[0] * params[-1]) / abs(params[0])
    center = -0.5 * params[1:-1] / params[0]
    # y    print >> stderr, "DEBUG: center=", repr(center), "centroid=", repr(centroid)
    center += np.asarray(centroid)[0]
    return (radius, center)

def transform_sphere_points(B1m, B1mr):
    ''' Function to extract the centers of each registration marker in the respective bone Optotrak sensor coordinate
    system'''

    # Within this data set, there are 30 points, each with an x,y,and z coordinate
    # next we want to convert this data set into an array. The units for the axis
    # are all in m.

    # But first need to get rid of the quotations at the beginning and end of string.
    B1m = B1m.replace('"', '')
    B1msplit = B1m.split(" ")

    # Get rid of first two cells, and convert string into float
    corrected_list = B1msplit[2:92]
    corrected_floatlist = map(float, corrected_list)

    ## matrix of the 30 points
    B1mCoord = np.asarray(corrected_floatlist)
    B1mCoord = B1mCoord.reshape(30, -1)

    # Within this data set, there are 30 points, each with an x,y,z,roll, pitch and
    # yaw coordinate next we want to convert this data set into an array. The
    # units for the x,y, and z axis are in m and the roll, pitch, and yaw axis
    # are in rad.

    # But first need to get rid of the quotations at the beginning and end of string.
    B1mr = B1mr.replace('"', '')
    B1mrsplit = B1mr.split(" ")

    # Get rid of first two cells, and convert string into float
    corrected_list = B1mrsplit[2:182]
    corrected_floatlist = map(float, corrected_list)

    ## matrix of the 30 points
    B1mrCoord = np.asarray(corrected_floatlist)
    B1mrCoord = B1mrCoord.reshape(30, -1)

    ## Define coordinate transformations of data
    P1 = np.ones((4, 1))
    Coord1 = np.zeros((30, 3))

    for i in range(0, 30):
        q1 = B1mrCoord[i, 0]
        q2 = B1mrCoord[i, 1]
        q3 = B1mrCoord[i, 2]
        q4 = B1mrCoord[i, 3]
        q5 = B1mrCoord[i, 4]
        q6 = B1mrCoord[i, 5]

        T1 = get_transformation_matrix(q1, q2, q3, q4, q5, q6)

        P1[0, 0] = B1mCoord[i, 0]
        P1[1, 0] = B1mCoord[i, 1]
        P1[2, 0] = B1mCoord[i, 2]

        invT1 = np.linalg.inv(T1)

        A = np.dot(invT1, P1) # Transform to bone Optotrak sensor coordinate system
        Coord1[i, 0] = A[0, 0]
        Coord1[i, 1] = A[1, 0]
        Coord1[i, 2] = A[2, 0]

    ACoord1 = Coord1[0:10, 0:3]
    BCoord1 = Coord1[10:20, 0:3]
    CCoord1 = Coord1[20:30, 0:3]


    # Sphere fit for Rigid Body Collected Points (m), all three spheres. Set to NAN if points were not digitized.
    NAN = float('nan')
    try:
        B1mSphereA = fit_hypersphere(ACoord1, method="Pratt")
    except:
        B1mSphereA = [(0), (NAN, NAN, NAN)]

    try:
        B1mSphereB = fit_hypersphere(BCoord1, method="Pratt")
    except:
        B1mSphereB = [(0), (NAN, NAN, NAN)]

    try:
        B1mSphereC = fit_hypersphere(CCoord1, method="Pratt")
    except:
        B1mSphereC = [(0), (NAN, NAN, NAN)]

    return np.array([(B1mSphereA[1][0], B1mSphereA[1][1], B1mSphereA[1][2]),(B1mSphereB[1][0], B1mSphereB[1][1], B1mSphereB[1][2]),(B1mSphereC[1][0], B1mSphereC[1][1], B1mSphereC[1][2])])

class FileSelectionApp(tk.Tk):
    """Application to display all of the trials in a list"""

    def __init__(self):
        tk.Tk.__init__(self)
        self.main()

    def main(self):
            try:
                self.mainApp = App(None)
                self.mainApp.getFiles()

                self.mainApp.ListBox = self

                self.columnconfigure(0, weight=1)
                self.rowconfigure(0, weight=1)

                self.title('Select Location for Analysis')

                self.lb = tk.Listbox()
                self.sb = tk.Scrollbar(orient=VERTICAL, command=self.lb.yview)
                self.lb.config(yscrollcommand=self.sb.set)

                self.lb.grid(column=0, row=0, sticky=(N, W, E, S))
                self.sb.grid(column=1, row=0, sticky=(N, S))
                self.lb.config(width=40, height=30)

                for item in self.mainApp.lstFilesTDMS:
                    self.lb.insert("end", os.path.split(item)[1][0:-5])

                self.lb.bind("<Double-Button-1>", self.OnDouble)
                self.protocol("WM_DELETE_WINDOW", self.on_exit)

            except:

                tkMessageBox.showerror("Error",
                                       "Directory is not setup in the expected format, check the wiki page for correct format")
                self.destroy()
                self.mainApp.destroy()
                self.quit()
                self.mainApp.quit()

                return


    def OnDouble(self, event):
        """When a location name is specified using a double-click, the execute function begins analysis"""
        widget = event.widget
        selection = widget.curselection()
        value = map(int, selection)
        # try:
        self.mainApp.main()
        self.withdraw()
        self.mainApp.execute(value[0])
        # except:
        #     return

    def yview(self, *args):
        apply(self.yview, args)

    def on_exit(self):
        self.quit()


class DraggablePatch:
    # draggable rectangle with the animation blit techniques; see
    # http://www.scipy.org/Cookbook/Matplotlib/Animations

    # Modified slightly from matplotlib Advanced User's Guide
    # http://matplotlib.org/users/event_handling.html

    lock = None  # only one can be animated at a time

    def __init__(self, obj):
        self.obj = obj
        self.press = None
        self.background = None

    def connect(self):
        """connect to all the events we need"""
        self.cidpress = self.obj.figure.canvas.mpl_connect('button_press_event', self.on_press)
        self.cidrelease = self.obj.figure.canvas.mpl_connect('button_release_event', self.on_release)
        self.cidmotion = self.obj.figure.canvas.mpl_connect('motion_notify_event', self.on_motion)

    def on_press(self, event):
        """on button press we will see if the mouse is over us and store some data"""
        if event.inaxes != self.obj.axes: return
        if DraggablePatch.lock is not None: return
        contains, attrd = self.obj.contains(event)
        if not contains: return

        x0, y0 = self.obj.center
        self.press = x0, y0, event.xdata, event.ydata

        DraggablePatch.lock = self

        # draw everything but the selected patch and store the pixel buffer
        canvas = self.obj.figure.canvas
        axes = self.obj.axes
        self.obj.set_animated(True)
        canvas.draw()
        self.background = canvas.copy_from_bbox(self.obj.axes.bbox)

        # now redraw just the patch
        axes.draw_artist(self.obj)

        # and blit just the redrawn area
        canvas.blit(axes.bbox)

    def on_motion(self, event):
        """on motion we will move the object if the mouse is over us"""
        if DraggablePatch.lock is not self:
            return
        if event.inaxes != self.obj.axes: return

        self.obj.center = (self.obj.center[0], event.ydata)

        canvas = self.obj.figure.canvas

        axes = self.obj.axes

        # restore the background region
        canvas.restore_region(self.background)

        # redraw just the current patch
        axes.draw_artist(self.obj)

        # blit just the redrawn area
        canvas.blit(axes.bbox)

    def on_release(self, event):
        """on release we reset the press data"""
        if DraggablePatch.lock is not self:
            return

        self.press = None
        DraggablePatch.lock = None

        # turn off the patch animation property and reset the background
        self.obj.set_animated(False)
        self.background = None

        # redraw the full figure
        self.obj.figure.canvas.draw()

    def disconnect(self):
        """disconnect all the stored connection ids"""
        self.obj.figure.canvas.mpl_disconnect(self.cidpress)
        self.obj.figure.canvas.mpl_disconnect(self.cidrelease)
        self.obj.figure.canvas.mpl_disconnect(self.cidmotion)

    def getlocation(self):
        return (self.obj.center[1])

class App(tk.Tk):

    def __init__(self,parent):
        tk.Tk.__init__(self,parent)
        self.withdraw()
        self.parent = parent

        self.protocol("WM_DELETE_WINDOW", self.on_exit)
        # self.FileSelectionApp = FileSelectionApp()
        # self.getFiles()
        # self.main()


    def main(self):
        self.num_saved = 0
        self.mm = 0
        self.columnconfigure(0, weight=1)
        self.rowconfigure(0, weight=1)
        self.rowconfigure(1, weight=1)
        self.rowconfigure(2, weight=1)
        self.rowconfigure(3, weight=1)
        self.minsize(600, 550)

        self.h = self.winfo_screenheight()

        self.fig = plt.figure(1)
        self.fig = plt.figure(figsize=(self.h/100, self.h/120))

        self.frame = tk.Frame(self)
        self.frame.grid(row=0, column=0)

        self.title("Drag points to tissue boundaries and hit 'Next Image' to save")

        self.canvas = FigureCanvasTkAgg(self.fig, master=self.frame)

        self.canvas.get_tk_widget().grid(row=0, column=0)

        self.canvas._tkcanvas.grid(row=0, column=0)
        toolbar_frame = Frame(self)

        self.toolbar = NavigationToolbar2Tk(self.canvas, toolbar_frame)
        self.toolbar.update()
        toolbar_frame.grid(row=1, column=0, sticky='ewns')

        # self.get_image(self.DCM_accepted[self.mm])

    def getFiles(self):
        """Get accepted trials for the selected subject, this includes the Ultrasound dicom images, Data tdms files,
        and the TimeSynchronization txt file. The program can either open a File Dialog box to search for the directory,
        or the user can hard code the directory"""

        # Hard code Subject Directory
        # self.directoryname = '/home/morrile2/Documents/MULTIS_test/MULTIS004-1'

        home = os.path.expanduser('~')
        for dirname, subdirList, fileList in os.walk(home):
            for dir in subdirList:
                if "MULTIS_invitro" in dir:
                    print(dirname+dir)
                    multis_dir = dirname+'/'+dir

        try:
            multis_dir
        except NameError:
            multis_dir = home


        self.directoryname = tkFileDialog.askdirectory(initialdir=multis_dir)
        print(self.directoryname)


        self.lstFilesDCM = []  # create an empty list for DICOM files
        self.lstFilesTDMS = []  # create an empty list for TDMS files

        # try:
        self.masterList, self.num_accept = XMLparser.getAcceptance(self.directoryname)

        # Find DICOM and TDMS files
        for dirName, subdirList, fileList in os.walk(self.directoryname):
            for filename in fileList:
                if ".ima" in filename.lower():  # check whether the file's DICOM
                    self.lstFilesDCM.append(os.path.join(dirName, filename))
                elif ".tdms" in filename.lower():  # check whether the file's TDMS
                    self.lstFilesTDMS.append(os.path.join(dirName, filename))

        # Sort variables by trial number
        self.lstFilesTDMS.sort()
        self.lstFilesDCM.sort()

        # except:

            # self.destroy
            # app.quit()


    def execute(self, kk):
        """Initiate the main application to show the Ultrasound images and interactive GUI for the thickness measurements"""
        # global lstFilesDCM, lstFilesTDMS, mm, location, selectedFrames, conv_fact, data, app, masterList, DCM_img

        self.location = kk
        for x in self.masterList:
            if x[0] == os.path.split(self.lstFilesTDMS[self.location])[1][0:3]:
                if x[1] == 1:
                    if self.mm == 0:
                        for DCM in self.lstFilesDCM:
                            # print(DCM)
                            if os.path.split(DCM)[1][0:3] == os.path.split(self.lstFilesTDMS[self.location])[1][0:3]:
                                self.DCM_img = DCM

                                print("Verify that file names correlate")
                                print(self.DCM_img)
                                print(self.lstFilesTDMS[self.location])

                                # try:
                                self.getFrames()
                                # app.destroy()  # Closes list window because the file selection is complete

                                if self.mm < len(self.frames):
                                    self.showPlot()
                                # except:
                                #     if self.delta_t_file == None:
                                #         tkMessageBox.showerror("Error",
                                #                                "The trial you have selected does not have an associated "
                                #                                "deltaT XML in the TimeSynchronization folder")
                                #     else:
                                #         tkMessageBox.showerror("Error",
                                #                                "An error occurred during analysis")
                                #     self.ListBox.deiconify()
                else:
                    tkMessageBox.showerror("Error", "The trial you have selected is not an accepted trial")
                    self.ListBox.deiconify()

    def getFrames(self):
        """Get the frames to be analyzed and save the data for those frames. Different for indentation and anatomical
        trials. Indentation contains frames that start at indentation and go through the peak force of the indentation,
        while anatomical analyzes all frames between start and end pulses from minimum to maximum force"""

        #Find the xml file with delta_t
        self.analysis_path = os.path.dirname(os.path.dirname(self.lstFilesTDMS[self.location])) + '/TimeSynchronization'
        self.split_name = os.path.split(self.lstFilesTDMS[self.location])
        self.tail = self.split_name[1][0:-5]+'_dT.xml'

        self.delta_t_file = self.find_file(self.tail, self.analysis_path)
        print(self.delta_t_file)

        if self.delta_t_file == None:
            return
        else:

            # Extract force information from TDMS file
            self.data = tdsmParserMultis.parseTDMSfile(self.lstFilesTDMS[self.location])

            Fx = np.array(self.data[u'State.6-DOF Load'][u'6-DOF Load Fx'])
            Fy = np.array(self.data[u'State.6-DOF Load'][u'6-DOF Load Fy'])
            Fz = np.array(self.data[u'State.6-DOF Load'][u'6-DOF Load Fz'])
            Mx = np.array(self.data[u'State.6-DOF Load'][u'6-DOF Load Mx'])
            My = np.array(self.data[u'State.6-DOF Load'][u'6-DOF Load My'])
            Mz = np.array(self.data[u'State.6-DOF Load'][u'6-DOF Load Mz'])

            F_mag = []
            for f in range(len(Fx)):
                F_mag.append(math.sqrt(Fx[f]**2+Fy[f]**2+Fz[f]**2))

            pulse = np.array(self.data[u'Sensor.Run Number Pulse Train'][u'Run Number Pulse Train'])
            pulse = pulse[:]

            Peaks = peakutils.indexes(pulse, thres=0.5 * max(pulse), min_dist=100)

            doc1 = ET.parse(self.delta_t_file)
            root1 = doc1.getroot()
            loc1 = root1.find("Location")
            dT_str = loc1.find("dT").text
            self.dT = float(dT_str)

            # Read metadata from the dicom image
            self.RefDs = dicom.read_file(self.DCM_img)
            self.frameTimeVector = self.RefDs.FrameTimeVector
            seq = self.RefDs.SequenceofUltrasoundRegions
            self.conv_fact = seq[0].PhysicalDeltaY * 10  # Extract conversion factor from dicom metadata (Original is in cm, convert
            # to mm)

            if self.lstFilesTDMS[self.location][-8:-7] == 'I':
                indentation = True
                anatomical = False
            elif self.lstFilesTDMS[self.location][-8:-7] == 'A':
                indentation = False
                anatomical = True

            # force_list = list(F_mag)

            time = np.arange(len(F_mag))

            if indentation:
                # Indentation frames from start of indentation to peak force during indentation
                # Find indentation start time in TDMS timeline , denoted by tdms b/c used for tdms
                data_i = zip(time, Fx, Fy, Fz, Mx, My, Mz, F_mag)

                force_list = self.createAverageFit(F_mag, 300)

                time_max_tdms = force_list.index(max(force_list))

                r_sq_old = 0
                r_sq_diff = 1

                i = 500
                # print(max(force_list[0:230])-min(force_list[0:230]))
                while r_sq_diff > .1:
                    if i > time_max_tdms:
                        time_start_tdms = 230
                        break
                    else:
                        x = np.arange(time_max_tdms - i, time_max_tdms, 1)
                        y = np.array(force_list[time_max_tdms - i:time_max_tdms])
                        r_sq_diff = -r_sq_old + (y[-1] - y[0])
                        # print(r_sq_diff)
                        r_sq_old = (y[-1] - y[0])
                        time_start_tdms = x[100]
                        i += 100

                start_frame, start_frame_time_tdms = self.findFrame(time_start_tdms)
                max_frame, max_frame_time_tdms = self.findFrame(time_max_tdms)

                frame_lst_i = [start_frame]
                data_i_final = [data_i[start_frame_time_tdms]]
                tdmsTime = [start_frame_time_tdms]

                for t in data_i[start_frame_time_tdms + 1:max_frame_time_tdms]:
                    inc_frame, inc_frame_time_tdms = self.findFrame(t[0])
                    if inc_frame not in frame_lst_i:
                        frame_lst_i.append(inc_frame)
                        data_i_final.append(data_i[inc_frame_time_tdms])
                        tdmsTime.append(inc_frame_time_tdms)

                self.frames = frame_lst_i
                self.DATA = data_i_final
                self.tdmsTimes = tdmsTime

            elif anatomical:
                # Anatomical frames from minimum to maximum normal force (Fx)
                time_preIndent_tdms = Peaks[0]  # 230 ms is location first pulse.
                time_postIndent_tdms = Peaks[-1]

                # Sort the data by the magnitude (F_mag) to get anatomical frame list from minimum to maximum
                data_a = zip(time, Fx, Fy, Fz, Mx, My, Mz, F_mag)
                data_a.sort(key=lambda t: abs(t[7]))

                frame_lst_a = []
                data_a_sort = []
                tdmsTime = []

                for t in data_a:
                    try:
                        min_frame, min_frame_time_tdms = self.findFrame(t[0])
                    except:
                        continue
                    if min_frame not in frame_lst_a:
                        if min_frame_time_tdms > time_preIndent_tdms and min_frame_time_tdms < time_postIndent_tdms:
                            frame_lst_a.append(min_frame)
                            data_a_sort.append(data_a[min_frame_time_tdms])
                            tdmsTime.append(min_frame_time_tdms)

                final_a_sort = zip(frame_lst_a, data_a_sort, tdmsTime)
                final_a_sort.sort(key=lambda t: abs(t[1][7]))

                # #Analyze the min, max, and middle force
                self.frames = [final_a_sort[0][0], final_a_sort[int(len(final_a_sort)/2)][0], final_a_sort[len(final_a_sort)-1][0]]
                self.DATA = [final_a_sort[0][1], final_a_sort[int(len(final_a_sort)/2)][1], final_a_sort[len(final_a_sort)-1][1]]
                self.tdmsTimes = [final_a_sort[0][2], final_a_sort[int(len(final_a_sort)/2)][2], final_a_sort[len(final_a_sort)-1][2]]


                # #Analyze from min to max
                # for d in final_a_sort:
                #     self.frames.append(d[0])
                #     self.DATA.append(d[1])



    def createAverageFit(self, F, avgThres):
        """Filter the tdms normal force data"""
        averagelist = []
        for items in range(len(F)):
            if F[items] != F[items - avgThres]:
                num2avg = F[items:(items + avgThres)]
                averagelist.append(np.average(num2avg))
            else:
                continue

        averagelist = [averagelist[0]] * (avgThres / 2) + averagelist[0:-(avgThres) / 2]

        return averagelist


    def find_file(self, name, path):
        for root, dirs, files in os.walk(path):
            if name in files:
                return os.path.join(root, name)


    def findFrame(self, initial_time):
        """Find the frame corresponding to the specified tdms time and return the adjusted tdms time to match the
        selected frame"""
        adjusted_time = self.dT + initial_time
        for f in range(len(self.frameTimeVector)):
            f += 1
            frame_time = sum(self.frameTimeVector[0:f])
            if adjusted_time <= frame_time:
                timeDiff_up = frame_time - adjusted_time
                timeDiff_low = adjusted_time - sum(self.frameTimeVector[0:f - 1])
                if timeDiff_up < timeDiff_low:
                    frame_frame = f
                    readjusted_time_tdms = frame_time - self.dT
                else:
                    frame_frame = f - 1
                    readjusted_time_tdms = sum(self.frameTimeVector[0:f - 1]) - self.dT
                break

        return frame_frame, int(readjusted_time_tdms)


    def showPlot(self):
        """Read and show the appropriate frame of the dicom image in tkinter GUI"""

        self.columnconfigure(0, weight=1)
        self.rowconfigure(0, weight=1)
        self.minsize(600, 550)

        # Read dicom image
        if self.mm == 0:
            reader = ImageFileReader()
            reader.SetFileName(self.DCM_img)
            self.img_all = reader.Execute()
            self.deiconify()

        else:
            self.fig.clear()


        # Read in the appropriate dicom frame and set up plot title
        img = self.img_all[:, :, self.frames[self.mm]]
        self.nda = GetArrayFromImage(img)
        plt.imshow(self.nda[:, :, 0], cmap="gray")
        split_name = os.path.split(self.lstFilesTDMS[self.location])
        plt.title('\n'+split_name[1][0:-5] + '  :  Frame = ' + str(self.frames[self.mm]) + '\n\n' + str(self.mm + 1) + '  of  ' + str(
            len(self.DATA)) + '\n')

        # Create four circles used to indicate tissue interfaces (Top of skin, skin/fat, fat/muscle, bottom of muscle)
        r = 7
        if self.mm == 0:
            y = [60, 90, 160, 360]
        else:
            y = self.find_minimum()

        circs = [matplotlib.patches.Circle((512, y[0]), radius=r, alpha=0.5, fc='r'),
                 matplotlib.patches.Circle((512, y[1]), radius=r, alpha=0.5, fc='r'),
                 matplotlib.patches.Circle((512, y[2]), radius=r, alpha=0.5, facecolor='r'),
                 matplotlib.patches.Circle((512, y[3]), radius=r, alpha=0.5, facecolor='r')]

        self.drs = []

        # Create main GUI

        self.title("Drag points to tissue boundaries and click 'Next Image' to continue")

        self.ax = plt.gca()
        # Plot the four circles and make them draggable using the DraggablePatch Class
        for circ in circs:
            self.ax.add_patch(circ)
            dr = DraggablePatch(circ)
            dr.connect()
            self.drs.append(dr)

        self.btn1 = tk.Button(self, text="Next Image", command=self.next_image)
        self.btn1.grid(row=2, column=0)
        self.btn2 = tk.Button(self, text="Done", command=self.end_program)
        self.btn2.grid(row=2, column=0, sticky='e')

        self.canvas.draw()

    def find_minimum(self):

        self.y_new = list(self.y_new)
        if self.y_new[0] < 60:
            self.y_new[0] = 60
        for i,j in enumerate(self.y_new):

            new_img = self.nda[int(j-60):int(j+60), 512, 0]
            old_img = self.template[i]

            new_img = signal.savgol_filter(new_img, 5,2, mode='nearest')
            old_img = signal.savgol_filter(old_img, 5,2, mode='nearest')

            error = []
            for k in range(len(new_img)-len(old_img)):
                e = 0

                for l in range(len(old_img)):
                    e += (old_img[l].astype(float)-new_img[k+l].astype(float))**2
                error.append(e)

            self.y_new[i] += error.index(min(error))-len(old_img)/2

        return(self.y_new)


    def next_image(self):
        """Saves data and gets next image for analysis - Triggered by 'Next Image' button"""
        # global location, mm, selectedFrames, root_plot, y_new

        self.btn1.config(state="disabled")

        if self.mm < len(self.frames):

            # Get locations of red dots and send to be calculated and saved in xml document
            self.coords = []
            for dr in self.drs:
                coord = dr.getlocation()
                self.coords.append(coord)
            self.calc_thicknesses()  # Calculate the skin, fat and muscle thicknesses
            self.create_xml()
            if self.num_saved == 0:
                self.save_first_image()
            self.num_saved += 1

            # Gets next image
            self.mm += 1

            self.y_new = []
            self.template = []
            for p in self.drs:
                self.y_new.append(int(p.getlocation()))
                self.template.append(self.nda[int(p.getlocation()-30):int(p.getlocation()+30), 512, 0])
            zipped = zip(self.y_new, self.template)
            zipped.sort(key=lambda t: t[0])
            self.y_new, self.template = zip(*zipped)

            if self.mm == len(self.frames):
                self.btn1.config(state="disabled")
            else:
                self.fig.clear()
                self.showPlot()
                self.btn1.config(state="active")
        # elif self.mm == len(self.frames) - 1:


    def end_program(self):
        """Checks to see if there was an xml file created. If so, it is pretty printed to the filename and the program
        closes all windows. Triggered by the 'Close Program' button"""
        try:
            if os.path.exists(self.xml_name):
                self.prettyPrintXml(self.xml_name)
        except:
            pass

        plt.close(self.fig)
        self.ListBox.deiconify()


    def on_exit(self):
        """Closes all windows if the user selects okay, does not pretty print the xml file"""
        if tkMessageBox.askokcancel("Quit", "Do you want to quit without pretty-printing the xml file and saving summary figure?"):
            self.destroy()
            app.quit()


    def prettyPrintXml(self, xmlFilePathToPrettyPrint):
        """Pretty print the xml file after all frames have been analyzed"""
        assert xmlFilePathToPrettyPrint is not None
        parser = etree.XMLParser(resolve_entities=False, remove_blank_text=True)
        document = etree.parse(xmlFilePathToPrettyPrint, parser)
        document.write(xmlFilePathToPrettyPrint, pretty_print=True, encoding='utf-8')
        # try:
        Plot_thicknessForce.PlotSum(xmlFilePathToPrettyPrint)
        # except:
        #     tkMessageBox.showerror("Error", "The plots were not formed.")


    def calc_thicknesses(self):
        """Calculate thicknesses of skin, fat and muscle layers (in pixels) and convert to mm"""
        self.coords.sort()  # Sort the coordinates from lowest to highest to adjust for point switching

        self.skin = (self.coords[1] - self.coords[0]) * self.conv_fact
        self.fat = (self.coords[2] - self.coords[1]) * self.conv_fact

        if self.coords[3] < 700:
            self.muscle = (self.coords[3] - self.coords[2]) * self.conv_fact
        else:
            self.muscle = float('nan')

        print(' ')
        print('*****Thickness Calculations*****')
        print('Skin_thickness = %f mm' % self.skin)
        print('Fat_thickness = %f mm' % self.fat)
        print('Muscle_thickness = %f mm' % self.muscle)
        print(' ')

    def get_positions(self, dir, seg, tdms_frame):
        dir_data = os.path.join(dir, 'Data')

        # Assign seg and bone parameters for later calculations
        if seg == 'UL_':
            bone = 'Femur'
        elif seg == 'LL_':
            bone = 'Tibia'
        elif seg == 'UA_':
            bone = 'Humerus'
        elif seg == 'LA_':
            bone = 'Radius'

        file1 = os.path.split(self.lstFilesTDMS[self.location])[1]

        # data = tdsmParserMultis.parseTDMSfile(os.path.join(dir_data, file1))

        config = ConfigParser.RawConfigParser()
        print config.read(os.path.join(os.path.join(dir, 'Configuration'), file1[0:-5] + '_State.cfg'))
        if not config.read(os.path.join(os.path.join(dir, 'Configuration'), file1[0:-5] + '_State.cfg')):
            raise IOError, "Cannot load configuration file... Check path."

        # Get the raw sensor data for bone position (TDMS)
        B_pos = self.data[u'Sensor.' + bone]

        p_int = self.data[u'Time'][u'Time'][0]
        frame = np.where(np.array(self.data[u'Time'][u'Time']) == min(list(self.data[u'Time'][u'Time']), key=lambda x: abs((x-p_int) - tdms_frame)))[0]

        q1 = B_pos[u'' + bone + '_smart_02.x'][frame] / 1000
        q2 = B_pos[u'' + bone + '_smart_02.y'][frame] / 1000
        q3 = B_pos[u'' + bone + '_smart_02.z'][frame] / 1000
        q4 = np.radians(B_pos[u'' + bone + '_smart_02.r'])[frame]
        q5 = np.radians(B_pos[u'' + bone + '_smart_02.p'])[frame]
        q6 = np.radians(B_pos[u'' + bone + '_smart_02.w'])[frame]

        # Calculate the transformation matrix global to bone Optotrak sensor coordinate system
        T_W_BOS = get_transformation_matrix(q1, q2, q3, q4, q5, q6)

        # US Optotrak sensor to US tip coordinate system
        T_USOS_US = config.get('Probe-' + bone + ' Position', 't_sensor2_rb2')
        T_USOS_US = convertTOarray(T_USOS_US).reshape(4, -1)

        # Get raw ultrasound position (TDMS)
        x = np.array(self.data[u'Sensor.US Probe'][u'US Probe_smart_02.x'])[frame] / 1000
        y = np.array(self.data[u'Sensor.US Probe'][u'US Probe_smart_02.y'])[frame] / 1000
        z = np.array(self.data[u'Sensor.US Probe'][u'US Probe_smart_02.z'])[frame] / 1000
        r = np.radians(np.array(self.data[u'Sensor.US Probe'][u'US Probe_smart_02.r'])[frame])
        p = np.radians(np.array(self.data[u'Sensor.US Probe'][u'US Probe_smart_02.p'])[frame])
        w = np.radians(np.array(self.data[u'Sensor.US Probe'][u'US Probe_smart_02.w'])[frame])

        # Define more transformation matrices
        T_W_USOS = get_transformation_matrix(x, y, z, r, p, w)
        T_W_US = np.dot(T_W_USOS, T_USOS_US)
        T_BOS_US = np.dot(np.linalg.inv(T_W_BOS), T_W_US)

        # Ultrasound probe tip coordinates in the CT image coordinate system
        # A = T_CT_US  # Define what coordinate system to plot everything in
        return get_xyzrpw(T_BOS_US)  # These are the positions that you are looking for (m for xyz, and rad for rpw)

    def create_xml(self):
        """Create xml file with forces, moments, and thicknesses"""

        #Extract data for specific frame (i.e. 'mm')
        data = self.DATA[self.mm]
        fx = data[1]
        fy = data[2]
        fz = data[3]
        mx = data[4]
        my = data[5]
        mz = data[6]

        x, y, z, r, p, w = self.get_positions(self.directoryname, self.lstFilesTDMS[self.location][-14:-11], self.tdmsTimes[self.mm])

        #Define path for the analysis folder
        split_name = os.path.split(self.lstFilesTDMS[self.location])
        analysis_path = os.path.join(os.path.join(os.path.dirname(os.path.dirname(self.lstFilesTDMS[self.location])), 'TissueThickness'), 'UltrasoundManual')

        # Check for Analysis directory
        if not os.path.isdir(analysis_path):
            os.makedirs(analysis_path)

        if self.num_saved == 0:
            self.TIME = time.strftime("%Y%m%d%H%M")

        # Name of the xml file
        self.xml_name = os.path.join(analysis_path, split_name[1][0:26] + '_manThick' + self.TIME + '.xml')

        # Check if an xml file exists for this image. If TRUE, edit the values and if FALSE, create a new one.
        if os.path.exists(self.xml_name):
            # print('Edit File')
            doc = ET.parse(self.xml_name)
            root = doc.getroot()
            loc_lst = []

            subj = root.find('Subject')
            src = subj.find('Source')

            for loc in src.findall('Frame'):
                loc_lst.append(loc.get('value'))
            tail = str(self.frames[self.mm])

            if tail in loc_lst:
                # Over-write thicknesses: Forces and moments should be the same since the data for this frame has already
                # been recorded
                frm = src.find('Frame[@value="%s"]' %tail)
                Thick = frm.find("Thickness")
                Thick.find("Skin").text = str(self.skin)
                Thick.find("Fat").text = str(self.fat)
                Thick.find("Muscle").text = str(self.muscle)

                tree = ET.ElementTree(root)
                tree.write(self.xml_name, xml_declaration=True)

            else:
                #Add new child, because information for this frame is not yet in the xml file
                frm = ET.SubElement(src, 'Frame', value=str(self.frames[self.mm]))
                ET.SubElement(frm, "Time", value=str(sum(self.frameTimeVector[0:self.frames[self.mm]])), units="ms")
                Forces = ET.SubElement(frm, "Forces")
                ET.SubElement(Forces, "Fx", units="N").text = str(fx)
                ET.SubElement(Forces, "Fy", units="N").text = str(fy)
                ET.SubElement(Forces, "Fz", units="N").text = str(fz)

                Moments = ET.SubElement(frm, "Moments")
                ET.SubElement(Moments, "Mx", units="Nm").text = str(mx)
                ET.SubElement(Moments, "My", units="Nm").text = str(my)
                ET.SubElement(Moments, "Mz", units="Nm").text = str(mz)

                Thick = ET.SubElement(frm, "Thickness")
                ET.SubElement(Thick, "Skin", units="mm").text = str(self.skin)
                ET.SubElement(Thick, "Fat", units='mm').text = str(self.fat)
                ET.SubElement(Thick, "Muscle", units='mm').text = str(self.muscle)

                Position = ET.SubElement(frm, "USPosition", attrib={"CoordinateSys": "BoneOptotrakSensor"})
                ET.SubElement(Position, "x", units='m').text = str(x)
                ET.SubElement(Position, "y", units='m').text = str(y)
                ET.SubElement(Position, "z", units='m').text = str(z)
                ET.SubElement(Position, "roll", units='rad').text = str(r)
                ET.SubElement(Position, "pitch", units='rad').text = str(p)
                ET.SubElement(Position, "yaw", units='rad').text = str(w)

                tree = ET.ElementTree(root)
                tree.write(self.xml_name, xml_declaration=True)

        else:
            # print('Make New File')

            TDMS_src = os.path.split(self.lstFilesTDMS[self.location])
            subID = os.path.split(os.path.dirname(os.path.dirname(self.lstFilesTDMS[self.location])))
            root = ET.Element('TA_sfm', attrib={"version":'2.0.2', "modality":'Ultrasound'})
            ET.SubElement(root, 'FileInfo', attrib={"user":getpass.getuser(), "timestamp":str(self.TIME)})
            subj = ET.SubElement(root, 'Subject', attrib={"ID": subID[1]})
            src = ET.SubElement(subj, 'Source', attrib={"Filename":TDMS_src[1]})
            frm = ET.SubElement(src, 'Frame', value=str(self.frames[self.mm]))
            ET.SubElement(frm, "Time", value=str(sum(self.frameTimeVector[0:self.frames[self.mm]])), units="ms")

            Forces = ET.SubElement(frm, "Forces")
            ET.SubElement(Forces, "Fx", units="N").text = str(fx)
            ET.SubElement(Forces, "Fy", units="N").text = str(fy)
            ET.SubElement(Forces, "Fz", units="N").text = str(fz)

            Moments = ET.SubElement(frm, "Moments")
            ET.SubElement(Moments, "Mx", units="Nm").text = str(mx)
            ET.SubElement(Moments, "My", units="Nm").text = str(my)
            ET.SubElement(Moments, "Mz", units="Nm").text = str(mz)

            Thick = ET.SubElement(frm, "Thickness")
            ET.SubElement(Thick, "Skin", units="mm").text = str(self.skin)
            ET.SubElement(Thick, "Fat", units='mm').text = str(self.fat)
            ET.SubElement(Thick, "Muscle", units='mm').text = str(self.muscle)

            Position = ET.SubElement(frm, "USPosition", attrib={"CoordinateSys":"BoneOptotrakSensor"})
            ET.SubElement(Position, "x", units='m').text = str(x)
            ET.SubElement(Position, "y", units='m').text = str(y)
            ET.SubElement(Position, "z", units='m').text = str(z)
            ET.SubElement(Position, "roll", units='rad').text = str(r)
            ET.SubElement(Position, "pitch", units='rad').text = str(p)
            ET.SubElement(Position, "yaw", units='rad').text = str(w)

            tree = ET.ElementTree(root)
            tree.write(self.xml_name, xml_declaration=True)


    def save_first_image(self):
        split_name = os.path.split(self.xml_name)
        png_path = os.path.join(split_name[0], 'ThicknessPNG')

        # Check for Analysis directory
        if not os.path.isdir(png_path):
            os.makedirs(png_path)

        png_name = os.path.join(png_path, split_name[1][0:26] + '_manThick' +self.TIME+'.png')
        self.fig.savefig(png_name)


if __name__ == "__main__":
    app = FileSelectionApp()
    app.mainloop()
