#!/usr/bin/env python
"""
Copyright (c) 2015 Computational Biomechanics (CoBi) Core, Department of
Biomedical Engineering, Cleveland Clinic

Permission is hereby granted, free of charge, to any person obtaining a
copy of this software and associated documentation files (the
"Software"), to deal in the Software without restriction, including
without limitation the rights to use, copy, modify, merge, publish,
distribute, sublicense, and/or sell copies of the Software, and to permit
persons to whom the Software is furnished to do so, subject to the
following conditions:

The above copyright notice and this permission notice shall be included
in all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN
NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE
USE OR OTHER DEALINGS IN THE SOFTWARE.
-------------------
"""

import sys
from nptdms import TdmsFile
import numpy as np
import os
import pandas as pd
try: import ConfigParser as cp
except: import configparser as cp
import copy
from lxml import etree as et
import matplotlib.pyplot as plt
import csv


class Group:
    def __init__(self, name, channels, data, units, time, processed = False):
        self.name = name
        self.channels = channels
        self.data = data
        self.units = units
        self.time = time
        self.processed = processed


def tdms_contents(file):
    """
    This function extracts all the information and data from the tdms file. returns the data as a list of Group objects
    """

    tdms_file = TdmsFile.read(file)
    all_groups = tdms_file.groups()

    # prints list of groups and extracts channels from tdms file

    # collect groups objects in list
    groups = []

    for grp in all_groups:
        # grp = tdms_file[gn.name]
        channels = grp.channels()

        # creates empty lists to place information later
        channel_list = []
        channel_data_list = []
        channel_unit_list = []
        channel_time_list = []

        # extracts information from each channel
        for channel in channels:
            try:
                channel_label = channel.properties["NI_ChannelName"]
                channel_units = channel.properties['NI_UnitDescription']
                channel_data = channel.data
                channel_time = channel.time_track()

            # creates lists for plotting
                channel_data_list.append(channel_data)
                channel_list.append(channel_label)
                channel_unit_list.append(channel_units)
                channel_time_list.append(channel_time)

                new_group = Group(grp.name, channel_list, channel_data_list, channel_unit_list, channel_time_list)
                groups.append(new_group)
            except:
                break

    return groups


def find_average_data(group, channel):

    dt = np.asarray(group.data[channel])
    ave = np.nanmean(dt)

    return ave


def crop_index(group, cutoff_value, cutoff_channel, baseline = 0.0):
    """ crop the data such that on the desired channel the data does not exceed the cutoff value in either direction"""

    # i = group.channels.index(cutoff_channel)
    i=cutoff_channel
    dt = group.data[i]
    keep_idx = np.where(np.logical_and(dt >= baseline-cutoff_value, dt <= baseline + cutoff_value))[0]

    keep_idx = np.asarray(keep_idx)

    return keep_idx


def crop_data(group, cropping_index):
    """ crop all the data in the group by the cropping index"""
    cropped_data = []
    cropped_time = []

    for D in np.asarray(group.data):
        cd = D[cropping_index]
        cropped_data.append(cd)

    for T in np.asarray(group.time):
        ct = T[cropping_index]
        cropped_time.append(ct)

    # make changes to the group
    group.data = cropped_data
    group.time = cropped_time
    group.processed = True


def extract_idx(group, loading_channel, loading_direction = 1):
    """extract only the data where the data along the loading channel is in the loading direction"""

    # note we cannot just take all positive or negative values because this may include
    # non zero data from other loading scenarios

    dt = np.asarray(group.data[loading_channel])

    # if we are looking for the negative data, flip the data so we can still index in the same way
    if loading_direction == -1:
        dt = -dt

    max_idx = np.argmax(dt)

    end_idx = max_idx
    start_idx = max_idx

    while dt[end_idx] > 0.0:
        end_idx += 1
        if end_idx == len(dt):
            break

    while dt[start_idx] > 0.0:
        start_idx -= 1
        if start_idx == 0:
            break

    keep_idx = range(start_idx+1, end_idx, 1)
    keep_idx = np.asarray(keep_idx)

    return keep_idx


def sorting_index(group, sorting_channel, loading_direction=1):

    # srt_idx = group.channels.index(sorting_channel)
    srt_idx = sorting_channel

    data = np.asarray(group.data)

    if loading_direction == -1: # to sort in descending order instead of ascending
        data_for_sorting = -data
    else:
        data_for_sorting = data

    sorting_index = data_for_sorting[srt_idx, :].argsort()
    sorting_data = data[srt_idx][sorting_index]

    return sorting_index, sorting_data


def sort_data(group, sorting_index, sorting_data=False):
    """ sort the data in increasing order by the sorting index"""

    data = np.asarray(group.data)
    sorted_data = data[:,sorting_index]
    sorted_data = sorted_data.tolist()

    group.data =sorted_data
    if sorting_data is not False:
        if len(group.time[0]) > 0: # in case it is empty, ignore this step
            group.time = [sorting_data]*len(data)

    group.processed = True


def resampling_index(group, resampling_channel, increment, range, start = 0.0):

    # chan_idx = group.channels.index(resampling_channel)
    chan_idx = resampling_channel
    chan_data = group.data[chan_idx]

    if increment > 0.0:
        resampling_increments = np.arange(start, np.max(chan_data)+increment, increment)
    else:
        resampling_increments = np.arange(start, np.min(chan_data) + increment, increment)

    indices = []

    for a in resampling_increments:
        a_ind = np.where(np.abs(chan_data - a) < range)[0]

        indices.append(a_ind)

    return indices, resampling_increments


def resample_data(group, resamp_idx, resamp_intervals):
    """ given the indices of the data points included for each resampled point, take the average of the data points.
    add the resampled data to the group and return the new group"""

    data_resampled = []
    data = np.asarray(group.data)
    for i in resamp_idx:
        di = data[:,i]
        di_ave = np.average(di, axis=1)
        data_resampled.append(di_ave)

    data_resampled = np.asarray(data_resampled)
    data_resampled = data_resampled.T # to get it in the same format as the initial data list
    data_resampled = data_resampled.tolist()

    group.data = data_resampled
    group.time = [resamp_intervals]*len(data)
    group.processed = True


