import sys
import numpy as np
import time
import os
import inspect
try:				from lxml import etree as et		# not in my standard Salome's python but better
except ImportError:	import xml.etree.cElementTree as et	# functionally equavalent, but no pretty printing
try:
    import salome
    import salome_notebook
    from salome.smesh import smeshBuilder
    import SMESH
except ImportError:
    import inspect
    print('\n\n\nUsage: salome {} args:ConnectivityFile,FeBioFile(optional),GeometryFile(optional),steps(optional)  \n\n\n'.format( inspect.getfile(inspect.currentframe()) ))
    sys.exit()
try:from MEDLoader import MEDLoader
except ImportError: import MEDLoader


salome.salome_init()
theStudy = salome.myStudy
notebook = salome_notebook.NoteBook(theStudy)
smesh = smeshBuilder.New()


def CloseSalome():
    """Try to close out of Salome gracefully"""
    # there is nothing graceful about this
    try:
        sys.stdout = 'redirect to nowhere...'  # so I don't see the junk it prints
        from killSalomeWithPort import killMyPort
        killMyPort(os.getenv('NSPORT'))
    except:		pass


def get_groups(original_mesh, all_node_array):
    """create a dictionary of group names, and node coordinates for all the groups in the original mesh"""

    # get the list of groups
    all_groups = original_mesh.GetGroups(SMESH.ALL)

    groups_dict = {}
    for g in all_groups:
        # name of the group
        g_name = g.GetName()
        # indices of the nodes in the group - check if this is starting from zero or 1
        group_nodes_id = g.GetNodeIDs()
        group_nodes_id = np.asarray(group_nodes_id)
        group_nodes_id -= 1 # python indexing (salome starts at 1)
        try:
            group_nodes = all_node_array[group_nodes_id]
            groups_dict[g_name] = group_nodes
        except IndexError: pass

    return groups_dict


def new_groups(original_groups_dict, new_mesh, all_nodes_array, threshhold=0.1):
    """find the nodes on the new mesh closest to the node sin the original groups"""

    existing_groups = new_mesh.GetGroups(SMESH.FACE) # get the existing face groups
    surface_ids = None
    for g in existing_groups:
        if 'All_Faces' in g.GetName():
            surface_ids = g.GetNodeIDs()
            surface_ids = np.asarray(surface_ids)

    surface_nodes_array = all_nodes_array[surface_ids-1] # -1 for python indexing

    from scipy.spatial import distance

    new_groups_dict = {}
    for g_name, g_nodes in iter(original_groups_dict.items()):
        dists = distance.cdist(g_nodes, surface_nodes_array)
        idx_close_points_surface = np.where(dists < threshhold)[1] # this is the python index of the points from the surface nodes
        close_point_ids = surface_ids[idx_close_points_surface] # this is the salome index of the points in the full node array
        new_groups_dict[g_name] = close_point_ids

    return new_groups_dict


def MakeAllGroups(mesh, dim='3'):
    """Make mesh Group of type "All" for specified dimensions
		mesh is mesh data, dim a the dimension of elemnt wanted
		for example dim = '01' would make all groups for nodes and edges
		no return value but mesh is augmented with groups"""
    name = mesh.GetName()
    if   '0' in dim:	mesh.MakeGroupByIds('{}_All_Nodes'.format(name),   SMESH.NODE,   mesh.GetElementsByType(SMESH.NODE))
    elif '1' in dim:	mesh.MakeGroupByIds('{}_All_Edges'.format(name),   SMESH.EDGE,   mesh.GetElementsByType(SMESH.EDGE))
    elif '2' in dim:	mesh.MakeGroupByIds('{}_All_Faces'.format(name),   SMESH.FACE,   mesh.GetElementsByType(SMESH.FACE))
    elif '3' in dim:	mesh.MakeGroupByIds('{}_All_Volumes'.format(name), SMESH.VOLUME, mesh.GetElementsByType(SMESH.VOLUME))


