# python 3 version of FebCustomization scipt. as of 02/06/2020 any
# future changes will be made in the this version of the script only (FebCustomization_p3.py)

from lxml import etree as et
import sys
import datetime
import copy
import os

class Mooney_Rivlin:

    def __init__(self, density, c1, c2, k):
        self.type = "Mooney-Rivlin"
        self.density=density
        self.c1=c1
        self.c2=c2
        self.k=k


class Transversely_Isotropic:

    def __init__(self, density, c1, c2, k, c3, c4, c5, lam_max, initial_stretch=None):
        self.type = "trans iso Mooney-Rivlin"
        self.density=density
        self.c1=c1
        self.c2=c2
        self.k=k
        self.c3 = c3
        self.c4 = c4
        self.c5 = c5
        self.lam_max = lam_max
        self.initial_stretch = initial_stretch


class Cylindrical_Joint:

    def __init__(self, name, origin, axis, body_a, body_b, prescribed_translation=None, prescribed_rotation=None):
        self.name = name
        self.origin = origin
        self.body_a = body_a
        self.body_b = body_b
        self.axis = axis
        self.prescribed_translation = prescribed_translation
        self.prescribed_rotation = prescribed_rotation


class My_Material:

    def __init__(self, mat_name, mat_id, mat_type):
        self.name = mat_name
        self.id = int(mat_id)
        self.type = mat_type


class Linear_Spring:

    def __init__(self, name, body_a, body_b, insertion_a, insertion_b, k):
        self.name = name
        self.body_a = body_a
        self.body_b =body_b
        self.insertion_a = insertion_a
        self.insertion_b = insertion_b
        self.k = k


def myElement(tag, parent=None, text=None, attrib={}, **extra):
    """Shortcut to create an xml element (root or child), add text and other attributes"""
    e = et.Element(tag, attrib, **extra)
    if text:	e.text = text
    if parent is not None:
        parent.append(e)
    return e


def add_subelements_from_dict(parent, dictionary, clear_out=False, exclude_these = []):
    """Shortcut to add a bunch of subelements where the dictionary contains {tag:text} for all subelements.
    If clear out is True, all existing subelements to the parent will first be cleared before adding the
    ones from the list"""

    if clear_out:
        for subelement in parent:
            parent.remove(subelement)

    for tag, text in dictionary.items():
        if tag in exclude_these:
            continue
        else:
            myElement(tag, parent, text=str(text))


def get_joint_info(AL, Right_or_Left):
    """ This function will collect the imformation for the rigid joints, including the locations of imaginary bodies
    anatomical_landmarks (AL) is a dictionary containing the coordinates of Each anatomical landmark or axis"""

    # name the joints depending of if its a right knee or a left knee
    flex_ext_name = "Extension_Flexion"
    pat_flex_name = "Patellar_Extension_Flexion"
    if Right_or_Left == 'R':
        abd_add_name = "Adduction_Abduction"
        int_ext_name = "Internal_External"
        pat_rot_name  = "Patellar_Medial_Rotation"
        pat_tilt_name = "Patellar_Medial_Tilt"
    else:
        abd_add_name = "Abduction_Adduction"
        int_ext_name = "External_Internal"
        pat_rot_name = "Patellar_Lateral_Rotation"
        pat_tilt_name = "Patellar_Lateral_Tilt"

    # store the information about the joints
    Joints = {}
    # store the names and locations of imaginary bodies
    imaginary_bodies_data = {}

    # try adding the Tibiofemoral joint bodies
    if 'TBO' in list(AL.keys()) and 'FMO' in list(AL.keys()):
        imaginary_bodies_data['TFTO'] = AL['TBO']
        imaginary_bodies_data['TFFO']= AL['FMO']
        Flex_Ext = Cylindrical_Joint(flex_ext_name, origin=AL['FMO'], axis=AL['Xf_axis'],
                                     body_a='FMB', body_b='TFFO')
        Ext_Int = Cylindrical_Joint(int_ext_name, origin=AL['TBO'], axis=AL['Zt_axis'],
                                    body_a='TFTO', body_b='TBB')
        Abd_Add = Cylindrical_Joint(abd_add_name, origin=AL['Ftf_Xf_intersect'], axis=AL['Ftf_axis'],
                                    body_a='TFFO', body_b='TFTO')
        Joints[flex_ext_name] = Flex_Ext
        Joints[int_ext_name] = Ext_Int
        Joints[abd_add_name] = Abd_Add
    else:
        print('\n Tibiofemoral Joint not included in the model')

    # try adding the Patellofemoral joint bodies
    if 'PTO' in list(AL.keys()) and 'FMO' in list(AL.keys()):
        imaginary_bodies_data['PFPO'] = AL['PTO']
        imaginary_bodies_data['PFFO']= AL['FMO']
        # Cylindrical Joint info to define Patellofemoral Joint
        Flex_Shift = Cylindrical_Joint(pat_flex_name, origin=AL['FMO'], axis=AL['Xf_axis'],
                                       body_a='FMB', body_b='PFFO')
        Pat_Tilt = Cylindrical_Joint(pat_tilt_name, origin=AL['PTO'], axis=AL['Zp_axis'],
                                     body_a='PFPO', body_b='PTB')
        Pat_Rot = Cylindrical_Joint(pat_rot_name, origin=AL['Fpf_Xf_intersect'], axis=AL['Fpf_axis'],
                                    body_a='PFFO', body_b='PFPO')
        Joints[pat_flex_name]=Flex_Shift
        Joints[pat_tilt_name]=Pat_Tilt
        Joints[pat_rot_name]=Pat_Rot
    else:
        print('\n Patellofemoral Joint not included in the model')

    # add the quadriceps origins to the imaginary bodies dictionary for later use in a spring joint
    if 'QAT' in list(AL.keys()):
        imaginary_bodies_data['QSO'] = AL['QAT']

    return Joints, imaginary_bodies_data


