
import sys
import matplotlib
matplotlib.use("Qt5Agg")
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.lines import Line2D
from PyQt5 import QtCore, QtWidgets
from matplotlib.backends.backend_qt4agg import FigureCanvasQTAgg as FigureCanvas
from matplotlib.backends.backend_qt4agg import NavigationToolbar2QT as NavigationToolbar
from matplotlib.figure import Figure
import numpy as np
import cv2
from cv2 import imread
# from qrangeslider import QRangeSlider
import os
# from skimage import color

class DraggablePoint:

    # http://stackoverflow.com/questions/21654008/matplotlib-drag-overlapping-points-interactively

    lock = None #  only one can be animated at a time

    def __init__(self, parent, x, y, size=0.1):

        self.parent = parent
        self.point = patches.Ellipse((x, y), size, size, fc='r', alpha=0.5, edgecolor='r')
        self.x = x
        self.y = y
        parent.fig.axes[0].add_patch(self.point)
        self.press = None
        self.background = None
        self.connect()

        if self.parent.list_points:
            line_x = [self.parent.list_points[0].x, self.x]
            line_y = [self.parent.list_points[0].y, self.y]

            self.line = Line2D(line_x, line_y, color='r', alpha=0.5)
            parent.fig.axes[0].add_line(self.line)


    def connect(self):

        'connect to all the events we need'

        self.cidpress = self.point.figure.canvas.mpl_connect('button_press_event', self.on_press)
        self.cidrelease = self.point.figure.canvas.mpl_connect('button_release_event', self.on_release)
        self.cidmotion = self.point.figure.canvas.mpl_connect('motion_notify_event', self.on_motion)


    def on_press(self, event):

        if event.inaxes != self.point.axes: return
        if DraggablePoint.lock is not None: return
        contains, attrd = self.point.contains(event)
        if not contains: return
        self.press = (self.point.center), event.xdata, event.ydata
        DraggablePoint.lock = self

        # draw everything but the selected rectangle and store the pixel buffer
        canvas = self.point.figure.canvas
        axes = self.point.axes
        self.point.set_animated(True)
        if self == self.parent.list_points[1]:
            self.line.set_animated(True)
        else:
            self.parent.list_points[1].line.set_animated(True)
        canvas.draw()
        self.background = canvas.copy_from_bbox(self.point.axes.bbox)

        # now redraw just the rectangle
        axes.draw_artist(self.point)

        # and blit just the redrawn area
        canvas.blit(axes.bbox)


    def on_motion(self, event):

        if DraggablePoint.lock is not self:
            return
        if event.inaxes != self.point.axes: return
        self.point.center, xpress, ypress = self.press
        dx = event.xdata - xpress
        dy = event.ydata - ypress
        self.point.center = (self.point.center[0]+dx, self.point.center[1]+dy)

        canvas = self.point.figure.canvas
        axes = self.point.axes
        # restore the background region
        canvas.restore_region(self.background)

        # redraw just the current rectangle
        axes.draw_artist(self.point)

        if self == self.parent.list_points[1]:
            axes.draw_artist(self.line)
        else:
            self.parent.list_points[1].line.set_animated(True)
            axes.draw_artist(self.parent.list_points[1].line)

        self.x = self.point.center[0]
        self.y = self.point.center[1]

        if self == self.parent.list_points[1]:
            line_x = [self.parent.list_points[0].x, self.x]
            line_y = [self.parent.list_points[0].y, self.y]
            self.line.set_data(line_x, line_y)
        else:
            line_x = [self.x, self.parent.list_points[1].x]
            line_y = [self.y, self.parent.list_points[1].y]

            self.parent.list_points[1].line.set_data(line_x, line_y)

        # blit just the redrawn area
        canvas.blit(axes.bbox)


    def on_release(self, event):

        'on release we reset the press data'
        if DraggablePoint.lock is not self:
            return

        self.press = None
        DraggablePoint.lock = None

        # turn off the rect animation property and reset the background
        self.point.set_animated(False)
        if self == self.parent.list_points[1]:
            self.line.set_animated(False)
        else:
            self.parent.list_points[1].line.set_animated(False)

        self.background = None

        # redraw the full figure
        self.point.figure.canvas.draw()

        self.x = self.point.center[0]
        self.y = self.point.center[1]

    def disconnect(self):

        'disconnect all the stored connection ids'

        self.point.figure.canvas.mpl_disconnect(self.cidpress)
        self.point.figure.canvas.mpl_disconnect(self.cidrelease)
        self.point.figure.canvas.mpl_disconnect(self.cidmotion)

