"""
MinimumWorkingExample.py A demonstration of use of the WraptMor software.

Copyright 2020 William Zaylor and Jason P. Halloran

This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with this program.  If not, see <https://www.gnu.org/licenses/>.
"""

import numpy as np
import vtk
import sys
sys.path.append('.') # Add the Working Directory to the PYTHONPATH

# Custom modules/functions
import src.WraptMor
import src.Utilities


def minimumWorkingExample(args):
    """
    A minimum working example of a WraptMor model.

    The input for this function is a list if parameters in the following order:
        0) script, string: This parameter is ignored. It is usually the name of this script.
        1) surfaceA_fileName, string: The name of the .stl file that is acting as one of the wrapping surfaces.
        2) surfaceB_fileName, string: The name of the .stl file that is acting as one of the wrapping surfaces.
        3) insertionPointsFileName, string: The name of the file that contains the ligament fiber's insertion points. This is usually an output of 3D Slicer.
        4) clippingDistance, string: This will be converted to a float value. The distance that surfaceA and surfaceB are clipped with respect to their corresponding insertion points.

    :param args: list, [null, surfaceA_fileName, surfaceB_fileName, insertionPointsFileName, clippingDistance]. See the documentation above for a description of the inputs.
    :return:
    """
    # Parse the input arguments into descriptive variable names
    null, surfaceA_fileName, surfaceB_fileName, insertionPointsFileName, clippingDistance = args

    # Define the global coordinates of the insertion points
    insertionPointAGlobal, insertionPointBGlobal = src.Utilities.loadInsertionPoints3DSlicer(insertionPointsFileName)

    # Load the surfaces
    surfaceAPolyData = src.Utilities.loadStl(surfaceA_fileName)
    surfaceBPolyData = src.Utilities.loadStl(surfaceB_fileName)

    # Convert clippingDistance to a float value
    clippingDistance = float(clippingDistance)

    # Clip the wrapping surfaces
    clippedSurfaceA = src.Utilities.getClippedSurface(insertionPointAGlobal, insertionPointBGlobal, surfaceAPolyData, clippingDistance)
    clippedSurfaceB = src.Utilities.getClippedSurface(insertionPointBGlobal, insertionPointAGlobal, surfaceBPolyData, clippingDistance)

    # Check if the line connecting 'insertionPointAGlobal' and 'insertionPointBGlobal' intersects with clippedSurfaceA
    intersectBodyA = src.Utilities.checkIntersection(insertionPointAGlobal, insertionPointBGlobal, clippedSurfaceA, obbTree=None)
    # Check if the line connecting 'insertionPointAGlobal' and 'insertionPointBGlobal' intersects with clippedSurfaceB
    intersectBodyB = src.Utilities.checkIntersection(insertionPointAGlobal, insertionPointBGlobal, clippedSurfaceB, obbTree=None)

    # Get the fiber insertion-to-insertion length and wrapping points data.
    nominalNormal = np.array([1., 0., 0.]) # Assume that the line connecting 'insertionPointAGlobal' and 'insertionPointBGlobal' does not have a vector parallel to this vector.
    if intersectBodyA is True and intersectBodyB is True: # If multibody wrapping is needed
        # Get the length from pointA to pointB, and the wrapping points on surfaceA and surfaceB.
        length, wrapPointsAGlobal, wrapPointsBGlobal = src.WraptMor.getLengthMultibody(insertionPointAGlobal, insertionPointBGlobal, clippedSurfaceA, clippedSurfaceB, nominalNormal, boundingBoxA=None, boundingBoxB=None, locatorA=None, locatorB=None)
    elif intersectBodyA is True: # If there is only wrapping around 'clippedSurfaceA'
        length, wrapPointsAGlobal = src.WraptMor.getLength(insertionPointBGlobal, insertionPointAGlobal, clippedSurfaceA, nominalNormal)
        wrapPointsAGlobal = np.flip(wrapPointsAGlobal, axis=0) # Reverse the order of points, because the function returned it ordered from [insertionB, point_n, point_n-1, ..., point_0, insertionA]
        # wrapPointsAGlobal has the order [insertionPointAGlobal, wrapPoint0, wrapPoint1, ..., wrapPointN] where wrapPointN is closest to pointB
        wrapPointsAGlobal = wrapPointsAGlobal[:-1] # Remove the point that lies on bodyB. There is only one point on bodyB (the fiber's insertion point), and that the the last point in the array.
        wrapPointsBGlobal = np.array([insertionPointBGlobal]) # Make this into the appropriate shape. There is only one point in this array, and that is the fiber's insertion point on surfaceB.
    elif intersectBodyB is True: # If there is only wrapping around 'clippedSurfaceB'
        length, wrapPointsBGlobal = src.WraptMor.getLength(insertionPointAGlobal, insertionPointBGlobal, clippedSurfaceB, nominalNormal) # wrapPointsBGlobal originally has the order [insertionPointAGlobal, wrapPoint0, wrapPoint1, ..., wrapPointN, insertionPointBGlobal]
        # wrapPointsBGlobal has the order [wrapPoint0, wrapPoint1, ..., wrapPointN, insertionPointBGlobal] where wrapPoint0 is closest to pointA
        wrapPointsBGlobal = wrapPointsBGlobal[1:] # Remove the point that lies on bodyA. There is only one point on bodyA (the fiber's insertion point), and that the the first point in the array.
        wrapPointsAGlobal = np.array([insertionPointAGlobal]) # Make this into the appropriate shape. There is only one point in this array, and that is the fiber's insertion point on surfaceA.
    elif intersectBodyA is False and intersectBodyB is False:
        length = np.linalg.norm(insertionPointBGlobal - insertionPointAGlobal)
        wrapPointsAGlobal = np.array([insertionPointAGlobal]) # There is no wrapping, so there is only one point in this array, which is the fiber's insertion point.
        wrapPointsBGlobal = np.array([insertionPointBGlobal]) # There is no wrapping, so there is only one point in this array, which is the fiber's insertion point.
    else:
        raise ValueError(f'**WARNING** intersectBodyA = {intersectBodyA}, intersectBodyB = {intersectBodyB}.\nThis statement should not logically be reached.')

    print(f'fiber insertion-to-insertion length: {length}')
    print(f'SurfaceA wrapping points: {wrapPointsAGlobal}')
    print(f'SurfaceB wrapping points: {wrapPointsBGlobal}')

    visualization(surfaceAPolyData, surfaceBPolyData, wrapPointsAGlobal, wrapPointsBGlobal)

    return

