#!/usr/bin/env python
"""
Copyright (c) 2015 Computational Biomechanics (CoBi) Core, Department of
Biomedical Engineering, Cleveland Clinic

Permission is hereby granted, free of charge, to any person obtaining a
copy of this software and associated documentation files (the
"Software"), to deal in the Software without restriction, including
without limitation the rights to use, copy, modify, merge, publish,
distribute, sublicense, and/or sell copies of the Software, and to permit
persons to whom the Software is furnished to do so, subject to the
following conditions:

The above copyright notice and this permission notice shall be included
in all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN
NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE
USE OR OTHER DEALINGS IN THE SOFTWARE.
-------------------

tdms_plotting.py

DESCRIPTION:

Python script to read a tdms file and create plots for data. Places plots in 
new directory with the path of the tdms file in svg and png formats. Processes
kinetic and kinematic data, which is also plotted. One must place the name of 
of tdms files in text file and include as argument for script to run properly. 
Finally, the script can compare AP laxity data for reproducibility testing. 
If performing this analysis, place all relevant files in text file, include as 
argument when running script, and enter the value '1' when prompted. This will
produce plots comparing all three AP laxity tests and a table of RMSD values 
with minimums, maximums, and means for each axis for kinematic and kinetic data 
for the three AP laxity tests.

REQUIREMENTS:

Python (http://www.python.org)
nptdms (https://pypi.python.org/pypi/npTDMS/)
matplotlib (http://matplotlib.org/)
NumPy (http://www.numpy.org/)

DEVELOPERS:

Omar M. Gad & Ahmet Erdemir
Computational Biomodeling (CoBi) Core
Department of Biomedical Engineering
Lerner Research Institute
Cleveland Clinic
Cleveland, OH
gado@ccf.org
erdemira@ccf.org
"""

import sys
from nptdms import TdmsFile
import numpy as np
import os

def USAGE(argv):
    print 
    print 'USAGE: ' + argv[0] + ' <text file with tdms files>'
    print 

def USAGE1(argv):
    print
    print 'USAGE: ' + argv[0] + ' <text file with three tdms files for reproducibility testing> '
    print
    
"""
This function extracts all the information and data from the tdms file. Then, 
a function (tdms_plot) plots the raw data. Desired kinetic data is processed
via the function kinetics_data. This is used to create processed plots for the 
actual kinetic and kinematic data from the tdms file.
"""    
def tdms_contents(argv):
    
    if len(argv) != 2:
        USAGE(argv)
        sys.exit(1)
        
        
    f = open(argv[-1])
    files = f.read().splitlines()
    processed = False
    repro_testing = raw_input('Enter the value 1 and enter for reproducibility testing. Otherwise hit enter.\n')           
    print
    
    if repro_testing == '1':
        if len(files) != 3:
            USAGE1(argv)
            sys.exit(1)
    else:
        pass
        
    group_loads = []
    group_kinematics = []
    titles = []
    time_repro = []
    title_repro = []
    range_loads = []
    range_kine = []
    data_info_loads = []
    data_info_kine = []
    
    for file in files:
        
        tdms_file = TdmsFile(file)
        file_name = os.path.split(file)[1]
        groups = tdms_file.groups()
        
        path = os.path.abspath(file)
        title = os.path.splitext(file_name)[0]
        root = os.path.splitext(path)[0]
        base = os.path.split(root)[0]
        
    # prints list of groups and extracts channels from tdms file
    
        for group in groups:
            channels = tdms_file.group_channels(group)
            
            # creates empty lists to place information later
            channel_list = []
            channel_data_list = []
            channel_unit_list = []
            channel_time_list = []
            
            # extracts information from each channel
            for channel in channels:        
                try: 
                    channel_label = channel.properties["NI_ChannelName"]
                    channel_units = channel.properties['NI_UnitDescription']
                    channel_data = channel.data
                    channel_time = channel.time_track()
                    
                # creates lists for plotting
                    channel_data_list.append(channel_data)
                    channel_list.append(channel_label)
                    channel_unit_list.append(channel_units)
                    channel_time_list.append(channel_time)
                except:
                    break
            
            #Plots data if group has six channels
            if len(channel_data_list) == 6:
                tdms_plot(title, group, channel_unit_list, 
                          channel_data_list, channel_time_list, 
                          channel_list, root, processed)
            else:
               pass
           
            if group == 'Kinetics.JCS.Desired':
                index_list = find_indices(channel_list, channel_data_list, 
                                          channel_time_list, title, group, 
                                          channel_unit_list, root)
            else: 
                pass
            
            if (group == 'State.JCS Load' or group == 'State.Knee JCS' or 
                group == 'Kinetics.JCS.Desired'):
                (data_processed, time_processed, range_list, data_info) = extract_data(
                                    index_list, title, group, channel_unit_list,
                                    channel_data_list, channel_time_list,
                                    channel_list, root, processed)
                
            else:
                pass
            
            processed = False
            if repro_testing == '1':
                
                #Save necessary information and construct data lists for 
                #rms calculations for kinetic and kinematic
                if group == 'State.JCS Load':
                    titles.append(title)
                    group_loads.append(data_processed)
                    time_repro.append(time_processed)
                    range_loads.append(range_list)
                    data_info_loads.append(data_info)
                    title_repro.append(title)
                    group_info_loads = (channel_list, channel_unit_list, base)
                    channel_loads = channel_list
                    unit_loads = channel_unit_list
                    group_loads_label = group
                
                elif group == 'State.Knee JCS':
                    group_kinematics.append(data_processed)
                    group_info_kinematics = (channel_list, channel_unit_list, 
                                             base)
                    channel_kine = channel_list
                    unit_kine = channel_unit_list
                    group_kine_label = group
                    range_kine.append(range_list)
                    data_info_kine.append(data_info)
                    
                else:
                    pass

            else:
                pass
        
        print ('Plots and data can be found in the following directory: ' 
                + root + '\n') 
    
    
    if repro_testing == '1':
        repro_test(group_loads, group_kinematics, group_loads_label, group_kine_label, 
               channel_loads, channel_kine, base, title_repro, unit_kine, unit_loads,
               range_kine, range_loads, data_info_loads, data_info_kine)
               
        repro_ap_plot(title_repro, group_loads_label, group_kine_label, 
                      group_loads, group_kinematics, group_info_loads, group_info_kinematics)
                
        for group in groups:
            if group == 'State.JCS Load':
                repro_plot(title_repro, group, group_loads, time_repro, 
                           group_info_loads)
            elif group == 'State.Knee JCS':
                repro_plot(title_repro, group, group_kinematics, time_repro, 
                           group_info_kinematics)
            else:
                pass
    else:   
        pass            
    

