# for processing du02 data for calibration phase
import os
import pandas as pd
import numpy as np
# import matplotlib.pyplot as plt
from scipy import signal
# from itertools import groupby
# from operator import itemgetter
import tdms_processing_oks as dp
import copy
import matplotlib.pyplot as plt

class Group:
    def __init__(self, name, channels, data, units, time, processed = False):
        self.name = name
        self.channels = channels
        self.data = data
        self.units = units
        self.time = time
        self.processed = processed


def butter_lowpass_filter(data, cutoff, fs, order):
    nyq = 0.5 * fs
    normal_cutoff = cutoff / nyq
    # Get the filter coefficients
    b, a = signal.butter(order, normal_cutoff)
    y = signal.filtfilt(b, a, data)

    return y


def apply_time_shift(kinetics_group, kinematics_group, kinetics_channel, kinematics_channel, sample_start_percent = 0.0, cut_percent = 0.1):
    """ find the time shift between the kinetics and kinematics data. default is to use first 10% of data,
     but can customize where to start (sample start percent) and how much to use (cut percent) to avoid bad data"""

    # for kinetics using SI force, for kinematics using flexion angle
    kinetics = kinetics_group.data[kinetics_channel]
    kinematics = kinematics_group.data[kinematics_channel]

    # filter the data to remove noise
    kinetics_filtered = butter_lowpass_filter(kinetics, cutoff=1, fs=50.0, order=2)
    kinematics_filtered  = butter_lowpass_filter(kinematics, cutoff=1, fs=50.0, order=2)

    # take only 10% (default) of data to compare extrema, starting from the start percent
    cut_idx = int(cut_percent*len(kinetics_filtered))
    start_idx = int(sample_start_percent*len(kinetics_filtered))
    kinetics_filtered = kinetics_filtered[start_idx:cut_idx+start_idx]
    kinematics_filtered = kinematics_filtered[start_idx:cut_idx+start_idx]

    # plt.plot(kinematics_filtered)
    # plt.plot(kinetics_filtered)
    # plt.show()

    # find the local extrema for both data
    idx_locmax_kinematics = signal.argrelextrema(kinematics_filtered, np.greater, order=100)[0]
    idx_locmin_kinematics = signal.argrelextrema(kinematics_filtered, np.less, order=100)[0]
    idx_locmin__kinetics = signal.argrelextrema(kinetics_filtered, np.less, order=100)[0]
    idx_locmax_kinetics = signal.argrelextrema(kinetics_filtered, np.greater, order=100)[0]

    idx_extrema_kinetics = np.sort(np.concatenate((idx_locmax_kinetics, idx_locmin__kinetics)))
    idx_extrema_kinematics =np.sort(np.concatenate((idx_locmax_kinematics, idx_locmin_kinematics)))

    # note they may not be the same length, an end may be included in one and not the other

    # create a 2d array of the difference in indices between all local extrema (kinematics X kinetics)
    index_difs = np.subtract(np.repeat([idx_extrema_kinetics], len(idx_extrema_kinematics), axis=0),
         np.repeat([idx_extrema_kinematics], len(idx_extrema_kinetics), axis = 0).T)

    min_cols = np.argmin(np.abs(index_difs), axis=1)

    min_difs = []
    for i in range(len(idx_extrema_kinematics)):
        dif = index_difs[i, min_cols[i]]
        min_difs.append(dif)

    min_difs = np.asarray(min_difs)
    # remove any outliers - likely they are poorly matched pairs, or "ends"

    # start by removing ones that we know are in the wrong direction.

    where_pos = np.where(min_difs >=0)[0]
    where_neg = np.where(min_difs < 0)[0]
    pos = min_difs[where_pos]
    neg = min_difs[where_neg]
    avg_pos = np.average(pos)
    avg_neg = np.average(neg)
    num_pos = len(pos)
    num_neg = len(neg)

    # if the average shift is too close to zero we will not do this step, becasue the shift may be straddling zero
    if num_pos > num_neg and avg_pos > 5:
        min_difs = np.delete(min_difs, where_neg)
    elif num_neg > num_pos and avg_neg < -5:
        min_difs = np.delete(min_difs, where_pos)

    # filter out outliers
    min_difs = np.asarray(min_difs)
    m=2
    std = np.std(min_difs)
    mean = np.mean(min_difs)
    min_difs_reduced = min_difs[abs(min_difs - mean) < (m * std)]

    # get the average shift
    time_shift = int(np.average(min_difs_reduced))

    if time_shift > 0:
        kinematics_shifted = kinematics_group.data[:, 0:len(kinematics) - time_shift]
        kinetics_shifted = kinetics_group.data[:, time_shift:]
    else:
        time_shift = -time_shift
        kinetics_shifted = kinetics_group.data[:, 0:len(kinetics) - time_shift]
        kinematics_shifted =kinematics_group.data[:, time_shift:]

    # plt.plot(kinematics_shifted[kinematics_channel])
    # plt.plot(kinetics_shifted[kinetics_channel])
    # plt.show()

    kinetics_group.data = kinetics_shifted
    kinetics_group.time = np.array([range(len(kinetics_shifted[0]))]*6)
    kinetics_group.processed = True

    kinematics_group.data = kinematics_shifted
    kinematics_group.time  = np.array([range(len(kinematics_shifted[0]))]*6)
    kinematics_group.processed = True

    return time_shift


