import numpy as np
import os
import sys
import MitralValveLibrary as mvl
from scipy.spatial.distance import cdist


def get_midpoints_and_normals(mesh_data):
    """return the midpoints and normals of all the elements in a mesh as an array in the same order as the elements"""

    points = mesh_data.nodes[mesh_data.elems]
    P1 = points[:, 0]
    P2 = points[:, 1]
    P3 = points[:, 2]

    midpoints = (P1 + P2 + P3)/3

    v = P2 - P1
    w = P3 - P1
    normals = np.cross(v, w)

    norms = np.linalg.norm(normals, axis= 1)
    normals = np.divide(normals, np.reshape(norms, (len(norms),1)))

    return midpoints, normals


def find_opposites(mesh_data, midpoints, normals):
    """for each element, compare with the rest of the mesh normals to find the element across from this one.
    find the midpoints between the two elements and add that point to the point cloud"""

    search_distance = 2.5
    centers = []

    # # this can probably be done without a loop !!!
    #
    # # dot product between each normal and all the other normals
    # dot = np.matmul(normals, normals.T)
    #
    # # distance between midpoints
    # distances = cdist(midpoints, midpoints)
    #
    # # will this work?
    # where_opposite = np.where((dot<-0.9) & (dot>-1.1))

    ## if not:
    # l = np.where(dot < -0.9)
    # m = np.where(dot > -1.1)
    # less = np.asarray([l[0], l[1]]).T
    # more = np.asarray([m[0], m[1]]).T

    for idx, normal in enumerate(normals):

        # dot product between normal and all other normals
        dot = np.dot(normals, normal)

        # find the distance to all the points
        distances = midpoints - midpoints[idx]
        distances = np.sqrt(np.linalg.norm(distances, axis=1))

        # find which points have opposite normals
        where_opposite = np.intersect1d(np.where(dot < -0.9)[0], np.where(dot > -1.1)[0])

        # check if normals are pointing away or toward each other. on opposite sides of thickness should be pointing away
        normals_distance = (midpoints + normals) - ( midpoints[idx] + normal)
        normals_distance  = np.sqrt(np.linalg.norm(normals_distance, axis=1))
        where_away = np.where((normals_distance - distances) > 0.0 )[0]

        where_both = np.intersect1d(where_opposite, where_away)

        # # find the closest opposite normal
        # arg = np.argmin(distances[where_opposite])
        # opposite_elem = where_opposite[arg]

        # find the closest opposite normal
        arg = np.argmin(distances[where_both])
        opposite_elem = where_both[arg]


        # if within the search distance, find the center between the two elements
        if distances[opposite_elem] < search_distance:
            center = (midpoints[idx] + midpoints[opposite_elem]) / 2
            centers.append(center)

    centers = np.asarray(centers)

    # remove any doubles
    centers = np.unique(centers, axis=0)

    return centers


def find_midsurface(stl_file):

    # convert the stl to mesh data
    mesh_data = mvl.stl_to_mesh_data(stl_file)

    # find each element's 'opposite' and save a point cloud with the midpoints between the opposites
    midpoints, normals = get_midpoints_and_normals(mesh_data)
    point_cloud = find_opposites(mesh_data, midpoints, normals)

    # save file to xyz space delimeter
    np.savetxt(stl_file.replace('.stl','_center.xyz'),point_cloud, delimiter=" ")


if __name__ == '__main__':

    input_stl = '/home/schwara2/Documents/MitralValve/Project_Files/source_code/geometry/MV62/MV62_WHOLE_SYS_0930_smooth.stl'
    find_midsurface(input_stl)