"""
Plots data for entire tdms file. Saves plot in the folder in the root of the
tdms file. If data is processed, file for plot will include '_Extracted' in 
name.
"""                            
import matplotlib.pyplot as plt

def tdms_plot(title, groups, units, data, time, channels, root, processed):
    point = ['s', 'D', 'o', 's', 'D', 'o']
    fig = plt.figure(figsize=(12,8))
    fig.suptitle(title, fontsize=16)
    
    #Creates two graphs one figure
    ax1 = plt.subplot(121)
    ax2 = plt.subplot(122)
    subplots = [ax1, ax2]
    
    for i in range(3):
        ax1.set_ylabel(units[i], fontsize=12)
        if processed == True:
            ax1.plot(time[i], data[i], point[i], label = channels[i])
        else:
            ax1.plot(time[i], data[i], label = channels[i])
    
    for i in range(3, 6):
        ax2.set_ylabel(units[i], fontsize=12)
        if processed == True:
            ax2.plot(time[i], data[i], point[i], label = channels[i])
        else:
            ax2.plot(time[i], data[i], label = channels[i])
    
    for subplot in subplots:
        subplot.set_xlabel('Time (ms)', fontsize=12)
        subplot.legend(loc = 'best')
        subplot.set_title(groups, fontsize = 12)
        subplot.grid(True)        
                      
#    plt.show()
    
    #Saves files as .png and .svg
    
    if not os.path.exists(root): 
        os.makedirs(root)
    
    if processed == True:
        fig.savefig(root + '/' + groups + '_Extracted' + '.png', 
                    format = "png", dpi = 100)
        fig.savefig(root + '/' + groups + '_Extracted' + '.svg', 
                    format = "svg")
    else:
        fig.savefig(root + '/' + groups + '.png', format = "png", dpi = 100)
        fig.savefig(root + '/' + groups + '.svg', format = "svg")
    
    plt.close(fig) 

    
