import sys
import numpy as np
from lxml import etree as et
from mpl_toolkits.mplot3d import axes3d
import matplotlib.pyplot as plt
import math
import copy
import os
import matplotlib.font_manager as font_manager

class rigid_body:
	def __init__(self, mat_name, mat_id, center_of_mass):
		self.name = mat_name
		self.id = int(mat_id)
		self.center_of_mass = center_of_mass
class Cylindrical_Joint:

	def __init__(self, name, axis, origin, body_a, body_b, joint_number):
		self.name = name
		self.body_a = body_a
		self.body_b = body_b
		self.axis = axis
		self.origin = origin
		self.joint_number = joint_number


def find_rigid_bodies(log_filename):
	""" find and store the rigid body information in a dictionary {material_name:rigid_body_object}"""
	rigid_bodies = {}
	# get rigid body information
	# use material data section not rigid body section because names are not included in rigid body section
	with open(log_filename) as f:
		for line in f:
			if 'MATERIAL DATA' in line:
				next(f)
				l = next(f)
				while len(l.strip()) != 0:  # keep going until the end of the material data section
					if 'rigid body' in l:
						l = l.split('-')
						mat_id = int(l[0])
						mat_name = l[1].split('(')[
							0].strip()  # get rid of any leading and trailing whitespace around name
						while 'center_of_mass' not in l:
							l = next(f)
						com_strings = l.split(':')[-1].split(',')
						com = []
						for i in com_strings:
							com.append(float(i))
						rigid_bodies[mat_name] = rigid_body(mat_name, mat_id, com)  # add to dictionary
					l = next(f)
				break

	return rigid_bodies

def find_input_filename(log_filename):
	"""get the name of the input file"""
	input_filename = None
	with open(log_filename) as f:
		for line in f:
			if 'FILES USED' in line:
				next(f)
				l=next(f)
				l=l.split(':')
				input_filename = l[-1].strip()
				break
	# check that the path is the same as the logfile
	base_file = os.path.basename(input_filename)
	path = os.path.dirname(log_filename)
	input_file = os.path.join(path, base_file)
	# if len(os.path.dirname(input_filename))>0:
	#	 input_file = input_filename
	# else:
	#	 path = os.path.dirname(log_filename)
	#	 input_file = os.path.join(path, os.sep, input_filename)
	return input_file

def collect_data(log_filename):
	""" data stored in nested dictionaries. {data_name:{body_id:data_as_array}} """
	time_steps = [0] # zero is just a placeholder, will remove later
	all_data = {}
	#
	with open(log_filename) as f:
		for line in f:
			if 'Data Record' in line:
				data_record_num = int(line.split('#')[-1])
				next(f) # ===== line
				next(f) # step number line
				time_line = next(f)
				time = float(time_line.split('=')[-1])
				if data_record_num == 1: # its its a new set of data records. ie a new solved time step
					time_steps.append(time)
				data_name_line = next(f)
				data_name = data_name_line.split('=')[-1].strip()
				l = next(f)
				# check if this type of data has been added. if not add it
				if data_name not in all_data.keys():
					all_data[data_name] = {}
				# keep going until the end of this data record
				while len(l.strip()) != 0:
					line_as_list = l.split(' ')
					try:
						body_id = int(line_as_list.pop(0))
					except: break # in case there's a line of string
					# check if this body has been added, if not add it
					if body_id not in all_data[data_name].keys():
						all_data[data_name][body_id] = []
					all_data[data_name][body_id].append([float(x) for x in line_as_list])
					l = next(f)
	time_steps = np.asarray(time_steps[1:]) # remove placeholder and turn into array
	# turn data lists into arrays
	for data_name, data_dict in iter(all_data.items()):
		try:
			for body_id, data_set in iter(data_dict.items()):
				all_data[data_name][body_id] = np.asarray(data_set)
		except:
			try:
				np.asarray(data_dict)
			except:
				pass
	return all_data, time_steps

def add_initial_values(all_data, time_steps, rigid_bodies):

	# add t=0 zero to the time_steps array
	time_steps = np.insert(time_steps,0,[0])

	# add the initial_position of the rigid_bodies to the COM data
	com_data = all_data['center_of_mass']
	for rb in rigid_bodies.values():
		rb_id = rb.id
		rb_com_initial = rb.center_of_mass
		rb_com_data = com_data[rb_id]
		com_data[rb_id] = np.insert(rb_com_data, 0, rb_com_initial, axis=0)

	# add the initial rotation_quaterion to the initial time step
	rot_quat_data = all_data['rotation_quaternion']
	rot_quat_initial = [0, 0, 0, 1]
	for rb_id, rot_quat in iter(rot_quat_data.items()):
		rot_quat_data[rb_id] = np.insert(rot_quat, 0, rot_quat_initial, axis=0)

	# to all other forces and moments, add 0,0,0 to the data
	forces_initial = [0.0,0.0,0.0]
	for data_name, data_dict in iter(all_data.items()):
		if data_name == 'rotation_quaternion' or data_name == 'center_of_mass':
			pass # already dealt with these
		else:
			for id, data in iter(data_dict.items()):
				data_dict[id] = np.insert(data, 0, forces_initial, axis=0)

	return all_data, time_steps

def fixed_axis(rb_rotation_quaternion, joint_info):
	# get the initial axis of the joint
	a = joint_info.axis
	initial_axis = copy.deepcopy(a)
	# find the rotation of the rigid body connected to the fixed axis
	rotation = rotation_matrix_from_quaternion(rb_rotation_quaternion)
	# apply the rotation matrix to the initial axis to get the direction of the axis at each time point
	initial_axis = np.asarray(initial_axis)
	joint_axis = np.matmul(rotation, initial_axis)
	# normalize, just in case
	norms = np.linalg.norm(joint_axis, axis=1)
	joint_axis = np.divide(joint_axis, np.reshape(norms, (len(norms),1)))
	return joint_axis

