#!/usr/bin/env python
"""
Copyright (c) 2015 Computational Biomechanics (CoBi) Core, Department of
Biomedical Engineering, Cleveland Clinic

Permission is hereby granted, free of charge, to any person obtaining a
copy of this software and associated documentation files (the
"Software"), to deal in the Software without restriction, including
without limitation the rights to use, copy, modify, merge, publish,
distribute, sublicense, and/or sell copies of the Software, and to permit
persons to whom the Software is furnished to do so, subject to the
following conditions:

The above copyright notice and this permission notice shall be included
in all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN
NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE
USE OR OTHER DEALINGS IN THE SOFTWARE.
-------------------

tdms_plotting.py

DESCRIPTION:

Python script to read a tdms file and create plots for data. Places plots in 
new directory with the path of the tdms file in svg and png formats. Processes
kinetic and kinematic data, which is also plotted. One must place the name of 
of tdms files in text file and include as argument for script to run properly. 
Finally, the script can compare AP laxity data for reproducibility testing. 
If performing this analysis, place all relevant files in text file, include as 
argument when running script, and enter the value '1' when prompted. This will
produce plots comparing all three AP laxity tests and a table of RMSD values 
with minimums, maximums, and means for each axis for kinematic and kinetic data 
for the three AP laxity tests.

REQUIREMENTS:

Python (http://www.python.org)
nptdms (https://pypi.python.org/pypi/npTDMS/)
matplotlib (http://matplotlib.org/)
NumPy (http://www.numpy.org/)

DEVELOPERS:

Omar M. Gad & Ahmet Erdemir
Computational Biomodeling (CoBi) Core
Department of Biomedical Engineering
Lerner Research Institute
Cleveland Clinic
Cleveland, OH
gado@ccf.org
erdemira@ccf.org

Modified by Erica E. Neumann to be used with animation/plotting script.
Computational Biomodeling (CoBi) Core
Department of Biomedical Engineering
Lerner Research Institute
Cleveland Clinic
Cleveland, OH
morrile2@ccf.org
"""

import sys
from nptdms import TdmsFile
import numpy as np
import os
import pandas as pd

"""
This function extracts all the information and data from the tdms file. Then, 
a function (tdms_plot) plots the raw data. Desired kinetic data is processed
via the function kinetics_data. This is used to create processed plots for the 
actual kinetic and kinematic data from the tdms file.
"""    
def tdms_contents(argv):
        
    f = open(argv[-1])
    files = f.read().splitlines()
    processed = False
    # files = argv[-1]

    for file in files:
        tdms_file = TdmsFile(file)
        file_name = os.path.split(file)[1]
        groups = tdms_file.groups()
        
        path = os.path.abspath(file)
        title = os.path.splitext(file_name)[0]
        root = os.path.splitext(path)[0]
        base = os.path.split(root)[0]
        
    # prints list of groups and extracts channels from tdms file
    
        for group in groups:
            channels = tdms_file.group_channels(group)
            
            # creates empty lists to place information later
            channel_list = []
            channel_data_list = []
            channel_unit_list = []
            channel_time_list = []
            
            # extracts information from each channel
            for channel in channels:        
                try: 
                    channel_label = channel.properties["NI_ChannelName"]
                    channel_units = channel.properties['NI_UnitDescription']
                    channel_data = channel.data
                    channel_time = channel.time_track()
                    
                # creates lists for plotting
                    channel_data_list.append(channel_data)
                    channel_list.append(channel_label)
                    channel_unit_list.append(channel_units)
                    channel_time_list.append(channel_time)
                except:
                    break
            
            #Plots raw data if group has six channels (desired groups are listed for plotting)
            if len(channel_data_list) == 6 and (group == 'State.JCS Load' or group == 'State.Knee JCS' or
                group == 'Kinetics.JCS.Desired' or group == 'Kinematics.JCS.Desired' or group == 'State.JCS' or group == 'State.Knee PTFJ'):
                tdms_plot(title, group, channel_unit_list, 
                          channel_data_list, channel_time_list, 
                          channel_list, root, processed)
            else:
               pass
           

            if "passive flexion" in file.lower() and group == 'Kinematics.JCS.Desired':
                index_list = find_indices(channel_list, channel_data_list,
                                          channel_time_list, title, group,
                                          channel_unit_list, root)
            elif group == 'Kinetics.JCS.Desired' and "passive flexion" not in file.lower():
                print file.lower()
                index_list = find_indices(channel_list, channel_data_list, channel_time_list, title, group,
                                          channel_unit_list, root)
            else:
                pass

        df_master = pd.DataFrame()

        # Plot processed data
        for group in groups:

            channels = tdms_file.group_channels(group)

            # creates empty lists to place information later
            channel_list = []
            channel_data_list = []
            channel_unit_list = []
            channel_time_list = []

            # extracts information from each channel
            for channel in channels:
                try:
                    channel_label = channel.properties["NI_ChannelName"]
                    channel_units = channel.properties['NI_UnitDescription']
                    channel_data = channel.data
                    channel_time = channel.time_track()

                    # creates lists for plotting
                    channel_data_list.append(channel_data)
                    channel_list.append(channel_label)
                    channel_unit_list.append(channel_units)
                    channel_time_list.append(channel_time)
                except:
                    break

            if (group == 'State.JCS Load' or group == 'State.Knee JCS' or
                group == 'Kinetics.JCS.Desired'or group == 'Kinematics.JCS.Desired' or group == 'State.JCS' or group == 'State.Knee PTFJ'):
                df = extract_data(
                                    index_list, title, group, channel_unit_list,
                                    channel_data_list, channel_time_list,
                                    channel_list, root, processed)
                df_master = pd.concat([df_master, df], axis=1)
            else:
                pass
        df_master.to_csv(file[:-4]+'csv')

    return index_list