def get_connector_info(AL):
    """ create linear springs objects, return as dictionary """

    if 'MPFL-P_01' in list(AL.keys()):
        k_mpfl = 100.0/2
        k_lpfl = 16.0/2
        MPFL_01 = Linear_Spring("MPFL_01", 'FMB', 'PTB', AL['MPFL-F_01'], AL['MPFL-P_01'], k_mpfl)
        MPFL_02 = Linear_Spring("MPFL_02", 'FMB', 'PTB', AL['MPFL-F_02'], AL['MPFL-P_02'], k_mpfl)
        LPFL_01 = Linear_Spring("LPFL_01", 'FMB', 'PTB', AL['LPFL-F_01'], AL['LPFL-P_01'], k_lpfl)
        LPFL_02 = Linear_Spring("LPFL_02", 'FMB', 'PTB', AL['LPFL-F_02'], AL['LPFL-P_02'], k_lpfl)

        connectors = {"MPFL_01":MPFL_01,"MPFL_02":MPFL_02, "LPFL_01":LPFL_01, "LPFL_02":LPFL_02}
    else:
        print('\n MPFL, LPFL connectors are not included in the model')
        connectors = {}

    return connectors


def get_material_properties(ModelProperties_xml):
    # extract the Material properties from the ModelProperties_xml, save in material_properties dict with MooneyRivlin
    # or trans iso MooneyRivling class objects

    ## there is probably an easier/faster way to do this than saving to class objects- this was a patch-up when I
    # changed from hard coded to getting properties form the xml.

    ModelProperties_tree = et.parse(ModelProperties_xml)
    ModelProperties_root = ModelProperties_tree.getroot()
    Material = get_section("Material", ModelProperties_root)
    material_properties = {}

    trans_iso_mats = get_section("material", Material, use_other_attribute="type",
                                 attribute_value="trans iso Mooney-Rivlin")
    mooney_riv_mats = get_section("material", Material, use_other_attribute="type", attribute_value="Mooney-Rivlin")

    def store_to_trans_class(mat):
        try:
            mat_name = mat.attrib["name"]
            density = get_section("density", mat)
            c1 = get_section("c1", mat)
            c2 = get_section("c2", mat)
            c3 = get_section("c3", mat)
            c4 = get_section("c4", mat)
            c5 = get_section("c5", mat)
            k = get_section("k", mat)
            lam_max = get_section("lam_max", mat)
            initial_stretch = get_section("initial_stretch", mat)
            material_properties[mat_name] = Transversely_Isotropic(density.text, c1.text, c2.text, k.text, c3.text, c4.text,
                                                                   c5.text, lam_max.text, initial_stretch.text)
        except: # if theres a comment
            pass

    def store_to_mooney_class(mat):
        try:
            mat_name = mat.attrib["name"]
            density = get_section("density", mat)
            c1 = get_section("c1", mat)
            c2 = get_section("c2", mat)
            k = get_section("k", mat)
            material_properties[mat_name] = Mooney_Rivlin(density.text, c1.text, c2.text, k.text)
        except: # if there a comment
            pass

    if isinstance(trans_iso_mats, list): # if theres more than 1 transversely isotropic material
        for mat in trans_iso_mats:
            store_to_trans_class(mat)
    else: # if there is only 1 transversely isotropic material
        store_to_trans_class(trans_iso_mats)

    if isinstance(mooney_riv_mats, list):# if theres more than 1 mooney rivlin material
        for mat in mooney_riv_mats:
            store_to_mooney_class(mat)
    else: # if theres only 1 mooney rivlin material
        store_to_mooney_class(mooney_riv_mats)

    return material_properties


def update_materials(MaterialSection, LoadDataSection,  material_properties, AL):
    """ This function updates the material section with the new materials including prestrains"""

    # point to the centers of mass in the anatomical landmarks
    centers_of_mass = {'FMB':'FMO','TBB':'TBO','PTB':'PTO'}

    # store the material names, ids, types in a dictionary of My_Material objects
    materials_info = {}

    # variable to check if there are any prestrain materials
    is_there_prestrain = False
    load_curve_num = 0

    for material in MaterialSection:

        try:
            tissue_name = material.attrib["name"]
            tissue_id = material.attrib["id"]
        except: # if theres a comment
            continue

        if material.attrib["type"] == 'rigid body':

            new_material = My_Material(tissue_name, tissue_id, 'rigid body')
            materials_info[tissue_name] = new_material

            # define the COM of the rigid bodies at their anatomical landmarks for patella, tibia, femur
            com_section = get_section('center_of_mass', material)
            try:
                landmark_name = centers_of_mass[material.attrib['name']]
                com = AL[landmark_name]
                com_section.text = str(com[0])+','+str(com[1])+','+str(com[2])
            except:
                com_section.text = '0,0,0'
        else:
            tissue_properties = material_properties[tissue_name]

            new_material = My_Material(tissue_name, tissue_id, tissue_properties.type)
            materials_info[tissue_name] = new_material

            # if it's a prestrain material, the material must be defined using the prestrain element
            if hasattr(tissue_properties, 'initial_stretch') and tissue_properties.initial_stretch != 'None' and tissue_properties.initial_stretch != None:

                # update the prestrain variable
                is_there_prestrain = True

                # add to the load curve number
                load_curve_num += 1

                material.attrib["type"]= "uncoupled prestrain elastic"
                
                # remove any existing subelements
                for prop in material:
                    material.remove(prop)


                # create a dictionary of all the  tissue properties we want to add
                tissue_properties_dict = tissue_properties.__dict__

                # add the bulk modulus to the material element
                myElement('k', material, text=str(tissue_properties.k))

                # add the elastic element which holds the material properties as subelements for a prestrain material
                elastic = myElement('elastic', material, type = tissue_properties.type)

                # add remaining tissue properties. exclude k, initial stretch, and type properties
                add_subelements_from_dict(elastic, tissue_properties_dict, exclude_these=['initial_stretch', 'type', 'k'])

                # add the prestrain element with initial stretch of 1.0
                prestrain = myElement('prestrain', material, type='in-situ stretch')
                myElement('stretch', prestrain, text="1", lc=str(load_curve_num))
                myElement('isochoric', prestrain, text=str(1))

                # create the load curve with the initial stretch for this material - if there is prestrain, then the model run time is 2, and the prestrain is from 0 to 1
                add_load_curve(LoadDataSection,load_curve_num, points = ['0,1.0','1,{}'.format(tissue_properties.initial_stretch),
                                '2,{}'.format(tissue_properties.initial_stretch)], name="{}_prestrain".format(tissue_name))

            # if not add the material properties as usual
            else:
                material.attrib["type"]= tissue_properties.type 
                tissue_properties_dict = tissue_properties.__dict__
                add_subelements_from_dict(material, tissue_properties_dict, clear_out=True, exclude_these=['initial_stretch', 'type'])

    return materials_info, is_there_prestrain, load_curve_num