def extract_kinetics_kinematcis(file):

    df = pd.read_csv(file, encoding='utf7')

    kinetics_channels = ["Force TF ML (N)","Force TF AP (N)","Force TF SI (N)","Torque TF FE (Nmm)","Torque TF VV (Nmm)","Torque TF IE (Nmm)"]
    kinetics_units = ['N','N','N','Nmm','Nmm','Nmm']
    kinetics_data = []

    for i in kinetics_channels:
        kinetics_data.append(list(map(float, df[i].values)))

    kinetics_data = np.asarray(kinetics_data)

    Kinetics_Group = Group("Kinetics", kinetics_channels, kinetics_data, kinetics_units, time =np.array([range(len(kinetics_data[0]))]*6))


    kinematics_channels = ["TF ML (mm)","TF AP (mm)","TF SI (mm)","TF FE (deg)","TF VV (deg)", "TF IE (deg)"]
    kinematics_units = ['mm','mm','mm','deg','deg','deg']
    kinematics_data = []

    for i in kinematics_channels:
        kinematics_data.append(list(map(float,df[i].values)))

    kinematics_data = np.asarray(kinematics_data)

    Kinematics_Group = Group("Kinematics", kinematics_channels, kinematics_data, kinematics_units,
                           time=np.array([range(len(kinematics_data[0]))]*6))

    return Kinetics_Group, Kinematics_Group


def change_kinematics_reporting(group):
    """report kinematics as cylindrical joint translations and rotations- they are initially given as clinical joint translations """

    # in a right knee we want the data to be represented as cylindrical joint rotations and translations with the folowing
    # postive directions:lateral, anterior, superior, extension, varus/adduction, internal rotation


    data = np.asarray(group.data)
    channels = group.channels

    updated_data = copy.deepcopy(data)
    updated_channels = copy.deepcopy(channels)

    # data is initially given using grood and sunday clinical translations
    q1 = data[0] #+lat
    q2 = data[1] #+ant tibia
    q3 = -data[2] # negated because don said this is +sup tibia, and grood and suntay describes joint distraction as q3
    # alpha = data[3] #+flex
    beta = np.pi/2 - np.radians(data[4]) #don described this as +val, grood and sunday desribes beta as pi/2+adduction
    # gamma = data[5] #external

    # convert these to joint translations using equation 5 from grood and suntay
    S1 = np.divide(q1 + np.multiply(q3,np.cos(beta)),(1-np.square(np.cos(beta))))
    S2 = q2
    S3= -q3- np.multiply(S1, np.cos(beta))

    # S1,S2 and S3 should now be described in the same directions as out model CS, because G&S uses those.

    # assigning the names and directions of the channels to align with our JCS definitions
    updated_data[0] = S1
    updated_channels[0] = 'Knee JCS Lateral Translation'

    updated_data[1] = S2
    updated_channels[1] = 'Knee JCS Anterior Translation'

    updated_data[2] = S3
    updated_channels[2] = 'Knee JCS Superior Translation'

    updated_data[3] = -data[3] #extention
    updated_channels[3] = 'Knee JCS Extension Rotation'

    updated_data[4] = -data[4] #adduction/varus
    updated_channels[4] = 'Knee JCS Adduction Rotation'

    updated_data[5] = -data[5] #internal rotation
    updated_channels[5] = 'Knee JCS Internal Rotation'

    group.data = updated_data
    group.channels = updated_channels
    group.processed = True


def change_kinetics_reporting(group):
    """flip the channels which are reported opposite to the model."""

    # in a right knee we want the data to be represented as external forces and moments applied to the tibia in the tibia CS, with the
    # following postive directions: lateral(x), anterior (y), superior(z), extension, varus/adduction, internal rotation

    data = np.asarray(group.data)
    channels = group.channels

    updated_data = copy.deepcopy(data)
    updated_channels = copy.deepcopy(channels)

    # changing the names and directions of the channels to align with our JCS definitions. everything is labeled as tibia motion relative to femur

    # experiment - lateral
    # model - lateral
    updated_channels[0] = 'External Tibia_x Load'

    # experiment - anterior
    # model - anterior
    updated_channels[1] = 'External Tibia_y Load'

    # experiment - Superior
    # model - superior
    updated_channels [2] = 'External Tibia_z Load'

    # experiment - flexion
    # model - Extension
    updated_channels[3] = 'External Tibia_x Moment'
    updated_data[3] = data[3] # from visual inspection of data, it appears x moment is not inverted


    # experiment Valgus
    # model- varus
    updated_channels[4]  = 'External Tibia_y Moment'
    updated_data[4] = -data[4]

    # Experiment - external
    # model- internal
    updated_channels [5] = 'External Tibia_z Moment'
    updated_data[5] = -data[5]

    group.data = updated_data
    group.channels = updated_channels
    group.processed = True


