from lxml import etree as et
from math import sin, cos
import numpy as np
import pygem as pg
import os
import re
from stl import mesh as stl
from string import whitespace
import subprocess
import sys
from tvtk.api import tvtk
from tvtk.api import colors

sys.path.append('/home/landisb/PycharmProjects/PythonScripts/Project')

import Utility_tvtk
import ReadSubjectXml		# this only works if sys.path append is right


limb = 'Upper_Arm'

probeWidth = 22.5	# width of ultrasound probe in mm
lengthOutput = False
clickHighlight = False

def CreateNestedDictionary():
	"""Create a nested dictionary structure for the ultrasound data
		this iterates "nicely" for these purposes as
		for place, slice in nestedDict.items():
			for orient, strata in slice.items():
				for layer, point in strata.items():"""
	from copy import deepcopy
	orient = {	'skin'  : np.zeros(3),
				'fat'   : np.zeros(3),
				'muscle': np.zeros(3),
				'bone'  : np.zeros(3)	}
	place = {	'Anterior' : deepcopy(orient),
				'Posterior': deepcopy(orient),
				'Medial'   : deepcopy(orient),
				'Lateral'  : deepcopy(orient)	}
	nestedDict={'Proximal' : deepcopy(place),
				'Central'  : deepcopy(place),
				'Distal'   : deepcopy(place)	}
	return nestedDict
def dist(p1, p2):
	distance = ( (p1[0]-p2[0])**2 + (p1[1]-p2[1])**2 + (p1[2]-p2[2])** 2 )**0.5
	return distance
def planeFit(points):
	"""
	p, n = planeFit(points)

	Given an array, points, of shape (d,...)
	representing points in d-dimensional space,
	fit an d-dimensional plane to the points.
	Return a point, p, on the plane (the point-cloud centroid),
	and the normal, n.
	"""
	from numpy.linalg import svd
	points = np.reshape(points, (np.shape(points)[0], -1)) # Collapse trialing dimensions
	assert points.shape[0] <= points.shape[1], "There are only {} points in {} dimensions.".format(points.shape[1], points.shape[0])
	ctr = points.mean(axis=1)
	x = points - ctr[:,np.newaxis]
	M = np.dot(x, x.T) # Could also use np.cov(x) here.
	return ctr, svd(M)[0][:,-1]
def ReadAndApplyUltrasoundThickness(limbSegment, file, probePoints, probeDirections, reverse=False):
	"""Read the xml to find thickness
		assign thickness to probePoints"""
	directory = os.path.dirname(file)
	doc = et.parse(file)
	root = doc.getroot()
	bodyPart = root.find(limbSegment.replace('_', ''))
	for loc in list(bodyPart):
		place, orient = re.findall('[A-Z][a-z]*', loc.tag)  # Split on capitalization
		thicknessFile = loc.get('Anatomical')
		subDoc = et.parse(os.path.join(directory, thicknessFile))
		subRoot = subDoc.getroot()
		subject = subRoot.find("Subject").find("Source")
		Thickness = {}
		for frame in list(subject):
			readings = frame.find('Thickness')
			for value in list(readings):
				try:				Thickness[value.tag].append(float(value.text))
				except KeyError:	Thickness[value.tag] = [float(value.text)]
		for t, v in Thickness.iteritems():
			Thickness[t] = sum(v) / len(v)
		#
		r = probeDirections[place][orient][0]  # roll
		p = probeDirections[place][orient][1]  # pitch
		y = probeDirections[place][orient][2]  # yaw
		transformMatrix = np.array(	[[ cos(p)*cos(y), sin(r)*sin(p)*cos(y)-cos(r)*sin(y), cos(r)*sin(p)*cos(y)+sin(r)*sin(y)],
			 						 [ cos(p)*sin(y), sin(r)*sin(p)*sin(y)+cos(r)*cos(y), cos(r)*sin(p)*sin(y)-sin(r)*cos(y)],
			 						 [-sin(p),		  sin(r)*cos(p),					  cos(r)*cos(p)						]	])
		if reverse:	# uses thickness to move outward from bone
			dirVector = np.dot(transformMatrix, [0,0,-1])	# ultrasound probe points in -Z dir
			probePoints[place][orient]['muscle']= probePoints[place][orient]['bone'] + dirVector*(Thickness['Muscle'])
			probePoints[place][orient]['fat']   = probePoints[place][orient]['bone'] + dirVector*(Thickness['Muscle'] + Thickness['Fat'])
			probePoints[place][orient]['skin']  = probePoints[place][orient]['bone'] + dirVector*(Thickness['Muscle'] + Thickness['Fat'] + Thickness['Skin'])
		else:	# Skin towards bone
			dirVector = np.dot(transformMatrix, [0,0,1])	# ultrasound probe points in +Z dir
			probePoints[place][orient]['fat']   = probePoints[place][orient]['skin'] + dirVector*(Thickness['Skin'])
			probePoints[place][orient]['muscle']= probePoints[place][orient]['skin'] + dirVector*(Thickness['Skin'] + Thickness['Fat'])
			probePoints[place][orient]['bone']  = probePoints[place][orient]['skin'] + dirVector*(Thickness['Skin'] + Thickness['Fat'] + Thickness['Muscle'])

