import sys
import numpy as np
from lxml import etree as et
from mpl_toolkits.mplot3d import axes3d
import matplotlib.pyplot as plt
import math
import copy
import os

class rigid_body:
    def __init__(self, mat_name, mat_id, center_of_mass):
        self.name = mat_name
        self.id = int(mat_id)
        self.center_of_mass = center_of_mass


class Cylindrical_Joint:

    def __init__(self, name, axis, origin, body_a, body_b, joint_number):
        self.name = name
        self.body_a = body_a
        self.body_b = body_b
        self.axis = axis
        self.origin = origin
        self.joint_number = joint_number


def find_rigid_bodies(log_filename):
    """ find and store the rigid body information in a dictionary {material_name:rigid_body_object}"""
    rigid_bodies = {}
    # get rigid body information
    # use material data section not rigid body section because names are not included in rigid body section
    with open(log_filename) as f:
        for line in f:
            if 'MATERIAL DATA' in line:
                next(f)
                l = next(f)
                while len(l.strip()) != 0:  # keep going until the end of the material data section
                    if 'rigid body' in l:
                        l = l.split('-')
                        mat_id = int(l[0])
                        mat_name = l[1].split('(')[
                            0].strip()  # get rid of any leading and trailing whitespace around name
                        while 'center_of_mass' not in l:
                            l = next(f)
                        com_strings = l.split(':')[-1].split(',')
                        com = []
                        for i in com_strings:
                            com.append(float(i))
                        rigid_bodies[mat_name] = rigid_body(mat_name, mat_id, com)  # add to dictionary
                    l = next(f)
                break

    return rigid_bodies


def find_input_filename(log_filename):
    """get the name of the input file"""

    input_filename = None

    with open(log_filename) as f:
        for line in f:
            if 'FILES USED' in line:
                next(f)
                l=next(f)
                l=l.split(':')
                input_filename = l[-1].strip()
                break

    # check that the path is the same as the logfile
    base_file = os.path.basename(input_filename)
    path = os.path.dirname(log_filename)
    input_file = os.path.join(path, base_file)

    # if len(os.path.dirname(input_filename))>0:
    #     input_file = input_filename
    # else:
    #     path = os.path.dirname(log_filename)
    #     input_file = os.path.join(path, os.sep, input_filename)

    return input_file


def collect_data(log_filename):
    """ data stored in nested dictionaries. {data_name:{body_id:data_as_array}} """
    time_steps = [0] # zero is just a placeholder, will remove later
    all_data = {}
    #
    with open(log_filename) as f:
        for line in f:
            if 'Data Record' in line:
                data_record_num = int(line.split('#')[-1])
                next(f) # ===== line
                next(f) # step number line
                time_line = next(f)
                time = float(time_line.split('=')[-1])
                if data_record_num == 1: # its its a new set of data records. ie a new solved time step
                    time_steps.append(time)
                data_name_line = next(f)
                data_name = data_name_line.split('=')[-1].strip()
                l = next(f)

                # check if this type of data has been added. if not add it
                if data_name not in all_data.keys():
                    all_data[data_name] = {}

                # keep going until the end of this data record
                while len(l.strip()) != 0:
                    line_as_list = l.split(' ')
                    try:
                        body_id = int(line_as_list.pop(0))
                    except: break # in case there's a line of string

                    # check if this body has been added, if not add it
                    if body_id not in all_data[data_name].keys():
                        all_data[data_name][body_id] = []
                    all_data[data_name][body_id].append([float(x) for x in line_as_list])
                    l = next(f)

    time_steps = np.asarray(time_steps[1:]) # remove placeholder and turn into array

    # turn data lists into arrays
    for data_name, data_dict in iter(all_data.items()):
        try:
            for body_id, data_set in iter(data_dict.items()):
                all_data[data_name][body_id] = np.asarray(data_set)
        except:
            try:
                np.asarray(data_dict)
            except:
                pass

    return all_data, time_steps


def add_initial_values(all_data, time_steps, rigid_bodies):

    # add t=0 zero to the time_steps array
    time_steps = np.insert(time_steps,0,[0])

    # add the initial_position of the rigid_bodies to the COM data
    com_data = all_data['center_of_mass']
    for rb in rigid_bodies.values():
        rb_id = rb.id
        rb_com_initial = rb.center_of_mass
        rb_com_data = com_data[rb_id]
        com_data[rb_id] = np.insert(rb_com_data, 0, rb_com_initial, axis=0)

    # add the initial rotation_quaterion to the initial time step
    rot_quat_data = all_data['rotation_quaternion']
    rot_quat_initial = [0, 0, 0, 1]
    for rb_id, rot_quat in iter(rot_quat_data.items()):
        rot_quat_data[rb_id] = np.insert(rot_quat, 0, rot_quat_initial, axis=0)

    # to all other forces and moments, add 0,0,0 to the data
    forces_initial = [0.0,0.0,0.0]
    for data_name, data_dict in iter(all_data.items()):
        if data_name == 'rotation_quaternion' or data_name == 'center_of_mass':
            pass # already dealt with these
        else:
            for id, data in iter(data_dict.items()):
                data_dict[id] = np.insert(data, 0, forces_initial, axis=0)

    return all_data, time_steps


def fixed_axis(rb_rotation_quaternion, joint_info):

    # get the initial axis of the joint
    a = joint_info.axis
    initial_axis = copy.deepcopy(a)

    # find the rotation of the rigid body connected to the fixed axis
    rotation = rotation_matrix_from_quaternion(rb_rotation_quaternion)

    # apply the rotation matrix to the initial axis to get the direction of the axis at each time point
    initial_axis = np.asarray(initial_axis)
    joint_axis = np.matmul(rotation, initial_axis)

    # normalize, just in case
    norms = np.linalg.norm(joint_axis, axis=1)
    joint_axis = np.divide(joint_axis, np.reshape(norms, (len(norms),1)))

    return joint_axis


