import dicom
import numpy as np
import matplotlib.pyplot as plt
import peakutils
import csv
import os
import time
from nptdms import TdmsFile
import SimpleITK as sitk
start = time.time()
from XMLparser import getAcceptance
import xml.etree.ElementTree as ET
import lxml.etree as etree
import warnings
import fileName
warnings.simplefilter('ignore', RuntimeWarning)
"""

This script was written to associate files for the MULTIS in vivo and in vitro experimentation.
Force data and ultrasound images are collected simultaneously during indentation. In order to
synchronize the force data from the .tdms file (from LABview) and the ultrasound images from
the dicom file, an encoded signal in the form of pulses is sent from the data collection program
(LABview) to the ECG analog input of the ultrasound. On the ultrasound monitor, the pulse can be
seen in place of a typical ECG. The ultrasound software typically uses the ECG data to monitor heart
rate. It does so by calculating the time intervals between each beat, specifically, the time between
each R-Wave peak. By sending a simulated and simplified ECG signal to the ultrasound, we can extract
each of the R-Wave time intervals from the DICOM metadata.

Next, the peak of each pulse in the signal is found. The array of peaks should correspond to the
RWave time vector from the ultrasound. The signal of pulses sent to the ultrasound consists of
    2 start pulses + 10 bit Experimental Run Number + 8 Bit Subject ID + 2 end pulses.

The start and end pulses act as a primer to ensure the 10 and 8 bit signals are always calculated
by the ultrasound.

The 10 and 8 bit signal are a set of binary signals. A pulse represents a 1, and no pulse represents
a zero.

A TDMS file is associated to a DICOM file by comparing the R Wave time vector to the time vector of
the pulses (peak to peak). Because of the limited sampling frequency of the Ultrasound, the peak of
each pulse sent to the ultrasound is calculated plus or minus 0 to 95 ms of the actual pulse.
Sufficient spacing between pulses allows for this margin of error. When comparing signals, the
maximum distance between an RWave and pulse peak cannot exceed max absolute value of error.

    Original Author:
        Tyler Schimmoeller
        Department of Biomedical Engineering
        Lerner Research Institute
        Cleveland Clinic
        Cleveland, OH
        schimmt@ccf.org


    Place all dicom and tdms files in same directory, the dicom files will be pre-appended with
    experimental run number

    Part 1 - Reading files, only extracting tdms files with pulse and dicom files with rwave
"""
Plot = True

# Directory = "../TEST/MULTIS001-1"  # change Directory for each subject
# tdmsDirectory = Directory + "/Data/"

# SubjectID = "MULTIS015-1" # change Directory for each subject
# Directory = "../MULTIS_trials/" + SubjectID
# tdmsDirectory = Directory + "/Data/"
SubjectID = fileName.setDirectory()[1]
Directory = fileName.setDirectory()[0]
tdmsDirectory = fileName.setDirectory()[2]
dicomDirectory = fileName.setDirectory()[3]


print ""  # for clarity in output

dicomFiles = sorted(os.listdir(dicomDirectory))  # This turns a folder of DICOM files in to a list called DicomFiles
if len(dicomFiles) < 3:
    print "No DICOM files present"
    exit()

lengthOfLongFile = len(dicomFiles[16]) # for renaming purposes

for files in dicomFiles: # change single digit trial numbers from '1' to '01' for alphabetical order purposes
    # print lengthOfLongFile, len(files)
    if len(files) < lengthOfLongFile:
        newFilename = files[0:-41] + '0' + files[-41:]
        os.rename(dicomDirectory + files, dicomDirectory + newFilename)

# sorted list of files in directory for purpose of checking file extensions in for loop
dicomFiles = sorted(os.listdir(dicomDirectory))
tdmsFiles = sorted(os.listdir(tdmsDirectory))

part1 = time.time()

# printing a progress line
def inprogress():
    print ("."),

print "{} TDMS Files".format(len(tdmsFiles))
print "{} DICOM Files".format(len(dicomFiles))

# beginning completion countdown
count = len(tdmsFiles)
print ". " * count

pulseList = []
TdmsFilesNoPulse = []
TdmsFiles_WithPulse = []
RWaveList = []
DicomNoRwave = []
DicomFiles_WithRwave = []
Std = []

