# this script was used to visualize experimental kinematics for open knees, using the bone stl files.
# see main function at bottom for examples
# sript will generate one femur stl file at each time point in the given kinematics csv file.
# open the stl files in meshlab, with the tibia stl file to visualize how the joint is expected to mode based on the given kinematics
# can help when it is unclear if the kinematics were properly processed

import numpy as np
import os
import matplotlib.pyplot as plt
from mayavi import mlab
import xml.etree.ElementTree as ET
from tvtk.tools import visual
from tvtk.api import tvtk
from tvtk.common import configure_input
from traits.api import on_trait_change
import stl
from mayavi.modules.text import Text
import math
from argparse import Namespace
# import moviepy.editor as mpy
# import moviepy.video as mpv
import pandas as pd
import sys
import zipfile
from PIL import Image
from lxml import etree as et

def T_fem_in_tib(a,b,c,alpha,beta,gamma):
    # this transformation matrix will give the position of the femur in the tibia coordinate system for each data point

    T_fem_tib = np.zeros((len(a),4,4))

    ca = np.cos(alpha)
    cb = np.cos(beta)
    cg = np.cos(gamma)
    sa = np.sin(alpha)
    sb = np.sin(beta)
    sg = np.sin(gamma)

    T_fem_tib[:,0,0] = np.multiply(cb,cg)
    T_fem_tib[:, 0, 1] = np.multiply(-cb,sg)
    T_fem_tib[:, 0, 2] = sb
    T_fem_tib[:, 0, 3] = np.multiply(c,sb) + a

    T_fem_tib[:, 1, 0] = np.multiply(np.multiply(sa,sb),cg) + np.multiply(ca,sg)
    T_fem_tib[:, 1, 1] = -np.multiply(np.multiply(sa,sb),sg) + np.multiply(ca, cg)
    T_fem_tib[:, 1, 2] = -np.multiply(sa, cb)
    T_fem_tib[:, 1, 3] = -np.multiply(c,np.multiply(sa, cb))+ np.multiply(b,ca)

    T_fem_tib[:, 2, 0] = -np.multiply(np.multiply(ca, sb), cg) + np.multiply(sa,sg)
    T_fem_tib[:, 2, 1] = np.multiply(np.multiply(ca,sb),sg) + np.multiply(sa,cg)
    T_fem_tib[:, 2, 2] = np.multiply(ca, cb)
    T_fem_tib[:, 2, 3] = np.multiply(c,np.multiply(ca, cb))+ np.multiply(b,sa)

    T_fem_tib[:, 3, 3] = 1.0

    return T_fem_tib

def transformSTL(stl_name, A, i):

    directory= os.path.dirname(stl_name)
    stl_file = stl.Mesh.from_file(stl_name, calculate_normals=True)

    A[0][3] = A[0][3] * 1000
    A[1][3] = A[1][3] * 1000
    A[2][3] = A[2][3] * 1000

    stl_file.transform(A)
    new_stl_name = os.path.join(directory,'step_{}.stl'.format(i))
    stl_file.save(new_stl_name)

def transformSTL_get_actor(v, stl_name, A, prop, i):
    '''Transformation matrix needs to be in m and rad'''

    stl_file = stl.Mesh.from_file(stl_name, calculate_normals=True)

    A[0][3] = A[0][3] * 1000
    A[1][3] = A[1][3] * 1000
    A[2][3] = A[2][3] * 1000

    stl_file.transform(A)
    stl_file.save('test_1.stl')
    reader2 = tvtk.STLReader()
    reader2.file_name = 'test_1.stl'
    reader2.update()

    mapper2 = tvtk.PolyDataMapper()
    configure_input(mapper2, reader2.output)

    # Add ultrasound probe to Mayavi scene
    actor2 = tvtk.Actor(mapper=mapper2, property=prop)
    return actor2