def get_kinetics_kinematics(groups):
    """ extracts the relevant groups for kinetics and kinematics processing"""

    kinetics_group = None
    kinematics_group = None

    for g in groups:
        if g.name == 'State.JCS Load':
            kinetics_group = g
        elif g.name == 'State.Knee JCS':
            kinematics_group = g
        else:
            pass

    return kinetics_group, kinematics_group


def get_desired_kinetics(groups):

    desired_kinetics_group = None
    for g in groups:
        if g.name == 'Kinetics.JCS.Desired':
            desired_kinetics_group = g
            break
        else:
            pass

    return desired_kinetics_group


def get_desired_kinematics(groups):

    desired_kinematics_group = None
    for g in groups:
        if g.name == 'Kinematics.JCS.Desired':
            desired_kinematics_group = g
            break
        else:
            pass

    return desired_kinematics_group

def find_hold_indices(group, channel):
    """ find the indices where the desired kinetics load was held steady on the channel"""

    data = np.asarray(group.data[channel])
    time = group.time[channel]
    tol = 1e-3

    change = np.abs(data[:-1] - data[1:])
    small_changes = np.where(change < tol)[0] # where the data is stable - ie force was held
    increments =np.abs(small_changes[:-1] - small_changes[1:]) # how many indices between the stable points
    large_increments = np.where(increments > 20)[0] # where the increment jumps

    # this gives the index for the data point at the start and end of each flat region
    index_end= small_changes[large_increments] + 1
    index_start = small_changes[large_increments+1] + 1

    # final "start" point is not relevant - it is when the curve returns to zero after loading is completed
    index_start = index_start[:-1]

    # create a "fake" start point for the first end point - this is the zero load
    # check length of other hold period
    other_len = index_end[1] - index_start[0]
    index_start = np.insert(index_start,0,index_end[0]-other_len)

    # # plot to check results
    # fig = plt.figure()
    # plt.plot(time, data, color='blue')
    # plt.plot(time[index_start],data[index_start],'o',color = 'red')
    # plt.plot(time[index_end],data[index_end],'o', color='yellow')
    # plt.show()

    # split the indices in half - positive and negative loading
    idx_split= int(len(index_end)/2)
    pos_end = index_end[:idx_split]
    neg_end  = index_end[idx_split:]
    pos_start = index_start[:idx_split]
    neg_start = index_start[idx_split:]

    pos_tuples = list(zip(pos_start, pos_end))
    neg_tuples = list(zip(neg_start,neg_end))

    # return just the end points. to return the start and end points for each flat zone, return the tuples instead
    return pos_end, neg_end


def plot_groups(title, group, x_label, tdms_directory, show_plot = True):

    point = ['s', 'D', 'o', 's', 'D', 'o']

    fig = plt.figure(figsize=(12, 8))
    fig.suptitle(title, fontsize=16)

    # Creates two graphs one figure
    ax1 = plt.subplot(121)
    ax2 = plt.subplot(122)
    subplots = [ax1, ax2]

    # assign the first 3 channels to ax1, and the last 3 to ax2
    axes = [ax1, ax1, ax1, ax2, ax2, ax2]

    for i in range(6):
        ax = axes[i]
        ax.set_ylabel(group.units[i], fontsize=12)

        # mask to ignore missing data points
        dt = np.asarray(group.data[i])
        dt = dt.astype(np.double)
        dt_mask = np.isfinite(dt)

        if dt.size == 0: # no data, just create an empty graph so we know its empty
            # ax.plot(group.time[i], dt, point[i], label=group.channels[i])
            ax.set_xlabel(x_label, fontsize=12)
        elif group.processed == True:
            ax.plot(group.time[i][dt_mask], dt[dt_mask], point[i], label=group.channels[i])
            ax.set_xlabel(x_label, fontsize=12)
            if np.nanmean(group.time[i]) < 0.0: # for cases where applied load is negative, plot starting at 0 on x axis
                ax.invert_xaxis()
        else:
            ax.plot(group.time[i][dt_mask], dt[dt_mask], label=group.channels[i])
            ax.set_xlabel('Time (ms)', fontsize=12)

    for subplot in subplots:
        subplot.legend(loc='best')
        subplot.set_title(group.name, fontsize=12)
        subplot.grid(True)

    if show_plot:
        plt.show()

    try:
        os.chdir('Processed_Data')
    except:
        os.mkdir('Processed_Data')
        os.chdir('Processed_Data')

    # Saves files as .png
    fig.savefig(title + '.png',
                    format="png", dpi=100)


    plt.close(fig)

    # save data in csv
    df = pd.DataFrame()
    df[x_label] = group.time[0]
    for i in range(6):
        df[group.channels[i] + ' [' +  group.units[i]+ ']'] = group.data[i]

    df.to_csv(title + '.csv')

    # go back to the directory containing the files
    os.chdir(tdms_directory)


def get_offsets(state_config, file_directory):
    """ extract the offsets form the configuration file and convert to mm,deg from m,rad"""

    config = cp.RawConfigParser()
    if not config.read(state_config):
        raise IOError("Cannot load configuration file... Check path.")

    Knee_offsets = config.get('Knee JCS', 'Position Offset (m,rad)')

    # convert string into list of 6 floats
    Knee_offsets = Knee_offsets.replace('"', '')
    Knee_offsets = Knee_offsets.split(" ")[1:]
    Knee_offsets = list(map(float, Knee_offsets))

    Knee_offsets = np.asarray(Knee_offsets)

    # to convert to mm and deg
    Knee_offsets[0:3] = Knee_offsets[0:3] *1000

    Knee_offsets[3:6] = Knee_offsets[3:6] * 180.0/np.pi


    # pull in the headers for the kinematics data from the state file,
    # give the same headers in the offsets file
    headers = []
    for i in range(6):
        chan_name = config.get('Knee JCS', 'Channel Names {}'.format(i))
        chan_name = chan_name.replace('"', '')
        chan_unit = config.get('Knee JCS', 'Channel Units {}'.format(i))
        chan_unit = chan_unit.replace('"', '')
        headers.append('Knee JCS '+ chan_name + ' [' + chan_unit + ']')


    save_experiment_offsets(Knee_offsets, headers, file_directory)

    return Knee_offsets

