"""

Description:



Getting started:


    Original Author:
        Erica Morrill
        Department of Biomedical Engineering
        Lerner Research Institute
        Cleveland Clinic
        Cleveland, OH
        morrile2@ccf.org

"""

import xml.etree.ElementTree as ET
import os
import Tkinter as tk
import pandas
import tkFileDialog
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
import numpy as np
import time
import math

class FileSelectionApp(tk.Tk):
    """Application to display all of the trials in a list"""

    def __init__(self):
        tk.Tk.__init__(self)

        # home = os.path.expanduser('~')
        # for dirname, subdirList, fileList in os.walk(home):
        #     for dir in subdirList:
        #         if "MULTIS_test" in dir and "studies" in :
        #             multis_dir = dirname + '/' + dir
        #
        # try:
        #     multis_dir
        # except NameError:
        #     multis_dir = tkFileDialog.askdirectory(title="Open MULTIS trials directory")
        multis_dir = '../dat/MULTIS_test'
        self.directory = multis_dir

        self.getSubjects()
        self.title('Select Files')

        self.masterdf = pandas.DataFrame(data=None)

        self.columnconfigure(0, weight=1)

        self.var = []
        self.i = 0
        for item in self.subFiles:
            self.var.append(tk.IntVar())
            c = tk.Checkbutton(self, text=item, variable=self.var[self.i])
            c.grid(column = 0, row=self.i, sticky='w')
            self.i +=1

        tk.Button(self, text="Okay", command=self.checkBoxes).grid(row=1, column =1, sticky='ens')
        tk.Button(self, text="Select All", command=self.SelectAll).grid(row=0, column=1, sticky='ens')

    def yview(self, *args):
        apply(self.yview, args)

    def make_summary(self, segment, seg_name):

        df2 = segment[['Segment', 'Location', 'Skin_Proximal', 'Fat_Proximal', 'Muscle_Proximal']]
        df2['Loc2'] = ['P']*len(df2)
        df2.columns = ['Segment', 'Location','Skin', 'Fat', 'Muscle', 'Loc2']

        df3=segment[['Segment', 'Location', 'Skin_Central', 'Fat_Central', 'Muscle_Central']]
        df3['Loc2'] = ['M']*len(df3)
        df3.columns = ['Segment', 'Location','Skin', 'Fat', 'Muscle', 'Loc2']

        df4 = segment[['Segment', 'Location', 'Skin_Distal', 'Fat_Distal', 'Muscle_Distal']]
        df4['Loc2'] = ['D'] * len(df4)
        df4.columns = ['Segment', 'Location', 'Skin', 'Fat', 'Muscle', 'Loc2']

        # df_mean = df2.append(df3.append(df4)).groupby(['Location', 'Loc2']).mean()
        # df_std_pos = df2.append(df3.append(df4)).groupby(['Location', 'Loc2']).std()

        df_mean = df3.groupby(['Segment','Location', 'Loc2']).mean()
        df_std = df3.groupby(['Segment', 'Location', 'Loc2']).std()

        # df_std_neg = np.zeros(df_std_pos._values.shape)
        # print(df_std_pos._values.shape, df_std_neg.shape)
        # print(np.vstack(([df_std_neg], [df_std_pos])))

        def stylize_axes(ax):
            ax.spines['top'].set_visible(False)
            ax.spines['right'].set_visible(False)

            ax.xaxis.set_tick_params(top='off', direction='out', width=1)
            ax.yaxis.set_tick_params(right='off', direction='out', width=1)

        fig, ax = plt.subplots(figsize=(4,3))
        title = seg_name + " Skin"
        tissue = "Skin"
        error = np.array([np.vstack(([np.zeros(df_std[tissue]._values.shape)], [np.array(df_std[tissue]._values)]))])
        df_mean[tissue].plot(kind='bar', width = 1.0, yerr = error)
        # plt.title(title)
        # plt.ylabel("Thickness (mm)")
        plt.xlabel("")
        plt.ylim([0, 3])
        plt.xticks(np.arange(4), ("A", "L", "M", "P"), rotation='horizontal')
        plt.yticks([0, 1, 2, 3], ("0", "1", "2", "3"))
        fig.subplots_adjust(bottom=0.2)
        stylize_axes(ax)
        plt.savefig('../doc/Figures/' + title + '.png', format='png', dpi = 150)
        plt.close()

        fig, ax = plt.subplots(figsize=(4,3))
        title = seg_name + " Fat"
        tissue = "Fat"
        error = np.array([np.vstack(([np.zeros(df_std[tissue]._values.shape)], [np.array(df_std[tissue]._values)]))])
        df_mean[tissue].plot(kind='bar', width=1.0, yerr=error)
        # plt.title(title)
        # plt.ylabel("Thickness (mm)")
        plt.xlabel("")
        plt.ylim([0, 30])
        plt.xticks(np.arange(4), ("A", "L", "M", "P"), rotation='horizontal')
        # plt.yticks([0, 10, 20, 30], ("0", "10", "20", "30"))
        fig.subplots_adjust(bottom=0.2)
        stylize_axes(ax)
        plt.savefig('../doc/Figures/' + title + '.png', format='png', dpi = 150)
        plt.close()

        fig, ax = plt.subplots(figsize=(4,3))
        title = seg_name + " Muscle"
        tissue = "Muscle"
        error = np.array([np.vstack(([np.zeros(df_std[tissue]._values.shape)], [np.array(df_std[tissue]._values)]))])
        df_mean[tissue].plot(kind='bar', width=1.0, yerr=error)
        # plt.title(title)
        # plt.ylabel("Thickness (mm)")
        plt.xlabel("")
        plt.ylim([0, 70])
        plt.xticks(np.arange(4), ("A", "L", "M", "P"), rotation='horizontal')
        # plt.yticks([0, 10, 20, 30, 40, 50, 60, 70], ("0", "", "20", "", "40", "", "60", ""))
        fig.subplots_adjust(bottom=0.2)
        stylize_axes(ax)
        plt.savefig('../doc/Figures/' + title + '.png', format='png', dpi = 150)
        plt.close()

    def checkBoxes(self):

        s_time = time.time()
        locations  = ['LA_A', 'LA_L', 'LA_P', 'LA_M', 'UA_A', 'UA_L', 'UA_P', 'UA_M', 'LL_A', 'LL_L', 'LL_P', 'LL_M', 'UL_A', 'UL_L', 'UL_P', 'UL_M']
        # locations = ['LA_A'] #Used for testing

        for loc in locations:
            count = 0
            self.df = pandas.DataFrame(data=None)
            self.location = loc
            for bb in self.var:
                if bb.get() == 1:
                    xml = self.dir[count] + '/Configuration/' + os.path.split(self.dir[count])[1] + '.xml'
                    self.saveDemographicData(xml)
                    self.saveData(count)
                count += 1

            #Build Function to create report
            self.avg = self.df.mean()
            self.stdDev = self.df.std()

        self.make_summary(self.masterdf[(self.masterdf['Segment'] == "UA")], "Upper Arm")
        self.make_summary(self.masterdf[(self.masterdf['Segment'] == "LA")], "Lower Arm")
        self.make_summary(self.masterdf[(self.masterdf['Segment'] == "UL")], "Upper Leg")
        self.make_summary(self.masterdf[(self.masterdf['Segment'] == "LL")], "Lower Leg")

        plt.show()
        self.masterdf.to_csv('../sol/' + 'Thickness_data_raw.csv')

        # Check for missing data
        missing_prox = self.masterdf["Total_Proximal"].isnull()
        missing_cent = self.masterdf["Total_Central"].isnull()
        missing_dist = self.masterdf["Total_Distal"].isnull()

        print("Missing proximal: ")
        print(self.masterdf[["SubID", "Segment", "Location"]][missing_prox])
        print("Missing central: ")
        print(self.masterdf[["SubID", "Segment", "Location"]][missing_cent])
        print("Missing Distal: ")
        print(self.masterdf[["SubID", "Segment", "Location"]][missing_dist])

        print("Number of locations missing data", len(self.masterdf[missing_prox])+len(self.masterdf[missing_cent])+len(self.masterdf[missing_dist]))

        print("Elapsed Time: %f seconds" %float((time.time()-s_time)))

        self.quit()

    def saveDemographicData(self, xml_name):

        doc = ET.parse(xml_name)
        root = doc.getroot()

        subjData = root.find("Subject_Data")
        Anatomical = subjData.find("Anatomical_Measurements")

        for child in Anatomical:
            if child.attrib['type'] == 'Cluster':
                for cl in child.findall("Cluster"):
                    if cl._children[0].text == 'Length':
                        length = float(cl._children[1].text)
                    elif cl._children[0].text == 'Distal Circumference':
                        circ_d = float(cl._children[1].text)
                    elif cl._children[0].text == 'Central Circumference':
                        circ_c = float(cl._children[1].text)
                    elif cl._children[0].text == 'Proximal Circumference':
                        circ_p = float(cl._children[1].text)

                if child.tag == 'Upper_Arm':
                    column_names = ['pc (cm)', 'cc (cm)', 'dc (cm)']
                    self.df_UA = pandas.DataFrame([[circ_p, circ_c, circ_d]], columns=column_names)
                elif child.tag == 'Lower_Arm':
                    column_names = ['pc (cm)', 'cc (cm)', 'dc (cm)']
                    self.df_LA = pandas.DataFrame([[circ_p, circ_c, circ_d]], columns=column_names)
                elif child.tag == 'Upper_Leg':
                    column_names = ['pc (cm)', 'cc (cm)', 'dc (cm)']
                    self.df_UL = pandas.DataFrame([[circ_p, circ_c, circ_d]], columns=column_names)
                elif child.tag == 'Lower_Leg':
                    column_names = ['pc (cm)', 'cc (cm)', 'dc (cm)']
                    self.df_LL = pandas.DataFrame([[circ_p, circ_c, circ_d]], columns=column_names)


    def saveData(self, count):
        # print(self.dir[count])

        files = self.getThickFiles(self.location, self.dir[count])

        for xml_name in files:
            # xml_name = self.dir[count]+'/'+self.subFiles[count]
            doc = ET.parse(xml_name)
            root = doc.getroot()

            subj = root.find('Subject')
            src = subj.find('Source')
            loc = src.find("Frame")
            thick = loc.find("Thickness")
            force = loc.find("Forces")
            f_mag = math.sqrt((float(force.find("Fx").text))**2 + (float(force.find("Fy").text))**2 + (float(force.find("Fz").text))**2)
            if f_mag > 2 and subj.attrib["ID"] != 'MULTIS001-1' and subj.attrib["ID"] != 'MULTIS002-1' and subj.attrib["ID"] != 'MULTIS004-1' and subj.attrib["ID"] != 'MULTIS005-1':
                print('Min force over 2N', xml_name, 'Force = %.2f' % f_mag)

            skin = (float(thick.find("Skin").text))
            fat = (float(thick.find("Fat").text))
            muscle = (float(thick.find("Muscle").text))
            if skin != np.nan and fat != np.nan and muscle != np.nan:
                total = skin+fat+muscle
            else:
                total = np.nan

            if 'P_' in os.path.split(xml_name)[1]:
                column_names = ['Skin_Proximal', 'Fat_Proximal', 'Muscle_Proximal', 'Total_Proximal']
                df_prox = pandas.DataFrame([[skin, fat, muscle, total]], columns=column_names)
                # print(df_prox)
                debug = "p"
            elif 'C_' in os.path.split(xml_name)[1]:
                column_names = ['Skin_Central', 'Fat_Central', 'Muscle_Central', 'Total_Central']
                df_cent = pandas.DataFrame([[skin, fat, muscle, total]], columns=column_names)
                # print(df_cent)
            elif 'D_' in os.path.split(xml_name)[1]:
                column_names = ['Skin_Distal', 'Fat_Distal', 'Muscle_Distal', 'Total_Distal']
                df_dist = pandas.DataFrame([[skin, fat, muscle, total]], columns=column_names)
                # print(df_dist)
                debug = 'd'
            else:
                print("Error", os.path.split(xml_name)[1])

            column_names = ['SubID', 'Segment', 'Location']
            df_demo = pandas.DataFrame([[xml_name[-40:-37], self.location[0:2], self.location[3]]], columns=column_names)
            if self.location[0:2] == 'UA':
                df_circ = self.df_UA
            elif self.location[0:2] == 'LA':
                df_circ = self.df_LA
            elif self.location[0:2] == 'UL':
                df_circ = self.df_UL
            elif self.location[0:2] == 'LL':
                df_circ = self.df_LL

        try:
            df_prox
        except NameError:
            column_names = ['Skin_Proximal', 'Fat_Proximal', 'Muscle_Proximal', 'Total_Proximal']
            df_prox = pandas.DataFrame([[np.nan, np.nan, np.nan, np.nan]], columns=column_names)

        try:
            df_cent
        except NameError:
            column_names = ['Skin_Central', 'Fat_Central', 'Muscle_Central', 'Total_Central']
            df_cent = pandas.DataFrame([[np.nan, np.nan, np.nan, np.nan]], columns=column_names)

        try:
            df_dist
        except NameError:
            column_names = ['Skin_Distal', 'Fat_Distal', 'Muscle_Distal', 'Total_Distal']
            df_dist = pandas.DataFrame([[np.nan, np.nan, np.nan, np.nan]], columns=column_names)

        self.df = self.df.append(pandas.concat([df_prox, df_cent, df_dist], axis=1), ignore_index=True)
        self.masterdf = self.masterdf.append(pandas.concat([df_demo, df_circ, df_prox, df_cent, df_dist], axis=1), ignore_index=True)


    def SelectAll(self):
        for bb in self.var:
            bb.set(1)


    def getSubjects(self):
        self.subFiles = []
        self.dir = []
        for dirname,subdirList,fileList in os.walk(self.directory):
            for file in fileList:
                if "TA_inclusion.xml" in file:
                    self.dir.append(dirname)
                    self.subFiles.append(os.path.split(dirname)[1])
        sortedTrials = sorted(zip(self.subFiles, self.dir))
        self.subFiles, self.dir = zip(*sortedTrials)

    def getThickFiles(self, location, directory):
        loc_files = []

        thick_xml_list = self.parseXMLinclusion(location, os.path.join(directory,os.path.split(directory)[1]+"_TA_inclusion.xml"))

        for th_xml in thick_xml_list:
            loc_files.append(os.path.join(directory, th_xml))

        return loc_files

    def parseXMLinclusion(self, location, inclusion):
        doc = ET.parse(inclusion)
        root = doc.getroot()
        thickness_xmls = []

        for child in root:
            for subChild in child:
                XML = subChild.attrib["Anatomical"]
                if XML != "None" and XML[-28:-27] == "A" and XML[-34:-30] == location:
                    thickness_xmls.append(XML)

        return thickness_xmls

if __name__ == "__main__":

    app = FileSelectionApp()
    app.mainloop()
