
''' Script to visualize the surgical tools trials all transformed to the image (model) coordinate system)'''

import tdsmParserMultis
import os
import ConfigParser
import XMLparser
import numpy as np
from sys import stderr
import math
from mayavi import mlab
import tkFileDialog

# To access any VTK object, we use 'tvtk', which is a Python wrapping of
# VTK replacing C++ setters and getters by Python properties and
# converting numpy arrays to VTK arrays when setting data.
from tvtk.tools import visual
from tvtk.api import tvtk
from tvtk.common import configure_input

def Arrow_From_A_to_B(x1, y1, z1, x2, y2, z2, c):
    ar1=visual.arrow(x=x1, y=y1, z=z1, color = c)
    ar1.length_cone=0.4

    arrow_length=np.sqrt((x2-x1)**2+(y2-y1)**2+(z2-z1)**2)
    ar1.actor.scale=[arrow_length, arrow_length, arrow_length]
    ar1.pos = ar1.pos/arrow_length
    ar1.axis = [x2-x1, y2-y1, z2-z1]
    return ar1

def get_transformation_matrix(q1, q2, q3, q4, q5, q6):
    ''' Transform from optotrak global coordinates to optotrak position sensor coordinates'''

    T = np.zeros((4, 4))

    T[0, 0] = math.cos(q6) * math.cos(q5)
    T[1, 0] = math.sin(q6) * math.cos(q5)
    T[2, 0] = -math.sin(q5)

    T[0, 1] = math.cos(q6) * math.sin(q5) * math.sin(q4) - math.sin(q6) * math.cos(q4)
    T[1, 1] = math.sin(q6) * math.sin(q5) * math.sin(q4) + math.cos(q6) * math.cos(q4)
    T[2, 1] = math.cos(q5) * math.sin(q4)

    T[0, 2] = math.cos(q6) * math.sin(q5) * math.cos(q4) + math.sin(q6) * math.sin(q4)
    T[1, 2] = math.sin(q6) * math.sin(q5) * math.cos(q4) - math.cos(q6) * math.sin(q4)
    T[2, 2] = math.cos(q5) * math.cos(q4)

    T[0, 3] = q1
    T[1, 3] = q2
    T[2, 3] = q3
    T[3, 3] = 1

    return T

def convertTOarray(data):

    data = data.replace("", '')
    data = data.split(" ")
    idx = 0
    for d in data:
        try:
            float(d)
            break
        except:
            idx +=1
    corrected_list = data[idx:len(data)]
    corrected_list[-1] = corrected_list[-1][0:-1]
    data_float = map(float, corrected_list)

    return np.asarray(data_float)

def get_distance(A, B):
    length = math.sqrt((A[0]-B[0])**2+(A[1]-B[1])**2+(A[2]-B[2])**2)
    vector = (A - B)
    x_angle = math.acos(np.dot(vector, [1,0,0])/length)*180/math.pi
    y_angle = math.acos(np.dot(vector, [0,1,0])/length)*180/math.pi
    z_angle = math.acos(np.dot(vector, [0,0,1])/length)*180/math.pi
    return float(length), float(x_angle), float(y_angle), float(z_angle)

def get_xyzrpw(T):
    x = T[0,3]
    y = T[1,3]
    z = T[2,3]
    w = math.atan2(T[1,0], T[0,0])
    p = math.atan2(-T[2,0], math.sqrt(T[2,1]**2+T[2,2]**2))
    r = math.atan2(T[2,1], T[2,2])
    return x, y, z, r, p, w