def save_experiment_offsets(experiment_offsets, headers, file_directory):
    """save the offsets in a csv file"""

    try:
        os.chdir('Processed_Data')
    except:
        os.mkdir('Processed_Data')
        os.chdir('Processed_Data')

    # save data in csv
    df = pd.DataFrame()
    df = pd.DataFrame([experiment_offsets], columns=headers)

    df.to_csv('kinematic_offsets.csv')

    # go back to the directory containing the files
    os.chdir(file_directory)


def apply_offsets(group, offsets):
    """ apply the offsets to the data in each channel"""

    data = np.asarray(group.data)
    offset_data = data.T + offsets
    offset_data = offset_data.T

    group.data = offset_data
    group.processed = True


def Tranform_axis_in_world(origin, axes):
    """given the origin and axes of a coordinate system, find the matrix to transfrom from world to that coordinate system"""

    axes = np.asarray(axes)

    RM = np.eye(4)
    RM[:3,3] = origin
    RM[:3, :3] = axes.T

    return RM


def Transform_A_in_B(origin_a, axes_a, origin_b, axes_b):

    A_in_world = Tranform_axis_in_world(origin_a, axes_a)
    B_in_world = Tranform_axis_in_world(origin_b, axes_b)

    world_in_B =  np.linalg.inv(B_in_world)

    A_in_B = np.matmul(world_in_B,A_in_world)

    return A_in_B


def find_model_offsets(ModelPropertiesXml):
    """ find the inital offsets in the tibiofemoral joint of the model"""

    # extract the origins and axes of the tibia and femur from the model properties file
    model_properties = et.parse(ModelPropertiesXml)
    ModelProperties = model_properties.getroot()
    Landmarks = ModelProperties.find("Landmarks")
    FMO = np.array(Landmarks.find("FMO").text.split(",")).astype(np.float)
    Xf_axis = np.array(Landmarks.find("Xf_axis").text.split(",")).astype(np.float)
    Yf_axis = np.array(Landmarks.find("Yf_axis").text.split(",")).astype(np.float)
    Zf_axis= np.array(Landmarks.find("Zf_axis").text.split(",")).astype(np.float)
    TBO =np.array( Landmarks.find("TBO").text.split(",")).astype(np.float)
    Xt_axis=np.array( Landmarks.find("Xt_axis").text.split(",")).astype(np.float)
    Yt_axis= np.array(Landmarks.find("Yt_axis").text.split(",")).astype(np.float)
    Zt_axis = np.array(Landmarks.find("Zt_axis").text.split(",")).astype(np.float)

    femur_axes = [Xf_axis,Yf_axis,Zf_axis]
    tibia_axes = [Xt_axis,Yt_axis,Zt_axis]

    # get the transformatrion matrix to get the tibia in the femur coordinate system
    T_in_F = Transform_A_in_B( TBO, tibia_axes, FMO, femur_axes) # this gives the transformation of the femur relative to the tibia coordnte system

    # extract the rotations and translations along the joints axes from the transformation matrix
    # use: https://simtk.org/plugins/moinmoin/openknee/Infrastructure/ExperimentationMechanics?action=AttachFile&do=view&target=Knee+Coordinate+Systems.pdf
    # page 6

    beta = np.arcsin(T_in_F[ 0, 2])
    alpha = np.arctan2(-T_in_F[ 1, 2], T_in_F[ 2, 2])
    gamma = np.arctan2(-T_in_F[ 0, 1], T_in_F[ 0, 0])

    ca = np.cos(alpha)
    sa = np.sin(alpha)
    cb = np.cos(beta)
    sb = np.sin(beta)

    b = (T_in_F[ 1, 3]*ca) + (T_in_F[ 2, 3]*sa)
    c = ((T_in_F[ 2, 3]*ca) - (T_in_F[ 1, 3]*sa))/ cb
    a = T_in_F[ 0, 3] - (c*sb)

    model_offsets = [a,b,c,alpha,beta,gamma]

    # convert alpha, beta, gamma to degrees
    model_offsets[3:6] = np.degrees(model_offsets[3:6])
    model_offsets = np.array(model_offsets)

    return model_offsets


def change_kinematics_reporting(group):
    """flip the channels which are reported oppostie to the model"""
    data = group.data
    channels = group.channels

    updated_data = copy.deepcopy(data)
    updated_channels = copy.deepcopy(channels)

    # changing the names and directions of the channels to align with our JCS definitions

    # expeirment=  "Medial"
    # model= medial translation (dont need to flip)
    updated_channels[0] = 'Knee JCS Medial Translation'

    # experiment= "Posterior"
    # model= anterior translation (need to flip)
    updated_data[1] = -data[1]
    updated_channels[1] = 'Knee JCS Anterior Translation'

    # experiment= "Superior"
    # model = superior translation (dont need to flip)
    updated_channels[2] = 'Knee JCS Superior Translation'

    # epxeriment  =  "Flexion"
    # model = extension (need to flip)
    updated_data[3] = -data[3]
    updated_channels[3] = 'Knee JCS Extension Rotation'

    # experiment = "Valgus"
    # model = abduction = valgus (dont need to flip)
    updated_channels[4] = 'Knee JCS Abduction Rotation'

    # experiment = "Internal Rotation"
    # model = External (need to flip)
    updated_data[5] = -data[5]
    updated_channels[5] = 'Knee JCS External Rotation'


    group.data = updated_data
    group.channels = updated_channels
    group.processed = True


def change_kinetics_reporting(group, loading_channel=None):
    """flip the channels which are reported opposite to the model."""

    data = np.asarray(group.data)
    channels = group.channels

    updated_data = copy.deepcopy(data)
    updated_channels = copy.deepcopy(channels)

    # changing the names and directios of the channels so the force descriptions are clearly defined in our coordinate systems

    #experiment = "Lateral Drawer"
    #model = medial (need to flip)
    updated_channels[0] = 'External Tibia_x Load'
    updated_data[0] = -data[0]

    #experiment = "Anterior Drawer"
    # model = anterior
    updated_channels[1] = 'External Tibia_y Load'

    #experiment = "Distraction"
    #model = superior (need to flip - distractiong is a force in the inferior direction)
    updated_channels [2] = 'External Tibia_z Load'
    updated_data[2] = -data[2]

    #experiment = "Extension Torque"
    #model = extension
    updated_channels[3] = 'External Tibia_x Moment'

    # experiment = "Varus Torque"
    # model = valgus (need to flip)
    updated_channels[4]  = 'External Tibia_y Moment'
    updated_data[4] = -data[4]

    # experiment ="External Rotation Torque"
    # model = "external"
    updated_channels[5] = 'External Tibia_z Moment'

    group.data = updated_data
    group.channels = updated_channels
    group.processed = True

    # change the "time" to be a function of loading
    if loading_channel is not None:
        loading_data = updated_data[loading_channel]
        group.time = [loading_data] *6 # repeat for every channel


