
import tdsmParserMultis
import os
import ConfigParser
import XMLparser
import numpy as np
from sys import stderr
import math
from mayavi import mlab
# import stl
# import peakutils
import xml.etree.ElementTree as ET
import tkFileDialog
import matplotlib.pyplot as plt

# To access any VTK object, we use 'tvtk', which is a Python wrapping of
# VTK replacing C++ setters and getters by Python properties and
# converting numpy arrays to VTK arrays when setting data.
from tvtk.api import tvtk
from tvtk.common import configure_input


def get_RB_transformation_matrix(q1, q2, q3, q4, q5, q6):
    ''' Convert relative positions into transformation matrix'''
    T = np.zeros((4, 4))

    T[0, 0] = math.cos(q6) * math.cos(q5)
    T[1, 0] = math.sin(q6) * math.cos(q5)
    T[2, 0] = -math.sin(q5)

    T[0, 1] = math.cos(q6) * math.sin(q5) * math.sin(q4) - math.sin(q6) * math.cos(q4)
    T[1, 1] = math.sin(q6) * math.sin(q5) * math.sin(q4) + math.cos(q6) * math.cos(q4)
    T[2, 1] = math.cos(q5) * math.sin(q4)

    T[0, 2] = math.cos(q6) * math.sin(q5) * math.cos(q4) + math.sin(q6) * math.sin(q4)
    T[1, 2] = math.sin(q6) * math.sin(q5) * math.cos(q4) - math.cos(q6) * math.sin(q4)
    T[2, 2] = math.cos(q5) * math.cos(q4)

    T[0, 3] = q1
    T[1, 3] = q2
    T[2, 3] = q3
    T[3, 3] = 1

    return T

def get_transformation_matrix(q1, q2, q3, q4, q5, q6):
    ''' Transform from optotrak global coordinates to optotrak position sensor coordinates'''

    T = np.zeros((4, 4))

    T[0, 0] = math.cos(q6) * math.cos(q5)
    T[1, 0] = math.sin(q6) * math.cos(q5)
    T[2, 0] = -math.sin(q5)

    T[0, 1] = math.cos(q6) * math.sin(q5) * math.sin(q4) - math.sin(q6) * math.cos(q4)
    T[1, 1] = math.sin(q6) * math.sin(q5) * math.sin(q4) + math.cos(q6) * math.cos(q4)
    T[2, 1] = math.cos(q5) * math.sin(q4)

    T[0, 2] = math.cos(q6) * math.sin(q5) * math.cos(q4) + math.sin(q6) * math.sin(q4)
    T[1, 2] = math.sin(q6) * math.sin(q5) * math.cos(q4) - math.cos(q6) * math.sin(q4)
    T[2, 2] = math.cos(q5) * math.cos(q4)

    T[0, 3] = q1
    T[1, 3] = q2
    T[2, 3] = q3
    T[3, 3] = 1

    return T

def convertTOarray(data):

    data = data.replace("", '')
    data = data.split(" ")
    corrected_list = data[2:len(data)]
    corrected_list[-1] = corrected_list[-1][0:-1]
    data_float = map(float, corrected_list)

    return np.asarray(data_float)

def get_distance(A, B):
    length = math.sqrt((A[0]-B[0])**2+(A[1]-B[1])**2+(A[2]-B[2])**2)
    vector = (A - B)
    x_angle = math.acos(np.dot(vector, [1,0,0])/length)*180/math.pi
    y_angle = math.acos(np.dot(vector, [0,1,0])/length)*180/math.pi
    z_angle = math.acos(np.dot(vector, [0,0,1])/length)*180/math.pi
    return float(length), float(x_angle), float(y_angle), float(z_angle)

def get_xyzrpw(T):
    x = T[0,3]
    y = T[1,3]
    z = T[2,3]
    w = math.atan2(T[1,0], T[0,0])
    p = math.atan2(-T[2,0], math.sqrt(T[2,1]**2+T[2,2]**2))
    r = math.atan2(T[2,1], T[2,2])
    return x, y, z, r, p, w

