
import tdsmParserMultis
import os
import ConfigParser
import XMLparser
import numpy as np
from sys import stderr
import math
from mayavi import mlab
import stl
import peakutils
import xml.etree.ElementTree as ET
import tkFileDialog
import matplotlib.pyplot as plt

# To access any VTK object, we use 'tvtk', which is a Python wrapping of
# VTK replacing C++ setters and getters by Python properties and
# converting numpy arrays to VTK arrays when setting data.
from tvtk.api import tvtk
from tvtk.common import configure_input


def get_RB_transformation_matrix(q1, q2, q3, q4, q5, q6):
    ''' Convert relative positions into transformation matrix'''
    T = np.zeros((4, 4))

    T[0, 0] = math.cos(q6) * math.cos(q5)
    T[1, 0] = math.sin(q6) * math.cos(q5)
    T[2, 0] = -math.sin(q5)

    T[0, 1] = math.cos(q6) * math.sin(q5) * math.sin(q4) - math.sin(q6) * math.cos(q4)
    T[1, 1] = math.sin(q6) * math.sin(q5) * math.sin(q4) + math.cos(q6) * math.cos(q4)
    T[2, 1] = math.cos(q5) * math.sin(q4)

    T[0, 2] = math.cos(q6) * math.sin(q5) * math.cos(q4) + math.sin(q6) * math.sin(q4)
    T[1, 2] = math.sin(q6) * math.sin(q5) * math.cos(q4) - math.cos(q6) * math.sin(q4)
    T[2, 2] = math.cos(q5) * math.cos(q4)

    T[0, 3] = q1
    T[1, 3] = q2
    T[2, 3] = q3
    T[3, 3] = 1

    return T

def get_transformation_matrix(q1, q2, q3, q4, q5, q6):
    ''' Transform from optotrak global coordinates to optotrak position sensor coordinates'''

    T = np.zeros((4, 4))

    T[0, 0] = math.cos(q6) * math.cos(q5)
    T[1, 0] = math.sin(q6) * math.cos(q5)
    T[2, 0] = -math.sin(q5)

    T[0, 1] = math.cos(q6) * math.sin(q5) * math.sin(q4) - math.sin(q6) * math.cos(q4)
    T[1, 1] = math.sin(q6) * math.sin(q5) * math.sin(q4) + math.cos(q6) * math.cos(q4)
    T[2, 1] = math.cos(q5) * math.sin(q4)

    T[0, 2] = math.cos(q6) * math.sin(q5) * math.cos(q4) + math.sin(q6) * math.sin(q4)
    T[1, 2] = math.sin(q6) * math.sin(q5) * math.cos(q4) - math.cos(q6) * math.sin(q4)
    T[2, 2] = math.cos(q5) * math.cos(q4)

    T[0, 3] = q1
    T[1, 3] = q2
    T[2, 3] = q3
    T[3, 3] = 1

    return T

def convertTOarray(data):

    data = data.replace("", '')
    data = data.split(" ")
    corrected_list = data[2:len(data)]
    corrected_list[-1] = corrected_list[-1][0:-1]
    data_float = map(float, corrected_list)

    return np.asarray(data_float)

def get_distance(A, B):
    length = math.sqrt((A[0]-B[0])**2+(A[1]-B[1])**2+(A[2]-B[2])**2)
    vector = (A - B)
    x_angle = math.acos(np.dot(vector, [1,0,0])/length)*180/math.pi
    y_angle = math.acos(np.dot(vector, [0,1,0])/length)*180/math.pi
    z_angle = math.acos(np.dot(vector, [0,0,1])/length)*180/math.pi
    return float(length), float(x_angle), float(y_angle), float(z_angle)

def get_xyzrpw(T):
    x = T[0,3]
    y = T[1,3]
    z = T[2,3]
    w = math.atan2(T[1,0], T[0,0])
    p = math.atan2(-T[2,0], math.sqrt(T[2,1]**2+T[2,2]**2))
    r = math.atan2(T[2,1], T[2,2])
    return x, y, z, r, p, w

