# This Script takes the manually chosen anatomical landmarks csv, calculates and adds more anatomical landmarks,
# saves as a new csv

# python 3 version of AnatomicalLandmarks scipt. as of 02/06/2020 any
# future changes will be made in the this version of the script only (AnatomicalLandmarks_p3.py)

import numpy as np
from scipy import spatial
import math
import pylab as pl
import FebCustomization_p3
from lxml import etree as et


class Geometry_Part:

    def __init__(self, name, nodes, node_ids, elems = None, elem_ids = None, nodesets = None, surfaces = None):
        self.name = name
        self.nodes = nodes
        self.node_ids = node_ids
        self.elems = elems
        self.elem_ids = elem_ids
        self.nodesets = {}
        self.surfaces = {}


def element_normals(nodes, elems, node_ids):

    # convert the elems to the correct ids to reference the nodes array
    # need to find the index in node-ids of all values in elems

    # first flatten elems to make it easier
    elems_flat = elems.flatten()

    # subtract node_ids from each element in the flattened list, and get the indices of the zeros
    node_ids_rep = np.tile(node_ids, (len(elems_flat),1))
    diffs = node_ids_rep -  elems_flat[:,None]
    new_flattened_elems = np.nonzero(diffs == 0)[1]

    #reshape back to the n by 3 array
    new_elems = np.reshape(new_flattened_elems, np.shape(elems))

    points = nodes[new_elems]
    P1 = points[:, 0]
    P2 = points[:, 1]
    P3 = points[:, 2]

    midpoints = (P1 + P2 + P3) / 3

    v = P2 - P1
    w = P3 - P1
    normals = np.cross(v, w)

    norms = np.linalg.norm(normals, axis=1)
    normals = np.divide(normals, np.reshape(norms, (len(norms), 1)))

    return normals, midpoints

def oriented_bounding_box(Nodes):

    ca = np.cov(Nodes, y=None, rowvar=False, bias=True)

    v, vect = np.linalg.eig(ca)
    tvect = np.transpose(vect)

    # use the inverse of the eigenvectors as a rotation matrix and
    # rotate the points so they align with the x and y axes
    ar = np.dot(Nodes, np.linalg.inv(tvect))

    # get the minimum and maximum x,y, and z
    mina = np.min(ar, axis=0)
    maxa = np.max(ar, axis=0)
    diff = (maxa - mina) * 0.5

    # the center is just half way between the min and max xyz
    center = mina + diff

    # get the 8 corners by subtracting and adding half the bounding boxes height, length and width to the center
    corners = np.array([center + [-diff[0], -diff[1], -diff[2]], center + [diff[0], -diff[1], -diff[2]],
                        center + [diff[0], diff[1], -diff[2]], center + [-diff[0], diff[1], -diff[2]],
                        center + [-diff[0], -diff[1], diff[2]], center + [diff[0], -diff[1], diff[2]],
                        center + [diff[0], diff[1], diff[2]], center + [-diff[0], diff[1], diff[2]]])

    # use the the eigenvectors as a rotation matrix and
    # rotate the corners and the centerback
    corners = np.dot(corners, tvect)
    center = np.dot(center, tvect)

    # get the vectors of the 3 edges
    edges = np.array([corners[1] - corners[0], corners[4] - corners[0], corners[3] - corners[0]])
    edge_lengths = np.linalg.norm(edges, axis=1)
    edge_vectors = edges/edge_lengths[:,None]

    return edge_lengths, edge_vectors, center


def ReadGeometry(febfile):
    """Reads the Geometry section of an Febio input file, and Returns all the Geometry_Parts in a dictionary """

    geo_tree, _= FebCustomization_p3.find_geometry_tree(febfile)
    febio_spec_root_geo = geo_tree.getroot()
    geometry_section = FebCustomization_p3.get_section('Geometry', febio_spec_root_geo)

    All_Geometry = {}

    def get_data(geom):
        """Reads the data from a geometry section subelement such as Nodes, Elements"""
        data_dict = {}
        for point in geom:
            try:
                point_id = int(point.attrib['id'])
                data_as_str = point.text.split(',')
                point_data = [float(x) for x in data_as_str]
                data_dict[point_id] = point_data
            except: # if its a comment
                pass
        return data_dict

    Nodes_sections = FebCustomization_p3.get_section('Nodes',geometry_section)
    Element_sections = FebCustomization_p3.get_section('Elements', geometry_section)
    Nodeset_sections = FebCustomization_p3.get_section('NodeSet', geometry_section)
    Surface_sections = FebCustomization_p3.get_section('Surface', geometry_section)

    # first create parts from all the Nodes sections
    for node_data in Nodes_sections:
        try:
            nodeset_name = node_data.attrib['name']
            node_dict = get_data(node_data)
            node_ids = np.asarray(list(node_dict.keys()))
            nodes = np.asarray(list(node_dict.values()))

            Nodes = Geometry_Part(nodeset_name, nodes, node_ids)

            All_Geometry[nodeset_name] = Nodes
        except: # if its a comment
            pass

    # then add the elements to those Parts
    for elem_data in Element_sections:
        try:
            elemset_name = elem_data.attrib['name']
            elem_dict = get_data(elem_data)
            elem_ids = np.asarray(list(elem_dict.keys()))
            elems = np.asarray(list(elem_dict.values()))
            elems = elems.astype(int)
            elem_ids = elem_ids.astype(int)
            All_Geometry[elemset_name].elems = elems
            All_Geometry[elemset_name].elem_ids = elem_ids
        except:
            pass

    # add the nodesets to the parts
    for nodeset in Nodeset_sections:

        try:
            nodeSet_name = nodeset.attrib['name']
            part_name = nodeSet_name.split('_')[0]
            node_ids = []
            for node in nodeset:
                try:
                    node_id = int(node.attrib['id'])
                    node_ids.append(node_id)
                except:
                    pass
            node_ids = np.asarray(node_ids)
            All_Geometry[part_name].nodesets[nodeSet_name] = node_ids
        except:
            pass

    # add the surfaces to the parts
    for surface in Surface_sections:

        try:
            surface_name = surface.attrib['name']
            part_name = surface_name.split('_')[0]
            elem_dict = get_data(surface)
            surface_elems = np.asarray(list(elem_dict.values()))
            All_Geometry[part_name].surfaces[surface_name] = surface_elems
        except:
            pass

    return All_Geometry