def get_joint_axes(rigid_bodies, rot_quat_data, constraint_info):
	""" get the axes of the joints at each time step"""
	fixed_joints = {}
	# to account for right or left knee axis naming conventions save the names of all the axes how they appear
	# in the constraint definitions
	for name in constraint_info.keys():
		if "Patellar" in name:		# patellofemoral axes
			if "Flexion" in name:
				fixed_joints[name] = 'FMB'
				pat_flex_name = name
			elif "Tilt" in name:
				fixed_joints[name] = 'PTB'
				pat_tilt_name = name
			elif "Rotation" in name:
				pat_rot_name = name
		else:						# tibiofemoral axes
			if "Extension" in name:
				fixed_joints[name] = 'FMB'
				flex_ext_name = name
			elif "Internal" in name:
				fixed_joints[name] = 'TBB'
				ext_int_name = name
			elif "Abduction" in name:
				abd_add_name = name
	joint_axes = {} # save all the joint axes by their name, and a matrix representing the axis at every time point
	for ax_name,rb_name in iter(fixed_joints.items()):
		joint_info = constraint_info[ax_name]
		rigid_body_id = rigid_bodies[rb_name].id
		rot_quat = rot_quat_data[rigid_body_id]
		axes = fixed_axis(rot_quat, joint_info)
		joint_axes[ax_name] = axes
	try:	joint_axes[abd_add_name] = np.cross(joint_axes[ext_int_name],joint_axes[flex_ext_name])
	except: pass
	try:	joint_axes[pat_rot_name] = np.cross(joint_axes[pat_tilt_name],joint_axes[pat_flex_name])
	except: pass
	return joint_axes

def kinematics_from_transform(T):
	""" extract the rotations and translations along the joints axes from the transformation matrix
	use: https://simtk.org/plugins/moinmoin/openknee/Infrastructure/ExperimentationMechanics?action=AttachFile&do=view&target=Knee+Coordinate+Systems.pdf
	page 6"""
	beta = np.arcsin(T[:, 0, 2])
	alpha = np.arctan2(-T[:, 1, 2], T[:, 2, 2])
	gamma = np.arctan2(-T[:, 0, 1], T[:, 0, 0])
	#
	ca = np.cos(alpha)
	sa = np.sin(alpha)
	cb = np.cos(beta)
	sb = np.sin(beta)
	#
	b = np.multiply(T[:, 1, 3], ca) + np.multiply(T[:, 2, 3], sa)
	c = np.divide(np.multiply(T[:, 2, 3], ca) - np.multiply(T[:, 1, 3], sa), cb)
	a = T[:, 0, 3] - np.multiply(c, sb)
	return a,b,c,alpha,beta,gamma

def joint_kinematics_2(bone_axes, rigid_bodies, rotation_quaternion_data, com_data, constraint_info):
	Fem_in_world = BoneinWorld_Transform(bone_axes, rotation_quaternion_data, rigid_bodies, 'FMB', com_data)
	Tib_in_world = BoneinWorld_Transform(bone_axes, rotation_quaternion_data, rigid_bodies, 'TBB', com_data)
	world_in_Fem = np.linalg.inv(Fem_in_world)
	T_in_F = np.matmul(world_in_Fem, Tib_in_world)
	a,b,c,alpha,beta,gamma = kinematics_from_transform(T_in_F)
	try:
		Pat_in_world = BoneinWorld_Transform(bone_axes, rotation_quaternion_data, rigid_bodies, 'PTB', com_data)
		P_in_F = np.matmul(world_in_Fem, Pat_in_world)
		a_pat,b_pat,c_pat,alpha_pat,beta_pat,gamma_pat = kinematics_from_transform(P_in_F)
	except: pass
	all_kinematics = {}
	for joint_info in constraint_info.values():
		if 'Patellar' in joint_info.name:
			if 'Extension' in joint_info.name:
				all_kinematics[joint_info.name] = (a_pat - a_pat[0], np.degrees(alpha_pat - alpha_pat[0]))
			elif 'Rotation' in joint_info.name:
				all_kinematics[joint_info.name] = (b_pat - b_pat[0], np.degrees(beta_pat - beta_pat[0]))
			elif 'Tilt' in joint_info.name:
				all_kinematics[joint_info.name] = (c_pat - c_pat[0], np.degrees(gamma_pat - gamma_pat[0]))
		elif 'Extension' in joint_info.name:
			all_kinematics[joint_info.name] = (a-a[0], np.degrees(alpha-alpha[0]))
		elif 'Abduction' in joint_info.name:
			all_kinematics[joint_info.name]= (b-b[0], np.degrees(beta-beta[0]))
		elif 'Internal' in joint_info.name:
			all_kinematics[joint_info.name] = (c-c[0], np.degrees(gamma-gamma[0]))
		else:
			pass
	return all_kinematics