def extract_idx(group, loading_channel, loading_direction = 1):
    """extract the positive and negative"""

    dt = np.asarray(group.data[loading_channel])

    if loading_direction == 1:
        keep_idx = np.where(dt > 0.0)[0]
    else:
        keep_idx = np.where(dt < 0.0)[0]

    return keep_idx


def resampling_index(group, resampling_channel, increment, range, start = 0):

    # chan_idx = group.channels.index(resampling_channel)
    chan_data = group.data[resampling_channel]

    # find the approximate resampling increments
    if increment > 0.0:
        resampling_approx = np.arange(start, np.max(chan_data)+increment, increment)
        b = 1.0
    else:
        resampling_approx = np.arange(start, np.min(chan_data) + increment, increment)
        b = -1.0

    resampling_increments = []
    indices = []

    # within each increment, find the point with the most data points surrounding it wihin the given range
    for inc in resampling_approx:
        num_points = []
        cropping_indexes = []
        for i in np.arange(inc, inc + increment, b):
            cropping_index = dp.crop_index(group, range, resampling_channel, baseline=i)
            cropping_indexes.append(cropping_index)
            num_points.append(len(cropping_index))

        max_idx = np.argmax(num_points)
        resamp_point = inc + max_idx
        resampling_increments.append(resamp_point)
        indices.append(cropping_indexes[max_idx])

    return indices, resampling_increments


def idealized_loading(group, loading_channel):
    """ generate a group representing the idealized loading of the loading channel -
    ie zero load on all other channels"""

    #create a copy of the group, make loads zero on all other channels
    group_copy = copy.deepcopy(group)
    mask_array = np.zeros((np.shape(group_copy.data)))
    mask_array[loading_channel] = 1
    group_copy.data[~mask_array] = 0

    return group_copy


def passive_flexion_processing(kinetics_group, kinematics_group, file_directory, model_offsets):
    """process the passive flexion data"""

    # save the raw data as csv and png file
    dp.plot_groups("Passive_Flexion_Kinematics_raw", kinematics_group, 'Time', file_directory, show_plot=False)
    dp.plot_groups("Passive_Flexion_Kinetics_raw", kinetics_group, 'Time', file_directory, show_plot=False)

    # expriement with SI force cropping
    kinetics_group_copy_temp = copy.deepcopy(kinetics_group)
    kinematics_group_copy_temp = copy.deepcopy(kinematics_group)

    # take only data where SI force is positive -
    idx_keep = extract_idx(kinetics_group_copy_temp, 2, 1)
    dp.crop_data(kinetics_group_copy_temp, idx_keep)
    dp.crop_data(kinematics_group_copy_temp, idx_keep)

    # save the raw data as csv and png file
    dp.plot_groups("Passive_Flexion_Kinematics_SI_raw", kinematics_group_copy_temp, 'Time', file_directory, show_plot=False)
    dp.plot_groups("Passive_Flexion_Kinetics__SI_raw", kinetics_group_copy_temp, 'Time', file_directory, show_plot=False)

    # Crop the raw data on the cutoff channels by the force and torque cutoff values
    force_cutoff = 8.0 #N
    torque_cutoff = 1500 #Nmm

    cutoff_channels = [0, 1, 4, 5]
    cutoff_values = [force_cutoff, force_cutoff, torque_cutoff, torque_cutoff]

    for idx, chan in enumerate(cutoff_channels):
        cropping_idx = dp.crop_index(kinetics_group, cutoff_values[idx], chan)
        dp.crop_data(kinetics_group, cropping_idx)
        dp.crop_data(kinematics_group, cropping_idx)

    # assume that the knee is at full extension (ie zero degrees flexion) at the beginning of the experiment
    zero_flex = kinematics_group.data[3][0]

    # sort the data by ascneding flexion axis
    sorting_channel = 3
    srt_index, srt_data = dp.sorting_index(kinematics_group, sorting_channel)
    dp.sort_data(kinetics_group, srt_index, srt_data)
    dp.sort_data(kinematics_group, srt_index, srt_data)

    # resample at approximately 5 degree increments, or wherever the most data it available

    # resample at 5 degree increments by averaging each channel where flexion angle is within 0.1 degrees
    resampling_channel = 3
    increments = 5.0
    range = 0.1

    # resample at 5 degree increments starting at the zero flexion angle
    resamp_idx, resamp_intervals = dp.resampling_index(kinematics_group, resampling_channel, increments, range, start = zero_flex)

    dp.resample_data(kinematics_group, resamp_idx, resamp_intervals)
    dp.resample_data(kinetics_group, resamp_idx, resamp_intervals)

    # save in experiment cooridnate system for comparison before changing to model CS
    dp.plot_groups("Passive_Flexion_Kinematics_in_JCS_experiment", kinematics_group, 'Flexion Angle (deg)', file_directory,
                   show_plot=False)
    dp.plot_groups("Passive_Flexion_TibiaKinetics_in_TibiaCS_experiment", kinetics_group, 'Flexion Angle (deg)', file_directory,
                   show_plot=False)

    # report the axes in the same direction as the model reporting :
    # positive directions are - extension, adduction (varus), external
    change_kinematics_reporting(kinematics_group)

    change_kinetics_reporting(kinetics_group) # figure out how it is reported in the experiment...

    # report the kinetics as external loads applied to femur , reported in tibia coordinate system
    dp.kinetics_tibia_to_femur(kinetics_group, kinematics_group)

    # apply model offsets to kinematics
    dp.apply_offsets(kinematics_group, -model_offsets)

    # save the data as csv and png file
    dp.plot_groups("Passive_Flexion_Kinematics_in_JCS", kinematics_group, 'Flexion Angle (deg)',file_directory,show_plot=False)
    dp.plot_groups("Passive_Flexion_Kinetics_in_TibiaCS", kinetics_group, 'Flexion Angle (deg)', file_directory,show_plot=False)


