import dicom
import numpy as np
import matplotlib
matplotlib.use('TkAgg')
import matplotlib.pyplot as plt
import peakutils
import csv
import os
import time
import SimpleITK as sitk
from nptdms import TdmsFile
import xml.etree.ElementTree as ET
import fileName
import lxml.etree as etree
import warnings
import tdsmParserMultis
import zipfile
import shutil
start = time.time()
from cXMLparser import getAcceptance
warnings.simplefilter('ignore', UserWarning)
warnings.simplefilter('ignore', RuntimeWarning)
"""
The Manual mathing script will match based on 1 or 2 r waves in two files THAT ARE ALREADY KNOWN TO MATCH

This script was written to associate files for the MULTIS in vivo and in vitro experimentation.
It is a supplement to the FileAssociation.py script. This script is intended to manually match
2 files at a time that are known to be matching.

    Original Author:
        Tyler Schimmoeller
        Department of Biomedical Engineering
        Lerner Research Institute
        Cleveland Clinic
        Cleveland, OH
        schimmt@ccf.org


    Part 1 - Reading files, only extracting tdms files with pulse and dicom files with rwave
"""



"""

This beta matcher does not require user to manually select RWaves. However, it does require paired file inputs of
known matches

"""

# import time
# print("something")
# time.sleep(30)    # pause 5.5 seconds
# print("something")

Directory , SubjectID, tdmsDirectory, dicomDirectory = fileName.setDirectory(1)



# ans = raw_input('Unpack files from Export[y/n]')
# unpack = False
#
# if ans == 'y':
#     zip_ref = zipfile.ZipFile('../MULTIS_trials/EXPORT/' + SubjectID + '.zip', 'r')
#     zip_ref.extractall(path='../MULTIS_trials')
#     zip_ref.close()
#     shutil.move('../MULTIS_trials/EXPORT', '../MULTIS_trials/' + SubjectID)
#     os.rename('../MULTIS_trials/' + SubjectID + '/EXPORT', '../MULTIS_trials/' + SubjectID + '/Ultrasound')

Plot = True

passed = 0

dicomFiles = sorted(os.listdir(dicomDirectory))

lengthOfLongFile = len(dicomFiles[16]) # for renaming purposes

for files in dicomFiles: # change single digit trial numbers from '1' to '01' for alphabetical order purposes
    # print lengthOfLongFile, len(files)
    if len(files) < lengthOfLongFile:
        newFilename = files[0:-40] + '0' + files[-40:]
        os.rename(dicomDirectory + files, dicomDirectory + newFilename)

tdmsFiles = sorted(os.listdir(tdmsDirectory))
dicomFiles = sorted(os.listdir(dicomDirectory))

part1 = time.time()

def inprogress():
    print ("."),
count = len(tdmsFiles)
print ". " * count

pulseList = []
TdmsFilesNoPulse = []
TdmsFiles_WithPulse = []
RWaveList = []
DicomNoRwave = []
DicomFiles_WithRwave = []
Std = []
for i in dicomFiles:
    if i.endswith('.IMA'):
        try:
            dataset = dicom.read_file(dicomDirectory + i, stop_before_pixels=True)  # reading all DICOM files
            RwaveRead = dataset.RWaveTimeVector
            try:
                if len(RwaveRead) > 1:
                    RWaveList.append(RwaveRead)  # creating list of rwaves
                    DicomFiles_WithRwave.append(i)
                else:
                    # print "File {} does not have a pulse train".format(i)
                    DicomNoRwave.append(i)
            except TypeError:
                DicomNoRwave.append(i)
                # print "File {} does not have a pulse train".format(i)
        except AttributeError:
            DicomNoRwave.append(i)
            # print "File {} does not have a pulse train" .format(i)

for i in tdmsFiles:
    if i.endswith('.tdms'):
        # count -= 1
        inprogress()
        tdmsFileName = tdmsDirectory + i
        tdmsData = tdsmParserMultis.parseTDMSfile(tdmsDirectory + i)   # reading all TDMS data
        tdmsFile = TdmsFile(tdmsFileName)
        try:  # check to see if pulse exists, if not, exclue TDMS file from directory

            pulse1 = tdmsData[u'Sensor.Run Number Pulse Train'][u'Run Number Pulse Train']
            # group = tdmsFile.groups()[5]
            # print i
            # print group
            # pulse1 = tdmsFile.group_channels(group)[0].data
            pulseList.append(pulse1)
            TdmsFiles_WithPulse.append(i)

        except KeyError:
            TdmsFilesNoPulse.append(i)  # must create new list because some files have no Pulse Train
            # print " File '" + i + "' does not have a pulse train"