def T_tib_in_fem(a, b, c, alpha, beta, gamma):
    # this transformation matrix will give the position of the tibia in the femur coordinate system for each data point

    T_fem_tib = np.zeros((len(a),4,4))

    ca = np.cos(alpha)
    cb = np.cos(beta)
    cg = np.cos(gamma)
    sa = np.sin(alpha)
    sb = np.sin(beta)
    sg = np.sin(gamma)

    T_fem_tib[:,0,0] = np.multiply(cb,cg)
    T_fem_tib[:, 0, 1] = np.multiply(-cb,sg)
    T_fem_tib[:, 0, 2] = sb
    T_fem_tib[:, 0, 3] = np.multiply(c,sb) + a

    T_fem_tib[:, 1, 0] = np.multiply(np.multiply(sa,sb),cg) + np.multiply(ca,sg)
    T_fem_tib[:, 1, 1] = -np.multiply(np.multiply(sa,sb),sg) + np.multiply(ca, cg)
    T_fem_tib[:, 1, 2] = -np.multiply(sa, cb)
    T_fem_tib[:, 1, 3] = -np.multiply(c,np.multiply(sa, cb))+ np.multiply(b,ca)

    T_fem_tib[:, 2, 0] = -np.multiply(np.multiply(ca, sb), cg) + np.multiply(sa,sg)
    T_fem_tib[:, 2, 1] = np.multiply(np.multiply(ca,sb),sg) + np.multiply(sa,cg)
    T_fem_tib[:, 2, 2] = np.multiply(ca, cb)
    T_fem_tib[:, 2, 3] = np.multiply(c,np.multiply(ca, cb))+ np.multiply(b,sa)

    T_fem_tib[:, 3, 3] = 1.0

    return T_fem_tib


def kinetics_tibia_to_femur(kinetics_group, kinematics_group, loading_channel=None):
    """ convert the kinetics channel to report loads applied to femur in the tibia coordinate system.
    if a loading channel is given, the 'time' in the group will be updated too"""

    # initially, loads are reported as external tibia loads in the tibia coordinate system.
    # (a 'lateral' load is really along the tibia x axis not the JCS lateral axis)
    tibia_kinetics_data_at_tibia_origin = np.asarray(kinetics_group.data)

    # to report the external forces applied to the femur, we need to invert the forces applied to the tibia
    femur_kinetics_data_at_tibia_origin = -tibia_kinetics_data_at_tibia_origin

    # now we need to translate the forces and moments so they are applied at the femur origin instead of the tibia origin

    # find a,b,c,alpha,beta,gamma
    # note: the kinematics data was given in deg, mm. need to convert to rad, m
    kinematics_data = np.asarray(kinematics_group.data)
    a = kinematics_data[0]/1000.0
    b = kinematics_data[1]/1000.0
    c = kinematics_data[2]/1000.0

    alpha = np.radians(kinematics_data[3])
    beta = np.radians(kinematics_data[4])
    gamma = np.radians(kinematics_data[5])

    # find the transformation of tibia in femur coordinte for each time point
    T = T_tib_in_fem(a, b, c, alpha, beta, gamma)

    # invert to get the position of femur in tibia CS
    T_fem_in_tib = np.linalg.inv(T)

    # vector from tibia origin to femur origin in tibia coordinate system at each time point
    vec_fmo = T_fem_in_tib[:, 0:3, 3] # units m
    vec_fmo = vec_fmo.T # transpose so it will be in the same shape as the data ie (axis, time point)

    # the moments at the femur origin are the moments at the tibia origin plus the forces at the tibia origin
    # cross with the moment arm (vector from femur origin to tibia origin in tibia cs, so negative of vec_fmo)

    loads = femur_kinetics_data_at_tibia_origin[0:3,:] # units N
    torques = femur_kinetics_data_at_tibia_origin[3:6,:] # units Nm, or Nmm

    # check the units of loads and torques
    torque_units = kinetics_group.units[3]

    if 'Nmm' in torque_units:
        vec_fmo = vec_fmo * 1000 # convert vector to mm, results will be in Nmm

    # load_moments = np.multiply(loads, vec_fmo) # units same as intial kinetics torques
    load_moments = np.cross(-vec_fmo.T, loads.T).T
    torques = torques + load_moments # units same as intial kinetics torques

    # store all the results back in the data, channels
    femur_kinetics_data_at_femur_origin = np.zeros(np.shape(femur_kinetics_data_at_tibia_origin))
    femur_kinetics_data_at_femur_origin[0:3, :] = loads
    femur_kinetics_data_at_femur_origin[3:6, :] = torques

    kinetics_group.data = femur_kinetics_data_at_femur_origin
    kinetics_group.channels = ['External Femur_x Load','External Femur_y Load','External Femur_z Load','External Femur_x Moment','External Femur_y Moment','External Femur_z Moment']
    kinetics_group.processed = True
    if loading_channel is not None:
        loading_data = femur_kinetics_data_at_femur_origin[loading_channel]
        kinetics_group.time = [loading_data] * 6 # repeat for every channel
        kinematics_group.time = [loading_data] * 6