def get_joint_axes(rigid_bodies, rot_quat_data, constraint_info):
    """ get the axes of the joints at each time step"""

    fixed_joints = {}

    # to account for right or left knee axis naming conventions save the names of all the axes how they appear
    # in the constraint definitions
    for name in constraint_info.keys():
        # patellofemoral axes
        if "Patellar" in name:
            if "Flexion" in name:
                fixed_joints[name] = 'FMB'
                pat_flex_name = name
            elif "Tilt" in name:
                fixed_joints[name] = 'PTB'
                pat_tilt_name = name
            elif "Rotation" in name:
                pat_rot_name = name
        # tibiofemoral axes
        else:
            if "Extension" in name:
                fixed_joints[name] = 'FMB'
                flex_ext_name = name
            elif "Internal" in name:
                fixed_joints[name] = 'TBB'
                ext_int_name = name
            elif "Abduction" in name:
                abd_add_name = name

    joint_axes = {} # we will save all the joint axes by their name, and a matrix representing the axis at every time point

    for ax_name,rb_name in iter(fixed_joints.items()):
        joint_info = constraint_info[ax_name]
        rigid_body_id = rigid_bodies[rb_name].id
        rot_quat = rot_quat_data[rigid_body_id]
        axes = fixed_axis(rot_quat, joint_info)
        joint_axes[ax_name] = axes

    try:
        joint_axes[abd_add_name] = np.cross(joint_axes[ext_int_name],joint_axes[flex_ext_name])
    except: pass
    try:
        joint_axes[pat_rot_name] = np.cross(joint_axes[pat_tilt_name],joint_axes[pat_flex_name])
    except: pass

    return joint_axes

def kinematics_from_transform(T):
    # extract the rotations and translations along the joints axes from the transformation matrix
    # use: https://simtk.org/plugins/moinmoin/openknee/Infrastructure/ExperimentationMechanics?action=AttachFile&do=view&target=Knee+Coordinate+Systems.pdf
    # page 6

    beta = np.arcsin(T[:, 0, 2])
    alpha = np.arctan2(-T[:, 1, 2], T[:, 2, 2])
    gamma = np.arctan2(-T[:, 0, 1], T[:, 0, 0])

    ca = np.cos(alpha)
    sa = np.sin(alpha)
    cb = np.cos(beta)
    sb = np.sin(beta)

    b = np.multiply(T[:, 1, 3], ca) + np.multiply(T[:, 2, 3], sa)
    c = np.divide(np.multiply(T[:, 2, 3], ca) - np.multiply(T[:, 1, 3], sa), cb)
    a = T[:, 0, 3] - np.multiply(c, sb)

    return a,b,c,alpha,beta,gamma

def joint_kinematics_2(bone_axes, rigid_bodies, rotation_quaternion_data, com_data, constraint_info):

    Fem_in_world = BoneinWorld_Transform(bone_axes, rotation_quaternion_data, rigid_bodies, 'FMB', com_data)
    Tib_in_world = BoneinWorld_Transform(bone_axes, rotation_quaternion_data, rigid_bodies, 'TBB', com_data)
    world_in_Fem = np.linalg.inv(Fem_in_world)
    T_in_F = np.matmul(world_in_Fem, Tib_in_world)
    a,b,c,alpha,beta,gamma = kinematics_from_transform(T_in_F)

    try:
        Pat_in_world = BoneinWorld_Transform(bone_axes, rotation_quaternion_data, rigid_bodies, 'PTB', com_data)
        P_in_F = np.matmul(world_in_Fem, Pat_in_world)
        a_pat,b_pat,c_pat,alpha_pat,beta_pat,gamma_pat = kinematics_from_transform(P_in_F)
    except: pass


    all_kinematics = {}
    for joint_info in constraint_info.values():
        if 'Patellar' in joint_info.name:
            if 'Extension' in joint_info.name:
                all_kinematics[joint_info.name] = (a_pat - a_pat[0], np.degrees(alpha_pat - alpha_pat[0]))
            elif 'Rotation' in joint_info.name:
                all_kinematics[joint_info.name] = (b_pat - b_pat[0], np.degrees(beta_pat - beta_pat[0]))
            elif 'Tilt' in joint_info.name:
                all_kinematics[joint_info.name] = (c_pat - c_pat[0], np.degrees(gamma_pat - gamma_pat[0]))
        elif 'Extension' in joint_info.name:
            all_kinematics[joint_info.name] = (a-a[0], np.degrees(alpha-alpha[0]))
        elif 'Abduction' in joint_info.name:
            all_kinematics[joint_info.name]= (b-b[0], np.degrees(beta-beta[0]))
        elif 'Internal' in joint_info.name:
            all_kinematics[joint_info.name] = (c-c[0], np.degrees(gamma-gamma[0]))
        else:
            pass

    return all_kinematics