def closest_point_index(node, array):
    """ find the index of the closes point in array to node"""
    closest_index = np.nanargmin(spatial.distance.cdist([node], array))
    return closest_index


def normalize(v):
    norm = np.linalg.norm(v)
    if norm == 0:
        return v
    return v / norm


def AddJointLandmarks_oks003_registered(AL, Right_or_Left, all_parts):
    """ this script is hard coded to replace the add joint landmarks script using the registered
    experimental data to calculate the joint landmarks"""

    # oks003 axes calculations

    # digitized landmarks registered to model coordinate system
    T1 = np.array([-41.06072145, -12.59595175, -33.09560265])
    T2 = np.array([30.08807098, 1.86686417, -42.6212507])
    T3 = np.array([-14.29465061, -16.63438655, -387.88052684])
    T4 = np.array([-14.30838607, -16.81189603, -388.15705736])
    T5 = np.array([-73.14836013, -48.72339693, -399.00929519])
    T6 = np.array([-72.72670309, -49.23909687, -399.03084317])
    F1 = np.array([-44.834302, 5.35961468, -8.55682216])
    F2 = np.array([36.87319067, 1.81609413, -13.98810562])
    F3 = np.array([16.30850255, 18.10361746, 407.9059709])
    F4 = np.array([16.30369863, 17.95336429, 408.15742557])
    F5 = np.array([1.66932492, -26.76421061, 402.95294433])
    F6 = np.array([1.89036519, -25.685348, 402.84165032])

    tibial_origin = (T1 + T2) / 2.0
    ankle_center = (T3 + T4 + T5 + T6) / 4.0
    zt = tibial_origin - ankle_center
    zt = zt / np.linalg.norm(zt)

    # add the new landmarks to the AL dictionary
    AL['TBO'] = tibial_origin
    AL['Zt_axis'] = zt

    xt_temp = (T2 - T1)
    xt_temp = xt_temp / np.linalg.norm(xt_temp)

    yt = np.cross(zt, xt_temp)
    yt = yt / np.linalg.norm(yt)

    xt = np.cross(yt, zt)
    xt = xt / np.linalg.norm(xt)

    # add the new landmarks to the AL dictionary
    AL['Yt_axis'] = yt
    AL['Xt_axis'] = xt


    femoral_origin = (F1 + F2) / 2.0
    hip_center = (F3 + F4 + F5 + F6) / 4.0

    xf = F2 - F1
    xf = xf / np.linalg.norm(xf)

    zf_temp = hip_center - femoral_origin
    zf_temp = zf_temp / np.linalg.norm(zf_temp)

    yf = np.cross(zf_temp, xf)
    yf = yf / np.linalg.norm(yf)

    zf = np.cross(xf, yf)

    # add the new landmarks to the AL dictionary
    AL['FMO'] = femoral_origin
    AL['Xf_axis'] = xf
    AL['Yf_axis'] = yf
    AL['Zf_axis'] = zf

    Ftf = np.cross(zt, xf)
    Ftf = Ftf / np.linalg.norm(Ftf)

    # intersection of Ftf anf Xf
    # find plane including Zt and Ftf
    normal = np.cross(zt, Ftf)
    plane_point = tibial_origin
    D = np.dot(normal, plane_point)

    # find the intesection of the plane with the xf axis
    t = (D - (np.dot(normal, femoral_origin))) / np.dot(normal, xf)
    Ftf_xf_intersect = femoral_origin + (t * xf)

    AL['Ftf_axis'] = Ftf
    AL['Ftf_Xf_intersect'] = Ftf_xf_intersect

    # patella is the same as it was before, we didnt use exprimental data to do these calculations- just copying from the original function
    patella_origin = 0.5 * (AL['MPR'] + AL['LPR'])
    xp = AL['MPR'] - AL['LPR']

    # if its a right knee, swap the direction. xp should point laterally for the right knee, medially for the left knee
    if Right_or_Left == 'R':
        xp = -xp

    xp = normalize(xp)
    zp = np.array([-xp[2], 0, xp[0]])
    zp = normalize(zp)
    yp = np.cross(zp, xp)
    yp = normalize(yp)

    # add the new landmarks to the AL dictionary
    AL['PTO'] = patella_origin
    AL['Xp_axis'] = xp
    AL['Yp_axis'] = yp
    AL['Zp_axis'] = zp

    print('\n Patella coordinate system found')

    # find Patellofemoral floating axis
    Fpf = np.cross(zp, xf)
    Fpf = normalize(Fpf)

    # intersection of Fpf anf Xf
    # find plane including Zp and Fpf
    normal = np.cross(zp, Fpf)
    plane_point = patella_origin
    D = np.dot(normal, plane_point)

    # find the intesection of the plane with the xf axis
    t = (D - (np.dot(normal, femoral_origin))) / np.dot(normal, xf)
    Fpf_xf_intersect = femoral_origin + (t * xf)

    AL['Fpf_axis'] = Fpf
    AL['Fpf_Xf_intersect'] = Fpf_xf_intersect

    return AL