def joint_kinematics(center_of_mass_data, rotation_quaternion_data, constraint_info, joint_axes):
	"""Calculate the kinematics of the rigid cylindrical joints. save in translation, rotation dictionary with the same names as the constraints"""
	all_kinematics = {}
	# joint_info = constraint_info[joint_name]
	for joint_info in constraint_info.values():
		# find the rigid_bodies connecting the joint
		body_a = joint_info.body_a
		body_b = joint_info.body_b
		# translation of the joint
		body_a_data = center_of_mass_data[body_a]
		body_b_data = center_of_mass_data[body_b]
		# translation of each rigid body
		b_translation_vec = body_b_data-body_b_data[0]
		a_translation_vec = body_a_data-body_a_data[0]
		# translation along joint axis is the dot product
		joint_axis = joint_axes[joint_info.name]
		b_trans = np.sum(joint_axis * b_translation_vec, axis=1)
		a_trans = np.sum(joint_axis * a_translation_vec, axis=1)
		# the joint translation is the relative translation of body a and b along the joint axis
		translation = b_trans - a_trans
		# rotation of the joint
		body_a_rot_quat = rotation_quaternion_data[body_a]
		body_b_rot_quat = rotation_quaternion_data[body_b]
		# rotation matrix of each body in the world coordinate
		body_a_rot_mat = rotation_matrix_from_quaternion(body_a_rot_quat)
		body_b_rot_mat = rotation_matrix_from_quaternion(body_b_rot_quat)
		# to get rotation of body b with respect to body a we need Rab = Rwa'*Rwb
		a_rot_mat_trans = np.transpose(body_a_rot_mat, axes = (0,2,1))
		relative_rot_mat = np.matmul(a_rot_mat_trans,body_b_rot_mat)
		# convert this to euler axis-angle to get the angle of rotation, and direction
		rotation_angle, rotation_axis = euler_axis_angle_from_rotation_matrix(relative_rot_mat)
		rotation_axis_norm = np.linalg.norm(rotation_axis, axis=1)
		rotation = np.degrees(rotation_angle) # no direction yet
		#
		dot = np.sum(joint_axis * rotation_axis, axis=1)
		cos_theta = np.divide(dot, rotation_axis_norm)  # should be all 1's and -1's (or close to it)
		#
		rotation_direction = np.around(cos_theta)
		rotation = np.multiply(rotation, rotation_direction)
		# get rid of nans that occured by dividing by zero (where rotation was zero). convert nans to zero
		rotation = np.nan_to_num(rotation)
		all_kinematics[joint_info.name] = (translation, rotation)
	return all_kinematics

def euler_axis_angle_from_quaternion(q):
	"""return the euler angle and axis given the rotation quaternion"""
	R = rotation_matrix_from_quaternion(q)
	angle, axis = euler_axis_angle_from_rotation_matrix(R)
	return angle, axis

def euler_axis_angle_from_rotation_matrix(R):
	""" calculate the euler axis, angle given the rotation matrix"""
	angle = np.arccos((R[:, 0, 0] + R[:, 1, 1] + R[:, 2, 2] - 1) / 2)
	e1 = np.divide((R[:, 2, 1] - R[:, 1, 2]), (2 * np.transpose(np.sin(angle))))
	e2 = np.divide((R[:, 0, 2] - R[:, 2, 0]), (2 * np.transpose(np.sin(angle))))
	e3 = np.divide((R[:, 1, 0] - R[:, 0, 1]), (2 * np.transpose(np.sin(angle))))
	axis = np.zeros((len(e1), 3))
	axis[:, 0] = e1
	axis[:, 1] = e2
	axis[:, 2] = e3
	#### note will only work if theta is not a multiple of pi
	return angle, axis

def euler_angles_from_quaternion(q):
	qw = q[:,3]
	qx = q[:,0]
	qy = q[:,1]
	qz = q[:,2]
	theta1 = np.arctan2((2*(np.multiply(qw,qx) + np.multiply(qy,qz))),(1-2*(np.square(qx)+np.square(qy))))
	theta2 = np.arcsin(2*(np.multiply(qw,qy) - np.multiply(qz,qx)))
	theta3 = np.arctan2((2*(np.multiply(qw,qz) + np.multiply(qx,qy))),(1-2*(np.square(qy)+np.square(qz))))
	return theta1,theta2,theta3

def rotation_matrix_from_quaternion(q):
	""" calculate rotation matrix/ matrices from the quaterion data"""
	R = np.zeros((len(q) ,3 ,3))
	qi = q[: ,0]
	qj = q[:, 1]
	qk = q[:, 2]
	qr = q[:, 3]
	s = np.sqrt(np.power(qi ,2) +np.power(qj, 2) + np.power(qk, 2) + np.power(qr, 2))
	R[:, 0, 0] = 1 - (2 * s * (np.power(qj, 2) + np.power(qk, 2)))
	R[:, 0, 1] = 2 * s * (np.multiply(qi, qj) - np.multiply(qk, qr))
	R[:, 0, 2] = 2 * s * (np.multiply(qi, qk) + np.multiply(qj, qr))
	R[:, 1, 0] = 2 * s * (np.multiply(qi, qj) + np.multiply(qk, qr))
	R[:, 1, 1] = 1 - (2 * s * (np.power(qi, 2) + np.power(qk, 2)))
	R[:, 1, 2] = 2 * s * (np.multiply(qj, qk) - np.multiply(qi, qr))
	R[:, 2, 0] = 2 * s * (np.multiply(qi, qk) - np.multiply(qj, qr))
	R[:, 2, 1] = 2 * s * (np.multiply(qj, qk) + np.multiply(qi, qr))
	R[:, 2, 2] = 1 - (2 * s * (np.power(qi, 2) + np.power(qj, 2)))
	return R

def GetConstraintInfo(feb_filename):
	""" extract the rigid bodies, axis, origin, for each of the cylindrical joints in the febio file"""
	Febio_tree = et.parse(feb_filename)
	Febio_spec_root = Febio_tree.getroot()
	# LoadingStep_Section = FebCustomization_p3.get_section("Step", Febio_spec_root, use_other_attribute='name', attribute_value='LoadingStep')
	LoadingStep_Section = Febio_spec_root.find("Step")
	# Constraint_Section = FebCustomization_p3.get_section("Constraints", LoadingStep_Section)
	Constraint_Section = LoadingStep_Section.find("Constraints")
	counter = 0
	constraint_info = {}
	for constraint in Constraint_Section:
		try:
			constraint_type = constraint.attrib["type"] # if its a comment this will cause an error
			if constraint_type == "rigid cylindrical joint":
				constraint_name = constraint.attrib["name"]
				# joint_axis_str = FebCustomization_p3.get_section("joint_axis", constraint).text
				joint_axis_str = constraint.find("joint_axis").text
				joint_axis_list = joint_axis_str.split(',')
				axis = [float(x) for x in joint_axis_list]
				# joint_origin_str = FebCustomization_p3.get_section("joint_origin", constraint).text
				joint_origin_str = constraint.find("joint_origin").text
				joint_origin_list = joint_origin_str.split(',')
				origin = [float(x) for x in joint_origin_list]
				# body_a_id = int(FebCustomization_p3.get_section("body_a", constraint).text)
				body_a_id = int(constraint.find("body_a").text)
				# body_b_id = int(FebCustomization_p3.get_section("body_b", constraint).text)
				body_b_id = int(constraint.find("body_b").text)
				counter += 1
				cylindrical_joint = Cylindrical_Joint(constraint_name, axis, origin, body_a_id, body_b_id, counter)
				constraint_info[constraint_name] = cylindrical_joint
			else:
				counter += 1  # count the constraint number but ignore otherwise
		except:
			pass
	return constraint_info