def visualization(surfaceA, surfaceB, wrapPointsA, wrapPointsB):
    """
    Visualize the fiber with wrapping.

    :param surfaceA: vtkPolyData, One of the wrapping surfaces
    :param surfaceB: vtkPolyData, One of the wrapping surfaces
    :param wrapPointsA: array mx3, The insertion and wrapping points that lie on surfaceA. If there are no wrapping points, m=1, and the only point in the array is the fiber's insertion point. Otherwise wrapPointsA has the order [insertionPointAGlobal, wrapPoint0, wrapPoint1, ..., wrapPointN] where wrapPointN is closest to surfaceB.
    :param wrapPointsB: array nx3, The insertion and wrapping points that lie on surfaceA. If there are no wrapping points, n=1, and the only point in the array is the fiber's insertion point. Otherwise wrapPointsB has the order [wrapPoint0, wrapPoint1, ..., wrapPointN, insertionPointBGlobal] where wrapPoint0 is closest to surfaceA.
    :return:
    """
    ren = vtk.vtkRenderer()
    renWin = vtk.vtkRenderWindow()
    renWin.AddRenderer(ren)
    renWin.SetSize(600, 750)
    iren = vtk.vtkRenderWindowInteractor()
    iren.SetRenderWindow(renWin)
    iren.SetInteractorStyle(vtk.vtkInteractorStyleTrackballCamera())

    surfaces = [surfaceA, surfaceB]
    for i in range(len(surfaces)):
        surfActor = vtk.vtkActor()
        surfMapper = vtk.vtkPolyDataMapper()
        surfMapper.SetInputData(surfaces[i])
        surfActor.SetMapper(surfMapper)
        surfActor.GetProperty().SetColor(np.array([222., 202., 176.])/255.)
        surfActor.GetProperty().SetOpacity(1.)
        ren.AddActor(surfActor)

    fiberPolyData = getLigamentFiberPolydata(wrapPointsA, wrapPointsB)
    fiberActor = vtk.vtkActor()
    fiberMapper = vtk.vtkPolyDataMapper()
    fiberMapper.SetInputData(fiberPolyData)
    fiberActor.SetMapper(fiberMapper)
    fiberActor.GetProperty().SetOpacity(1.)
    fiberActor.GetProperty().SetLineWidth(4)
    fiberActor.GetProperty().SetPointSize(8)
    fiberActor.GetProperty().SetRenderPointsAsSpheres(True)
    fiberActor.GetProperty().SetRenderLinesAsTubes(True)
    ren.AddActor(fiberActor)

    ren.SetBackground(1, 1, 1)

    iren.Initialize()
    renWin.Render()
    iren.Start()

    return