def passive_flexion_processing(groups, experiment_offsets, model_offsets, tdms_directory):
    """ process the passive flexion tdms file"""

    # get the kinetcis and  kinematics  groups
    kinetics_group, kinematics_group = get_kinetics_kinematics(groups)

    # save the raw data as csv and png file
    plot_groups("Passive_Flexion_Kinematics_raw", kinematics_group, 'Time (ms)', tdms_directory, show_plot=False)
    plot_groups("Passive_Flexion_Kinetics_raw", kinetics_group, 'Time (ms)', tdms_directory, show_plot=False)

    # Crop the raw data on the cutoff channels by the force and torque cutoff values
    force_cutoff = 3.0
    torque_cutoff = 0.3

    # cutoff_channels = ['JCS Load Lateral Drawer', 'JCS Load Anterior Drawer','JCS Load Varus Torque', 'JCS Load External Rotation Torque']
    cutoff_channels = [0,1,4,5]
    cutoff_values = [force_cutoff, force_cutoff, torque_cutoff, torque_cutoff]

    for idx, chan in enumerate(cutoff_channels):
        cropping_idx = crop_index(kinetics_group, cutoff_values[idx],chan)
        crop_data(kinetics_group, cropping_idx)
        crop_data(kinematics_group, cropping_idx)

    # plot_groups("kinematics_cropped", kinematics_group,'Flexion Angle (deg)',tdms_directory)
    # plot_groups("kinetics_cropped", kinetics_group, 'Flexion Angle (deg)',tdms_directory)

    # sort the data by ascneding flexion axis

    # sorting_channel = 'Knee JCS Flexion'
    sorting_channel = 3
    srt_index, srt_data = sorting_index(kinematics_group, sorting_channel)
    sort_data(kinetics_group, srt_index, srt_data)
    sort_data(kinematics_group, srt_index, srt_data)

    # plot_groups("kinematics_sorted", kinematics_group, 'Flexion Angle (deg)',tdms_directory)
    # plot_groups("kinetics_sorted", kinetics_group, 'Flexion Angle (deg)',tdms_directory)

    # resample at 5 degree increments by averageing each channed where flexion angle is within 0.1 degrees
    # resampling_channel = 'Knee JCS Flexion'
    resampling_channel = 3
    increments = 5.0
    range = 1.25

    resamp_idx, resamp_intervals = resampling_index(kinematics_group, resampling_channel, increments, range)
    resample_data(kinematics_group, resamp_idx, resamp_intervals)
    resample_data(kinetics_group, resamp_idx, resamp_intervals)

    # # plot the data
    # plot_groups("Kinematics_resampled", kinematics_group ,'Flexion Angle (deg)',tdms_directory)
    # plot_groups("Kinetics_resampled", kinetics_group,'Flexion Angle (deg)',tdms_directory)

    # apply experiment offsets to the kinematics data
    apply_offsets(kinematics_group, experiment_offsets)

    # plot the kinematics in the experiment cs before offsets, etc.
    plot_groups('Passive_Flexion_kinematics_in_JCS_experiment',
                kinematics_group,
                'Felxion Angle', tdms_directory, show_plot=False)
    plot_groups(
        'Passive_Flexion_TibiaKinetics_in_TibiaCS_experiment',
        kinetics_group,
        'Flexion Angle', tdms_directory, show_plot=False)


    # report the axes in the same direction as the model reporting :
    # positive directions are - extension, adduction (varus), external
    change_kinematics_reporting(kinematics_group)
    change_kinetics_reporting(kinetics_group)

    # plot_groups("Kinematics", kinematics_group,'Flexion Angle (deg)',tdms_directory)
    # plot_groups("Kinetics", kinetics_group, 'Flexion Angle (deg)',tdms_directory)

    # report the kinetics as external loads applied to femur in tibia coordinate system
    kinetics_tibia_to_femur(kinetics_group, kinematics_group)

    # apply model offsets - Note this is done AFTER changing the signs of the data to register with model outputs.
    apply_offsets(kinematics_group, -model_offsets)

    # save the data as csv and png file
    plot_groups("Passive_Flexion_Kinematics_in_JCS", kinematics_group, 'Flexion Angle (deg)',tdms_directory,show_plot=False)
    plot_groups("Passive_Flexion_Kinetics_in_TibiaCS", kinetics_group, 'Flexion Angle (deg)', tdms_directory,show_plot=False)