def NEW_ReadAndApplyUltrasoundThickness(limbSegment, file, probePoints, probeDirections, reverse=False):
	"""Read the xml to find thickness
		assign thickness to probePoints"""
	directory = os.path.dirname(file)
	doc = et.parse(file)
	root = doc.getroot()
	bodyPart = root.find(limbSegment.replace('_', ''))
	for loc in list(bodyPart):
		place, orient = re.findall('[A-Z][a-z]*', loc.tag)  # Split tag on capitalization
		thicknessFile = loc.get('Anatomical')
		subDoc = et.parse(os.path.join(directory, thicknessFile))
		subRoot = subDoc.getroot()
		subject = subRoot.find("Subject").find("Source")
		Thickness = {}
		for frame in list(subject):
			readings = frame.find('Thickness')
			for value in list(readings):
				try:				Thickness[value.tag].append(float(value.text))
				except KeyError:	Thickness[value.tag] = [float(value.text)]
		for t, v in Thickness.iteritems():
			Thickness[t] = sum(v) / len(v)
		#
		r = probeDirections[place][orient][0]  # roll
		p = probeDirections[place][orient][1]  # pitch
		y = probeDirections[place][orient][2]  # yaw
		transformMatrix = np.array(	[[ cos(p)*cos(y), sin(r)*sin(p)*cos(y)-cos(r)*sin(y), cos(r)*sin(p)*cos(y)+sin(r)*sin(y)],
			 					 [ cos(p)*sin(y), sin(r)*sin(p)*sin(y)+cos(r)*cos(y), cos(r)*sin(p)*sin(y)-sin(r)*cos(y)],
			 					 [-sin(p),		  sin(r)*cos(p),					  cos(r)*cos(p)						]	])
		if reverse:	# uses thickness to move outward from bone
			dirVector = np.dot(transformMatrix, [0,0,-1])	# ultrasound probe points in -Z dir
			probePoints[place][orient]['muscle']= probePoints[place][orient]['bone'] + dirVector*(Thickness['Muscle'])
			probePoints[place][orient]['fat']   = probePoints[place][orient]['bone'] + dirVector*(Thickness['Muscle'] + Thickness['Fat'])
			probePoints[place][orient]['skin']  = probePoints[place][orient]['bone'] + dirVector*(Thickness['Muscle'] + Thickness['Fat'] + Thickness['Skin'])
		else:	# Skin towards bone
			dirVector = np.dot(transformMatrix, [0,0,1])	# ultrasound probe points in +Z dir
			probePoints[place][orient]['fat']   = probePoints[place][orient]['skin'] + dirVector*(Thickness['Skin'])
			probePoints[place][orient]['muscle']= probePoints[place][orient]['skin'] + dirVector*(Thickness['Skin'] + Thickness['Fat'])
			probePoints[place][orient]['bone']  = probePoints[place][orient]['skin'] + dirVector*(Thickness['Skin'] + Thickness['Fat'] + Thickness['Muscle'])