def fit_hypersphere(data, method="Hyper"):
    """
            FitHypersphere.py

            fit_hypersphere(collection of tuples or lists of real numbers)
            will return a hypersphere of the same dimension as the tuples:
                    (radius, (center))

            using the Hyper (hyperaccurate) algorithm of
            Ali Al-Sharadqah and Nikolai Chernov
            Error analysis for circle fitting algorithms
            Electronic Journal of Statistics
            Vol. 3 (2009) 886-911
            DOI: 10.1214/09-EJS419

            generalized to n dimensions

            Mon Apr 23 04:08:05 PDT 2012 Kevin Karplus

            Note: this version using SVD works with Hyper, Pratt, and Taubin methods.
            If you are not familiar with them, Hyper is probably your best choice.


            Creative Commons Attribution-ShareAlike 3.0 Unported License.
            http://creativecommons.org/licenses/by-sa/3.0/
    """

    """returns a hypersphere of the same dimension as the
        collection of input tuples
                (radius, (center))

       Methods available for fitting are "algebraic" fitting methods
        Hyper   Al-Sharadqah and Chernov's Hyperfit algorithm
        Pratt   Vaughn Pratt's algorithm
        Taubin  G. Taubin's algorithm

       The following methods, though very similar, are not implemented yet,
          because the contraint matrix N would be singular,
          and so the N_inv computation is not doable.

        Kasa    Kasa's algorithm
    """
    num_points = len(data)
    #    print >>stderr, "DEBUG: num_points=", num_points

    if num_points == 0:
        return (0, None)
    if num_points == 1:
        return (0, data[0])
    dimen = len(data[0])  # dimensionality of hypersphere
    #    print >>stderr, "DEBUG: dimen=", dimen

    if num_points < dimen + 1:
        raise ValueError( \
            "Error: fit_hypersphere needs at least {} points to fit {}-dimensional sphere, but only given {}".format(
                dimen + 1, dimen, num_points))

    # central dimen columns of matrix  (data - centroid)
    central = np.matrix(data, dtype=float)  # copy the data
    centroid = np.mean(central, axis=0)
    for row in central:
        row -= centroid
    # print >>stderr, "DEBUG: central=", repr(central)

    # squared magnitude for each centered point, as a column vector
    square_mag = [sum(a * a for a in row.flat) for row in central]
    square_mag = np.matrix(square_mag).transpose()
    #    print >>stderr, "DEBUG: square_mag=", square_mag

    if method == "Taubin":
        # matrix of normalized squared magnitudes, data
        mean_square = square_mag.mean()
        data_Z = np.bmat([[(square_mag - mean_square) / (2 * math.sqrt(mean_square)), central]])
        #    print >> stderr, "DEBUG: data_Z=",data_Z
        u, s, v = np.linalg.svd(data_Z, full_matrices=False)
        param_vect = v[-1, :]
        params = [x for x in np.asarray(param_vect)[0]]  # convert from (dimen+1) x 1 matrix to list
        params[0] /= 2 * math.sqrt(mean_square)
        params.append(-mean_square * params[0])
        params = np.array(params)

    else:
        # matrix of squared magnitudes, data, 1s
        data_Z = np.bmat([[square_mag, central, np.ones((num_points, 1))]])
        #    print >> stderr, "DEBUG: data_Z=",data_Z

        # SVD of data_Z
        # Note: numpy's linalg.svd returns data_Z = u * s * v
        #         not u*s*v.H as the Release 1.4.1 documentation claims.
        #         Newer documentation is correct.
        u, s, v = np.linalg.svd(data_Z, full_matrices=False)
        #    print >>stderr, "DEBUG: u=",repr(u)
        #    print >>stderr, "DEBUG: s=",repr(s)
        #    print >>stderr, "DEBUG: v=",repr(v)
        #    print >>stderr, "DEBUG: v.I=",repr(v.I)

        if s[-1] / s[0] < 1e-12:
            # singular case
            # param_vect as (dimen+2) x 1 matrix
            param_vect = v[-1, :]
            # Note: I get last ROW of v, while Chernov claims last COLUMN,
            # because of difference in definition of SVD for MATLAB and numpy

            #        print >> stderr, "DEBUG: singular, param_vect=", repr(param_vect)
            #        print >> stderr, "DEBUG: data_Z*V=", repr(data_Z*v)
            #        print >> stderr, "DEBUG: data_Z*VI=", repr(data_Z*v.I)
            #        print >> stderr, "DEBUG: data_Z*A=", repr(data_Z*v[:,-1])
        else:
            Y = v.H * np.diag(s) * v
            Y_inv = v.H * np.diag([1. / x for x in s]) * v
            #        print >>stderr, "DEBUG: Y=",repr(Y)
            #        print >>stderr, "DEBUG: Y.I=",repr(Y.I), "\nY_inv=",repr(Y_inv)
            # Ninv is the inverse of the constraint matrix, after centroid has been removed
            Ninv = np.asmatrix(np.identity(dimen + 2, dtype=float))
            if method == "Hyper":
                Ninv[0, 0] = 0
                Ninv[0, -1] = 0.5
                Ninv[-1, 0] = 0.5
                Ninv[-1, -1] = -2 * square_mag.mean()
            elif method == "Pratt":
                Ninv[0, 0] = 0
                Ninv[0, -1] = -0.5
                Ninv[-1, 0] = -0.5
                Ninv[-1, -1] = 0
            else:
                raise ValueError("Error: unknown method: {} should be 'Hyper', 'Pratt', or 'Taubin'")
                #        print >> stderr, "DEBUG: Ninv=", repr(Ninv)

            # get the eigenvector for the smallest positive eigenvalue
            matrix_for_eigen = Y * Ninv * Y
            #   print >> stderr, "DEBUG: {} matrix_for_eigen=\n{}".format(method, repr(matrix_for_eigen))
            eigen_vals, eigen_vects = np.linalg.eigh(matrix_for_eigen)
            #   print >> stderr, "DEBUG: eigen_vals=", repr(eigen_vals)
            #   print >> stderr, "DEBUG: eigen_vects=", repr(eigen_vects)

            positives = [x for x in eigen_vals if x > 0]
            if len(positives) + 1 != len(eigen_vals):
                # raise ValueError("Error: for method {} exactly one eigenvalue should be negative: {}".format(method,eigen_vals))
                print>> stderr, "Warning: for method {} exactly one eigenvalue should be negative: {}".format(method,
                                                                                                              eigen_vals)
            smallest_positive = min(positives)
            #    print >> stderr, "DEBUG: smallest_positive=", smallest_positive
            # chosen eigenvector as 1 x (dimen+2) matrix
            A_colvect = eigen_vects[:, list(eigen_vals).index(smallest_positive)]
            #        print >> stderr, "DEBUG: A_colvect=", repr(A_colvect)
            # now have to multiply by Y inverse
            param_vect = (Y_inv * A_colvect).transpose()
            #        print >> stderr, "DEBUG: nonsingular, param_vect=", repr(param_vect)
            params = np.asarray(param_vect)[0]  # convert from (dimen+2) x 1 matrix to array of (dimen+2)


            #    print >> stderr, "DEBUG: params=", repr(params)
    radius = 0.5 * math.sqrt(sum(a * a for a in params[1:-1]) - 4 * params[0] * params[-1]) / abs(params[0])
    center = -0.5 * params[1:-1] / params[0]
    # y    print >> stderr, "DEBUG: center=", repr(center), "centroid=", repr(centroid)
    center += np.asarray(centroid)[0]
    return (radius, center)