def add_imaginary_bodies(MaterialSection, material_info, imaginary_body_data):
    """ Add rigid bodies to the Material section for all the imaginary bodies"""

    last_id = 0
    for mat in list(material_info.values()):
        if mat.id > last_id:
            last_id = mat.id

    for name, position in imaginary_body_data.items():
        mat_name = name
        mat_id = last_id+1
        last_id += 1
        material = myElement('material', MaterialSection, id =str(mat_id) , type='rigid body', name=mat_name)
        myElement('center_of_mass', material, text = str(position[0])+','+str(position[1])+','+str(position[2]))
        myElement('density', material, text = str(1))

        new_material = My_Material(mat_name, str(mat_id), 'rigid body')
        material_info[mat_name] = new_material

    return material_info


def add_nodeset(GeometrySection, node_ids, nodeset_name):
    """ Add a nodeset to the geomerty section """

    NodeSet = myElement('NodeSet', GeometrySection, name=nodeset_name)
    for n_id in node_ids:
        myElement('node', NodeSet, id=str(n_id))


def add_rigid_nodes(BoundarySection, material_info, nodeset_name, rigid_boundary_name):
    """Attach the nodes at the top of the QAT to the imaginary body for the slider joint"""
    rb_id = material_info['QSO'].id
    myElement('rigid', BoundarySection, rb=str(rb_id), node_set=nodeset_name, name=rigid_boundary_name)


def add_outputs(OutputSection, prestrain):
    """ Add the desired outputs to the logfile and plotfile sections"""

    logfile_element = get_section("logfile", OutputSection)
    plotfile_element = get_section("plotfile", OutputSection)

    # _____________logfile____________________
    # rigid body data
    myElement('rigid_body_data', logfile_element, name='center_of_mass', data='x;y;z')
    myElement('rigid_body_data', logfile_element, name='rotation_quaternion', data='qx;qy;qz;qw')
    myElement('rigid_body_data', logfile_element, name='Reaction_Forces', data='Fx;Fy;Fz')
    myElement('rigid_body_data', logfile_element, name='Reaction_Torques', data='Mx;My;Mz')

    # rigid joint data
    myElement('rigid_connector_data', logfile_element, name='Rigid_Connector_Force', data='RCFx;RCFy;RCFz')
    myElement('rigid_connector_data', logfile_element, name='Rigid_Connector_Moment', data='RCMx;RCMy;RCMz')
    # rotation and translation of the joints is the relative position and angle of the "imaginary" rigid bodies

    # # fiber stretch
    # myElement('stretch', logfile_element)  # may need to specify for which materials?

    # _____________plotfile_______________________
    # rigid body data
    myElement('var', plotfile_element, type='rigid position')
    myElement('var', plotfile_element, type='rigid angular position')
    myElement('var', plotfile_element, type='rigid force')
    myElement('var', plotfile_element, type='rigid torque')

    # rigid connector data
    # could not find plotfile output for rigid joints in febio user manual

    # prestrain stretch- if there is prestrain in the model
    if prestrain:
        myElement('var', plotfile_element, type='prestrain stretch')

    # fiber stretch
    myElement('var', plotfile_element, type='fiber stretch')


def add_prestrain_rule(ConstraintsSection):
    """Add the prestrain rule to the Constraints section"""

    prestrain_parameters = {}
    prestrain_parameters['update'] = 1
    prestrain_parameters['tolerance'] = 0.03
    prestrain_parameters['min_iters'] = 0
    prestrain_parameters['max_iters'] = 0

    constraint = myElement('constraint', ConstraintsSection, type='prestrain')
    add_subelements_from_dict(constraint, prestrain_parameters)


