import dicom
import numpy as np
import matplotlib
matplotlib.use('TkAgg')
import matplotlib.pyplot as plt
import peakutils
import csv
import os
import time
import SimpleITK as sitk
from nptdms import TdmsFile
import xml.etree.ElementTree as ET
import fileName
import lxml.etree as etree
import warnings
import zipfile
import shutil
start = time.time()
from XMLparser import getAcceptance
warnings.simplefilter('ignore', UserWarning)
warnings.simplefilter('ignore', RuntimeWarning)
"""
The Manual mathing script will match based on 1 or 2 r waves in two files THAT ARE ALREADY KNOWN TO MATCH

This script was written to associate files for the MULTIS in vivo and in vitro experimentation.
It is a supplement to the FileAssociation.py script. This script is intended to manually match
2 files at a time that are known to be matching.

    Original Author:
        Tyler Schimmoeller
        Department of Biomedical Engineering
        Lerner Research Institute
        Cleveland Clinic
        Cleveland, OH
        schimmt@ccf.org


    Part 1 - Reading files, only extracting tdms files with pulse and dicom files with rwave
"""



"""

This beta matcher does not require user to manually select RWaves. However, it does require paired file inputs of
known matches

"""

# import time
# print("something")
   # pause 5.5 seconds
# print("something")

Directory , SubjectID, tdmsDirectory, dicomDirectory = fileName.setDirectory(1)

# time.sleep(20)

# ans = raw_input('Unpack files from Export[y/n]')
# unpack = False
#
# if ans == 'y':
#     zip_ref = zipfile.ZipFile('../MULTIS_trials/EXPORT/' + SubjectID + '.zip', 'r')
#     zip_ref.extractall(path='../MULTIS_trials')
#     zip_ref.close()
#     shutil.move('../MULTIS_trials/EXPORT', '../MULTIS_trials/' + SubjectID)
#     os.rename('../MULTIS_trials/' + SubjectID + '/EXPORT', '../MULTIS_trials/' + SubjectID + '/Ultrasound')

Plot = True

passed = 0

dicomFiles = sorted(os.listdir(dicomDirectory))

lengthOfLongFile = len(dicomFiles[0]) # for renaming purposes

for files in dicomFiles: # change single digit trial numbers from '1' to '01' for alphabetical order purposes
    # print lengthOfLongFile, len(files)
    if len(files) < lengthOfLongFile:
        newFilename = files[0:-41] + '0' + files[-41:]
        os.rename(dicomDirectory + files, dicomDirectory + newFilename)

tdmsFiles = sorted(os.listdir(tdmsDirectory))
dicomFiles = sorted(os.listdir(dicomDirectory))

part1 = time.time()

def inprogress():
    print ("."),
count = len(tdmsFiles)
print ". " * count

pulseList = []
TdmsFilesNoPulse = []
TdmsFiles_WithPulse = []
RWaveList = []
DicomNoRwave = []
DicomFiles_WithRwave = []
Std = []
for i in dicomFiles:
    if i.endswith('.IMA'):
        try:
            dataset = dicom.read_file(dicomDirectory + i, stop_before_pixels=True)  # reading all DICOM files
            RwaveRead = dataset.RWaveTimeVector
            try:
                if len(RwaveRead) > 1:
                    RWaveList.append(RwaveRead)  # creating list of rwaves
                    DicomFiles_WithRwave.append(i)
                else:
                    # print "File {} does not have a pulse train".format(i)
                    DicomNoRwave.append(i)
            except TypeError:
                DicomNoRwave.append(i)
                # print "File {} does not have a pulse train".format(i)
        except AttributeError:
            DicomNoRwave.append(i)
            # print "File {} does not have a pulse train" .format(i)