#   check and store directory for dicom files ending in IMA only
for i in dicomFiles:
    if i.endswith('.IMA'):
        try:
            dataset = dicom.read_file(dicomDirectory + i, stop_before_pixels=True)  # reading all DICOM files
            RwaveRead = dataset.RWaveTimeVector
            try:
                if len(RwaveRead) > 0:
                    RWaveList.append(RwaveRead)  # creating list of rwaves
                    DicomFiles_WithRwave.append(i)
                else:
                    # print "File {} does not have a pulse train".format(i)
                    DicomNoRwave.append(i)
            except TypeError:
                DicomNoRwave.append(i)
                # print "File {} does not have a pulse train".format(i)
        except AttributeError:
            DicomNoRwave.append(i)
            # print "File {} does not have a pulse train" .format(i)

#   check and store directory for tdms files ending in tdms only
for i in tdmsFiles:
    if i.endswith('.tdms'):
        inprogress()
        tdmsFileName = tdmsDirectory + i
        # tdmsData = tdsmParserMultisCopy.parseTDMSfile(TdmsDirectory + filename)   # reading all TDMS data
        tdmsFile = TdmsFile(tdmsFileName)
        try:  # check to see if pulse exists, if not, exclue TDMS file from directory

            # pulse1 = tdmsData[u'Sensor.Run Number Pulse Train'][u'Run Number Pulse Train'] # the old way of parsing
            group = tdmsFile.groups()[3]
            pulse1 = tdmsFile.group_channels(group)[0].data
            pulseList.append(pulse1)
            TdmsFiles_WithPulse.append(i)

        except KeyError:
            TdmsFilesNoPulse.append(i)  # must create new list because some files have no Pulse Train
            # print " File '" + i + "' does not have a pulse train"

print ""
# print "{} Dicom Files without RWaves".format(len(DicomNoRwave))

for files in DicomNoRwave:
    print files

print ""
print "{} TDMS Files without Pulse".format(len(TdmsFilesNoPulse))

for files in TdmsFilesNoPulse:
    print files

part2 = time.time()
elapsePart2 = part2 - part1

print ""
print "Time Part 2 : {}".format(elapsePart2)

"""
    Part 2 - iterating through files, matching associated files
"""

Tdsm_and_pulses = zip(TdmsFiles_WithPulse, pulseList)
Dicom_and_rwaves = zip(DicomFiles_WithRwave, RWaveList)

count = len(DicomFiles_WithRwave)

def removeZeros_binary(num):
    Binary = 0
    number = 0
    num = num[0:3]
    if num[0] == '0':
        if num[1] == '0':
            number = num[2]
            Binary = bin(int(number))[2:]
        else:
            number = num[1:]
            Binary = bin(int(number))[2:]

    return number, Binary

MatchTDMS = []      # list of matched tdms files
MatchDICOM = []     # list of matched dicom files
expRunList = []     # list of trial numbers which have matches
timeSynchList = []  # list of delta T's for each matched trial run
matches = 0

#   Lists of running totals of the iteration in which each trial was matched. Ideally, all would match on iter 0
iter_0 = []
iter_1 = []
iter_2 = []
iter_3 = []
iter_4 = []
iter_5 = []
iter_6 = []
iterList = []