def get_bone_axes(model_properties_xml):
	""" find the axes of the bones in the model properties file"""
	ModelProperties_tree = et.parse(model_properties_xml)
	ModelProperties = ModelProperties_tree.getroot()
	landmarks = ModelProperties.find('Landmarks')
	bone_axes = {}

	def extract_axes(bone_first_letter):
		x = landmarks.find('X{}_axis'.format(bone_first_letter)).text.split(',')
		y = landmarks.find('Y{}_axis'.format(bone_first_letter)).text.split(',')
		z = landmarks.find('Z{}_axis'.format(bone_first_letter)).text.split(',')
		axes  = [x, y, z]
		axes = [[float(i) for i in a] for a in axes]
		return axes

	try:
		tibia = extract_axes('t')
		bone_axes['TBB'] = tibia
	except:	pass
	try:
		femur = extract_axes('f')
		bone_axes['FMB'] = femur
	except:	pass
	try:
		patella = extract_axes('p')
		bone_axes['PTB'] = patella
	except:	pass
	return bone_axes

def BoneinWorld_Transform(all_bone_axes, rotation_quaternion_data, rigid_bodies, bone_name, com_data):
	"""create the transfromation matrix to transform a vector from world coordinates to femur
	 coordinates at each time step"""
	# initial bone coordinate system defined in world coordinates
	bone_axes = all_bone_axes[bone_name]
	Bone_com = com_data[rigid_bodies[bone_name].id]
	# rotation matrix at each step representing the rotation of the femur in world coordinates
	rotation_matrix = rotation_matrix_from_quaternion(rotation_quaternion_data[rigid_bodies[bone_name].id])
	# apply rotation matrix to the transpose of the bone axes to get the orientation of the bone axes at every time step
	M = np.transpose(bone_axes)
	Bone_rot_in_World = np.matmul(rotation_matrix, M)
	# create the full transformation from bone axes,and center of mass location at each time step
	Bone_in_World = np.zeros((len(Bone_rot_in_World),4,4))
	Bone_in_World[:,3,3] = 1
	Bone_in_World[:,0:3,0:3] = Bone_rot_in_World
	Bone_in_World[:,0:3,3] = Bone_com
	return Bone_in_World

def Plot3DMotion(xyz_data, time_steps, figure_title, png_file_name, units=None):
	# try to create images, if fails just csv
	if units is not None:
		u=units
	else:
		u='mm'
	try:
		fig = plt.figure(figsize=(6.4,6.4))
		fig.suptitle(figure_title)
		data_list = [("time_steps",time_steps)]
		#
		plt.subplot(311)
		plt.plot(time_steps, xyz_data[:,0])
		data_list.append(("X ["+u+"]",xyz_data[:,0]))
		plt.xlabel('Time')
		plt.ylabel('X ['+u+']')
		#
		plt.subplot(312)
		plt.plot(time_steps, xyz_data[:, 1])
		data_list.append(("Y ["+u+"]", xyz_data[:, 1]))
		plt.xlabel('Time')
		plt.ylabel('Y ['+u+"]")
		#
		plt.subplot(313)
		plt.plot(time_steps, xyz_data[:, 2])
		data_list.append(("Z ["+u+"]", xyz_data[:, 2]))
		plt.xlabel('Time')
		plt.ylabel('Z ['+u+"]")
		#
		plt.savefig(png_file_name)
		#plt.show()
		plt.close()
	except:
		print("failed creating images, trying to create csv")
		data_list = [("time_steps", time_steps)]
		data_list.append(("X ["+u+"]", xyz_data[:, 0]))
		data_list.append(("Y ["+u+"]", xyz_data[:, 1]))
		data_list.append(("Z ["+u+"]", xyz_data[:, 2]))
	return data_list

def ProcessAndPlotTranslations(rigid_bodies, com_data, rotation_quaternion_data, time_steps, bone_axes):
	Fem_in_World = BoneinWorld_Transform(bone_axes, rotation_quaternion_data, rigid_bodies, 'FMB', com_data)
	World_in_Fem = np.linalg.inv(Fem_in_World)

	def get_relative_translation(bone_name):
		Bone_in_World = BoneinWorld_Transform(bone_axes, rotation_quaternion_data, rigid_bodies, bone_name, com_data)
		Bone_in_Fem = np.matmul(World_in_Fem, Bone_in_World)
		# COM position is just the translation part of the matrix
		T = Bone_in_Fem[:,0:3,3]
		bone_relative_to_femur = T - T[0]
		return bone_relative_to_femur

	try:	# If the Tibia is in the model, find its translation relative to femur
		tibia_relative_to_femur = get_relative_translation('TBB')
		# plot the translations in 3D and then save to csv
		data_list = Plot3DMotion(tibia_relative_to_femur, time_steps,
					 'Translation of tibia origin relative to femur origin \n in femoral coordinate system', 'Tibia_Translation.png')
		save_to_csv(data_list, "Tibia_Translation.csv")
	except KeyError:	pass
	try:	# If the Patella is in the model, find its translation relative to femur
		patella_relative_to_femur = get_relative_translation('PTB')
		data_list = Plot3DMotion(patella_relative_to_femur, time_steps,
					 'Translation of patella origin relative to femur origin \n in femoral coordinate system', 'Patella_Translation.png')
		save_to_csv(data_list, "Patella_Translation.csv")
	except KeyError:	pass

