Seite 1 von 1

Plot in Tkinter einbinden und mit "after"-Funktion aktualisieren

Verfasst: Samstag 31. Oktober 2020, 10:05
von Finn_h
Hallo,
ich habe ein GUI mit 3 aufeinander folgenden Frames mittels tkinter erstellt und mit dummy Daten gefüllt. An zwei Stellen komme ich nicht weiter:

1. ich habe ein 3D-plot über Matplotlib integriert. Jetzt möchte ich über 4 Buttons (hoch, runter, links, rechts) das Plot im Raum drehen. Meine Idee war es die entsprechenden Werte in die Funktion zu übergeben und mit "ax.view_init" das neue leicht gedrehte Plot zu berechnen. Das Plot würde ich dann mit einer "after"-Funktion nach einem gewissen Zeitinterval neu aufbauen.

2. Wenn ich aus meinem letzten Frame (precises_frame) in meinem mittleren Frame (image_frame) zurückkehre und ein neuen Bildbereich auswähle und dann wieder in den letzten Frame wechsel, bekomme ich eine Fehlermeldung.

Über Hilfe wäre ich sehr dankbar.
Code:

Code: Alles auswählen

import cv2
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
import matplotlib.pyplot as plt
import tkinter as tk
import numpy as np
import PIL.Image, PIL.ImageTk
from matplotlib.ticker import MaxNLocator

LINE_AMOUNT = 10


