# this script is to be used for comparing model kinematics to experiment kinematics.
# It is intended to be used after the LogPostProcessing.py script has been run on a model.
# see main function at the end of script for usage examples

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from lxml import etree as et
import os


def restructure_model_data(df_model):

    # create a list that we are looking to fill with the corresponding data
    model_data = [None, None, None, None, None, None]
    model_time_steps = []
    data_names = [None, None, None, None, None, None]

    for name in df_model.columns:
        i = None

        if "time_steps" in name.lower():
            model_time_steps = list(map(float, df_model[name].values))
        elif "force" in name.lower() or "load" in name.lower() or "translation" in name.lower():
            if "extension" in name.lower() or "_x" in name.lower() or "x_" in name.lower():
                i = 0
            elif "adduction" in name.lower() or "_y" in name.lower() or "y_" in name.lower():
                i = 1
            elif "internal" in name.lower() or "_z" in name.lower() or "z_" in name.lower():
                i = 2
        elif "moment" in name.lower() or "rotation" in name.lower():
            if "extension" in name.lower() or "_x" in name.lower() or "x_" in name.lower():
                i = 3
            elif "adduction" in name.lower() or "_y" in name.lower() or "y_" in name.lower():
                i = 4
            elif "internal" in name.lower() or "_z" in name.lower() or "z_" in name.lower():
                i = 5

        if i is not None:
            model_data[i] = np.asarray(list(map(float, df_model[name].values)))
            data_names[i] = name

    model_time_steps = np.asarray(model_time_steps)

    model_data = np.asarray(model_data)

    # for experimental loading we want to compare after time 2. Ie after model reaches initial loading conditions
    tm = 2.0 # use tm=1.0 if comparing non-experimental loading case - ie just in situ strain+loading

    # cut the data so we only look at loading after time 2
    keep_idx = np.where(model_time_steps >= tm)[0]
    model_data = model_data[:, keep_idx]
    model_time_steps = model_time_steps[keep_idx]
    model_time_steps -= tm

    # check that all data is in N, Nmm, mm, deg. change if necessary
    for i, name in enumerate(data_names):
        if "[m]" in name or "(m)" in name:
            model_data[i] = 1000*model_data[i]
        elif "[rad]" in name or "(rad)" in name:
            model_data[i] = np.degrees(model_data[i])
        elif "[Nm]" in name or "(Nm)" in name:
            model_data[i] = 1000*model_data[i]

    return model_time_steps, model_data, data_names


def restructure_experiment_data(df_experiment):

    experiment_data = []

    for n in df_experiment.columns:
        experiment_data.append(np.asarray(list(map(float, df_experiment[n].values))))

    applied_load = experiment_data[1]
    experiment_data = experiment_data[2:]
    data_names = df_experiment.columns[2:]
    # the time steps are based on the normalized applied load
    max_arg = np.nanargmax(np.abs(applied_load))
    experiment_time_steps = applied_load / applied_load[max_arg]

    experiment_data = np.asarray(experiment_data)

    # check that all data is in N, Nmm, mm, deg. change if necessary
    for i, name in enumerate(data_names):
        if "[m]" in name or "(m)" in name:
            experiment_data[i] = 1000*experiment_data[i]
        elif "[rad]" in name or "(rad)" in name:
            experiment_data[i] = np.degrees(experiment_data[i])
        elif "[Nm]" in name or "(Nm)" in name:
            experiment_data[i] = 1000*experiment_data[i]

    return experiment_time_steps, experiment_data, data_names


def plot_comaprison(csv_list, all_data, data_names, title, save_directory, column_names):
    # Plot the data on the same graph
    fig = plt.figure(figsize=(12.8, 7.2))
    fig.suptitle(title)

    line_styles = ['-', '--', '-.', ':']  # up to 4 options in case we want to compare more files
    line_colors = ['r', 'b', 'g']

    plt.subplot(1, 2, 1)
    for e, c in enumerate(csv_list):
        time_steps = all_data[c][0]
        data = all_data[c][1]
        for i, n in enumerate(column_names[0:3]):
            plt.plot(time_steps, data[i], line_colors[i] + line_styles[e], label=n + '_' + data_names[e])
    plt.legend(loc='best')
    plt.xlabel('Time')
    if "kinematics" in csv_list[
        0].lower():  # csv files should be all kinematics or all kinetics so just check the first
        plt.title('Translations')
        plt.ylabel('mm')
    elif "kinetics" in csv_list[0].lower():
        plt.title('Force')
        plt.ylabel('N')

    plt.subplot(1, 2, 2)
    for e, c in enumerate(csv_list):
        time_steps = all_data[c][0]
        data = all_data[c][1]
        for i, n in enumerate(column_names[3:]):
            plt.plot(time_steps, data[i + 3], line_colors[i] + line_styles[e], label=n + '_' + data_names[e])
    plt.legend(loc='best')
    plt.xlabel('Time')
    if "kinematics" in csv_list[0].lower():
        plt.title('Rotations')
        plt.ylabel('deg')
    elif "kinetics" in csv_list[0].lower():
        plt.title('Moment')
        plt.ylabel('Nmm')

    # plt.show()
    save_as = os.path.join(save_directory, title)
    plt.savefig(save_as + '.png')
    # plt.savefig(save_as+'.svg', format='svg')
    plt.close()