print ""
# print "{} Dicom Files without RWaves".format(len(DicomNoRwave))

for files in DicomNoRwave:
    print files

print ""
# print "{} TDMS Files without Pulse".format(len(TdmsFilesNoPulse))

for files in TdmsFilesNoPulse:
    print files

part2 = time.time()
elapsePart2 = part2 - part1



print ". " * count

"""
    Part 2 - iterating through files, matching associated files
"""

Tdsm_and_pulses = zip(TdmsFiles_WithPulse, pulseList)
Dicom_and_rwaves = zip(DicomFiles_WithRwave, RWaveList)

# dicom_files = 0
MatchList = []
MatchTDMS = []
MatchDICOM = []
expRunList = []
timeSynchList = []
matches = 0
failed = []
iterList = []

for dicom_files, rwaves in Dicom_and_rwaves:
    if dicom_files not in MatchDICOM:
        for tdms_files, pulses in Tdsm_and_pulses:
            if tdms_files not in MatchTDMS:
                # iterations cycle through RWave and Peak variations.

                dcm_range_check = int(dicom_files[-41:-39])
                tdms_range_check = int(tdms_files[-31:-28])
                range_check = abs(dcm_range_check - tdms_range_check)

                if range_check > 36:
                    # print dcm_range_check, dicom_files
                    # print tdms_range_check, tdms_files
                    # go = raw_input('presss Enter:')

                    donothing = 0
                if range_check < 32:
                # print dcm_range_check, dicom_files
                # print tdms_range_check, tdms_files
                # go = raw_input('presss Enter:')
                    donothing = 0
                else:



                    pulse = pulses

                    def findPeaks():
                        Peaks = peakutils.indexes(pulse, thres=0.5 * max(pulse), min_dist=100)
                        return Peaks
                    fp = findPeaks()
                    peaks0 = findPeaks()[:]
                    peaks1 = findPeaks()[1:]
                    peaks2 = findPeaks()[2:]
                    peaks3 = findPeaks()[3:]
                    peakLIST = [peaks0, peaks1, peaks2, peaks3]

                    DeltaVecList = [] # list of dt's for each iteration
                    DeltaAvgList = [] # list of dt avg for each iteration

                    try:

                        RWave = np.array(rwaves)

                    except ValueError:
                        donothing = 0

                    if len(RWave) > len(peaks0):
                        RWave = RWave[2:]

                    if RWave[0] < 200: # RWave from prev trial
                        RWave = RWave[1:]
                    RWave_adjusted = []
                    initial_diff_temp = []
                    for iteration in range(4):

                        peaks = peakLIST[iteration]
                        if RWave[0] >= peaks[0]:
                            initial_diff = RWave[0] - peaks[0]  # set rwaves to begin at first peak
                            RWave_adjusted_temp = RWave - initial_diff
                        else:
                            initial_diff = peaks[0] - RWave[0]
                            RWave_adjusted_temp = RWave + initial_diff

                        initial_diff_temp.append(initial_diff)
                        # useful when iteration > 1 when peaks taken away
                        peakBits = peaks0 - peaks0[0]

                        # Find the avg difference between R Wave's and pulse Peaks

                        # find the peak closest to each Rwave

                        # Peaks = Peaks - Peaks[0]
                        DeltaIterationVec = []
                        DeltaIterationAvg = []

                        for wave in range(len(RWave_adjusted_temp)):

                            absDelta_Tvec_perWave = []
                            Delta_Tvec_perWave = []     # diff bt each peak and single rwave, choose closest
                            for peak in range(len(peaks)):
                                if wave <= peak+4:
                                    dt = RWave_adjusted_temp[wave]-peaks[peak]
                                    absDelta_Tvec_perWave.append(abs(dt))
                                    Delta_Tvec_perWave.append(dt)

                            try:

                                Dt = absDelta_Tvec_perWave.index(min(absDelta_Tvec_perWave))
                            except ValueError:
                                print 'ValueError', iteration, wave
                                print absDelta_Tvec_perWave
                            Dt = Delta_Tvec_perWave[Dt]
                            DeltaIterationVec.append(Dt)

                        Delta_Tavg = round(sum(DeltaIterationVec) / len(DeltaIterationVec), 4)
                        DeltaAvgList.append(Delta_Tavg)
                        Delta_T_opt_vec = [a - Delta_Tavg for a in DeltaIterationVec]
                        DeltaVecList.append(Delta_T_opt_vec)
                        RWave_adjusted.append(RWave_adjusted_temp)
                        it = iteration
                    Max = []
                    for vecs in DeltaVecList:
                        Max.append(max(abs(i) for i in vecs))

                    MaxIdx = Max.index(min(Max))  # iteration that is closest to passing
                    DeltaVecList_1_abs = np.abs(DeltaVecList)
                    DeltaVecList_1 = list(DeltaVecList_1_abs[MaxIdx])
                    Delta_T_avg = DeltaAvgList[MaxIdx]
                    Delta_T_max_index = DeltaVecList_1.index(max(DeltaVecList_1))
                    # Delta_T_opt_vec = [a - Delta_T for a in DeltaVecList_1]
                    Delta_T_max = DeltaVecList_1[Delta_T_max_index]
                    initial_diff = initial_diff_temp[MaxIdx]

                    if abs(Delta_T_max) > 190:
                        donothing = 0
                        # print 'Failure ---------------------------------------------------------- Failure'
                        # print "          TDMS: ", tdms_files[0:3]
                        # print "           DCM: ", dicom_files[-41:-39]
                        # print '     Iteration: ', MaxIdx
                        # print '           Max: ', Delta_T_max
                        # print '         RWave: ', RWave
                        # print '         RWave: ', RWave_adjusted[MaxIdx]- Delta_T
                        # print '         Peaks: ', peaks
                    else:
                        print "Match"
                        # print 'Max', Delta_T_max

                        passed += 1

                        # used for synchronizing files in later modeling
                        timeSynch = initial_diff + Delta_T_avg

                        RWave_adjusted = RWave_adjusted[MaxIdx] - Delta_T_avg

                        Std.append(round(np.std(DeltaVecList_1),2))

                        def getBits(filename):
                            f = open(filename)
                            PulseWidthsFile = csv.reader(f)
                            Bits = []
                            for row in PulseWidthsFile:
                                Bits.append(row[0])
                            return map(int, Bits)  # map converts list of strings to list of int

                        # loading bit markers (each at specified location)
                        Bits = getBits('../PulseWidths300.csv')
                        Bits_Original = Bits  # to plot unaltered bits if desired
                        # ensuring bit markers line up with pulse peaks in each iteration
                        Bits = Bits - (peakLIST[0][0] - peakLIST[0][0])

                        # Create the x values for pulse plot
                        def createXvals_msec(pulse):
                            xVals_msec = []
                            for index in range(len(pulse)):  # change peaks to peaks0 to plot original peak set
                                xVals_msec += [range(len(pulse))[index] - len(pulse[0:peaks1[0]])]
                            return xVals_msec

                        def createBinary(Bits, peakBits):
                            # may need to add margin somehow. right now peaks and bits have to be exactly equal
                            #
                            # notice always adding new number to beginning, hence in reverse
                            # because binary signal was created in reverse
                            Binary = ""
                            for b in Bits:
                                if b in peakBits:  # if a peak is located at bit marker -> binary = 1, else 0
                                    Binary = '1' + Binary
                                else:
                                    Binary = '0' + Binary

                            SubjID_bin = Binary[0:8]
                            ExpRunNum_bin = Binary[9:18]
                            SubjID_dec = str(int(SubjID_bin, 2))
                            ExpRunNum_dec = str(int(ExpRunNum_bin, 2))
                            return Binary, ExpRunNum_bin, SubjID_bin, ExpRunNum_dec, SubjID_dec

                        SubjID_dec = createBinary(Bits, peakBits)[-1]

                        def make3digits(input):
                            if len(input) == 1:
                                output = '00' + str(input)
                            elif len(input) == 2:
                                output = '0' + str(input)
                            else:
                                output = str(input)
                            return output

                        ExpRunNum_dec = createBinary(Bits, peakBits)[-2]
                        subID = make3digits(SubjID_dec)
                        expRun = make3digits(ExpRunNum_dec)
                        expRunList.append(expRun)
                        iterList.append(iteration)

                        if Plot:

                            reader = sitk.ImageFileReader()
                            reader.SetFileName(dicomDirectory + dicom_files)

                            img = reader.Execute()
                            zSlice = -1
                            img = img[0:880, 658:702, -1]

                            nda = sitk.GetArrayFromImage(img)

                            filtered = []
                            for rows, columns in enumerate(nda):
                                rowList = []
                                for cols in range(len(columns)):
                                    if any(nda[rows][cols] > 160):
                                        rowList.append([0, 0, 0])
                                    else:
                                        rowList.append([255, 0, 0])
                                filtered.append(rowList)

                            fig, (ax1, ax2) = plt.subplots(nrows=2, ncols=1, figsize=(18, 9), dpi=90)
                            ax1 = plt.subplot2grid((7, 1), (0, 0), rowspan=5)
                            ax2 = plt.subplot2grid((7, 1), (5, 0), rowspan=2)

                            fig.tight_layout(pad=3, w_pad=.5, h_pad=2)
                            ax2.imshow(nda[:,:,0], interpolation=None, cmap='gray')

                            # plotting pulse
                            # ax1.plot(createXvals_msec(pulse), pulse, color='black', linewidth=2, label='$Pulse$')
                            ax1.plot(pulse, color='black', linewidth=2.5, label='$Pulse$')
                            ax1.set_xlim(0, 8000)
                            ax1.set_ylim(-0.1, 0.6)

                            ax1.annotate('', xy=(RWave[0], 0.52), xytext=(RWave_adjusted[0], 0.52),
                                         arrowprops=dict(arrowstyle="<|-|>", color="red", lw=2, alpha=.6))
                            Xcord = RWave[0] + (RWave_adjusted[0] - RWave[0]) / 2

                            ax1.annotate('dTavg: ' + str(round(timeSynch,2)) + " It: " + str(iteration),
                                         xy=(Xcord, 0.52),
                                         xycoords='data',
                                         xytext=(-120, 60), textcoords='offset points',
                                         size=12,
                                         bbox=dict(boxstyle="round,pad=.2", fc='white', ec='blue', lw=2, alpha=.6),
                                         arrowprops=dict(arrowstyle="-", color='red', lw=2, alpha=.6,
                                                         connectionstyle="angle,angleA=0,angleB=90,rad=5"))

                            # plotting adjusted R Wave
                            # RWave_adjusted = RWave1 # comment/uncomment to plot synched RWave with  ORIGINAL pulse
                            for wave in RWave_adjusted:

                                if wave == RWave_adjusted[0]:
                                    ax1.axvline(x=wave, linewidth=2, color='green', linestyle='--',
                                                label='$R Wave Adjusted$')
                                else:
                                    ax1.axvline(x=wave, linewidth=2, color='green', linestyle='--')

                            # plotting UN adjusted R Wave
                            for wave in RWave:

                                if wave == RWave[0]:
                                    ax1.axvline(x=wave, linewidth=2, color='red', linestyle=':',
                                                label='$R Wave Un-Adjusted$')
                                else:
                                    ax1.axvline(x=wave, linewidth=2, color='red', linestyle=':')

                            # plotting pulse widths for encoding/decoding
                            Bits = Bits_Original + peaks0[0]  # comment/uncomment when want to plot original bits
                            for bits in Bits:
                                if bits == Bits[0]:
                                    ax1.axvline(x=Bits[0], linewidth=2, color='gray', linestyle='-', alpha=0.4)

                                elif bits == Bits[2]:
                                    ax1.axvline(x=Bits[2], linewidth=2, color='gray', linestyle='-', alpha=0.4)
                                elif bits == Bits[12]:
                                    ax1.axvline(x=Bits[12], linewidth=2, color='gray', linestyle='-', alpha=0.4)

                            for bits in range(len(Bits)):
                                ax1.axvline(x=Bits[bits], linewidth=.5, color='gray', linestyle='-')

                            trial = "MULTIS" + str(subID) + " - Trial # " + str(expRun)

                            ax1.set_xlabel("time(ms)")
                            ax1.set_ylabel("Volts")
                            ax1.set_title(trial + " - " + 'Dicom #: ' + dicom_files[-41:-39] , color='black')
                            ax1.legend(loc='upper right', prop={'size': 10}, borderpad=0.1, handlelength=5)

                            ax2.set_title("DICOM pulse extracted from last image in sequence. Pattern, not start"
                                          " time, should match pulse above", fontsize=14)
                            ax2.set_xlabel(
                            "* if pulse not present, error when recording. However, pulse data still collected ",
                            fontsize=10)

                            frame1 = plt.gca()
                            frame1.axes.get_xaxis().set_ticks([])
                            frame1.axes.get_yaxis().set_ticks([])

                            plt.draw()
                            if not os.path.exists(Directory + '/FileAssociation'):
                                os.makedirs(Directory + '/FileAssociation')
                            plt.savefig(Directory + '/FileAssociation/' + tdms_files[-31:-5] + '.png')
                            plt.close()
                            matches += 1


                            timeSynchList.append(timeSynch)


                            def removeZeros_binary(num):
                                num = num[0:3]
                                Binary = 0
                                number = 0
                                if num[0] == '0':
                                    try:
                                        if num[1] == '0':
                                            number = num[2]
                                            Binary = bin(int(number))[2:]
                                        else:
                                            number = num[1:]
                                            Binary = bin(int(number))[2:]
                                    except ValueError:
                                        print num

                                return number, Binary
                        else:
                            donothing = 0

                        # Renaming Matched Files
                        try:
                            os.rename(dicomDirectory + dicom_files,
                                      dicomDirectory + expRun + '_' + dicom_files[-63:])
                            # MatchDICOM.append(expRun + '_' + dicom_files[-64:])
                        except OSError:

                            try:
                                os.rename(dicomDirectory + dicom_files,
                                          dicomDirectory + expRun + '_' + dicom_files[:])
                                # MatchDICOM.append(RunNum + '_' + dicom_files[:])
                            except OSError:
                                print "check file length"
                                print dicom_files
                                exit()

                        os.rename(tdmsDirectory + tdms_files, tdmsDirectory + tdms_files[-31:])

                        MatchList.append(SubjectID + '_' + expRun)
                        MatchTDMS.append(tdms_files)
                        MatchDICOM.append(dicom_files)

                        # print SubjectID
                        print "       TDMS", expRun
                        print "        DCM ", dicom_files[-41:-39]
                        print ". " * count
                        count -= 1
                        break
                        # plt.show()
            else:
                continue
    else:
        donothing = 0