class GraphicalUserInterface:

    def __init__(self, master):
        master.geometry("1400x700+30+30")
        master.title("Graphical User Interface")

        #Start Frame

        self.start_frame = tk.Frame(master)

        self.start_frame.pack(expand=True, fill="both")
        self.start_frame.pack_propagate(0)
        self.start_frame_label = tk.Label(self.start_frame, text="Start Frame")
        self.start_frame_label.pack()

        self.continue_button_start_frame = tk.Button(self.start_frame, text="Take Image and Continue",
                                                     command=lambda: self.start_to_image_frame())
        self.continue_button_start_frame.place(relx=0.85, rely=0.05)

        video_label = tk.Label(self.start_frame)
        video_label.pack()

        self.width, self.height = 640, 480

        self.cap = cv2.VideoCapture(0)
        self.cap.set(cv2.CAP_PROP_FRAME_WIDTH, self.width)
        self.cap.set(cv2.CAP_PROP_FRAME_HEIGHT, self.height)


        def show_video():
            _, image = self.cap.read()
            image = cv2.flip(image, 1)
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGBA)
            image = PIL.Image.fromarray(image)
            image_tk = PIL.ImageTk.PhotoImage(image=image)
            video_label.image_tk = image_tk
            video_label.configure(image=image_tk)
            video_label.after(10, show_video)

        show_video()

        #Image Frame

        self.image_frame = tk.Frame(master)
        self.image_frame_label = tk.Label(self.image_frame, text="Image Frame")
        self.image_frame_label.pack()

        self.continue_button_image_frame = tk.Button(self.image_frame, text="Continue",
                                                     command=lambda:
                                                     self.image_to_precise_frame(self.x_entry_widget.get(),
                                                                                 self.y_entry_widget.get(), self.image))
        self.continue_button_image_frame.place(relx=0.85, rely=0.05)

        self.back_button_image_frame = tk.Button(self.image_frame, text="Back",
                                                     command=lambda: self.image_to_start_frame())
        self.back_button_image_frame.place(relx=0.15, rely=0.05)


        self.image_panel = tk.Label(self.image_frame)
        self.image_panel.place(x=750, y= 550, anchor="sw")

        self.arrow_x_canvas = tk.Canvas(self.image_frame)
        self.arrow_x_canvas.config(width=210, height=30)
        self.arrow_x_canvas.place(x=750, y=565)
        self.arrow_x_canvas.create_line(0, 15, 200, 15, arrow=tk.LAST)
        self.x_axis_tag = tk.Label(self.image_frame, text="x")
        self.x_axis_tag.place(x=960, y=590, anchor="s")

        self.arrow_y_canvas = tk.Canvas(self.image_frame)
        self.arrow_y_canvas.config(width=30, height=210)
        self.arrow_y_canvas.place(x=735, y=550, anchor="se")
        self.arrow_y_canvas.create_line(15, 210, 15, 10, arrow=tk.LAST)
        self.y_axis_tag = tk.Label(self.image_frame, text="y")
        self.y_axis_tag.place(x=720, y=320)

        self.x_entry_widget = tk.Entry(self.image_frame, bg="yellow", text="10")
        self.x_entry_widget.place(x=900, y=610)
        self.x_entry_tag = tk.Label(self.image_frame, text="x")
        self.x_entry_tag.place(x=900, y=610, anchor="ne")

        self.y_entry_widget = tk.Entry(self.image_frame, bg="yellow")
        self.y_entry_widget.place(x=900, y=630)
        self.y_entry_tag = tk.Label(self.image_frame, text="y")
        self.y_entry_tag.place(x=900, y=630, anchor="ne")

        ########################################################

        self.vertical_display = tk.Label(self.image_frame, text=30, bg="red")
        self.vertical_display.place(relx=0.1, rely=0.9)

        self.horizontal_display = tk.Label(self.image_frame, text=30, bg="red")
        self.horizontal_display.place(relx=0.1, rely=0.85)

        self.up_button = tk.Button(master=self.image_frame, text="hoch",
                              command=lambda: self.turn_horizontal(self.horizontal_display, 5))
        self.up_button.place(x=300, y=570)

        self.down_button = tk.Button(master=self.image_frame, text="runter",
                                command=lambda: self.turn_horizontal(self.horizontal_display, -5))
        self.down_button.place(x=300, y=620)

        self.left_button = tk.Button(master=self.image_frame, text="links",
                                command=lambda: self.turn_vertical(self.vertical_display, -5))
        self.left_button.place(x=250, y=595)

        self.right_button = tk.Button(master=self.image_frame, text="rechts",
                                 command=lambda: self.turn_vertical(self.vertical_display, 5))
        self.right_button.place(x=350, y=595)

        # Precise Frame

        self.precise_frame = tk.Frame(master)
        self.precise_frame_label = tk.Label(self.precise_frame, text="Precise Frame")
        self.precise_frame_label.pack()

        self.x_value_label = tk.Label(self.precise_frame, bg="red")
        self.x_value_label.pack()

        self.y_value_label = tk.Label(self.precise_frame, bg="red")
        self.y_value_label.pack()

        self.image_panel_precise_frame = tk.Label(self.precise_frame)
        self.image_panel_precise_frame.place(x=750, y= 550, anchor="sw")



        self.back_button_precise_frame = tk.Button(self.precise_frame, text="Back",
                                                   command=lambda: self.precise_to_image_frame())
        self.back_button_precise_frame.place(relx=0.15, rely=0.05)

    def get_dummy_three_d_image(self, turn_vertical, turn_horizontal):
        ARRAY_HEIGHT = self.height
        ARRAY_WIDTH = self.width

        sensor_input = np.random.rand(ARRAY_HEIGHT, ARRAY_WIDTH)

        plot_square_size = 8

        if ARRAY_HEIGHT % plot_square_size or ARRAY_WIDTH % plot_square_size:
            print("Fehlermeldung")

        self.plot_data = list()

        for x in range(int(ARRAY_WIDTH / plot_square_size)):
            x = x * plot_square_size
            for y in range(int(ARRAY_HEIGHT / plot_square_size)):
                y = y * plot_square_size
                single_plot_item = np.array([x + plot_square_size / 2 - 0.5,
                                             y + plot_square_size / 2 - 0.5,
                                             sensor_input[y:y + plot_square_size, x:x + plot_square_size].mean()])
                self.plot_data.append(single_plot_item)
        #
        # eventuell noch alle Randpunkte einbeziehen
        #
        self.plot_data = np.asarray(self.plot_data)

        x = self.plot_data[:, 0]
        y = self.plot_data[:, 1]
        z = self.plot_data[:, 2]

        fig = plt.figure()

        ax = fig.add_subplot(111, projection="3d")
        ax.set_zlim(0, 5)

        surf = ax.plot_trisurf(x, y, z, cmap="plasma", linewidth=0)
        fig.colorbar(surf)

        ax.xaxis.set_major_locator(MaxNLocator(5))
        ax.yaxis.set_major_locator(MaxNLocator(5))
        ax.zaxis.set_major_locator(MaxNLocator(5))

        fig.tight_layout()

        ax.view_init(turn_vertical, turn_horizontal)
        return fig

    def turn_vertical(self, display, value):
        display["text"] = display["text"] + value

    def turn_horizontal(self, display, value):
        display["text"] = display["text"] + value

    def start_to_image_frame(self):
        self.start_frame.pack_forget()
        _, self.image = self.cap.read()
        self.cluster_image = self.draw_cluster_on_image(self.image)
        self.image_panel["image"]= self.cluster_image

        self.fig = self.get_dummy_three_d_image(self.horizontal_display["text"], self.vertical_display["text"])
        self.plot_canvas = FigureCanvasTkAgg(self.fig, master=self.image_frame)
        self.plot_canvas._tkcanvas.place(x=10, y=550, anchor="sw")

        self.image_frame.pack(expand=True, fill="both")
        self.image_frame.pack_propagate(0)

    def image_to_precise_frame(self, x_value, y_value, image):
        self.image_frame.pack_forget()
        self.x_value_label["text"] = x_value
        self.y_value_label["text"] = y_value
        self.image = self.highlight_location_on_image(image, x_value, y_value)



        self.precise_fig = self.get_dummy_precise_three_d_image(self.vertical_display["text"],
                                                                self.horizontal_display["text"],
                                                                self.plot_data,
                                                                x_value,
                                                                y_value)

        self.precise_plot_canvas = FigureCanvasTkAgg(self.precise_fig, master=self.precise_frame)
        self.precise_plot_canvas._tkcanvas.place(x=10, y=550, anchor="sw")


        self.image_panel_precise_frame["image"] = self.image
        self.precise_frame.pack(expand=True, fill="both")
        self.precise_frame.pack_propagate(0)

    def image_to_start_frame(self):
        self.image_frame.pack_forget()
        self.start_frame.pack(expand=True, fill="both")
        self.start_frame.pack_propagate(0)

    def precise_to_image_frame(self):
        self.precise_frame.pack_forget()
        self.image_frame.pack(expand=True, fill="both")
        self.image_frame.pack_propagate(0)

    def draw_cluster_on_image(self, image):

        image = cv2.flip(image, 1)

        height, width, channel = image.shape

        line_distance_x = width / LINE_AMOUNT
        line_distance_y = height / LINE_AMOUNT

        for line in range(LINE_AMOUNT):
            image = cv2.line(image, (int(line * line_distance_x), 0), (int(line * line_distance_x), height),
                             (100, 100, 100))
            image = cv2.line(image, (0, int(line * line_distance_y)), (width, int(line * line_distance_y)),
                             (100, 100, 100))

        image_as_array = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        image = PIL.Image.fromarray(image_as_array)
        image = PIL.ImageTk.PhotoImage(image)
        return image

    def highlight_location_on_image(self, image, x_value, y_value):
        image = cv2.flip(image, 1)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        first_point = int((int(x_value))/LINE_AMOUNT*self.width), int(self.height-(int(y_value)+1) / LINE_AMOUNT * self.height)
        second_point = int((int(x_value)+1)/LINE_AMOUNT*self.width), int(self.height-(int(y_value))/LINE_AMOUNT*self.height)
        blurred_image = cv2.GaussianBlur(image, (29, 29), 0)
        mask = np.zeros((image.shape), dtype=np.uint8)
        mask = cv2.rectangle(mask, (first_point), (second_point), (255, 255, 255), -1)

        image = np.where(mask == np.array([255, 255, 255]), image, blurred_image)

        image = PIL.Image.fromarray(image)
        image = PIL.ImageTk.PhotoImage(image)
        return image

    def get_dummy_precise_three_d_image(self, turn_vertical, turn_horizontal, plot_data, x_value, y_value):
        x_value, y_value = int(x_value), int(y_value)
        first_point = (x_value/LINE_AMOUNT*self.width, y_value/LINE_AMOUNT*self.height)
        second_point = ((x_value+1)/LINE_AMOUNT*self.width, (y_value+1)/LINE_AMOUNT*self.height)
        selected_data = list()
        diselected_data = list()
        for row in plot_data:
            if first_point[0] <= row[0] <= second_point[0] and first_point[1] <= row[1] <= second_point[1]:
                selected_data.append(row)
            else:
                diselected_data.append(row)

        selected_data = np.asarray(selected_data)
        diselected_data = np.asarray(diselected_data)

        x_diselected = diselected_data[:, 0]
        y_diselected = diselected_data[:, 1]
        z_diselected = diselected_data[:, 2]

        x_selected = selected_data[:, 0]
        y_selected = selected_data[:, 1]
        z_selected = selected_data[:, 2]

        fig = plt.figure()

        ax = fig.add_subplot(111, projection="3d")
        ax.set_zlim(0, 5)

        surf = ax.plot_trisurf(x_diselected, y_diselected, z_diselected, cmap="plasma", linewidth=0, alpha=0.3)
        fig.colorbar(surf)

        ax.plot_trisurf(x_selected, y_selected, z_selected, color="red", linewidth=0)

        ax.xaxis.set_major_locator(MaxNLocator(5))
        ax.yaxis.set_major_locator(MaxNLocator(5))
        ax.zaxis.set_major_locator(MaxNLocator(5))

        fig.tight_layout()


        X_plane = np.array([[first_point[0], second_point[0]], [first_point[0], second_point[0]]])
        Y_plane = np.array([[first_point[1], first_point[1]], [second_point[1], second_point[1]]])
        Z_plane = np.array([[4.5, 4.5], [4.5, 4.5]])

        ax.plot_surface(X_plane, Y_plane, Z_plane, color="green")

        ax.view_init(turn_vertical, turn_horizontal)
        return fig

root = tk.Tk()
GraphicalUserInterface(root)
root.mainloop()