def laxity_processing(kinetics_group, kinematics_group, file_directory, loading_channel, model_offsets):
    """process the laxity data"""

    kinetics_channel_names = ['LM','AP','SI','FE','VV','EI']

    # save the raw data as csv and png file
    dp.plot_groups("Laxity_{}_Kinematics_raw".format(kinetics_channel_names[loading_channel]), kinematics_group, 'Time', file_directory,show_plot=False)
    dp.plot_groups("Laxity_{}_Kinetics_raw".format(kinetics_channel_names[loading_channel]), kinetics_group, 'Tim', file_directory,show_plot=False)

    # expriement with SI force cropping
    kinetics_group_copy_temp = copy.deepcopy(kinetics_group)
    kinematics_group_copy_temp = copy.deepcopy(kinematics_group)

    # take only data where SI force is positive -
    idx_keep = extract_idx(kinetics_group_copy_temp, 2, 1)
    dp.crop_data(kinetics_group_copy_temp, idx_keep)
    dp.crop_data(kinematics_group_copy_temp, idx_keep)

    # save the raw data as csv and png file
    dp.plot_groups("Laxity_{}_Kinematics_SI_raw".format(kinetics_channel_names[loading_channel]), kinematics_group_copy_temp, 'Time',
                   file_directory, show_plot=False)
    dp.plot_groups("Laxity_{}_Kinetics__SI_raw".format(kinetics_channel_names[loading_channel]), kinetics_group_copy_temp, 'Tim',
                   file_directory, show_plot=False)

    loading_directions = [1, -1]
    for i in loading_directions:
        ang = 0
        while ang < 120:

            # make copies of the data so we don't make changes to the original data at every loop
            kinetics_group_copy = copy.deepcopy(kinetics_group)
            kinematics_group_copy = copy.deepcopy(kinematics_group)

            # extract only the data in the loading direction
            idx_keep = extract_idx(kinetics_group_copy, loading_channel, i)
            dp.crop_data(kinetics_group_copy, idx_keep)
            dp.crop_data(kinematics_group_copy, idx_keep)

            # find the felxion angles to resample the data at so there is the largest range of data
            angle_cutoff = 0.2
            cropping_indexes = []
            ranges = []

            loading_data = kinetics_group_copy.data[loading_channel]

            # find the angle within ang + 10 that has the largest range of data

            for a in np.arange(ang, ang+10, 1):
                cropping_index = dp.crop_index(kinematics_group_copy, angle_cutoff, 3, baseline=a)
                cropping_indexes.append(cropping_index)
                load_data = loading_data[cropping_index]
                if len(load_data) > 0:
                    rng = max(load_data) - min(load_data)
                    ranges.append(rng)
                else:
                    ranges.append(np.nan)

            try:
                max_idx = np.nanargmax(ranges) # to look at the largest range
            except: # if they are all nans, ie no data at any angles sampled, try the next 10 degrees
                ang += 10
                continue

            flexion_angle  = ang + max_idx
            cropping_index_1 = cropping_indexes[max_idx]
            range_1 = ranges[max_idx]

            dp.crop_data(kinematics_group_copy, cropping_index_1)
            dp.crop_data(kinetics_group_copy, cropping_index_1)

            # sort by loading channel
            sort_idx, srt_data = dp.sorting_index(kinetics_group_copy, loading_channel, i)
            dp.sort_data(kinetics_group_copy, sort_idx, srt_data)
            dp.sort_data(kinematics_group_copy, sort_idx, srt_data)

            # select points from the data set to use - this is instead of averaging the data
            loading_data = kinetics_group_copy.data[loading_channel]
            keep_idx = [0]
            inc = range_1/15 # take about 15 approx evenly spaced data points (assuming there are not "missing regions" in the data
            check_idx = 0
            while check_idx < len(loading_data) - 1:
                check_idx += 1
                dif = abs(loading_data[check_idx] - loading_data[keep_idx[-1]])
                if dif >=inc:
                    keep_idx.append(check_idx)

            dp.crop_data(kinetics_group_copy, keep_idx)
            dp.crop_data(kinematics_group_copy, keep_idx)

            # for labeling graphs
            if i == 1:
                num = '1'
            else:
                num = '2'

            # plot the results
            channel_units = kinetics_group_copy.units[loading_channel]

            # plot the kinematics in the experiment cs before offsets, etc.
            dp.plot_groups('Laxity_{}deg_'.format(flexion_angle) + kinetics_channel_names[
                loading_channel] + num + '_kinematics_in_JCS_experiment', kinematics_group_copy,
                           'Applied Load (' + channel_units + ')', file_directory, show_plot=False)
            dp.plot_groups('Laxity_{}deg_'.format(flexion_angle) + kinetics_channel_names[loading_channel] + num +'_Tibiakinetics_in_TibiaCS_experiment', kinetics_group_copy,
                        'Applied Load (' + channel_units +')', file_directory,show_plot=False)

            # report the axes in the same direction as the model reporting :
            # positive directions in the model are - extension, adduction (varus), internal
            change_kinematics_reporting(kinematics_group_copy)

            change_kinetics_reporting(kinetics_group_copy)

            # # for in situ strain calibration, generate kinetics csv with idealized loading on the loading channel,
            # # zero loads on all other axes using tibia loads
            # kinetics_group_idealized_tibia = idealized_loading(kinetics_group_copy, loading_channel)

            # plot the tibia loads in tibia CS for comparison with model later -
            # the time axis needs to be the same as the one applied in the model, so will save a copy of this to plot later
            Tibiakinetics_group_copy = copy.deepcopy(kinetics_group_copy)

            # report kinetics as forces applied to femur in tibia coordinate system
            dp.kinetics_tibia_to_femur(kinetics_group_copy, kinematics_group_copy, loading_channel)

            # # transfer to femur for the idealized group
            # dp.kinetics_tibia_to_femur(kinetics_group_idealized_tibia, kinematics_group_copy, loading_channel)

            # set the time in the TibiaKinetics group the same as the femur kinetics group, and plot
            Tibiakinetics_group_copy.time = kinetics_group_copy.time
            dp.plot_groups('Laxity_{}deg_'.format(flexion_angle) + kinetics_channel_names[loading_channel] + num + '_TibiaKinetics_in_TibiaCS',
                           Tibiakinetics_group_copy,
                           'Applied Load (' + channel_units + ')', file_directory, show_plot=False)

            # # create another idealized loading case using femur kinetics
            # kinetics_group_idealized_femur = idealized_loading(kinetics_group_copy, loading_channel)

            # apply model offsets to kinematics
            dp.apply_offsets(kinematics_group_copy, -model_offsets)

            dp.plot_groups('Laxity_{}deg_'.format(flexion_angle) + kinetics_channel_names[loading_channel] + num +'_kinetics_in_TibiaCS', kinetics_group_copy,
                        'Applied Load (' + channel_units +')', file_directory,show_plot=False)
            dp.plot_groups('Laxity_{}deg_'.format(flexion_angle) + kinetics_channel_names[loading_channel] + num + '_kinematics_in_JCS', kinematics_group_copy,
                        'Applied Load (' + channel_units + ')',file_directory ,show_plot=False)


            # # plot the idealized loading groups
            # dp.plot_groups()



            # increase the flexion angle to find the next data set
            ang = flexion_angle + 15