synchList = sorted(zip(expRunList, iterList, timeSynchList, Std))
TDMS_name_list, Total_Accepted = getAcceptance(Directory)

cnt = 0
resultsFileName = []
resultsAcceptance = []
resultsIteration = []
resultsTimeSynch = []
resultsStd = []
for filename, acceptance in TDMS_name_list:
    RunNumber = filename[0:3]
    resultsFileName.append(filename)
    resultsAcceptance.append(acceptance)

    try:
        if synchList[cnt][0] != RunNumber:
            resultsIteration.append('N/A')
            resultsTimeSynch.append('N/A')
            resultsStd.append('N/A')
        else:
            resultsIteration.append(synchList[cnt][1])
            resultsTimeSynch.append(round(synchList[cnt][2], 4))
            resultsStd.append(round(synchList[cnt][3], 4))
            cnt += 1
    except IndexError:
        resultsIteration.append('N/A')
        resultsTimeSynch.append('N/A')
        resultsStd.append('N/A')
        print "index error"

results = zip(resultsFileName, resultsAcceptance, resultsIteration, resultsTimeSynch, resultsStd)
# timeFile = open(Directory + '/TimeSynchronization.txt', 'wb')

unMatched_but_accepted = []
matched_and_accepted = []
for TDMS_name, acc, iter, time, Std in results:
    if acc == 1:
        if time == 'N/A':
            unMatched_but_accepted.append(TDMS_name)
        else:
            matched_and_accepted.append([TDMS_name, time, Std])
            # timeFile.write("%s %s\n" % (TDMS_name, time))


