#!/usr/bin/env python
"""
Copyright (c) 2015 Computational Biomechanics (CoBi) Core, Department of
Biomedical Engineering, Cleveland Clinic

Permission is hereby granted, free of charge, to any person obtaining a
copy of this software and associated documentation files (the
"Software"), to deal in the Software without restriction, including
without limitation the rights to use, copy, modify, merge, publish,
distribute, sublicense, and/or sell copies of the Software, and to permit
persons to whom the Software is furnished to do so, subject to the
following conditions:

The above copyright notice and this permission notice shall be included
in all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN
NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE
USE OR OTHER DEALINGS IN THE SOFTWARE.
-------------------

tdms_plotting.py

DESCRIPTION:

Python script to read a tdms file and create plots for data. Places plots in 
new directory with the path of the tdms file in svg and png formats. Processes
kinetic and kinematic data, which is also plotted. One must place the name of 
of tdms files in text file and include as argument for script to run properly. 
Finally, the script can compare AP laxity data for reproducibility testing. 
If performing this analysis, place all relevant files in text file, include as 
argument when running script, and enter the value '1' when prompted. This will
produce plots comparing all three AP laxity tests and a table of RMSD values 
and normalized % error for kinematic and kinetic data for the three AP laxity 
tests.

REQUIREMENTS:

Python (http://www.python.org)
nptdms (https://pypi.python.org/pypi/npTDMS/)
matplotlib (http://matplotlib.org/)
NumPy (http://www.numpy.org/)

DEVELOPERS:

Omar M. Gad & Ahmet Erdemir
Computational Biomodeling (CoBi) Core
Department of Biomedical Engineering
Lerner Research Institute
Cleveland Clinic
Cleveland, OH
gado@ccf.org
erdemira@ccf.org
"""
import sys
from nptdms import TdmsFile
import numpy as np
import os

#Check that passing argument has a tdms file and the program to run 
def USAGE(argv):
    print 
    print 'USAGE: ' + argv[0] + ' <text file with tdms files>'
    print     
"""
This function extracts all the information and data from the tdms file. Then, 
a function (tdms_plot) plots the raw data. Desired kinetic data is processed
via the function kinetics_data. This is used to create processed plots for the 
actual kinetic and kinematic data from the tdms file.
"""    
def tdms_contents(argv):
   # if the call length has less than two values exit the function
    if len(argv) != 2:
        USAGE(argv)
        sys.exit(1)   
    # open the file given, read it, and set it as unprocessed        
    f = open(argv[-1])
    files = f.read().split()
    processed = False
    # get each path for the files    
    for file in files:
        tdms_file = TdmsFile(file)
        file_name = os.path.split(file)[1]
        groups = tdms_file.groups()
        path = os.path.abspath(file)
        title = os.path.splitext(file_name)[0]
        root = os.path.splitext(path)[0]
        # for each group set empty lists for each channels
        for group in groups:
            channels = tdms_file.group_channels(group)
            channel_list = []
            channel_data_list = []
            channel_unit_list = []
            channel_time_list = []
            # extracts information from each channel
            for channel in channels:        
                try: 
                    channel_label = channel.properties["NI_ChannelName"]
                    channel_units = channel.properties['NI_UnitDescription']
                    channel_data = channel.data
                    channel_time = channel.time_track()  
                # creates lists for plotting
                    channel_data_list.append(channel_data)
                    channel_list.append(channel_label)
                    channel_unit_list.append(channel_units)
                    channel_time_list.append(channel_time)
                except:
                    break
            #for each group of a specific given name  store all of their channels
            if group == 'State.JCS Load':
                loads_infoPF = [title, group, channel_unit_list, channel_data_list,
                                channel_time_list, channel_list, root, processed]
            if group == 'State.Knee JCS':
                kine_infoPF = [title, group, channel_unit_list, channel_data_list,
                                channel_time_list, channel_list, root, processed]
            if group == 'State.Knee PTFJ':
                kinePTFJ_infoPF = [title, group, channel_unit_list, channel_data_list,
                                channel_time_list, channel_list, root, processed]
            
            if group == 'Kinetics.JCS.Desired':
                kinedes_infoPF = [title, group, channel_unit_list, channel_data_list,
                                channel_time_list, channel_list, root, processed]
            #Plots data if group has six channels
            if len(channel_data_list) == 6:
                tdms_plot(title, group, channel_unit_list, 
                          channel_data_list, channel_time_list, 
                          channel_list, root, processed)
            else:
               pass
           
            if group == 'Kinetics.JCS.Desired':
                index_list = find_indices(channel_list, channel_data_list, 
                                          channel_time_list, title, group, 
                                          channel_unit_list, root)
                   
            else: 
                pass
            
            processed = False
        print ('Plots and data can be found in the following directory: ' 
                + root + '\n') 
        #call the extract data function for each of the groups and their stored channels.
        extract_data(index_list, loads_infoPF)
        extract_data(index_list, kine_infoPF)
        extract_data(index_list, kinedes_infoPF)
        extract_data(index_list, kinePTFJ_infoPF)
    else:   
        pass            
"""
Plots data for entire tdms file. Saves plot in the folder in the root of the
tdms file. If data is processed, file for plot will include '_Extracted' in 
name.
"""                            
import matplotlib.pyplot as plt