def write_file(xml_tree, new_filename):

    # Write the New File
    xml_tree.write(new_filename, xml_declaration=True, pretty_print=True)

    # make sure that it's pretty printing properly, by parsing and re-writing
    parser = et.XMLParser(remove_blank_text=True)
    new_feb_tree = et.parse(new_filename, parser)
    new_feb_tree.write(new_filename, xml_declaration=True, pretty_print=True)

def get_rms_errors(csv_list, all_data):
    """ return the rms errors for each kinematics/kinetics channel, for each csv in csv list compared with the
     first csv in the list"""

    def find_common_points(list1, list2, error):
        """ find the indices of the common points in two lists, within error. return the indices for list1 and list2 of
        the commong points"""

        indices_1 = []
        indices_2 = []

        for idx1, val in enumerate(list1):
            dif = np.abs(np.array(list2 - val))
            try:
                idx2 = np.where( dif < error)[0][0]
                indices_1.append(idx1)
                indices_2.append(idx2)
            except:# no value found
                pass

        return indices_1, indices_2

    rms_errors = []

    first_csv_data = all_data[csv_list[0]][1]
    first_csv_time_points = all_data[csv_list[0]][0]
    for d in range(1, len(csv_list)):
        second_csv_time_points = all_data[csv_list[d]][0]
        second_csv_data = all_data[csv_list[d]][1]

        # try first with small error, if less than 3 points come up, increase the error allowance.
        # If still few points could be a small data set so just continue
        idx1, idx2 = find_common_points(first_csv_time_points, second_csv_time_points, 0.0001)
        if len(idx1) < 3:
            idx1, idx2 = find_common_points(first_csv_time_points, second_csv_time_points, 0.005)

        first_data = first_csv_data[:,np.array(idx1)]
        second_data = second_csv_data[:,np.array(idx2)]
        squared_difs = np.square(first_data - second_data)

        sum_squares = np.nansum(squared_difs, axis=1)
        rms_error = np.sqrt(sum_squares/len(squared_difs[0]))
        rms_errors.append(rms_error)

    return rms_errors