def transform_sphere_points(B1m, B1mr):
    ''' Function to extract the centers of each registration marker in the respective bone Optotrak sensor coordinate
    system'''

    # Within this data set, there are 30 points, each with an x,y,and z coordinate
    # next we want to convert this data set into an array. The units for the axis
    # are all in m.

    # But first need to get rid of the quotations at the beginning and end of string.
    B1m = B1m.replace('"', '')
    B1msplit = B1m.split(" ")

    # Get rid of first two cells, and convert string into float
    corrected_list = B1msplit[2:92]
    corrected_floatlist = map(float, corrected_list)

    ## matrix of the 30 points
    B1mCoord = np.asarray(corrected_floatlist)
    B1mCoord = B1mCoord.reshape(30, -1)

    # Within this data set, there are 30 points, each with an x,y,z,roll, pitch and
    # yaw coordinate next we want to convert this data set into an array. The
    # units for the x,y, and z axis are in m and the roll, pitch, and yaw axis
    # are in rad.

    # But first need to get rid of the quotations at the beginning and end of string.
    B1mr = B1mr.replace('"', '')
    B1mrsplit = B1mr.split(" ")

    # Get rid of first two cells, and convert string into float
    corrected_list = B1mrsplit[2:182]
    corrected_floatlist = map(float, corrected_list)

    ## matrix of the 30 points
    B1mrCoord = np.asarray(corrected_floatlist)
    B1mrCoord = B1mrCoord.reshape(30, -1)

    ## Define coordinate transformations of data
    P1 = np.ones((4, 1))
    Coord1 = np.zeros((30, 3))

    for i in range(0, 30):
        q1 = B1mrCoord[i, 0]
        q2 = B1mrCoord[i, 1]
        q3 = B1mrCoord[i, 2]
        q4 = B1mrCoord[i, 3]
        q5 = B1mrCoord[i, 4]
        q6 = B1mrCoord[i, 5]

        T1 = get_transformation_matrix(q1, q2, q3, q4, q5, q6)

        P1[0, 0] = B1mCoord[i, 0]
        P1[1, 0] = B1mCoord[i, 1]
        P1[2, 0] = B1mCoord[i, 2]

        invT1 = np.linalg.inv(T1)

        A = np.dot(invT1, P1) # Transform to bone Optotrak sensor coordinate system
        Coord1[i, 0] = A[0, 0]
        Coord1[i, 1] = A[1, 0]
        Coord1[i, 2] = A[2, 0]

    ACoord1 = Coord1[0:10, 0:3]
    BCoord1 = Coord1[10:20, 0:3]
    CCoord1 = Coord1[20:30, 0:3]


    # Sphere fit for Rigid Body Collected Points (m), all three spheres. Set to NAN if points were not digitized.
    NAN = float('nan')
    try:
        B1mSphereA = fit_hypersphere(ACoord1, method="Pratt")
    except:
        B1mSphereA = [(0), (NAN, NAN, NAN)]

    try:
        B1mSphereB = fit_hypersphere(BCoord1, method="Pratt")
    except:
        B1mSphereB = [(0), (NAN, NAN, NAN)]

    try:
        B1mSphereC = fit_hypersphere(CCoord1, method="Pratt")
    except:
        B1mSphereC = [(0), (NAN, NAN, NAN)]

    return B1mSphereA, B1mSphereB, B1mSphereC