def joint_kinematics(center_of_mass_data, rotation_quaternion_data, constraint_info, joint_axes):
    """Calculate the kinematics of the rigid cylindrical joints. save in translation, rotation dictionary with the same names as the constraints"""

    all_kinematics = {}

    # joint_info = constraint_info[joint_name]

    for joint_info in constraint_info.values():

        # find the rigid_bodies connecting the joint
        body_a = joint_info.body_a
        body_b = joint_info.body_b

        # translation of the joint
        body_a_data = center_of_mass_data[body_a]
        body_b_data = center_of_mass_data[body_b]

        # translation of each rigid body
        b_translation_vec = body_b_data-body_b_data[0]
        a_translation_vec = body_a_data-body_a_data[0]

        # translation along joint axis is the dot product
        joint_axis = joint_axes[joint_info.name]
        b_trans = np.sum(joint_axis * b_translation_vec, axis=1)
        a_trans = np.sum(joint_axis * a_translation_vec, axis=1)

        # the joint translation is the relative translation of body a and b along the joint axis
        translation = b_trans - a_trans

        # rotation of the joint
        body_a_rot_quat = rotation_quaternion_data[body_a]
        body_b_rot_quat = rotation_quaternion_data[body_b]

        # rotation matrix of each body in the world coordinate
        body_a_rot_mat = rotation_matrix_from_quaternion(body_a_rot_quat)
        body_b_rot_mat = rotation_matrix_from_quaternion(body_b_rot_quat)

        # to get rotation of body b with respect to body a we need Rab = Rwa'*Rwb
        a_rot_mat_trans = np.transpose(body_a_rot_mat, axes = (0,2,1))
        relative_rot_mat = np.matmul(a_rot_mat_trans,body_b_rot_mat)

        # convert this to euler axis-angle to get the angle of rotation, and direction
        rotation_angle, rotation_axis = euler_axis_angle_from_rotation_matrix(relative_rot_mat)
        rotation_axis_norm = np.linalg.norm(rotation_axis, axis=1)
        rotation = np.degrees(rotation_angle) # no direction yet

        dot = np.sum(joint_axis * rotation_axis, axis=1)
        cos_theta = np.divide(dot, rotation_axis_norm)  # should be all 1's and -1's (or close to it)

        rotation_direction = np.around(cos_theta)
        rotation = np.multiply(rotation, rotation_direction)

        # get rid of nans that occured by dividing by zero (where rotation was zero). convert nans to zero
        rotation = np.nan_to_num(rotation)

        all_kinematics[joint_info.name] = (translation, rotation)


    return all_kinematics


def euler_axis_angle_from_quaternion(q):
    """return the euler angle and axis given the rotation quaternion"""

    R = rotation_matrix_from_quaternion(q)
    angle, axis = euler_axis_angle_from_rotation_matrix(R)

    return angle, axis


def euler_axis_angle_from_rotation_matrix(R):
    """ calculate the euler axis, angle given the rotation matrix"""

    angle = np.arccos((R[:, 0, 0] + R[:, 1, 1] + R[:, 2, 2] - 1) / 2)
    e1 = np.divide((R[:, 2, 1] - R[:, 1, 2]), (2 * np.transpose(np.sin(angle))))
    e2 = np.divide((R[:, 0, 2] - R[:, 2, 0]), (2 * np.transpose(np.sin(angle))))
    e3 = np.divide((R[:, 1, 0] - R[:, 0, 1]), (2 * np.transpose(np.sin(angle))))
    axis = np.zeros((len(e1), 3))
    axis[:, 0] = e1
    axis[:, 1] = e2
    axis[:, 2] = e3

    #### note will only work if theta is not a multiple of pi

    return angle, axis


def euler_angles_from_quaternion(q):

    qw = q[:,3]
    qx = q[:,0]
    qy = q[:,1]
    qz = q[:,2]

    theta1 = np.arctan2((2*(np.multiply(qw,qx) + np.multiply(qy,qz))),(1-2*(np.square(qx)+np.square(qy))))
    theta2 = np.arcsin(2*(np.multiply(qw,qy) - np.multiply(qz,qx)))
    theta3 = np.arctan2((2*(np.multiply(qw,qz) + np.multiply(qx,qy))),(1-2*(np.square(qy)+np.square(qz))))

    return theta1,theta2,theta3


def rotation_matrix_from_quaternion(q):
    """ calculate rotation matrix/ matrices from the quaterion data"""
    R = np.zeros((len(q) ,3 ,3))

    qi = q[: ,0]
    qj = q[:, 1]
    qk = q[:, 2]
    qr = q[:, 3]

    s = np.sqrt(np.power(qi ,2) +np.power(qj, 2) + np.power(qk, 2) + np.power(qr, 2))
    R[:, 0, 0] = 1 - (2 * s * (np.power(qj, 2) + np.power(qk, 2)))
    R[:, 0, 1] = 2 * s * (np.multiply(qi, qj) - np.multiply(qk, qr))
    R[:, 0, 2] = 2 * s * (np.multiply(qi, qk) + np.multiply(qj, qr))
    R[:, 1, 0] = 2 * s * (np.multiply(qi, qj) + np.multiply(qk, qr))
    R[:, 1, 1] = 1 - (2 * s * (np.power(qi, 2) + np.power(qk, 2)))
    R[:, 1, 2] = 2 * s * (np.multiply(qj, qk) - np.multiply(qi, qr))
    R[:, 2, 0] = 2 * s * (np.multiply(qi, qk) - np.multiply(qj, qr))
    R[:, 2, 1] = 2 * s * (np.multiply(qj, qk) + np.multiply(qi, qr))
    R[:, 2, 2] = 1 - (2 * s * (np.power(qi, 2) + np.power(qj, 2)))

    return R