def prettyPrintXml(xmlFilePathToPrettyPrint):
    """Pretty print the xml file after all frames have been analyzed"""
    assert xmlFilePathToPrettyPrint is not None
    parser = etree.XMLParser(resolve_entities=False, strip_cdata=False)
    document = etree.parse(xmlFilePathToPrettyPrint, parser)
    document.write(xmlFilePathToPrettyPrint, pretty_print=True, encoding='utf-8')


# Check for Analysis directory
analysis_path = Directory + '/TimeSynchronization/'

if not os.path.exists(analysis_path):
    os.makedirs(analysis_path)


ii = 0
for TDMS_name, time_adj, Std in matched_and_accepted:
    # Define path for the analysis folder
    # split_name = os.path.split(TDMS_name)


    # Name of the xml file
    xml_name = analysis_path + TDMS_name[0:26] + '_dT.xml'

    tail = TDMS_name[17:26]

    subID = TDMS_name[4:16]
    root = ET.Element(subID)
    loc = ET.SubElement(root, 'Location')
    ET.SubElement(loc, "Name").text = tail

    ET.SubElement(loc, "dT").text = str(round(time_adj, 4))

    tree = ET.ElementTree(root)
    tree.write(xml_name, xml_declaration=True)

    prettyPrintXml(xml_name)

    ii += 1