def fit_hypersphere(data, method="Hyper"):
    """
            FitHypersphere.py

            fit_hypersphere(collection of tuples or lists of real numbers)
            will return a hypersphere of the same dimension as the tuples:
                    (radius, (center))

            using the Hyper (hyperaccurate) algorithm of
            Ali Al-Sharadqah and Nikolai Chernov
            Error analysis for circle fitting algorithms
            Electronic Journal of Statistics
            Vol. 3 (2009) 886-911
            DOI: 10.1214/09-EJS419

            generalized to n dimensions

            Mon Apr 23 04:08:05 PDT 2012 Kevin Karplus

            Note: this version using SVD works with Hyper, Pratt, and Taubin methods.
            If you are not familiar with them, Hyper is probably your best choice.


            Creative Commons Attribution-ShareAlike 3.0 Unported License.
            http://creativecommons.org/licenses/by-sa/3.0/
    """

    """returns a hypersphere of the same dimension as the
        collection of input tuples
                (radius, (center))

       Methods available for fitting are "algebraic" fitting methods
        Hyper   Al-Sharadqah and Chernov's Hyperfit algorithm
        Pratt   Vaughn Pratt's algorithm
        Taubin  G. Taubin's algorithm

       The following methods, though very similar, are not implemented yet,
          because the contraint matrix N would be singular,
          and so the N_inv computation is not doable.

        Kasa    Kasa's algorithm
    """
    num_points = len(data)
    #    print >>stderr, "DEBUG: num_points=", num_points

    if num_points == 0:
        return (0, None)
    if num_points == 1:
        return (0, data[0])
    dimen = len(data[0])  # dimensionality of hypersphere
    #    print >>stderr, "DEBUG: dimen=", dimen

    if num_points < dimen + 1:
        raise ValueError( \
            "Error: fit_hypersphere needs at least {} points to fit {}-dimensional sphere, but only given {}".format(
                dimen + 1, dimen, num_points))

    # central dimen columns of matrix  (data - centroid)
    central = np.matrix(data, dtype=float)  # copy the data
    centroid = np.mean(central, axis=0)
    for row in central:
        row -= centroid
    # print >>stderr, "DEBUG: central=", repr(central)

    # squared magnitude for each centered point, as a column vector
    square_mag = [sum(a * a for a in row.flat) for row in central]
    square_mag = np.matrix(square_mag).transpose()
    #    print >>stderr, "DEBUG: square_mag=", square_mag

    if method == "Taubin":
        # matrix of normalized squared magnitudes, data
        mean_square = square_mag.mean()
        data_Z = np.bmat([[(square_mag - mean_square) / (2 * math.sqrt(mean_square)), central]])
        #    print >> stderr, "DEBUG: data_Z=",data_Z
        u, s, v = np.linalg.svd(data_Z, full_matrices=False)
        param_vect = v[-1, :]
        params = [x for x in np.asarray(param_vect)[0]]  # convert from (dimen+1) x 1 matrix to list
        params[0] /= 2 * math.sqrt(mean_square)
        params.append(-mean_square * params[0])
        params = np.array(params)

    else:
        # matrix of squared magnitudes, data, 1s
        data_Z = np.bmat([[square_mag, central, np.ones((num_points, 1))]])
        #    print >> stderr, "DEBUG: data_Z=",data_Z

        # SVD of data_Z
        # Note: numpy's linalg.svd returns data_Z = u * s * v
        #         not u*s*v.H as the Release 1.4.1 documentation claims.
        #         Newer documentation is correct.
        u, s, v = np.linalg.svd(data_Z, full_matrices=False)
        #    print >>stderr, "DEBUG: u=",repr(u)
        #    print >>stderr, "DEBUG: s=",repr(s)
        #    print >>stderr, "DEBUG: v=",repr(v)
        #    print >>stderr, "DEBUG: v.I=",repr(v.I)

        if s[-1] / s[0] < 1e-12:
            # singular case
            # param_vect as (dimen+2) x 1 matrix
            param_vect = v[-1, :]
            # Note: I get last ROW of v, while Chernov claims last COLUMN,
            # because of difference in definition of SVD for MATLAB and numpy

            #        print >> stderr, "DEBUG: singular, param_vect=", repr(param_vect)
            #        print >> stderr, "DEBUG: data_Z*V=", repr(data_Z*v)
            #        print >> stderr, "DEBUG: data_Z*VI=", repr(data_Z*v.I)
            #        print >> stderr, "DEBUG: data_Z*A=", repr(data_Z*v[:,-1])
        else:
            Y = v.H * np.diag(s) * v
            Y_inv = v.H * np.diag([1. / x for x in s]) * v
            #        print >>stderr, "DEBUG: Y=",repr(Y)
            #        print >>stderr, "DEBUG: Y.I=",repr(Y.I), "\nY_inv=",repr(Y_inv)
            # Ninv is the inverse of the constraint matrix, after centroid has been removed
            Ninv = np.asmatrix(np.identity(dimen + 2, dtype=float))
            if method == "Hyper":
                Ninv[0, 0] = 0
                Ninv[0, -1] = 0.5
                Ninv[-1, 0] = 0.5
                Ninv[-1, -1] = -2 * square_mag.mean()
            elif method == "Pratt":
                Ninv[0, 0] = 0
                Ninv[0, -1] = -0.5
                Ninv[-1, 0] = -0.5
                Ninv[-1, -1] = 0
            else:
                raise ValueError("Error: unknown method: {} should be 'Hyper', 'Pratt', or 'Taubin'")
                #        print >> stderr, "DEBUG: Ninv=", repr(Ninv)

            # get the eigenvector for the smallest positive eigenvalue
            matrix_for_eigen = Y * Ninv * Y
            #   print >> stderr, "DEBUG: {} matrix_for_eigen=\n{}".format(method, repr(matrix_for_eigen))
            eigen_vals, eigen_vects = np.linalg.eigh(matrix_for_eigen)
            #   print >> stderr, "DEBUG: eigen_vals=", repr(eigen_vals)
            #   print >> stderr, "DEBUG: eigen_vects=", repr(eigen_vects)

            positives = [x for x in eigen_vals if x > 0]
            if len(positives) + 1 != len(eigen_vals):
                # raise ValueError("Error: for method {} exactly one eigenvalue should be negative: {}".format(method,eigen_vals))
                print>> stderr, "Warning: for method {} exactly one eigenvalue should be negative: {}".format(method,
                                                                                                              eigen_vals)
            smallest_positive = min(positives)
            #    print >> stderr, "DEBUG: smallest_positive=", smallest_positive
            # chosen eigenvector as 1 x (dimen+2) matrix
            A_colvect = eigen_vects[:, list(eigen_vals).index(smallest_positive)]
            #        print >> stderr, "DEBUG: A_colvect=", repr(A_colvect)
            # now have to multiply by Y inverse
            param_vect = (Y_inv * A_colvect).transpose()
            #        print >> stderr, "DEBUG: nonsingular, param_vect=", repr(param_vect)
            params = np.asarray(param_vect)[0]  # convert from (dimen+2) x 1 matrix to array of (dimen+2)


            #    print >> stderr, "DEBUG: params=", repr(params)
    radius = 0.5 * math.sqrt(sum(a * a for a in params[1:-1]) - 4 * params[0] * params[-1]) / abs(params[0])
    center = -0.5 * params[1:-1] / params[0]
    # y    print >> stderr, "DEBUG: center=", repr(center), "centroid=", repr(centroid)
    center += np.asarray(centroid)[0]
    return (radius, center)