"""
The function kinetics_data determines the index at which consecutive values 
are the same from the desired kinetics data. Returns index_list, which is an 
array for the final index at which the values are the same. This data is then 
plotted using the function tdms_plot. One may use the start point, or find 
midpoint by creating range of values using start point and end point of each list.
"""
def find_indices(channel_list, channel_data_list, channel_time_list,
                  title, group, channel_unit_list, root):
    index_le = []
    index_ls = []
    
    for channel_data, channel_label in zip(channel_data_list, channel_list):
            
        a = channel_data
        i = 0
        
        index_end = []
        index_start = []
        
        while i < len(a) - 1:
            j = i + 1
            if j == len(a) - 1:
                #appends the last value of the window to the list
#                index_end.append(j)
                #appends the first value of the window to the list
#                index_start.append(i)
                i = j
                break
            if abs(a[i] - a[j]) < 0.000000001:
                for k in range(j + 1, len(a)):
                    if abs(a[i] - a[k]) < 0.000000001:
                        if k == len(a) - 1:
                            #appends the last value of the window to the list
                            index_end.append(k)
                            #appends the first value of the window to the list
                            index_start.append(i)   
                            i = k
                            break
                        else:
                            continue
                    else:
                        #appends the last value of the window to the list
                        index_end.append(k - 1)
                        #appends the first value of the window to the list
                        index_start.append(i)
                        i = k
                        break                    
            else:
                i = j
            
            index_end_array = np.array(index_end)
            index_start_array = np.array(index_start)
            
        
        index_ls.extend(index_start_array)
        indices_start = list(set(index_ls))
        indices_st = sorted(indices_start)
        
        index_le.extend(index_end_array)
        indices_end = list(set(index_le))
        indices = sorted(indices_end)
        
        #Special scenario in which beginning of step in loading condition
        #is not zero resulting in change in number of values in indics_st
        while len(indices_st) > len(indices):
            for p in range(len(indices)):
                if indices[p] > indices_st[p+1]:
                    indices_st.remove(indices_st[p+1])
                    break
                    
    #Creates list of six sets of key "time index"
    #To be applied to each channel to extract data (may use indices_st)
    index_list_end = [indices, indices, indices, indices, indices, indices]
#    index_list_start = [indices_st, indices_st, indices_st, indices_st, indices_st, indices_st]
    
    return index_list_end
           
    
"""
Extracts data from other groups and channels for plotting based on processed 
time_end found in kinetics_data.
"""
def extract_data(index_list, title, group, channel_unit_list,
                 channel_data_list, channel_time_list, channel_list, root, processed):
    title_p = title + '_extracted'
    processed = True 
    data_processed = []
    time_processed = []
    range_list = []
    data_info_list = []
    
    #Unpack channels
    for index_end, channel_data, channel_time in zip(index_list, 
        channel_data_list, channel_time_list):
        
        data = []
        time = []
        
        #Unpack indices
        for index in index_end:
            data.append(channel_data[index])
            data_array = np.array(data)
            time.append(channel_time[index])
            time_array = np.array(time)
        data_max = round(max(data_array), 8)
        data_min = round(min(data_array), 8)
        data_mean = round(np.mean(data_array), 8)
        range_data = data_max - data_min
        
        data_info = [data_max, data_min, data_mean]
        data_info_list.append(data_info)
        range_list.append(range_data)
        data_processed.append(data_array)
        time_processed.append(time_array)
    
    tdms_plot(title_p, group, channel_unit_list, data_processed, 
              time_processed, channel_list, root, processed)         
    
    fd = open(root + '/' + group + '.txt', 'w')
    fd.write('Text file for extracted data for each experimental test.\n\n')
    fd.write('file = ' + title + '\n\n')
    fd.write('group = ' + group + '\n\n')
    fd.write('Extracted Time Points [ms]\n')
    fd.write(str(time_processed[0]) + '\n\n')
    for k in range(len(channel_list)):
        fd.write(channel_list[k] + '  ' + '[' + channel_unit_list[k] + ']\n')
        fd.write(str(data_processed[k]) + '\n\n')
    fd.write('\n')
    fd.close()
        
    return (data_processed, time_processed, range_list, data_info_list)

"""
This function analyzes multiple AP laxity tests to determine whether there
were significant differences between the loads prescribed to the knee during
testing, and whether this resulted in different kinematics. This function tests
all combinations of the laxity tests and provides RMSD value between the two 
tests. RMSD values were then normalized for both kinematics and kinetics.
Groups used for this function are State.Knee JCS and State.JCS Load. Finally, 
processed data was used for the analysis.
"""
import itertools

