import dicom
import numpy as np
import matplotlib.pyplot as plt
import peakutils
import csv
from nptdms import TdmsFile
import SimpleITK as sitk
import os
import fileName

Directory, SubjectID, tdmsDirectory, dicomDirectory = fileName.setDirectory(0)


if not os.path.exists(Directory + '/RwaveSelctionPNG'):
    os.makedirs(Directory + '/RwaveSelctionPNG')

fileNames = fileName.setFileNames(0)[-1] # for print purposes only
filesLength = len(fileNames)
def inprogress():
    print ("."),

print ". " * filesLength * 4

for i in range(filesLength):
    for iteration in range(4):
        # print i, iteration
        inprogress()

        dicomFiles, tdmsFiles, filesLength = fileName.setFileNames(i)
        trial = tdmsFiles

        tdmsDirectory = tdmsDirectory + tdmsFiles
        dicomDirectory = dicomDirectory + dicomFiles

        dataset = dicom.read_file(dicomDirectory)  # reading all DICOM files
        RWave1 = np.array(dataset.RWaveTimeVector)[:]
        RWave2 = RWave1[1:]
        RWave3 = RWave1[2:-1]
        RWave = [RWave1, RWave1, RWave1, RWave1, RWave1, RWave1, RWave2, RWave2]
        RWave = RWave[iteration]
        RWave = RWave - RWave[0]

        # print "Rwave ", len(RWave)

        tdmsFile = TdmsFile(tdmsDirectory)
        group = tdmsFile.groups()[3]
        channels = tdmsFile.group_channels(group)[0]
        pulse = channels.data

        def findPeaks():
            peaks = peakutils.indexes(pulse, thres=0.5 * max(pulse), min_dist=100)
            return peaks

        # peaks = findPeaks()
        peaks0 = findPeaks()[:]
        peaks1 = findPeaks()[1:-1]
        peaks2 = findPeaks()[2:-1]
        peaks3 = findPeaks()[3:-1]
        peaks4 = findPeaks()[:-2]
        peaks5 = findPeaks()[4:]
        peakLIST = [peaks0, peaks1, peaks2, peaks3, peaks4, peaks5, peaks4, peaks2]
        peaks = peakLIST[iteration]
        # print len(peaks)
        #
        # Set the pulse to begin at peak[0] at time = 0, ie first peak is at zero
        def shiftPeaks(peaks):
            peaks_adjusted = peaks - peaks[0]
            return peaks_adjusted

        shiftedPeaks = shiftPeaks(peaks)

        def findDeltaT(rwave, Peaks):
            Delta_T1 = []
            for i in range(len(Peaks)):
                dt = rwave[i] - Peaks[i]
                Delta_T1.append(dt)

            Delta_T = sum(Delta_T1) / len(Peaks)
            return Delta_T, Delta_T1
        try:
            delta_T_opt = findDeltaT(RWave,shiftedPeaks)[0]
            dt_vector = findDeltaT(RWave,shiftedPeaks)[1]
            # print "Delta T opt is {} ".format(delta_T_opt)
            # print "Max dt is {} ".format(max(dt_vector))
        except IndexError:
            donothing= 0

        def getBits(filename):
            f = open(filename)
            PulseWidthsFile = csv.reader(f)
            Bits = []
            for row in PulseWidthsFile:
                Bits.append(row[0])
            return map(int, Bits)  # map converts list of strings to list of int

        Bits = getBits('../PulseWidths300.csv')
        Bits = Bits - (peakLIST[iteration][0]- peakLIST[0][0])

        def createBinary(Bits, peaks_adjusted_ALLpeaks):

            Binary = ""
            for b in Bits:
                # may need to add margin somehow. right now peaks and bits have to be exactly equal
                #
                # Just Do it all in one step.  You don't use Binary1 for anything else.
                if b in peaks_adjusted_ALLpeaks:
                    Binary = '1' + Binary
                else:
                    Binary = '0' + Binary

            ExpRunNum_bin = Binary[0:8]
            SubjID_bin = Binary[9:18]
            SubjID_dec = str(int(SubjID_bin, 2))
            ExpRunNum_dec = str(int(ExpRunNum_bin, 2))
            return Binary, ExpRunNum_bin, SubjID_bin, ExpRunNum_dec, SubjID_dec

        SubjID_dec = createBinary(Bits, shiftedPeaks)[-2]

        def make3digits(input):
            if len(input) == 1:
                output = '00' + str(input)
            elif len(input) == 2:
                output = '0' + str(input)
            else:
                output = str(input)
            return output

        ExpRunNum_dec = createBinary(Bits, shiftedPeaks)[-1]
        subID = make3digits(SubjID_dec)
        expRun = make3digits(ExpRunNum_dec)

        # print "Sub {}".format(subID)
        # print "Exp Run {}".format(expRun)
        def createXvals_msec(pulse):
            xVals_msec = []
            for index in range(len(pulse)):
                xVals_msec = xVals_msec + [range(len(pulse))[index] - len(pulse[0:peaks[0]])]
            return xVals_msec
        if iteration == 0:
            reader = sitk.ImageFileReader()
            reader.SetFileName(dicomDirectory)
            img = reader.Execute()
            zSlice = -1
            img = img[0:880, 658:702, zSlice]
            nda = sitk.GetArrayFromImage(img)

            # filtered = []
            # for rows, columns in enumerate(nda):
            #     rowList = []
            #     for cols in range(len(columns)):
            #         if any(nda[rows][cols] > 160):
            #             rowList.append([0, 0, 0])
            #         else:
            #             rowList.append([255, 0, 0])
            #     filtered.append(rowList)
            #
            # filtered = np.array(filtered)

        plot = True
        if plot is True:
            fig, (ax1, ax2) = plt.subplots(nrows=2, ncols=1, figsize=(18, 9), dpi=90)

            ax1 = plt.subplot2grid((7, 1), (0, 0), rowspan=5)
            ax2 = plt.subplot2grid((7, 1), (5, 0), rowspan=2)

            fig.tight_layout(pad=3, w_pad=.5, h_pad=2)

            if iteration == 0:
                ax2.imshow(nda[:,:,0], interpolation=None, cmap='gray')

                ax2.set_title("DICOM pulse extracted from last image in sequence", fontsize=14)
                ax2.set_xlabel("* if pulse not present, error when recording. However, pulse data still collected ", fontsize=10)
            ax1.plot(createXvals_msec(pulse), pulse, color='black', linewidth=2, label='$Pulse$')

            # for bits in Bits:
            #     if bits == Bits[0]:
            #         ax1.axvline(x=Bits[0], linewidth=2, color='cyan', linestyle='-', label='$Starting Pulse$')
            #
            #     elif bits == Bits[2]:
            #         ax1.axvline(x=Bits[2], linewidth=2, color='purple', linestyle='-', label='$Exp. Run #$')
            #     elif bits == Bits[12]:
            #         ax1.axvline(x=Bits[12], linewidth=2, color='green', linestyle='-', label='$Subj. ID$')

            for bits in range(len(Bits)):
                ax1.axvline(x=Bits[bits], linewidth=.5, color='gray', linestyle='-')

            # plotting adjusted R Wave
            for wave in RWave:

                if wave == RWave[0]:
                    ax1.axvline(x=wave, linewidth=2, color='green', linestyle='--', label='$R Wave Adjusted$')
                else:
                    ax1.axvline(x=wave, linewidth=2, color='green', linestyle='--')

            ax1.set_xlabel("time(ms)")
            ax1.set_ylabel("Volts")
            ax1.set_title(trial)
            ax1.legend(loc='best')



            plt.savefig(Directory + '/RwaveSelctionPNG/' + tdmsFiles[0:-5] + '_' + str(iteration) + '.png')
            plt.close()