def fit_hypersphere(data, method="Hyper"):
    """
            FitHypersphere.py

            fit_hypersphere(collection of tuples or lists of real numbers)
            will return a hypersphere of the same dimension as the tuples:
                    (radius, (center))

            using the Hyper (hyperaccurate) algorithm of
            Ali Al-Sharadqah and Nikolai Chernov
            Error analysis for circle fitting algorithms
            Electronic Journal of Statistics
            Vol. 3 (2009) 886-911
            DOI: 10.1214/09-EJS419

            generalized to n dimensions

            Mon Apr 23 04:08:05 PDT 2012 Kevin Karplus

            Note: this version using SVD works with Hyper, Pratt, and Taubin methods.
            If you are not familiar with them, Hyper is probably your best choice.


            Creative Commons Attribution-ShareAlike 3.0 Unported License.
            http://creativecommons.org/licenses/by-sa/3.0/
    """

    """returns a hypersphere of the same dimension as the
        collection of input tuples
                (radius, (center))

       Methods available for fitting are "algebraic" fitting methods
        Hyper   Al-Sharadqah and Chernov's Hyperfit algorithm
        Pratt   Vaughn Pratt's algorithm
        Taubin  G. Taubin's algorithm

       The following methods, though very similar, are not implemented yet,
          because the contraint matrix N would be singular,
          and so the N_inv computation is not doable.

        Kasa    Kasa's algorithm
    """
    num_points = len(data)
    #    print >>stderr, "DEBUG: num_points=", num_points

    if num_points == 0:
        return (0, None)
    if num_points == 1:
        return (0, data[0])
    dimen = len(data[0])  # dimensionality of hypersphere
    #    print >>stderr, "DEBUG: dimen=", dimen

    if num_points < dimen + 1:
        raise ValueError( \
            "Error: fit_hypersphere needs at least {} points to fit {}-dimensional sphere, but only given {}".format(
                dimen + 1, dimen, num_points))

    # central dimen columns of matrix  (data - centroid)
    central = np.matrix(data, dtype=float)  # copy the data
    centroid = np.mean(central, axis=0)
    for row in central:
        row -= centroid
    # print >>stderr, "DEBUG: central=", repr(central)

    # squared magnitude for each centered point, as a column vector
    square_mag = [sum(a * a for a in row.flat) for row in central]
    square_mag = np.matrix(square_mag).transpose()
    #    print >>stderr, "DEBUG: square_mag=", square_mag

    if method == "Taubin":
        # matrix of normalized squared magnitudes, data
        mean_square = square_mag.mean()
        data_Z = np.bmat([[(square_mag - mean_square) / (2 * math.sqrt(mean_square)), central]])
        #    print >> stderr, "DEBUG: data_Z=",data_Z
        u, s, v = np.linalg.svd(data_Z, full_matrices=False)
        param_vect = v[-1, :]
        params = [x for x in np.asarray(param_vect)[0]]  # convert from (dimen+1) x 1 matrix to list
        params[0] /= 2 * math.sqrt(mean_square)
        params.append(-mean_square * params[0])
        params = np.array(params)

    else:
        # matrix of squared magnitudes, data, 1s
        data_Z = np.bmat([[square_mag, central, np.ones((num_points, 1))]])
        #    print >> stderr, "DEBUG: data_Z=",data_Z

        # SVD of data_Z
        # Note: numpy's linalg.svd returns data_Z = u * s * v
        #         not u*s*v.H as the Release 1.4.1 documentation claims.
        #         Newer documentation is correct.
        u, s, v = np.linalg.svd(data_Z, full_matrices=False)
        #    print >>stderr, "DEBUG: u=",repr(u)
        #    print >>stderr, "DEBUG: s=",repr(s)
        #    print >>stderr, "DEBUG: v=",repr(v)
        #    print >>stderr, "DEBUG: v.I=",repr(v.I)

        if s[-1] / s[0] < 1e-12:
            # singular case
            # param_vect as (dimen+2) x 1 matrix
            param_vect = v[-1, :]
            # Note: I get last ROW of v, while Chernov claims last COLUMN,
            # because of difference in definition of SVD for MATLAB and numpy

            #        print >> stderr, "DEBUG: singular, param_vect=", repr(param_vect)
            #        print >> stderr, "DEBUG: data_Z*V=", repr(data_Z*v)
            #        print >> stderr, "DEBUG: data_Z*VI=", repr(data_Z*v.I)
            #        print >> stderr, "DEBUG: data_Z*A=", repr(data_Z*v[:,-1])
        else:
            Y = v.H * np.diag(s) * v
            Y_inv = v.H * np.diag([1. / x for x in s]) * v
            #        print >>stderr, "DEBUG: Y=",repr(Y)
            #        print >>stderr, "DEBUG: Y.I=",repr(Y.I), "\nY_inv=",repr(Y_inv)
            # Ninv is the inverse of the constraint matrix, after centroid has been removed
            Ninv = np.asmatrix(np.identity(dimen + 2, dtype=float))
            if method == "Hyper":
                Ninv[0, 0] = 0
                Ninv[0, -1] = 0.5
                Ninv[-1, 0] = 0.5
                Ninv[-1, -1] = -2 * square_mag.mean()
            elif method == "Pratt":
                Ninv[0, 0] = 0
                Ninv[0, -1] = -0.5
                Ninv[-1, 0] = -0.5
                Ninv[-1, -1] = 0
            else:
                raise ValueError("Error: unknown method: {} should be 'Hyper', 'Pratt', or 'Taubin'")
                #        print >> stderr, "DEBUG: Ninv=", repr(Ninv)

            # get the eigenvector for the smallest positive eigenvalue
            matrix_for_eigen = Y * Ninv * Y
            #   print >> stderr, "DEBUG: {} matrix_for_eigen=\n{}".format(method, repr(matrix_for_eigen))
            eigen_vals, eigen_vects = np.linalg.eigh(matrix_for_eigen)
            #   print >> stderr, "DEBUG: eigen_vals=", repr(eigen_vals)
            #   print >> stderr, "DEBUG: eigen_vects=", repr(eigen_vects)

            positives = [x for x in eigen_vals if x > 0]
            if len(positives) + 1 != len(eigen_vals):
                # raise ValueError("Error: for method {} exactly one eigenvalue should be negative: {}".format(method,eigen_vals))
                print>> stderr, "Warning: for method {} exactly one eigenvalue should be negative: {}".format(method,
                                                                                                              eigen_vals)
            smallest_positive = min(positives)
            #    print >> stderr, "DEBUG: smallest_positive=", smallest_positive
            # chosen eigenvector as 1 x (dimen+2) matrix
            A_colvect = eigen_vects[:, list(eigen_vals).index(smallest_positive)]
            #        print >> stderr, "DEBUG: A_colvect=", repr(A_colvect)
            # now have to multiply by Y inverse
            param_vect = (Y_inv * A_colvect).transpose()
            #        print >> stderr, "DEBUG: nonsingular, param_vect=", repr(param_vect)
            params = np.asarray(param_vect)[0]  # convert from (dimen+2) x 1 matrix to array of (dimen+2)


            #    print >> stderr, "DEBUG: params=", repr(params)
    radius = 0.5 * math.sqrt(sum(a * a for a in params[1:-1]) - 4 * params[0] * params[-1]) / abs(params[0])
    center = -0.5 * params[1:-1] / params[0]
    # y    print >> stderr, "DEBUG: center=", repr(center), "centroid=", repr(centroid)
    center += np.asarray(centroid)[0]
    return (radius, center)