def repro_test(group_loads, group_kinematics, group_loads_label, 
               group_kine_label, channel_loads, channel_kine, base, title_repro,
               units_kine, units_loads, range_kine, range_loads, data_info_loads,
               data_info_kine):
    
    # unpacking variables
    for channel_list in group_loads:
        channel_list_length = len(channel_list)
        for channel in channel_list:
            channel_length = len(channel)

    #Calculate RMS values for loads data
    diff = []
    rms_loads = []
    norm_loads = []
    group_num = range(len(group_loads))
    for a, b in itertools.combinations(group_num, 2):
        for j in range(channel_list_length):
            for k in range(channel_length):
                diff_value = (group_loads[a][j][k] - group_loads[b][j][k])
                diff.append(diff_value)
            rms_l = np.sqrt(np.divide(float(np.sum(np.square(diff))), 
                                      channel_length))
            
            range_mean = np.mean([range_loads[a][j], range_loads[b][j]])
            norm_l = np.multiply(np.divide(float(rms_l), range_mean), float(100))
            
            norm_loads.append(norm_l)
            rms_loads.append(rms_l)
            diff = []
    #Calculate RMS values for kinematic data
    diff = []
    rms_kine = []
    norm_kine = []

    group_num = range(len(group_kinematics))
    for a, b in itertools.combinations(group_num, 2):
        for j in range(channel_list_length):
            for k in range(channel_length):
                diff_value = (group_kinematics[a][j][k] - group_kinematics[b][j][k])
                diff.append(diff_value)
            rms_k = np.sqrt(np.divide(float(np.sum(np.square(diff))), channel_length))
            rms_kine.append(rms_k) 
            range_mean = np.mean([range_kine[a][j], range_kine[b][j]])
            norm_k = np.multiply(np.divide(float(rms_l), range_mean), float(100))
            norm_kine.append(norm_k)
            diff = []
       
    root = base + '/Repro_' + title_repro[2]
    if not os.path.exists(root): 
        os.makedirs(root)
        
    rms_loads_array = np.array(rms_loads)
    rms_kine_array = np.array(rms_kine)
    
    for a, b in itertools.combinations(group_num, 2):
            for k in range(6):
                for j in range(2):
                    if j == 0: 
                        
                        if data_info_loads[a][k][j] >= data_info_loads[b][k][j]:
                            data_info_loads[b][k][j] = data_info_loads[a][k][j]
                        else:
                            data_info_loads[a][k][j] = data_info_loads[b][k][j]
                            
                        if data_info_kine[a][k][j] >= data_info_kine[b][k][j]:
                            data_info_kine[b][k][j] = data_info_kine[a][k][j]
                        else:
                            data_info_kine[a][k][j] = data_info_kine[b][k][j]
                            
                    if j == 1:
                        
                        if data_info_loads[a][k][j] <= data_info_loads[b][k][j]:
                            data_info_loads[b][k][j] = data_info_loads[a][k][j]
                        else:
                            data_info_loads[a][k][j] = data_info_loads[b][k][j]
                            
                        if data_info_kine[a][k][j] <= data_info_kine[b][k][j]:
                            data_info_kine[b][k][j] = data_info_kine[a][k][j]
                        else:
                            data_info_kine[a][k][j] = data_info_kine[b][k][j]
                            
                data_info_loads[0][k][2] = round(np.mean([data_info_loads[0][k][2], 
                        data_info_loads[1][k][2], data_info_loads[2][k][2]]), 8)
                
                data_info_kine[0][k][2] = round(np.mean([data_info_kine[0][k][2], 
                        data_info_kine[1][k][2], data_info_kine[2][k][2]]), 8)
    
    
#    Create text file for RMSD Values
    ft = open(root + '/RMSD_Repro_Table.txt', 'w')
    ft.write('RMSD Values for Joint Mechanics Reproducibility Tests\n\n')
    for i in range(1, 4):
        ft.write(str(i) + ' = ' + title_repro[i - 1] + '\n')
    ft.write('\n')
    ft.write('RMSD Values for ' + group_loads_label + '\tRow1: 1, 2; Row2: 2, 3; Row3: 1, 3\n')
    c = 0
    for k in range(6):
        ft.write(channel_loads[k] + ' [' + units_loads[k] + ']\t\t[Maximum Value, Minimum Value, Mean Value]\n')
        while k <= c + 12:
            if k in range(6): 
                ft.write(str(round(rms_loads_array[k], 8)) + '\t\t\t\t' + str(data_info_loads[0][c]) + '\n')
            else:
                ft.write(str(round(rms_loads_array[k], 8)) + '\n')
            k = k + 6
        c = c + 1
        ft.write('\n')
    
    ft.write('\n\n\n')
    
    ft.write('RMSD Values for ' + group_kine_label + '\tRow1: 1, 2; Row2: 2, 3; Row3: 1, 3\n')
    c = 0
    for k in range(6):
        ft.write(channel_kine[k] + ' [' + units_kine[k] + ']\t\t\t[Maximum Value, Minimum Value, Mean Value]\n')
        while k <= c + 12:
            if k in range(6): 
                ft.write(str(round(rms_kine_array[k], 8)) + '\t\t\t\t' + str(data_info_kine[0][c]) + '\n')
            else:
                ft.write(str(round(rms_kine_array[k], 8)) + '\n')
            k = k + 6
        c = c + 1
        ft.write('\n')
    ft.close()
    
        