class MyGraph(FigureCanvas):
    """A canvas that updates itself every second with a new plot."""

    def __init__(self, parent=None, width=5, height=4, dpi=100):
        self.fig = Figure(figsize=(width, height), dpi=dpi)
        self.axes = self.fig.add_subplot(111)

        self.axes.grid(True)

        FigureCanvas.__init__(self, self.fig)
        self.setParent(parent)

        FigureCanvas.setSizePolicy(self, QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Expanding)
        FigureCanvas.updateGeometry(self)

        self.fig.set_visible(False)

        # To store the 2 draggable points
        self.list_points = []

    def plotDraggablePoints(self, xy1, xy2, size):
        """Plot and define the 2 draggable points of the baseline"""
        self.list_points.append(DraggablePoint(self, xy1[0], xy1[1], size))
        self.list_points.append(DraggablePoint(self, xy2[0], xy2[1], size))
        self.updateFigure()

    def clearFigure(self):
        """Clear the graph"""
        self.axes.clear()
        self.axes.grid(True)
        del (self.list_points[:])
        # self.updateFigure()

    def updateFigure(self):
        """Update the graph. Necessary, to call after each plot"""
        self.fig.suptitle(os.path.split(str(str(self.filenames[self.img_idx])))[1])
        self.draw_idle()

    def fileDialog(self):
        self.clearFigure()
        self.processed_imgs = {}
        self.original_imgs = []
        '''Open images to be analyzed and plot the first image for calibration'''
        self.filenames = QtWidgets.QFileDialog.getOpenFileNames(self,directory=os.getcwd())
        self.filenames = self.filenames[0]
        if not os.path.exists(os.path.join(os.path.split(self.filenames[0])[0], 'Processed')):
            os.makedirs(os.path.join(os.path.split(self.filenames[0])[0], 'Processed'))

        self.img_idx = 0
        for i in self.filenames:
            img = imread(str(i))
            self.original_imgs.append(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
            self.processed_imgs[os.path.split(str(i))[1]] = {}
            self.processed_imgs[os.path.split(str(i))[1]]['Image'] = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

        self.axes.imshow(self.original_imgs[0])
        self.fig.set_visible(True)
        img_size = self.original_imgs[0].shape
        self.plotDraggablePoints([img_size[1]*.6, img_size[0]*0.6], [img_size[1]*.6, img_size[0]*0.4], 50)
        self.updateFigure()

    def getConversion(self, dist_mm):
        p1 = np.array(self.list_points[0].point.center)
        p2 = np.array(self.list_points[1].point.center)
        dist_pix = np.sqrt(sum((p1-p2)**2))
        self.conv_mm_per_pix = float(dist_mm)/dist_pix

    def process_image(self, threshold):
        #threshold_max = float(threshold.displayText())
        img = self.original_imgs[self.img_idx]

        #r,g,b = cv2.split(img)
        img2 = img[int(self.axes.get_ylim()[1]):int(self.axes.get_ylim()[0]),
               int(self.axes.get_xlim()[0]):int(self.axes.get_xlim()[1])]

        blur = cv2.bilateralFilter(img2,9,75,75)

        edges = cv2.Canny(blur, 50, 255)

        width_pix = []
        for n, e in enumerate(edges):
            indices = np.where(e != 0)
            if len(indices[0]) >= 2:
                width_pix.append(indices[0][-1]-indices[0][0])
            if len(indices[0]) > 2:
                for i in indices[0][1:-1]:
                    edges[n][i] = 0

        self.processed_imgs[str(os.path.split(str(self.filenames[self.img_idx]))[1])]['Image'] = img
        self.processed_imgs[str(os.path.split(str(self.filenames[self.img_idx]))[1])]['Edges'] = edges

        try:
            self.conv_mm_per_pix
        except:
            print("ERROR: Calibration of image has not been completed yet!")
            return

        width_mm = np.array(width_pix)*self.conv_mm_per_pix

        self.width_mm[str(os.path.split(str(self.filenames[self.img_idx]))[1])] = width_mm

        print('Width = {} +/- {} mm'.format(np.average(width_mm), np.std(width_mm)))
        # plt.plot(self.width_mm)

        self.axes.imshow(self.processed_imgs[str(os.path.split(str(self.filenames[self.img_idx]))[1])]['Edges'], extent=(int(self.axes.get_xlim()[0]),int(self.axes.get_xlim()[1]),int(self.axes.get_ylim()[0]),int(self.axes.get_ylim()[1])))
        self.axes.imshow(self.processed_imgs[str(os.path.split(str(self.filenames[self.img_idx]))[1])]['Image'], cmap='gray', alpha=0.5)

        self.updateFigure()
        filename_split = os.path.split(self.filenames[self.img_idx][0:-4])

        self.fig.savefig(os.path.join(filename_split[0], 'Processed', filename_split[1] + '_processed.png'))
        # plt.show()


    def process_all_images(self, threshold):
        for i in range(0,len(self.original_imgs)):
            self.img_idx = i
            self.process_image(threshold)

    def ManualThickness(self):
        lines = []
        xy = self.fig.ginput(n=2, timeout=0)
        x_m = [p[0] for p in xy]
        y_m = [p[1] for p in xy]
        line = self.axes.plot(x_m, y_m)
        self.axes.figure.canvas.draw()
        lines.append(line)
        self.updateFigure()
        return xy

    def next_image(self):
        if (self.img_idx + 1) < len(self.original_imgs):
            self.img_idx += 1

        try:
            self.axes.imshow(self.processed_imgs[str(os.path.split(str(self.filenames[self.img_idx]))[1])]['Edges'], extent=(int(self.axes.get_xlim()[0]),int(self.axes.get_xlim()[1]),int(self.axes.get_ylim()[0]),int(self.axes.get_ylim()[1])))
            self.axes.imshow(self.processed_imgs[str(os.path.split(str(self.filenames[self.img_idx]))[1])]['Image'],
                             cmap='gray', alpha=0.5)
        except:
            self.axes.imshow(self.processed_imgs[str(os.path.split(str(self.filenames[self.img_idx]))[1])]['Image'],
                             cmap='gray')
        self.updateFigure()

    def previous_image(self):
        if (self.img_idx - 1) >= 0:
            self.img_idx -= 1

        try:
            self.axes.imshow(self.processed_imgs[str(os.path.split(str(self.filenames[self.img_idx]))[1])]['Edges'], extent=(int(self.axes.get_xlim()[0]),int(self.axes.get_xlim()[1]),int(self.axes.get_ylim()[0]),int(self.axes.get_ylim()[1])))
            self.axes.imshow(self.processed_imgs[str(os.path.split(str(self.filenames[self.img_idx]))[1])]['Image'],
                             cmap='gray', alpha=0.5)
        except:
            self.axes.imshow(self.processed_imgs[str(os.path.split(str(self.filenames[self.img_idx]))[1])]['Image'],
                             cmap='gray')
        self.updateFigure()
    #
    # def save_results(self):
    #     filename_split = os.path.split(self.filenames[0][0:-9])
    #     outputfile = open(os.path.join(filename_split[0], 'Processed', filename_split[1]+'.csv'), 'w')
    #     outputfile.write('ImageName, Major (mm), Minor (mm), Contour Area (mm^2), Ellipse Area (mm^2)\n')
    #     self.img_idx = 0
    #     for f in self.filenames:
    #         cont_area = cv2.contourArea(self.processed_imgs[os.path.split(str(f))[1]]['Contour'][0])*(self.conv_mm_per_pix**2)
    #         axes = np.array(self.processed_imgs[os.path.split(str(f))[1]]['Ellipse'][1])*self.conv_mm_per_pix
    #         ellipse_area = axes[0]*axes[1]*np.pi/4
    #
    #         outputfile.write('{}, {}, {}, {}, {}\n'.format(os.path.split(str(f))[1], np.max(axes), np.min(axes), cont_area, ellipse_area))
    #
    #     print("Results saved")
    #     outputfile.close()

class MainApp(QtWidgets.QMainWindow):
    def __init__(self):
        QtWidgets.QMainWindow.__init__(self)
        self.left = 10
        self.top = 10
        self.title = 'Step 1: Open Images'
        self.width = 640
        self.height = 400
        self._main = QtWidgets.QWidget()
        self.setCentralWidget(self._main)
        self.layout = QtWidgets.QGridLayout(self._main)
        self.initUI()

    def initUI(self):
        self.setWindowTitle(self.title)
        self.setGeometry(self.left, self.top, self.width, self.height)

        # Add spot for images to appear
        self.m = MyGraph(self, width=5, height=4)
        self.m.move(0, 0)
        self.layout.addWidget(self.m, 0, 0, 30, 1)
        self.addToolBar(QtCore.Qt.BottomToolBarArea, NavigationToolbar(self.m, self))

        # Add button to open images
        button = QtWidgets.QPushButton('Open Images', self)
        button.setToolTip('Use this to open images for analysis')
        button.clicked.connect(self.openImages)
        self.layout.addWidget(button, 0, 1,1,5)

        # Add button to calibrate images
        self.le = QtWidgets.QLineEdit()
        self.le.setText('10')
        self.le.setToolTip('Length of line specified by red dots')
        self.layout.addWidget(self.le, 2,1,1,4)
        self.layout.addWidget(QtWidgets.QLabel('mm'),2,5,1,1)
        button = QtWidgets.QPushButton('Calibrate', self)
        button.setToolTip('Click when ready to calibrate')
        button.clicked.connect(self.getCalibration)
        self.layout.addWidget(button, 3, 1, 1, 5)

        # Add buttons to toggle between images
        button = QtWidgets.QPushButton('Prev', self)
        button.setToolTip('Previous Image')
        button.clicked.connect(self.m.previous_image)
        self.layout.addWidget(button, 29, 1,1,2.5)
        button = QtWidgets.QPushButton('Next', self)
        button.setToolTip('Next Image')
        button.clicked.connect(self.m.next_image)
        self.layout.addWidget(button, 29, 4, 1, 2.5)

        # Add buttons to process images
        self.layout.addWidget(QtWidgets.QLabel('Threshold range'), 6, 1, 1,5)
        self.threshold = QtWidgets.QLineEdit()
        self.threshold.setText('125')
        self.layout.addWidget(self.threshold, 7, 1, 1,5)

        button = QtWidgets.QPushButton('Process', self)
        button.setToolTip('Process only visible image')
        button.clicked.connect(self.processSingleImage)
        self.layout.addWidget(button, 8, 1,1,5)

        button = QtWidgets.QPushButton('Process All', self)
        button.setToolTip('Process all open images')
        button.clicked.connect(self.processImages)
        self.layout.addWidget(button, 9, 1,1,5)

        # Add button to process width/thickness manually
        button = QtWidgets.QPushButton('Manual Measurement', self)
        button.setToolTip('Use manual measurement for width/thickness')
        button.clicked.connect(self.ManualProcessing)
        self.layout.addWidget(button, 10, 1, 1, 5)

        # Add button to save results
        button = QtWidgets.QPushButton('Save Data', self)
        button.setToolTip('Save data to csv file')
        button.clicked.connect(self.saveResults)
        self.layout.addWidget(button, 11, 1, 1, 5)

        self.show()

    def openImages(self):
        self.m.fileDialog()
        self.setWindowTitle('Step 2: Calibration')

    def getCalibration(self):
        dist = self.le.displayText()
        self.m.getConversion(dist)
        self.setWindowTitle('Step 3: Process Images')

    def processImages(self):
        self.m.process_all_images(self.threshold)

    def calculateDistance(self, points):
        (x1, y1), (x2, y2) = points
        dist = np.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2)
        return dist

    def ManualProcessing(self):

        try:
            self.m.conv_mm_per_pix
            points = self.m.ManualThickness()
        except:
            print("ERROR: Calibration of image has not been completed yet!")
            return

        try:
            self.m.width_mm[str(os.path.split(str(self.m.filenames[self.m.img_idx]))[1])].append(
                self.calculateDistance(points) * self.m.conv_mm_per_pix)
        except:
            self.m.width_mm[str(os.path.split(str(self.m.filenames[self.m.img_idx]))[1])] = []
            self.m.width_mm[str(os.path.split(str(self.m.filenames[self.m.img_idx]))[1])].append(
                self.calculateDistance(points) * self.m.conv_mm_per_pix)

    def processSingleImage(self):
        self.m.process_image(self.threshold)

    def saveResults(self):
        # Right now this just prints the data for the visible image. I did not test this for process all function.
        measurement_img = self.m.width_mm[str(os.path.split(str(self.m.filenames[self.m.img_idx]))[1])]
        measurement_img = np.array(measurement_img)
        print('Width = {} +/- {} mm'.format(np.average(measurement_img), np.std(measurement_img)))
        # self.m.save_results()

if __name__ == '__main__':

    qApp = QtWidgets.QApplication([sys.argv])
    app = MainApp()

    sys.exit(qApp.exec_())