import turtle
from src.turtlegfx.utils.load_data import load_task
from src.turtlegfx.utils.colormap import color2rgb
from src.turtlegfx.utils.enums import MIDI_TASKIDS, MAXI_TASKIDS
from PIL import ImageGrab, Image
import matplotlib.pyplot as plt


class Visualizer:
    def __init__(self, pensize=2, speed=1, init_heading=90, stretch_wid=2, stretch_len=2):
        # Initialize the turtle drawer
        self.pensize = pensize
        self.speed = speed
        self.init_heading = init_heading
        self.stretch_wid = stretch_wid
        self.stretch_len = stretch_len

        self.screen = turtle.Screen()
        self.screen.tracer(0)  # Turn off automatic updates for smoother drawing
        self.screen.colormode(255)  # Set the color mode

        self.turtle = turtle.Turtle()
        self._setup_turtle(pensize, speed, init_heading, stretch_wid, stretch_len)

    def _setup_turtle(self, pensize, speed, init_heading, stretch_wid, stretch_len):
        self.turtle.speed(speed)  # Set the speed of the turtle
        self.turtle.pensize(pensize)  # Set the width of the lines
        self.turtle.turtlesize(stretch_wid=stretch_wid, stretch_len=stretch_len)
        self.turtle.setheading(init_heading)  # Make the initial direction 90 degrees
        self.turtle.stamp()  # Mark the starting position

    def _draw_line(self, x1, y1, x2, y2, color):
        # Function to draw a line segment
        self.turtle.penup()
        self.turtle.goto(x1, y1)
        self.turtle.pendown()
        self.turtle.color(color)
        self.turtle.goto(x2, y2)

    def _draw_label(self, x, y, text, fontsize):
        # Function to draw a label
        self.turtle.penup()
        self.turtle.goto(x, y)
        self.turtle.pendown()
        self.turtle.write(text, font=("Arial", fontsize, "normal"))

    def _draw_rectangle(self, x1, y1, x2, y2, color):
        self.turtle.color(color)
        self.turtle.begin_fill()

        # Draw the rectangle
        self._draw_line(x1, y1, x1, y2, color)
        self._draw_line(x1, y2, x2, y2, color)
        self._draw_line(x2, y2, x2, y1, color)
        self._draw_line(x2, y1, x1, y1, color)

        self.turtle.end_fill()

    def draw_task(self, task):
        self.draw(task['lineSegments'], task['labels'], task['rectangles'])

    def draw(self, lines, labels, rectangles):
        for segment in lines:
            self._draw_line(segment["x1"], segment["y1"], segment["x2"], segment["y2"], color2rgb(segment["color"]))

        for label in labels:
            self._draw_label(label["x"], label["y"], label["text"], label["fontSize"])

        for rectangle in rectangles:
            self._draw_rectangle(rectangle["x1"], rectangle["y1"], rectangle["x2"], rectangle["y2"],
                                 color2rgb(rectangle["color"]))

        self.turtle.hideturtle()  # Hide the turtle after drawing is done
        self.screen.update()  # Update the screen to show the drawing

    def save(self, filename, show=False, save_png=True, save_eps=False):
        # Center the screen content and ensure it fits within the screen
        self.screen.setup(width=1.0, height=1.0)  # Resize to the full screen dynamically
        self.screen.update()

        # Capture the screen
        canvas = self.screen.getcanvas()
        if save_eps:
            canvas.postscript(file=filename.replace('.png', '.eps'))
        x0 = self.screen.window_width() // 2
        y0 = self.screen.window_height() // 2
        img = ImageGrab.grab(bbox=(self.screen._root.winfo_rootx(),
                                   self.screen._root.winfo_rooty(),
                                   self.screen._root.winfo_rootx() + self.screen.window_width(),
                                   self.screen._root.winfo_rooty() + self.screen.window_height()))

        # Save the captured image
        if save_png:
            img.save(filename, "PNG")

        # Optionally show the image
        if show:
            plt.imshow(img)
            plt.axis('off')  # Hide axes
            plt.show()

    def clear(self):
        self.screen.clearscreen()
        self.__init__(self.pensize, self.speed, self.init_heading, self.stretch_wid, self.stretch_len)


class XLogoVisualizer:
    def __init__(self, level):
        self.level = level

    def visualize_all_tasks(self):
        task_ids = MIDI_TASKIDS if self.level == "midi" else MAXI_TASKIDS
        for task_id in task_ids:
            task = load_task(task_id)
            visualizer = Visualizer(pensize=2, speed=1, init_heading=90, stretch_wid=2, stretch_len=2)
            visualizer.draw(task['lineSegments'], task['labels'], task['rectangles'])
            visualizer.save(f'../assets/drawings/{task_id}.png')
            visualizer.clear()


if __name__ == "__main__":
    midi_visualizer = XLogoVisualizer(level="midi")
    midi_visualizer.visualize_all_tasks()

    maxi_visualizer = XLogoVisualizer(level="maxi")
    maxi_visualizer.visualize_all_tasks()