def transform_sphere_points(B1m, B1mr):
    ''' Function to extract the centers of each registration marker in the respective bone Optotrak sensor coordinate
    system'''

    # Within this data set, there are 30 points, each with an x,y,and z coordinate
    # next we want to convert this data set into an array. The units for the axis
    # are all in m.

    # But first need to get rid of the quotations at the beginning and end of string.
    B1m = B1m.replace('"', '')
    B1msplit = B1m.split(" ")

    # Get rid of first two cells, and convert string into float
    corrected_list = B1msplit[2:92]
    corrected_floatlist = map(float, corrected_list)

    ## matrix of the 30 points
    B1mCoord = np.asarray(corrected_floatlist)
    B1mCoord = B1mCoord.reshape(30, -1)

    # Within this data set, there are 30 points, each with an x,y,z,roll, pitch and
    # yaw coordinate next we want to convert this data set into an array. The
    # units for the x,y, and z axis are in m and the roll, pitch, and yaw axis
    # are in rad.

    # But first need to get rid of the quotations at the beginning and end of string.
    B1mr = B1mr.replace('"', '')
    B1mrsplit = B1mr.split(" ")

    # Get rid of first two cells, and convert string into float
    corrected_list = B1mrsplit[2:182]
    corrected_floatlist = map(float, corrected_list)

    ## matrix of the 30 points
    B1mrCoord = np.asarray(corrected_floatlist)
    B1mrCoord = B1mrCoord.reshape(30, -1)

    ## Define coordinate transformations of data
    P1 = np.ones((4, 1))
    Coord1 = np.zeros((30, 3))

    for i in range(0, 30):
        q1 = B1mrCoord[i, 0]
        q2 = B1mrCoord[i, 1]
        q3 = B1mrCoord[i, 2]
        q4 = B1mrCoord[i, 3]
        q5 = B1mrCoord[i, 4]
        q6 = B1mrCoord[i, 5]

        T1 = get_transformation_matrix(q1, q2, q3, q4, q5, q6)

        P1[0, 0] = B1mCoord[i, 0]
        P1[1, 0] = B1mCoord[i, 1]
        P1[2, 0] = B1mCoord[i, 2]

        invT1 = np.linalg.inv(T1)

        A = np.dot(invT1, P1) # Transform to bone Optotrak sensor coordinate system
        Coord1[i, 0] = A[0, 0]
        Coord1[i, 1] = A[1, 0]
        Coord1[i, 2] = A[2, 0]

    ACoord1 = Coord1[0:10, 0:3]
    BCoord1 = Coord1[10:20, 0:3]
    CCoord1 = Coord1[20:30, 0:3]


    # Sphere fit for Rigid Body Collected Points (m), all three spheres. Set to NAN if points were not digitized.
    NAN = float('nan')
    try:
        B1mSphereA = fit_hypersphere(ACoord1, method="Pratt")
    except:
        B1mSphereA = [(0), (NAN, NAN, NAN)]

    try:
        B1mSphereB = fit_hypersphere(BCoord1, method="Pratt")
    except:
        B1mSphereB = [(0), (NAN, NAN, NAN)]

    try:
        B1mSphereC = fit_hypersphere(CCoord1, method="Pratt")
    except:
        B1mSphereC = [(0), (NAN, NAN, NAN)]

    return B1mSphereA, B1mSphereB, B1mSphereC