for iteration in range(7):
    for dicom_files, rwaves in Dicom_and_rwaves:
        if dicom_files not in MatchDICOM:   # once a dicom has found a match, it is no longer checked
            for tdms_files, pulses in Tdsm_and_pulses:
                if tdms_files not in MatchTDMS: # once a tdms file has found a match, it is no longer checked
                    # dcm_range_check = int(dicom_files[-42:-40])
                    # tdms_range_check = int(tdms_files[-30:-27])
                    # range_check = abs(dcm_range_check - tdms_range_check)
                    #
                    # if range_check > 10:
                    #     print dcm_range_check, dicom_files
                    #     print tdms_range_check, tdms_files
                    #     go = raw_input('presss Enter:')
                    #
                    #     break
                    # else:

                    # iterations cycle through RWave and Peak variations. The ultrasound has a tendency to skip the first,
                    # second, or fist and second, or last RWaves
                    try:
                        RWave1 = np.array(rwaves)[:]
                        RWave2 = RWave1[1:]
                        RWave3 = RWave1[2:]
                        RWave4 = RWave1[1:]
                        RWave = [RWave1, RWave1, RWave1, RWave1, RWave4, RWave2, RWave2]
                        RWave = RWave[iteration]
                    except ValueError:
                        donothing = 0

                    pulse = pulses

                    def findPeaks():
                        Peaks = peakutils.indexes(pulse, thres=0.5 * max(pulse), min_dist=100)
                        return Peaks

                    peaks0 = findPeaks()[:]
                    peaks1 = findPeaks()[1:]
                    peaks2 = findPeaks()[2:]
                    peaks3 = findPeaks()[3:]
                    peaks4 = findPeaks()[:-1]

                    peakLIST = [peaks0, peaks1, peaks2, peaks4, peaks2, peaks0, peaks4]
                    peaks = peakLIST[iteration]

                    # useful when iteration > 1 when peaks taken away
                    peakBits = peaks0 - peaks[0]

                    # Find the avg difference between R Wave's and pulse Peaks
                    def findDeltaT(rwave, Peaks):
                        Delta_T1 = []
                        for i in range(len(Peaks)):
                            dt = rwave[i] - Peaks[i]
                            Delta_T1.append(dt)
                        Delta_T = sum(Delta_T1) / len(Peaks)
                        return Delta_T

                    # Find a vector of differences to later find max
                    def findDeltaT_vector(rwave, Peaks):
                        Delta_T_Vector = []
                        for rw, p in zip(rwave, Peaks):
                            Delta_T_Vector.append(rw - p)
                        return Delta_T_Vector

                    # Ensure number of RWave number of Peaks, else end
                    if len(RWave) != len(peaks):  # must be equal to continue
                        donothing = True  # necessary for an if loop
                        # print "not equal rwave and peaks"
                    else:
                        # find initial difference, use to adjust rwaves for alignment
                        Delta_T_initial = findDeltaT(RWave, peaks)

                        # used for synchronizing files in later modeling
                        if Delta_T_initial > 0:
                            timeSynch = Delta_T_initial
                        else:  # if DAQ started before ultrasound
                            timeSynch = RWave[0] + Delta_T_initial

                        RWave_adjusted = RWave - Delta_T_initial

                        # now that rwaves are optimally aligned, the max diff between any rwave and associated peak from signal
                        # should not exceed 95 else not a match
                        Delta_T_opt = np.array(findDeltaT_vector(RWave_adjusted, peaks))
                        Std.append(round(np.std(Delta_T_opt), 2))
                        max_dt = max(abs(Delta_T_opt))
                        if max_dt > 120:
                            donothing = 0

                        else:
                            print "Files Match"
                            print "Iteration {} ".format(iteration)

                            # import bit widths to decode pulse, each number represents 1 bit
                            def getBits(filename):
                                f = open(filename)
                                PulseWidthsFile = csv.reader(f)
                                Bits = []
                                for row in PulseWidthsFile:
                                    Bits.append(row[0])
                                return map(int, Bits)  # map converts list of strings to list of int

                            # loading bit markers (each at specified location)
                            Bits = getBits('../PulseWidths300.csv')
                            Bits_Original = Bits  # to plot unaltered bits if desired

                            # ensuring bit markers line up with pulse peaks in each iteration
                            Bits = Bits - (peakLIST[iteration][0] - peakLIST[0][0])

                            # Create the x values for pulse plot
                            def createXvals_msec(pulse):
                                xVals_msec = []
                                for index in range(len(pulse)):  # change peaks to peaks0 to plot original peak set
                                    xVals_msec.append(range(len(pulse))[index] - len(pulse[0:peaks0[0]]))
                                return xVals_msec

                            def createBinary(Bits, peakBits):
                                # may need to add margin somehow. right now peaks and bits have to be exactly equal
                                # which works for now

                                # notice always adding new number to beginning, hence in reverse
                                # because binary signal was created in reverse
                                Binary = ""
                                for b in Bits:
                                    if b in peakBits:  # if a peak is located at bit marker -> binary = 1, else 0
                                        Binary = '1' + Binary
                                    else:
                                        Binary = '0' + Binary

                                SubjID_bin = Binary[0:8]
                                ExpRunNum_bin = Binary[9:18]
                                SubjID_dec = str(int(SubjID_bin, 2))
                                ExpRunNum_dec = str(int(ExpRunNum_bin, 2))
                                return Binary, ExpRunNum_bin, SubjID_bin, ExpRunNum_dec, SubjID_dec

                            SubjID_dec = createBinary(Bits, peakBits)[-1]

                            def make3digits(input):
                                if len(input) == 1:
                                    output = '00' + str(input)
                                elif len(input) == 2:
                                    output = '0' + str(input)
                                else:
                                    output = str(input)
                                return output

                            ExpRunNum_dec = createBinary(Bits, peakBits)[-2]
                            subID = make3digits(SubjID_dec)
                            expRun = make3digits(ExpRunNum_dec)

                            if iteration is 0:
                                iter_0.append(expRun)
                            elif iteration is 1:
                                iter_1.append(expRun)
                            elif iteration is 2:
                                iter_2.append(expRun)
                            elif iteration is 3:
                                iter_3.append(expRun)
                            elif iteration is 4:
                                iter_4.append(expRun)
                            elif iteration is 5:
                                iter_5.append(expRun)
                            elif iteration is 6:
                                iter_6.append(expRun)

                            iterList.append(iteration)
                            timeSynchList.append(timeSynch)
                            expRunList.append(expRun)
                            matches += 1
                            if Plot:
                                # reading the dicom pixel data
                                reader = sitk.ImageFileReader()
                                reader.SetFileName(dicomDirectory + dicom_files)
                                img = reader.Execute()

                                img = img[0:880, 658:702, -1]   # indexed to crop image, and use last image in seq.

                                nda = sitk.GetArrayFromImage(img)

                                # optional filter to create black and white figure
                                # filtered = []
                                # for rows, columns in enumerate(nda):
                                #     rowList = []
                                #     for cols in range(len(columns)):
                                #         if any(nda[rows][cols] > 160):
                                #             rowList.append([0, 0, 0])
                                #         else:
                                #             rowList.append([255, 0, 0])
                                #     filtered.append(rowList)

                                fig, (ax1, ax2) = plt.subplots(nrows=2, ncols=1, figsize=(18, 9), dpi=90)

                                ax1 = plt.subplot2grid((7, 1), (0, 0), rowspan=5)
                                ax2 = plt.subplot2grid((7, 1), (5, 0), rowspan=2)

                                # image data found only in R channel of RGB. hence the '0' in the nda index indicating
                                # use of R channel only
                                ax2.imshow(nda[:, :, 0], interpolation=None, cmap='gray')
                                fig.tight_layout(pad=3, w_pad=.5, h_pad=2)

                                # plotting pulse

                                ax1.plot(pulse, color='black', linewidth=2.5, label='$Pulse$')
                                ax1.set_xlim(0, 8000)
                                ax1.set_ylim(-0.1, 0.6)

                                ax1.annotate('', xy=(RWave[0], 0.52), xytext=(RWave_adjusted[0], 0.52),
                                             arrowprops=dict(arrowstyle="<|-|>", color="blue", lw=2, alpha=.6))
                                Xcord = RWave[0] + (RWave_adjusted[0] - RWave[0]) / 2

                                ax1.annotate('dTavg: ' + str(round(Delta_T_initial, 2)) + " It: " + str(iteration),
                                             xy=(Xcord, 0.52),
                                             xycoords='data',
                                             xytext=(-120, 60), textcoords='offset points',
                                             size=12,
                                             bbox=dict(boxstyle="round,pad=.2", fc='white', ec='blue', lw=2, alpha=.6),
                                             arrowprops=dict(arrowstyle="-", color='blue', lw=2, alpha=.6,
                                                             connectionstyle="angle,angleA=0,angleB=90,rad=5"))

                                # plotting adjusted R Wave
                                # RWave_adjusted = RWave1 # comment/uncomment to plot synched RWave with  ORIGINAL pulse
                                for wave in RWave_adjusted:

                                    if wave == RWave_adjusted[0]:
                                        ax1.axvline(x=wave, linewidth=2, color='green', linestyle='--',
                                                    label='$R Wave Adjusted$')
                                    else:
                                        ax1.axvline(x=wave, linewidth=2, color='green', linestyle='--')

                                # plotting UN adjusted R Wave
                                for wave in RWave:

                                    if wave == RWave[0]:
                                        ax1.axvline(x=wave, linewidth=2, color='red', linestyle=':',
                                                    label='$R Wave Un-Adjusted$')
                                    else:
                                        ax1.axvline(x=wave, linewidth=2, color='red', linestyle=':')

                                # plotting pulse widths for encoding/decoding
                                Bits = Bits_Original + peaks0[0]  # comment/uncomment when want to plot original bits
                                for bits in Bits:
                                    if bits == Bits[0]:
                                        ax1.axvline(x=Bits[0], linewidth=2, color='gray', linestyle='-', alpha=0.4)

                                    elif bits == Bits[2]:
                                        ax1.axvline(x=Bits[2], linewidth=2, color='gray', linestyle='-', alpha=0.4)
                                    elif bits == Bits[12]:
                                        ax1.axvline(x=Bits[12], linewidth=2, color='gray', linestyle='-', alpha=0.4)

                                for bits in range(len(Bits)):
                                    ax1.axvline(x=Bits[bits], linewidth=.5, color='gray', linestyle='-')

                                # optionally plotting pulse peaks
                                # for j in range(len(peaks):
                                #     if j == 0:
                                #         ax1.axvline(x=shiftPeaks(peaks)[j], linewidth=1, color='red',linestyle='-', label ='Peaks')
                                #     else:
                                #         ax1.axvline(x=shiftPeaks(peaks)[j], linewidth=1, color='red', linestyle='-')
                                #
                                # # plotting un-adjusted/raw r wave intervalsi], l
                                # for i in range(len(RWave)):
                                #
                                #     if i == 0:
                                #         ax1.axvline(x=RWave[i], linewidth=1, color='orange',linestyle='--', label ='$R Wave$')
                                #     else:
                                #         ax1.axvline(x=RWave[i], linewidth=1, color='orange',linestyle='--')


                                trial = "MULTIS" + str(subID) + " - Trial # " + str(expRun)

                                ax1.set_xlabel("time(ms)")
                                ax1.set_ylabel("Volts")
                                ax1.set_title("TDMS #:" + trial + " - " + "DICOM #: " + dicom_files[-42:-40])
                                ax1.legend(loc='upper right', prop={'size': 10}, borderpad=0.1, handlelength=5)

                                ax2.set_title("DICOM pulse extracted from last image in sequence. The pattern "
                                              "should match pulse above. Will be offset.", fontsize=14)
                                ax2.set_xlabel(
                                    "* if pulse not present, error when recording. However, pulse data still collected ",
                                    fontsize=10)

                                frame1 = plt.gca()
                                frame1.axes.get_xaxis().set_ticks([])
                                frame1.axes.get_yaxis().set_ticks([])

                                plt.draw()
                                if not os.path.exists(Directory + '/FileAssociationPNG'):
                                    os.makedirs(Directory + '/FileAssociationPNG')
                                plt.savefig(Directory + '/FileAssociationPNG/' + tdms_files[:-5] + '.png')

                            else:
                                donothing = 0

                            # Renaming matched files
                            try:
                                os.rename(dicomDirectory + dicom_files,
                                          dicomDirectory + expRun + '_' + dicom_files[-61:])
                                # MatchDICOM.append(expRun + '_' + dicom_files[-61:])
                            except OSError:

                                try:
                                    os.rename(dicomDirectory + dicom_files,
                                              dicomDirectory + expRun + '_' + dicom_files[:])
                                    # MatchDICOM.append(expRun + '_' + dicom_files[:])
                                except OSError:
                                    print "check file length"
                                    print dicom_files
                                    exit()

                            MatchTDMS.append(tdms_files)
                            MatchDICOM.append(dicom_files)
                            print "Time Synchronization : ", timeSynch
                            print "Delta_T_Opt avg SHOULD BE ZERO : ", findDeltaT(RWave_adjusted, peaks)
                            print "Std : {}".format(Std[-1])
                            print "SubjectID = ", subID
                            print "Experiment Run Num = ", expRun
                            print ". " * count
                            count -= 1
                            # plt.show()
                else:
                    continue
        else:
            donothing = 0