def transform_sphere_points(B1m, B1mr):
    ''' Function to extract the centers of each registration marker in the respective bone Optotrak sensor coordinate
    system'''

    # Within this data set, there are 30 points, each with an x,y,and z coordinate
    # next we want to convert this data set into an array. The units for the axis
    # are all in m.

    # But first need to get rid of the quotations at the beginning and end of string.
    B1m = B1m.replace('"', '')
    B1msplit = B1m.split(" ")

    # Get rid of first two cells, and convert string into float
    corrected_list = B1msplit[2:92]
    corrected_floatlist = map(float, corrected_list)

    ## matrix of the 30 points
    B1mCoord = np.asarray(corrected_floatlist)
    B1mCoord = B1mCoord.reshape(30, -1)

    # Within this data set, there are 30 points, each with an x,y,z,roll, pitch and
    # yaw coordinate next we want to convert this data set into an array. The
    # units for the x,y, and z axis are in m and the roll, pitch, and yaw axis
    # are in rad.

    # But first need to get rid of the quotations at the beginning and end of string.
    B1mr = B1mr.replace('"', '')
    B1mrsplit = B1mr.split(" ")

    # Get rid of first two cells, and convert string into float
    corrected_list = B1mrsplit[2:182]
    corrected_floatlist = map(float, corrected_list)

    ## matrix of the 30 points
    B1mrCoord = np.asarray(corrected_floatlist)
    B1mrCoord = B1mrCoord.reshape(30, -1)

    ## Define coordinate transformations of data
    P1 = np.ones((4, 1))
    Coord1 = np.zeros((30, 3))

    for i in range(0, 30):
        q1 = B1mrCoord[i, 0]
        q2 = B1mrCoord[i, 1]
        q3 = B1mrCoord[i, 2]
        q4 = B1mrCoord[i, 3]
        q5 = B1mrCoord[i, 4]
        q6 = B1mrCoord[i, 5]

        T1 = get_transformation_matrix(q1, q2, q3, q4, q5, q6)

        P1[0, 0] = B1mCoord[i, 0]
        P1[1, 0] = B1mCoord[i, 1]
        P1[2, 0] = B1mCoord[i, 2]

        invT1 = np.linalg.inv(T1)

        A = np.dot(invT1, P1) # Transform to bone Optotrak sensor coordinate system
        Coord1[i, 0] = A[0, 0]
        Coord1[i, 1] = A[1, 0]
        Coord1[i, 2] = A[2, 0]

    ACoord1 = Coord1[0:10, 0:3]
    BCoord1 = Coord1[10:20, 0:3]
    CCoord1 = Coord1[20:30, 0:3]


    # Sphere fit for Rigid Body Collected Points (m), all three spheres. Set to NAN if points were not digitized.
    NAN = float('nan')
    try:
        B1mSphereA = fit_hypersphere(ACoord1, method="Pratt")
    except:
        B1mSphereA = [(0), (NAN, NAN, NAN)]

    try:
        B1mSphereB = fit_hypersphere(BCoord1, method="Pratt")
    except:
        B1mSphereB = [(0), (NAN, NAN, NAN)]

    try:
        B1mSphereC = fit_hypersphere(CCoord1, method="Pratt")
    except:
        B1mSphereC = [(0), (NAN, NAN, NAN)]

    return B1mSphereA, B1mSphereB, B1mSphereC