for i in tdmsFiles:
    if i.endswith('.tdms'):
        # count -= 1
        inprogress()
        tdmsFileName = tdmsDirectory + i
        # tdmsData = tdsmParserMultisCopy.parseTDMSfile(TdmsDirectory + filename)   # reading all TDMS data
        tdmsFile = TdmsFile(tdmsFileName)
        try:  # check to see if pulse exists, if not, exclue TDMS file from directory

            # pulse1 = tdmsData[u'Sensor.Run Number Pulse Train'][u'Run Number Pulse Train']
            group = tdmsFile.groups()[3]
            pulse1 = tdmsFile.group_channels(group)[0].data
            if len(pulse1) == 1:
                TdmsFilesNoPulse.append(i)
            else:

                pulseList.append(pulse1)
                TdmsFiles_WithPulse.append(i)

        except KeyError:
            TdmsFilesNoPulse.append(i)  # must create new list because some files have no Pulse Train
            # print " File '" + i + "' does not have a pulse train"

print ""
# print "{} Dicom Files without RWaves".format(len(DicomNoRwave))

for files in DicomNoRwave:
    print files

print ""
# print "{} TDMS Files without Pulse".format(len(TdmsFilesNoPulse))

for files in TdmsFilesNoPulse:
    print files

part2 = time.time()
elapsePart2 = part2 - part1



print ". " * count

"""
    Part 2 - iterating through files, matching associated files
"""

Tdsm_and_pulses = zip(TdmsFiles_WithPulse, pulseList)
Dicom_and_rwaves = zip(DicomFiles_WithRwave, RWaveList)

# dicom_files = 0
MatchList = []
MatchTDMS = []
MatchDICOM = []
expRunList = []
timeSynchList = []
matches = 0
failed = []
iterList = []