iterationList = [iter_0, iter_1, iter_2, iter_3, iter_4, iter_5, iter_6]

synchList = sorted(zip(expRunList, iterList, timeSynchList, Std))
TDMS_name_list, Total_Accepted = getAcceptance(Directory)

cnt = 0
resultsFileName = []
resultsAcceptance = []
resultsIteration = []
resultsTimeSynch = []
resultsStd = []
for filename, acceptance in TDMS_name_list:
    RunNumber = filename[0:3]
    resultsFileName.append(filename)
    resultsAcceptance.append(acceptance)

    try:
        if synchList[cnt][0] != RunNumber:
            resultsIteration.append('N/A')
            resultsTimeSynch.append('N/A')
            resultsStd.append('N/A')
        else:
            resultsIteration.append(synchList[cnt][1])
            resultsTimeSynch.append(round(synchList[cnt][2], 4))
            resultsStd.append(round(synchList[cnt][3], 4))
            cnt += 1
    except IndexError:
        resultsIteration.append('N/A')
        resultsTimeSynch.append('N/A')
        resultsStd.append('N/A')
        print "index error"

results = zip(resultsFileName, resultsAcceptance, resultsIteration, resultsTimeSynch, resultsStd)
# timeFile = open(Directory + '/TimeSynchronization.txt', 'wb')