def findFrame(dT, frameTimeVector, initial_time):
    """Find the frame corresponding to the specified tdms time and return the adjusted tdms time to match the
    selected frame"""
    adjusted_time = dT + initial_time
    for f in range(len(frameTimeVector)):
        f += 1
        frame_time = sum(frameTimeVector[0:f])
        if adjusted_time <= frame_time:
            timeDiff_up = frame_time - adjusted_time
            timeDiff_low = adjusted_time - sum(frameTimeVector[0:f - 1])
            if timeDiff_up < timeDiff_low:
                frame_frame = f
                readjusted_time_tdms = frame_time - dT
            else:
                frame_frame = f - 1
                readjusted_time_tdms = sum(frameTimeVector[0:f - 1]) - dT
            break

    return frame_frame, int(readjusted_time_tdms)


def find_file(name, path):
    for root, dirs, files in os.walk(path):
        if name in files:
            return os.path.join(root, name)

# def getFrames(dir, tdms, frameTimeVector):
#     """Get the frames to be analyzed and save the data for those frames. Different for indentation and anatomical
#     trials. Indentation contains frames that start at indentation and go through the peak force of the indentation,
#     while anatomical analyzes all frames between start and end pulses from minimum to maximum force"""
#
#     # Find the xml file with delta_t
#     analysis_path = os.path.join(os.path.split(dir)[0], 'TimeSynchronization')
#     split_name = tdms
#     tail = split_name[0:-5] + '_dT.xml'
#
#     delta_t_file = find_file(tail, analysis_path)
#     # print(delta_t_file)
#
#     if delta_t_file == None:
#         return
#     else:
#
#         # Extract force information from TDMS file
#         data = tdsmParserMultis.parseTDMSfile(os.path.join(dir, tdms))
#
#         Fx = np.array(data[u'State.6-DOF Load'][u'6-DOF Load Fx'])
#         Fy = np.array(data[u'State.6-DOF Load'][u'6-DOF Load Fy'])
#         Fz = np.array(data[u'State.6-DOF Load'][u'6-DOF Load Fz'])
#         Mx = np.array(data[u'State.6-DOF Load'][u'6-DOF Load Mx'])
#         My = np.array(data[u'State.6-DOF Load'][u'6-DOF Load My'])
#         Mz = np.array(data[u'State.6-DOF Load'][u'6-DOF Load Mz'])
#
#         F_mag = []
#         for f in range(len(Fx)):
#             F_mag.append(math.sqrt(Fx[f] ** 2 + Fy[f] ** 2 + Fz[f] ** 2))
#
#         pulse = np.array(data[u'Sensor.Run Number Pulse Train'][u'Run Number Pulse Train'])
#         pulse = pulse[:]
#
#         # Peaks = peakutils.indexes(pulse, thres=0.5 * max(pulse), min_dist=100)
#
#         doc1 = ET.parse(delta_t_file)
#         root1 = doc1.getroot()
#         loc1 = root1.find("Location")
#         dT_str = loc1.find("dT").text
#         dT = float(dT_str)
#
#         time = np.arange(len(F_mag))
#         data_i = zip(time, Fx, Fy, Fz, Mx, My, Mz, F_mag)
#
#         # Anatomical frames from minimum to maximum normal force (Fx)
#         time_preIndent_tdms = Peaks[0]  # 230 ms is location first pulse.
#         time_postIndent_tdms = Peaks[-1]
#
#         # Sort the data by the magnitude (F_mag) to get anatomical frame list from minimum to maximum
#         data_a = zip(time, Fx, Fy, Fz, Mx, My, Mz, F_mag)
#         data_a.sort(key=lambda t: abs(t[7]))
#
#         frame_lst_a = []
#         data_a_sort = []
#
#         for t in data_a:
#             try:
#                 min_frame, min_frame_time_tdms = findFrame(dT, frameTimeVector, t[0])
#             except:
#                 continue
#             if min_frame not in frame_lst_a:
#                 if min_frame_time_tdms > time_preIndent_tdms and min_frame_time_tdms < time_postIndent_tdms:
#                     frame_lst_a.append(min_frame)
#                     data_a_sort.append(data_i[min_frame_time_tdms])
#
#         final_a_sort = zip(frame_lst_a, data_a_sort)
#         final_a_sort.sort(key=lambda t: abs(t[1][7]))
#
#         # Return Minimum Force
#         min_frame = final_a_sort[0][0]
#         DATA = final_a_sort[0][1]
#
#         return min_frame, DATA