def findFrame(dT, frameTimeVector, initial_time):
    """Find the frame corresponding to the specified tdms time and return the adjusted tdms time to match the
    selected frame"""
    adjusted_time = dT + initial_time
    for f in range(len(frameTimeVector)):
        f += 1
        frame_time = sum(frameTimeVector[0:f])
        if adjusted_time <= frame_time:
            timeDiff_up = frame_time - adjusted_time
            timeDiff_low = adjusted_time - sum(frameTimeVector[0:f - 1])
            if timeDiff_up < timeDiff_low:
                frame_frame = f
                readjusted_time_tdms = frame_time - dT
            else:
                frame_frame = f - 1
                readjusted_time_tdms = sum(frameTimeVector[0:f - 1]) - dT
            break

    return frame_frame, int(readjusted_time_tdms)


def find_file(name, path):
    for root, dirs, files in os.walk(path):
        if name in files:
            return os.path.join(root, name)

def getFrames(dir, tdms, frameTimeVector):
    """Get the frames to be analyzed and save the data for those frames. Different for indentation and anatomical
    trials. Indentation contains frames that start at indentation and go through the peak force of the indentation,
    while anatomical analyzes all frames between start and end pulses from minimum to maximum force"""

    # Find the xml file with delta_t
    analysis_path = os.path.join(os.path.split(dir)[0], 'TimeSynchronization')
    split_name = tdms
    tail = split_name[0:-5] + '_dT.xml'

    delta_t_file = find_file(tail, analysis_path)
    # print(delta_t_file)

    if delta_t_file == None:
        return
    else:

        # Extract force information from TDMS file
        data = tdsmParserMultis.parseTDMSfile(os.path.join(dir, tdms))

        Fx = np.array(data[u'State.6-DOF Load'][u'6-DOF Load Fx'])
        Fy = np.array(data[u'State.6-DOF Load'][u'6-DOF Load Fy'])
        Fz = np.array(data[u'State.6-DOF Load'][u'6-DOF Load Fz'])
        Mx = np.array(data[u'State.6-DOF Load'][u'6-DOF Load Mx'])
        My = np.array(data[u'State.6-DOF Load'][u'6-DOF Load My'])
        Mz = np.array(data[u'State.6-DOF Load'][u'6-DOF Load Mz'])

        F_mag = []
        for f in range(len(Fx)):
            F_mag.append(math.sqrt(Fx[f] ** 2 + Fy[f] ** 2 + Fz[f] ** 2))

        pulse = np.array(data[u'Sensor.Run Number Pulse Train'][u'Run Number Pulse Train'])
        pulse = pulse[:]

        Peaks = peakutils.indexes(pulse, thres=0.5 * max(pulse), min_dist=100)

        doc1 = ET.parse(delta_t_file)
        root1 = doc1.getroot()
        loc1 = root1.find("Location")
        dT_str = loc1.find("dT").text
        dT = float(dT_str)

        time = np.arange(len(F_mag))
        data_i = zip(time, Fx, Fy, Fz, Mx, My, Mz, F_mag)

        # Anatomical frames from minimum to maximum normal force (Fx)
        time_preIndent_tdms = Peaks[0]  # 230 ms is location first pulse.
        time_postIndent_tdms = Peaks[-1]

        # Sort the data by the magnitude (F_mag) to get anatomical frame list from minimum to maximum
        data_a = zip(time, Fx, Fy, Fz, Mx, My, Mz, F_mag)
        data_a.sort(key=lambda t: abs(t[7]))

        frame_lst_a = []
        data_a_sort = []

        for t in data_a:
            try:
                min_frame, min_frame_time_tdms = findFrame(dT, frameTimeVector, t[0])
            except:
                continue
            if min_frame not in frame_lst_a:
                if min_frame_time_tdms > time_preIndent_tdms and min_frame_time_tdms < time_postIndent_tdms:
                    frame_lst_a.append(min_frame)
                    data_a_sort.append(data_i[min_frame_time_tdms])

        final_a_sort = zip(frame_lst_a, data_a_sort)
        final_a_sort.sort(key=lambda t: abs(t[1][7]))

        # Return Minimum Force
        min_frame = final_a_sort[0][0]
        DATA = final_a_sort[0][1]

        return min_frame, DATA