def find_file(name, path):
    for root, dirs, files in os.walk(path):
        if name in files:
            return os.path.join(root, name)

def get_stl_points(file):
    reader = tvtk.STLReader()
    reader.file_name = file
    reader.update()

    stl_points = reader.output.points.data.to_array()

    return stl_points

def get_stl_dist(subj_path, modality, segment):

    spheres = tkFileDialog.askopenfilenames(title="Select all " + segment + " fiducial marker STL files", initialdir =os.path.join(subj_path, modality))
    # spheres = ['/home/morrile2/Documents/MULTIS Data/MULTIS_surgtools/SMULTIS003-1/Registration/R01/R01_CMULTIS003-1_WL_CT_F2.stl', '/home/morrile2/Documents/MULTIS Data/MULTIS_surgtools/SMULTIS003-1/Registration/R01/R01_CMULTIS003-1_WL_CT_F3.stl', '/home/morrile2/Documents/MULTIS Data/MULTIS_surgtools/SMULTIS003-1/Registration/R01/R01_CMULTIS003-1_WL_CT_F6.stl']
    labels = []
    for s in spheres:
        filename = os.path.split(s)[1]
        labels.append(filename[-6:-4])

    sphere_fit = []
    for i, s in enumerate(spheres):
        r, pos = fit_hypersphere(get_stl_points(s))
        sphere_fit.append([os.path.split(s)[1], r, pos])

    center_dist = []
    # print "STLS"
    for p1 in range(len(sphere_fit)-1):
        for p2 in range(p1+1,len(sphere_fit)):
            d, x_a, y_a, z_a = get_distance(sphere_fit[p1][2], sphere_fit[p2][2])
            center_dist.append([sphere_fit[p1][0], sphere_fit[p2][0], d, x_a, y_a, z_a])

    return center_dist, spheres, sphere_fit