for dicom_files, rwaves in Dicom_and_rwaves:
    if dicom_files not in MatchDICOM:
        for tdms_files, pulses in Tdsm_and_pulses:
            if tdms_files not in MatchTDMS:
                # iterations cycle through RWave and Peak variations.

                dcm_range_check = int(dicom_files[-42:-40])
                tdms_range_check = int(tdms_files[-30:-27])
                range_check = abs(dcm_range_check - tdms_range_check)

                if range_check >= 6:
                    # print dcm_range_check, dicom_files
                    # print tdms_range_check, tdms_files
                    # go = raw_input('presss Enter:')

                    donothing = 0
                else:



                    pulse = pulses

                    def findPeaks():
                        Peaks = peakutils.indexes(pulse, thres=0.5 * max(pulse), min_dist=100)
                        return Peaks
                    fp = findPeaks()
                    peaks0 = findPeaks()[:]
                    peaks1 = findPeaks()[1:]
                    peaks2 = findPeaks()[2:]
                    peaks3 = findPeaks()[3:]
                    peakLIST = [peaks0, peaks1, peaks2, peaks3]

                    DeltaVecList = [] # list of dt's for each iteration
                    DeltaAvgList = [] # list of dt avg for each iteration

                    try:

                        RWave = np.array(rwaves)

                    except ValueError:
                        donothing = 0

                    if len(RWave) > len(peaks0):
                        RWave = RWave[2:]

                    if RWave[0] < 270: # RWave from prev trial
                        RWave = RWave[1:]
                    RWave_adjusted = []
                    initial_diff_temp = []
                    for iteration in range(4):

                        peaks = peakLIST[iteration]
                        try:
                            if RWave[0] >= peaks[0]:
                                initial_diff = RWave[0] - peaks[0]  # set rwaves to begin at first peak
                                RWave_adjusted_temp = RWave - initial_diff
                            else:
                                initial_diff = peaks[0] - RWave[0]
                                RWave_adjusted_temp = RWave + initial_diff

                            initial_diff_temp.append(initial_diff)
                            # useful when iteration > 1 when peaks taken away
                            peakBits = peaks0 - peaks0[0]
                        except IndexError:
                            print dicom_files
                            print tdms_files
                        # Find the avg difference between R Wave's and pulse Peaks

                        # find the peak closest to each Rwave

                        # Peaks = Peaks - Peaks[0]
                        DeltaIterationVec = []
                        DeltaIterationAvg = []

                        for wave in range(len(RWave_adjusted_temp)):

                            absDelta_Tvec_perWave = []
                            Delta_Tvec_perWave = []     # diff bt each peak and single rwave, choose closest
                            for peak in range(len(peaks)):
                                if wave <= peak+4:
                                    dt = RWave_adjusted_temp[wave]-peaks[peak]
                                    absDelta_Tvec_perWave.append(abs(dt))
                                    Delta_Tvec_perWave.append(dt)

                            try:

                                Dt = absDelta_Tvec_perWave.index(min(absDelta_Tvec_perWave))
                            except ValueError:
                                print 'ValueError', iteration, wave
                                print absDelta_Tvec_perWave
                            try:

                                Dt = absDelta_Tvec_perWave.index(min(absDelta_Tvec_perWave))
                                Dt = Delta_Tvec_perWave[Dt]
                            except ValueError:
                                print 'ValueError', iteration, wave
                                print absDelta_Tvec_perWave
                                Dt = 1500

                            DeltaIterationVec.append(Dt)

                        Delta_Tavg = round(sum(DeltaIterationVec) / len(DeltaIterationVec), 4)
                        DeltaAvgList.append(Delta_Tavg)
                        Delta_T_opt_vec = [a - Delta_Tavg for a in DeltaIterationVec]
                        DeltaVecList.append(Delta_T_opt_vec)
                        RWave_adjusted.append(RWave_adjusted_temp)
                        it = iteration
                    Max = []
                    for vecs in DeltaVecList:
                        Max.append(max(abs(i) for i in vecs))

                    MaxIdx = Max.index(min(Max))  # iteration that is closest to passing
                    DeltaVecList_1_abs = np.abs(DeltaVecList)
                    DeltaVecList_1 = list(DeltaVecList_1_abs[MaxIdx])
                    Delta_T_avg = DeltaAvgList[MaxIdx]
                    Delta_T_max_index = DeltaVecList_1.index(max(DeltaVecList_1))
                    # Delta_T_opt_vec = [a - Delta_T for a in DeltaVecList_1]
                    Delta_T_max = DeltaVecList_1[Delta_T_max_index]
                    initial_diff = initial_diff_temp[MaxIdx]

                    if abs(Delta_T_max) > 190:
                        donothing = 0
                        # print 'Failure ---------------------------------------------------------- Failure'
                        # print "          TDMS: ", tdms_files[0:3]
                        # print "           DCM: ", dicom_files[-42:-40]
                        # print '     Iteration: ', MaxIdx
                        # print '           Max: ', Delta_T_max
                        # print '         RWave: ', RWave
                        # print '         RWave: ', RWave_adjusted[MaxIdx]- Delta_T
                        # print '         Peaks: ', peaks
                    else:
                        print "Match"
                        # print 'Max', Delta_T_max

                        passed += 1

                        # used for synchronizing files in later modeling
                        timeSynch = initial_diff + Delta_T_avg

                        RWave_adjusted = RWave_adjusted[MaxIdx] - Delta_T_avg

                        Std.append(round(np.std(DeltaVecList_1),2))
                        print Std
                        def getBits(filename):
                            f = open(filename)
                            PulseWidthsFile = csv.reader(f)
                            Bits = []
                            for row in PulseWidthsFile:
                                Bits.append(row[0])
                            return map(int, Bits)  # map converts list of strings to list of int

                        # loading bit markers (each at specified location)
                        Bits = getBits('../PulseWidths300.csv')
                        Bits_Original = Bits  # to plot unaltered bits if desired
                        # ensuring bit markers line up with pulse peaks in each iteration
                        Bits = Bits - (peakLIST[0][0] - peakLIST[0][0])

                        # Create the x values for pulse plot
                        def createXvals_msec(pulse):
                            xVals_msec = []
                            for index in range(len(pulse)):  # change peaks to peaks0 to plot original peak set
                                xVals_msec += [range(len(pulse))[index] - len(pulse[0:peaks1[0]])]
                            return xVals_msec

                        def createBinary(Bits, peakBits):
                            # may need to add margin somehow. right now peaks and bits have to be exactly equal
                            #
                            # notice always adding new number to beginning, hence in reverse
                            # because binary signal was created in reverse
                            Binary = ""
                            for b in Bits:
                                if b in peakBits:  # if a peak is located at bit marker -> binary = 1, else 0
                                    Binary = '1' + Binary
                                else:
                                    Binary = '0' + Binary

                            SubjID_bin = Binary[0:8]
                            ExpRunNum_bin = Binary[9:18]
                            SubjID_dec = str(int(SubjID_bin, 2))
                            ExpRunNum_dec = str(int(ExpRunNum_bin, 2))
                            return Binary, ExpRunNum_bin, SubjID_bin, ExpRunNum_dec, SubjID_dec

                        SubjID_dec = createBinary(Bits, peakBits)[-1]

                        def make3digits(input):
                            if len(input) == 1:
                                output = '00' + str(input)
                            elif len(input) == 2:
                                output = '0' + str(input)
                            else:
                                output = str(input)
                            return output

                        ExpRunNum_dec = createBinary(Bits, peakBits)[-2]
                        subID = make3digits(SubjID_dec)
                        expRun = make3digits(ExpRunNum_dec)
                        expRunList.append(expRun)
                        iterList.append(iteration)

                        if Plot:

                            reader = sitk.ImageFileReader()
                            reader.SetFileName(dicomDirectory + dicom_files)

                            img = reader.Execute()
                            zSlice = -1
                            img = img[0:880, 658:702, zSlice]

                            nda = sitk.GetArrayFromImage(img)

                            filtered = []
                            for rows, columns in enumerate(nda):
                                rowList = []
                                for cols in range(len(columns)):
                                    if any(nda[rows][cols] > 160):
                                        rowList.append([0, 0, 0])
                                    else:
                                        rowList.append([255, 0, 0])
                                filtered.append(rowList)

                            fig, (ax1, ax2) = plt.subplots(nrows=2, ncols=1, figsize=(18, 9), dpi=90)
                            ax1 = plt.subplot2grid((6, 1), (2, 0), rowspan=1)
                            ax2 = plt.subplot2grid((6, 1), (4, 0), rowspan=1)
                            ax0 = plt.subplot2grid((6, 1), (1, 0), rowspan=1)

                            fig.tight_layout(pad=3, w_pad=.5, h_pad=2)
                            ax2.imshow(nda[:,:,0], interpolation=None, cmap='gray')

                            # plotting pulse
                            # ax1.plot(createXvals_msec(pulse), pulse, color='black', linewidth=2, label='$Pulse$')
                            ax1.plot(pulse, color='black', linewidth=1.5)
                            ax1.set_xlim(0, 8000)
                            ax1.set_ylim(-0.1, 0.6)
                            ax0.plot(pulse, color='black', linewidth=1.5)
                            ax0.set_xlim(0, 8000)
                            ax0.set_ylim(-0.1, 0.6)

                            # ax1.annotate('', xy=(RWave[0], 0.52), xytext=(RWave_adjusted[0], 0.52),
                            #              arrowprops=dict(arrowstyle="<|-|>", color="red", lw=2, alpha=.6))
                            Xcord = RWave[0] + (RWave_adjusted[0] - RWave[0]) / 2
                            #
                            # ax1.annotate('dTavg: ' + str(round(timeSynch,2)) + " It: " + str(iteration),
                            #              xy=(Xcord, 0.52),
                            #              xycoords='data',
                            #              xytext=(-120, 60), textcoords='offset points',
                            #              size=12,
                            #              bbox=dict(boxstyle="round,pad=.2", fc='white', ec='blue', lw=2, alpha=.6),
                            #              arrowprops=dict(arrowstyle="-", color='red', lw=2, alpha=.6,
                            #                              connectionstyle="angle,angleA=0,angleB=90,rad=5"))

                            # plotting adjusted R Wave
                            # RWave_adjusted = RWave1 # comment/uncomment to plot synched RWave with  ORIGINAL pulse
                            for wave in RWave_adjusted:

                                if wave == RWave_adjusted[0]:
                                    ax1.axvline(x=wave, linewidth=2, color='green', linestyle='--')
                                else:
                                    ax1.axvline(x=wave, linewidth=2, color='green', linestyle='--')

                            # plotting UN adjusted R Wave
                            for wave in RWave:

                                if wave == RWave[0]:
                                    ax0.axvline(x=wave, linewidth=2, color='red', linestyle='--')
                                else:
                                    ax0.axvline(x=wave, linewidth=2, color='red', linestyle='--')

                            # plotting pulse widths for encoding/decoding
                            Bits = Bits_Original + peaks0[0]  # comment/uncomment when want to plot original bits
                            # for bits in Bits:
                            #     # if bits == Bits[0]:
                            #     #     ax1.axvline(x=Bits[0], linewidth=2, color='gray', linestyle='-', alpha=0.4)
                            #
                            #     if bits == Bits[2]:
                            #         ax1.axvline(x=Bits[2], linewidth=2, color='gray', linestyle='-', alpha=0.4)
                            #         ax0.axvline(x=Bits[2], linewidth=2, color='gray', linestyle='-', alpha=0.4)
                            #
                            #     elif bits == Bits[12]:
                            #         ax1.axvline(x=Bits[12], linewidth=2, color='gray', linestyle='-', alpha=0.4)
                            #         ax0.axvline(x=Bits[12], linewidth=2, color='gray', linestyle='-', alpha=0.4)

                            for bits in range(len(Bits)):
                                if bits > 2:
                                    ax1.axvline(x=Bits[bits], linewidth=1.5, color='gray', linestyle='-', alpha=0.4)
                                    ax0.axvline(x=Bits[bits], linewidth=1.5, color='gray', linestyle='-', alpha=0.4)

                            trial = "MULTIS" + str(subID) + " - Trial # " + str(expRun)

                            ax1.set_xlabel("time(ms)")
                            ax1.set_ylabel("Volts")
                            # ax1.set_title(trial + " - " + 'Dicom #: ' + dicom_files[-42:-40] , color='black')
                            # ax1.legend(loc='upper right', prop={'size': 10}, borderpad=0.1, handlelength=5)

                            # ax2.set_title("DICOM pulse extracted from last image in sequence. Pattern, not start"
                            #               " time, should match pulse above", fontsize=14)
                            # ax2.set_xlabel(
                            # "* if pulse not present, error when recording. However, pulse data still collected ",
                            # fontsize=10)

                            frame1 = plt.gca()
                            frame1.axes.get_xaxis().set_ticks([])
                            frame1.axes.get_yaxis().set_ticks([])

                            # plt.draw()
                            if not os.path.exists(Directory + '/FileAssociation'):
                                os.makedirs(Directory + '/FileAssociation')
                            plt.savefig(Directory + '/FileAssociation/' + tdms_files[-30:-5] + '.eps', format='eps')
                            plt.close()
                            matches += 1


                            timeSynchList.append(timeSynch)


                            def removeZeros_binary(num):
                                num = num[0:3]
                                Binary = 0
                                number = 0
                                if num[0] == '0':
                                    if num[1] == '0':
                                        number = num[2]
                                        Binary = bin(int(number))[2:]
                                    else:
                                        number = num[1:]
                                        Binary = bin(int(number))[2:]

                                return number, Binary
                        else:
                            donothing = 0

                        # Renaming Matched Files
                        try:
                            os.rename(dicomDirectory + dicom_files,
                                      dicomDirectory + expRun + '_' + dicom_files[-61:])
                            # MatchDICOM.append(expRun + '_' + dicom_files[-61:])
                        except OSError:

                            try:
                                os.rename(dicomDirectory + dicom_files,
                                          dicomDirectory + expRun + '_' + dicom_files[:])
                                # MatchDICOM.append(RunNum + '_' + dicom_files[:])
                            except OSError:
                                print "check file length"
                                print dicom_files
                                exit()

                        os.rename(tdmsDirectory + tdms_files, tdmsDirectory + tdms_files[-30:])

                        MatchList.append(SubjectID + '_' + expRun)
                        MatchTDMS.append(tdms_files)
                        MatchDICOM.append(dicom_files)

                        # print SubjectID
                        print "       TDMS", expRun
                        print "        DCM ", dicom_files[-42:-40]
                        print ". " * count
                        count -= 1
                        break
                        # plt.show()
            else:
                continue
    else:
        donothing = 0