def get_stl_points(file):
    reader = tvtk.STLReader()
    reader.file_name = file
    reader.update()

    stl_points = reader.output.points.data.to_array()

    return stl_points

def get_stl_points(file):
    reader = tvtk.STLReader()
    reader.file_name = file
    reader.update()

    stl_points = reader.output.points.data.to_array()

    return stl_points

def get_stl_dist(subj_path, modality, segment, spheres):

    # spheres = tkFileDialog.askopenfilenames(title="Select all " + segment + " fiducial marker STL files", initialdir =os.path.join(subj_path, modality))
    labels = []
    for s in spheres:
        filename = os.path.split(s)[1]
        labels.append(filename[-9:-7])

    sphere_fit = []
    for i, s in enumerate(spheres):
        r, pos = fit_hypersphere(get_stl_points(s))
        sphere_fit.append([os.path.split(s)[1], r, pos])

    center_dist = []
    # print "STLS"
    for p1 in range(len(sphere_fit)-1):
        for p2 in range(p1+1,len(sphere_fit)):
            d, x_a, y_a, z_a = get_distance(sphere_fit[p1][2], sphere_fit[p2][2])
            center_dist.append([sphere_fit[p1][0], sphere_fit[p2][0], d, x_a, y_a, z_a])

    return center_dist, spheres, sphere_fit

def adjust_angles(ang_list):
    for i, A in enumerate(ang_list):
        if ang_list[i] > np.pi:
            ang_list[i] -= np.pi*2
        elif ang_list[i] < -np.pi:
            ang_list[i] += np.pi*2
    return ang_list

def adjust_loads(L):
    L -= np.mean(L[0:200])
    return L

def zero_pos(p):
    p -= p[0]
    return p