diff = 56/2 - len(unMatched_but_accepted)
percentage = round((((diff) / float(56/2)) * 100), 2)

f = open(Directory + '/' + SubjectID + 'readme.txt', 'wb')
g = open(Directory + '/' + SubjectID + '_UnMatched_Accepted.txt', 'wb')

line1 = "This file contains a summary of file matching results for"
line2 = "the Reference Models for Multi-Layer Tissue Structures project."

f.write("%s\n%s\n" % (line1, line2))
f.write("\n")

f.write("Subject ID: %s\n" % ("MULTIS" + str(subID)))
f.write("\n")

f.write("Matched Pairs            : %i \n" % (matches))
f.write("\n")

f.write("DICOM Files (total)      : %i\n" % (len(dicomFiles)))
f.write("DICOM Files (with RWave) : %i\n" % (len(DicomFiles_WithRwave)))
f.write("DICOM Files (w/o RWave)  : %i\n" % (len(DicomNoRwave)))
f.write("TDMS Files (total)       : %i\n" % (len(tdmsFiles)))
f.write("TDMS Files (with Pulse)  : %i\n" % (len(TdmsFiles_WithPulse)))
f.write("TDMS Files (w/o Pulse)   : %i\n" % (len(TdmsFilesNoPulse)))
f.write("\n")