unMatched_but_accepted = []
matched_and_accepted = []
for TDMS_name, acc, iter, time, Std in results:
    if acc == 1:
        if iter == 'N/A':
            unMatched_but_accepted.append(TDMS_name)
        else:
            matched_and_accepted.append([TDMS_name, time, Std])
            # timeFile.write("%s %s\n" % (TDMS_name, time))

def prettyPrintXml(xmlFilePathToPrettyPrint):
    """Pretty print the xml file after all frames have been analyzed"""
    assert xmlFilePathToPrettyPrint is not None
    parser = etree.XMLParser(resolve_entities=False, strip_cdata=False)
    document = etree.parse(xmlFilePathToPrettyPrint, parser)
    document.write(xmlFilePathToPrettyPrint, pretty_print=True, encoding='utf-8')

# Check for Analysis directory
analysis_path = Directory + '/TimeSynchronization/'

if not os.path.exists(analysis_path):
    os.makedirs(analysis_path)

ii = 0
for TDMS_name, time_adj, Std in matched_and_accepted:

    #Define path for the analysis folder
    # split_name = os.path.split(TDMS_name)


    # Name of the xml file
    xml_name = analysis_path + TDMS_name[0:25] + '_dT.xml'

    tail = TDMS_name[16:25]

    subID = TDMS_name[4:15]
    root = ET.Element(subID)
    loc = ET.SubElement(root, 'Location')
    ET.SubElement(loc, "Name").text = tail

    ET.SubElement(loc, "dT").text = str(round(time_adj,4))

    tree = ET.ElementTree(root)
    tree.write(xml_name, xml_declaration=True)

    prettyPrintXml(xml_name)

    ii += 1