def ReadInVitroPositions(limbSegment, file):
	"""Read the xml to find the in vitro ultrasound positions"""
	basePoints = CreateNestedDictionary()
	baseDirections = CreateNestedDictionary()
	doc = et.parse(file)
	root = doc.getroot()
	bodyPart = root.find(limbSegment.replace('_',''))
	for loc in list(bodyPart):
		place = loc.tag
		place, orient = re.findall('[A-Z][a-z]*', place)	# Split on capitalization
		pos = loc.find("Anatomical").find("USPosition")
		x = float( pos.find("x").get('value') ) * 1000
		y = float( pos.find("y").get('value') ) * 1000
		z = float( pos.find("z").get('value') ) * 1000
		roll = float( pos.find("roll" ).get('value') )
		pitch= float( pos.find("pitch").get('value') )
		yaw  = float( pos.find("yaw"  ).get('value') )
		basePoints[place][orient]['skin'] = [x, y, z]
		baseDirections[place][orient] = [roll, pitch, yaw]
	return basePoints, baseDirections
def WritePygemParameterFile(pygemParamFile, controlPoints, adjustedPoints):
	rbfTemplate = """[Radial Basis Functions]
basis function: thin_plate_spline
radius: 10
[Control points]
original control points:{}
deformed control points:{}"""
	original = str(controlPoints).replace('[', '').replace(']', '').replace(whitespace, '\t')
	morphed = str(adjustedPoints).replace('[', '').replace(']', '').replace(whitespace, '\t')
	#
	pygemParams = open(pygemParamFile, 'w')
	pygemParams.write(rbfTemplate.format(original, morphed))
	pygemParams.close()

# Locations of all the files that will be used
if   'Leg' in limb:
	meshes = {	'bone'  : '/home/landisb/PycharmProjects/Data/Morphing/CMULTIS002-2/Segmentation/bone.stl',
				'fat'   : '/home/landisb/PycharmProjects/Data/Morphing/CMULTIS002-2/Segmentation/fat.stl',
				'muscle': '/home/landisb/PycharmProjects/Data/Morphing/CMULTIS002-2/Segmentation/muscle.stl',
				'skin'  : '/home/landisb/PycharmProjects/Data/Morphing/CMULTIS002-2/Segmentation/skin.stl'}
	#
	inVitroPositionsFile = '/home/landisb/PycharmProjects/Data/Morphing/FALSE_CADAVER_DATA/FALSE_CADAVER_UL_US_CT.xml'
	# inVitroPositionsFile = '/home/landisb/PycharmProjects/Data/Morphing/CMULTIS002-2/CMULTIS002-2_UL_US_CT.xml'
	#
	cadaverFile = '/home/landisb/PycharmProjects/Data/Morphing/CMULTIS002-2/Configuration/CMULTIS002-2.xml'
	subjectFile = '/home/landisb/PycharmProjects/Data/Morphing/MULTIS008-1/Configuration/MULTIS008-1.xml'
	#
	cadaverUltrasound = '/home/landisb/PycharmProjects/Data/Morphing/CMULTIS002-2/CMULTIS002-2_TA_inclusion.xml'
	# cadaverUltrasound = '/home/landisb/PycharmProjects/Data/Morphing/FALSE_CADAVER_DATA/FALSE_TA_inclusion.xml'
	subjectUltrasound = '/home/landisb/PycharmProjects/Data/Morphing/MULTIS008-1/MULTIS008-1_TA_inclusion.xml'
elif 'Arm' in limb:
	meshes = {	'bone'  : '/home/landisb/PycharmProjects/Data/Morphing/CMULTIS002-3/Segmentation/bone.stl',
				'fat'   : '/home/landisb/PycharmProjects/Data/Morphing/CMULTIS002-3/Segmentation/fat.stl',
				'muscle': '/home/landisb/PycharmProjects/Data/Morphing/CMULTIS002-3/Segmentation/muscle.stl',
				'skin'  : '/home/landisb/PycharmProjects/Data/Morphing/CMULTIS002-3/Segmentation/skin.stl'}
	#
	inVitroPositionsFile = '/home/landisb/PycharmProjects/Data/Morphing/CMULTIS002-3/CMULTIS002-3_UA_US_CT.xml'
	#
	cadaverFile = '/home/landisb/PycharmProjects/Data/Morphing/CMULTIS002-3/Configuration/CMULTIS002-3.xml'
	subjectFile = '/home/landisb/PycharmProjects/Data/Morphing/MULTIS008-1/Configuration/MULTIS008-1.xml'
	#
	cadaverUltrasound = '/home/landisb/PycharmProjects/Data/Morphing/CMULTIS002-3/TissueThickness/UltrasoundManual/CMULTIS002-3_TA_inclusion.xml'
	subjectUltrasound = '/home/landisb/PycharmProjects/Data/Morphing/MULTIS008-1/MULTIS008-1_TA_inclusion.xml'