def add_rigid_joints(ConstraintsSection, Joints, material_info, AL):
    """Add rigid joints to the constraints section, Joints is a dictionary containing {name:Cylindrical Joint object}"""

    for J in list(Joints.values()):
        constraint = myElement('constraint', ConstraintsSection, type='rigid cylindrical joint', name=J.name)
        body_a_id = material_info[J.body_a].id
        body_b_id = material_info[J.body_b].id
        parameters = {'tolerance':0, 'gaptol':0.01, 'angtol':0.0001, 'force_penalty':1e4, 'moment_penalty':3e6,
                      'body_a':body_a_id, 'body_b':body_b_id, 'minaug':0, 'maxaug':0,
                      'joint_origin':str(J.origin[0])+','+str(J.origin[1])+','+str(J.origin[2]), 'joint_axis':str(J.axis[0])+','+str(J.axis[1])+','+str(J.axis[2])}

        if J.prescribed_rotation == None:
            parameters['prescribed_rotation'] = 0
            parameters['rotation'] = 0
        else:
            parameters['prescribed_rotation'] = 1
            parameters['rotation'] = J.prescribed_rotation

        if J.prescribed_translation == None:
            parameters['prescribed_translation'] = 0
            parameters['translation'] = 0
        else:
            parameters['prescribed_translation'] = 1
            parameters['translation'] = J.prescribed_translation

        add_subelements_from_dict(constraint, parameters)

    # add a spring between the Quadriceps tendon and FMB if the QAT is included in the model
    if "QAT" in list(material_info.keys()):
        origin = AL['QAT'] # the origin of the spring is at the top of the quadriceps tendon
        body_a_id = material_info['QSO'].id
        body_b_id = material_info['FMB'].id
        spring_constant = 0.1
        constraint = myElement('constraint', ConstraintsSection, type='rigid spring')
        parameters = {'body_a': body_a_id, 'body_b': body_b_id, 'insertion_a':str(origin[0]) + ',' + str(origin[1]) + ',' + str(origin[2]),
                      'insertion_b':str(origin[0]) + ',' + str(origin[1]) + ',' + str(origin[2]),
                      'k':spring_constant}
        add_subelements_from_dict(constraint, parameters)


def prescribe_joint_motions(ConstraintsSection, rotations_dictionary = {}, translations_dictionary = {}):
    """ prescribe the motion of rigid joints. dictionaries contain {"name of joint":(rotation/translation amount, load_curve_id)}}.
    Constraints Section should already contain the joints, this function will just make changes"""

    for name, values in rotations_dictionary.items():

        my_constraint = get_section('constraint', ConstraintsSection, use_other_attribute='name', attribute_value=name)
        rotation_value = values[0]
        lc_id_rot = values[1]

        pres_rot = get_section('prescribed_rotation', my_constraint)
        pres_rot.text = str(1)

        rot = get_section('rotation', my_constraint)
        rot.attrib["lc"] = str(lc_id_rot)
        rot.text = str(rotation_value)

    for name, values in translations_dictionary.items():
        my_constraint = get_section('constraint', ConstraintsSection, use_other_attribute='name', attribute_value=name)
        translation_value = values[0]
        lc_id_tran = values[1]

        pres_tra = get_section('prescribed_translation', my_constraint)
        pres_tra.text = str(1)

        tra  = get_section('translation', my_constraint)
        tra.attrib["lc"] = str(lc_id_tran)
        tra.text = str(translation_value)


def add_discrete_elements(DiscreteSection, LoadDataSection, GeometrySection, connector_info, load_curve_num):
    """ Add rigid springs at the ligament insertions sites"""

    try:
        MPFL_01 = connector_info['MPFL_01']
        MPFL_02 = connector_info['MPFL_02']
        LPFL_01 = connector_info['LPFL_01']
        LPFL_02 = connector_info['LPFL_02']
    except KeyError: # no MPFL, LPFL in this model
        return

    # define the discrete sets in the geometry sections
    discrete_set1 = myElement('DiscreteSet', GeometrySection, name="MPFL")
    myElement('delem', discrete_set1, text = str(MPFL_01.insertion_a)+','+ str(MPFL_01.insertion_b))
    myElement('delem', discrete_set1, text = str(MPFL_02.insertion_a)+','+str(MPFL_02.insertion_b))
    discrete_set2 = myElement('DiscreteSet', GeometrySection, name="LPFL")
    myElement('delem', discrete_set2, text = str(LPFL_01.insertion_a)+','+str(LPFL_01.insertion_b))
    myElement('delem', discrete_set2, text=str(LPFL_02.insertion_a)+','+str(LPFL_02.insertion_b))


    # define the discrete materials for the mpfl, lpfl
    dmat1 = myElement('discrete_material', DiscreteSection, id="1", type="nonlinear spring")
    load_curve_num += 1
    myElement('force', dmat1, lc=str(load_curve_num), text= str(1.0))

    # define the load curves for the force-displacement data
    # lc = myElement('loadcurve', LoadDataSection, id=str(load_curve_num), type="linear")
    # points3 = ['-1,0', '0,0', '1,{}'.format(MPFL_01.k)]
    # for i in points3:
    #     myElement('point', lc, text=i)

    mpfl_points = ['-1,0', '0,0', '1,{}'.format(MPFL_01.k)]
    add_load_curve(LoadDataSection, load_curve_num=load_curve_num, points=mpfl_points, name="MPFL_force_displacement")

    load_curve_num += 1
    dmat2 = myElement('discrete_material', DiscreteSection, id="2", type="nonlinear spring")
    myElement('force', dmat2, lc=str(load_curve_num), text= str(1.0))

    # lc4 = myElement('loadcurve', LoadDataSection, id=str(load_curve_num), type="linear")
    # points4 = ['-1,0','0,0','1,{}'.format(LPFL_01.k)]
    # for j in points4:
    #     myElement('point', lc4, text = j)

    lpfl_points = ['-1,0','0,0','1,{}'.format(LPFL_01.k)]
    add_load_curve(LoadDataSection, load_curve_num=load_curve_num, points=lpfl_points, name="LPFL_force_displacement")

    # assign the discrete materials to the discrete set
    myElement('discrete', DiscreteSection, discrete_set="MPFL", dmat="1")
    myElement('discrete', DiscreteSection, discrete_set="LPFL", dmat="2")

    return load_curve_num