# f.write("Iteration Summary:\n" % ())
# f.write("%s : %s\n" % ("#", "Count"))
# for index in range(7):
#     f.write("%i : %i\n" % (index, len(iterationList[index])))
# f.write("\n")

f.write("Matching and Accepted: %i files - %s %%\n" % (diff, percentage))
f.write("\n")

f.write("Unmatched and Accepted Trials:\n")
f.write("%s %s %s\n" % ("Run#", "Conversion", "Binary"))
for trials in unMatched_but_accepted:
    print trials
    f.write("%s %s %s\n" % (trials, removeZeros_binary(trials)[0], removeZeros_binary(trials)[1]))

f.write("\n")
f.write('[')
for trials in unMatched_but_accepted:
    f.write("['%s',''],\n" % (trials[0:3]))
    g.write("%s\n" % (trials[0:3]))
f.write(']')
f.write("\n")

f.write("\n")
f.write('[')
for trials in unMatched_but_accepted:
    f.write("(),\n")
f.write(']')
f.write("\n")
f.write("%s %s %s %s %s\n" % ("Run#", "Acceptance", "Iteration", "Synch(ms)", "Std"))
for i in range(len(resultsFileName)):
    f.write(
        "%s %s %s %s %s\n" % (results[i][0], str(results[i][1]), results[i][2], str(results[i][3]), results[i][4]))
f.write("\n")

f.write("Manually Matched :\n")
f.write("ExpRun# TimeSynch(ms)\n")

print "{} matched trials".format(matches)
print ""
print "Un-matched DICOM Files : "