else:	print 'limb value {} is unexpected.'.format(limb)

# First get probe positions from inVitro, ie cadaver test, model is based off this
probePoints, probeDirections = ReadInVitroPositions(limb, inVitroPositionsFile)

# Read ultrasound thicknesses and apply them to the probePositions
NEW_ReadAndApplyUltrasoundThickness(limb, cadaverUltrasound, probePoints, probeDirections)

# Read subject file xml to find experimentally measured sizes
cadaverMeasurements = ReadSubjectXml.GetNewData(cadaverFile)
subjectMeasurements = ReadSubjectXml.GetData(subjectFile)


# Find limb normal
normals = {}
for place, slice in probePoints.items():
	pointArray = np.array(  [point for strata in slice.values() for point in strata.values()]  )
	center, normal = planeFit(pointArray.T)				# First calculate the origin and normal
	normals[place] = normal
norm = sum( normals.values() )/len(normals)

# Assign projected probe BONE positions
projectedPoints = CreateNestedDictionary()	# probe points in morphed positions
for place, slice in probePoints.items():
	landmark = 'landmarkto{}Circumference'.format(place)
	central = 'landmarktoCentralCircumference'
	morphedLength  = subjectMeasurements[limb][landmark] - subjectMeasurements[limb][central]
	originalLength = cadaverMeasurements[limb][landmark] - cadaverMeasurements[limb][central]
	diff = (morphedLength-originalLength)*10		# to put it in mm not cm
	for orient, strata in slice.items():
		for layer, point in strata.items():
			if layer is 'bone':
				projectedPoints[place][orient][layer] = point + norm*diff

# Read subject ultrasound thicknesses and apply them outward from the bone to projectedPositions
ReadAndApplyUltrasoundThickness(limb, subjectUltrasound, projectedPoints, probeDirections, reverse=True)

# Turn dictionaries of points in numpy arrays for use with pygem
controlArray = np.array( [point for slice in probePoints.values()     for strata in slice.values() for point in strata.values()] )
morphedArray = np.array( [point for slice in projectedPoints.values() for strata in slice.values() for point in strata.values()] )

# # Morph meses with subprocees to speed things up.
# print 'Morphing...',
# WritePygemParameterFile('morphing.dat', controlArray,morphedArray)
# processes = [subprocess.Popen(['python', 'RunPygem.py', mesh, 'morphing.dat']) for layer, mesh in meshes.items()]
# print 'this takes a while...',
# for proc in processes:	proc.wait()
# print 'finished.'

# Draw it all to the screen
# Some containers to make drawing easier
planeColors = 	{'Proximal':colors.pink, 'Central':colors.green_pale, 'Distal':colors.cyan}
inVivo_Colors = {'Proximal':colors.red, 'Central':colors.green_dark, 'Distal':colors.blue}
legColors = {'skin':colors.sandy_brown, 'fat':colors.lemon_chiffon, 'muscle':colors.salmon, 'bone':colors.blanched_almond}
stlActors = []
pointActors = []
cutActors = []
planeActors = []
arrowActors = []


# Create graphics stuff
ren = tvtk.Renderer(background=colors.white)
renWin = tvtk.RenderWindow(size=(800,800))
renWin.add_renderer(ren)
if clickHighlight:
		style = Utility_tvtk.MouseHighLightActor(ren)
		iren = tvtk.RenderWindowInteractor(render_window=renWin, interactor_style=style.style)
else:	iren = tvtk.RenderWindowInteractor(render_window=renWin)