def ProcessAndPlotBoneKinetics(rigid_bodies, rigid_moments, rigid_forces, time_steps, bone_axes, com_data):
	"""plot the bone forces and moments"""
	Tibia_rb = rigid_bodies["TBB"]
	Tibia_id = Tibia_rb.id
	#
	Femur_rb = rigid_bodies["FMB"]
	Femur_id = Femur_rb.id
	#
	Fibula_rb = rigid_bodies["FBB"]
	Fibula_id = Fibula_rb.id
	#
	force_tibia =rigid_forces[Tibia_id]
	moments_tibia=rigid_moments[Tibia_id]
	#
	force_fibula = rigid_forces[Fibula_id]
	moments_fibula= rigid_moments[Fibula_id]
	#
	force_femur= rigid_forces[Femur_id]
	moments_femur = rigid_moments[Femur_id]
	# move the fibula forces to the tibia origin
	tibia_com_data = com_data[Tibia_id]
	fibula_com_data= com_data[Fibula_id]
	femur_com_data =com_data[Femur_id]
	# vector from tibia to fibula origin - fibula and tibia are fixed so this shouldn't change,
	# but just in case of future use in a model where they are not fixed, assume a moving COM
	vec_fbo = fibula_com_data - tibia_com_data
	moments_fibula_tibia = np.cross(vec_fbo, force_fibula) + moments_fibula + moments_tibia
	force_fibula_tibia = force_fibula + force_tibia
	# convert loads and moments to the tibia CS
	tibia_axes = np.asarray(bone_axes["TBB"])
	T_tib_in_image = np.linalg.inv(tibia_axes.T)
	#
	Forces_Tib_Fib_TCS = np.matmul(T_tib_in_image, np.reshape(force_fibula_tibia, (len(force_fibula_tibia),3,1)))
	Moments_Tib_Fib_TCS = np.matmul(T_tib_in_image, np.reshape(moments_fibula_tibia,(len(moments_fibula_tibia),3,1)))
	#
	Forces_Tib_Fib_TCS = np.reshape(Forces_Tib_Fib_TCS, (len(Forces_Tib_Fib_TCS,),3))
	Moments_Tib_Fib_TCS= np.reshape(Moments_Tib_Fib_TCS, (len(Moments_Tib_Fib_TCS),3))
	#
	Forces_Fem_TCS = np.matmul(T_tib_in_image, np.reshape(force_femur, (len(force_femur),3,1)))
	Moments_Fem_TCS = np.matmul(T_tib_in_image, np.reshape(moments_femur, (len(moments_femur),3,1)))
	#
	Forces_Fem_TCS = np.reshape(Forces_Fem_TCS, (len(Forces_Fem_TCS,),3))
	Moments_Fem_TCS= np.reshape(Moments_Fem_TCS, (len(Moments_Fem_TCS),3))
	# forces_tib_fib_TCS = {'Tibia_x Load': Forces_Tib_Fib_TCS[:,0],
	#		   'Tibia_y Load': Forces_Tib_Fib_TCS[:,1],
	#		   'Tibia_z Load':Forces_Tib_Fib_TCS[:,2]}
	# moments_tib_fib_TCS = {'Tibia_x Moment': Moments_Tib_Fib_TCS[:,0],
	#		   'Tibia_y Moment': Moments_Tib_Fib_TCS[:,1],
	#		   'Tibia_z Moment':Moments_Tib_Fib_TCS[:,2]}
	model_names = {"Tibia_and_Fibula_Kinetics_in_TibiaCS": (Forces_Tib_Fib_TCS, Moments_Tib_Fib_TCS),
				   "Tibia_and_Fibula_Kinetics_in_ImageCS":(force_fibula_tibia, moments_fibula_tibia),
				   "Femur_Kinetics_in_TibiaCS":(Forces_Fem_TCS, Moments_Fem_TCS),
				   "Femur_Kinetics_in_ImageCS":(force_femur, moments_femur)}
	force_axes = {'x_load':0, 'y_load':1,'z_load':2}
	moment_axes = {'x_moment':0, 'y_moment':1,'z_moment':2}
	# #plot centers of mass for checking for
	# fig = plt.figure(figsize=(12.8, 7.2))
	# fig.suptitle("centers_of_mass")
	# ax = fig.add_subplot(121)
	# plt.title('Femur')
	# plt.xlabel('Time')
	# plt.ylabel('COM')
	# ax.plot(time_steps, femur_com_data[:, 0], label='x')
	# ax.plot(time_steps, femur_com_data[:, 1], label='y')
	# ax.plot(time_steps, femur_com_data[:, 2], label='z')
	# plt.legend(loc="upper left")
	# ax = fig.add_subplot(122)
	# plt.title('Tibia')
	# plt.xlabel('Time')
	# plt.ylabel('COM')
	# ax.plot(time_steps, tibia_com_data[:, 0], label='x')
	# ax.plot(time_steps, tibia_com_data[:, 1], label='y')
	# ax.plot(time_steps, tibia_com_data[:, 2], label='z')
	# plt.legend(loc="upper left")
	# plt.savefig('COM' + '.png')
	for title, kinetics in iter(model_names.items()):
		forces = kinetics[0]
		moments = kinetics[1]
		try:	# prepare the figure
			fig = plt.figure(figsize=(12.8, 7.2))
			fig.suptitle(title)
			data_list = [("time_steps", time_steps)]
			#
			ax = fig.add_subplot(121)
			plt.title('Force')
			plt.xlabel('Time')
			plt.ylabel('Force [N]')
			#
			for ax_name, idx in iter(force_axes.items()):
				ax.plot(time_steps, forces[:,idx], label=ax_name)
				data_list.append((ax_name , forces[:,idx]))
			plt.legend(loc="upper left")
			#
			ax = fig.add_subplot(122)
			plt.title('Moment')
			plt.xlabel('Time')
			plt.ylabel('Moment [Nmm]')
			#
			for ax_name, idx in iter(moment_axes.items()):
				ax.plot(time_steps, moments[:,idx], label=ax_name)
				data_list.append((ax_name, moments[:,idx]))
			plt.legend(loc="upper left")
			plt.savefig(title+'.png')
			# plt.show()
			save_to_csv(data_list, title+".csv")
			plt.close()
		except:		# just save the data
			data_list = [("time_steps", time_steps)]
			for ax_name, idx in iter(force_axes.items()):
				data_list.append((ax_name, forces[:,idx]))
			for ax_name, idx in iter(moment_axes.items()):
				data_list.append((ax_name, moments[:,idx]))
			save_to_csv(data_list, title+".csv")