def add_spring_tie(DiscreteSection, GeometrySection, node_pairs):

    # define the discrete set in the geometry section between each node pair
    discrete_set = myElement('DiscreteSet', GeometrySection, name="MCL_MNS-M_tie")
    for n1,n2 in iter(node_pairs.items()):
        myElement('delem', discrete_set, text=str(n1) + ',' + str(n2))

    # check what the last dmat id was
    prev_id = get_section('discrete_material',DiscreteSection)[-1].attrib['id']
    prev_id = int(prev_id)

    # define the discrete materials for the springs
    dmat = myElement('discrete_material', DiscreteSection, id=str(prev_id+1), type="linear spring")
    myElement('E', dmat, text=str(1000)) # spring constant

    # assign the discrete material to the discrete set
    myElement('discrete', DiscreteSection, discrete_set="MCL_MNS-M_tie", dmat=str(prev_id+1))


def fix_rigid_bodies(BoundarySection, material_info, list_of_bodies, load_curve_num):
    """ fix all DOF of the rigid bodies in the list of bodies - do this using prescribed instead of fixed for post processing purposes"""

    for body in list_of_bodies:

        try:
            body_id = material_info[body].id
            rigid_body = myElement("rigid_body", BoundarySection, mat=str(body_id))

            all_dof = ['x','y','z','Rx','Ry','Rz']
            for dof in all_dof:
                myElement('prescribed', rigid_body, bc=dof, lc=str(load_curve_num), text="0.0")
                # myElement('fixed', rigid_body, bc=dof)
        except:
            pass


def update_control_parameters(ControlSection, control_params, time_stepper_params):

    add_subelements_from_dict(ControlSection, control_params, clear_out=True)

    myElement('analysis', ControlSection, type='static')

    time_stepper = myElement('time_stepper', ControlSection)

    add_subelements_from_dict(time_stepper, time_stepper_params)


def add_fiber_orienations(MeshDataSection, MaterialSection, fiber_orientations):
    """Add fiber orientations to the MeshData Section"""

    for ligament_name, orientation in fiber_orientations.items():

        if len(orientation.shape)> 1: # if we have fiber directions for each element

            ElementData = myElement('ElementData', MeshDataSection, var='fiber', elem_set =ligament_name)
            for c,vec in enumerate(orientation):
                myElement('elem', ElementData, lid = str(c+1), text =str(vec[0])+','+str(vec[1])+','+str(vec[2])) #### there is no example of this in febio, I'm assuming this is how to specify it?

        else: # one fiber direction for all elements in ligament
            this_material = get_section('material', MaterialSection, use_other_attribute='name', attribute_value=ligament_name)

            if this_material.attrib['type'] == 'prestrain elastic': # if it's a prestrain material
                elastic_section = get_section('elastic', this_material)
                myElement('fiber', elastic_section, type='vector', text =str(orientation[0])+','+str(orientation[1])+','+str(orientation[2]))
            else:
                myElement('fiber', this_material, type='vector', text =str(orientation[0])+','+str(orientation[1])+','+str(orientation[2]))


def change_contact_type(ContactSection, old_type, new_type):
    for contact in ContactSection:
        if contact.attrib["type"] == old_type:
            contact.attrib["type"] = new_type
        else:
            pass


def update_section_parameters(ParentSection, new_params_dict):

    if isinstance(ParentSection, list):
        for section in ParentSection:
            add_subelements_from_dict(section, new_params_dict, clear_out=True)
    else:
        add_subelements_from_dict(ParentSection, new_params_dict, clear_out=True)


def add_load_curve(LoadDataSection, load_curve_num, points , load_curve_type="linear", name=None):
    """ add a load curve to the load data section, points is a list of strings of the points in the load curve. ex: points = ['0,0','1,1']"""

    if name is not None:
        loadcurve = myElement('loadcurve', LoadDataSection, id=str(load_curve_num), type=load_curve_type, name=name)
    else:
        loadcurve = myElement('loadcurve', LoadDataSection, id=str(load_curve_num), type=load_curve_type)

    for point in points:
        myElement('point', loadcurve, text=point)


def check_master_slave(contact_section, geometry_section):
    """check the contacts to make sure that they are in the correct order for master/slave. if not flip them"""

    # this is the order of the master - slaves for all the contacts that we know is working:
    correct_orders = ['QAT_To_FMB', 'QAT__To_FMC','TBC-L_To_FMC','TBC-L_To_MNS-L','PCL_To_TBB','PCL_To_FMB',
                      'PCL_To_MNS-M','PCL_To_ACL','PTC_To_FMC', 'ACL_To_TBB','ACL_To_FMB','MCL_To_TBB','MCL_To_FMB',
                      'MNS-L_To_FMC','MNS-M_To_TBC-M','MNS-M_To_FMC','LCL_To_FMB','TBC-M_To_FMC','MCL_To_MNS-M']


    for contact in contact_section:

        try:
            pair_name = contact.attrib['surface_pair']
        except: # in case there is a comment
            continue

        if pair_name in correct_orders:
            pass
        else:
            first_part = pair_name.split('_')[0]
            second_part = pair_name.split('_')[-1]
            # find the pair in the geometry file
            surface_pair = get_section('SurfacePair',geometry_section,use_other_attribute='name',attribute_value=pair_name)
            # flip the master and slave
            for i in surface_pair:
                if i.tag == 'master':
                    i.tag ='slave'
                else:
                    i.tag = 'master'

            # rename the surface pair in the geometry file
            new_pair_name = second_part + '_To_' + first_part
            surface_pair.attrib['name'] = new_pair_name

            # reference the new surface pair in the contact
            contact.attrib['surface_pair'] = new_pair_name