def getLigamentFiberPolydata(wrapPointsA, wrapPointsB):
    """
    Generate the vtkPolyData object that is used to visualize the ligament fiber.

    :param wrapPointsA: array mx3, mx3, The insertion and wrapping points that lie on surfaceA. If there are no wrapping points, m=1, and the only point in the array is the fiber's insertion point. Otherwise wrapPointsA has the order [insertionPointAGlobal, wrapPoint0, wrapPoint1, ..., wrapPointN] where wrapPointN is closest to surfaceB.
    :param wrapPointsB: array nx3, nx3, The insertion and wrapping points that lie on surfaceA. If there are no wrapping points, n=1, and the only point in the array is the fiber's insertion point. Otherwise wrapPointsB has the order [wrapPoint0, wrapPoint1, ..., wrapPointN, insertionPointBGlobal] where wrapPoint0 is closest to surfaceA.
    :return: vtkPolyData, The polydata that is used to visualize the ligament fiber.
    """
    # Initialize the points vtk class instance
    fiberPoints = vtk.vtkPoints()
    # Initialize the lines variable as a vtkCellArray
    lines = vtk.vtkCellArray()
    # Initialize the polydata variable
    fiberPolydata = vtk.vtkPolyData()
    # Initialize the array that is used to visualize the insertion and wrapping points
    verts = vtk.vtkCellArray()
    # Assign different colors to wrapPointsA and wrapPointsB
    pointColors = vtk.vtkUnsignedCharArray()
    pointColors.SetNumberOfComponents(3)
    pointColors.SetName("Colors")

    fiberPointIds = []
    for i in range(len(wrapPointsA)):
        pointId = fiberPoints.InsertNextPoint(wrapPointsA[i]) # Add the ith point wrapPointsA
        fiberPointIds.append(pointId)
        verts.InsertNextCell(1)
        verts.InsertCellPoint(pointId)
        pointColors.InsertNextTuple3(213., 43., 30.)
    for j in range(len(wrapPointsB)):
        pointId = fiberPoints.InsertNextPoint(wrapPointsB[j]) # Add the jth point wrapPointsB with an identifier as the number i
        fiberPointIds.append(pointId)
        verts.InsertNextCell(1)
        verts.InsertCellPoint(pointId)
        pointColors.InsertNextTuple3(190., 214., 0.)
    for i in range(1, len(fiberPointIds)):
        line = vtk.vtkLine() # Initialize the specific line class
        line.GetPointIds().SetId(0, fiberPointIds[i-1]) # Set the point ID for the first point in the line
        line.GetPointIds().SetId(1, fiberPointIds[i]) # Set the point ID for the second point in the line
        lines.InsertNextCell(line) # Add the line to the lines variable

    fiberPolydata.SetPoints(fiberPoints) # Set the points to the polydata
    fiberPolydata.SetLines(lines) # Set the lines to the polydata
    fiberPolydata.SetVerts(verts) # Set the variable that is needed to visualize the insertion and wrapping points.
    fiberPolydata.GetPointData().SetScalars(pointColors) # Set the colors for the points
    return fiberPolydata



if __name__ == '__main__':
    minimumWorkingExample(sys.argv)