def ProcessAndPlotJointKinetics(rigid_connector_moments, rigid_connector_forces, constraint_info, time_steps, joint_axes):
	"""plot the forces and moments in the tibiofemoral joint"""
	axes_forces = {}
	axes_moments = {}
	connector_forces = {}
	connector_moments = {}
	# tibiofemoel kinetcs as actuator forces
	for joint_name, joint_info in iter(constraint_info.items()):
		if "Patellar" in joint_name:	pass
		else: # if its a tibiofemoral axis
			joint_num = joint_info.joint_number
			force = rigid_connector_forces[joint_num]
			moment = rigid_connector_moments[joint_num]
			joint_axis = joint_axes[joint_name]
			# get the projection of the force and moment on the axis at each time step
			force_along_axis = np.sum(force * joint_axis, axis=1)
			moment_along_axis = np.sum(moment * joint_axis, axis=1)
			# actuator moments and forces along joint axis
			axes_forces[joint_name] = force_along_axis
			axes_moments[joint_name] = moment_along_axis
			# # net forces and moments in image coordinate system
			# connector_forces[joint_name] =force
			# connector_moments[joint_name] = moment
	# try creating figures and csv, if fails just do csv
	try:	# prepare the figure
		fig = plt.figure(figsize=(12.8,7.2))
		fig.suptitle("Tibiofemoral_Constraint_Kinetics")
		data_list = [("time_steps", time_steps)]
		#
		ax = fig.add_subplot(121)
		plt.title('Force')
		plt.xlabel('Time')
		plt.ylabel('Force [N]')
		#
		for ax_name, force in iter(axes_forces.items()):
			ax.plot(time_steps, force, label=ax_name+'_axis')
			data_list.append((ax_name + '_Force [N]', force))
		plt.legend(loc="upper left")
		#
		ax = fig.add_subplot(122)
		plt.title('Moment')
		plt.xlabel('Time')
		plt.ylabel('Moment [Nmm]')
		#
		for ax_name, moment in iter(axes_moments.items()):
			ax.plot(time_steps, moment, label = ax_name+'_axis')
			data_list.append((ax_name + '_Moment [Nmm]', moment))
		plt.legend(loc="upper left")
		#
		plt.savefig('Tibiofemoral_Kinetics.png')
		#plt.show()
		save_to_csv(data_list, "Tibiofemoral_Kinetics.csv")
		plt.close()
	except:
		print('failed creating images, trying to create csv')
		data_list = [("time_steps", time_steps)]
		for ax_name, force in iter(axes_forces.items()):
			data_list.append((ax_name + '_Force [N]', force))
		for ax_name, moment in iter(axes_moments.items()):
			data_list.append((ax_name + '_Moment [Nmm]', moment))
		save_to_csv(data_list, "Tibiofemoral_Kinetics.csv")

def ProcessAndPlotFemurKinematics(rigid_bodies, rotation_quaternion_data, time_steps, bone_axes, com_data):
	# get transformation of femur in image cs
	Femur_in_World = BoneinWorld_Transform(bone_axes, rotation_quaternion_data, rigid_bodies, 'FMB', com_data)
	# this gives the relative rotation from intial position
	femur_rot_quat = rotation_quaternion_data[rigid_bodies['FMB'].id]
	rot_x, rot_y,rot_z = euler_angles_from_quaternion(femur_rot_quat)
	# # center of mass positions are
	# pos_x = Femur_in_World[:,0,3]
	# pos_y = Femur_in_World[:,1,3]
	# pos_z = Femur_in_World[:,2,3]
	relative_trans = Femur_in_World[:,0:3,3]- Femur_in_World[:,0:3,3][0]
	data_list = Plot3DMotion(relative_trans, time_steps,
							 'Translation of femur origin in image coordinate system',
							 'Femur_in_Image_Translation.png')
	save_to_csv(data_list, "Femur_in_Image_Translation.csv")
	data_list = Plot3DMotion(np.array([rot_x,rot_y,rot_z]).T, time_steps,
							 'Rotations of Femur in image coordinate system', 'Femur_in_Image_Rotation.png', units='rad')
	save_to_csv(data_list, "Femur_in_Image_Rotation.csv")

def save_to_csv(data_list, title):
	"""data contains a list of tuples of the (headers,data) for each column of the file """
	header_string = ''
	all_data = np.zeros((len(data_list[0][1]),len(data_list)))
	for i,tup in enumerate(data_list):
		header_string += tup[0] + ','
		all_data[:,i] = tup[1]
	np.savetxt(title, all_data, delimiter=",",header=header_string)

def run_all_in_file(xml_file):
	file_tree = et.parse(xml_file)
	file_info = file_tree.getroot()
	#
	gen_files = file_info.find("general_files")
	feb_file = gen_files.find("febio_file").text
	mod_props_file= gen_files.find("model_properties_file").text
	#
	dirname = os.path.dirname(feb_file)
	models = file_info.find("Models")
	# collect the names of all the models to run
	all_log_files = []
	for mod in models:
		if mod.tag is et.Comment:
			continue
		model_name = mod.attrib["name"]
		log_file = model_name+'.log'
		all_log_files.append(log_file)
	# run the models
	for lf in all_log_files:
		os.chdir(dirname) # re-enter this directory each time,as directory changes in make graphs
		folder_name = 'Processed_Results_'+lf.split('.')[0]
		MakeGraphs(lf, mod_props_file, folder_name)