def transform_digitized_points(B1m, B1mr, N):
    ''' Function to extract the centers of each registration marker in the respective bone Optotrak sensor coordinate
    system'''

    # Within this data set, there are 30 points, each with an x,y,and z coordinate
    # next we want to convert this data set into an array. The units for the axis
    # are all in m.

    # But first need to get rid of the quotations at the beginning and end of string.
    B1m = B1m.replace('"', '')
    B1msplit = B1m.split(" ")

    # Get rid of first two cells, and convert string into float
    corrected_list = B1msplit[2:3*N + 2]
    corrected_floatlist = map(float, corrected_list)

    B1mCoord = np.asarray(corrected_floatlist)
    B1mCoord = B1mCoord.reshape(N, -1)

    # Within this data set, there are 30 points, each with an x,y,z,roll, pitch and
    # yaw coordinate next we want to convert this data set into an array. The
    # units for the x,y, and z axis are in m and the roll, pitch, and yaw axis
    # are in rad.

    # But first need to get rid of the quotations at the beginning and end of string.
    B1mr = B1mr.replace('"', '')
    B1mrsplit = B1mr.split(" ")

    # Get rid of first two cells, and convert string into float
    corrected_list = B1mrsplit[2:6*N + 2]
    corrected_floatlist = map(float, corrected_list)

    B1mrCoord = np.asarray(corrected_floatlist)
    B1mrCoord = B1mrCoord.reshape(N, -1)

    ## Define coordinate transformations of data
    P1 = np.ones((4, 1))
    Coord1 = np.zeros((N, 3))

    for i in range(0, N):
        q1 = B1mrCoord[i, 0]
        q2 = B1mrCoord[i, 1]
        q3 = B1mrCoord[i, 2]
        q4 = B1mrCoord[i, 3]
        q5 = B1mrCoord[i, 4]
        q6 = B1mrCoord[i, 5]

        T1 = get_transformation_matrix(q1, q2, q3, q4, q5, q6)

        P1[0, 0] = B1mCoord[i, 0]
        P1[1, 0] = B1mCoord[i, 1]
        P1[2, 0] = B1mCoord[i, 2]

        invT1 = np.linalg.inv(T1)

        A = np.dot(invT1, P1) # Transform to bone Optotrak sensor coordinate system
        Coord1[i, 0] = A[0, 0]
        Coord1[i, 1] = A[1, 0]
        Coord1[i, 2] = A[2, 0]

    return Coord1