def  AddJointLandmarks_du02_registered(AL, Right_or_Left, all_parts):

    # digitized landmarks registered to model coordinate system
    tibial_origin =  np.array([-22.093, 21.462, 1.838])
    ankle_center = np.array([-2.256091975587048637e+01, 1.187577442259493665e+01, -1.870634473018965593e+02])
    mtp = np.array([-4.004545533035823013e+01, 2.867889244720712227e+01, -6.045751094052334906e+00])
    ltp = np.array([9.070382386969129129e-01, 2.050263033284607417e+01 ,-1.519371932541432102e+00])
    femoral_origin = np.array([-2.142702978987773577e+01, 4.162980331390022570e+01, 7.701877078606315763e+00])
    hip_center = np.array([-47.57672293,-25.9340359,455.35360123])
    mfc = np.array([-5.162462998188532026e+01 ,4.604576916742635362e+00 ,1.454273615352403226e+01])
    lfc = np.array([-2.732911862043877704e+00 ,6.638116981659738514e-01 ,1.983900548032593747e+01])
    patella_origin = np.array([-1.666189656189223456e+01, 6.287427726552794383e+01, 2.931991595433390785e+01])
    mpr = np.array([-3.205826086325663482e+01, 7.182270720248122586e+01, 2.958480831570104996e+01])
    lpr = np.array([2.395935168577793206e+00, 6.182879669854872162e+01, 2.981003850753829099e+01])
    ptt = np.array([-1.229866733054356587e+01, 6.344522084107743609e+01 ,4.680293497589433116e+01])
    ptb = np.array([-1.314024991833808542e+01 ,6.255110847976774124e+01, 1.501489002447784138e+01])

    # du02 axes calculations

    #TIBIA
    zt = tibial_origin-ankle_center
    zt = zt/np.linalg.norm(zt)

    yt = np.cross(zt, (mtp - ltp))
    yt = -yt/np.linalg.norm(yt) # negative for right knee

    xt = np.cross(yt, zt)
    xt = xt/np.linalg.norm(xt)

    # add the new landmarks to the AL dictionary
    AL['TBO'] = tibial_origin
    AL['Zt_axis'] = zt
    AL['Yt_axis'] = yt
    AL['Xt_axis'] = xt

    #FEMUR
    zf = hip_center - femoral_origin
    zf = zf/np.linalg.norm(zf)

    yf = np.cross(zf, (mfc - lfc))
    yf = -yf/np.linalg.norm(yf) # negative for right knee

    xf = np.cross(yf, zf)
    xf = xf/np.linalg.norm(xf)

    # add the new landmarks to the AL dictionary
    AL['FMO'] = femoral_origin
    AL['Xf_axis'] = xf
    AL['Yf_axis'] = yf
    AL['Zf_axis'] = zf

    # TIBIOFEMORAL
    Ftf = np.cross(zt, xf)
    Ftf = Ftf/np.linalg.norm(Ftf)

    # intersection of Ftf anf Xf
    # find plane including Zt and Ftf
    normal = np.cross(zt, Ftf)
    plane_point = tibial_origin
    D = np.dot(normal, plane_point)

    # find the intesection of the plane with the xf axis
    t = (D - (np.dot(normal, femoral_origin))) / np.dot(normal, xf)
    Ftf_xf_intersect = femoral_origin + (t * xf)

    AL['Ftf_axis'] = Ftf
    AL['Ftf_Xf_intersect'] = Ftf_xf_intersect

    #PATELLA
    xp = lpr - mpr
    xp = xp/np.linalg.norm(xp)

    yp = yf = np.cross((ptt - ptb),xp)
    yp = yp/np.linalg.norm(yp)

    zp = np.cross(xp,yp)
    zp = zp/np.linalg.norm(zp)

    AL['PTO'] = patella_origin
    AL['Xp_axis'] = xp
    AL['Yp_axis'] = yp
    AL['Zp_axis'] = zp

    #PATELLOFEMORAL
    Fpf = np.cross(zp, xf)
    Fpf = Fpf/np.linalg.norm(Fpf)

    # intersection of Fpf anf Xf
    # find plane including Zp and Fpf
    normal = np.cross(zp, Fpf)
    plane_point = patella_origin
    D = np.dot(normal, plane_point)

    # find the intesection of the plane with the xf axis
    t = (D - (np.dot(normal, femoral_origin))) / np.dot(normal, xf)
    Fpf_xf_intersect = femoral_origin + (t * xf)

    AL['Fpf_axis'] = Fpf
    AL['Fpf_Xf_intersect'] = Fpf_xf_intersect

    return AL