"""
Plots data for entire tdms file. Saves plot in the folder in the root of the
tdms file. If data is processed, file for plot will include '_Extracted' in 
name.
"""                            
import matplotlib.pyplot as plt

def tdms_plot(title, groups, units, data, time, channels, root, processed):
    point = ['s', 'D', 'o', 's', 'D', 'o']
    fig = plt.figure(figsize=(12,8))
    fig.suptitle(title, fontsize=16)
    
    #Creates two graphs one figure
    ax1 = plt.subplot(121)
    ax2 = plt.subplot(122)
    subplots = [ax1, ax2]
    
    for i in range(3):
        ax1.set_ylabel(units[i], fontsize=12)
        if processed == True:
            ax1.plot(time[i], data[i], point[i], label = channels[i])
        else:
            ax1.plot(time[i], data[i], label = channels[i])
    
    for i in range(3, 6):
        ax2.set_ylabel(units[i], fontsize=12)
        if processed == True:
            ax2.plot(time[i], data[i], point[i], label = channels[i])
        else:
            ax2.plot(time[i], data[i], label = channels[i])
    
    for subplot in subplots:
        subplot.set_xlabel('Time (ms)', fontsize=12)
        subplot.legend(loc = 'best')
        subplot.set_title(groups, fontsize = 12)
        subplot.grid(True)        
                      
#    plt.show()
    
    #Saves files as .png and .svg
    
    if not os.path.exists(root): 
        os.makedirs(root)
    
    if processed == True:
        fig.savefig(root + '/' + groups + '_Extracted' + '.png', 
                    format = "png", dpi = 100)
        fig.savefig(root + '/' + groups + '_Extracted' + '.svg', 
                    format = "svg")
    else:
        fig.savefig(root + '/' + groups + '.png', format = "png", dpi = 100)
        fig.savefig(root + '/' + groups + '.svg', format = "svg")
    
    plt.close(fig) 

    