def main(dir, segment):

    # Assign seg and bone parameters for later calculations
    if segment == 'UpperLeg':
        seg = 'UL_'
        bone = 'Femur'
        markers = ["F1", "F2", "F3", "F4", "F5", "F6"]

    # Get accepted trials from Surgical tools experiment
    masterList, num_accept = XMLparser.getAcceptance(dir)
    accepted_list = []

    for x in masterList:
        if x[1] == 1:
            accepted_list.append(x[0])
    dir_data = os.path.join(dir, 'Data')
    files = os.listdir(dir_data)
    files.sort()

    for file in files:

        if any(file[0:3] in s for s in accepted_list) and '005' in file[0:3]:
            print file
            data = tdsmParserMultis.parseTDMSfile(os.path.join(dir_data, file))

            config = ConfigParser.RawConfigParser()
            if not config.read(os.path.join(os.path.join(dir, 'Configuration'), file[0:-5]+'_State.cfg')):
                raise IOError, "Cannot load configuration file... Check path."

            # Get transformation matrix of CAD points and tool coordinate system for ALL tools
            TOOLS = ['Forceps Left', 'Forceps Right', 'Retractor', 'Scalpel', 'Probe']
            T_CAD_TOS_dict = {}
            for t in TOOLS:
                v = mlab.figure()  # Create Mayavi figure
                visual.set_viewer(v)

                CAD_data = np.genfromtxt(os.path.join('divot_points_in_world', t.replace(" ", "")+'Points2.dat.txt'), delimiter='\t', skip_header=1, usecols=(1,2,3))
                CAD_data_rNames = np.genfromtxt(os.path.join('divot_points_in_world', t.replace(" ", "")+'Points2.dat.txt'), delimiter='\t', skip_header=1, usecols=(0), dtype='S')
                #Find all of the load cell divot points in CAD CS for registration to digitized points
                CAD_pts = CAD_data[[i for i, s in enumerate(CAD_data_rNames) if 'Handle' in s and 'Div' in s]]

                #Add tool specific digitized points to the drawing and define the tool origin in the CAD CS for comparison to digitized point origin
                if t == 'Forceps Left':
                    CAD_pts_2 = CAD_data[-3:]
                    CAD_origin = CAD_pts_2[2]
                elif t == 'Forceps Right':
                    CAD_pts_2 = CAD_data[-4:-1]
                    CAD_origin = CAD_pts_2[1]
                elif t == 'Retractor':
                    # THIS IS WRONG RIGHT NOW!!!
                    CAD_pts_2 = CAD_data[-4:]
                    CAD_origin = CAD_data[-2]
                elif t == 'Scalpel':
                    CAD_pts_2 = CAD_data[-4:]
                    CAD_origin = CAD_pts_2[-1]
                elif t == 'Probe':
                    CAD_pts_2 = CAD_data[-3:]
                    CAD_origin = CAD_pts_2[1]

                for pt in CAD_pts_2:
                    mlab.points3d(*pt, color=(1,0,0))

                # Get digitized points from experiment
                B1m = config.get(t+' LC2REF', 'Collected Points Rigid Body 2 (m)')  # Proximal Digitized points (global)
                B1mr = config.get(t+' LC2REF',
                                  'Collected Points Rigid Body 2 Position Sensor (m,rad)')  # Position of bone Optotrak Sensor

                digitized_points = transform_digitized_points(B1m, B1mr, len(CAD_pts))  # digitized points in tool optotrak sensor coordinate system

                B1m = config.get(t, 'Collected Points Rigid Body 2 (m)')
                B1mr = config.get(t, 'Collected Points Rigid Body 2 Position Sensor (m,rad)')

                if t == 'Probe':
                    len_pts_CS = 10
                else:
                    len_pts_CS = 3

                digitized_points_CS = transform_digitized_points(B1m, B1mr, len_pts_CS)

                # Create transformation matrix using CAD points and digitized points (SVD method)
                x = np.array(digitized_points)  # world points
                y = np.array(CAD_pts) / 1000  # CAD points

                x_bar = np.mean(x, axis=0)
                y_bar = np.mean(y, axis=0)

                A = (x - x_bar).transpose()
                B = (y - y_bar).transpose()

                C = np.dot(B, A.transpose())

                P, D, QT = np.linalg.svd(C, full_matrices=True)

                Rot = np.dot(np.dot(P, np.diag([1, 1, np.linalg.det(np.dot(P, QT))])), QT)
                d = y_bar.transpose() - np.dot(Rot, x_bar.transpose())

                # Define transformation matrix Tool Optotrak Sensor in CAD CS
                T_CAD_TOS = np.zeros((4, 4))
                T_CAD_TOS[0:3, 0:3] = Rot
                T_CAD_TOS[3, 3] = 1
                T_CAD_TOS[0:3, 3] = d

                # Transformation matrix from bone Optotrak sensor coordinate system to bone coordinate system

                print
                print t
                T_TOS_T = config.get(t, 't_sensor2_rb2')
                T_TOS_T = convertTOarray(T_TOS_T).reshape(4, -1)

                origin = np.dot(np.dot(T_CAD_TOS, T_TOS_T), [0,0,0,1])[0:3]*1000
                mlab.points3d(*origin, color=(0,1,0), scale_factor=2)

                #Calculate distance between tip (origin) from CAD and digitized points
                print "Distance between origin of Digitized and CAD: {} mm".format(np.sqrt(np.sum((origin-CAD_origin)**2)))

                xaxis = np.dot(np.dot(T_CAD_TOS, T_TOS_T), [.01,0,0,1])[0:3]*1000
                yaxis = np.dot(np.dot(T_CAD_TOS, T_TOS_T), [0, .01, 0, 1])[0:3]*1000
                zaxis = np.dot(np.dot(T_CAD_TOS, T_TOS_T), [0, 0, .01, 1])[0:3]*1000

                Arrow_From_A_to_B(*(list(origin)+list(xaxis)+list([(1,0,0)])))
                Arrow_From_A_to_B(*(list(origin) + list(yaxis) + list([(0, 1, 0)])))
                Arrow_From_A_to_B(*(list(origin) + list(zaxis) + list([(0, 0, 1)])))


                for i in digitized_points_CS:
                    pt = np.dot(T_CAD_TOS, np.append(i, 1))[0:3]*1000
                    mlab.points3d(*pt, color = (0,1,0))

                #Calculate the RMS Error
                # colors = [(1,1,0), (0,1,1), (1,0,1), (1,0,0), (0,1,0), (0,0,1), (.8, 1, .2), (0,0,0) ] #Used if you want to plot each point as a different color
                SUM = 0
                for i, pt in enumerate(zip(CAD_pts,digitized_points)):
                    pt_CAD = pt[0]
                    pt_DIG = np.dot(T_CAD_TOS, np.append(pt[1],1))[0:3]*1000
                    print np.sqrt(np.sum((pt_CAD - pt_DIG)**2)) # This is the distance between corresponding points
                    SUM += np.sum((pt_CAD - pt_DIG)**2) #Distance between points squared
                    mlab.points3d(*pt_CAD, color = (1,0,0), scale_factor=2)
                    mlab.points3d(*pt_DIG, color = (0,1,0), scale_factor=2)
                print "RMSE = {} mm".format(np.sqrt(SUM/len(CAD_pts)))
                T_CAD_TOS_dict[t] = T_CAD_TOS
                print '+++++++++++++++++++++++'

                if "Forceps" in t:

                    filename2 = os.path.join('ToolSTLs', 'forcepsRight.stl')
                    filename2_left = os.path.join('ToolSTLs', 'forcepsLeft.stl')

                    reader = tvtk.STLReader()
                    reader.file_name = filename2
                    reader.update()

                    mapper = tvtk.PolyDataMapper()
                    configure_input(mapper, reader.output)

                    prop = tvtk.Property(opacity=.8, color = (1,1,1))

                    v.scene.add_actor(tvtk.Actor(mapper=mapper, property=prop))

                    reader_L = tvtk.STLReader()
                    reader_L.file_name = filename2_left
                    reader_L.update()

                    mapper_L = tvtk.PolyDataMapper()
                    configure_input(mapper_L, reader_L.output)

                    prop_L = tvtk.Property(opacity=.8, color=(1,1,1))

                    v.scene.add_actor(tvtk.Actor(mapper=mapper_L, property=prop_L))

                else:
                    filename2 = os.path.join('ToolSTLs', t.lower() + '.stl')

                    reader = tvtk.STLReader()
                    reader.file_name = filename2
                    reader.update()

                    mapper = tvtk.PolyDataMapper()
                    configure_input(mapper, reader.output)

                    prop = tvtk.Property(opacity=.8, color=(1,1,1))

                    v.scene.add_actor(tvtk.Actor(mapper=mapper, property=prop))

            mlab.show()

if __name__ == "__main__":
    dir = '/home/morrile2/Documents/MULTIS Data/MULTIS_surgtools/SMULTIS004-1'  # Cadaver specimen folder
    # dir = '/Users/schimmt/Documents/SurgicalData/SMULTIS030-1'
    # dir = '/Users/schimmt/Multis/app/InstrumentedSurgicalTools/SMULTIS031-1'
    segment = 'UpperLeg'  # Change this to the desired segment
    modality = 'CT'
    main(dir, segment)