class SimulationStructure:
	def __init__(self, dir, name, append=True):
		self.name = name
		self.append = append
		#
		self.log_filename = os.path.join(dir, 'FeBio_custom.log')
		self.model_properties_xml = os.path.join(dir, 'ModelProperties.xml')
		self.febio_input_filename = find_input_filename(self.log_filename)	# get the febio input filename
		self.bone_axes = get_bone_axes(self.model_properties_xml)				# get the axes of the bones from the model properties file
		# get the names of the constraints to match with the numbering in the log file, from the febio input file
		self.constraint_info = GetConstraintInfo(self.febio_input_filename) # {name: cylindrical_joint}
		self.rigid_bodies = find_rigid_bodies(self.log_filename)		# parse the logfile for the rigid body info
		self.all_data, self.time_steps = collect_data(self.log_filename)	# parse the logfile for the time steps and data
		self.all_data, self.time_steps = add_initial_values(self.all_data, self.time_steps, self.rigid_bodies)
		#
		self.com_data = self.all_data['center_of_mass']
		self.rot_quat_data = self.all_data['rotation_quaternion']
		self.rigid_connector_moments = self.all_data['Rigid_Connector_Moment']
		self.rigid_connector_forces = self.all_data['Rigid_Connector_Force']
		self.rigid_moments = self.all_data['Reaction_Torques']
		self.rigid_forces = self.all_data['Reaction_Forces']
		#
		self.joint_axes = get_joint_axes(self.rigid_bodies, self.rot_quat_data, self.constraint_info)
		
	def ProcessJointKinematics(self):
		"""	Bens quick and dirty hack of ProcessAndPlotJointKinematics"""
		self.all_kinematics = joint_kinematics_2(self.bone_axes, self.rigid_bodies, self.rot_quat_data,
													self.com_data, self.constraint_info)
		self.tibiofemoral_translations = {}
		self.tibiofemoral_rotations = {}
		self.patellofemoral_translations = {}
		self.patellofemoral_rotations = {}
		for name, kin in iter(self.all_kinematics.items()):
			if self.append:	name = '{}_{}'.format(name, self.name)
			else:			name = '{}_{}'.format(self.name, name)
			if "Patellar" in name:
				self.patellofemoral_translations[name] = kin[0]
				self.patellofemoral_rotations[name] = kin[1]
			else:
				self.tibiofemoral_translations[name] = kin[0]
				self.tibiofemoral_rotations[name] = kin[1]
	
def MakeGraphs(dir1, dir2):
	first  = SimulationStructure(dir1, 'Modeler A', False)
	second = SimulationStructure(dir2, "Modeler B", False)
	os.chdir(dir2)
	# create a folder called Processed Results which will contain all the graphs and xml files associated with them
	
	folder_name = None
	if folder_name is not None:		Name = folder_name
	else:							Name = 'Results_Processed'
	try:			os.mkdir(Name)
	except OSError:	pass
	finally:		os.chdir(Name)
	print(os.getcwd())
	# Joint Kinematics
	# ProcessAndPlotJointKinematics(com_data, rot_quat_data, time_steps, constraint_info, joint_axes, bone_axes, rigid_bodies)
	first.ProcessJointKinematics()
	second.ProcessJointKinematics()
	#
	DualPlotJointKinematics('Tibiofemoral'  , first, second)
	DualPlotJointKinematics('Patellofemoral', first, second)
	# Tibia and Patella Translations
	# ProcessAndPlotTranslations(rigid_bodies, com_data, rot_quat_data, time_steps, bone_axes)
	# # # kinematics as rotations and translations of femur
	# # ProcessAndPlotFemurKinematics(rigid_bodies, rot_quat_data, time_steps, bone_axes, com_data)
	# # Constraint Kinetics
	# ProcessAndPlotJointKinetics(rigid_connector_moments, rigid_connector_forces, constraint_info, time_steps, joint_axes)
	# # # Tibia and Femur kinetics
	# # ProcessAndPlotBoneKinetics(rigid_bodies, rigid_moments, rigid_forces, time_steps, bone_axes, com_data)
	# print('\n')
	# print('Graphs were created in ' + os.path.join(dir, Name))