def AddJointLandmarks(AL, Right_or_Left, all_parts):
    """ find the landmarks needed to define the joints """

    # checking for tibial landmarks
    tibial_landmarks = ['MTS','LTS','MTP','LTP']
    all_tibia_there = True
    for lm in tibial_landmarks:
        try:
            AL[lm]
        except KeyError:
            print('\n the following landmark was not found for the tibia: ' + lm)
            all_tibia_there = False

    # checking for femoral landmarks
    femoral_landmarks = ['DFP', 'MFC', 'LFC']
    all_femur_there = True
    for lm in femoral_landmarks:
        try:
            AL[lm]
        except:
            print('\n the following landmark was not found for the femur: ' + lm)
            all_femur_there = False

    # checking for patella landmarks
    patella_landmarks = ['MPR', 'LPR']
    all_patella_there = True
    for lm in patella_landmarks:
        try:
            AL[lm]
        except KeyError:
            print('\n the following landmark was not found for the patella: ' + lm)
            all_patella_there = False

    # check if the knee is oriented femur_up or femur_down
    femur_up = True
    if all_femur_there and all_tibia_there: # if either are missing just assume femur-up?
        print("Medial Tibial spine")
        print(AL['MTS'])
        print("Distal Femur Point")
        print(AL['DFP'])
        if AL['MTS'][2] >  AL['DFP'][2]: # the the tibia landmark is higher than the femur landmark
            femur_up = False

    # if all tibia there, and TBB included in model, calculate tibia coordinate system
    if all_tibia_there and 'TBB' in all_parts:
        tibial_origin = 0.5 * (AL['MTS'] + AL['LTS'])

        if femur_up:
                zt = np.array([0.0, 0.0, 1.0])
        else: # if the knee is oriented femur-down
                zt = np.array([0.0, 0.0, -1.0])

        yt = np.cross(zt, (AL['MTP'] - AL['LTP']))

        # add the new landmarks to the AL dictionary
        AL['TBO'] = tibial_origin
        AL['Zt_axis'] = zt

        if Right_or_Left == 'R': # if its a right knee, need to swap the direction so yt points anterioirly
            yt = -yt

        yt = normalize(yt)
        xt = np.cross(yt, zt)
        xt = normalize(xt)

        # add the new landmarks to the AL dictionary
        AL['Yt_axis'] = yt
        AL['Xt_axis'] = xt

        print('\n Tibia coordinate system found')

    else:
        print('\n Tibia coordinate system will not be included in the model')


    # if all there, and FMB included in model, calculate femur coordinate system
    if all_femur_there and 'FMB' in all_parts:

        femoral_origin = AL['DFP']

        if femur_up:
            zf = np.array([0.0, 0.0, 1.0])
        else:
            zf =  np.array([0.0, 0.0, -1.0])

        yf = np.cross(zf, (AL['MFC'] - AL['LFC']))

        # femur y axis should point anteriorly:
        if Right_or_Left == 'R':
            yf = -yf

        yf = normalize(yf)
        xf = np.cross(yf, zf)
        xf = normalize(xf)

        # add the new landmarks to the AL dictionary
        AL['FMO'] = femoral_origin
        AL['Xf_axis'] = xf
        AL['Yf_axis'] = yf
        AL['Zf_axis'] = zf

        print('\n Femur coordinate system found')

        # find Tibiofemoral Floating Axis, if both are there
        if all_tibia_there and 'TBB' in all_parts:

            Ftf = np.cross(zt, xf)
            Ftf = normalize(Ftf)

            # # intersection of Ftf and xf  = femoral_origin + projection of femoral_origin - tibial_origin on the xf axis,
            # scalar_projection = np.dot(tibial_origin-femoral_origin, xf)/np.linalg.norm(xf)
            # vector_projection = scalar_projection * xf/np.linalg.norm(xf)
            # Ftf_xf_intersect = femoral_origin + vector_projection

            # intersection of Ftf anf Xf
            # find plane including Zt and Ftf
            normal = np.cross(zt, Ftf)
            plane_point = tibial_origin
            D = np.dot(normal, plane_point)

            # find the intesection of the plane with the xf axis
            t = (D - (np.dot(normal, femoral_origin))) / np.dot(normal, xf)
            Ftf_xf_intersect = femoral_origin + (t * xf)

            AL['Ftf_axis'] = Ftf
            AL['Ftf_Xf_intersect'] = Ftf_xf_intersect

            print('\n Tibiofemoral joint coordinate system found')

    else:
        print('\n Femur coordinate system will not be included in the model')


    # if all there, and PTB included in model calculate patella coordinate system
    if all_patella_there and 'PTB' in all_parts:

        patella_origin = 0.5 * (AL['MPR'] + AL['LPR'])
        xp = AL['MPR'] - AL['LPR']

        # if its a right knee, swap the direction. xp should point laterally for the right knee, medially for the left knee
        if Right_or_Left == 'R':
            xp = -xp

        xp = normalize(xp)
        zp = np.array([-xp[2], 0, xp[0]])
        zp = normalize(zp)
        yp = np.cross(zp, xp)
        yp = normalize(yp)

        # add the new landmarks to the AL dictionary
        AL['PTO'] = patella_origin
        AL['Xp_axis'] = xp
        AL['Yp_axis'] = yp
        AL['Zp_axis'] = zp

        print('\n Patella coordinate system found')

        # find Patellofemoral floating axis
        if all_femur_there and 'FMB' in all_parts:

            Fpf = np.cross(zp, xf)
            Fpf = normalize(Fpf)

            # # intersection of Fpf and xf
            # scalar_projection = np.dot(patella_origin-femoral_origin, xf)/np.linalg.norm(xf)
            # vector_projection = scalar_projection * xf/np.linalg.norm(xf)
            # Fpf_xf_intersect = femoral_origin + vector_projection

            # intersection of Fpf anf Xf
            # find plane including Zp and Fpf
            normal = np.cross(zp, Fpf)
            plane_point = patella_origin
            D = np.dot(normal, plane_point)

            # find the intesection of the plane with the xf axis
            t = (D - (np.dot(normal, femoral_origin))) / np.dot(normal, xf)
            Fpf_xf_intersect = femoral_origin + (t * xf)

            AL['Fpf_axis'] = Fpf
            AL['Fpf_Xf_intersect'] = Fpf_xf_intersect

            print('\n Patellofemoral joint coordinate system found')

    else:
        print('\n Patella coordinate system will not be included in the model')

    return AL


def TransformToAlignAxes(origin, u_axis, v_axis):
    """create the transformation matrix to transform the given coordinate system (u,v,w) to the World coordinate system (x,y,z)"""

    w_axis = np.cross(u_axis, v_axis)
    axes = [u_axis, v_axis, w_axis]

    RM = np.eye(4)
    RM[:3, 3] = origin
    RM[:3, :3] = np.asarray(axes).T

    transform = np.linalg.inv(RM)

    return transform


def transformation_matrix(q1, q2, q3, q4, q5, q6):
    """ this transformation matrix performs roll (q4), then pitch (q5), then yaw (q6)
        and translates by x (q1), y (q2), z (q3)"""

    T = np.zeros((4, 4))

    T[0, 0] = math.cos(q6) * math.cos(q5)
    T[1, 0] = math.sin(q6) * math.cos(q5)
    T[2, 0] = -math.sin(q5)

    T[0, 1] = math.cos(q6) * math.sin(q5) * math.sin(q4) - math.sin(q6) * math.cos(q4)
    T[1, 1] = math.sin(q6) * math.sin(q5) * math.sin(q4) + math.cos(q6) * math.cos(q4)
    T[2, 1] = math.cos(q5) * math.sin(q4)

    T[0, 2] = math.cos(q6) * math.sin(q5) * math.cos(q4) + math.sin(q6) * math.sin(q4)
    T[1, 2] = math.sin(q6) * math.sin(q5) * math.cos(q4) - math.cos(q6) * math.sin(q4)
    T[2, 2] = math.cos(q5) * math.cos(q4)

    T[0, 3] = q1
    T[1, 3] = q2
    T[2, 3] = q3

    T[3,3] = 1

    return T