def get_stl_points(file):
    reader = tvtk.STLReader()
    reader.file_name = file
    reader.update()

    stl_points = reader.output.points.data.to_array()

    return stl_points

def get_stl_points(file):
    reader = tvtk.STLReader()
    reader.file_name = file
    reader.update()

    stl_points = reader.output.points.data.to_array()

    return stl_points

def get_stl_dist(subj_path, modality, segment, spheres):

    # spheres = tkFileDialog.askopenfilenames(title="Select all " + segment + " fiducial marker STL files", initialdir =os.path.join(subj_path, modality))
    labels = []
    for s in spheres:
        filename = os.path.split(s)[1]
        labels.append(filename[-9:-7])

    sphere_fit = []
    for i, s in enumerate(spheres):
        r, pos = fit_hypersphere(get_stl_points(s))
        sphere_fit.append([os.path.split(s)[1], r, pos])

    center_dist = []
    # print "STLS"
    for p1 in range(len(sphere_fit)-1):
        for p2 in range(p1+1,len(sphere_fit)):
            d, x_a, y_a, z_a = get_distance(sphere_fit[p1][2], sphere_fit[p2][2])
            center_dist.append([sphere_fit[p1][0], sphere_fit[p2][0], d, x_a, y_a, z_a])

    return center_dist, spheres, sphere_fit