def GetConstraintInfo(feb_filename):
    """ extract the rigid bodies, axis, origin, for each of the cylindrical joints in the febio file"""

    Febio_tree = et.parse(feb_filename)
    Febio_spec_root = Febio_tree.getroot()
    # LoadingStep_Section = FebCustomization_p3.get_section("Step", Febio_spec_root, use_other_attribute='name', attribute_value='LoadingStep')
    LoadingStep_Section = Febio_spec_root.find("Step")
    # Constraint_Section = FebCustomization_p3.get_section("Constraints", LoadingStep_Section)
    Constraint_Section = LoadingStep_Section.find("Constraints")
    counter = 0
    constraint_info = {}

    for constraint in Constraint_Section:

        try:
            constraint_type = constraint.attrib["type"] # if its a comment this will cause an error

            if constraint_type == "rigid cylindrical joint":

                constraint_name = constraint.attrib["name"]

                # joint_axis_str = FebCustomization_p3.get_section("joint_axis", constraint).text
                joint_axis_str = constraint.find("joint_axis").text
                joint_axis_list = joint_axis_str.split(',')
                axis = [float(x) for x in joint_axis_list]

                # joint_origin_str = FebCustomization_p3.get_section("joint_origin", constraint).text
                joint_origin_str = constraint.find("joint_origin").text
                joint_origin_list = joint_origin_str.split(',')
                origin = [float(x) for x in joint_origin_list]

                # body_a_id = int(FebCustomization_p3.get_section("body_a", constraint).text)
                body_a_id = int(constraint.find("body_a").text)
                # body_b_id = int(FebCustomization_p3.get_section("body_b", constraint).text)
                body_b_id = int(constraint.find("body_b").text)

                counter += 1
                cylindrical_joint = Cylindrical_Joint(constraint_name, axis, origin, body_a_id, body_b_id, counter)
                constraint_info[constraint_name] = cylindrical_joint

            else:
                counter += 1  # count the constraint number but ignore otherwise

        except:
            pass

    return constraint_info


def get_bone_axes(model_properties_xml):
    """ find the axes of the bones in the model properties file"""

    ModelProperties_tree = et.parse(model_properties_xml)
    ModelProperties = ModelProperties_tree.getroot()
    landmarks = ModelProperties.find('Landmarks')

    bone_axes = {}

    def extract_axes(bone_first_letter):
        x = landmarks.find('X{}_axis'.format(bone_first_letter)).text.split(',')
        y = landmarks.find('Y{}_axis'.format(bone_first_letter)).text.split(',')
        z = landmarks.find('Z{}_axis'.format(bone_first_letter)).text.split(',')

        axes  = [x, y, z]
        axes = [[float(i) for i in a] for a in axes]

        return axes

    try:
        # tibia_x = landmarks.find('Xt_axis').text.split(',')
        # tibia_y = landmarks.find('Yt_axis').text.split(',')
        # tibia_z = landmarks.find('Zt_axis').text.split(',')
        #
        # tibia  = [tibia_x, tibia_y, tibia_z]
        # tibia = [[float(i) for i in a] for a in tibia]

        tibia = extract_axes('t')
        bone_axes['TBB'] = tibia
    except:
        pass

    try:
        # femur_x = landmarks.find('Xf_axis').text.split(',')
        # femur_y = landmarks.find('Yf_axis').text.split(',')
        # femur_z = landmarks.find('Zf_axis').text.split(',')
        #
        # femur = [femur_x,femur_y,femur_z]
        # femur = [[float(i) for i in a] for a in femur]
        femur = extract_axes('f')
        bone_axes['FMB'] = femur
    except:
        pass


    try:
        # patella_x = landmarks.find('Xp_axis').text.split(',')
        # patella_y = landmarks.find('Yp_axis').text.split(',')
        # patella_z = landmarks.find('Zp_axis').text.split(',')
        #
        # patella = [patella_x, patella_y, patella_z]
        # patella = [[float(i) for i in a] for a in patella]
        patella = extract_axes('p')
        bone_axes['PTB'] = patella
    except:
        pass


    return bone_axes


def BoneinWorld_Transform(all_bone_axes, rotation_quaternion_data, rigid_bodies, bone_name, com_data):
    """create the transfromation matrix to transform a vector from world coordinates to femur
     coordinates at each time step"""

    # initial bone coordinate system defined in world coordinates
    bone_axes = all_bone_axes[bone_name]
    Bone_com = com_data[rigid_bodies[bone_name].id]

    # rotation matrix at each step representing the rotation of the femur in world coordinates
    rotation_matrix = rotation_matrix_from_quaternion(rotation_quaternion_data[rigid_bodies[bone_name].id])

    # apply rotation matrix to the transpose of the bone axes to get the orientation of the bone axes at every time step
    M = np.transpose(bone_axes)
    Bone_rot_in_World = np.matmul(rotation_matrix, M)

    # create the full transformation from bone axes,and center of mass location at each time step
    Bone_in_World = np.zeros((len(Bone_rot_in_World),4,4))
    Bone_in_World[:,3,3] = 1
    Bone_in_World[:,0:3,0:3] = Bone_rot_in_World
    Bone_in_World[:,0:3,3] = Bone_com

    return Bone_in_World


def Plot3DMotion(xyz_data, time_steps, figure_title, png_file_name, units=None):

    # try to create images, if fails just csv

    if units is not None:
        u=units
    else:
        u='mm'

    try:
        fig = plt.figure(figsize=(6.4,6.4))
        fig.suptitle(figure_title)
        data_list = [("time_steps",time_steps)]

        plt.subplot(311)
        plt.plot(time_steps, xyz_data[:,0])
        data_list.append(("X ["+u+"]",xyz_data[:,0]))
        plt.xlabel('Time')
        plt.ylabel('X ['+u+']')

        plt.subplot(312)
        plt.plot(time_steps, xyz_data[:, 1])
        data_list.append(("Y ["+u+"]", xyz_data[:, 1]))
        plt.xlabel('Time')
        plt.ylabel('Y ['+u+"]")

        plt.subplot(313)
        plt.plot(time_steps, xyz_data[:, 2])
        data_list.append(("Z ["+u+"]", xyz_data[:, 2]))
        plt.xlabel('Time')
        plt.ylabel('Z ['+u+"]")

        plt.savefig(png_file_name)
        #plt.show()
        plt.close()

    except:
        print("failed creating images, trying to create csv")
        data_list = [("time_steps", time_steps)]
        data_list.append(("X ["+u+"]", xyz_data[:, 0]))
        data_list.append(("Y ["+u+"]", xyz_data[:, 1]))
        data_list.append(("Z ["+u+"]", xyz_data[:, 2]))

    return data_list