def transform_points(T, Points):
    """ transforms an array of x,y,z points using transformation matrix T.
    Points in an n by 3 numpy array where n is the number of points """

    # double check that points we given as array not list
    Points  = np.asarray(Points)

    # transpose the array
    transformed_points = np.matrix.transpose(Points)

    # add a row of ones at the end
    try:
        transformed_points = np.append(transformed_points, [np.ones(len(Points))], axis=0)
    except ValueError:
        transformed_points = np.append(transformed_points, 1)

    # Multiply by transformation matrix
    transformed_points = np.matmul(T, transformed_points)

    # remove the last row of ones
    transformed_points = transformed_points[0:-1]

    # transpose back
    transformed_points = np.matrix.transpose(transformed_points)

    return transformed_points


def LigamentInsertions(anatomical_landmarks, all_parts):
    """ add the insertion points for MPFL to the anatomical landmarks dict """

    femur = all_parts['FMB']
    femur_nodes = femur.nodes
    femur_node_ids = femur.node_ids

    patella = all_parts['PTB']
    patella_nodes = patella.nodes
    patella_node_ids = patella.node_ids

   # _________________femoral insertions for MPFL_____________________________
    femoral_transform = TransformToAlignAxes(anatomical_landmarks['FMO'], anatomical_landmarks['Xf_axis'], anatomical_landmarks['Yf_axis'])

    medial_fc = anatomical_landmarks['MFC']

    transformed_femur_nodes = transform_points(femoral_transform, femur_nodes)
    transformed_mfc = transform_points(femoral_transform, medial_fc)

    # account for left vs right knee
    if transformed_mfc[0] > 0:
        medial_idx = np.where(transformed_femur_nodes[:,0]>0)[0]
        most_medial_f = np.max(transformed_femur_nodes[:,0])
        medial_farthest_idx = np.where(transformed_femur_nodes[:,0] > 0.9*most_medial_f)[0]
        lateral_epicondyle_idx = np.argmin(transformed_femur_nodes[:,0])
        lateral_epicondyle = transformed_femur_nodes[lateral_epicondyle_idx]
        most_lateral = lateral_epicondyle[0]
        lateral_farthest_idx = np.where(transformed_femur_nodes[:,0] < 0.9*most_lateral)[0]

    else:
        medial_idx = np.where(transformed_femur_nodes[:,0]<0)[0]
        most_medial_f = np.min(transformed_femur_nodes[:,0])
        medial_farthest_idx = np.where(transformed_femur_nodes[:, 0] < 0.9 * most_medial_f)[0]
        lateral_epicondyle_idx = np.argmax(transformed_femur_nodes[:, 0])
        lateral_epicondyle = transformed_femur_nodes[lateral_epicondyle_idx]
        most_lateral = lateral_epicondyle[0]
        lateral_farthest_idx = np.where(transformed_femur_nodes[:,0] > 0.9*most_lateral)[0]

    medial_nodes = transformed_femur_nodes[medial_idx]
    max_AP = np.max(medial_nodes[:,1])
    min_AP = np.min(medial_nodes[:,1])
    CD = max_AP - min_AP   #CD is the anterior-posterior size of the medial condyle

    distal_z = np.min(medial_nodes[:,2]) # find the distal end of medial condyle
    post_y = min_AP
    zm = distal_z+0.5*CD # 0.5xCD from the distal side of the medial condyle
    ym = post_y+0.4*CD # 0.4xCD from the posterior side of the medial condyle

    # check the farthest 10% of points for the point closest to z, y
    idx_in_farthest = closest_point_index([ym,zm], transformed_femur_nodes[medial_farthest_idx][:,1:])
    idx_insertion = medial_farthest_idx[idx_in_farthest]

    # anatomical_landmarks["MPFL-F_01"] = anatomical_landmarks["MPFL-F_02"] = femur_nodes[idx_insertion] # insertion coordinates
    anatomical_landmarks["MPFL-F_01"] = anatomical_landmarks["MPFL-F_02"] = femur_node_ids[idx_insertion] # the node number in febio of the insertion origin


    # ____________________________Femoral Insertions for LPFL___________________

    lateral_epicondyle = transformed_femur_nodes[lateral_epicondyle_idx]
    yl = lateral_epicondyle[1]+ 10.6 # 10.6 mm anterior to lateral epicondyle
    zl = lateral_epicondyle[2]-2.6   # 2.6 mm distal to the lateral epicondyle
    idx_in_farthest = closest_point_index([yl,zl], transformed_femur_nodes[lateral_farthest_idx][:,1:])
    idx_mid_insertion = lateral_farthest_idx[idx_in_farthest]
    mid_insertion = femur_nodes[idx_mid_insertion]

    # insertion coordinates
    # anatomical_landmarks["LPFL-F_01"] = mid_insertion + [0,0,11.7/2] # average width 11.7 mm
    # anatomical_landmarks["LPFL-F_02"] = mid_insertion - [0,0,11.7/2]

    idx_1 = closest_point_index(mid_insertion + [0,0,11.7/2], femur_nodes)
    idx_2 = closest_point_index(mid_insertion - [0,0,11.7/2], femur_nodes)

    anatomical_landmarks["LPFL-F_01"] = femur_node_ids[idx_1] #febio node number
    anatomical_landmarks["LPFL-F_02"] = femur_node_ids[idx_2]


    # _________________patellar insertions for MPFL_______________________
    # superomedial aspect of patella (~ top 1/3)
    patella_transform = TransformToAlignAxes(anatomical_landmarks["PTO"], anatomical_landmarks["Xp_axis"], anatomical_landmarks["Yp_axis"])

    transformed_patella_nodes = transform_points(patella_transform, patella_nodes)

    medial_patella_ridge = anatomical_landmarks['MPR']
    lateral_patella_ridge = anatomical_landmarks['LPR']

    transformed_mpr = transform_points(patella_transform, medial_patella_ridge)
    transformed_lpr = transform_points(patella_transform, lateral_patella_ridge)

    inferior_pole = np.min(transformed_patella_nodes[:,2])
    superior_pole = np.max(transformed_patella_nodes[:,2])
    height = superior_pole-inferior_pole

    if transformed_mpr[0] > 0: #left vs right knee
        most_medial_quart_idx = np.where(transformed_patella_nodes[:,0]>0.75*transformed_mpr[0])[0]
        most_lateral_quart_idx = np.where(transformed_patella_nodes[:,0]<0.75*transformed_lpr[0])[0]
    else:
        most_medial_quart_idx = np.where(transformed_patella_nodes[:,0]<0.75*transformed_mpr[0])[0]
        most_lateral_quart_idx = np.where(transformed_patella_nodes[:, 0]>0.75 * transformed_lpr[0])[0]

    # top insertion
    most_medial_quart = transformed_patella_nodes[most_medial_quart_idx]
    zm1 = np.max(most_medial_quart[:,2]) # find the highest point in the most medial quarter
    ym1 = np.average(most_medial_quart[:, 1])  # find the middle depth of the most medial quarter
    idx_in_quart_1 = closest_point_index([ym1, zm1], most_medial_quart[:, 1:])
    idx_insertion_1 = most_medial_quart_idx[idx_in_quart_1]

    #bottom insertion
    zm2 = superior_pole - ((1.0 / 3.0) * height)
    ym2 = ym1
    idx_in_quart_2 = closest_point_index([ym2, zm2], most_medial_quart[:,1:])
    idx_insertion_2 = most_medial_quart_idx[idx_in_quart_2]

    # anatomical_landmarks["MPFL-P_01"] = patella_nodes[idx_insertion_1]
    # anatomical_landmarks["MPFL-P_02"] = patella_nodes[idx_insertion_2]
    anatomical_landmarks["MPFL-P_01"] = patella_node_ids[idx_insertion_1]# febio node numbering
    anatomical_landmarks["MPFL-P_02"] = patella_node_ids[idx_insertion_2]

    #_______________________ Patellar Insertion for LPFL
    zl1 = superior_pole - 8.0 # 8 mm from superior pole to upper insertion,
    zl2 = zl1-(0.45*height) # insertion width ~ 45% of articular surface.
    most_lateral_quart = transformed_patella_nodes[most_lateral_quart_idx]
    yl1 = yl2 =  np.average(most_lateral_quart[:, 1])  # find the middle depth of the most lateral quarter
    idx_in_quart_top = closest_point_index([yl1, zl1], most_lateral_quart[:, 1:])
    idx_in_quart_bot = closest_point_index([yl2, zl2], most_lateral_quart[:, 1:])

    idx_insertion_top = most_lateral_quart_idx[idx_in_quart_top]
    idx_insertion_bot = most_lateral_quart_idx[idx_in_quart_bot]

    # anatomical_landmarks["LPFL-P_01"] = patella_nodes[idx_insertion_top]
    # anatomical_landmarks["LPFL-P_02"] = patella_nodes[idx_insertion_bot]

    anatomical_landmarks["LPFL-P_01"] = patella_node_ids[idx_insertion_top]
    anatomical_landmarks["LPFL-P_02"] = patella_node_ids[idx_insertion_bot]

    return anatomical_landmarks