def laxity_processing(groups, experiment_offsets, model_offsets,  tdms_directory):
    """process the laxity tdms file"""

    kinetics_group, kinematics_group = get_kinetics_kinematics(groups)

    # find the average flexion angle in the data
    flexion_angle = find_average_data(kinematics_group, 3)

    rounded_flexion = int(round(flexion_angle, -1))  # round to the nearest 10 - this is the 'intended' flexion, use for naming files

    # save the raw data as csv and png file
    plot_groups("Laxity_{}deg_Kinematics_raw".format(rounded_flexion), kinematics_group, 'Time (ms)', tdms_directory,show_plot=False)
    plot_groups("Laxity_{}deg_Kinetics_raw".format(rounded_flexion), kinetics_group, 'Time (ms)', tdms_directory,show_plot=False)

    # crop any data that falls outside the flexion angle cutoff
    angle_cutoff = 1.0
    cropping_index_1 = crop_index(kinematics_group, angle_cutoff, 3, baseline=flexion_angle)
    crop_data(kinematics_group, cropping_index_1)
    crop_data(kinetics_group, cropping_index_1)

    # plot_groups("kinematics_cut", kinematics_group, 'Time (ms)',tdms_directory)
    # plot_groups("kinetics_cut", kinetics_group,'Time (ms)',tdms_directory)

    Loading_channels = [4, 5, 1] # Varus Torque, External Rotation Torque, Anterior Drawer
    loading_directions = [1, -1]
    channel_nickname = ['VV', 'EI', 'AP']

    # set up resampling variables
    resampling_increments = [2.5, 1.0, 10.0]
    force_range = 1.0
    torque_range = 0.1
    resampling_range = [torque_range, torque_range, force_range]

    # set up cropping variables
    force_cutoff = 1.5
    torque_cutoff = 0.1
    cutoff_channels = [4, 5, 1,0] # Varus Torque,External Rotation Torque, Anterior Drawer,Lateral Drawer
    cutoff_values = [torque_cutoff, torque_cutoff, force_cutoff, force_cutoff]

    for n, channel in enumerate(Loading_channels):
        for i in loading_directions:

            # make copies of the data so we don't make changes to the original data at every loop
            kinetics_group_copy = copy.deepcopy(kinetics_group)
            kinematics_group_copy = copy.deepcopy(kinematics_group)

            # extract the data for the loading channel and loading direction
            idx_keep = extract_idx(kinetics_group, channel, loading_direction=i)
            crop_data(kinetics_group_copy, idx_keep)
            crop_data(kinematics_group_copy, idx_keep)

            # crop any other channels by force and torque cutoffs
            for idx, chan in enumerate(cutoff_channels):
                if chan == channel: # dont crop if its the current loading channel
                    continue
                cropping_idx = crop_index(kinetics_group_copy, cutoff_values[idx], chan)
                crop_data(kinetics_group_copy, cropping_idx)
                crop_data(kinematics_group_copy, cropping_idx)

            # sort the data in order of ascending/descending load
            srt_index, srt_data = sorting_index(kinetics_group_copy, channel, loading_direction=i)

            sort_data(kinetics_group_copy, srt_index, srt_data)
            sort_data(kinematics_group_copy, srt_index, srt_data)

            # resample at loading increments
            increments = resampling_increments[n] * i
            range = resampling_range[n]

            resamp_idx, resamp_intervals = resampling_index(kinetics_group_copy, channel, increments, range)
            resample_data(kinematics_group_copy, resamp_idx, resamp_intervals)
            resample_data(kinetics_group_copy, resamp_idx, resamp_intervals)

            channel_units = kinetics_group_copy.units[channel]

            if i == 1:
                num = '1'
            else:
                num = '2'

            # # plot the kinematics in the experiment cs before experiment offsets are added.
            # plot_groups('Laxity_{}deg_'.format(rounded_flexion) + channel_nickname[n] + num + '_kinematics_in_JCS_experiment_nooffset', kinematics_group_copy,
            #                'Applied Load (' + channel_units + ')', tdms_directory, show_plot=False)

            # apply experiment offsets to the kinematics data
            apply_offsets(kinematics_group_copy, experiment_offsets)

            # plot the kinematics in the experiment cs before model offsets, etc.
            plot_groups('Laxity_{}deg_'.format(rounded_flexion) + channel_nickname[n] + num + '_kinematics_in_JCS_experiment', kinematics_group_copy,
                           'Applied Load (' + channel_units + ')', tdms_directory, show_plot=False)
            plot_groups('Laxity_{}deg_'.format(rounded_flexion) + channel_nickname[n] + num + '_TibiaKinetics_in_TibiaCS_experiment',
                        kinetics_group_copy,
                        'Applied Load (' + channel_units + ')', tdms_directory, show_plot=False)

            # report the axes in the same direction as the model reporting
            change_kinematics_reporting(kinematics_group_copy)
            change_kinetics_reporting(kinetics_group_copy) # this is external loads on tibia in tibia cs

            # plot the tibia loads in tibia CS for comparison with model later -
            # the time axis needs to be the same as the one applied in the model, so will save a copy of this to plot later
            Tibiakinetics_group_copy = copy.deepcopy(kinetics_group_copy)

            # report the kinetics as external loads applied to femur in tibia coordinate system
            kinetics_tibia_to_femur(kinetics_group_copy, kinematics_group_copy, channel)

            # set the time in the TibiaKinetics group the same as the femur kinetics group, and plot
            Tibiakinetics_group_copy.time = kinetics_group_copy.time
            plot_groups('Laxity_{}deg_'.format(rounded_flexion) + channel_nickname[n] + num + '_TibiaKinetics_in_TibiaCS',
                        Tibiakinetics_group_copy,
                        'Applied Load (' + channel_units + ')', tdms_directory, show_plot=False)

            # apply model offsets - Note this is done AFTER changing the signs of the data to register with model outputs.
            apply_offsets(kinematics_group_copy, -model_offsets)

            plot_groups('Laxity_{}deg_'.format(rounded_flexion) + channel_nickname[n] + num +'_kinetics_in_TibiaCS', kinetics_group_copy,
                        'Applied Load (' + channel_units +')', tdms_directory,show_plot=False)
            plot_groups('Laxity_{}deg_'.format(rounded_flexion) + channel_nickname[n] + num + '_kinematics_in_JCS', kinematics_group_copy,
                        'Applied Load (' + channel_units + ')',tdms_directory ,show_plot=False)


def laxity_processing_2(groups, experiment_offsets, model_offsets,  tdms_directory):
    """use the desired kinematics channels to filter the kinematics and kinetics data"""

    kinetics_group, kinematics_group =  get_kinetics_kinematics(groups)
    desired_kinetics_group = get_desired_kinetics(groups)

    # find the average flexion angle in the data
    flexion_angle = find_average_data(kinematics_group, 3)
    rounded_flexion = int(round(flexion_angle, -1))  # round to the nearest 10 - this is the 'intended' flexion, use for naming files

    # plot_groups("Laxity_{}deg_Kinetics_desired_raw".format(rounded_flexion), desired_kinetics_group, 'Time (ms)', tdms_directory,show_plot=False)

    # save the raw data as csv and png file
    plot_groups("Laxity_{}deg_Kinematics_raw".format(rounded_flexion), kinematics_group, 'Time (ms)', tdms_directory,show_plot=False)
    plot_groups("Laxity_{}deg_Kinetics_raw".format(rounded_flexion), kinetics_group, 'Time (ms)', tdms_directory,show_plot=False)

    Loading_channels = [4, 5, 1]
    channel_nickname = ['VV', 'EI', 'AP']

    for n, chan in enumerate(Loading_channels):

        # use the desired kinetics to find the indices of the data points at the end of the "flat" regions
        # where forces were held steady
        pos_index, neg_index = find_hold_indices(desired_kinetics_group, chan)

        for c, index in enumerate([pos_index,neg_index]):

            # make copies of the data so we don't make changes to the original data at every loop
            kinetics_group_copy = copy.deepcopy(kinetics_group)
            kinematics_group_copy = copy.deepcopy(kinematics_group)

            # extract the data from the kinematics and kinetics
            crop_data(kinetics_group_copy, index)
            crop_data(kinematics_group_copy, index)

            # set the x axis to the applied load (actual)
            kinetics_group_copy.time = [kinetics_group_copy.data[chan]] * len(kinetics_group_copy.data)
            kinematics_group_copy.time =[kinetics_group_copy.data[chan]] * len(kinematics_group_copy.data)

            channel_units = kinetics_group_copy.units[chan]

            # these are the results that have undergone some processing and will be pushlished for data representation step
            # plot
            plot_groups('Laxity_{}deg_'.format(rounded_flexion) + channel_nickname[n] + str(c+1) +'_TibiaKinetics_in_TibiaCS_experiment', kinetics_group_copy,
                        'Applied Load (' + channel_units +')', tdms_directory,show_plot=False)
            plot_groups('Laxity_{}deg_'.format(rounded_flexion) + channel_nickname[n] + str(c+1) + '_kinematics_in_JCS_experiment', kinematics_group_copy,
                        'Applied Load (' + channel_units + ')',tdms_directory ,show_plot=False)

            # continue with remaining processing steps for our team's workflow

            # apply experiment offsets to the kinematics data
            apply_offsets(kinematics_group_copy, experiment_offsets)

            # report the axes in the right handed coordinate system we defined.
            change_kinematics_reporting(kinematics_group_copy)
            change_kinetics_reporting(kinetics_group_copy) # this is external loads on tibia in tibia cs

            # report the kinetics as external loads applied to femur in tibia coordinate system
            kinetics_tibia_to_femur(kinetics_group_copy, kinematics_group_copy, chan)

            # apply model offsets - Note this is done AFTER changing the signs of the data to register with model outputs.
            apply_offsets(kinematics_group_copy, -model_offsets)

            # processed data this will be used to generate models replicating experiment
            plot_groups('Laxity_{}deg_'.format(rounded_flexion) + channel_nickname[n] + str(c+1) +'_kinetics_in_TibiaCS', kinetics_group_copy,
                        'Applied Load (' + channel_units +')', tdms_directory,show_plot=False)
            plot_groups('Laxity_{}deg_'.format(rounded_flexion) + channel_nickname[n] + str(c+1) + '_kinematics_in_JCS', kinematics_group_copy,
                        'Applied Load (' + channel_units + ')',tdms_directory ,show_plot=False)