def adjust_angles(ang_list):
    for i, A in enumerate(ang_list):
        if ang_list[i] > np.pi:
            ang_list[i] -= np.pi*2
        elif ang_list[i] < -np.pi:
            ang_list[i] += np.pi*2
    return ang_list

def adjust_loads(L):
    L -= np.mean(L[0:200])
    return L

def zero_pos(p):
    p -= p[0]
    return p

def main(dir, segment, modality):


    # Assign seg and bone parameters for later calculations
    if segment == 'UpperLeg':
        seg = 'UL_'
        bone = 'Femur'
        markers = ["F1", "F2", "F3", "F4", "F5", "F6"]
    elif segment == 'LowerLeg':
        seg = 'LL_'
        bone = 'Tibia'
        markers = ["T1", "T2", "T3", "T4", "T5", "T6"]
    elif segment == 'UpperArm':
        seg = 'UA_'
        bone = 'Humerus'
        markers = ["H1", "H2", "H3", "H4", "N/A", "N/A"]
    elif segment == 'LowerArm':
        seg = 'LA_'
        bone = 'Radius'
        markers = ["R1", "N/A", "N/A", "R2", "R3", "N/A"]

    donorIDs = os.listdir(dir)
    donorIDs.sort()
    donorIDlist = []
    for item in donorIDs:
        if item != '.DS_Store':
            donorIDlist.append(item)

    # for donorID in donorIDlist:

    donorID = donorIDlist[0]
    # print donorIDlist

    dir = dir + donorID + '/'

    # Get accepted trials from experiment
    masterList, num_accept = XMLparser.getAcceptance(dir)
    accepted_list = []
    for x in masterList:
        if x[1] == 1:
            accepted_list.append(x[0])

    print accepted_list

    dir_data = os.path.join(dir, 'Data')
    files = os.listdir(dir_data)
    files.sort()

    # Plot positions of ultrasound probe for each accepted ultrasound trial for the specified segment
    for file in files:
        if any(file[0:3] in s for s in accepted_list):

            if file[0:3]!='100':

                data = tdsmParserMultis.parseTDMSfile(os.path.join(dir_data, file))

                config = ConfigParser.RawConfigParser()
                if not config.read(os.path.join(os.path.join(dir, 'Configuration'), file[0:-5]+'_State.cfg')):
                    raise IOError, "Cannot load configuration file... Check path."

                print '====================='
                print file

                if file[17] == 'R':
                    tools = ['Retractor']
                    loads = ['Retractor Load']
                elif file[17] == 'S':
                    tools = ['Scalpel']
                    loads = ['Scalpel Load']
                elif file[17] == 'F':
                    tools = ['Forceps Left', 'Forceps Right']
                    loads = ['Forceps Left Load', 'Forceps Right Load']

                for i, tool in enumerate(tools):
                    # try:
                    # Get the raw data for tool position and load
                    B_pos = data[u'State.' + tool]
                    frame = 0
                    time = data[u'Time'][u'Time']
                    x = (B_pos[u'' + tool + '_x'])/ 1000
                    y = (B_pos[u'' + tool + '_y'])/ 1000
                    z = (B_pos[u'' + tool + '_z']) / 1000
                    r = np.radians(B_pos[u'' + tool + '_roll'])
                    p = np.radians(B_pos[u'' + tool + '_pitch'])
                    w = np.radians(B_pos[u'' + tool + '_yaw'])

                    # B_pos = data[u'Sensor.' + tool]
                    # frame = 0
                    # time = data[u'Time'][u'Time']
                    # x = B_pos[u'' + tool + '_scalpel multis.x'] / 1000
                    # y = B_pos[u'' + tool + '_scalpel multis.y'] / 1000
                    # z = B_pos[u'' + tool + '_scalpel multis.z'] / 1000
                    # r = np.radians(B_pos[u'' + tool + '_scalpel multis.r'])
                    # p = np.radians(B_pos[u'' + tool + '_scalpel multis.p'])
                    # w = np.radians(B_pos[u'' + tool + '_scalpel multis.w'])

                    T_FEM_T = np.zeros([len(x), 3,3])
                    for ii in range(len(x)):
                        T_FEM_T_1 = get_RB_transformation_matrix(x[ii], y[ii], z[ii], r[ii], p[ii], w[ii])
                        T_FEM_T[ii] = T_FEM_T_1[0:3,0:3]

                    missing = np.sum(np.isnan(x))
                    if missing > 0:
                        print "Missing {} tool data: {} data points".format(tool, missing)

                    force_len = len(data[u'State.'+loads[i]][u''+loads[i]+' Fx'])
                    print "Time Diff", force_len - time[-1]
                    print force_len, time[-1]
                    # exit()
                    start = force_len - time[-1]
                    start = 0
                    Fx = data[u'State.'+loads[i]][u''+loads[i]+' Fx'][start:]
                    Fy = data[u'State.'+loads[i]][u''+loads[i]+' Fy'][start:]
                    Fz = data[u'State.'+loads[i]][u''+loads[i]+' Fz'][start:]
                    Mx = data[u'State.'+loads[i]][u''+loads[i]+' Mx'][start:]
                    My = data[u'State.'+loads[i]][u''+loads[i]+' My'][start:]
                    Mz = data[u'State.'+loads[i]][u''+loads[i]+' Mz'][start:]
                    # print len(Fx)
                    t = np.arange(0, len(Fx), 1)
                    print time[-1]
                    print t[-1]

                    for f in range(len(Fx)):
                        Fx[f], Fy[f], Fz[f] = np.dot(T_FEM_T[(time - t[f]).argmin()], [Fx[f], Fy[f], Fz[f]])
                        Mx[f], My[f], Mz[f] = np.dot(T_FEM_T[(time - t[f]).argmin()], [Mx[f], My[f], Mz[f]])

                    # Zero positions and angles at beginning of trial
                    x = zero_pos(x)
                    y = zero_pos(y)
                    z = zero_pos(z)
                    r = zero_pos(r)
                    p = zero_pos(p)
                    w = zero_pos(w)

                    # Adjust the angles so that they lay between -180 and 180 degrees
                    r = adjust_angles(r)
                    p = adjust_angles(p)
                    w = adjust_angles(w)

                    # # Digitally tare loads and moments
                    Fx = adjust_loads(Fx)
                    Fy = adjust_loads(Fy)
                    Fz = adjust_loads(Fz)
                    Mx = adjust_loads(Mx)
                    My = adjust_loads(My)
                    Mz = adjust_loads(Mz)

                    f, (ax1, ax2, ax3, ax4) = plt.subplots(4, 1, figsize=(15, 15))

                    ax1.plot(time, x*1000, c='red', ls='--')
                    ax1.plot(time, y*1000, c='green', ls=':')
                    ax1.plot(time, z*1000, c='blue')
                    # ax1.scatter(time, x * 1000)
                    # ax1.scatter(time, y * 1000)
                    # ax1.scatter(time, z * 1000)
                    ax1.set_xlim([0,time[-1]])
                    ax1.set_ylabel('Tool Position (mm)')
                    ax1.legend(['X', 'Y', 'Z'], loc='center left', bbox_to_anchor=(1, 0.5))

                    ax2.plot(time, np.degrees(r), c='red', ls='--')
                    ax2.plot(time, np.degrees(p), c='green', ls=':')
                    ax2.plot(time, np.degrees(w), c='blue')
                    # ax2.scatter(time, np.degrees(r))
                    # ax2.scatter(time, np.degrees(p))
                    # ax2.scatter(time, np.degrees(w))
                    ax2.set_ylabel('Tool Orientation (deg)')
                    ax2.set_xlim([0, time[-1]])
                    ax2.legend(['roll', 'pitc', 'yaw'], loc='center left', bbox_to_anchor=(1, 0.5))

                    ax3.plot(Fx, c='red', ls='--')
                    ax3.plot(Fy, c='green', ls=':')
                    ax3.plot(Fz, c='blue')
                    ax3.set_ylabel('Tool Forces (N)')
                    ax3.set_xlim([0, time[-1]])
                    ax3.legend(['Fx', 'Fy', 'Fz'], loc='center left', bbox_to_anchor=(1, 0.5))

                    ax4.plot(Mx, c='red', ls='--')
                    ax4.plot(My, c='green', ls=':')
                    ax4.plot(Mz, c='blue')
                    ax4.set_xlim([0, time[-1]])
                    ax4.set_ylabel('Tool Moments (Nm)')
                    ax4.set_xlabel('time (ms)')
                    ax4.legend(['Mx', 'My', 'Mz'], loc='center left', bbox_to_anchor=(1, 0.5))

                    plt.suptitle(tool)
                    plt.tight_layout()
                    plt.subplots_adjust(right=0.85, top = 0.95)
                    if not os.path.exists(os.path.join(dir, 'DataQuality')):
                        os.makedirs(os.path.join(dir, 'DataQuality'))
                    f.savefig(os.path.join(dir, 'DataQuality', file[0:-5]+'_'+tool.replace(" ", "")+'.png'))
                    # print file[-18:-7]
                    # if file[-18:-7] == 'SXX_CUT_SKN':
                        # f.savefig(os.path.join(dir, 'DataQuality', file[0:-5] + '_' + tool.replace(" ", "") + '.eps'), format = 'eps')
                    # exit()
                    plt.close()
                    # except:
                    #     print "Missing " + tool + " optotrak data"

            print ""
        # break

if __name__ == "__main__":
    dir = '/Users/schimmt/Downloads/SMULTIS009-1'  # Cadaver specimen folder
    dir = '/Users/schimmt/Multis/app/FileAssociation/MULTIS_trials/Surgical Tools/'
    dir = '/Users/schimmt/Multis/studies/SurgicalToolsRefData/dat/'  # Cadaver specimen folder

    # dir = '/Users/schimmt/Downloads/SMULTIS006-1'  # Cadaver specimen folder
    segment = 'UpperLeg'  # Change this to the desired segment
    modality = 'CT'
    main(dir, segment, modality)
    # plt.show()