def AP_processing(kinetics_group, kinematics_group, file_directory, loading_channel, model_offsets):

    kinetics_channel_names = ['LM', 'AP', 'SI', 'FE', 'VV', 'EI']

    loading_directions = [1, -1]
    for i in loading_directions:
        # make a copy of the data
        kinetics_group_copy = copy.deepcopy(kinetics_group)
        kinematics_group_copy = copy.deepcopy(kinematics_group)

        # extract only the data in the loading direction
        idx_keep = extract_idx(kinetics_group_copy, loading_channel, i)
        dp.crop_data(kinetics_group_copy, idx_keep)
        dp.crop_data(kinematics_group_copy, idx_keep)

        # set up cropping variables
        force_cutoff = 1.5 #N
        torque_cutoff = 1000 #Nmm
        cutoff_channels = [4, 5, 0]  # VV torque,EI torque, ML force
        cutoff_values = [torque_cutoff, torque_cutoff, force_cutoff]

        # crop any other channels by force and torque cutoffs
        for idx, chan in enumerate(cutoff_channels):
            cropping_idx = dp.crop_index(kinetics_group_copy, cutoff_values[idx], chan)
            dp.crop_data(kinetics_group_copy, cropping_idx)
            dp.crop_data(kinematics_group_copy, cropping_idx)

        #sort by increasing flexion angle
        # sort the data by ascneding flexion axis
        sorting_channel = 3
        srt_index, srt_data = dp.sorting_index(kinematics_group_copy, sorting_channel)
        dp.sort_data(kinetics_group_copy, srt_index)
        dp.sort_data(kinematics_group_copy, srt_index)

        # select points where the load is always increasing/decreasing depending on loading direction
        # select points from the data set to use - this is instead of averaging the data
        loading_data = kinetics_group_copy.data[loading_channel]
        keep_idx = [0]
        check_idx = 0
        inc = loading_data[-1]/10
        while check_idx < len(loading_data) - 1:
            check_idx += 1
            dif = loading_data[check_idx] - loading_data[keep_idx[-1]]
            if dif*i > 0:
                if abs(dif) > inc:
                    keep_idx.append(check_idx)

        dp.crop_data(kinetics_group_copy, keep_idx)
        dp.crop_data(kinematics_group_copy, keep_idx)

        # set the time as the loading channel
        loading_data = kinetics_group_copy.data[loading_channel]
        kinetics_group_copy.time= [loading_data]*6
        kinematics_group_copy.time = [loading_data]*6

        channel_units = kinetics_group_copy.units[loading_channel]
        if i==1:
            num='1'
        else:
            num='2'

        # plot the kinematics in the experiment cs before offsets, etc.
        dp.plot_groups('AP' + num + '_kinematics_in_JCS_experiment', kinematics_group_copy,
                       'Applied Load (' + channel_units + ')', file_directory, show_plot=False)
        dp.plot_groups('AP' + num + '_Tibiakinetics_in_TibiaCS_experiment', kinetics_group_copy,
                       'Applied Load (' + channel_units + ')', file_directory, show_plot=False)

        # change reporting directions
        change_kinematics_reporting(kinematics_group_copy)
        change_kinetics_reporting(kinetics_group_copy)


        # move kinetics to femur loads
        dp.kinetics_tibia_to_femur(kinetics_group_copy, kinematics_group_copy, loading_channel)


        # apply model offsets to kinematics
        dp.apply_offsets(kinematics_group_copy, -model_offsets)

        dp.plot_groups('AP' + num + '_kinetics_in_TibiaCS', kinetics_group_copy,
                       'Applied Load (' + channel_units + ')', file_directory, show_plot=False)
        dp.plot_groups('AP' + num + '_kinematics_in_JCS', kinematics_group_copy,
                       'Applied Load (' + channel_units + ')', file_directory, show_plot=False)