def get_section(section_tag, parent, use_other_attribute = None, attribute_value = None, clear_out = False, create_new= True):
    """Get the section from the parent. defualt is to search for the xml tag that matches with section_name.
    If use_other_attribute is specified, will search based on an xml attribute instead of the tag.
    if the section doesn't exist, create it. If clear_out is true, empty the section from whatever was in it previously
    If more than one section exists in the parent with that name, return them as a list"""

    # note this could all be done using built in lxml find functions. maybe go through the files and edit this sometime..
    # but keep in mind this function also creates new elements when none exist, find will not do that

    # if searching by a section tag
    if use_other_attribute is None:
        section_tags = [c.tag for c in parent]

        if section_tags.count(section_tag) == 1:
            section = parent[section_tags.index(section_tag)]
            if clear_out:
                for sub_section in section:
                    section.remove(sub_section)

        elif section_tags.count(section_tag) > 1:
            section_indices = [i for i, x in enumerate(section_tags) if x == section_tag]
            section = []
            for i in section_indices:
                this_section = parent[i]
                section.append(this_section)
                if clear_out:
                    for sub_section in this_section:
                        this_section.remove(sub_section)

        else:  # create it if it doesnt exist
            if create_new:
                section = myElement(section_tag, parent)
            else: section = None

    # if searching by an attribute
    else:
        # get all the sections with the tag. = my_sections
        section_tags = [c.tag for c in parent]
        section_indices = [i for i, x in enumerate(section_tags) if x == section_tag]
        my_sections = []
        for i in section_indices:
            my_sections.append(parent[i])

        # search those sections for the requested attribute
        section_attribute = []
        for c in my_sections:
            try:
                att = c.attrib[use_other_attribute]
                section_attribute.append(att)
            except KeyError:
                section_attribute.append(None) # placeholder if they dont have the attribute

        # return the section/sections from my_section that have the requested attribute

        if section_attribute.count(attribute_value) == 0: # create it if it doesnt exist
            if create_new:
                section = myElement(section_tag, parent, **{use_other_attribute:attribute_value})
            else:
                section = None

        elif section_attribute.count(attribute_value) == 1: # if only 1 return that one
            section = my_sections[section_attribute.index(attribute_value)]
            if clear_out:
                for sub_section in section:
                    section.remove(sub_section)

        else: # if multiple return as list
            section_indices = [i for i, x in enumerate(section_attribute) if x == attribute_value]
            section = []
            for i in section_indices:
                this_section = my_sections[i]
                section.append(this_section)
                if clear_out:
                    for sub_section in this_section:
                        this_section.remove(sub_section)


    return section