def compare_results(csv_list, title, save_directory):
    """ create a comparison between two csv files containing processes kinematics or kinetics data.
    can be model or experiment files, which one will be determined based on the name of the directory
    containing "Data" or "Results" in the name"""

    # read the csv files as data frames. try different encoding, sometimes changes on windows/linux depending on the file
    # restructure the data into array of time steps between 0 and 1,
    # and array of the data where the indices are
    # 0,1,2 = x,y,z translations/forces
    # 3,4,5 = x,y,z rotations/moments

    all_data = {}
    data_frames = []
    data_names = []
    type=None

    for i, c in enumerate(csv_list):
        directory = os.path.dirname(c)
        immediate_directory = os.path.basename(os.path.normpath(directory))
        if 'Data' in immediate_directory:
            type = "e"
        elif 'Results' in immediate_directory:
            type= "m"
        else:
            print("could not recognize if this file contains experimental data or model data.")
            print("file name: {}".format(c))
            type= input(" is this model or experiment data? type m for model, e for experiment:")

        if type=="e":
            try:
                df_experiment = pd.read_csv(c, encoding='utf7', na_values='nan')
                print("experiment data read")
            except:
                df_experiment = pd.read_csv(c, encoding='utf8', na_values='nan')
                print("experiment data read")

            # the experiment data
            experiment_time_steps, experiment_data, experiemnt_data_names = restructure_experiment_data(df_experiment)

            all_data[c] = ((experiment_time_steps, experiment_data, experiemnt_data_names))
            data_frames.append(df_experiment)
            data_names.append("experiment_{}".format(i))


        elif type=="m":
            try:
                df_model = pd.read_csv(c, encoding='utf7', na_values='nan')
                print("model data read")
            except:
                df_model = pd.read_csv(c, encoding='utf8', na_values='nan')
                print("model data read")

            # model data
            model_time_steps, model_data, model_data_names = restructure_model_data(df_model)

            all_data[c] = ((model_time_steps, model_data, model_data_names))
            data_frames.append(df_model)
            data_names.append("model_{}".format(i))


    # get all the strings in the csv list
    sep = "-"
    all_strings = sep.join(csv_list)

    # if plotting bone kinetics, column names will be:
    if 'kinetics' in all_strings.lower():
        column_names = ["x","y","z","x","y","z"]

    else:
        # search of specimen name to decide if right or left knee, if not ask the user
        if 'oks003' in all_strings.lower():
            left_or_right = "l"
        elif 'du02' in all_strings.lower():
            left_or_right = "r"
        else:
            print("couldn't find name oks003 or du02 in file names, to check for left or right knee")
            left_or_right = input("is this a left or right knee? type l for left, r for right")

        if left_or_right == "l":
            column_names = ["medial", "anterior", "superior", "extension", "valgus", "external"]
        elif left_or_right == "r":
            column_names = ["lateral", "anterior", "superior", "extension", "varus", "internal"]
        else:
            print("could not understand user input")

    # plot the data on the same figure
    plot_comaprison(csv_list, all_data, data_names, title, save_directory, column_names)

    # get the rms error along each channel, comparing the fist csv in the list each other csv in the list
    rms_errors = get_rms_errors(csv_list, all_data)

    # store the rms errors in an xml file
    ModelPredictionError = et.Element("ModelPredictionError")
    for d in range(1,len(csv_list)):
        Compare = et.SubElement(ModelPredictionError, "Compare")
        file_1 = et.SubElement(Compare, "file_1")
        file_1.text = str(csv_list[0])
        file_2 = et.SubElement(Compare, "file_2")
        file_2.text = str(csv_list[d])

        RMS_Error = et.SubElement(Compare, "RMS_Error")
        rms_error = rms_errors[d-1]
        for i in range(6):
            err = et.SubElement(RMS_Error, column_names[i])
            err.text = str(rms_error[i])

    # save the xml in the save directory
    ModelPredictionError_root = et.ElementTree(ModelPredictionError)
    filename = os.path.join(save_directory, title+".xml")
    write_file(ModelPredictionError_root, filename)

def from_xml(xml_file):

    file_tree = et.parse(xml_file)
    file_info = file_tree.getroot()

    gen_files = file_info.find("general_files")
    feb_file = gen_files.find("febio_file").text
    processed_data_folder = gen_files.find("processed_data_directory").text

    models_dirname = os.path.dirname(feb_file)

    models = file_info.find("Models")
    for mod in models:
        if mod.tag is et.Comment:
            continue
        model_name = mod.attrib["name"]

        # get the experiment csv files

        # exp_femur_kinetics_csv = os.path.join(processed_data_folder, mod.find("kinetics_csv").text)
        exp_kinematics_csv = os.path.join(processed_data_folder, mod.find("kinematics_csv").text)
        # exp_tibia_kinetcs_csv = exp_femur_kinetics_csv.replace("kinetics_in_TibiaCS", "kinetics_in_ImageCS")

        model_results_dir = os.path.join(models_dirname,'Processed_Results_'+model_name)

        model_kinematics_csv = os.path.join(model_results_dir, "Tibiofemoral_Kinematics.csv")

        # model_tibia_kinetics_csv = os.path.join(model_results_dir, "Femur Kinetics in Image CS.csv")

        try:
            compare_results([exp_kinematics_csv, model_kinematics_csv], "JCS_Kinematics_Prediction_Errors", model_results_dir)
        except:
            print("failed comparison of kinematics for model " + model_name)

        # try:
        #     compare_results([exp_tibia_kinetcs_csv, model_tibia_kinetics_csv], "Tibia_Kinetics_Prediction_Errors", model_results_dir)
        # except:
        #     print("failed comparison of kinetics for model " + model_name)