def all_kinetics_kinematics(all_groups, file_directory):

    fig = plt.figure(figsize=(12, 12))
    fig.suptitle("kinematics vs kinetics", fontsize=16)
    colors = ['r','b','y','g','c']

    for i in range(6):
        for j in range(6):
            ax = plt.subplot2grid((6, 6), (i, j))
            for n,(kinetics_group, kinematics_group) in enumerate(all_groups):
                ax.plot(kinematics_group.data[j], kinetics_group.data[i], marker='o', color=colors[n], markersize=0.1)
                if n==0:
                    if i == 5:
                        ax.set_xlabel(kinematics_group.channels[j])
                    if j == 0:
                        ax.set_ylabel(kinetics_group.channels[i])
    try:
        os.chdir('Processed_Data')
    except:
        os.mkdir('Processed_Data')
        os.chdir('Processed_Data')

    # plt.show()
    # Saves files as .png
    fig.savefig("kinematics_vs_kinetics_all.png",
                format="png", dpi=100)

    plt.close(fig)

    # go back to the directory containing the files
    os.chdir(file_directory)

def kinetics_kinematics_analysis(kinetics_group, kinematics_group, file_directory, loading_name):
    """ create a 6X6 plot of kinetics vs kinematics for each DOF. to check directionality of kinetics channels"""

    fig = plt.figure(figsize=(12, 12))
    fig.suptitle("kinematics vs kinetics", fontsize=16)

    for i in range(6):
        for j in range(6):
            ax= plt.subplot2grid((6, 6), (i, j))
            ax.plot(kinematics_group.data[j], kinetics_group.data[i])
            if i==5:
                ax.set_xlabel(kinematics_group.channels[j])
            if j==0:
                ax.set_ylabel(kinetics_group.channels[i])
    try:
        os.chdir('Processed_Data')
    except:
        os.mkdir('Processed_Data')
        os.chdir('Processed_Data')

    # plt.show()
    # Saves files as .png
    fig.savefig("kinematics_vs_kinetics_" +loading_name +".png",
                    format="png", dpi=100)


    plt.close(fig)

    # go back to the directory containing the files
    os.chdir(file_directory)