def customize_file(febfile, rigid_connector_info , joint_data, imaginary_body_data, qat_nodeset, fiber_orientations, material_properties, anatomical_landmarks, mcl_mns_nodepairs):
    """ make the changes to the febio xml tree """

    feb_tree = et.parse(febfile)
    febio_spec_root = feb_tree.getroot()
    geo_tree, _ = find_geometry_tree(febfile)  # find the geometry file that it points to
    febio_spec_root_geo = geo_tree.getroot()

    # get the load data section, and clear whatever loadcurves were there before
    LoadData_section = get_section('LoadData', febio_spec_root, clear_out=True)

    # update the material section, and add the imaginary bodies
    Material_section = get_section('Material', febio_spec_root)
    material_info, prestrain, load_curve_num = update_materials(Material_section, LoadData_section, material_properties, anatomical_landmarks)
    material_info = add_imaginary_bodies(Material_section, material_info, imaginary_body_data)

    geometry_section = get_section('Geometry', febio_spec_root_geo)

    # add the fiber orientations
    MeshData_section = get_section('MeshData', febio_spec_root)
    add_fiber_orienations(MeshData_section, Material_section, fiber_orientations)

    # add the Discrete elements to define the MPFL, LPFL
    if 'MPFL_01' in list(rigid_connector_info.keys()):
        Discrete_section = get_section('Discrete', febio_spec_root)
        load_curve_num = add_discrete_elements(Discrete_section, LoadData_section, geometry_section, rigid_connector_info, load_curve_num)

    Boundary_section = get_section('Boundary', febio_spec_root)

    # add the nodeset to be attached to the quadriceps slider joint
    if qat_nodeset is None: # if the QAT is not included in the model
        pass
    else:
        add_nodeset(geometry_section, qat_nodeset, "QAT_@_QSO_TiesNodes")
        # add the rigid nodes to the quadriceps slider imaginary body
        add_rigid_nodes(Boundary_section, material_info, "QAT_@_QSO_TiesNodes", "QSO_With_QAT")

    # Remove the fixed bodies from the boundary section (template model fixes every rigid body created)
    rigid_bodies = get_section('rigid_body', Boundary_section)
    try:
        Boundary_section.remove(rigid_bodies) # in case there's only one rigid body in the model
    except TypeError:
        for rb in rigid_bodies:
            Boundary_section.remove(rb)

    # change the contact definition to the latest febio version
    Contact_section = get_section('Contact', febio_spec_root)
    contact_type = "sliding-elastic"
    change_contact_type(Contact_section, old_type="facet-to-facet sliding", new_type=contact_type)
    sliding_contacts = get_section('contact', Contact_section, use_other_attribute="type", attribute_value=contact_type)
    contact_params = {'laugon':0, 'tolerance':0,'gaptol':1e-2,'penalty':0.1,'two_pass':1,'auto_penalty':1,
                      'fric_coeff':0,'search_tol':0.01,'search_radius':0.005,'minaug':0,'maxaug':10,'seg_up':0}
    update_section_parameters(sliding_contacts, contact_params)

    # check the contacts to make sure that the rigid bodies are the slaves
    check_master_slave(Contact_section, geometry_section)

    # remove the mcl-mns-m tie (if there), use springs instead
    try:
        # remove the old tie
        mcl_mns_tie = get_section('contact',Contact_section, use_other_attribute='surface_pair', attribute_value="MCL_To_MNS-M", create_new=False)
        Contact_section.remove(mcl_mns_tie)

        # add the linear springs
        Discrete_section = get_section('Discrete', febio_spec_root)
        add_spring_tie(Discrete_section, geometry_section, mcl_mns_nodepairs)

    except:
        pass


    # # update the mcl-mns-m tie parameters (if it exists) - check for the tie as both MCL_To_MNS-M and vice versa
    # try:
    #     mcl_mns_tie = get_section('contact',Contact_section, use_other_attribute='surface_pair', attribute_value="MCL_To_MNS-M", create_new=False)
    #     new_tie_params = {'penalty':1.0, 'laugon':0}
    #     update_section_parameters(mcl_mns_tie, new_tie_params)
    # except:
    #     try:
    #         mcl_mns_tie = get_section('contact', Contact_section, use_other_attribute='surface_pair',
    #                                   attribute_value="MNS-M_To_MCL", create_new=False)
    #         new_tie_params = {'penalty': 1.0, 'laugon': 0}
    #         update_section_parameters(mcl_mns_tie, new_tie_params)
    #     except:
    #         pass


    # add the outputs required for the logfile and plotfile
    Output_section = get_section('Output', febio_spec_root)
    add_outputs(Output_section, prestrain)

    # find the the prestrain and loading steps
    InitializeStep = get_section("Step", febio_spec_root, use_other_attribute='name', attribute_value="Initialize")
    LoadingStep = get_section("Step", febio_spec_root, use_other_attribute='name', attribute_value="LoadingStep")

    # move the contact section to the loading step- this is important for prestrian cases, contact seems to interfere with febio's prestrain solver
    Contact_section_copy = Contact_section
    febio_spec_root.remove(Contact_section)
    LoadingStep.append(Contact_section_copy)

    # remove the initialize step - we switched to a one step model
    febio_spec_root.remove(InitializeStep)

    # Loading step settings
    LoadingStepConstraints = get_section("Constraints", LoadingStep)

    # create the loading step load curve
    load_curve_num += 1

    if prestrain:
        points = ['0,0','1,0','2,1'] # all the load curves go from time 0 to 2. prestrain is applied from 0 to 1, loading from 1 to 2
        time_steps = 40
        step_size = 0.05
        # add the prestrain constraint
        add_prestrain_rule(LoadingStepConstraints)
    else:
        time_steps = 20
        step_size = 0.05
        points = ['0,0', '1,1'] # the loading is from 0 to 1

    # add the load curve for the loading step
    add_load_curve(LoadData_section, load_curve_num, points = points, name="Loading_Curve")

    # update the loading control parameters
    LoadingStepControl = get_section("Control", LoadingStep)
    control_params_load = {'time_steps': time_steps, 'step_size': step_size, 'max_refs': 25, 'max_ups': 0, 'diverge_reform': 1,
     'reform_each_time_step': 1, 'dtol': 0.01, 'etol': 0.1, 'rtol': 0, 'lstol': 0.9, 'min_residual':0.001,'qnmethod': 1,
     'symmetric_stiffness': 0, 'plot_level': 'PLOT_MAJOR_ITRS'}
    time_stepper_params_load = {'dtmin':0.001, 'dtmax':0.05, 'max_retries':5, 'opt_iter':50, 'aggressiveness':1}
    update_control_parameters(LoadingStepControl, control_params_load, time_stepper_params_load)

    # assuming there are joints in the model, add the joint constraint
    if joint_data:
        add_rigid_joints(LoadingStepConstraints, joint_data, material_info, anatomical_landmarks)
        if "Extension_Flexion" in list(joint_data.keys()): # if there is a flexion extension joint in the model
            prescribed_rotations = {"Extension_Flexion": (-1.57, load_curve_num)}
            prescribe_joint_motions(LoadingStepConstraints, rotations_dictionary=prescribed_rotations)


    # fix the tibia and fibula in the loading step
    LoadingStepBoundary = get_section("Boundary", LoadingStep)
    bodies_to_fix = ["TBB", "FBB"]
    fix_rigid_bodies(LoadingStepBoundary, material_info, bodies_to_fix, load_curve_num)

    # make sure that all the sections are in the correct order in the febio file
    OrderFebioSections(febio_spec_root)
    # OrderFebioSections(InitializeStep) # we removed the intiialized step
    OrderFebioSections(LoadingStep)
    OrderGeometrySections(geometry_section)

    return feb_tree, geo_tree


def OrderGeometrySections(geom_root):
    """Put all the geometry subsections in the correct order"""

    correct_order = ['Nodes','Elements','NodeSet','Edge','Surface', 'DiscreteSet','ElementSet', 'SurfacePair']

    section_tags = [c.tag for c in geom_root]

    counter = 0

    for section_name in correct_order:
        # number of sections with that name
        num_of_sections = section_tags.count(section_name)
        # indices of those sections
        section_indices = [i for i, x in enumerate(section_tags) if x == section_name]
        for c, i in enumerate(section_indices):
            if i == counter:
                counter += 1
            else:
                xml_element = geom_root[i]
                geom_root.remove(xml_element)
                geom_root.insert(counter, xml_element)
                section_tags.insert(counter, section_tags.pop(i))
                counter += 1


