""" This script calculates average edge length for a given tet mesh (format - MED)
Written by - Snehal Chokhandre """



import meshio
import numpy as np
import math
import itertools
import matplotlib.pyplot as plt 
from mpl_toolkits.mplot3d import Axes3D
import scipy.stats as stats




mesh = meshio.read('C:\oks\oks003_kneehub1\model\MED\oks003_MNS-M_AGS_02_LVTIT.med')
verts = mesh.points



tetra =[]
tetra = mesh.cells_dict['tetra'] # list of vertices in each tet
num_tetra = len(tetra)




# create edge list 

edges_tetra = []
ed = [[] for k in range(4)]
edges = [[] for x in range(num_tetra*4)]
for i in range (0, num_tetra-1):
   edges_tetra = list(itertools.combinations(tetra[i], 2))
   edges.append(edges_tetra)

all_edges = list(itertools.chain(*edges)) # create one long list of lists
all_edges = [list(y) for y in all_edges] # tuples to lists


# remove duplicate edges 

for r in range(0, len(all_edges)-1): # node pairs are ordered as [lower number, higher number]
    all_edges[r].sort()
    
print (len(all_edges))
# all_edges = list(all_edges for all_edges,_ in itertools.groupby(all_edges)) 
new_all_edges = []
for elem in all_edges:
    if elem not in new_all_edges:
        new_all_edges.append(elem)
all_edges = new_all_edges
print (len(all_edges))


# find edge lengths 

edge_lengths = []
for j in range(0,len(all_edges)-1):
    p1 = np.array(verts[all_edges[j][0]])
    p2 = np.array(verts[all_edges[j][1]])
    squared_dist = np.sum((p1-p2)**2, axis=0)
    dist = np.sqrt(squared_dist)
    # print (dist)
    edge_lengths.append(dist)


# edge length stats (min, max, average, stdev)

average_edge_length = np.average(edge_lengths)
sd_edge_length = np.std(edge_lengths)
min_edge_length = min(edge_lengths)
max_edge_length = max(edge_lengths)

print (average_edge_length)
print (sd_edge_length)
print (min_edge_length)
print (max_edge_length)


# plot 

fig = plt.figure()
ax = plt.axes(projection='3d')
ax.plot3D(verts[:,0], verts[:,1], verts[:,2], 'b.')

fig2 = plt.figure()
edge_lengths.sort()
y = stats.norm.pdf(edge_lengths, average_edge_length, sd_edge_length) # one dimesional gaussian distribution function
plt.plot(edge_lengths, y, 'r.')
plt.hist(edge_lengths,density=True)




# can remove outliers by finding median and sigma and, clipping everything outside of that. 