def T_tib_in_image(ModelPropertiesXml):
    """ find the inital offsets in the tibiofemoral joint of the model"""

    # extract the origins and axes of the tibia and femur from the model properties file
    model_properties = et.parse(ModelPropertiesXml)
    ModelProperties = model_properties.getroot()
    Landmarks = ModelProperties.find("Landmarks")

    FMO = np.array(Landmarks.find("FMO").text.split(",")).astype(np.float)
    Xf_axis = np.array(Landmarks.find("Xf_axis").text.split(",")).astype(np.float)
    Yf_axis = np.array(Landmarks.find("Yf_axis").text.split(",")).astype(np.float)
    Zf_axis = np.array(Landmarks.find("Zf_axis").text.split(",")).astype(np.float)

    femur_axes = [Xf_axis, Yf_axis, Zf_axis]
    femur_axes = np.asarray(femur_axes)

    T_fem_image = np.eye(4)
    T_fem_image[:3,:3] = femur_axes.T # unitless
    T_fem_image[:3,3] = FMO/1000 # convert to m

    TBO =np.array( Landmarks.find("TBO").text.split(",")).astype(np.float)
    Xt_axis=np.array( Landmarks.find("Xt_axis").text.split(",")).astype(np.float)
    Yt_axis= np.array(Landmarks.find("Yt_axis").text.split(",")).astype(np.float)
    Zt_axis = np.array(Landmarks.find("Zt_axis").text.split(",")).astype(np.float)

    tibia_axes = [Xt_axis,Yt_axis,Zt_axis]
    tibia_axes = np.asarray(tibia_axes)

    T_tib_image = np.eye(4)
    T_tib_image[:3,:3] = tibia_axes.T # unitless
    T_tib_image[:3,3] = TBO/1000 # convert to m

    return T_tib_image, T_fem_image

def visualize_kinematics(experiment_kinematics_csv, ModelPropertiesXml, tibia_stl_file, femur_stl_file):

    # tibia_stl = stl.Mesh.from_file(tibia_stl_file)
    # femur_stl = stl.Mesh.from_file(femur_stl_file)

    # read the csv
    df = pd.read_csv(experiment_kinematics_csv)

    # convert kinematics to m and radians
    ML = df['Knee JCS Medial [mm]'].values/1000
    AP = - df['Knee JCS Posterior [mm]'].values/1000 # right handed cs, anterior positive so flip the data
    SI = df['Knee JCS Superior [mm]'].values/1000
    EF = -np.radians(df['Knee JCS Flexion [deg]'].values) # extension should be positive so flip data
    VV = np.radians(df['Knee JCS Valgus [deg]'].values) # valgus positive for left knee
    EI = -np.radians(df['Knee JCS Internal Rotation [deg]'].values) # left knee external positive so flip data

    # create imagiary time
    time = [range(len(ML))]

    # input those kinematics to get rb tranfromation matrix to get the position of the femur in the tibia cs
    T_Fem_in_Tib = T_fem_in_tib(ML, AP, SI, EF, VV, EI)

    # transformation matrix tibia and femur in image cs
    T_Tib_in_Image, T_Fem_in_Image = T_tib_in_image(ModelPropertiesXml)

    # transfrom femur to origin, then apply transformation in image coordinate system
    T_fem = np.matmul(T_Tib_in_Image, np.matmul(np.linalg.inv(T_Fem_in_Tib), np.linalg.inv(T_Fem_in_Image)))

    for i, A in enumerate(T_fem):
        transformSTL(femur_stl_file, A, i)


    # v = mlab.figure()
    # v.scene._lift()
    #
    # actors = []
    #
    # # plot the tibia
    # act = transformSTL_get_actor(v, tibia_stl_file, np.identity(4), tvtk.Property(opacity=1, color=(1, 1, 1)), 0)
    # v.scene.add_actor(act)
    #
    # for i, A in enumerate(T_Fem_in_Image):
    #     act = transformSTL_get_actor(v, femur_stl_file, A, tvtk.Property(opacity=1, color=(1, 1, 1)), i)
    #     actors.append([act])
    #
    #     # ---------------------------------------------------------------------
    #     # Use this section to visualize animations (repeated)
    #     # ---------------------------------------------------------------------
    # @mlab.show
    # @mlab.animate(delay=100)
    # @on_trait_change('scene.activated')
    # def anim():
    #     """Animate the b1 box."""
    #     t = 0
    #     engine = mlab.get_engine()
    #     scene = engine.current_scene
    #     timetext = Text()
    #     timetext.text = 'Time = %f sec' % (float(time[t]))
    #     v.scene.add_actor(timetext.actor)
    #     # timetext = mlab.text(0.7, 0.01, 'Time = %f sec' % (float(time[t]) / 1000.0), figure = scene)
    #     while 1:
    #         global actor_old
    #         if 'actor_old' in globals() and len(actor_old) > 0:
    #             v.scene.remove_actor(actor_old[0])
    #
    #         try:
    #             for act_i in actors[t]:
    #                 v.scene.add_actor(act_i)
    #
    #             actor_old = [actors[t], timetext.actor]
    #         except:
    #             pass
    #         timetext.text = 'Time = %f sec' % (float(time[t]))
    #         t = (t + 1) % len(time)
    #
    #         yield
    #
    # # anim()