def MakeMesh(meshFile, material='elastic', meshName='', ifSurface=True, ifAll=True):
    """Takes an stl mesh file open file, mesh in 3D,set all groups that require no other information, returns mesh data"""
    mesh = smesh.CreateMeshesFromSTL(meshFile)
    if material != 'rigid':
        meshMin, meshMax = mesh.GetMinMax(SMESH.FT_Length2D)
        print('	optimizing.')
        #
        #	3D Meshing and parameters
        #

        NETGEN_3D = mesh.Tetrahedron(geom=None)
        NETGEN_3D_Parameters = NETGEN_3D.Parameters()
        NETGEN_3D_Parameters.SetMaxSize(meshMax)
        NETGEN_3D_Parameters.SetOptimize(1)
        NETGEN_3D_Parameters.SetMinSize(meshMin)
        NETGEN_3D_Parameters.SetFineness(4)
        #
        if not mesh.Compute():		# this fixes an error I had with one of my test cases
            print('\n\n\nChanging NETGEN parameters and remeshing\n\n\n')
            NETGEN_3D_Parameters.SetFineness(3)		# I don't know why changing the fineness matters
            mesh.Compute()

    #
    if meshName:	mesh.SetName(meshName)
    else:			mesh.SetName(mesh.GetName().split('.')[0])	  # remove '.stl' from file name
    #
    dim = mesh.MeshDimension()
    if ifAll:					MakeAllGroups(mesh, str(dim) )
    if ifSurface and dim != 2:	MakeAllGroups(mesh, '2')

    return mesh


def add_new_groups(new_mesh, new_groups_dict):
    """add the groups in the dictionary to the mesh"""

    Contact_face_groups = []
    Tie_face_groups = []

    for g_name, node_ids in iter(new_groups_dict.items()):
        if 'All_Faces' in g_name:
            continue
        if 'All_Volumes' in g_name:
            continue
        node_ids = node_ids.tolist()
        node_group = new_mesh.MakeGroupByIds(g_name, SMESH.NODE, node_ids) # make a node group, will be removed later if its a face group
        if 'Faces' in g_name:
            face_group = new_mesh.CreateDimGroup([node_group], SMESH.FACE, g_name, SMESH.ALL_NODES, 0)
            new_mesh.RemoveGroup(node_group)
            if 'ContactFaces' in g_name:
                Contact_face_groups.append(face_group)
            elif 'TiesFaces' in g_name:
                Tie_face_groups.append(face_group)

    # check for overlapping between groups - remove all TiesFaces from all ContactFaces
    for contact_grp in Contact_face_groups:
        for tie_grp in Tie_face_groups:
            contact_grp.Remove(tie_grp.GetIDs())


def MakeMed(original_med_file, new_file):
    """ mesh the stl and then transfer the groups from original_med_file onto the new mesh."""

    # create the med directory if it doesn't exist
    medDir = os.path.join(os.path.dirname(new_file), 'MED')
    if not os.path.isdir(medDir):    os.mkdir(medDir)

    # get all the nodes from the original med file as an array
    AllNodeArray_original = MEDLoader.ReadUMeshFromFile(original_med_file, 0).getCoords().toNumPyArray()

    # open the med file from the original mesh
    original_mesh, status = smesh.CreateMeshesFromMED(original_med_file)
    original_mesh = original_mesh[0]
    # approximates a 'characteristic length' of the mesh
    minSize, maxSize = original_mesh.GetMinMax(SMESH.FT_Length2D)
    original_meshSize = (maxSize + minSize) / 2.0

    # get the node coordinates of all the groups in the original mesh
    original_groups_dict = get_groups(original_mesh, AllNodeArray_original)

    if new_file.endswith('.stl'):
        # mesh the stil and create the med file
        print("making med for {}".format(os.path.basename(new_file)))

        # create the new med file, and save it (easier to pull in nodes from med file)
        med_filename = os.path.join(medDir, os.path.basename(new_file).replace('.stl', '.med'))
        new_mesh = MakeMesh(new_file, meshName = original_mesh.GetName().split('.')[0])
        new_mesh.ExportMED(med_filename)

    elif new_file.endswith('.med'):
        # open the med file and remove any groups that are not "All" groups
        med_filename = new_file
        new_mesh, status = smesh.CreateMeshesFromMED(med_filename)
        new_mesh = new_mesh[0]

        groups = new_mesh.GetGroups(SMESH.ALL)
        print('	removing existing groups.')
        for g in groups:
            g_name = g.GetName()
            AllGroup = False
            # info = smesh.GetMeshInfo(g)
            # keys = list(info.keys())
            # keys.sort()
            # volumeGroup = False
            # for key in keys:
            #     if ('Quad' in str(key) and info[key] != 0) \
            #     or ('Tetra' in str(key) and info[key] != 0) \
            #     or ('Penta' in str(key) and info[key] != 0):    volumeGroup = True
            if "All" in g_name: AllGroup = True
            if not AllGroup:
                new_mesh.RemoveGroup(g)

    # approximates a 'characteristic length' of the mesh
    minSize, maxSize = new_mesh.GetMinMax(SMESH.FT_Length2D)
    new_meshSize = (maxSize + minSize) / 2.0

    # get all the nodes from the new mesh
    AllNodeArray_new = MEDLoader.ReadUMeshFromFile(med_filename, 0).getCoords().toNumPyArray()

    print("finding new groups..")

    # find the node ids of the new mesh groups
    new_groups_dict = new_groups(original_groups_dict, new_mesh, AllNodeArray_new, threshhold=max([original_meshSize,new_meshSize]))

    print("adding new groups..")


    # create the groups in the new mesh
    add_new_groups(new_mesh, new_groups_dict)

    # export the new mesh as a med file
    new_mesh.ExportMED(med_filename)