def passive_flexion_processing_2(groups, experiment_offsets, model_offsets,  tdms_directory):

    kinetics_group, kinematics_group = get_kinetics_kinematics(groups)

    plot_groups("Passive_Flexion_Kinematics_Raw", kinematics_group, 'Flexion Angle (deg)', tdms_directory, show_plot=False)
    plot_groups("Passive_Flexion_Kinetics_Raw", kinetics_group, 'Flexion Angle (deg)', tdms_directory, show_plot=False)

    # separate flexion and extension data
    # find max data point, anything before is flexion, anything after is extension
    flexion_kinematics = kinematics_group.data[3]
    max_flex_idx = np.argmax(flexion_kinematics)

    flex_crop_idx = np.arange(0,max_flex_idx)
    # ext_crop_idx = np.arange(max_flex_idx, len(flexion_kinematics)) # in case we need the extension points for something

    # use only the flexion data
    crop_data(kinetics_group, flex_crop_idx)
    crop_data(kinematics_group, flex_crop_idx)

    # resample at 5 degree increments by averageing each channel where flexion angle is within range degrees
    resampling_channel = 3
    increments = 5.0
    range = 0.5

    resamp_idx, resamp_intervals = resampling_index(kinematics_group, resampling_channel, increments, range)
    resample_data(kinematics_group, resamp_idx, resamp_intervals)
    resample_data(kinetics_group, resamp_idx, resamp_intervals)

    # files for data representation
    plot_groups("Passive_Flexion_Kinematics_in_JCS_experiment", kinematics_group, 'Flexion Angle (deg)',
                tdms_directory, show_plot=False)
    plot_groups("Passive_Flexion_TibiaKinetics_in_TibiaCS_experiment", kinetics_group, 'Flexion Angle (deg)',
                tdms_directory, show_plot=False)

    #continue processing for our workflow
    # apply experiment offsets to the kinematics data (add)
    apply_offsets(kinematics_group, experiment_offsets)

    # report the axes in the same direction as the model reporting :
    # positive directions are - extension, adduction (varus), external
    change_kinematics_reporting(kinematics_group)
    change_kinetics_reporting(kinetics_group)

    # report the kinetics as external loads applied to femur in tibia coordinate system
    kinetics_tibia_to_femur(kinetics_group, kinematics_group)

    # apply model offsets (subtract) - Note this is done AFTER changing the signs of the data to register with model outputs.
    apply_offsets(kinematics_group, -model_offsets)

    # save the data as csv and png files - this is what will be used to model experimental conditions
    plot_groups("Passive_Flexion_Kinematics_in_JCS", kinematics_group, 'Flexion Angle (deg)', tdms_directory,
                show_plot=False)
    plot_groups("Passive_Flexion_Kinetics_in_TibiaCS", kinetics_group, 'Flexion Angle (deg)', tdms_directory,
                show_plot=False)

def combined_processing(groups, tdms_directory):
    """process combined loading files. find data point where valgus = 10 Nm, internal = 5 Nm, anterior = 0"""

    kinetics_group, kinematics_group = get_kinetics_kinematics(groups)
    desired_kinetics_group = get_desired_kinetics(groups)

    # find the average flexion angle in the data
    flexion_angle = find_average_data(kinematics_group, 3)
    rounded_flexion = int(round(flexion_angle, -1))  # round to the nearest 10 - this is the 'intended' flexion, use for naming files

    # save the raw data as csv and png file
    plot_groups("Combined_{}deg_Kinetics_desired_raw".format(rounded_flexion), desired_kinetics_group, 'Time (ms)', tdms_directory,show_plot=False)
    plot_groups("Combined_{}deg_Kinematics_raw".format(rounded_flexion), kinematics_group, 'Time (ms)', tdms_directory,show_plot=False)
    plot_groups("Combined_{}deg_Kinetics_raw".format(rounded_flexion), kinetics_group, 'Time (ms)', tdms_directory,show_plot=False)

    # find indices for cropping the data where valgus = 10 Nm, internal = 5 Nm
    # channel 1- Anterior Drawer
    # channel 4 - varus torque
    # channel 5 - external torque
    cutoff = 1.0e-08
    cropping_chans = [4,5,1]
    cropping_baselines = [-10,-5,0]

    # crop using desired kinetics
    for idx, chan in enumerate(cropping_chans):
        cropping_idx = crop_index(desired_kinetics_group, cutoff, chan, baseline=cropping_baselines[idx])
        crop_data(desired_kinetics_group, cropping_idx)
        crop_data(kinetics_group, cropping_idx)
        crop_data(kinematics_group, cropping_idx)

    # take the endpoint of whatever data is left
    cropping_idx = [-1]
    crop_data(kinetics_group, cropping_idx)
    crop_data(kinematics_group, cropping_idx)

    # plot and save the cropped data
    plot_groups("Combined_{}deg_Kinematics_cropped".format(rounded_flexion), kinematics_group, 'Time (ms)',
                tdms_directory, show_plot=False)
    plot_groups("Combined_{}deg_Kinetics_cropped".format(rounded_flexion), kinetics_group, 'Time (ms)', tdms_directory,
                show_plot=False)

    # return the final data so we can combine the different flexion angles into one data set

    return kinetics_group, kinematics_group