if __name__ == '__main__':

    # must use registered model properties file
    ModelPropertiesXml= "C:\\Users\schwara2\Documents\Open_Knees\oks003_calibration\Registration\model\ModelProperties.xml"
    tibia_stl_file = "C:\\Users\schwara2\Documents\Open_Knees\\visualization\oks003\oks003_TBB_AGS_LVTIT.stl"
    femur_stl_file ="C:\\Users\schwara2\Documents\Open_Knees\\visualization\oks003\oks003_FMB_AGS_LVTIT.stl"

    # #Anterior
    # experiment_kinematics_csv = "C:\\Users\schwara2\Documents\Open_Knees\oks003_calibration\DataProcessing\Processed_Data\Laxity_0deg_AP1_kinematics_in_JCS_experiment.csv"
    # visualize_kinematics(experiment_kinematics_csv, ModelPropertiesXml, tibia_stl_file, femur_stl_file)

    # # Posterior
    # experiment_kinematics_csv = "C:\\Users\schwara2\Documents\Open_Knees\oks003_calibration\DataProcessing\Processed_Data\Laxity_0deg_AP2_kinematics_in_JCS_experiment.csv"
    # visualize_kinematics(experiment_kinematics_csv, ModelPropertiesXml, tibia_stl_file, femur_stl_file)

    # #Varus
    # experiment_kinematics_csv = "C:\\Users\schwara2\Documents\Open_Knees\oks003_calibration\DataProcessing\Processed_Data\Laxity_0deg_VV1_kinematics_in_JCS_experiment.csv"
    # visualize_kinematics(experiment_kinematics_csv, ModelPropertiesXml, tibia_stl_file, femur_stl_file)

    # #Valgus
    # experiment_kinematics_csv = "C:\\Users\schwara2\Documents\Open_Knees\oks003_calibration\DataProcessing\Processed_Data\Laxity_0deg_VV2_kinematics_in_JCS_experiment.csv"
    # visualize_kinematics(experiment_kinematics_csv, ModelPropertiesXml, tibia_stl_file, femur_stl_file)
    #
    # #Internal
    # experiment_kinematics_csv = "C:\\Users\schwara2\Documents\Open_Knees\oks003_calibration\DataProcessing\Processed_Data\Laxity_0deg_VV2_kinematics_in_JCS_experiment.csv"
    # visualize_kinematics(experiment_kinematics_csv, ModelPropertiesXml, tibia_stl_file, femur_stl_file)
    #
    # #External
    # experiment_kinematics_csv = "C:\\Users\schwara2\Documents\Open_Knees\oks003_calibration\DataProcessing\Processed_Data\Laxity_0deg_VV2_kinematics_in_JCS_experiment.csv"
    # visualize_kinematics(experiment_kinematics_csv, ModelPropertiesXml, tibia_stl_file, femur_stl_file)

    # passive flexion
    experiment_kinematics_csv = "C:\\Users\schwara2\Documents\Open_Knees\oks003_calibration\DataProcessing\Processed_Data\Passive_Flexion_Kinematics_in_JCS_experiment.csv"
    visualize_kinematics(experiment_kinematics_csv, ModelPropertiesXml, tibia_stl_file, femur_stl_file)