import sys
import numpy as np
import os
import pandas as pd
import copy
from lxml import etree as et
import matplotlib.pyplot as plt
import csv
from matplotlib import gridspec
import LogPostProcessing

def experiment_data(experiment_csv):

    try:
        df_experiment = pd.read_csv(experiment_csv, encoding='utf7', na_values='nan')
        print("experiment data read")
    except:
        df_experiment = pd.read_csv(experiment_csv, encoding='utf8', na_values='nan')
        print("experiment data read")

    experiment_data = []

    for n in df_experiment.columns:
        experiment_data.append(np.asarray(list(map(float, df_experiment[n].values))))

    if "ImageCS" in experiment_csv:
        experiment_data = experiment_data[1:]
        data_names = df_experiment.columns[1:]

        experiment_time_steps = np.array(range(len(experiment_data[0])))
        experiment_time_steps = experiment_time_steps/max(experiment_time_steps)

    else:
        applied_load = experiment_data[1]
        experiment_data = experiment_data[2:]
        data_names = df_experiment.columns[2:]

        # the time steps are based on the normalized applied load
        # max_arg = np.nanargmax(np.abs(applied_load))
        # experiment_time_steps = applied_load / applied_load[max_arg]
        experiment_time_steps=applied_load

    experiment_data = np.asarray(experiment_data)

    # check that all data is in N, Nm, mm, deg. for plotting .change if necessary
    for i, name in enumerate(data_names):
        if "[m]" in name or "(m)" in name:
            experiment_data[i] = 1000 * experiment_data[i]
        elif "[rad]" in name or "(rad)" in name:
            experiment_data[i] = np.degrees(experiment_data[i])
        elif "[Nmm]" in name or "(Nmm)" in name:
            experiment_data[i] = experiment_data[i]/1000

    if "kinetics" in experiment_csv.lower():
        names = ["Medial_Load", "Anterior_Load", "Superior_Load", "Extension_Moment",
                 "Valgus_Moment", "External_Moment"]
    else:
        names = ["Medial_Translation","Anterior_Translation","Superior_Translation","Extension_Rotation","Valgus_Rotation","External_Rotation"]

    return experiment_time_steps, experiment_data, names

def read_raw_data(csv):
    try:
        df_experiment = pd.read_csv(csv, encoding='utf7', na_values='nan')
        print("experiment data read")
    except:
        df_experiment = pd.read_csv(csv, encoding='utf8', na_values='nan')
        print("experiment data read")

    raw_data = []

    for n in df_experiment.columns:
        raw_data.append(np.asarray(list(map(float, df_experiment[n].values))))

    time_steps = raw_data[1]
    raw_data = raw_data[2:]
    data_names = df_experiment.columns[2:]

    time_steps = np.asarray(time_steps)
    raw_data= np.asarray(raw_data)

    return time_steps, raw_data, data_names

def Raw_Data(csv_file, title):
    """load raw csv files and make graphs"""

    time, data, names = read_raw_data(csv_file)

    if "kinetics" in csv_file.lower():
        trans_title = 'Load [N]'
        rot_title = 'Torque [Nm]'
    else:
        trans_title = 'Translation [mm]'
        rot_title = 'Rotation [deg]'

    fig = plt.figure(figsize=(12,8))
    ttl = fig.suptitle(title, fontsize=16)

    trans_ax = plt.subplot2grid((1, 2), (0, 0))
    rot_ax = plt.subplot2grid((1, 2), (0, 1))

    colors=  ['r','b','g','c','m','y']

    mins = []
    maxs = []

    for a in range(3):
        ax = trans_ax
        ax.plot(time, data[a], label=names[a], color=colors[a])
        mins.append(np.nanmin(data[a]))
        maxs.append(np.nanmax(data[a]))
        ax.grid('on')

        trans_handles, trans_labels = ax.get_legend_handles_labels()

    ax.set_xlabel('Time (ms)')
    ax.set_title(trans_title)
    ax.legend(loc='best')

    mins = []
    maxs = []

    for a in range(3, 6):
        ax = rot_ax
        ax.plot(time, data[a], label=names[a], color=colors[a])
        mins.append(np.nanmin(data[a]))
        maxs.append(np.nanmax(data[a]))
        ax.grid('on')

        rot_handles, rot_labels = ax.get_legend_handles_labels()

    ax.set_xlabel('Time (ms)')
    ax.set_title(rot_title)
    ax.legend(loc='best')
    #
    # h = trans_handles + rot_handles
    # l = trans_labels + rot_labels
    #
    # lgd = fig.legend(h, l, loc='upper center', bbox_to_anchor=(1.05, 0.6))
    fig.savefig(title)

    plt.show()

