from lxml import etree as et
import sys
import numpy as np


def remove_boundaries(Boundary_section, parts_to_keep, new_mat_ids):
    """remove any boundaries that are not between two parts we are keeping"""

    for boundary in Boundary_section:
        try:
            boundary_name = boundary.attrib["name"]
            parts = boundary_name.split('_')
            parts.remove('With')
            if not set(parts).issubset(parts_to_keep): # if all the parts are not in the parts_to_keep list
                Boundary_section.remove(boundary)
            else:
                old_mat_id = boundary.attrib['rb']
                new_mat_id = new_mat_ids[old_mat_id]
                boundary.attrib['rb'] = new_mat_id

        except KeyError: # in case its a comment
            pass


def remove_contacts(Contact_section, parts_to_keep):
    """remove any contacts that are not between two parts we are keeping"""

    for contact in Contact_section:
        try:
            contact_name = contact.attrib["surface_pair"]
            parts = contact_name.split('_')
            parts.remove('To')
            if not set(parts).issubset(parts_to_keep):  # if all the parts are not in the parts_to_keep list
                Contact_section.remove(contact)
        except KeyError:
            pass


def remove_materials(Material_section, parts_to_keep):
    """ any material that is not one we are keeping will be changed to a rigid body"""

    new_mat_ids = {}

    count= 0

    for material in Material_section:
        material_name = material.attrib['name']
        material_id = material.attrib['id']
        if material_name not in parts_to_keep:
            Material_section.remove(material)
        else: # if it is one of the materials we are keeping, change the material number
            count += 1
            new_mat_ids[material_id]= str(count)
            material.attrib['id'] = str(count)

    return new_mat_ids


def remove_geometries(Geometry_section, parts_to_keep, new_mat_ids):
    """store all the info from the parts we want to keep, and then write a new geometry section"""

    Nodes_sections = get_section('Nodes', Geometry_section)
    Elements_sections = get_section('Elements', Geometry_section)
    NodeSet_sections = get_section('NodeSet', Geometry_section)
    Surface_sections = get_section('Surface', Geometry_section)
    SurfacePair_sections = get_section('SurfacePair', Geometry_section)
    DiscreteSet_sections = get_section('DiscreteSet', Geometry_section)

    # remove irrelevant nodes
    for nodes in Nodes_sections:
        if nodes.attrib['name'] not in parts_to_keep:
            Geometry_section.remove(nodes)

    # remove irrelevant Elements
    for elements in Elements_sections:
        if elements.attrib['name'] not in parts_to_keep:
            Geometry_section.remove(elements)
        else:
            old_mat_id = elements.attrib['mat']
            new_mat_id = new_mat_ids[old_mat_id]
            elements.attrib['mat'] = new_mat_id

    # remove irrelevant NodeSets
    for nodeset in NodeSet_sections:
        nodeset_name = nodeset.attrib['name']
        parts = nodeset_name.split('_')
        parts.remove('@')
        try:
            parts.remove('TiesNodes')
        except:
            parts.remove('ContactNodes')
        if not set(parts).issubset(parts_to_keep):
            Geometry_section.remove(nodeset)

    # remove irrelevant Surfaces
    for surface in Surface_sections:
        surface_name = surface.attrib['name']
        parts = surface_name.split('_')
        try:
            parts.remove('@')
        except:
            parts.remove('All')
        try:
            parts.remove('ContactFaces')
        except:
            try:
                parts.remove('TiesFaces')
            except:
                parts.remove('Faces')
        if not set(parts).issubset(parts_to_keep):
            Geometry_section.remove(surface)

    # remove irrelevant surface pairs
    for surface_pair in SurfacePair_sections:
        surface_pair_name = surface_pair.attrib['name']
        parts = surface_pair_name.split('_')
        parts.remove('To')
        if not set(parts).issubset(parts_to_keep):
            Geometry_section.remove(surface_pair)

    # remove DiscreteSets
    for discrete in DiscreteSet_sections:
        Geometry_section.remove(discrete)


def remove_mesh_data(MeshData_section, parts_to_keep):
    for ElementData in MeshData_section:
        elem_set_name = ElementData.attrib['elem_set']
        if elem_set_name not in parts_to_keep:
            MeshData_section.remove(ElementData)


def edit_step_boundary(Step_sections, new_mat_ids):

    for Step in Step_sections:
        boundary_section = get_section('Boundary', Step)
        for boundary in boundary_section:
            old_mat_id = boundary.attrib['mat']
            if old_mat_id in new_mat_ids.keys():
                new_mat_id = new_mat_ids[old_mat_id]
                boundary.attrib['mat'] = new_mat_id
            else:
                boundary_section.remove(boundary)