def ProcessAndPlotJointKinematics(com_data, rot_quat_data, time_steps, constraint_info, joint_axes, bone_axes, rigid_bodies):

    # calculate the joints kinematics for all the constraints

    # two different ways to calculate kinematics. joint kinematics 2 used the femur to tibia transfromation
    # to extract the kinematics

    # all_kinematics = joint_kinematics(com_data, rot_quat_data, constraint_info, joint_axes)

    all_kinematics = joint_kinematics_2(bone_axes, rigid_bodies, rot_quat_data, com_data, constraint_info)

    tibiofemoral_translations = {}
    tibiofemoral_rotations = {}
    patellofemoral_translations = {}
    patellofemoral_rotations = {}

    for name, kin in iter(all_kinematics.items()):
        if "Patellar" in name:
            patellofemoral_translations[name] = kin[0]
            patellofemoral_rotations[name] = kin[1]
        else:
            tibiofemoral_translations[name] = kin[0]
            tibiofemoral_rotations[name] = kin[1]

    # plot joint kinematics, save png, and save data  to csv
    # try creating both csv and png, if fails just csv


    try:
        # tibiofemoral joint
        fig = plt.figure(figsize=(12.8, 7.2))
        fig.suptitle("Kinematics of Tibiofemoral cylindrical joints")

        data_list = [("time_steps", time_steps)]

        plt.subplot(1, 2, 1)
        for label, trans in iter(tibiofemoral_translations.items()):
            plt.plot(time_steps, trans, label=label + '_axis')
            data_list.append((label + '_Translation [mm]', trans))
        plt.legend(loc='upper left')
        plt.title('Translations')
        plt.xlabel('Time')
        plt.ylabel('mm')

        plt.subplot(1, 2, 2)
        for label, rot in iter(tibiofemoral_rotations.items()):
            plt.plot(time_steps, rot, label=label + '_axis')
            data_list.append((label + '_Rotation [deg]', rot))
        plt.legend(loc='upper left')
        plt.title('Rotations')
        plt.xlabel('Time')
        plt.ylabel('deg')

        plt.savefig('Tibiofemoral_Kinematics.png')
        # plt.show()

        save_to_csv(data_list, "Tibiofemoral_Kinematics.csv")

        plt.close()

        # patellofemoral joint
        # plot joint kinematics
        fig = plt.figure(figsize=(12.8, 7.2))
        fig.suptitle("Kinematics of patellofemoral cylindrical joints")
        data_list = [("time_steps", time_steps)]

        plt.subplot(1, 2, 1)
        for label, trans in iter(patellofemoral_translations.items()):
            plt.plot(time_steps, trans, label=label + '_axis')
            data_list.append((label + '_Translation [mm]', trans))
        plt.legend(loc='upper left')
        plt.title('Translations')
        plt.xlabel('Time')
        plt.ylabel('mm')

        plt.subplot(1, 2, 2)
        for label, rot in iter(patellofemoral_rotations.items()):
            plt.plot(time_steps, rot, label=label + '_axis')
            data_list.append((label + '_Rotation [deg]', rot))
        plt.legend(loc='upper left')
        plt.title('Rotations')
        plt.xlabel('Time')
        plt.ylabel('deg')

        plt.savefig('Patellofemoral_Kinematics.png')
        # plt.show()

        save_to_csv(data_list, "Patellofemoral_Kinematics.csv")
        plt.close()
    except:
        print('failed creating images, trying to create csv file only')

        data_list = [("time_steps", time_steps)]
        for label, trans in iter(tibiofemoral_translations.items()):
            data_list.append((label + '_Translation [mm]', trans))
        for label, rot in iter(tibiofemoral_rotations.items()):
            data_list.append((label + '_Rotation [deg]', rot))

        save_to_csv(data_list, "Tibiofemoral_Kinematics.csv")

        data_list = [("time_steps", time_steps)]
        for label, trans in iter(patellofemoral_translations.items()):
            data_list.append((label + '_Translation [mm]', trans))
        for label, rot in iter(patellofemoral_rotations.items()):
            data_list.append((label + '_Rotation [deg]', rot))

        save_to_csv(data_list, "Patellofemoral_Kinematics.csv")


def ProcessAndPlotTranslations(rigid_bodies, com_data, rotation_quaternion_data, time_steps, bone_axes):

    Fem_in_World = BoneinWorld_Transform(bone_axes, rotation_quaternion_data, rigid_bodies, 'FMB', com_data)
    World_in_Fem = np.linalg.inv(Fem_in_World)

    def get_relative_translation(bone_name):

        Bone_in_World = BoneinWorld_Transform(bone_axes, rotation_quaternion_data, rigid_bodies, bone_name, com_data)
        Bone_in_Fem = np.matmul(World_in_Fem, Bone_in_World)

        # COM position is just the translation part of the matrix
        T = Bone_in_Fem[:,0:3,3]
        bone_relative_to_femur = T - T[0]

        return bone_relative_to_femur

    # If the Tibia is in the model, find its translation relative to femur
    try:
        tibia_relative_to_femur = get_relative_translation('TBB')
        # plot the translations in 3D and then save to csv
        data_list = Plot3DMotion(tibia_relative_to_femur, time_steps,
                     'Translation of tibia origin relative to femur origin \n in femoral coordinate system', 'Tibia_Translation.png')
        save_to_csv(data_list, "Tibia_Translation.csv")

    except KeyError:
        pass

    # If the Patella is in the model, find its translation relative to femur
    try:
        patella_relative_to_femur = get_relative_translation('PTB')

        data_list = Plot3DMotion(patella_relative_to_femur, time_steps,
                     'Translation of patella origin relative to femur origin \n in femoral coordinate system', 'Patella_Translation.png')
        save_to_csv(data_list, "Patella_Translation.csv")

    except KeyError:
        pass