def Processed_PF(csv_file, title):
    time, data, names = experiment_data(csv_file)

    # remove nans from data and time
    where_nan = np.isnan(data[0])
    data =  data[:,~where_nan]
    time= time[~where_nan]

    if "kinetics" in csv_file.lower():
        trans_title = 'Load [N]'
        rot_title = 'Torque [Nm]'
    else:
        trans_title = 'Translation [mm]'
        rot_title = 'Rotation [deg]'

    fig = plt.figure(figsize=(12, 8))
    ttl = fig.suptitle(title, fontsize=16)

    trans_ax = plt.subplot2grid((1, 2), (0, 0))
    rot_ax = plt.subplot2grid((1, 2), (0, 1))

    colors = ['r', 'b', 'g', 'c', 'm', 'y']

    mins = []
    maxs = []

    for a in range(3):
        ax = trans_ax
        ax.plot(time, data[a], 'o', label=names[a], color=colors[a])
        ax.plot(time, data[a], color=colors[a])
        mins.append(np.nanmin(data[a]))
        maxs.append(np.nanmax(data[a]))
        ax.grid('on')

        trans_handles, trans_labels = ax.get_legend_handles_labels()

    ax.set_xticklabels([])
    ax.set_title(trans_title)
    ax.legend(loc='best')

    mins = []
    maxs = []

    for a in range(3, 6):
        ax = rot_ax
        ax.plot(time, data[a], 'o', label=names[a], color=colors[a])
        ax.plot(time, data[a], color=colors[a])
        mins.append(np.nanmin(data[a]))
        maxs.append(np.nanmax(data[a]))
        ax.grid('on')

        rot_handles, rot_labels = ax.get_legend_handles_labels()

    ax.set_xticklabels([])
    ax.set_title(rot_title)
    ax.legend(loc='best')
    #
    # h = trans_handles + rot_handles
    # l = trans_labels + rot_labels
    #
    # lgd = fig.legend(h, l, loc='upper center', bbox_to_anchor=(1.05, 0.6))
    fig.savefig(title)

    plt.show()

