from lxml import etree as et
import sys
import os

def comment_element(element, parent):
    parent.replace(element, et.Comment(et.tostring(element)))

def remove_boundaries(Boundary_section, parts_to_keep):
    """remove any boundaries that are not between two parts we are keeping"""

    for boundary in Boundary_section:
        try:
            boundary_name = boundary.attrib["name"]
            parts = boundary_name.split('_')
            parts.remove('With')
            if not set(parts).issubset(parts_to_keep): # if all the parts are not in the parts_to_keep list
                comment_element(boundary, Boundary_section)
        except KeyError: # in case its a comment
            pass


def remove_contacts(Contact_section, parts_to_keep):
    """remove any contacts that are not between two parts we are keeping"""

    for contact in Contact_section:
        try:
            contact_name = contact.attrib["surface_pair"]
            parts = contact_name.split('_')
            parts.remove('To')
            if not set(parts).issubset(parts_to_keep):  # if all the parts are not in the parts_to_keep list
                comment_element(contact, Contact_section)
        except KeyError:
            pass


def change_unwanted_materials(Material_section, parts_to_keep):
    """ any material that is not one we are keeping will be changed to a rigid body"""

    unwanted_material_ids = {}

    for c, material in enumerate(Material_section):

        material_name = material.attrib['name']
        material_id = material.attrib['id']
        material_type = material.attrib['type']

        if material_name not in parts_to_keep: # if its not a material we are keeping

            # if it's not already a rigid body, make it one
            if material_type != 'rigid body':
                material.attrib['type'] = 'rigid body'
                for prop in material:
                    material.remove(prop)
                myElement('density',material, text='1e-9')

            # store all the unwanted material ids
            unwanted_material_ids[material_name]=material_id

    return unwanted_material_ids


def fix_bodies(Boundary_section, mat_ids):
    """ fix all the rigid bodies in the list of material id's
    excluding the imaginary rigid bodies in the cylindrical joint"""

    all_bcs = ['x','y','z','Rx','Ry','Rz']

    imaginary_links = ["TFFO","TFTO","PFFO","PFPO","QSO"]

    for name, id in mat_ids.items():
        if name not in imaginary_links:
            rigid_body = myElement('rigid_body',Boundary_section, mat=str(id))
            for bc in all_bcs:
                myElement('fixed',rigid_body, bc=bc)


def find_mat_id(Material_section,mat_name):
    mat = get_section('material', Material_section, use_other_attribute="name", attribute_value=mat_name)
    mat_id = mat.attrib['id']
    return mat_id

def ReduceModel(febfile, parts_to_keep):
    """ Remove anything in the febio file that does not pertain to these parts, and make changes as necessary"""
    # try first leaving in all the geometry, just removing all ties/contacts with anything not in the list

    # turn parts to keep into a list again, sys.argv turns it into a string
    parts_to_keep = parts_to_keep[1:-1] # remove the brackets from the string
    parts_to_keep = parts_to_keep.split(',') # create a list split by the commas

    parser = et.XMLParser(remove_blank_text=True)
    feb_tree = et.parse(febfile, parser)
    febio_spec_root = feb_tree.getroot()

    Material = get_section('Material', febio_spec_root)
    unwanted_mat_ids = change_unwanted_materials(Material, parts_to_keep)

    Boundary = get_section('Boundary',febio_spec_root)
    remove_boundaries(Boundary, parts_to_keep)

    # fix all unwanted materials (rigid bodies)
    fix_bodies(Boundary, unwanted_mat_ids)

    # remove MFPl,LPFL if necessary
    Discrete = get_section('Discrete', febio_spec_root)
    if 'PTB' in parts_to_keep and 'FMB' in parts_to_keep:
        pass
    else:
        discrete_MPFL = get_section('discrete', Discrete, use_other_attribute="discrete_set", attribute_value="MPFL")
        discrete_LPFL = get_section('discrete', Discrete, use_other_attribute="discrete_set", attribute_value="LPFL")
        Discrete.remove(discrete_LPFL)
        Discrete.remove(discrete_MPFL)

    # check if contact section is in main section
    Contact = get_section('Contact', febio_spec_root)
    remove_contacts(Contact, parts_to_keep)

    # remove the irrelevant contacts in the loading step
    # LoadingStep = get_section("Step", febio_spec_root, use_other_attribute='name', attribute_value="LoadingStep")
    LoadingStep = get_section("Step", febio_spec_root)
    Contact = get_section('Contact', LoadingStep)
    remove_contacts(Contact, parts_to_keep)

    directory = os.path.dirname(febfile)
    # directory = '/'.join(febfile.split('/')[:-1]) + '/'

    new_filename = os.path.join(directory, '_'.join(parts_to_keep) + '.feb')
    # new_filename = directory + '_'.join(parts_to_keep) + '.feb'

    # if the QAT is not included in the model, fix the QSO and comment out the QAT slider constraint.
    if 'QAT' in parts_to_keep:
        pass
    else:
        # fix the QSO
        QSO_id =find_mat_id(Material,'QSO')
        all_bcs = ['x', 'y', 'z', 'Rx', 'Ry', 'Rz']
        rigid_body = myElement('rigid_body', Boundary, mat=str(QSO_id))
        for bc in all_bcs:
            myElement('fixed', rigid_body, bc=bc)
        # comment out the constraint
        Constraints = get_section('Constraints', LoadingStep)
        # QSO_constraint = get_section('constraint', Constraints, use_other_attribute="name", attribute_value="Quadriceps_Slider")
        QSO_constraint = get_section('constraint', Constraints, use_other_attribute="type", attribute_value="rigid spring")
        comment_element(QSO_constraint, Constraints)

    write_file(feb_tree, new_filename)
    # # Write the New File
    # feb_tree.write(new_filename, xml_declaration=True, pretty_print=True)
    #
    # # make sure that it's pretty printing properly, by parsing and re-writing
    # parser = et.XMLParser(remove_blank_text=True)
    # new_feb_tree = et.parse(new_filename, parser)
    # new_feb_tree.write(new_filename, xml_declaration=True, pretty_print=True)