def passive_flexion_all(all_groups, file_directory, loading_channel, model_offsets):

    #combine kinematics and kinetics from all groups:
    first_group = all_groups[0]
    kinetics_group = first_group[0]
    kinematics_group= first_group[1]

    for data in all_groups[1:]:
        next_kinetics = data[0]
        next_kinematics= data[1]
        kinetics_group.data = np.concatenate((kinetics_group.data ,next_kinetics.data),axis=1)
        kinematics_group.data = np.concatenate((kinematics_group.data, next_kinematics.data), axis=1)
        next_kinetics_time = next_kinetics.time + np.reshape(kinetics_group.time[:,-1], (6,1)) + np.ones((6,1))
        next_kinematics_time= next_kinematics.time + np.reshape(kinematics_group.time[:,-1],(6,1)) + np.ones((6,1))
        kinetics_group.time = np.concatenate((kinetics_group.time ,next_kinetics_time),axis=1)
        kinematics_group.time = np.concatenate((kinematics_group.time, next_kinematics_time), axis=1)

    kinetics_channel_names = ['LM','AP','SI','FE','VV','EI']

    # set up cropping variables
    force_cutoff = 5
    torque_cutoff = 1000
    cutoff_channels = [0, 1, 2, 4, 5]  # Varus Torque,External Rotation Torque, Anterior Drawer,Lateral Drawer
    cutoff_values = [force_cutoff, force_cutoff, force_cutoff, torque_cutoff, torque_cutoff]

    # make copies of the data so we don't make changes to the original data at every loop
    kinetics_group_copy2 = copy.deepcopy(kinetics_group)
    kinematics_group_copy2 = copy.deepcopy(kinematics_group)

    # crop any other channels by force and torque cutoffs
    for idx, chan in enumerate(cutoff_channels):
        if chan == loading_channel:  # dont crop if its the current loading channel
            continue
        cropping_idx = dp.crop_index(kinetics_group_copy2, cutoff_values[idx], chan)
        dp.crop_data(kinetics_group_copy2, cropping_idx)
        dp.crop_data(kinematics_group_copy2, cropping_idx)

    # sort the data in order of ascending/descending load
    srt_index, srt_data = dp.sorting_index(kinetics_group_copy2, loading_channel)

    dp.sort_data(kinetics_group_copy2, srt_index, srt_data)
    dp.sort_data(kinematics_group_copy2, srt_index, srt_data)

    channel_units = kinetics_group_copy2.units[loading_channel]

    # plot the kinematics in the experiment cs before offsets, etc.
    dp.plot_groups('pf_kinematics_in_JCS_experiment', kinematics_group_copy2,
                   'Applied Load (' + channel_units + ')', file_directory, show_plot=False)
    dp.plot_groups('pf_Tibiakinetics_in_TibiaCS_experiment',
                   kinetics_group_copy2, 'Applied Load (' + channel_units + ')', file_directory,
                   show_plot=False)

def processing_all(all_groups, file_directory, loading_channel, model_offsets):
    """process using all loading data"""

    #combine kinematics and kinetics from all groups:
    first_group = all_groups[0]
    kinetics_group = first_group[0]
    kinematics_group= first_group[1]

    for data in all_groups[1:]:
        next_kinetics = data[0]
        next_kinematics= data[1]
        kinetics_group.data = np.concatenate((kinetics_group.data ,next_kinetics.data),axis=1)
        kinematics_group.data = np.concatenate((kinematics_group.data, next_kinematics.data), axis=1)
        next_kinetics_time = next_kinetics.time + np.reshape(kinetics_group.time[:,-1], (6,1)) + np.ones((6,1))
        next_kinematics_time= next_kinematics.time + np.reshape(kinematics_group.time[:,-1],(6,1)) + np.ones((6,1))
        kinetics_group.time = np.concatenate((kinetics_group.time ,next_kinetics_time),axis=1)
        kinematics_group.time = np.concatenate((kinematics_group.time, next_kinematics_time), axis=1)

    kinetics_channel_names = ['LM','AP','SI','FE','VV','EI']



    for flexion_angle in np.arange(0,120,10):

        # make copies of the data so we don't make changes to the original data at every loop
        kinetics_group_copy = copy.deepcopy(kinetics_group)
        kinematics_group_copy = copy.deepcopy(kinematics_group)

        # crop any data that falls outside the flexion angle cutoff
        angle_cutoff = 2.0
        cropping_index_1 = dp.crop_index(kinematics_group_copy, angle_cutoff, 3, baseline=flexion_angle)
        dp.crop_data(kinematics_group_copy, cropping_index_1)
        dp.crop_data(kinetics_group_copy, cropping_index_1)

        num_points = np.shape(kinetics_group_copy.data)[1]

        print("after cropping at {} deg, by +-2 deg there are {} data points left".format(flexion_angle,num_points))

        loading_directions = [1, -1]

        # set up cropping variables
        force_cutoff = 5
        torque_cutoff = 1000
        cutoff_channels = [0,1,2,4,5]  # Varus Torque,External Rotation Torque, Anterior Drawer,Lateral Drawer
        cutoff_values = [force_cutoff, force_cutoff, force_cutoff, torque_cutoff, torque_cutoff]

        for i in loading_directions:

            # make copies of the data so we don't make changes to the original data at every loop
            kinetics_group_copy2 = copy.deepcopy(kinetics_group_copy)
            kinematics_group_copy2 = copy.deepcopy(kinematics_group_copy)

            # extract the data for the loading channel and loading direction
            idx_keep = extract_idx(kinetics_group_copy2, loading_channel, loading_direction=i)
            dp.crop_data(kinetics_group_copy2, idx_keep)
            dp.crop_data(kinematics_group_copy2, idx_keep)

            num_points = np.shape(kinetics_group_copy2.data)[1]
            print("after extractig for loading direction {}, there are {} data points left".format(i,num_points))

            # crop any other channels by force and torque cutoffs
            for idx, chan in enumerate(cutoff_channels):
                if chan == loading_channel:  # dont crop if its the current loading channel
                    continue
                cropping_idx = dp.crop_index(kinetics_group_copy2, cutoff_values[idx], chan)
                dp.crop_data(kinetics_group_copy2, cropping_idx)
                dp.crop_data(kinematics_group_copy2, cropping_idx)

            num_points = np.shape(kinetics_group_copy2.data)[1]
            print("after cropping forces, there are {} data points left".format(num_points))

            if num_points>1:

                # sort the data in order of ascending/descending load
                srt_index, srt_data = dp.sorting_index(kinetics_group_copy2, loading_channel, loading_direction=i)

                dp.sort_data(kinetics_group_copy2, srt_index, srt_data)
                dp.sort_data(kinematics_group_copy2, srt_index, srt_data)

                channel_units = kinetics_group_copy2.units[loading_channel]

                if i == 1:
                    num = '1'
                else:
                    num = '2'

                # plot the kinematics in the experiment cs before offsets, etc.
                dp.plot_groups('{}deg_'.format(flexion_angle) + kinetics_channel_names[
                    loading_channel] + num + '_kinematics_in_JCS_experiment', kinematics_group_copy2,
                               'Applied Load (' + channel_units + ')', file_directory, show_plot=False)
                dp.plot_groups('{}deg_'.format(flexion_angle) + kinetics_channel_names[loading_channel] + num +'_Tibiakinetics_in_TibiaCS_experiment',
                               kinetics_group_copy2,'Applied Load (' + channel_units +')', file_directory,show_plot=False)