def Processed_Laxity(files, title):
    """load the csv files from oks processed data and make graphs.
     files = [A,P,VR,VL,IR,ER]"""

    # fig = plt.figure(figsize=(8, 14))

    # ax1 = plt.subplot2grid((6,2),(0,0))
    # ax2 = plt.subplot2grid((6,2),(0,1))
    # ax3 = plt.subplot2grid((6,2),(1,0))
    # ax4 = plt.subplot2grid((6,2),(1,1))
    # ax5 = plt.subplot2grid((6,2),(2,0))
    # ax6 = plt.subplot2grid((6,2),(2,1))
    # ax7 = plt.subplot2grid((6,2),(3,0))
    # ax8 = plt.subplot2grid((6,2),(3,1))
    # ax9 = plt.subplot2grid((6,2),(4,0))
    # ax10 = plt.subplot2grid((6,2),(4,1))
    # ax11 = plt.subplot2grid((6,2),(5,0))
    # ax12= plt.subplot2grid((6,2),(5,1))
    #
    # all_axes = [(ax1,ax2),(ax3,ax4),(ax5,ax6),(ax7,ax8),(ax9,ax10),(ax11,ax12)]

    fig = plt.figure(figsize=(20, 5))

    ax1 = plt.subplot2grid((2,6),(0,0))
    ax2 = plt.subplot2grid((2,6),(1,0))
    ax3 = plt.subplot2grid((2,6),(0,1))
    ax4 = plt.subplot2grid((2,6),(1,1))
    ax5 = plt.subplot2grid((2,6),(0,2))
    ax6 = plt.subplot2grid((2,6),(1,2))
    ax7 = plt.subplot2grid((2,6),(0,3))
    ax8 = plt.subplot2grid((2,6),(1,3))
    ax9 = plt.subplot2grid((2,6),(0,4))
    ax10 = plt.subplot2grid((2,6),(1,4))
    ax11 = plt.subplot2grid((2,6),(0,5))
    ax12= plt.subplot2grid((2,6),(1,5))

    all_axes = [(ax1,ax2),(ax3,ax4),(ax5,ax6),(ax7,ax8),(ax9,ax10),(ax11,ax12)]

    if "kinematics" in files[0].lower():
        ax1.set_ylabel('Translation [mm]', fontsize=12)
        ax2.set_ylabel('Rotation [deg]', fontsize=12)
    else:
        ax1.set_ylabel('Load [N]', fontsize=12)
        ax2.set_ylabel('Moment [Nm]', fontsize=12)

    loading= ['Anterior Load [N]', 'Posterior Load [N]', 'Varus Torque [Nm]', 'Valgus Torque [Nm]', 'Internal Torque [Nm]', 'External Torque [Nm]']

    colors=  ['r','b','g','c','m','y']

    for i in range(6): # for each loading direction

        trans_ax= all_axes[i][0]
        rot_ax = all_axes[i][1]

        axes = [trans_ax, trans_ax, trans_ax, rot_ax, rot_ax, rot_ax]

        applied_load, data, names = experiment_data(files[i])

        if np.nanmean(applied_load) < 0:
            applied_load = -applied_load # get the positive since we are assigning direction in the x label name

        for a in range(6):

            ax = axes[a]
            ax.plot(applied_load, data[a], 'o', label=names[a], markersize = 3, color = colors[a])
            ax.plot(applied_load, data[a],color=colors[a])
            ax.grid('on')

            # if i==0:
            #     if a<3:
            #         ax.set_ylabel('Translation [mm]', fontsize=12)
            #     else:
            #         ax.set_ylabel('Rotation [deg]', fontsize=12)
            if a<3:
                t_handles, t_labels = ax.get_legend_handles_labels() # store handles and lables
                ax.set_xticklabels([])
            else:
                r_handles, r_labels = ax.get_legend_handles_labels()

            if a == 5:
                ax.set_xlabel(loading[i], fontsize=12)

    # plt.tight_layout()
    # fig.legend(trans_handles, trans_labels, loc='upper left')
    # fig.legend(rot_handles,rot_labels, loc='upper right')
    handles = t_handles + r_handles
    labels = t_labels + r_labels
    lgd = plt.legend(handles, labels, loc='upper right', bbox_to_anchor = (2.0,2.0))
    # lgd = plt.legend(handles, labels, loc='upper center', bbox_to_anchor=(1.05, 0.6))

    ttl = fig.suptitle(title, fontsize=16)
    # text = ax.text(-0.2, 1.05, "Aribitrary text", transform=ax.transAxes)
    # ax.set_title("Trigonometry")

    fig.savefig(title, bbox_extra_artists=(lgd,ttl), bbox_inches='tight')

    # fig.legend(h,l,loc='upper right',bbox_to_anchor=(1.0, 0.5))
    plt.show()


