import dicom
import numpy as np
import matplotlib.pyplot as plt
import peakutils
import csv
import os
import time
import SimpleITK as sitk
from nptdms import TdmsFile
import xml.etree.ElementTree as ET
import fileName
import lxml.etree as etree
start = time.time()
from XMLparser import getAcceptance

"""
The Manual mathing script will match based on 1 or 2 r waves in two files THAT ARE ALREADY KNOWN TO MATCH

This script was written to associate files for the MULTIS in vivo and in vitro experimentation.
It is a supplement to the FileAssociation.py script. This script is intended to manually match
2 files at a time that are known to be matching.

    Original Author:
        Tyler Schimmoeller
        Department of Biomedical Engineering
        Lerner Research Institute
        Cleveland Clinic
        Cleveland, OH
        schimmt@ccf.org


    Part 1 - Reading files, only extracting tdms files with pulse and dicom files with rwave
"""

run = True

Plot = True

SubjectID = fileName.setDirectory()[1]
Directory = fileName.setDirectory()[0]
tdmsDirectory = fileName.setDirectory()[2]
dicomDirectory = fileName.setDirectory()[3]

# for i in range(12):

theChosenRwaves = fileName.rwaveSelection()

count = len(theChosenRwaves)
passed = 0
if run:
    print ". " * count
    for Pairs in range(len(fileName.setFileNames(0)[-1])):
        # print setFileNames(0)[-1][6][0]
        if run: #Pairs == 0:
            chosenRwaves =  theChosenRwaves[Pairs]
            dicomFiles, tdmsFiles = fileName.setFileNames(Pairs)[0:2]
            dicomFiles = [dicomFiles]
            tdmsFiles = [tdmsFiles]
            part1 = time.time()

            def inprogress():
                print (". "),

            print "{} TDMS Files".format(len(tdmsFiles))
            print "{} DICOM Files".format(len(dicomFiles))

            inprogress()

            pulseList = []
            TdmsFilesNoPulse = []
            TdmsFiles_WithPulse = []
            RWaveList = []
            DicomNoRwave = []
            DicomFiles_WithRwave = []
            Std = []
            for i in dicomFiles:
                if i.endswith('.IMA'):
                    try:
                        dataset = dicom.read_file(dicomDirectory + i, stop_before_pixels=True)  # reading all DICOM files
                        RwaveRead = dataset.RWaveTimeVector
                        try:
                            if len(RwaveRead) > 1:
                                RWaveList.append(RwaveRead)  # creating list of rwaves
                                DicomFiles_WithRwave.append(i)
                            else:
                                # print "File {} does not have a pulse train".format(i)
                                DicomNoRwave.append(i)
                        except TypeError:
                            DicomNoRwave.append(i)
                            # print "File {} does not have a pulse train".format(i)
                    except AttributeError:
                        DicomNoRwave.append(i)
                        # print "File {} does not have a pulse train" .format(i)

            for i in tdmsFiles:
                if i.endswith('.tdms'):
                    # count -= 1
                    # inprogress()
                    tdmsFileName = tdmsDirectory + i
                    # tdmsData = tdsmParserMultisCopy.parseTDMSfile(TdmsDirectory + filename)   # reading all TDMS data
                    tdmsFile = TdmsFile(tdmsFileName)
                    try:  # check to see if pulse exists, if not, exclue TDMS file from directory

                        # pulse1 = tdmsData[u'Sensor.Run Number Pulse Train'][u'Run Number Pulse Train']
                        group = tdmsFile.groups()[3]
                        pulse1 = tdmsFile.group_channels(group)[0].data
                        pulseList.append(pulse1)
                        TdmsFiles_WithPulse.append(i)

                    except KeyError:
                        TdmsFilesNoPulse.append(i)  # must create new list because some files have no Pulse Train
                        # print " File '" + i + "' does not have a pulse train"

            print ""
            # print "{} Dicom Files without RWaves".format(len(DicomNoRwave))

            for files in DicomNoRwave:
                print files

            print ""
            # print "{} TDMS Files without Pulse".format(len(TdmsFilesNoPulse))

            for files in TdmsFilesNoPulse:
                print files

            part2 = time.time()
            elapsePart2 = part2 - part1



            """
                Part 2 - iterating through files, matching associated files
            """

            Tdsm_and_pulses = zip(TdmsFiles_WithPulse, pulseList)
            Dicom_and_rwaves = zip(DicomFiles_WithRwave, RWaveList)



            dicom_files = 0
            MatchList = []
            MatchTDMS = []
            MatchDICOM = []
            RunNumList = []
            timeSynchList = []
            matches = 0
            failed = []
            for iteration in range(1):
                iteration = 0
                for dicom_files, rwaves in Dicom_and_rwaves:
                    if dicom_files not in MatchDICOM:
                        for tdms_files, pulses in Tdsm_and_pulses:
                            if tdms_files not in MatchTDMS:
                                # iterations cycle through RWave and Peak variations. The ultrasound has a tendency to skip the first,
                                # second, or fist and second, or last RWaves
                                try:

                                    RWave1 = np.array(rwaves)

                                    RWave2 = np.array(rwaves)[:]
                                    RWave3 = RWave1
                                    RWave = [RWave2, RWave1, RWave2]
                                    RWave = RWave[iteration]
                                except ValueError:
                                    donothing = 0

                                pulse = pulses

                                def findPeaks():
                                    Peaks = peakutils.indexes(pulse, thres=0.5 * max(pulse), min_dist=100)
                                    return Peaks
                                fp = findPeaks()
                                peaks1 = findPeaks()[:]

                                try:
                                    peaks2 = [fp[i] for i in chosenRwaves]
                                except IndexError:
                                    print 'Check RWaves for: ' + tdms_files, dicom_files
                                    exit
                                peakLIST = [peaks2, peaks1, peaks2]
                                peaks = peakLIST[iteration]

                                # useful when iteration > 1 when peaks taken away
                                peakBits = peaks1 - peaks1[0]


                                # Set the pulse to begin at peak[0] at time = 0, ie first peak is at zero
                                # shiftedPeaks = peaks - peaks[0]

                                # Find the avg difference between R Wave's and pulse Peaks
                                def findDeltaT(rwave, Peaks):
                                    Delta_T1 = []
                                    for i in range(len(Peaks)):
                                        dt = rwave[i] - Peaks[i]
                                        Delta_T1.append(dt)

                                    Delta_T = sum(Delta_T1) / len(Peaks)
                                    return Delta_T


                                # Find a vector of differences to later find max
                                def findDeltaT_vector(rwave, Peaks):
                                    Delta_T_Vector = []
                                    for rw, p in zip(rwave, Peaks):
                                        Delta_T_Vector.append(rw - p)
                                    return Delta_T_Vector


                                # Ensure number of RWave number of Peaks, else end
                                if len(RWave) != len(peaks):  # must be equal to continue
                                    donothing = True  # necessary for an if loop
                                    print "LENGTH NOT EQUAL --> Change # of Rwaves"

                                    print len(RWave), len(peaks)
                                else:
                                    # find initial difference, use to adjust rwaves for alignment
                                    Delta_T_initial = findDeltaT(RWave, peaks)

                                    # used for synchronizing files in later modeling
                                    if Delta_T_initial > 0:
                                        timeSynch = Delta_T_initial
                                    else:  # if DAQ started before ultrasound
                                        timeSynch = RWave[0] + Delta_T_initial

                                    RWave_adjusted = RWave - Delta_T_initial
                                    # if Delta_T_initial > 0:  # positive error means RWave needs to shift left.
                                    #     RWave_adjusted = RWave - Delta_T_initial
                                    # else:
                                    #     RWave_adjusted = RWave + Delta_T_initial

                                    # now that rwaves are optimally aligned, the max diff between any rwave and associated peak from signal
                                    # should not exceed 95 else not a match
                                    Delta_T_opt = np.array(findDeltaT_vector(RWave_adjusted, peaks))
                                    Std.append(round(np.std(Delta_T_opt), 2))
                                    max_dt = max(abs(Delta_T_opt))

                                    if max_dt > 200:
                                        donothing = 0
                                        print "DO NOT MATCH --> Change # of Rwaves or Select diff waves"
                                        # print "****Files do NOT match***"
                                        # print "Max Delta_T = ", max_dt
                                        # print " . " * count

                                    else:
                                        print "Files Match"
                                        passed += 1
                                        # import bit widths to decode pulse, each number represents 1 bit
                                        def getBits(filename):
                                            f = open(filename)
                                            PulseWidthsFile = csv.reader(f)
                                            Bits = []
                                            for row in PulseWidthsFile:
                                                Bits.append(row[0])
                                            return map(int, Bits)  # map converts list of strings to list of int

                                        # loading bit markers (each at specified location)
                                        Bits = getBits('../PulseWidths300.csv')
                                        Bits_Original = Bits  # to plot unaltered bits if desired
                                        # ensuring bit markers line up with pulse peaks in each iteration
                                        Bits = Bits - (peakLIST[0][0] - peakLIST[0][0])

                                        # Create the x values for pulse plot
                                        def createXvals_msec(pulse):
                                            xVals_msec = []
                                            for index in range(len(pulse)):  # change peaks to peaks0 to plot original peak set
                                                xVals_msec += [range(len(pulse))[index] - len(pulse[0:peaks1[0]])]
                                            return xVals_msec

                                        def createBinary(Bits, peakBits):
                                            # may need to add margin somehow. right now peaks and bits have to be exactly equal
                                            #
                                            # notice always adding new number to beginning, hence in reverse
                                            # because binary signal was created in reverse
                                            Binary = ""
                                            for b in Bits:
                                                if b in peakBits:  # if a peak is located at bit marker -> binary = 1, else 0
                                                    Binary = '1' + Binary
                                                else:
                                                    Binary = '0' + Binary

                                            SubjID_bin = Binary[0:8]
                                            RunNumNum_bin = Binary[9:18]
                                            SubjID_dec = str(int(SubjID_bin, 2))
                                            RunNumNum_dec = str(int(RunNumNum_bin, 2))
                                            return Binary, RunNumNum_bin, SubjID_bin, RunNumNum_dec, SubjID_dec

                                        SubjID_dec = createBinary(Bits, peakBits)[-1]

                                        def make3digits(input):
                                            if len(input) == 1:
                                                output = '00' + str(input)
                                            elif len(input) == 2:
                                                output = '0' + str(input)
                                            else:
                                                output = str(input)
                                            return output

                                        RunNumNum_dec = createBinary(Bits, peakBits)[-2]
                                        subID = make3digits(SubjID_dec)
                                        RunNum = make3digits(RunNumNum_dec)

                                        if Plot:

                                            reader = sitk.ImageFileReader()
                                            reader.SetFileName(dicomDirectory + dicom_files)

                                            img = reader.Execute()
                                            zSlice = -1
                                            img = img[0:880, 658:702, zSlice]

                                            nda = sitk.GetArrayFromImage(img)

                                            filtered = []
                                            for rows, columns in enumerate(nda):
                                                rowList = []
                                                for cols in range(len(columns)):
                                                    if any(nda[rows][cols] > 160):
                                                        rowList.append([0, 0, 0])
                                                    else:
                                                        rowList.append([255, 0, 0])
                                                filtered.append(rowList)

                                            fig, (ax1, ax2) = plt.subplots(nrows=2, ncols=1, figsize=(18, 9), dpi=90)
                                            ax1 = plt.subplot2grid((7, 1), (0, 0), rowspan=5)
                                            ax2 = plt.subplot2grid((7, 1), (5, 0), rowspan=2)

                                            fig.tight_layout(pad=3, w_pad=.5, h_pad=2)
                                            ax2.imshow(nda[:,:,0], interpolation=None, cmap='gray')

                                            # plotting pulse
                                            # ax1.plot(createXvals_msec(pulse), pulse, color='black', linewidth=2, label='$Pulse$')
                                            ax1.plot(pulse, color='black', linewidth=2.5, label='$Pulse$')
                                            ax1.set_xlim(0, 8000)
                                            ax1.set_ylim(-0.1, 0.6)

                                            ax1.annotate('', xy=(RWave[0], 0.52), xytext=(RWave_adjusted[0], 0.52),
                                                         arrowprops=dict(arrowstyle="<|-|>", color="red", lw=2, alpha=.6))
                                            Xcord = RWave[0] + (RWave_adjusted[0] - RWave[0]) / 2

                                            ax1.annotate('dTavg: ' + str(round(Delta_T_initial,2)) + " It: " + str(iteration),
                                                         xy=(Xcord, 0.52),
                                                         xycoords='data',
                                                         xytext=(-120, 60), textcoords='offset points',
                                                         size=12,
                                                         bbox=dict(boxstyle="round,pad=.2", fc='red', ec='blue', lw=2, alpha=.6),
                                                         arrowprops=dict(arrowstyle="-", color='red', lw=2, alpha=.6,
                                                                         connectionstyle="angle,angleA=0,angleB=90,rad=5"))

                                            # plotting adjusted R Wave
                                            # RWave_adjusted = RWave1 # comment/uncomment to plot synched RWave with  ORIGINAL pulse
                                            for wave in RWave_adjusted:

                                                if wave == RWave_adjusted[0]:
                                                    ax1.axvline(x=wave, linewidth=2, color='green', linestyle='--',
                                                                label='$R Wave Adjusted$')
                                                else:
                                                    ax1.axvline(x=wave, linewidth=2, color='green', linestyle='--')

                                            # plotting UN adjusted R Wave
                                            for wave in RWave:

                                                if wave == RWave[0]:
                                                    ax1.axvline(x=wave, linewidth=2, color='red', linestyle=':',
                                                                label='$R Wave Un-Adjusted$')
                                                else:
                                                    ax1.axvline(x=wave, linewidth=2, color='red', linestyle=':')

                                            # plotting pulse widths for encoding/decoding
                                            Bits = Bits_Original + peaks1[0]  # comment/uncomment when want to plot original bits
                                            for bits in Bits:
                                                if bits == Bits[0]:
                                                    ax1.axvline(x=Bits[0], linewidth=2, color='gray', linestyle='-', alpha=0.4)

                                                elif bits == Bits[2]:
                                                    ax1.axvline(x=Bits[2], linewidth=2, color='gray', linestyle='-', alpha=0.4)
                                                elif bits == Bits[12]:
                                                    ax1.axvline(x=Bits[12], linewidth=2, color='gray', linestyle='-', alpha=0.4)

                                            for bits in range(len(Bits)):
                                                ax1.axvline(x=Bits[bits], linewidth=.5, color='gray', linestyle='-')

                                            trial = "MULTIS" + str(subID) + " - Trial # " + str(RunNum)

                                            ax1.set_xlabel("time(ms)")
                                            ax1.set_ylabel("Volts")
                                            ax1.set_title(trial + " - " + 'Dicom #: ' + dicom_files[-42:-40] + "   MANUAL MATCH", color='red')
                                            ax1.legend(loc='upper right', prop={'size': 10}, borderpad=0.1, handlelength=5)

                                            ax2.set_title("DICOM pulse extracted from last image in sequence. Pattern, not start"
                                                          " time, should match pulse above", fontsize=14)
                                            ax2.set_xlabel(
                                            "* if pulse not present, error when recording. However, pulse data still collected ",
                                            fontsize=10)

                                            frame1 = plt.gca()
                                            frame1.axes.get_xaxis().set_ticks([])
                                            frame1.axes.get_yaxis().set_ticks([])

                                            plt.draw()
                                            if not os.path.exists(Directory + '/FileAssociationPNG'):
                                                os.makedirs(Directory + '/FileAssociationPNG')
                                            plt.savefig(Directory + '/FileAssociationPNG/' + tdms_files[-30:-5] + '.png')

                                            matches += 1

                                            timeSynchList.append(timeSynch)
                                            RunNumList.append(RunNum)

                                            def removeZeros_binary(num):
                                                Binary = 0
                                                number = 0
                                                if num[0] == '0':
                                                    if num[1] == '0':
                                                        number = num[2]
                                                        Binary = bin(int(number))[2:]
                                                    else:
                                                        number = num[1:]
                                                        Binary = bin(int(number))[2:]

                                                return number, Binary
                                        else:
                                            donothing = 0

                                        timeFile = open(Directory + '/TimeSynchronization.txt', 'a')
                                        f = open(Directory + '/' + SubjectID + 'readme.txt', 'a')
                                        f.write("%s     %s %s\n" % (RunNum, round(timeSynch, 4), Std))
                                        timeFile.write("%s %s\n" % (RunNum, round(timeSynch, 4)))

                                        analysis_path = Directory + '/TimeSynchronization/'


                                        TDMS_name = tdms_files

                                        ii = 0

                                        # Name of the xml file
                                        if TDMS_name[0:3] == 'NMa':
                                            xml_name = analysis_path + TDMS_name[4:-5] + '_dT.xml'
                                        else:
                                            xml_name = analysis_path + TDMS_name[:-5] + '_dT.xml'

                                        tail = TDMS_name[16:25]

                                        # subID = TDMS_name[4:15]
                                        root = ET.Element(SubjectID)
                                        loc = ET.SubElement(root, 'Location')
                                        ET.SubElement(loc, "Name").text = tail

                                        ET.SubElement(loc, "dT").text = str(round(timeSynch, 4))

                                        tree = ET.ElementTree(root)
                                        tree.write(xml_name, xml_declaration=True)

                                        ii += 1

                                        # Renaming Matched Files
                                        try:
                                            os.rename(dicomDirectory + dicom_files,
                                                      dicomDirectory + RunNum + '_' + dicom_files[-61:])
                                            # MatchDICOM.append(RunNum + '_' + dicom_files[-61:])
                                        except OSError:

                                            try:
                                                os.rename(dicomDirectory + dicom_files,
                                                          dicomDirectory + RunNum + '_' + dicom_files[:])
                                                # MatchDICOM.append(RunNum + '_' + dicom_files[:])
                                            except OSError:
                                                print "check file length"
                                                print dicom_files
                                                exit()

                                        os.rename(tdmsDirectory + tdms_files, tdmsDirectory + tdms_files[-30:])

                                        MatchList.append(SubjectID + '_' + RunNum)
                                        MatchTDMS.append(tdms_files)
                                        MatchDICOM.append(dicom_files)

                                        # verifying different outputs
                                        # print "Number of RWaves = ", len(RWave)
                                        # print "Number of Peaks = ", len(peaks)
                                        # print "Pulse length = ", len(pulse)
                                        # print "RWaves at ", RWave
                                        # print "Peaks at ", shiftPeaks(peaks)
                                        # print "Bit locations = ", Bits
                                        # print "bits length = ", len(Bits)
                                        # print "delta T vector = ", findDeltaT_vector(RWave, shiftPeaks(peaks))
                                        # print "Delta_T_Opt vector = ", Delta_T_opt
                                        # print "Delta_T_Initial : ",
                                        # print "Time Synchronization : ", timeSynch
                                        # print "Delta_T_Opt avg SHOULD BE ZERO : ", findDeltaT(RWave_adjusted, peaks)
                                        # print "Std : {}".format(Std)
                                        print "SubjectID = ", SubjectID
                                        print "Experiment Run Num = ", RunNum

                                        count -= 1
                                        # plt.show()
                            else:
                                continue
                    else:
                        donothing = 0

print 'Manual Matching is complete'

if passed == len(theChosenRwaves):
    print 'running dataQuality'
    import dataQuality
else:
    exit()