def process_csv_files(file_directory, Model_Properties):

    csv_files = []
    os.chdir(file_directory)
    for file in os.listdir(file_directory):
        if file.endswith('.csv'):
            csv_files.append(file)
        else:
            pass

    # calculate model offsets
    model_offsets = dp.find_model_offsets(Model_Properties)

    #store all the data together for analysis purpose
    all_data = []

    # process the data in each of the tdms files
    for file in csv_files:

        kinetics_group, kinematics_group= extract_kinetics_kinematcis(file)

        if 'passive' in file.lower():
            # dp.plot_groups("Passive_Flexion_Kinematics_raw_no_shift", kinematics_group, 'Time', file_directory, show_plot=False)
            # dp.plot_groups("Passive_Flexion_Kinetics_raw_no_shift", kinetics_group, 'Time', file_directory, show_plot=False)
            ts = apply_time_shift(kinetics_group, kinematics_group, 2, 3, cut_percent=0.5)
            print('PassiveFlexion_time_shift:{}'.format(ts))
            all_data.append((kinetics_group, kinematics_group))
            # kinetics_kinematics_analysis(kinetics_group, kinematics_group, file_directory, 'PF')
            passive_flexion_processing(kinetics_group, kinematics_group, file_directory, model_offsets)
        elif 'AP' in file:
            ts = apply_time_shift(kinetics_group, kinematics_group, 1, 1)
            print('AP_time_shift:{}'.format(ts))
            all_data.append((kinetics_group, kinematics_group))
            # kinetics_kinematics_analysis(kinetics_group, kinematics_group, file_directory, 'AP')
            laxity_processing(kinetics_group, kinematics_group, file_directory, 1, model_offsets)
            # AP_processing(kinetics_group, kinematics_group, file_directory, 1, model_offsets)
        elif 'VV' in file:
            # first delete first 1260 data points, as the test really only begins then.
            kinematics_group.data = kinematics_group.data[:, 1260:]
            kinetics_group.data = kinetics_group.data[:,1260:]
            ts = apply_time_shift(kinetics_group, kinematics_group, 4, 4, sample_start_percent=0.2)
            print('VV_time_shift:{}'.format(ts))
            all_data.append((kinetics_group, kinematics_group))
            # kinetics_kinematics_analysis(kinetics_group, kinematics_group, file_directory, 'VV')
            laxity_processing(kinetics_group, kinematics_group, file_directory, 4, model_offsets)
        elif 'IE' in file:
            ts = apply_time_shift(kinetics_group, kinematics_group, 5, 5, cut_percent=1.0)
            print('IE_time_shift:{}'.format(ts))
            all_data.append((kinetics_group, kinematics_group))
            # kinetics_kinematics_analysis(kinetics_group, kinematics_group, file_directory, 'IE')
            laxity_processing(kinetics_group, kinematics_group, file_directory, 5, model_offsets)
        else:
            pass

    # all_kinetics_kinematics(all_data, file_directory)
    # processing_all(all_data, file_directory, 2, model_offsets)
    # processing_all(all_data, file_directory, 0, model_offsets)
    # passive_flexion_all(all_data, file_directory, 3, model_offsets)

    # Kinetics Data
    # 0 - ML force
    # 1-  AP force
    # 2 - SI force
    # 3 - FE torque
    # 4 - VV torque
    # 5 - IE torque

    # Kinematics Data
    # 0 - ML trans
    # 1- AP trans
    # 2 - SI trans
    # 3 - FE rot
    # 4 - VV rot
    # 5 - IE rot



if __name__ == "__main__":

    # process_csv_files(sys.argv[-1])

    # csv_directory = '/home/schwara2/Documents/Open_Knees/knee_hub/DU02/calibration/DataProcessing/'
    # Model_Properties = '/home/schwara2/Documents/Open_Knees/knee_hub/DU02/calibration/Registration/model/ModelProperties.xml'

    csv_directory = "C:\\Users\schwara2\Documents\Open_Knees\du02_calibration\DataProcessing"
    Model_Properties = "C:\\Users\schwara2\Documents\Open_Knees\du02_calibration\Registration\model5\ModelProperties.xml"

    process_csv_files(csv_directory, Model_Properties)