def write_file(xml_tree, new_filename):

    # Write the New File
    xml_tree.write(new_filename, xml_declaration=True, pretty_print=True)

    # make sure that it's pretty printing properly, by parsing and re-writing
    parser = et.XMLParser(remove_blank_text=True)
    new_feb_tree = et.parse(new_filename, parser)
    new_feb_tree.write(new_filename, xml_declaration=True, pretty_print=True)

def myElement(tag, parent=None, text=None, attrib={}, **extra):
    """Shortcut to create an xml element (root or child), add text and other attributes"""
    e = et.Element(tag, attrib, **extra)
    if text:	e.text = text
    if parent is not None:
        parent.append(e)
    return e


def get_section(section_tag, parent, use_other_attribute = None, attribute_value = None):
    """Get the section from the parent. defualt is to search for the xml tag that matches with section_name.
    If use_other_attribute is specified, will search based on an xml attribute instead of the tag.
    if the section doesn't exist, create it.
    If more than one section exists in the parent with that name, return them as a list"""

    # if searching by a section tag
    if use_other_attribute is None:
        section_tags = [c.tag for c in parent]

        if section_tags.count(section_tag) == 1:
            section = parent[section_tags.index(section_tag)]

        elif section_tags.count(section_tag) > 1:
            section_indices = [i for i, x in enumerate(section_tags) if x == section_tag]
            section = []
            for i in section_indices:
                section.append(parent[i])

        else:  # create it if it doesnt exist
            section = myElement(section_tag, parent)

    # if searching by an attribute
    else:
        # get all the sections with the tag. = my_sections
        section_tags = [c.tag for c in parent]
        section_indices = [i for i, x in enumerate(section_tags) if x == section_tag]
        my_sections = []
        for i in section_indices:
            my_sections.append(parent[i])

        # search those sections for the requested attribute
        section_attribute = []
        for c in my_sections:
            try:
                att = c.attrib[use_other_attribute]
                section_attribute.append(att)
            except KeyError:
                section_attribute.append(None) # placeholder if they dont have the attribute

        # return the section/sections from my_section that have the requested attribute
        if section_attribute.count(attribute_value) == 0: # create it if it doesnt exist
            section = myElement(section_tag, parent, **{use_other_attribute:attribute_value})
        elif section_attribute.count(attribute_value) == 1: # if only 1 return that one
            section = my_sections[section_attribute.index(attribute_value)]
        else: # if multiple return as list
            section_indices = [i for i, x in enumerate(section_attribute) if x == attribute_value]
            section = []
            for i in section_indices:
                section.append(my_sections[i])

    return section


if __name__ == '__main__':
    ReduceModel(*sys.argv[1:])

    # feb_filename = '/home/schwara2/Documents/Open_Knees/knee_hub/oks003/model/Febio/FeBio_custom.feb'
    # parts_to_keep = ['PTB','PTL','TBB']
    # ReduceModel(feb_filename, parts_to_keep)