def Quadriceps_Slider(anatomical_landmarks, all_parts):
    """Find the location of the slider joint origin, and the node id's that will be attached to the rigid body"""

    qat_nodes = all_parts['QAT'].nodes
    qat_node_ids = all_parts['QAT'].node_ids

    # check the direction of the femur axis, if no femur assume > 0
    if 'FMB' in all_parts:
        zf_axis = anatomical_landmarks['Zf_axis']
    else:
        zf_axis = [0,0,1]

    if zf_axis[2] > 0:
        # find the nodes that have a z value in the top 5 %
        z_qat = np.max(qat_nodes[:, 2])
        end_nodes_idx = np.where(qat_nodes[:,2]>0.95*z_qat)[0]
    else:
        # find the nodes that have a z vale in the bottom 5 %
        z_qat = np.min(qat_nodes[:, 2])
        end_nodes_idx = np.where(qat_nodes[:, 2] < 0.95 * z_qat)[0]

    end_nodes = qat_nodes[end_nodes_idx]

    # locate the slider joint in the approximate center
    x_qat = np.average(end_nodes[:,0])
    y_qat = np.average(end_nodes[:,1])

    anatomical_landmarks['QAT'] = [x_qat,y_qat,z_qat]

    # get the ids of the nodes that will be attached to the rigid body
    top_nodes_ids = qat_node_ids[end_nodes_idx]

    return anatomical_landmarks, top_nodes_ids


def ElemWiseMaterials(ModelProperties_xml):
    """ read the modelproperties file to find which parts to do element wise fiber directions"""

    elem_wise_fibers = []

    ModelProperties_tree = et.parse(ModelProperties_xml)
    ModelProperties = ModelProperties_tree.getroot()
    Material = ModelProperties.find("Material")

    for mat in Material:
        try:
            if mat.find('fibers').text == 'element':
                elem_wise_fibers.append(mat.attrib["name"])
            else:
                pass
        except:
            pass


    return elem_wise_fibers