diff = 56 - len(unMatched_but_accepted)
percentage = round((((diff) / float(56))*100), 2)
f = open(Directory + '/' + SubjectID + 'readme.txt', 'wb')

line1 = "This file contains a summary of file matching results for"
line2 = "the Reference Models for Multi-Layer Tissue Structures project."

f.write("%s\n%s\n" % (line1, line2))
f.write("\n")

f.write("Subject ID: %s\n" % ("MULTIS" + str(subID)))
f.write("\n")

f.write("Matched Pairs            : %i \n" % (matches))
f.write("\n")

f.write("DICOM Files (total)      : %i\n" % (len(dicomFiles)))
f.write("DICOM Files (with RWave) : %i\n" % (len(DicomFiles_WithRwave)))
f.write("DICOM Files (w/o RWave)  : %i\n" % (len(DicomNoRwave)))
f.write("TDMS Files (total)       : %i\n" % (len(tdmsFiles)))
f.write("TDMS Files (with Pulse)  : %i\n" % (len(TdmsFiles_WithPulse)))
f.write("TDMS Files (w/o Pulse)   : %i\n" % (len(TdmsFilesNoPulse)))
f.write("\n")

f.write("Iteration Summary:\n" % ())
f.write("%s : %s\n" % ("#", "Count"))
for index in range(7):
    f.write("%i : %i\n" % (index, len(iterationList[index])))
f.write("\n")

f.write("Matching and Accepted: %i files - %s %%\n" % (diff, percentage))
f.write("\n")

f.write("Unmatched and Accepted Trials:\n")
f.write("%s %s %s\n" % ("Run#", "Conversion", "Binary"))
for trials in unMatched_but_accepted:
    f.write("%s %s %s\n" % (trials, removeZeros_binary(trials)[0], removeZeros_binary(trials)[1]))

f.write("\n")
f.write('[')
for trials in unMatched_but_accepted:
    f.write("['%s',''],\n" % (trials[0:3]))
f.write(']')
f.write("\n")