# Get actors for stls
for layer, mesh in meshes.iteritems():
	if   layer is 'fat':	op = 0
	elif layer is 'skin':	op = 0.0
	elif layer is 'bone':	op = 0.8
	elif layer is 'muscle':	op = 0.0
	stlActors.append(Utility_tvtk.MakeStl(mesh, legColors[layer], opacity=op))
	# stlActors.append(  Utility_tvtk.MakeStl(mesh, colors.cyan, opacity=op)  )
	# newMesh = mesh.replace('.', '_morphed.')
	# stlActors.append(  Utility_tvtk.MakeStl(newMesh, colors.sandy_brown, opacity=op)  )


# Get actors for probePoints
print 'In Vitro (light colors):'
for place, slice in probePoints.iteritems():
	post = slice.pop('Posterior')		# remove posterior points
	pointArray = np.array([point for strata in slice.values() for point in strata.values()])
	pointActor = Utility_tvtk.MakePoints(pointArray, planeColors[place], size=5)
	pointActors.append(pointActor)
	slice['Posterior'] = post			# re-add posterior points
	#
	for orient, strata in slice.items():	arrowActors.append(  Utility_tvtk.MakeArrow(probePoints[place][orient]['skin'], probePoints[place][orient]['bone'], planeColors[place], opacity=0.5)  )
	#
	center, normal = planeFit(pointArray.T)					# First calculate the origin and normal
	# size = 5 * dist(center, pointArray[0])
	# planeActors.append(Utility_tvtk.DrawPlane(center, normal, size, planeColors[place]))
	if 'Leg' in limb:
		if   place is 'Proximal':	centerAdjustment = center - probeWidth * normal
		elif place is 'Central':	centerAdjustment = center  # no adjustment
		elif place is 'Distal':		centerAdjustment = center + probeWidth * normal
	elif 'Arm' in limb:
		if   place is 'Proximal':	centerAdjustment = center+probeWidth*normal
		elif place is 'Central':	centerAdjustment = center		# no adjustment
		elif place is 'Distal':		centerAdjustment = center-probeWidth*normal
	else:	centerAdjustment = center
	centerAdjustment = center
	plane = tvtk.Plane(origin=centerAdjustment,normal=normal)  # Create the plane as a function, that will be used to cut the stl.
	#
	cut, circum = Utility_tvtk.MakeStlCut(meshes['skin'], plane, planeColors[place])
	cutActors.append(cut)
	# showCut, dontcare = Utility_tvtk.MakeStlCut(meshes['fat'], plane, colors.lemon_chiffon)
	# cutActors.append(showCut)
	# showCut, dontcare = Utility_tvtk.MakeStlCut(meshes['muscle'], plane, colors.pink)
	# cutActors.append(showCut)
	showCut, dontcare = Utility_tvtk.MakeStlCut(meshes['bone'], plane, colors.yellow)
	cutActors.append(showCut)
	measured = cadaverMeasurements[limb][place+'Circumference']*10
	percent = (measured-circum)/measured*100
	print '	{:8s} circumference (mm)	measured {:3.0f} calculated {:6.3f}	error {: 4.2f}.'.format(place, measured, circum, percent)
	#
	# Find hip and knee mesh locations
	if lengthOutput and place is 'Proximal':
		# Points at the hip and knee
		landmark = 'landmarktoCentralCircumference'
		# hip = probePoints['Central']['Lateral']['skin'] - normals['Proximal']*cadaverMeasurements[limb][landmark]*10
		knee= probePoints['Central']['Lateral']['skin'] + normals['Distal']*(cadaverMeasurements[limb]['Length']-cadaverMeasurements[limb][landmark])*10
		hip = probePoints['Central']['Lateral']['skin'] - norm*cadaverMeasurements[limb][landmark]*10
		# knee= probePoints['Central']['Lateral']['skin'] + norm*(cadaverMeasurements[limb]['Length']-cadaverMeasurements[limb][landmark])*10
		#
		stlMesh = stl.Mesh.from_file(meshes['skin'])  # read stl file
		stlPoints = stlMesh.vectors.reshape(-1,3)
		#
		distArray = np.linalg.norm(stlPoints-hip, axis=1)
		hipIndex = np.argmin(distArray)
		#
		distArray = np.linalg.norm(stlPoints-knee, axis=1)
		kneeIndex = np.argmin(distArray)
		#
		length = dist(stlPoints[hipIndex], stlPoints[kneeIndex])
		measured = cadaverMeasurements[limb]['Length']*10
		estimated = dist(hip, knee)
		percent = (measured-length)/measured*100
		print '	Length (mm)	 				measured {:3.0f} calculated {:6.3f}	error {: 4.2f}'.format(measured, length, percent)
		#
		# Highlight hip and knee points
		pointActor = Utility_tvtk.MakePoints(stlPoints[hipIndex].reshape(1, 3), planeColors['Proximal'], size=10)
		pointActors.append(pointActor)
		pointActor = Utility_tvtk.MakePoints(stlPoints[kneeIndex].reshape(1, 3), planeColors['Distal'], size=10)
		pointActors.append(pointActor)