def element_wise_fibers(all_parts, L):
    """ find the element wise fiber directions by taking
    the cross production between the nearest face normal and
    the shortest edge of the oriented bounding box  """

    Nodes_L  = all_parts[L].nodes
    Node_ids_L = all_parts[L].node_ids
    Elems_L = all_parts[L].elems
    Surfaces_L  = all_parts[L].surfaces
    edge_length_P, edge_vectors_P, center_P = oriented_bounding_box(Nodes_L)

    # find the shortest edge of the bounding box
    shortest_edge_idx = np.argmin(edge_length_P)
    shortest_vector = edge_vectors_P[shortest_edge_idx]

    # use the normals of the surface elements -
    # go with nearest surface element for internal element

    #normals and midpoints of surface elems
    face_elems = Surfaces_L['{}_All_Faces'.format(L)]
    face_normals, face_midpoints = element_normals(Nodes_L, face_elems, Node_ids_L)

    fiber_orientations = []
    for elem in Elems_L:
        first_node_id = elem[0]
        first_node = Nodes_L[np.where(Node_ids_L == first_node_id)[0]][0]
        close_idx = closest_point_index(first_node, face_midpoints)
        normal = face_normals[close_idx]
        fib_orientation = np.cross(normal, shortest_vector)
        fiber_orientations.append(fib_orientation)

    fiber_orientations = np.asarray(fiber_orientations)

    # this was the hacked way for doing it
    # # this file is created by running medial_surface.py on the initial STL file for the PCL,
    # # which creates an xyz file. open the xyz file in meshlab, and calculate normals for point sets
    # # then export the xyz file from meshlab inlcuding normals.
    # medial_points_with_normals = '/home/schwara2/Documents/Open_Knees/knee_hub/oks003/calibration/MaterialProperties/PCL_testing/oks003_PCL_IP6_center_norm.xyz'
    # medial_nodes = []
    # medial_normals = []
    # with open(medial_points_with_normals) as f:
    #     for line in f:
    #         all_Str = line.split(' ')
    #         all_Str = all_Str[:6]
    #         all_flt = [float(x) for x in all_Str]
    #         medial_nodes.append(all_flt[0:3])
    #         medial_normals.append(all_flt[3:])
    #
    # medial_nodes = np.asarray(medial_nodes)
    # medial_normals = np.asarray(medial_normals)
    # pcl_orientations = []
    # for elem in Elems_L:
    #     first_node_id = elem[0]
    #     first_node = Nodes_L[np.where(Node_ids_L == first_node_id)[0]][0]
    #     close_idx = closest_point_index(first_node, medial_nodes)
    #     normal = medial_normals[close_idx]
    #     fib_orientation = np.cross(normal, shortest_vector)
    #     pcl_orientations.append(fib_orientation)
    # pcl_orientations = np.asarray(pcl_orientations)

    return fiber_orientations

def FiberOrientations(all_parts, element_wise_materials):
    """ Get the fiber orientations for the tissues that require it.
    Element wise fiber directions fo mensci, and for any ligaments in the list element-wise-materials"""

    fiber_orientations = {}
    # for the ligaments and tendons we need the longest edge of the oriented bounding box

    ligaments = ['PCL','ACL','MCL','LCL','PTL','QAT']

    for L in ligaments:
        try:
            Nodes_L = all_parts[L].nodes
        except KeyError: # if the ligament isn't included
            continue

        if L in element_wise_materials:

            print("finding element wise fiber directions for the " + L)
            fiber_orientations[L] = element_wise_fibers(all_parts, L)

        else: # single fiber direction by bounding box
            edge_length_L, edge_vectors_L, center_L = oriented_bounding_box(Nodes_L)
            fiber_dir = edge_vectors_L[np.argmax(edge_length_L)]
            fiber_orientations[L] = fiber_dir

    # for the meniscus, we will approximate a circle, and give each element a fiber direction
    # note - we may be able to use the same method as ligamnet fiber directions for this
    meniscus = ['MNS-M', 'MNS-L']
    for M in meniscus:

        try:
            Nodes_M = all_parts[M].nodes
            Node_ids_M = all_parts[M].node_ids
            Elems_M = all_parts[M].elems

        except: # if meniscus not included
            continue

        print("finding element-wise fiber directions for the " + M)
        edge_length_M, edge_vectors_M, center_M= oriented_bounding_box(Nodes_M)

        # shortest edge of box is height, we will fit an oval to the width, length
        shortest_idx = np.argmin(edge_length_M)
        mask = np.ones(3, np.bool)
        mask[shortest_idx] = 0
        edge_directions = edge_vectors_M[mask]

        # create an ellipse at the origin in the x, y plane, the size of the longer edges of the bounding box

        side_lengths = edge_length_M[mask]
        a, b = side_lengths[0], side_lengths[1]

        oval_x = np.arange(-a,a+0.1,0.1)
        oval_y_upper = np.sqrt((b**2)*(1-(oval_x/a)**2))
        oval_y_lower = -oval_y_upper
        oval = np.zeros((len(oval_x)*2,3))
        oval[:,0] = np.concatenate((oval_x, oval_x), axis=0)
        oval[:,1] = np.concatenate((oval_y_upper, oval_y_lower), axis=0)

        # transform the nodes so that the oriented bounding box sits at the origin
        T = TransformToAlignAxes(center_M, edge_directions[0], edge_directions[1])
        transformed_nodes = transform_points(T, Nodes_M)

        meniscus_orientations  = []
        # for each element, find closest point on oval, fiber orientation is tangent to oval at that point

        for elem in Elems_M:
            first_node_id = elem[0]
            first_node = transformed_nodes[np.where(Node_ids_M == first_node_id)[0]][0]
            oval_idx = closest_point_index(first_node, oval)
            if oval[:,1][oval_idx] == 0: # to avoid division by zero in the case of an infinite slope
                meniscus_orientations.append([0, 1, 0])
            else:
                slope_of_tangent = -(oval[:,0][oval_idx]*(b**2))/((a**2)*oval[:,1][oval_idx])
                meniscus_orientations.append([1, slope_of_tangent, 0])

        # transform the vectors back to the meniscus coordinate system
        meniscus_orientations = transform_points(np.linalg.inv(T), meniscus_orientations)
        # subtract center of oriented box to get the orientations world coordinate system
        meniscus_orientations = meniscus_orientations - center_M

        # add to fiber orientations dictionary
        fiber_orientations[M] = meniscus_orientations

    return fiber_orientations