def tdms_plot(title, groups, units, data, time, channels, root, processed):
    point = ['s', 'D', 'o', 's', 'D', 'o']
    fig = plt.figure(figsize=(12,8))
    fig.suptitle(title, fontsize=16) 
    #Creates two graphs one figure
    ax1 = plt.subplot(121)
    ax2 = plt.subplot(122)
    subplots = [ax1, ax2]
    for i in range(3):
        ax1.set_ylabel(units[i], fontsize=12)
        if processed == True:
            ax1.plot(time[i], data[i], point[i], label = channels[i])
        else:
            ax1.plot(time[i], data[i], label = channels[i])
    for i in range(3, 6):
        ax2.set_ylabel(units[i], fontsize=12)
        if processed == True:
            ax2.plot(time[i], data[i], point[i], label = channels[i])
        else:
            ax2.plot(time[i], data[i], label = channels[i])
    for subplot in subplots:
        subplot.set_xlabel('Time (ms)', fontsize=12)
        subplot.legend(loc = 'best')
        subplot.set_title(groups, fontsize = 12)
        subplot.grid(True)        
    if not os.path.exists(root): 
        os.makedirs(root)
    #save the figures and create the titles of the plots
    if processed == True:
        fig.savefig(root + '/' + groups + '_Extracted' + '.png', 
                    format = "png", dpi = 100)
        fig.savefig(root + '/' + groups + '_Extracted' + '.svg', 
                    format = "svg")
    else:
        fig.savefig(root + '/' + groups + '.png', format = "png", dpi = 100)
        fig.savefig(root + '/' + groups + '.svg', format = "svg")
    
    plt.close(fig)     
"""
The function kinetics_data determines the index at which consecutive values 
are the same from the desired kinetics data. Returns index_list, which is an 
array for the final index at which the values are the same. This data is then 
plotted using the function tdms_plot. One may use the start point, or find 
midpoint by creating range of values using start point and end point of each list.
"""
def find_indices(channel_list, channel_data_list, channel_time_list,
                  title, group, channel_unit_list, root):
    index_l = []
    for channel_data, channel_label in zip(channel_data_list, channel_list):       
        a = channel_data
        i = 0
        index_end = []
        while i < len(a) - 1:
            j = i + 1
            if j == len(a) - 1:
                #appends the last value of the window the list
                index_end.append(j)
                i = j
                break
            if abs(a[i] - a[j]) < 0.000000001:
                for k in range(j + 1, len(a)):
                    if abs(a[i] - a[k]) < 0.000000001:
                        if k == len(a) - 1:
                            #appends the last value of the window to the list
                            index_end.append(k)
                            i = k
                            break
                        else:
                            continue
                    else:
                        #appends the last value of the window to the list
                        index_end.append(k - 1)
                        i = k
                        break                    
            else:
                i = j
            index_end_array = np.array(index_end)
        index_l.extend(index_end_array)
        indices1 = list(set(index_l))
        indices = sorted(indices1)
    index_list = [indices, indices, indices, indices, indices, indices]
    return index_list          
"""
Extracts data from other groups and channels for plotting based on processed 
time_end found in kinetics_data.
"""
def extract_data(index_list, channel_infoPF):
    #set each channel to its appropriate title
    title = channel_infoPF[0]
    group = channel_infoPF[1]
    channel_unit_list = channel_infoPF[2]
    channel_data_list = channel_infoPF[3]
    channel_time_list = channel_infoPF[4]
    channel_list = channel_infoPF[5]
    root = channel_infoPF[6]
    processed = channel_infoPF[7]
    title_p = title + '_processed'
    processed = True 
    data_processed = []
    time_processed = []
    range_list = []
    data_info_list = []
    #Unpack channels
    for index_end, channel_data, channel_time in zip(index_list, 
        channel_data_list, channel_time_list):
        data = []
        time = []
        #Unpack indices
        for index in index_end:
            data.append(channel_data[index])
            data_array = np.array(data)
            time.append(channel_time[index])
            time_array = np.array(time)
        data_max = round(max(data_array), 8)
        data_min = round(min(data_array), 8)
        data_mean = round(np.mean(data_array), 8)
        range_data = data_max - data_min
        data_info = [data_max, data_min, data_mean]
        data_info_list.append(data_info)
        range_list.append(range_data)
        data_processed.append(data_array)
        time_processed.append(time_array)
    #call the plot function for the given values    
    tdms_plot(title_p, group, channel_unit_list, data_processed, 
              time_processed, channel_list, root, processed)         
    fd = open(root + '/' + group + '.txt', 'w')
    fd.write('Text file for extracted data for each experimental test.\n\n')
    fd.write('file = ' + title + '\n\n')
    fd.write('group = ' + group + '\n\n')
    fd.write('Extracted Time Points [ms]\n')
    fd.write(str(time_processed[0]) + '\n\n')
    for k in range(len(channel_list)):
        fd.write(channel_list[k] + '  ' + '[' + channel_unit_list[k] + ']\n')
        fd.write(str(data_processed[k]) + '\n\n')
    fd.write('\n')
    fd.close()
    
    print 'Data has been processed for ' + group + '.'

#make this script callable from the terminal, and run all of its functions   
if __name__ == "__main__":
    tdms_contents(sys.argv)    