def ReduceModel(febfile, parts_to_keep):
    """ Remove anything in the febio file that does not pertain to these parts, and make changes as necessary"""
    # try first leaving in all the geometry, just removing all ties/contacts with anything not in the list

    # turn parts to keep into a list again, sys.argv turns it into a string
    parts_to_keep = parts_to_keep[1:-1] # remove the brackets from the string
    parts_to_keep = parts_to_keep.split(',') # create a list split by the commas

    feb_tree = et.parse(febfile)
    febio_spec_root = feb_tree.getroot()

    # get the geometry file from the febio file
    geosection = get_section('Geometry', febio_spec_root)
    geofile = geosection.attrib['from']
    geofile_as_list = febfile.split('/')[:-1] + [geofile]
    geo_tree = et.parse(('/').join(geofile_as_list))
    geo_spec_root = geo_tree.getroot()

    Material = get_section('Material', febio_spec_root)
    new_mat_ids = remove_materials(Material, parts_to_keep)

    # as each section is updated, change the material ids where necessary

    Geometry = get_section('Geometry', geo_spec_root)
    remove_geometries(Geometry, parts_to_keep, new_mat_ids)

    # replace the geometry section in the febio xml with the reduced one
    section_tags = [c.tag for c in febio_spec_root]
    febio_spec_root[section_tags.index('Geometry')] = Geometry

    MeshData = get_section('MeshData', febio_spec_root)
    remove_mesh_data(MeshData, parts_to_keep)

    Boundary = get_section('Boundary',febio_spec_root)
    remove_boundaries(Boundary, parts_to_keep, new_mat_ids)

    Contact = get_section('Contact', febio_spec_root)
    remove_contacts(Contact, parts_to_keep)

    Discrete = get_section('Discrete', febio_spec_root)
    febio_spec_root.remove(Discrete)

    Steps = get_section('Step', febio_spec_root)
    edit_step_boundary(Steps, new_mat_ids)

    directory = '/'.join(febfile.split('/')[:-1]) + '/'

    new_filename = directory + '_'.join(parts_to_keep) + '.feb'

    # Write the New File
    feb_tree.write(new_filename, xml_declaration=True, pretty_print=True)

    # make sure that it's pretty printing properly, by parsing and re-writing
    parser = et.XMLParser(remove_blank_text=True)
    new_feb_tree = et.parse(new_filename, parser)
    new_feb_tree.write(new_filename, xml_declaration=True, pretty_print=True)



def myElement(tag, parent=None, text=None, attrib={}, **extra):
    """Shortcut to create an xml element (root or child), add text and other attributes"""
    e = et.Element(tag, attrib, **extra)
    if text:	e.text = text
    if parent is not None:
        parent.append(e)
    return e


def get_section(section_tag, parent, use_other_attribute = None, attribute_value = None):
    """Get the section from the parent. defualt is to search for the xml tag that matches with section_name.
    If use_other_attribute is specified, will search based on an xml attribute instead of the tag.
    if the section doesn't exist, create it.
    If more than one section exists in the parent with that name, return them as a list"""

    # if searching by a section tag
    if use_other_attribute is None:
        section_tags = [c.tag for c in parent]

        if section_tags.count(section_tag) == 1:
            section = parent[section_tags.index(section_tag)]

        elif section_tags.count(section_tag) > 1:
            section_indices = [i for i, x in enumerate(section_tags) if x == section_tag]
            section = []
            for i in section_indices:
                section.append(parent[i])

        else:  # create it if it doesnt exist
            section = myElement(section_tag, parent)

    # if searching by an attribute
    else:
        # get all the sections with the tag. = my_sections
        section_tags = [c.tag for c in parent]
        section_indices = [i for i, x in enumerate(section_tags) if x == section_tag]
        my_sections = []
        for i in section_indices:
            my_sections.append(parent[i])

        # search those sections for the requested attribute
        section_attribute = []
        for c in my_sections:
            try:
                att = c.attrib[use_other_attribute]
                section_attribute.append(att)
            except KeyError:
                section_attribute.append(None) # placeholder if they dont have the attribute

        # return the section/sections from my_section that have the requested attribute
        if section_attribute.count(attribute_value) == 0: # create it if it doesnt exist
            section = myElement(section_tag, parent, **{use_other_attribute:attribute_value})
        elif section_attribute.count(attribute_value) == 1: # if only 1 return that one
            section = my_sections[section_attribute.index(attribute_value)]
        else: # if multiple return as list
            section_indices = [i for i, x in enumerate(section_attribute) if x == attribute_value]
            section = []
            for i in section_indices:
                section.append(my_sections[i])

    return section


if __name__ == '__main__':
    ReduceModel(*sys.argv[1:])

    # feb_filename = '/home/schwara2/Documents/Open_Knees/knee_hub/oks003/2nd_order/Febio/FeBio_custom.feb'
    # parts_to_keep = ['FMB','ACL','TBB']
    # ReduceModel(feb_filename, parts_to_keep)