f.write("\n")
f.write('[')
for trials in unMatched_but_accepted:
    f.write("(),\n")
f.write(']')
f.write("\n")
f.write("%s %s %s %s %s\n" % ("Run#", "Acceptance", "Iteration", "Synch(ms)", "Std"))
for i in range(len(resultsFileName)):
    f.write("%s %s %s %s %s\n" % (results[i][0], str(results[i][1]), results[i][2], str(results[i][3]), results[i][4]))
f.write("\n")

f.write("Manually Matched :\n")
f.write("ExpRun# TimeSynch(ms)\n")

print "{} matched trials".format(matches)
print ""
print "Un-matched DICOM Files : "
g = open(Directory + '/' + SubjectID + 'UnMatched.txt', 'wb')
# creating figures for unmatched dicoms
for files in DicomFiles_WithRwave:
    if files not in MatchDICOM:
        g.write("%s\n" % (files))
        reader = sitk.ImageFileReader()
        reader.SetFileName(dicomDirectory + files)

        img = reader.Execute()
        img = img[0:880, 658:702, -1]

        nda = sitk.GetArrayFromImage(img)
        # filtered = []
        # for rows, columns in enumerate(nda):
        #     rowList = []
        #     for cols in range(len(columns)):
        #         if any(nda[rows][cols] > 160):
        #             rowList.append([0, 0, 0])
        #         else:
        #             rowList.append([255, 0, 0])
        #     filtered.append(rowList)

        fig1, ax3 = plt.subplots(nrows=1, ncols=1, figsize=(18, 9), dpi=90)
        fig1.tight_layout(pad=4, w_pad=.5, h_pad=2)

        ax3.imshow(nda[:, :, 0], interpolation=None)

        ax3.set_title("DICOM pulse extracted from last image in sequence. The pattern should match.", fontsize=14)
        ax3.set_xlabel(
            "* if pulse not present, error when recording. However, pulse data still collected ",
            fontsize=10)
        frame1 = plt.gca()
        frame1.axes.get_xaxis().set_ticks([])
        frame1.axes.get_yaxis().set_ticks([])
        plt.savefig(Directory + '/FileAssociationPNG/NMa_' + files + '.png')
        os.rename(dicomDirectory + files, dicomDirectory + "NMa_" + files[-61:])
        print "  ", files

    else:
        if len(files) > 65:
            try:
                os.rename(dicomDirectory + files, dicomDirectory + (files[0:4] + files[-61:]))
            except OSError:
                print "Not in directory: "
                print "                  {}".format(files)

print ""
print "Un-matched TDMS Files : "
for files in TdmsFiles_WithPulse:
    if files not in MatchTDMS:
        os.rename(tdmsDirectory + files, tdmsDirectory + "NMa_" + files[-30:])
        print "  ", files
    else:
        if len(files) > 30:
            os.rename(tdmsDirectory + files, tdmsDirectory + files[-30:])
print""

for files in DicomNoRwave:
    os.rename(dicomDirectory + files, dicomDirectory + "NRW_" + files[-61:])

print ""
print "{} Dicom Files without RWaves".format(len(DicomNoRwave))
print ""
for files in DicomNoRwave:
    print files
print ""
print "{} TDMS Files without Pulse".format(len(TdmsFilesNoPulse))
print ""
for files in TdmsFilesNoPulse:
    print files

print "UnMatched but Accepted: ", len(unMatched_but_accepted)
print Total_Accepted

# A final check
print ""
print ""
print "Final Verification"

dicomFiles = sorted(os.listdir(dicomDirectory))
tdmsFiles = sorted(os.listdir(tdmsDirectory))

TDMS_name_list, Total_Accepted = getAcceptance(Directory)
tdmsList = []
dicomList = []

for trials, acceptance in TDMS_name_list:
    if acceptance == 1:
        for dicom_files in dicomFiles:
            if dicom_files[0:3] == trials[0:3]:
                dicomList.append(trials)
        for tdms_files in tdmsFiles:
            if tdms_files[0:3] == trials[0:3]:
                tdmsList.append(trials)
cnt = 0
for trials in dicomList:
    for tdmsTrials in tdmsList:
        if trials == tdmsTrials:
            cnt += 1
print "Accepted (XML): ", Total_Accepted
print "Dicom: ", len(dicomList)
print "Tdms: ", len(tdmsList)
print "Matchted TDMS with DICOM (checking directory):", cnt