def ProcessAndPlotJointKinematics(com_data, rot_quat_data, time_steps, constraint_info, joint_axes, bone_axes, rigid_bodies):
	"""calculate the joints kinematics for all the constraints
	two different ways to calculate kinematics. joint kinematics 2 used the femur to tibia transformation
	to extract the kinematics"""
	# all_kinematics = joint_kinematics(com_data, rot_quat_data, constraint_info, joint_axes)
	all_kinematics = joint_kinematics_2(bone_axes, rigid_bodies, rot_quat_data, com_data, constraint_info)
	tibiofemoral_translations = {}
	tibiofemoral_rotations = {}
	patellofemoral_translations = {}
	patellofemoral_rotations = {}
	for name, kin in iter(all_kinematics.items()):
		if "Patellar" in name:
			patellofemoral_translations[name] = kin[0]
			patellofemoral_rotations[name] = kin[1]
		else:
			tibiofemoral_translations[name] = kin[0]
			tibiofemoral_rotations[name] = kin[1]
	print(tibiofemoral_translations)
	input('...')
	# plot joint kinematics, save png, and save data  to csv
	# try creating both csv and png, if fails just csv
	try:
		# tibiofemoral joint
		fig = plt.figure(figsize=(12.8, 7.2))
		fig.suptitle("Kinematics of Tibiofemoral cylindrical joints")
		#
		data_list = [("time_steps", time_steps)]
		#
		plt.subplot(1, 2, 1)
		for label, trans in iter(tibiofemoral_translations.items()):
			plt.plot(time_steps, trans, label=label + '_axis')
			data_list.append((label + '_Translation [mm]', trans))
		plt.legend(loc='upper left')
		plt.title('Translations')
		plt.xlabel('Time')
		plt.ylabel('mm')
		#
		plt.subplot(1, 2, 2)
		for label, rot in iter(tibiofemoral_rotations.items()):
			plt.plot(time_steps, rot, label=label + '_axis')
			data_list.append((label + '_Rotation [deg]', rot))
		plt.legend(loc='upper left')
		plt.title('Rotations')
		plt.xlabel('Time')
		plt.ylabel('deg')
		#
		plt.savefig('Tibiofemoral_Kinematics.png')
		# plt.show()
		save_to_csv(data_list, "Tibiofemoral_Kinematics.csv")
		plt.close()
		#
		# patellofemoral joint
		# plot joint kinematics
		fig = plt.figure(figsize=(12.8, 7.2))
		fig.suptitle("Kinematics of patellofemoral cylindrical joints")
		data_list = [("time_steps", time_steps)]
		#
		plt.subplot(1, 2, 1)
		for label, trans in iter(patellofemoral_translations.items()):
			plt.plot(time_steps, trans, label=label + '_axis')
			data_list.append((label + '_Translation [mm]', trans))
		plt.legend(loc='upper left')
		plt.title('Translations',**csfont)
		plt.xlabel('Time',**csfont)
		plt.ylabel('mm',**csfont)
		#
		plt.subplot(1, 2, 2)
		for label, rot in iter(patellofemoral_rotations.items()):
			plt.plot(time_steps, rot, label=label + '_axis')
			data_list.append((label + '_Rotation [deg]', rot))
		plt.legend(loc='upper left')
		plt.title('Rotations',**csfont)
		plt.xlabel('Time',**csfont)
		plt.ylabel('deg',**csfont)
		#
		plt.savefig('Patellofemoral_Kinematics.png')
		# plt.show()
		save_to_csv(data_list, "Patellofemoral_Kinematics.csv")
		plt.close()
	except:
		print('failed creating images, trying to create csv file only')
		data_list = [("time_steps", time_steps)]
		for label, trans in iter(tibiofemoral_translations.items()):
			data_list.append((label + '_Translation [mm]', trans))
		for label, rot in iter(tibiofemoral_rotations.items()):
			data_list.append((label + '_Rotation [deg]', rot))
		save_to_csv(data_list, "Tibiofemoral_Kinematics.csv")
		data_list = [("time_steps", time_steps)]
		for label, trans in iter(patellofemoral_translations.items()):
			data_list.append((label + '_Translation [mm]', trans))
		for label, rot in iter(patellofemoral_rotations.items()):
			data_list.append((label + '_Rotation [deg]', rot))
		save_to_csv(data_list, "Patellofemoral_Kinematics.csv")
# def ProcessJointKinematics(stuff):
# 	"""	Bens quick and dirty hack of ProcessAndPlotJointKinematics"""
# 	all_kinematics = joint_kinematics_2(stuff.bone_axes, stuff.rigid_bodies, stuff.rot_quat_data, stuff.com_data, stuff.constraint_info)
# 	tibiofemoral_translations = {}
# 	tibiofemoral_rotations = {}
# 	patellofemoral_translations = {}
# 	patellofemoral_rotations = {}
# 	for name, kin in iter(all_kinematics.items()):
# 		# if new: name='{}_new'.format(name)
# 		if "Patellar" in name:
# 			patellofemoral_translations[name] = kin[0]
# 			patellofemoral_rotations[name] = kin[1]
# 		else:
# 			tibiofemoral_translations[name] = kin[0]
# 			tibiofemoral_rotations[name] = kin[1]
# 	return tibiofemoral_translations, tibiofemoral_rotations, patellofemoral_translations, patellofemoral_rotations
def DualPlotJointKinematics(name, first, second):
	"""	Bens quick and dirty hack of ProcessAndPlotJointKinematics"""
	if "Tibio" in name:
		translations1 = first.tibiofemoral_translations
		rotations1 = first.tibiofemoral_rotations
		translations2 = second.tibiofemoral_translations
		rotations2 = second.tibiofemoral_rotations
	elif "Patello" in name:
		translations1 = first.patellofemoral_translations
		rotations1 = first.patellofemoral_rotations
		translations2 = second.patellofemoral_translations
		rotations2 = second.patellofemoral_rotations
	else:
		print('Bad Name provided to DualPlotJointKinematics()')
		sys.exit()
	time_steps1 = first.time_steps
	time_steps2 = second.time_steps
	#
	plt.rcParams["font.family"] = "Times New Roman"
	fig = plt.figure(figsize=(12.8, 7.2))
	fig.suptitle("Kinematics of {} Cylindrical joints".format(name))
	plt.subplot(1, 2, 1)
	for label, trans in iter(translations1.items()):
		plt.plot(time_steps1[20:141], trans[20:141], label=label + '_axis')
	for label, trans in iter(translations2.items()):
		plt.plot(time_steps2[20:80], trans[20:80], ls="--", label=label + '_axis')

	plt.legend(loc='upper left')
	plt.title('Translations')
	plt.xlabel('Time')
	plt.ylabel('mm')
	#

	plt.subplot(1, 2, 2)
	for label, rot in iter(rotations1.items()):
		plt.plot(time_steps1[20:141], rot[20:141], label=label + '_axis')
	for label, rot in iter(rotations2.items()):
		plt.plot(time_steps2[20:80], rot[20:80],ls="--", label=label + '_axis')
	plt.rcParams["font.family"] = "Times New Roman"
	plt.legend(loc='lower left')
	plt.title('Rotations')


	plt.xlabel('Time')
	plt.ylabel('deg')
	#
	plt.savefig('{}_Kinematics.png'.format(name))
	plt.show()
	plt.close()

if __name__ == '__main__':
	originalDir = r"/Users/klonowe/oks/oks003/Model/FebioAGS/"
	compareDir  = r"/Users/klonowe/oks/oks003/Model/FebioEMK/35/"
	#
	MakeGraphs(originalDir, compareDir)