#    fd = open(root + '/Data_Info_Repro_Table.txt', 'w')
#    fd.write('Maximum, Minimum, and Mean values for Joint Mechanics Reproducibility Tests\n\n')
#    fd.write('[Maximum Value, Minimum Value, Mean Value]\n\n')
#    for i in range(3):
#        fd.write(title_repro[i] + '\n')
#        fd.write(group_loads_label + '\n')
#        for k in range(6):
#            fd.write(channel_loads[k] + ' [' + units_loads[k] + ']\t\t' + str(data_info_loads[i][k]) + '\n')
#        fd.write('\n')
#    fd.write('\n')
#    for i in range(3):
#        fd.write(title_repro[i] + '\n')
#        fd.write(group_kine_label + '\n')
#        for k in range(6):
#            fd.write(channel_kine[k] + ' [' + units_kine[k] + ']\t\t\t' + str(data_info_kine[i][k]) + '\n')
#        fd.write('\n')
#    fd.close()
        
        
"""
Plots data for the three AP laxity tests. Kinetic and kinematic data plotted. 
Translations plotted and rotations plotted on separate figures for kinematic
data and forces and torques plotted on separate figures for kinetic data. 
Direct comparison of tests. Plots are stored in same path as text file, within
a folder titled Repro_Testing. 
"""
def repro_plot(title_repro, group, group_repro, time_repro, group_info_repro):
    
    channel_list = group_info_repro[0]
    unit_list = group_info_repro[1]
    base = group_info_repro[2]
    root = base + '/Repro_' + title_repro[2]
    if not os.path.exists(root): 
        os.makedirs(root)
        
    f = open(root + '/' + group + '_Data_Repro.txt', 'w')
    f.write('Text file for extracted data from three AP laxity tests.\n\n')
    f.write('Extracted Time Points [ms]\n')
    f.write(str(time_repro[0][0]) + '\n\n')
    for i in range(3):
        f.write('file = ' + title_repro[i] + '\n\n')
        f.write('group = ' + group + '\n')
        for k in range(len(channel_list)):
            f.write(channel_list[k] + '  ' + '[' + unit_list[k] + ']\n')
#            np.savetxt(f, group_repro[i][k])
            f.write(str(group_repro[i][k]) + '\n\n')
        f.write('\n')
    f.close()
            
    for x in range(2): 
    
        point = ['s', 'D', 'o']
        fig = plt.figure(figsize=(18, 8))
        
#        
        #Creates three graphs one figure
        ax1 = plt.subplot(131)
        ax2 = plt.subplot(132)
        ax3 = plt.subplot(133)
        subplots = [ax1, ax2, ax3]
   
#       Unpacking data
        
        for channel_data in group_repro:
#            channel_list_length = len(channel_list)
            for channel in channel_data:
                channel_length = len(channel)
                
        
        time_point_0 = []
        group_point_0 = []
        time_point_1 = []
        group_point_1 = []
        time_point_2 = []
        group_point_2 = []
        
        #Extracts data for each plot.
        #Group_repro represents packaged data. Data structure is
        #group_ repro = [[Group1][Group2][Group3]]
        #Group1 = [[Channel1][Channel][Channel3][Channel4][Channel5][Channel6]]
        #Channel1 = [23_data_points]
        group_num = range(len(group_repro))
        for i in group_num:
            #Creation of data lists for each file (kinetic)
            if x == 0:
                for k in range(channel_length):
                    time_point_0.append(time_repro[i][0][k])
                    group_point_0.append(group_repro[i][0][k])
                    time_point_1.append(time_repro[i][1][k])
                    group_point_1.append(group_repro[i][1][k])
                    time_point_2.append(time_repro[i][2][k])
                    group_point_2.append(group_repro[i][2][k])
            #Creation of data lists for each file (kinematic)
            if x == 1:
                for k in range(channel_length):
                    time_point_0.append(time_repro[i][3][k])
                    group_point_0.append(group_repro[i][3][k])
                    time_point_1.append(time_repro[i][4][k])
                    group_point_1.append(group_repro[i][4][k])
                    time_point_2.append(time_repro[i][5][k])
                    group_point_2.append(group_repro[i][5][k])
            
        
            ax1.plot(time_point_0, group_point_0, point[i])
            ax2.plot(time_point_1, group_point_1, point[i],
                         label = title_repro[i])
            ax3.plot(time_point_2, group_point_2, point[i])
            time_point_0 = []
            group_point_0 = []
            time_point_1 = []
            group_point_1 = []
            time_point_2 = []
            group_point_2 = []
    
        
        for subplot, v, k in zip(subplots, range(3), range(3, 6)):
            if x == 0:
                subplot.set_xlabel('Time (ms)', fontsize=12)
                subplot.legend(loc = 'best')
                subplot.set_title(channel_list[v], fontsize = 12)
                subplot.set_ylabel(unit_list[v])
                subplot.grid(True)
            if x == 1:
                subplot.set_xlabel('Time (ms)', fontsize=12)
                subplot.legend(loc = 'best')
                subplot.set_title(channel_list[k], fontsize = 12)
                subplot.set_ylabel(unit_list[k])
                subplot.grid(True)
        
#    plt.show()
                      
    #Saves files as .png and .svg
    
        if x == 0: 
            if group == 'State.JCS Load':
                fig.suptitle(group + ' Forces Reproducibility Tests', fontsize=16)
                fig.savefig(root + '/' + group + '_Forces' + '.png', 
                    format = "png", dpi = 100)
                fig.savefig(root + '/' + group + '_Forces' + '.svg', 
                    format = "svg")
    
            if group == 'State.Knee JCS':
                fig.suptitle(group + ' Translations Reproducibility Tests', fontsize=16)
                fig.savefig(root + '/' + group + '_Translations' + '.png', 
                    format = "png", dpi = 100)
                fig.savefig(root + '/' + group + '_Translations' + '.svg', 
                    format = "svg")
        
        if x == 1:
            if group == 'State.JCS Load':
                fig.suptitle(group + ' Torques Reproducibility Tests', fontsize=16)
                fig.savefig(root + '/' + group + '_Torques' + '.png', 
                    format = "png", dpi = 100)
                fig.savefig(root + '/' + group + '_Torques' + '.svg', 
                    format = "svg")
    
            if group == 'State.Knee JCS':
                fig.suptitle(group + ' Rotations Reproducibility Tests', fontsize=16)
                fig.savefig(root + '/' + group + '_Rotations' + '.png', 
                    format = "png", dpi = 100)
                fig.savefig(root + '/' + group + '_Rotations' + '.svg', 
                    format = "svg")
           
        plt.close(fig) 
        
    if group == 'State.JCS Load':    
        print ('\nPlots and table for reproducibility testing can be found in' 
             + root)
        print
        
def repro_ap_plot(title_repro, group_loads_label, group_kinematics_label, 
              group_loads, group_kinematics, group_info_loads, group_info_kine):
    
    point = ['s', 'D', 'o']
    
    channel_list_loads = group_info_loads[0]
    unit_list_loads = group_info_loads[1]
        
    channel_list = group_info_kine[0]
    unit_list = group_info_kine[1]
    base = group_info_kine[2]
    root = base + '/Repro_' + title_repro[2]
    if not os.path.exists(root): 
        os.makedirs(root) 
        
    plt.title(group_loads_label + ' vs ' + group_kinematics_label)
    for i in range(3):
        plt.plot(group_loads[i][1], group_kinematics[i][1], point[i], label = title_repro[i])
    plt.xlabel(channel_list[1] + ' [' + unit_list[1] + ']')
    plt.ylabel(channel_list_loads[1] + ' [' + unit_list_loads[1] + ']')
    plt.legend(loc = 'best')
    plt.grid(True)
    
    plt.savefig(root + '/Repro_Plot' + '.png', format = "png")
    plt.savefig(root + '/Repro_Plot' + '.svg', format = "svg")
    
    plt.close()    
    
    
if __name__ == "__main__":
    tdms_contents(sys.argv)    