def ProcessAndPlotBoneKinetics(rigid_bodies, rigid_moments, rigid_forces, time_steps, bone_axes, com_data):
    """plot the bone forces and moments"""

    Tibia_rb = rigid_bodies["TBB"]
    Tibia_id = Tibia_rb.id

    Femur_rb = rigid_bodies["FMB"]
    Femur_id = Femur_rb.id

    Fibula_rb = rigid_bodies["FBB"]
    Fibula_id = Fibula_rb.id

    force_tibia =rigid_forces[Tibia_id]
    moments_tibia=rigid_moments[Tibia_id]

    force_fibula = rigid_forces[Fibula_id]
    moments_fibula= rigid_moments[Fibula_id]

    force_femur= rigid_forces[Femur_id]
    moments_femur = rigid_moments[Femur_id]

    # move the fibula forces to the tibia origin
    tibia_com_data = com_data[Tibia_id]
    fibula_com_data= com_data[Fibula_id]
    femur_com_data =com_data[Femur_id]

    # vector from tibia to fibula origin - fibula and tibia are fixed so this shouldnt change,
    # but just in case of future use in a model where they are not fixed, assume a moving COM
    vec_fbo = fibula_com_data - tibia_com_data
    moments_fibula_tibia = np.cross(vec_fbo, force_fibula) + moments_fibula + moments_tibia
    force_fibula_tibia = force_fibula + force_tibia

    # convert loads and moments to the tibia CS
    tibia_axes = np.asarray(bone_axes["TBB"])
    T_tib_in_image = np.linalg.inv(tibia_axes.T)

    Forces_Tib_Fib_TCS = np.matmul(T_tib_in_image, np.reshape(force_fibula_tibia, (len(force_fibula_tibia),3,1)))
    Moments_Tib_Fib_TCS = np.matmul(T_tib_in_image, np.reshape(moments_fibula_tibia,(len(moments_fibula_tibia),3,1)))

    Forces_Tib_Fib_TCS = np.reshape(Forces_Tib_Fib_TCS, (len(Forces_Tib_Fib_TCS,),3))
    Moments_Tib_Fib_TCS= np.reshape(Moments_Tib_Fib_TCS, (len(Moments_Tib_Fib_TCS),3))

    Forces_Fem_TCS = np.matmul(T_tib_in_image, np.reshape(force_femur, (len(force_femur),3,1)))
    Moments_Fem_TCS = np.matmul(T_tib_in_image, np.reshape(moments_femur, (len(moments_femur),3,1)))

    Forces_Fem_TCS = np.reshape(Forces_Fem_TCS, (len(Forces_Fem_TCS,),3))
    Moments_Fem_TCS= np.reshape(Moments_Fem_TCS, (len(Moments_Fem_TCS),3))


    # forces_tib_fib_TCS = {'Tibia_x Load': Forces_Tib_Fib_TCS[:,0],
    #           'Tibia_y Load': Forces_Tib_Fib_TCS[:,1],
    #           'Tibia_z Load':Forces_Tib_Fib_TCS[:,2]}
    # moments_tib_fib_TCS = {'Tibia_x Moment': Moments_Tib_Fib_TCS[:,0],
    #           'Tibia_y Moment': Moments_Tib_Fib_TCS[:,1],
    #           'Tibia_z Moment':Moments_Tib_Fib_TCS[:,2]}

    model_names = {"Tibia_and_Fibula_Kinetics_in_TibiaCS": (Forces_Tib_Fib_TCS, Moments_Tib_Fib_TCS),
                   "Tibia_and_Fibula_Kinetics_in_ImageCS":(force_fibula_tibia, moments_fibula_tibia),
                   "Femur_Kinetics_in_TibiaCS":(Forces_Fem_TCS, Moments_Fem_TCS),
                   "Femur_Kinetics_in_ImageCS":(force_femur, moments_femur)}

    force_axes = {'x_load':0, 'y_load':1,'z_load':2}
    moment_axes = {'x_moment':0, 'y_moment':1,'z_moment':2}

    # #plot centers of mass for checking for
    # fig = plt.figure(figsize=(12.8, 7.2))
    # fig.suptitle("centers_of_mass")
    # ax = fig.add_subplot(121)
    # plt.title('Femur')
    # plt.xlabel('Time')
    # plt.ylabel('COM')
    # ax.plot(time_steps, femur_com_data[:, 0], label='x')
    # ax.plot(time_steps, femur_com_data[:, 1], label='y')
    # ax.plot(time_steps, femur_com_data[:, 2], label='z')
    # plt.legend(loc="upper left")
    # ax = fig.add_subplot(122)
    # plt.title('Tibia')
    # plt.xlabel('Time')
    # plt.ylabel('COM')
    # ax.plot(time_steps, tibia_com_data[:, 0], label='x')
    # ax.plot(time_steps, tibia_com_data[:, 1], label='y')
    # ax.plot(time_steps, tibia_com_data[:, 2], label='z')
    # plt.legend(loc="upper left")
    # plt.savefig('COM' + '.png')

    for title, kinetics in iter(model_names.items()):
        forces = kinetics[0]
        moments = kinetics[1]

        try:
            # prepare the figure
            fig = plt.figure(figsize=(12.8, 7.2))
            fig.suptitle(title)
            data_list = [("time_steps", time_steps)]

            ax = fig.add_subplot(121)
            plt.title('Force')
            plt.xlabel('Time')
            plt.ylabel('Force [N]')

            for ax_name, idx in iter(force_axes.items()):
                ax.plot(time_steps, forces[:,idx], label=ax_name)
                data_list.append((ax_name , forces[:,idx]))
            plt.legend(loc="upper left")

            ax = fig.add_subplot(122)
            plt.title('Moment')
            plt.xlabel('Time')
            plt.ylabel('Moment [Nmm]')

            for ax_name, idx in iter(moment_axes.items()):
                ax.plot(time_steps, moments[:,idx], label=ax_name)
                data_list.append((ax_name, moments[:,idx]))
            plt.legend(loc="upper left")

            plt.savefig(title+'.png')
            # plt.show()

            save_to_csv(data_list, title+".csv")
            plt.close()
        except:
            # just save the data
            data_list = [("time_steps", time_steps)]

            for ax_name, idx in iter(force_axes.items()):
                data_list.append((ax_name, forces[:,idx]))

            for ax_name, idx in iter(moment_axes.items()):
                data_list.append((ax_name, moments[:,idx]))

            save_to_csv(data_list, title+".csv")