def OrderFebioSections(root_element):
    """Put all the sections in the correct order according to the FEBio manual"""


    correct_order = ['Module', 'Control', 'Globals', 'Material', 'Geometry', 'MeshData', 'Initial', 'Boundary', 'Loads',
                     'Contact', 'Constraints', 'Discrete', 'LoadData', 'Output', 'Step']

    d = {k: v for v, k in enumerate(correct_order)}

    original_order = [c.tag for c in root_element]
    new_order= copy.deepcopy(original_order)

    # remove comments from new_order list
    for l in new_order:
        if not isinstance(l, str):
            new_order.remove(l)

    new_order.sort(key=d.get)

    # move the sections that need to be moved.
    counter = 0

    while counter < len(new_order):
        section_name = new_order[counter]
        orig_idx = original_order.index(section_name)
        if section_name == 'Step':
            counter +=1
        elif orig_idx != counter:
            xml_element = root_element[orig_idx] # get the xml element
            root_element.remove(xml_element) # remove it
            root_element.insert(counter, xml_element) # add it where it should be
            original_order.insert(counter, original_order.pop(orig_idx)) # update the original list
            counter += 1
        else:
            counter += 1

    # # print the new order to check it
    # sorted_order = [c.tag for c in root_element if not c.tag == et.Comment]
    # print(sorted_order)


def AddLandmarksToXml(ModelProperties_xml_file, all_landmarks):

    ModelProperties_tree = et.parse(ModelProperties_xml_file)
    ModelProperties = ModelProperties_tree.getroot()
    Landmarks = get_section("Landmarks", ModelProperties)

    # add the new comment
    Landmarks.append(et.Comment("These landmarks were last calculated on {}".format(datetime.datetime.now())))

    # if the landmark already is in the file, overrite it, if not add it.
    for name, coords in all_landmarks.items():
        this_landmark = get_section(name, Landmarks)
        try:
            this_landmark.text = str(coords[0])+','+str(coords[1])+','+str(coords[2])
        except:
            this_landmark.text = str(coords)

    write_file(ModelProperties_tree, ModelProperties_xml_file)


def write_file(xml_tree, new_filename):

    # Write the New File
    xml_tree.write(new_filename, xml_declaration=True, pretty_print=True)

    # make sure that it's pretty printing properly, by parsing and re-writing
    parser = et.XMLParser(remove_blank_text=True)
    new_feb_tree = et.parse(new_filename, parser)
    new_feb_tree.write(new_filename, xml_declaration=True, pretty_print=True)


def find_geometry_tree(febfile):
    """Find the geometry section in case it is referenced to in a different file.
    Retrun the geometry section, the geometry tree, and the geometry file name """

    # find the geometry file
    feb_tree = et.parse(febfile)
    febio_spec_root = feb_tree.getroot()

    geo_sect_in_feb = get_section('Geometry',febio_spec_root)

    if geo_sect_in_feb.attrib['from'] is not None:

        geofile = geo_sect_in_feb.attrib['from']

        if len(os.path.dirname(geofile))>0: # if the pathway is given:
            pass
        else:  # if no pathway is given, assume the pathway is the same as the febio file:
            dir = os.path.dirname(febfile)
            geofile = os.path.join(dir, geofile)

        # if len(geofile.split('/')) > 1: # if the pathway is given:
        #     pass
        # else: # if no pathway is given, assume the pathway is the same as the febio file:
        #
        #     path_list = febfile.split('/')[:-1]
        #     path_list.append(geofile)
        #     geofile ='/'.join(path_list)

        geo_tree = et.parse(geofile)

    else:
        geo_tree = feb_tree
        geofile = None

    return geo_tree, geofile


def MakeFeb(febfile, ModelProperties_xml_file):

    # calculate the other anatomical landmarks, nodeset for quadriceps tendon, fiber orientations
    import AnatomicalLandmarks_p3
    my_anatomical_landmarks, qat_nodeset, fiber_orientations, right_or_left, mcl_mns_nodepairs = AnatomicalLandmarks_p3.DoCalculations(ModelProperties_xml_file, febfile)

    # collect all the info needed to add to the febfile
    my_joint_info, my_imaginary_body_info = get_joint_info(my_anatomical_landmarks, right_or_left)
    my_connector_info = get_connector_info(my_anatomical_landmarks)
    my_material_properties = get_material_properties(ModelProperties_xml_file)

    print('\n Adding calculated landmarks to Model Properties xml')
    AddLandmarksToXml(ModelProperties_xml_file, my_anatomical_landmarks)

    print('\n Starting the Febio file Customization')
    feb_xml_tree, geo_xml_tree = customize_file(febfile, my_connector_info, my_joint_info, my_imaginary_body_info, qat_nodeset, fiber_orientations, my_material_properties, my_anatomical_landmarks, mcl_mns_nodepairs)

    _, geofile = find_geometry_tree(febfile)

    # new filenames
    new_febfilename = febfile.split('.')[0] + '_custom.feb'
    new_geofilename = geofile.split('.')[0] + '_custom.feb'

    # update the geometry tag in the febfile
    geometry_section = get_section('Geometry', feb_xml_tree.getroot())
    geometry_section.attrib['from'] = os.path.basename(new_geofilename)  # remove path from filename
    # geometry_section.attrib['from'] = new_geofilename.split('/')[-1] # remove path from filename

    print('\n Writing the Customized Files \n')
    # write the customized febio and geometry files
    write_file(feb_xml_tree, new_febfilename)
    write_file(geo_xml_tree, new_geofilename)


def test():
    # test with these files
    febio_file = '/home/schwara2/Documents/Open_Knees/knee_hub/oks003/calibration/Registration/model/Febio/FeBio.feb'
    ModelProperties_xml = '/home/schwara2/Documents/Open_Knees/knee_hub/oks003/calibration/Registration/model/ModelProperties.xml'
    MakeFeb(febio_file, ModelProperties_xml)

if __name__ == '__main__':

    MakeFeb(*sys.argv[1:])
    # test()