def make_image(c):
    try:
        df_model = pd.read_csv(c, encoding='utf7', na_values='nan')
        print("model data read")
    except:
        df_model = pd.read_csv(c, encoding='utf8', na_values='nan')
        print("model data read")

    model_time_steps, model_data, model_data_names = restructure_model_data(df_model)
    column_names = ["medial", "anterior", "superior", "extension", "valgus", "external"]

    points_list = np.array([2,2.22,2.44,2.67,2.89,3.0])
    points_list -= 2.0

    save_directory = "C:\\Users\schwara2\Documents\calibration_paper\presentation\GIF"


    line_colors = ['r', 'b', 'g']
    for e,p in enumerate(points_list):

        # Plot the data on the same graph
        fig = plt.figure(figsize=(12.8, 7.2))
        fig.suptitle("Model Kinematics", fontsize = 18)
        plt.subplot(1, 2, 1)

        time_steps = model_time_steps
        data = model_data
        idx = np.where(time_steps == p)[0]

        for i, n in enumerate(column_names[0:3]):
            plt.plot(time_steps, data[i], line_colors[i] + '-', label=n)
            plt.plot(p,data[i][idx],'o',color=line_colors[i])
        plt.legend(loc='best')
        plt.xlabel('Simulation Time', fontsize = 14)
        plt.title('Translations', fontsize = 14)
        plt.ylabel('mm', fontsize = 14)

        plt.subplot(1, 2, 2)
        for i, n in enumerate(column_names[3:]):
            plt.plot(time_steps, data[i + 3], line_colors[i] + '-', label=n)
            plt.plot(p, data[i+3][idx], 'o', color=line_colors[i])
        plt.legend(loc='best')
        plt.xlabel('Simulation Time', fontsize = 14)
        plt.title('Rotations', fontsize = 14)
        plt.ylabel('deg', fontsize = 14)

        # plt.show()
        save_as = os.path.join(save_directory, "Model_Kinematics_{}".format(e))
        plt.savefig(save_as + '.png')
        # plt.savefig(save_as+'.svg', format='svg')
        plt.close()




if __name__ == '__main__':

    # compare_results(*sys.argv[1:])

    # if you want to compare a bunch of models that were created using an exp_to_mod xml file, first run log post processing on all
    # and then just use this function and give the xml as input
    from_xml("C:\\Users\schwara2\Documents\Open_Knees\du02_calibration\CustomizedFullModels\ExperimentalLoading03\Exp_to_Mod.xml")

    # # make image for presentation
    # mod = "C:\\Users\schwara2\Documents\Open_Knees\oks003_calibration\CustomizedFullModels\ExperimentalLoading\F90_Processed_Results\Tibiofemoral_Kinematics.csv"
    # make_image(mod)

    # below are examples of how to call the compare_results function.

    # #ACL
    # exp_p = "C:\\Users\schwara2\Documents\Open_Knees\du02_calibration\DataProcessing\Processed_Data\Laxity_9deg_AP1_kinematics_in_JCS.csv"
    # mod_p = "C:\\Users\schwara2\Documents\Open_Knees\du02_calibration\InSituStrain\Spring_Ties7\AnteriorLaxity\Processed_Results\Tibiofemoral_Kinematics.csv"
    # save_dir_p  = "C:\\Users\schwara2\Documents\Open_Knees\du02_calibration\InSituStrain\Spring_Ties7\AnteriorLaxity\Processed_Results"
    # compare_results([exp_p, mod_p], "predicted_kinematics_error", save_dir_p)

    # #PCL
    # exp_p = "C:\\Users\schwara2\Documents\Open_Knees\du02_calibration\DataProcessing\Processed_Data\Laxity_7deg_AP2_kinematics_in_JCS.csv"
    # mod_p = "C:\\Users\schwara2\Documents\Open_Knees\du02_calibration\InSituStrain\Spring_Ties7\PosteriorLaxity\Processed_Results\Tibiofemoral_Kinematics.csv"
    # save_dir_p  = "C:\\Users\schwara2\Documents\Open_Knees\du02_calibration\InSituStrain\Spring_Ties7\PosteriorLaxity\Processed_Results"
    # compare_results([exp_p, mod_p], "predicted_kinematics_error", save_dir_p)

    # #LCL
    # exp_p = "C:\\Users\schwara2\Documents\Open_Knees\du02_calibration\DataProcessing\Processed_Data\Laxity_11deg_VV2_kinematics_in_JCS.csv"
    # mod_p = "C:\\Users\schwara2\Documents\Open_Knees\du02_calibration\InSituStrain\Spring_Ties7\VarusLaxity\Processed_Results\Tibiofemoral_Kinematics.csv"
    # save_dir_p  = "C:\\Users\schwara2\Documents\Open_Knees\du02_calibration\InSituStrain\Spring_Ties7\VarusLaxity\Processed_Results"
    # compare_results([exp_p, mod_p], "predicted_kinematics_error", save_dir_p)

    # #MCL
    # exp_p = "C:\\Users\schwara2\Documents\Open_Knees\du02_calibration\DataProcessing\Processed_Data\Laxity_9deg_VV1_kinematics_in_JCS.csv"
    # mod_p = "C:\\Users\schwara2\Documents\Open_Knees\du02_calibration\InSituStrain\Spring_Ties7\ValgusLaxity\Processed_Results\Tibiofemoral_Kinematics.csv"
    # save_dir_p  = "C:\\Users\schwara2\Documents\Open_Knees\du02_calibration\InSituStrain\Spring_Ties7\VarusLaxity\Processed_Results"
    # compare_results([exp_p, mod_p], "predicted_kinematics_error", save_dir_p)