def ProcessAndPlotJointKinetics(rigid_connector_moments, rigid_connector_forces, constraint_info, time_steps, joint_axes):
    """plot the forces and moments in the tibiofemoral joint"""

    # plot the forces and moments in the tibiofemoral joint

    axes_forces = {}
    axes_moments = {}

    connector_forces = {}
    connector_moments = {}

    # tibiofemoel kinetcs as actuator forces
    for joint_name, joint_info in iter(constraint_info.items()):

        if "Patellar" in joint_name:
            pass
        else: # if its a tibiofemoral axis

            joint_num = joint_info.joint_number

            force = rigid_connector_forces[joint_num]
            moment = rigid_connector_moments[joint_num]

            joint_axis = joint_axes[joint_name]

            # get the projection of the force and moment on the axis at each time step
            force_along_axis = np.sum(force * joint_axis, axis=1)
            moment_along_axis = np.sum(moment * joint_axis, axis=1)

            # actuator moments and forces along joint axis
            axes_forces[joint_name] = force_along_axis
            axes_moments[joint_name] = moment_along_axis

            # # net forces and moments in image coordinate system
            # connector_forces[joint_name] =force
            # connector_moments[joint_name] = moment


    # try creating figures and csv, if fails just do csv

    try:
        # prepare the figure
        fig = plt.figure(figsize=(12.8,7.2))
        fig.suptitle("Tibiofemoral_Constraint_Kinetics")
        data_list = [("time_steps", time_steps)]

        ax = fig.add_subplot(121)
        plt.title('Force')
        plt.xlabel('Time')
        plt.ylabel('Force [N]')

        for ax_name, force in iter(axes_forces.items()):
            ax.plot(time_steps, force, label=ax_name+'_axis')
            data_list.append((ax_name + '_Force [N]', force))
        plt.legend(loc="upper left")

        ax = fig.add_subplot(122)
        plt.title('Moment')
        plt.xlabel('Time')
        plt.ylabel('Moment [Nmm]')

        for ax_name, moment in iter(axes_moments.items()):
            ax.plot(time_steps, moment, label = ax_name+'_axis')
            data_list.append((ax_name + '_Moment [Nmm]', moment))
        plt.legend(loc="upper left")

        plt.savefig('Tibiofemoral_Kinetics.png')
        #plt.show()

        save_to_csv(data_list, "Tibiofemoral_Kinetics.csv")
        plt.close()
    except:
        print('failed creating images, trying to create csv')
        data_list = [("time_steps", time_steps)]
        for ax_name, force in iter(axes_forces.items()):
            data_list.append((ax_name + '_Force [N]', force))
        for ax_name, moment in iter(axes_moments.items()):
            data_list.append((ax_name + '_Moment [Nmm]', moment))
        save_to_csv(data_list, "Tibiofemoral_Kinetics.csv")

    # forces and moments on each connector in image coordinate system

    # for joint_name, force in iter(connector_forces.items()):
    #     moment = connector_moments[joint_name]
    #     axes = ['x', 'y', 'z']
    #     try:
    #         # prepare the figure
    #         fig = plt.figure(figsize=(12.8,7.2))
    #         fig.suptitle(joint_name+"_Connector_Kinetics_ImageCS")
    #         data_list = [("time_steps", time_steps)]
    #
    #         ax = fig.add_subplot(121)
    #         plt.title('Force')
    #         plt.xlabel('Time')
    #         plt.ylabel('Force [N]')
    #
    #         for i in range(3):
    #             ax.plot(time_steps, force[:,i], label=axes[i])
    #             data_list.append((axes[i] + '_Force [N]', force[:,i]))
    #         plt.legend(loc="upper left")
    #
    #         ax = fig.add_subplot(122)
    #         plt.title('Moment')
    #         plt.xlabel('Time')
    #         plt.ylabel('Moment [Nmm]')
    #
    #         for i in range(3):
    #             ax.plot(time_steps, moment[:,i], label = axes[i])
    #             data_list.append((axes[i] + '_Moment [Nmm]', moment[:,i]))
    #         plt.legend(loc="upper left")
    #
    #         plt.savefig(joint_name+'_Connector_Kinetics_ImageCS.png')
    #         #plt.show()
    #
    #         save_to_csv(data_list, joint_name+"_Connector_Kinetics_ImageCS.csv")
    #     except:
    #         print('failed creating images, trying to create csv')
    #         data_list = [("time_steps", time_steps)]
    #         for i in range(3):
    #             data_list.append((axes[i] + '_Force [N]', force[:,i]))
    #         for i in range(3):
    #             data_list.append((axes[i] + '_Moment [Nmm]', moment[:,i]))
    #         save_to_csv(data_list, joint_name+"_Connector_Kinetics_ImageCS.csv")

