<html><head><meta name="color-scheme" content="light dark"></head><body><pre style="word-wrap: break-word; white-space: pre-wrap;">'''
This script uses a registration XML to read bone and ligament STLs to visualize Bounding box and signed distance calculates in Kinematics Kinetics Analysis.

Inputs:
Registration xml with the appropriate locations of the segmented stls

USAGE:
python visualization.py &lt;Registration.xml&gt; &lt;text file of csv data to analyze&gt;

Outputs:
VTK window animating desired bones/ligaments/cartilages

Written By:
Rodrigo Lopez-Navarro
Department of Biomedical Engineering, Cleveland Clinic
'''
import vtk
import numpy as np
import os
import xml.etree.ElementTree as ET

def main(args):
    #STLs
    input_xml = ET.parse(args[-1])
    root = input_xml.getroot()

    knee_dir = root.find("Knee_directory").text
    knee_id = os.path.split(knee_dir)[1]

    # FMC_stl_name = root.find("Cartilages").find("Fem-cart").find("file").text
    # TBB_L_stl_name = root.find("Cartilages").find("Tibia-L").find("file").text
    # TBB_M_stl_name = root.find("Cartilages").find("Tibia-M").find("file").text
    # PTC_stl_name = root.find("Cartilages").find("Pat-cart").find("file").text
    # cartilage_stls = [os.path.join(knee_dir, 'mri-' + knee_id, FMC_stl_name),
    #                   os.path.join(knee_dir, 'mri-' + knee_id, TBB_L_stl_name),
    #                   os.path.join(knee_dir, 'mri-' + knee_id, TBB_M_stl_name),
    #                   os.path.join(knee_dir, 'mri-' + knee_id, PTC_stl_name)]

    femur_stl_name = root.find("Bones").find("Femur").find("file").text
    tibia_stl_name = root.find("Bones").find("Tibia").find("file").text
    fibula_stl_name = root.find("Bones").find("Fibula").find("file").text
    patella_stl_name = root.find("Bones").find("Patella").find("file").text
    bone_stls = [os.path.join(knee_dir, 'mri-' + knee_id, femur_stl_name),
                 os.path.join(knee_dir, 'mri-' + knee_id, tibia_stl_name),
                 os.path.join(knee_dir, 'mri-' + knee_id, patella_stl_name),
                 os.path.join(knee_dir, 'mri-' + knee_id, fibula_stl_name)]

    ACL_stl_name = root.find("Ligaments").find("ACL").find("file").text
    PCL_stl_name = root.find("Ligaments").find("PCL").find("file").text
    MCL_stl_name = root.find("Ligaments").find("MCL").find("file").text
    LCL_stl_name = root.find("Ligaments").find("LCL").find("file").text
    ligament_stls = [os.path.join(knee_dir, 'mri-' + knee_id, ACL_stl_name), os.path.join(knee_dir, 'mri-' + knee_id, PCL_stl_name), os.path.join(knee_dir, 'mri-' + knee_id, MCL_stl_name), os.path.join(knee_dir, 'mri-' + knee_id, LCL_stl_name)]


    #Read bone STLS to plot
    femur = vtk.vtkSTLReader()
    femur.SetFileName(bone_stls[0])
    femur.Update()

    tibia = vtk.vtkSTLReader()
    tibia.SetFileName(bone_stls[1])
    tibia.Update()

    fibula = vtk.vtkSTLReader()
    fibula.SetFileName(bone_stls[3])
    fibula.Update()

    sphereMapper = vtk.vtkPolyDataMapper()
    sphereMapper.SetInputConnection(femur.GetOutputPort())
    sphereMapper.ScalarVisibilityOff()

    sphereActor = vtk.vtkActor()
    sphereActor.SetMapper(sphereMapper)
    sphereActor.GetProperty().SetOpacity(.3)
    sphereActor.GetProperty().SetColor(1, 1, 1)

    spheretMapper = vtk.vtkPolyDataMapper()
    spheretMapper.SetInputConnection(tibia.GetOutputPort())
    spheretMapper.ScalarVisibilityOff()

    spheretActor = vtk.vtkActor()
    spheretActor.SetMapper(spheretMapper)
    spheretActor.GetProperty().SetOpacity(.3)
    spheretActor.GetProperty().SetColor(1, 1, 1)

    spherefMapper = vtk.vtkPolyDataMapper()
    spherefMapper.SetInputConnection(fibula.GetOutputPort())
    spherefMapper.ScalarVisibilityOff()

    spherefActor = vtk.vtkActor()
    spherefActor.SetMapper(spherefMapper)
    spherefActor.GetProperty().SetOpacity(.3)
    spherefActor.GetProperty().SetColor(1, 1, 1)

    #Read ligament STL to analyze
    lig = vtk.vtkSTLReader()
    lig.SetFileName(ligament_stls[1])
    lig.Update()
    lig_data = lig.GetOutput()

    #Set which bone to calculate signed distance from ligament
    implicitPolyDataDistance = vtk.vtkImplicitPolyDataDistance()
    implicitPolyDataDistance.SetInput(femur.GetOutput())

    # Setup a grid
    lig_points = vtk.vtkPoints()
    b = np.zeros(3)
    for i in range(lig_data.GetNumberOfPoints()):
        lig_data.GetPoint(i, b)
        lig_points.InsertNextPoint(b[0], b[1], b[2])

    # Add distances to each point
    lig_signedDistances = vtk.vtkFloatArray()
    lig_signedDistances.SetNumberOfComponents(1)
    lig_signedDistances.SetName("lig_SignedDistances")
    lig_dist = []
    # Evaluate the signed distance function at all of the grid points
    for pointId in range(lig_points.GetNumberOfPoints()):
        p = lig_points.GetPoint(pointId)
        lig_signedDistance = implicitPolyDataDistance.EvaluateFunction(p)
        lig_signedDistances.InsertNextValue(lig_signedDistance)
        lig_dist.append(lig_signedDistance)

    polyDataLig = vtk.vtkPolyData()
    polyDataLig.SetPoints(lig_points)
    polyDataLig.GetPointData().SetScalars(lig_signedDistances)

    vertexGlyphFilterLig = vtk.vtkVertexGlyphFilter()
    vertexGlyphFilterLig.SetInputData(polyDataLig)
    vertexGlyphFilterLig.Update()

    signedDistanceMapperLig = vtk.vtkPolyDataMapper()
    signedDistanceMapperLig.SetInputConnection(vertexGlyphFilterLig.GetOutputPort())
    signedDistanceMapperLig.ScalarVisibilityOn()

    signedDistanceActorLig = vtk.vtkActor()
    signedDistanceActorLig.SetMapper(signedDistanceMapperLig)

    #Create Oriented Bounding Box
    obbTree = vtk.vtkOBBTree()
    obbTree.SetDataSet(lig_data)
    obbTree.BuildLocator()
    polydata = vtk.vtkPolyData()
    obbTree.GenerateRepresentation(0,lig_data)
    obbtreeMapper = vtk.vtkPolyDataMapper()
    obbtreeMapper.SetInputData(lig_data)
    obbtreeActor = vtk.vtkActor()
    obbtreeActor.SetMapper(obbtreeMapper)
    obbtreeActor.GetProperty().SetInterpolationToFlat()
    obbtreeActor.GetProperty().SetRepresentationToWireframe()

    corner = [0, 0, 0]
    max = [0, 0, 0]
    mid = [0, 0, 0]
    min = [0, 0, 0]
    size = [0, 0, 0]
    obbTree.ComputeOBB(lig_points, corner, max, mid, min, size)
    corner = np.array(corner)
    max = np.array(max)
    mid = np.array(mid)
    min = np.array(min)
    c1 = corner
    c2 = corner + mid
    c3 = corner + min
    c4 = corner + min + mid
    c5 = corner + max
    c6 = corner + max + mid
    c7 = corner + max + min
    c8 = corner + max + min + mid
    plA = [c1, c2, c3, c4]
    plB = [c5, c6, c7, c8]
    midpA = np.average(plA, axis=0)
    midpB = np.average(plB, axis=0)

    sphere1source = vtk.vtkSphereSource()
    sphere1source.SetCenter(midpA[0], midpA[1], midpA[2])
    sphere1source.SetRadius(1)
    sphere1source.Update()

    sphere1Mapper = vtk.vtkPolyDataMapper()
    sphere1Mapper.SetInputConnection(sphere1source.GetOutputPort())
    sphere1Mapper.ScalarVisibilityOff()

    sphere1Actor = vtk.vtkActor()
    sphere1Actor.SetMapper(sphere1Mapper)
    sphere1Actor.GetProperty().SetOpacity(.7)
    sphere1Actor.GetProperty().SetColor(0, 1, 0)

    sphere2source = vtk.vtkSphereSource()
    sphere2source.SetCenter(midpB[0], midpB[1], midpB[2])
    sphere2source.SetRadius(1)
    sphere2source.Update()

    sphere2Mapper = vtk.vtkPolyDataMapper()
    sphere2Mapper.SetInputConnection(sphere2source.GetOutputPort())
    sphere2Mapper.ScalarVisibilityOff()

    sphere2Actor = vtk.vtkActor()
    sphere2Actor.SetMapper(sphere2Mapper)
    sphere2Actor.GetProperty().SetOpacity(.7)
    sphere2Actor.GetProperty().SetColor(0, 1, 0)

    #Render and animate
    renderer = vtk.vtkRenderer()
    renderer.AddViewProp(sphereActor)
    renderer.AddViewProp(spheretActor)
    renderer.AddViewProp(spherefActor)
    renderer.AddViewProp(signedDistanceActorLig)
    renderer.AddViewProp(obbtreeActor)
    renderer.AddViewProp(sphere1Actor)
    renderer.AddViewProp(sphere2Actor)

    renderWindow = vtk.vtkRenderWindow()
    renderWindow.AddRenderer(renderer)

    renWinInteractor = vtk.vtkRenderWindowInteractor()
    renWinInteractor.SetRenderWindow(renderWindow)

    renderWindow.Render()
    renWinInteractor.Start()

if __name__ == "__main__":
    # main(sys.argv)
    main(['callfunction', '/home/lopezr3/Documents/CC/Registration/oks003_registration_01.xml'])</pre></body></html>