def XmlLandmarksToDict(ModelProperties_xml):

    ModelProperties_tree = et.parse(ModelProperties_xml)
    ModelProperties = ModelProperties_tree.getroot()
    Landmarks = FebCustomization_p3.get_section("Landmarks", ModelProperties)
    dict = {}
    for subelement in Landmarks:
        try:
            str_coord = subelement.text
            str_coord = str_coord.split(',')
            float_coord = [float(c) for c in str_coord]
            arr = np.asarray(float_coord)
            dict[subelement.tag] = arr
        except ValueError: # in case there is a comment
            continue
    return dict


def plot_fiber_orientations(fiber_orientations, all_parts):

    import matplotlib.pyplot as plt
    from mpl_toolkits.mplot3d import Axes3D
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')

    # ax.quiver(x,y,z,u,v,z)
    # u,v,w is the vector (ie fiber orientation
    # x,y,z is the location (ie node)

    for ligament_name, orientation in fiber_orientations.items():

        if len(orientation.shape) > 1:  # if we have fiber directions for each element
            first_nodes_ids = all_parts[ligament_name].elems[:,0]
            nodes_ids = all_parts[ligament_name].node_ids
            nodes = all_parts[ligament_name].nodes
            first_nodes = np.zeros((len(first_nodes_ids),3))
            for c, node_id in enumerate(first_nodes_ids):
                node_idx = np.where(nodes_ids == node_id)[0]
                first_nodes[c] = nodes[node_idx]

            # just plot 100 elements orientations in each to check
            ax.quiver(first_nodes[:,0][:100], first_nodes[:,1][:100], first_nodes[:,2][:100], orientation[:,0][:100], orientation[:,1][:100], orientation[:,2][:100], length=1)
        else:  # one fiber direction for all elements in ligament
            pass

    plt.show()


def MCL_MNS_tie(all_parts):
    """find the nodesets to connect the mcl and mns ties nodes via springs.
    return a dictionary of node ids to be connected via springs"""

    def get_nodeset(part_name, nodeset_name):

        part = all_parts[part_name]
        nodes = part.nodes
        node_ids = part.node_ids
        tiesnode_ids = part.nodesets[nodeset_name]

        tiesnode_loc_idx = []
        for id in tiesnode_ids:
            node_idx = np.where(node_ids == id)[0][0]
            tiesnode_loc_idx.append(node_idx)

        tiesnode_loc_idx = np.asarray(tiesnode_loc_idx)
        tiesnodes = nodes[tiesnode_loc_idx]

        return tiesnodes, tiesnode_ids

    # get the mcl nodeset
    MCL_tiesnodes, MCL_tiesnode_ids = get_nodeset("MCL", 'MCL_@_MNS-M_TiesNodes')

    # get the mns-m nodeset
    MNSM_tiesnodes, MNSM_tiesnode_ids = get_nodeset("MNS-M",'MNS-M_@_MCL_TiesNodes')

    # for each mcl node, idx of the closest node in the mns-m nodeset
    closest_node_idx = np.nanargmin(spatial.distance.cdist(MCL_tiesnodes, MNSM_tiesnodes), axis=1)

    # global indices of MNSM pair nodes
    MNSM_pair_ids = MNSM_tiesnode_ids[closest_node_idx]

    # create a dictionary of node pairs to connect using their global node ids
    node_pairs = dict(zip(MCL_tiesnode_ids, MNSM_pair_ids))

    return node_pairs


def DoCalculations(ModelProperties_xml, febio_file):

    # get user input for whether the knee is a right knee or a left knee
    Right_or_Left = input("Is this a right knee or a left knee? Type R for right, and L for left: ")
    if Right_or_Left in ['r','R','right','Right']:
        Right_or_Left = 'R'
    elif Right_or_Left in ['l','L','left','Left']:
        Right_or_Left = 'L'
    else:
        print("user input not understood")
        import sys
        sys.exit()

    # extract the manually chosen landmarks from the xml file
    manual_landmarks = XmlLandmarksToDict(ModelProperties_xml)

    print('\n Reading the Geometry')
    all_parts = ReadGeometry(febio_file)

    # _________KNEE JCS CALCULATION_________________
    print('\n Calculating Joint Axes')
    landmarks= AddJointLandmarks(manual_landmarks, Right_or_Left, all_parts)

    # # use for registered experimental data for oks003
    # landmarks = AddJointLandmarks_oks003_registered(manual_landmarks, Right_or_Left, all_parts)

    # # use for registered experimental data for DU02
    # landmarks  = AddJointLandmarks_du02_registered(manual_landmarks, Right_or_Left, all_parts)

    # _________MCL-MNS-M TIE_________________
    if 'MCL' in all_parts and 'MNS-M' in all_parts:
        mcl_mns_nodepairs = MCL_MNS_tie(all_parts)
    else:
        mcl_mns_nodepairs = None

    # _________LIGAMENT_INSERTIONS_________________
    if 'FMB' in all_parts and 'PTB' in all_parts:
        print('\n Finding Ligament Insertion Sites')
        landmarks = LigamentInsertions(landmarks, all_parts)

    if 'QAT' in all_parts:
        print('\n Finding Quadriceps Slider Joint')
        all_landmarks, qat_nodeset = Quadriceps_Slider(landmarks, all_parts)
    else:
        all_landmarks = landmarks
        qat_nodeset = None

    # _________FIBER ORIENTATIONS_________________
    element_wise_materials = ElemWiseMaterials(ModelProperties_xml)

    print('\n Calculating the Material Fiber Orientations')
    fiber_orientations = FiberOrientations(all_parts,  element_wise_materials)

    # # use this function to plot the fiber directions for element wise ligaments
    # plot_fiber_orientations(fiber_orientations, all_parts)

    return all_landmarks, qat_nodeset, fiber_orientations, Right_or_Left, mcl_mns_nodepairs


if __name__=='__main__':

    # test with these files

    febio_file = 'C:\\Users\\Owner\\Documents\\CCF\\test\\FeBio.feb'
    ModelProperties_xml = 'C:\\Users\\Owner\\Documents\\CCF\\test\\ModelProperties.xml'
    DoCalculations(ModelProperties_xml, febio_file)