# # Get actors for projectedPoints
# print 'In Vivo (dark colors):'
# for place, slice in projectedPoints.iteritems():
# 	pointArray = np.array([point for strata in slice.values() for point in strata.values()])
# 	pointActor = Utility_tvtk.MakePoints(pointArray, inVivo_Colors[place], size=5)
# 	pointActors.append(pointActor)
# 	#
# 	center, normal = planeFit(pointArray.T)
# 	if 'Leg' in limb:
# 		if place is 'Proximal':		centerAdjustment = center - probeWidth * normal
# 		elif place is 'Central':	centerAdjustment = center  # no adjustment
# 		elif place is 'Distal':		centerAdjustment = center + probeWidth * normal
# 	elif 'Arm' in limb:
# 		if   place is 'Proximal':	centerAdjustment = center+probeWidth*normal
# 		elif place is 'Central':	centerAdjustment = center		# no adjustment
# 		elif place is 'Distal':		centerAdjustment = center-probeWidth*normal
# 	else:	centerAdjustment = center
# 	centerAdjustment = center
# 	plane = tvtk.Plane(origin=centerAdjustment,normal=normal)  # Create the plane as a function, that will be used to cut the stl.
# 	#
# 	newMesh = meshes['skin'].replace('.', '_morphed.')
# 	cut,  circum = Utility_tvtk.MakeStlCut(newMesh, plane, inVivo_Colors[place])
# 	cutActors.append(cut)
# 	measured = subjectMeasurements[limb][place+'Circumference']*10
# 	percent = (measured-circum)/measured*100
# 	print '	{:8s} circumference (mm)	measured {:3.0f} calculated {:6.3f}	error {: 4.2f}.'.format(place, measured, circum, percent)
# 	#
# 	# Find morphed hip and knee mesh locations
# 	if lengthOutput and place is 'Proximal':		# proximal so it prints after all 3 circumferences
# 		newMesh = meshes['skin'].replace('.', '_morphed.')
# 		stlMesh = stl.Mesh.from_file(newMesh)  # read stl file
# 		newPoints = stlMesh.vectors.reshape(-1,3)
# 		#
# 		length = dist(newPoints[hipIndex], newPoints[kneeIndex])
# 		measured = subjectMeasurements[limb]['Length']*10
# 		percent = (measured-length)/measured*100
# 		print '	Length (mm)	 				measured {:3.0f} calculated {:6.3f}	error {: 4.2f}'.format(measured, length, percent)
# 		#
# 		# Highlight morphed hip and knee
# 		pointActor = Utility_tvtk.MakePoints(newPoints[hipIndex].reshape(1,3), inVivo_Colors['Proximal'], size=10)
# 		pointActors.append(pointActor)
# 		pointActor = Utility_tvtk.MakePoints(newPoints[kneeIndex].reshape(1,3), inVivo_Colors['Distal'], size=10)
# 		pointActors.append(pointActor)

# Add all the actors that have been made
for stl in stlActors:		ren.add_actor(stl)
for point in pointActors:	ren.add_actor(point)
for plane in planeActors:	ren.add_actor(plane)
for arrow in arrowActors:	ren.add_actor(arrow)
for cut in cutActors:		ren.add_actor(cut)

# Reset camera and viewing area
ren.reset_camera()
cam = ren.active_camera
cam.azimuth(30)
cam.elevation(30)
cam.dolly(1.5)
ren.reset_camera_clipping_range()

iren.initialize()
renWin.render()
renWin.window_name=os.path.basename(__file__).replace('.py', '')
iren.start()

#
# #
# # #	print TVTK_Object.class_editable_traits()
# #
#