# creating figures for unmatched dicoms
for files in DicomFiles_WithRwave:
    if files not in MatchDICOM:

        reader = sitk.ImageFileReader()
        reader.SetFileName(dicomDirectory + files)

        img = reader.Execute()
        img = img[0:880, 658:702, -1]

        nda = sitk.GetArrayFromImage(img)
        # filtered = []
        # for rows, columns in enumerate(nda):
        #     rowList = []
        #     for cols in range(len(columns)):
        #         if any(nda[rows][cols] > 160):
        #             rowList.append([0, 0, 0])
        #         else:
        #             rowList.append([255, 0, 0])
        #     filtered.append(rowList)

        fig1, ax3 = plt.subplots(nrows=1, ncols=1, figsize=(18, 9), dpi=90)
        fig1.tight_layout(pad=4, w_pad=.5, h_pad=2)

        ax3.imshow(nda[:, :, 0], interpolation=None)

        ax3.set_title("DICOM pulse extracted from last image in sequence. The pattern should match.", fontsize=14)
        ax3.set_xlabel(
            "* if pulse not present, error when recording. However, pulse data still collected ",
            fontsize=0)
        frame1 = plt.gca()
        frame1.axes.get_xaxis().set_ticks([])
        frame1.axes.get_yaxis().set_ticks([])
        plt.savefig(Directory + '/FileAssociation/NMa_' + files + '.png')
        os.rename(dicomDirectory + files, dicomDirectory + "NMa_" + files[-64:])
        print "  ", files

    else:
        if len(files) > 66:
            try:
                os.rename(dicomDirectory + files, dicomDirectory + (files[0:4] + files[-63:]))
            except OSError:
                print "Not in directory: "
                print "                  {}".format(files)
                print "                  {}".format(files)

print ""
print "Un-matched TDMS Files : "
for files in TdmsFiles_WithPulse:
    if files not in MatchTDMS:
        os.rename(tdmsDirectory + files, tdmsDirectory + "NMa_" + files[-30:])
        print "  ", files
    else:
        if len(files) > 31:
            os.rename(tdmsDirectory + files, tdmsDirectory + files[-30:])
print""

for files in DicomNoRwave:
    os.rename(dicomDirectory + files, dicomDirectory + "NRW_" + files[-63:])

print ""
print "{} Dicom Files without RWaves".format(len(DicomNoRwave))
print ""
for files in DicomNoRwave:
    print files
print ""
print "{} TDMS Files without Pulse".format(len(TdmsFilesNoPulse))
print ""
for files in TdmsFilesNoPulse:
    print files

print "UnMatched but Accepted: ", len(unMatched_but_accepted)
print Total_Accepted

# A final check
print ""
print ""
print "Final Verification"

dicomFiles = sorted(os.listdir(dicomDirectory))
tdmsFiles = sorted(os.listdir(tdmsDirectory))

TDMS_name_list, Total_Accepted = getAcceptance(Directory)
tdmsList = []
dicomList = []

for trials, acceptance in TDMS_name_list:
    if acceptance == 1:
        for dicom_files in dicomFiles:
            if dicom_files[0:3] == trials[0:3]:
                dicomList.append(trials)
        for tdms_files in tdmsFiles:
            if tdms_files[0:3] == trials[0:3]:
                tdmsList.append(trials)
cnt = 0
for trials in dicomList:
    for tdmsTrials in tdmsList:
        if trials == tdmsTrials:
            cnt += 1
print "Accepted (XML): ", Total_Accepted
print "Dicom: ", len(dicomList)
print "Tdms: ", len(tdmsList)
print "Matchted TDMS with DICOM (checking directory):", cnt

#
# dicomFiles = sorted(os.listdir(dicomDirectory))
# tdmsFiles = sorted(os.listdir(tdmsDirectory))
#
# TDMS_name_list, Total_Accepted = getAcceptance(Directory)
# tdmsList = []
# dicomList = []
#
# for trials, acceptance in TDMS_name_list:
#     if acceptance == 1:
#         for dicom_files in dicomFiles:
#             if dicom_files[0:3] == trials[0:3]:
#                 dicomList.append(trials)
#         for tdms_files in tdmsFiles:
#             if tdms_files[0:3] == trials[0:3]:
#                 tdmsList.append(trials)
# cnt = 0
#
# for trials in dicomList:
#     for tdmsTrials in tdmsList:
#         if trials == tdmsTrials:
#             cnt += 1
# print "Accepted (XML): ", Total_Accepted
# print "Dicom: ", len(dicomList)
# print "Tdms: ", len(tdmsList)
# print "Matchted TDMS with DICOM (checking directory):", cnt
#
if cnt == 56/2:
    print 'running dataQuality'
    import dataQuality
else:
    print 'Missing Matches'
    exit()