def main(dir, segment, modality):

    # # Read in both probe STLs that were transformed to US tip coordinate system
    # filename_scalpel = 'SurgTools STLS/scalpel.transformed.stl'
    # reader_scalpel = tvtk.STLReader()
    # reader_scalpel.file_name = file_name = filename_scalpel
    # reader_scalpel.update()
    #
    # filename_retractor = 'SurgTools STLS/retractor.transformed.stl'
    # reader_retractor = tvtk.STLReader()
    # reader_retractor.file_name = file_name = filename_retractor
    # reader_retractor.update()
    #
    # filename_forceps = 'SurgTools STLS/forceps.transformed.stl'
    # reader_forceps = tvtk.STLReader()
    # reader_forceps.file_name = file_name = filename_forceps
    # reader_forceps.update()

    # Assign seg and bone parameters for later calculations
    if segment == 'UpperLeg':
        seg = 'UL_'
        bone = 'Femur'
        markers = ["F1", "F2", "F3", "F4", "F5", "F6"]
    elif segment == 'LowerLeg':
        seg = 'LL_'
        bone = 'Tibia'
        markers = ["T1", "T2", "T3", "T4", "T5", "T6"]
    elif segment == 'UpperArm':
        seg = 'UA_'
        bone = 'Humerus'
        markers = ["H1", "H2", "H3", "H4", "N/A", "N/A"]
    elif segment == 'LowerArm':
        seg = 'LA_'
        bone = 'Radius'
        markers = ["R1", "N/A", "N/A", "R2", "R3", "N/A"]


    # v = mlab.figure() # Create Mayavi figure

    ## Find and plot all ultrasound positions of relevant trials ##
    # Get accepted trials from Ultrasound experiment
    masterList, num_accept = XMLparser.getAcceptance(dir)
    accepted_list = []
    for x in masterList:
        if x[1] == 1:
            accepted_list.append(x[0])

    print accepted_list

    dir_data = os.path.join(dir, 'Data')
    files = os.listdir(dir_data)
    files.sort()

    # Plot positions of ultrasound probe for each accepted ultrasound trial for the specified segment
    for file in files:
        if any(file[0:3] in s for s in accepted_list):
            data = tdsmParserMultis.parseTDMSfile(os.path.join(dir_data, file))

            config = ConfigParser.RawConfigParser()
            if not config.read(os.path.join(os.path.join(dir, 'Configuration'), file[0:-5]+'_State.cfg')):
                raise IOError, "Cannot load configuration file... Check path."

            print '====================='
            print file

            if file[17] == 'R':
                tools = ['Retractor']
                loads = ['Retractor Load']
            elif file[17] == 'S':
                tools = ['Scalpel']
                loads = ['Scalpel Load']
            elif file[17] == 'F':
                tools = ['Forceps Left', 'Forceps Right']
                loads = ['Forceps Left Load', 'Forceps Right Load']

            for i, tool in enumerate(tools):
                # try:
                # Get the raw data for tool position and load
                B_pos = data[u'State.' + tool]
                frame = 0
                time = data[u'Time'][u'Time']
                x = (B_pos[u'' + tool + '_x'])/ 1000
                y = (B_pos[u'' + tool + '_y'])/ 1000
                z = (B_pos[u'' + tool + '_z']) / 1000
                r = np.radians(B_pos[u'' + tool + '_roll'])
                p = np.radians(B_pos[u'' + tool + '_pitch'])
                w = np.radians(B_pos[u'' + tool + '_yaw'])

                T_FEM_T = np.zeros([len(x), 3,3])
                for ii in range(len(x)):
                    T_FEM_T_1 = get_RB_transformation_matrix(x[ii], y[ii], z[ii], r[ii], p[ii], w[ii])
                    T_FEM_T[ii] = T_FEM_T_1[0:3,0:3]

                missing = np.sum(np.isnan(x))
                if missing > 0:
                    print "Missing {} tool data: {} data points".format(tool, missing)

                # force_len = len(data[u'State.'+loads[i]][u''+loads[i]+' Fx'])
                Fx = data[u'State.'+loads[i]][u''+loads[i]+' Fx']#[force_len-time[-1]:]
                Fy = data[u'State.'+loads[i]][u''+loads[i]+' Fy']#[force_len-time[-1]:]
                Fz = data[u'State.'+loads[i]][u''+loads[i]+' Fz']#[force_len-time[-1]:]
                Mx = data[u'State.'+loads[i]][u''+loads[i]+' Mx']#[force_len-time[-1]:]
                My = data[u'State.'+loads[i]][u''+loads[i]+' My']#[force_len-time[-1]:]
                Mz = data[u'State.'+loads[i]][u''+loads[i]+' Mz']#[force_len-time[-1]:]
                t = np.arange(0, len(Fx), 1)
                print time[-1], t[-1]

                for f in range(len(Fx)):
                    Fx[f], Fy[f], Fz[f] = np.dot(T_FEM_T[(time - t[f]).argmin()], [Fx[f], Fy[f], Fz[f]])
                    Mx[f], My[f], Mz[f] = np.dot(T_FEM_T[(time - t[f]).argmin()], [Mx[f], My[f], Mz[f]])

                # Zero positions and angles at beginning of trial
                x = zero_pos(x)
                y = zero_pos(y)
                z = zero_pos(z)
                r = zero_pos(r)
                p = zero_pos(p)
                w = zero_pos(w)

                # Adjust the angles so that they lay between -180 and 180 degrees
                r = adjust_angles(r)
                p = adjust_angles(p)
                w = adjust_angles(w)

                # # Digitally tare loads and moments
                # Fx = adjust_loads(Fx)
                # Fy = adjust_loads(Fy)
                # Fz = adjust_loads(Fz)
                # Mx = adjust_loads(Mx)
                # My = adjust_loads(My)
                # Mz = adjust_loads(Mz)

                f, (ax1, ax2, ax3, ax4) = plt.subplots(4, 1, figsize=(15, 15))

                ax1.plot(time, x*1000)
                ax1.plot(time, y*1000)
                ax1.plot(time, z*1000)
                ax1.scatter(time, x * 1000)
                ax1.scatter(time, y * 1000)
                ax1.scatter(time, z * 1000)
                ax1.set_xlim([0,time[-1]])
                ax1.set_ylabel('Tool Position')
                ax1.legend(['X (mm)', 'Y (mm)', 'Z (mm)'], loc='center left', bbox_to_anchor=(1, 0.5))

                ax2.plot(time, np.degrees(r))
                ax2.plot(time, np.degrees(p))
                ax2.plot(time, np.degrees(w))
                ax2.scatter(time, np.degrees(r))
                ax2.scatter(time, np.degrees(p))
                ax2.scatter(time, np.degrees(w))
                ax2.set_ylabel('Tool Orientation')
                ax2.set_xlim([0, time[-1]])
                ax2.legend(['roll (deg)', 'pitch (deg)', 'yaw (deg)'], loc='center left', bbox_to_anchor=(1, 0.5))

                ax3.plot(t, Fx)
                ax3.plot(t, Fy)
                ax3.plot(t, Fz)
                ax3.set_ylabel('Tool Forces')
                ax3.set_xlim([0, time[-1]])
                ax3.legend(['Fx (N)', 'Fy (N)', 'Fz (N)'], loc='center left', bbox_to_anchor=(1, 0.5))

                ax4.plot(t, Mx)
                ax4.plot(t, My)
                ax4.plot(t, Mz)
                ax4.set_xlim([0, time[-1]])
                ax4.set_ylabel('Tool Moments')
                ax4.legend(['Mx (Nm)', 'My (Nm)', 'Mz (Nm)'], loc='center left', bbox_to_anchor=(1, 0.5))

                plt.suptitle(tool)
                plt.tight_layout()
                plt.subplots_adjust(right=0.85, top = 0.95)
                if not os.path.exists(os.path.join(dir, 'DataQuality')):
                    os.makedirs(os.path.join(dir, 'DataQuality'))
                f.savefig(os.path.join(dir, 'DataQuality', file[0:-5]+'_'+tool.replace(" ", "")+'.png'))
                plt.close()
                # plt.show()
                # except:
                #     print "Missing " + tool + " optotrak data"

            print ""
        # break




if __name__ == "__main__":
    dir = '/home/morrile2/Documents/Multis/app/InstrumentedSurgicalTools/SMULTIS008-1'  # Cadaver specimen folder
    segment = 'UpperLeg'  # Change this to the desired segment
    modality = 'CT'
    main(dir, segment, modality)
    # plt.show()