def all_in_directory(stl_dir, med_dir):
    """ convert all the stl files in stl_dir to med files, transferring the groups from the med files in med_dir.
    This function makes use of the naming conventions used for parts in the open knee project. The names of the med files,
     and the stl files must contain the part name such as 'FMB' or 'MNS-M' in the names """

    new_stls  = os.listdir(stl_dir)
    original_meds = os.listdir(med_dir)

    def which_knee_part(filename):
        knee_parts = ['FMB', 'PTB', 'TBB', 'FBB', 'ACL', 'PCL', 'MCL', 'LCL', 'PTL', 'QAT', 'MNS-L', 'MNS-M', 'TBC-L',
                      'TBC-M', 'PTC', 'FMC']
        all_strings = filename.split('_')
        try:
            part_name = [s for s in all_strings if s in knee_parts][0]
        except:
            return None # no knee part in file name
        return part_name

    # find the original med file for each knee part
    original_meds_dict = {}
    for m in original_meds:
        pn = which_knee_part(m)
        if pn:
            original_meds_dict[pn] = med_dir + m

    # go through the stls, find the correct med file to use, and then transfer the groups
    for s in new_stls:
        pn = which_knee_part(s)
        if pn:
            orig_med = original_meds_dict[pn]
            new_s  = stl_dir + s
            MakeMed(orig_med, new_s)

    CloseSalome()

def test():
    orig_med_file = '/home/schwara2/Documents/Open_Knees/ACL_modeling/MeshDensity/MED/oks009_MRC_ACL_AGS_01_LVTIT_SR4_E.med'

    new_file = '/home/schwara2/Documents/Open_Knees/ACL_modeling/MeshDensity/MED/oks009_MRC_ACL_AGS_01_LVTIT_SR6_E.med'
    MakeMed(orig_med_file, new_file)

    new_file = '/home/schwara2/Documents/Open_Knees/ACL_modeling/MeshDensity/MED/oks009_MRC_ACL_AGS_01_LVTIT_SR8_E.med'
    MakeMed(orig_med_file, new_file)

    new_file = '/home/schwara2/Documents/Open_Knees/ACL_modeling/MeshDensity/MED/oks009_MRC_ACL_AGS_01_LVTIT_SR10_E.med'
    MakeMed(orig_med_file, new_file)

    CloseSalome()


if __name__ == '__main__':

    test()

    # all_in_directory('/home/schwara2/Documents/Open_Knees/ACL_modeling/', '/home/schwara2/Documents/Open_Knees/knee_hub/CC-NKD-MD-outcomes/CC-NKD-MD-outputs/model/MED/')