def ProcessAndPlotFemurKinematics(rigid_bodies, rotation_quaternion_data, time_steps, bone_axes, com_data):

    # get transformation of femur in image cs
    Femur_in_World = BoneinWorld_Transform(bone_axes, rotation_quaternion_data, rigid_bodies, 'FMB', com_data)

    # this gives the relative rotation from intial position
    femur_rot_quat = rotation_quaternion_data[rigid_bodies['FMB'].id]
    rot_x, rot_y,rot_z = euler_angles_from_quaternion(femur_rot_quat)

    # # center of mass positions are
    # pos_x = Femur_in_World[:,0,3]
    # pos_y = Femur_in_World[:,1,3]
    # pos_z = Femur_in_World[:,2,3]

    relative_trans = Femur_in_World[:,0:3,3]- Femur_in_World[:,0:3,3][0]

    data_list = Plot3DMotion(relative_trans, time_steps,
                             'Translation of femur origin in image coordinate system',
                             'Femur_in_Image_Translation.png')
    save_to_csv(data_list, "Femur_in_Image_Translation.csv")

    data_list = Plot3DMotion(np.array([rot_x,rot_y,rot_z]).T, time_steps,
                             'Rotations of Femur in image coordinate system', 'Femur_in_Image_Rotation.png', units='rad')
    save_to_csv(data_list, "Femur_in_Image_Rotation.csv")

def save_to_csv(data_list, title):
    """data contains a list of tuples of the (headers,data) for each column of the file """

    header_string = ''
    all_data = np.zeros((len(data_list[0][1]),len(data_list)))

    for i,tup in enumerate(data_list):
        header_string += tup[0] + ','
        all_data[:,i] = tup[1]

    np.savetxt(title, all_data, delimiter=",",header=header_string)


def MakeGraphs(log_filename, model_properties_xml, folder_name=None):

    # get the directory containing the log file
    dir = os.path.dirname(log_filename)

    if len(dir)>0: # if we are already in the correct directory, we don't need to change directory
        os.chdir(dir)

    # get the febio input filename
    febio_input_filename = find_input_filename(log_filename)

    # get the axes of the bones from the model properties file
    bone_axes = get_bone_axes(model_properties_xml)

    # get the names of the constraints to match with the numbering in the log file, from the febio input file
    constraint_info = GetConstraintInfo(febio_input_filename) # {name: cylindrical_joint}

    # parse the logfile for the rigid body info
    rigid_bodies = find_rigid_bodies(log_filename)

    # parse the logfile for the time steps and data
    all_data, time_steps = collect_data(log_filename)

    all_data, time_steps = add_initial_values(all_data, time_steps, rigid_bodies)

    com_data = all_data['center_of_mass']
    rot_quat_data = all_data['rotation_quaternion']
    rigid_connector_moments = all_data['Rigid_Connector_Moment']
    rigid_connector_forces = all_data['Rigid_Connector_Force']
    rigid_moments = all_data['Reaction_Torques']
    rigid_forces = all_data['Reaction_Forces']

    joint_axes = get_joint_axes(rigid_bodies, rot_quat_data, constraint_info)

    # create a folder called Processed Results which will contain all the graphs and xml files associated with them
    if folder_name is not None:
        Name = folder_name
    else:
        Name = 'Processed_Results'

    try:
        os.mkdir(Name)
        os.chdir(Name)
    except OSError:
        os.chdir(Name)

    # Joint Kinematics
    ProcessAndPlotJointKinematics(com_data, rot_quat_data, time_steps, constraint_info, joint_axes, bone_axes, rigid_bodies)

    # Tibia and Patella Translations
    ProcessAndPlotTranslations(rigid_bodies, com_data, rot_quat_data, time_steps, bone_axes)

    # # kinematics as rotations and translations of femur
    # ProcessAndPlotFemurKinematics(rigid_bodies, rot_quat_data, time_steps, bone_axes, com_data)

    # Constraint Kinetics
    ProcessAndPlotJointKinetics(rigid_connector_moments, rigid_connector_forces, constraint_info, time_steps, joint_axes)

    # # Tibia and Femur kinetics
    # ProcessAndPlotBoneKinetics(rigid_bodies, rigid_moments, rigid_forces, time_steps, bone_axes, com_data)

    print('\n')
    print('Graphs were created in ' + os.path.join(dir + os.sep, Name))

def run_all_in_file(xml_file):
    file_tree = et.parse(xml_file)
    file_info = file_tree.getroot()

    gen_files = file_info.find("general_files")
    feb_file = gen_files.find("febio_file").text
    mod_props_file= gen_files.find("model_properties_file").text

    dirname = os.path.dirname(feb_file)

    models = file_info.find("Models")

    # collect the names of all the models to run

    all_log_files = []
    for mod in models:
        if mod.tag is et.Comment:
            continue

        model_name = mod.attrib["name"]
        log_file = model_name+'.log'
        all_log_files.append(log_file)



    # run the models
    for lf in all_log_files:
        os.chdir(dirname) # re-enter this directory each time,as directory changes in make graphs
        folder_name = 'Processed_Results_'+lf.split('.')[0]
        MakeGraphs(lf, mod_props_file, folder_name)

if __name__ == '__main__':

    MakeGraphs(*sys.argv[1:])

    # using the exp to mod xml file to run on a bunch of models:
    # exp_to_mod = "C:\\Users\schwara2\Documents\Open_Knees\du02_calibration\CustomizedFullModels\ExperimentalLoading03\Exp_to_Mod.xml"
    # run_all_in_file(exp_to_mod)

    # log_filename = "C:\\Users\schwara2\Documents\Open_Knees\du02_calibration\InSituStrain\\test\FMB_ACL_TBB_01.log"
    # model_properties_xml = "C:\\Users\schwara2\Documents\Open_Knees\du02_calibration\InSituStrain\\test\ModelProperties.xml"
    #
    # MakeGraphs(log_filename, model_properties_xml)