def merge_combined_data(combined_data):
    """merge the combined data that was collected at each flexion angle"""

    if len(combined_data) == 0:
        print("no combined data to merge")
        return

    # start by setting the first group in the combined data list
    kinetics_group = combined_data[0][0]
    kinematics_group = combined_data[0][1]

    for kin_group, kinem_group in combined_data[1:]:
        kinetics_group.data = np.hstack((kinetics_group.data, kin_group.data))
        kinematics_group.data = np.hstack((kinematics_group.data, kinem_group.data))
        kinetics_group.time = np.hstack((kinetics_group.time, kin_group.time))
        kinematics_group.time = np.hstack((kinematics_group.time, kinem_group.time))

    # sort the data by increasing flexion angle
    sort_idx, sorting_data = sorting_index(kinematics_group, 3)
    sort_data(kinetics_group, sort_idx, sorting_data)
    sort_data(kinematics_group, sort_idx, sorting_data)

    # plot the merged data
    # plot and save the cropped data
    plot_groups("Combined_Kinematics", kinematics_group, 'Time (ms)',
                tdms_directory, show_plot=False)
    plot_groups("Combined_Kinetics", kinetics_group, 'Time (ms)', tdms_directory,
                show_plot=False)

    # if you want to do more processing to these files such as add experiment offsets, subract model offsets,
    # change reporting conventions, see passive flexion and laxity processing functions

def process_tdms_files(file_directory, ModelProperties):

    # sort through files, label them as the state file or tdms file
    tdms_files = []
    State_file = None
    os.chdir(file_directory)
    for file in os.listdir(file_directory):
        if file.endswith('.cfg'):
            State_file = file
        elif file.endswith('.tdms'):
            tdms_files.append(file)
        else:
            pass

    # calculate experiment offsets and model offsets - return both in mm, deg
    experiment_offsets = get_offsets(State_file, file_directory)
    model_offsets = find_model_offsets(ModelProperties)

    # print(experiment_offsets)
    # print(model_offsets)

    #store all the data together for plotting/visual analysis purpose
    all_data = []

    # process the data in each of the tdms files
    for file in tdms_files:

        groups = tdms_contents(file)

        # processing for passive flexion file:
        if 'passive flexion' in file.lower():

            # kinetics_group, kinematics_group = get_kinetics_kinematics(groups)
            # all_data.append((kinetics_group, kinematics_group))

            # this processing script was used in initial knee hub calibration, but we found a better way to do it.
            # passive_flexion_processing(groups, experiment_offsets, model_offsets,  file_directory)

            passive_flexion_processing_2(groups, experiment_offsets, model_offsets,  file_directory)


        elif 'laxity' in file.lower():

            # kinetics_group, kinematics_group = get_kinetics_kinematics(groups)
            # all_data.append((kinetics_group, kinematics_group))

            # this processing script was used in initial knee hub calibration, but we found a better way to do it.
            # laxity_processing(groups, experiment_offsets, model_offsets,  file_directory)

            laxity_processing_2(groups, experiment_offsets, model_offsets,  file_directory)
            # pass


    # import csv_processing_nkd as cp
    # cp.all_kinetics_kinematics(all_data, file_directory)

    # below describes what each channel of oks003 data contains.

    # Kinetics Channels
    # 0 - lateral drawer
    # 1- Anterior Drawer
    # 2 - distraction
    # 3 - extension torque
    # 4 - varus torque
    # 5 - external torque

    # Kinematics Channels
    # 0 - medial
    # 1- posterior
    # 2 - superior
    # 3 - flexion
    # 4 - valgus
    # 5 - internal


def process_tdms_combined(file_directory):

    # sort through files, label them as the state file or tdms file
    tdms_files = []
    State_file = None
    os.chdir(file_directory)
    for file in os.listdir(file_directory):
        if file.endswith('.cfg'):
            State_file = file
        elif file.endswith('.tdms'):
            tdms_files.append(file)
        else:
            pass

    # call this just so that the kinematic offsets file will be saved in the directory
    experiment_offsets = get_offsets(State_file, file_directory)

    # store combined data to merge together after
    combined_data = []

    # process the data in each of the combined loading tdms files
    for file in tdms_files:

        groups = tdms_contents(file)

        if 'combined' in file.lower():
            kinetics_group, kinematics_group = combined_processing(groups,file_directory)
            combined_data.append((kinetics_group, kinematics_group))

    # if you want to do more processing (adding offsets, etc) it should be done in the merge_combined_data function
    # after merging all into a new group
    merge_combined_data(combined_data)


if __name__ == "__main__":

    # main(sys.argv[-1])

    # tdms_directory = '/home/schwara2/Documents/Open_Knees/knee_hub/oks003/calibration/DataProcessing/'
    # Model_Properties = '/home/schwara2/Documents/Open_Knees/knee_hub/oks003/calibration/Registration/model/ModelProperties.xml'

    #laxity and passive flexion processing for calibration
    # all tdms file must be in the tdms_directory, not in subfolders. State file must also be in the same folder
    # tdms_directory = "C:\\Users\schwara2\Documents\Open_Knees\oks003_calibration\DataProcessing"
    # Model_Properties = "C:\\Users\schwara2\Documents\Open_Knees\oks003_calibration\Registration\ModelProperties.xml"
    tdms_directory = "C:/oks/app/KneeHub/test/tdms"
    Model_Properties = "C:/oks/app/KneeHub/test/tdms/ModelProperties.xml"
    process_tdms_files(tdms_directory, Model_Properties)

    # # combined loading for benchmarking
    # tdms_directory ="C:\\Users\schwara2\Documents\Open_Knees\oks003_benchmarking\Data_Processing"
    # process_tdms_combined(tdms_directory)