"""
The function kinetics_data determines the index at which consecutive values 
are the same from the desired kinetics data. Returns index_list, which is an 
array for the final index at which the values are the same. This data is then 
plotted using the function tdms_plot. One may use the start point, or find 
midpoint by creating range of values using start point and end point of each list.
"""
def find_indices(channel_list, channel_data_list, channel_time_list,
                  title, group, channel_unit_list, root):
    if 'Kinetics' in group:
        index_le = []
        index_ls = []

        for channel_data, channel_label in zip(channel_data_list, channel_list):

            a = channel_data
            i = 0

            index_end = []
            index_start = []

            while i < len(a) - 1:
                j = i + 1
                if j == len(a) - 1:
                    #appends the last value of the window to the list
    #                index_end.append(j)
                    #appends the first value of the window to the list
    #                index_start.append(i)
                    i = j
                    break
                if abs(a[i] - a[j]) < 0.000000001:
                    for k in range(j + 1, len(a)):
                        if abs(a[i] - a[k]) < 0.000000001:
                            if k == len(a) - 1:
                                #appends the last value of the window to the list
                                index_end.append(k)
                                #appends the first value of the window to the list
                                index_start.append(i)
                                i = k
                                break
                            else:
                                continue
                        else:
                            #appends the last value of the window to the list
                            index_end.append(k - 1)
                            #appends the first value of the window to the list
                            index_start.append(i)
                            i = k
                            break
                else:
                    i = j

                index_end_array = np.array(index_end)
                index_start_array = np.array(index_start)


            index_ls.extend(index_start_array)
            indices_start = list(set(index_ls))
            indices_st = sorted(indices_start)

            index_le.extend(index_end_array)
            indices_end = list(set(index_le))
            indices = sorted(indices_end)

            #Special scenario in which beginning of step in loading condition
            #is not zero resulting in change in number of values in indics_st
            while len(indices_st) > len(indices):
                for p in range(len(indices)):
                    if indices[p] > indices_st[p+1]:
                        indices_st.remove(indices_st[p+1])
                        break

        #Creates list of six sets of key "time index"
        #To be applied to each channel to extract data (may use indices_st)

    #    index_list_start = [indices_st, indices_st, indices_st, indices_st, indices_st, indices_st]

    elif 'Kinematics' in group:
        flex_idx = [i for i, x in enumerate(channel_list) if "Flexion" in x][0]
        data = np.array(channel_data_list[flex_idx])
        indices = []
        angle_increments = np.arange(0, np.max(data), 5)
        for a in angle_increments:
            a_ind = np.where(np.abs(data - a)<0.01)[0]
            if len(a_ind) < 10:
                indices.append(a_ind[0])
                indices.append(a_ind[-1])
            else:
                indices.append(int(np.average(a_ind)))
        indices = sorted(indices)

    index_list_end = [indices, indices, indices, indices, indices, indices]

    return index_list_end
           
    
"""
Extracts data from other groups and channels for plotting based on processed 
time_end found in kinetics_data.
"""
def extract_data(index_list, title, group, channel_unit_list,
                 channel_data_list, channel_time_list, channel_list, root, processed):
    title_p = title + '_extracted'
    processed = True 
    data_processed = []
    time_processed = []
    range_list = []
    data_info_list = []
    
    #Unpack channels
    for index_end, channel_data, channel_time in zip(index_list, 
        channel_data_list, channel_time_list):
        
        data = []
        time = []
        
        #Unpack indices
        for index in index_end:
            data.append(channel_data[index])
            data_array = np.array(data)
            time.append(channel_time[index])
            time_array = np.array(time)
        data_max = round(max(data_array), 8)
        data_min = round(min(data_array), 8)
        data_mean = round(np.mean(data_array), 8)
        range_data = data_max - data_min
        
        data_info = [data_max, data_min, data_mean]
        data_info_list.append(data_info)
        range_list.append(range_data)
        data_processed.append(data_array)
        time_processed.append(time_array)
    
    tdms_plot(title_p, group, channel_unit_list, data_processed, 
              time_processed, channel_list, root, processed)         
    
    df = pd.DataFrame()

    df['Extracted Time Points [s]'] = time_processed[0]
    for k in range(len(channel_list)):
        df[channel_list[k] + '  ' + '[' + channel_unit_list[k] + ']'] = data_processed[k]

    return df

    
if __name__ == "__main__":
    tdms_contents(sys.argv)