"""

Description:



Getting started:


    Original Author:
        Erica Morrill
        Department of Biomedical Engineering
        Lerner Research Institute
        Cleveland Clinic
        Cleveland, OH
        morrile2@ccf.org

"""

import xml.etree.ElementTree as ET
import os
import Tkinter as tk
import pandas
import tkFileDialog
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
import numpy as np
import time
import math
import ConfigParser
import statsmodels.formula.api as sm



class FileSelectionApp(tk.Tk):
    """Application to display all of the trials in a list"""

    def __init__(self):
        tk.Tk.__init__(self)

        home = os.path.expanduser('~')
        for dirname, subdirList, fileList in os.walk(home):
            for dir in subdirList:
                if "MULTIS_test" in dir and "studies" not in dirname:
                    multis_dir = dirname + '/' + dir

        try:
            multis_dir
        except NameError:
            multis_dir = tkFileDialog.askdirectory(title="Open MULTIS trials directory")

        self.directory = multis_dir

        self.getSubjects()
        self.title('Select Files')

        self.masterdf = pandas.DataFrame(data=None)

        self.columnconfigure(0, weight=1)

        self.var = []
        self.i = 0
        col = 0
        for item in self.subFiles:
            self.var.append(tk.IntVar())
            c = tk.Checkbutton(self, text=item[0:-2], variable=self.var[self.i])
            c.grid(column = col, row=self.i%10, sticky='w')
            self.i +=1
            col = self.i/10

        self.minsize(200, 100/(self.i))
        tk.Button(self, text="Okay", command=self.checkBoxes).grid(row=1, column =self.i/10+1, sticky='ens')
        tk.Button(self, text="Select All", command=self.SelectAll).grid(row=0, column=self.i/10+1, sticky='ens')

    def yview(self, *args):
        apply(self.yview, args)

    def checkBoxes(self):

        s_time = time.time()
        locations  = ['LA_A', 'LA_P', 'UA_A', 'UA_P', 'LL_A','LL_P', 'UL_A', 'UL_P']
        # locations = ['LA_A'] #Used for testing

        for loc in locations:
            count = 0
            self.df = pandas.DataFrame(data=None)
            self.location = loc
            for bb in self.var:
                if bb.get() == 1:
                    xml = self.dir[count] + '/Configuration/' + os.path.split(self.dir[count])[1] + '.xml'
                    # self.saveDemographicData(xml)
                    self.saveData(xml, count)
                count += 1

            #Build Function to create report
            self.avg = self.df.mean()
            self.stdDev = self.df.std()

        self.masterdf.to_csv('/home/morrile2/Documents/Multis/studies/Indentation/dat/' + '002_MasterList_indentation.csv')

        print("Elapsed Time: %f seconds" %float((time.time()-s_time)))

        self.quit()

    def saveDemographicData(self, xml_name):

        doc = ET.parse(xml_name)
        root = doc.getroot()

        subjData = root.find("Subject_Data")
        Anatomical = subjData.find("Anatomical_Measurements")

        for child in Anatomical:
            if child.attrib['type'] == 'Cluster':
                for cl in child.findall("Cluster"):
                    if cl._children[0].text == 'Length':
                        length = float(cl._children[1].text)
                    elif cl._children[0].text == 'Distal Circumference':
                        circ_d = float(cl._children[1].text)
                    elif cl._children[0].text == 'Central Circumference':
                        circ_c = float(cl._children[1].text)
                    elif cl._children[0].text == 'Proximal Circumference':
                        circ_p = float(cl._children[1].text)

                if child.tag == 'Upper_Arm':
                    column_names = ['dc (cm)', 'cc (cm)', 'pc (cm)']
                    self.df_UA = pandas.DataFrame([[circ_d, circ_c, circ_p]], columns=column_names)
                elif child.tag == 'Lower_Arm':
                    column_names = ['dc (cm)', 'cc (cm)', 'pc (cm)']
                    self.df_LA = pandas.DataFrame([[circ_d, circ_c, circ_p]], columns=column_names)
                elif child.tag == 'Upper_Leg':
                    column_names = ['dc (cm)', 'cc (cm)', 'pc (cm)']
                    self.df_UL = pandas.DataFrame([[circ_d, circ_c, circ_p]], columns=column_names)
                elif child.tag == 'Lower_Leg':
                    column_names = ['dc (cm)', 'cc (cm)', 'pc (cm)']
                    self.df_LL = pandas.DataFrame([[circ_d, circ_c, circ_p]], columns=column_names)

    def lin_fit(self, x, y):

        try:
            '''Fits a linear fit of the form mx to the data'''
            A = np.vstack([x]).T
            m, _, _, _ = np.linalg.lstsq(A, y)

            y_fit = m * x
            y_bar = np.average(y)
            SS_tot = np.sum((y - y_bar) ** 2)
            SS_res = np.sum((y - y_fit) ** 2)
            R_sqr = 1 - SS_res / SS_tot

            RMS = np.sqrt(np.average((y - y_fit) ** 2))
            return m[0], R_sqr, y_fit, RMS
        except:
            return np.nan, np.nan, np.nan, np.nan

    def poly_fit(self, x, y):
        try:
            data = pandas.DataFrame(data={'x':x, 'y':y})
            olsres2 = sm.ols(formula='y~x+I(x**2)-1', data=data).fit()
            y_fit = olsres2.predict(data)
            y_bar = np.average(y)
            SS_tot = np.sum((y - y_bar) ** 2)
            SS_res = np.sum((y - y_fit) ** 2)
            R_sqr = 1 - SS_res / SS_tot

            RMS = np.sqrt(np.average((y - y_fit) ** 2))
            return R_sqr, y_fit, RMS
        except:
            return np.nan, np.nan, np.nan

    def saveData(self, subj_xml, count):
        # print(self.dir[count])

        files = self.getThickFiles(self.location, self.dir[count])
        # print(len(files))
        subj_xml = ET.parse(subj_xml)
        root = subj_xml.getroot()

        demo = root.find('Subject_Data').find("Demographics")
        age = (demo.find("Age").text)
        gender = (demo.find("Gender").get("sel"))
        ethnicity = demo.find("Ethnicity").get("sel")
        race = demo.find("Race").get("sel")

        H_M = root.find("Subject_Data").find("Height_and_Mass")
        height = float(H_M.find("Height").find("Magnitude").text)
        mass = float(H_M.find("Mass").find("Magnitude").text)
        BMI = mass/((height/100)**2)

        activity = (root.find("Subject_Data").find("Activity_Level").find("Lifestyle").get("sel"))


        if len(files) > 0:

            for xml_name in files:
                skin = []
                fat = []
                muscle = []
                Fx = []
                Fy = []
                Fz = []
                total = []
                time = []

                doc = ET.parse(xml_name)
                root = doc.getroot()

                subj = root.find('Subject')
                src = subj.find('Source')
                locs = src.findall("Frame")
                for loc in locs:
                    thick = loc.find("Thickness")
                    force = loc.find("Forces")
                    Fx.append(float(force.find("Fx").text))
                    Fy.append(float(force.find("Fy").text))
                    Fz.append(float(force.find("Fz").text))
                    skin.append(float(thick.find("Skin").text))
                    fat.append(float(thick.find("Fat").text))
                    muscle.append(float(thick.find("Muscle").text))
                    total.append(skin[-1] + fat[-1] + muscle[-1])
                    time.append(float(loc.find("Time").attrib["value"]))

                if 'C_' in os.path.split(xml_name)[1]:

                    config = ConfigParser.RawConfigParser()
                    if not config.read(os.path.join(os.path.join(self.dir[count], 'Configuration'), os.path.split(xml_name)[1][0:25] + '_State.cfg')):
                        raise IOError, "Cannot load configuration file... Check path."

                    probe = config.get('6-DOF Load', 'Description')
                    if probe[1:-1] == "9L4 Ultrasound":
                        probe_area = 45*15 # in mm
                    elif probe[1:-1] == "14L5 Ultrasound":
                        probe_area = 45*8 # in mm

                    Fx = np.array(Fx)
                    Fy = np.array(Fy)
                    Fz = np.array(Fz)
                    force_mag = np.sqrt((Fx-Fx[0])**2+(Fy-Fy[0])**2+(Fz-Fz[0])**2)

                    column_names = ['TotalStrain']
                    df_data = pandas.DataFrame([[(total[0]-total[-1])/total[0]]], columns=column_names)

                    if min(force_mag) > 2:
                        print(os.path.split(xml_name), min(force_mag))

                    force_mag = force_mag/probe_area
                    total = np.array(total)
                    time = np.array(time)

                    R_sqr_poly, y_poly_fit, RMS_poly = self.poly_fit(-(total-total[0]), force_mag)

                    stiff_t, R_sqr, y_lin_fit, RMS_lin = self.lin_fit(-(total-total[0]), force_mag)
                    EMod_t = stiff_t * total[0]

                    # if R_sqr_poly - R_sqr < 0.05:
                    #     plt.scatter(-(total-total[0]), force_mag)
                    #     plt.plot(-(total-total[0]), y_poly_fit, 'g')
                    #     plt.plot(-(total-total[0]), y_lin_fit, 'r')
                    #     plt.ylabel('Pressure (MPa)')
                    #     plt.xlabel('Displacement (mm)')
                    #     plt.savefig('/home/morrile2/Documents/Multis/studies/Indentation/dat/Linear Plots/'+os.path.split(xml_name)[1][:-3]+'png')
                    #     plt.close()

                    ind_rate, R_sqr_2, y_lin_fit_2, RMS_lin_2 = self.lin_fit(time-time[0], (total[0]-total))

                    column_names = ['Thickness', 'Total_Stiff', 'R_squared', 'R_squared_poly', 'Diff rsquared', 'RMS_lin', 'RMS_poly', 'Diff RMS', 'Total_EMod', 'Rate']

                    df_cent_stiff = pandas.DataFrame([[total[0], stiff_t, R_sqr, R_sqr_poly, R_sqr_poly-R_sqr, RMS_lin, RMS_poly, RMS_lin-RMS_poly, EMod_t, ind_rate*1000]], columns=column_names)

                    column_names = ['SubID', 'Location', 'Age', 'Gender', 'BMI', 'ActivityLevel', 'Race',
                                    'Ethnicity']
                    df_demo = pandas.DataFrame(
                        [[xml_name[-40:-37], self.location[0:4], age, gender, BMI, activity, race, ethnicity]],
                        columns=column_names)

                    self.masterdf = self.masterdf.append(pandas.concat([df_demo, df_data, df_cent_stiff], axis=1),
                                                         ignore_index=True)

                else:
                    print("Error", os.path.split(xml_name)[1])

    def SelectAll(self):
        for bb in self.var:
            bb.set(1)


    def getSubjects(self):
        self.subFiles = []
        self.dir = []
        for dirname,subdirList,fileList in os.walk(self.directory):
            for file in fileList:
                if "TA_inclusion.xml" in file:
                    self.dir.append(os.path.split(os.path.split(dirname)[0])[0])
                    self.subFiles.append(os.path.split(os.path.split(os.path.split(dirname)[0])[0])[1])
        sortedTrials = sorted(zip(self.subFiles, self.dir))
        self.subFiles, self.dir = zip(*sortedTrials)

    def getThickFiles(self, location, directory):
        loc_files = []

        thick_xml_list = self.parseXMLinclusion(location, os.path.join(os.path.join(os.path.join(directory,'TissueThickness'), 'UltrasoundManual'), os.path.split(directory)[1]+"_TA_inclusion.xml"))

        for th_xml in thick_xml_list:
            loc_files.append(os.path.join(directory, 'TissueThickness', 'UltrasoundManual',  th_xml))

        return loc_files

    def parseXMLinclusion(self, location, inclusion):

        doc = ET.parse(inclusion)
        root = doc.getroot()
        thickness_xmls = []

        for child in root:
            for subChild in child:
                try:
                    XML = subChild.attrib["Indentation"]
                    if XML != "None" and XML[-28:-27] == "I" and XML[-34:-30] == location:
                        thickness_xmls.append(XML)
                except:
                    continue

        return thickness_xmls

if __name__ == "__main__":

    app = FileSelectionApp()
    app.mainloop()