def make_graphs():

    # # 0 deg laxity kinematics
    # A = "C:\\Users\schwara2\Documents\Open_Knees\oks003_calibration\DataProcessing\Processed_Data\Laxity_0deg_AP1_kinematics_in_JCS.csv"
    # P = "C:\\Users\schwara2\Documents\Open_Knees\oks003_calibration\DataProcessing\Processed_Data\Laxity_0deg_AP2_kinematics_in_JCS.csv"""
    # VR= "C:\\Users\schwara2\Documents\Open_Knees\oks003_calibration\DataProcessing\Processed_Data\Laxity_0deg_VV1_kinematics_in_JCS.csv"
    # VL = "C:\\Users\schwara2\Documents\Open_Knees\oks003_calibration\DataProcessing\Processed_Data\Laxity_0deg_VV2_kinematics_in_JCS.csv"
    # IR = "C:\\Users\schwara2\Documents\Open_Knees\oks003_calibration\DataProcessing\Processed_Data\Laxity_0deg_EI2_kinematics_in_JCS.csv"
    # ER = "C:\\Users\schwara2\Documents\Open_Knees\oks003_calibration\DataProcessing\Processed_Data\Laxity_0deg_EI1_kinematics_in_JCS.csv"
    # files = [A,P,VR,VL,IR,ER]
    # Processed_Laxity(files,"Processed 0 deg Laxity Kinematics")

    # # 0 deg laxity kinetics
    # A = "C:\\Users\schwara2\Documents\Open_Knees\oks003_calibration\DataProcessing\Processed_Data\Laxity_0deg_AP1_TibiaKinetics_in_TibiaCS.csv"
    # P = "C:\\Users\schwara2\Documents\Open_Knees\oks003_calibration\DataProcessing\Processed_Data\Laxity_0deg_AP2_TibiaKinetics_in_TibiaCS.csv"""
    # VR= "C:\\Users\schwara2\Documents\Open_Knees\oks003_calibration\DataProcessing\Processed_Data\Laxity_0deg_VV1_TibiaKinetics_in_TibiaCS.csv"
    # VL = "C:\\Users\schwara2\Documents\Open_Knees\oks003_calibration\DataProcessing\Processed_Data\Laxity_0deg_VV2_TibiaKinetics_in_TibiaCS.csv"
    # IR = "C:\\Users\schwara2\Documents\Open_Knees\oks003_calibration\DataProcessing\Processed_Data\Laxity_0deg_EI2_TibiaKinetics_in_TibiaCS.csv"
    # ER = "C:\\Users\schwara2\Documents\Open_Knees\oks003_calibration\DataProcessing\Processed_Data\Laxity_0deg_EI1_TibiaKinetics_in_TibiaCS.csv"
    # files = [A,P,VR,VL,IR,ER]
    # Processed_Laxity(files,"Processed 0 deg Laxity Kinetics")

    # #0 deg laxity raw kinematics
    # raw_lx0 = "C:\\Users\schwara2\Documents\Open_Knees\oks003_calibration\DataProcessing\Processed_Data\Laxity_0deg_Kinematics_raw.csv"
    # Raw_Data(raw_lx0, "Raw 0 deg Laxity Kinematics")

    # #0 deg laxity raw kinetics
    # raw_lx0 = "C:\\Users\schwara2\Documents\Open_Knees\oks003_calibration\DataProcessing\Processed_Data\Laxity_0deg_Kinetics_raw.csv"
    # Raw_Data(raw_lx0, "Raw 0 deg Laxity Kinetics")

    # #passive flexion raw kinematics
    # raw_pf = "C:\\Users\schwara2\Documents\Open_Knees\oks003_calibration\DataProcessing\Processed_Data\Passive_Flexion_Kinematics_raw.csv"
    # Raw_Data(raw_pf, "Raw Passive Flexion Kinematics")

    # #passive flexion processed kinematics
    # processed_pf = "C:\\Users\schwara2\Documents\Open_Knees\oks003_calibration\DataProcessing\Processed_Data\Passive_Flexion_Kinematics_in_JCS.csv"
    # Processed_PF(processed_pf, "Processed Passive Flexion Kinematics")

    # #passive flexion raw kinetics
    # raw_pf = "C:\\Users\schwara2\Documents\Open_Knees\oks003_calibration\DataProcessing\Processed_Data\Passive_Flexion_Kinetics_raw.csv"
    # Raw_Data(raw_pf, "Raw Passive Flexion Kinetics")

    # #passive flexion processed kinetics- image_CS
    # processed_pf = "C:\\Users\schwara2\Documents\Open_Knees\oks003_calibration\DataProcessing\Processed_Data\Passive_Flexion_Kinetics_in_ImageCS.csv"
    # Processed_PF(processed_pf, "Processed Passive Flexion Kinetics")

    #passive flexion processed kinetics - tibia CS
    processed_pf = "C:\\Users\schwara2\Documents\Open_Knees\oks003_calibration\DataProcessing\Processed_Data\Passive_Flexion_Kinetics_in_TibiaCS.csv"
    Processed_PF(processed_pf, "Processed Passive Flexion Kinetics")


if __name__=="__main__":
    make_graphs()