import os
import sys
import glob
import argparse
import tempfile
import shutil
from typing import List, Optional
from PIL import Image, ImageTk
import tkinter as tk
from tkinter import messagebox, filedialog
from app.generator.generator import generate_captcha

sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))


class CollageApp:
    def __init__(
        self,
        scenes: List[str],
        n_rows: int,
        n_cols: int,
        output_path: str,
        padding_x: int = 10,
        padding_y: int = 10,
    ):
        self.scenes = scenes
        self.n_rows = n_rows
        self.n_cols = n_cols
        self.output_path = output_path
        self.padding_x = padding_x
        self.padding_y = padding_y
        self.temp_dir = tempfile.mkdtemp()

        self.root = tk.Tk()
        self.root.title(f"Interactive Collage - {','.join(scenes)}")

        self.image_paths = []
        self.image_labels = []
        self.image_scenes = []
        self.max_size = 300

        self.setup_ui()

    def find_manifest_path(self, scene_name: str) -> str:
        base_dir = os.path.dirname(
            os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
        )
        scenes_dir = os.path.join(base_dir, "scenes")

        manifest_path = os.path.join(scenes_dir, scene_name, "manifest.json")
        if os.path.exists(manifest_path):
            return manifest_path

        manifests = glob.glob(
            os.path.join(scenes_dir, "**", "manifest.json"), recursive=True
        )
        for manifest in manifests:
            if scene_name in manifest:
                return manifest

        raise FileNotFoundError(f"No manifest found for scene: {scene_name}")

    def setup_ui(self):
        main_frame = tk.Frame(self.root)
        main_frame.pack(padx=10, pady=10)

        self.grid_frame = tk.Frame(main_frame)
        self.grid_frame.pack()

        for i in range(self.n_rows):
            for j in range(self.n_cols):
                label = tk.Label(
                    self.grid_frame, text="Loading...", relief="raised", cursor="hand2"
                )
                label.grid(row=i, column=j, padx=2, pady=2)
                label.bind(
                    "<Button-1>",
                    lambda e, idx=i * self.n_cols + j: self.regenerate_image(
                        idx, cycle_scene=e.state & 0x1
                    ),
                )
                self.image_labels.append(label)

        button_frame = tk.Frame(main_frame)
        button_frame.pack(pady=10)

        save_btn = tk.Button(
            button_frame,
            text="Save Collage",
            command=self.save_collage,
            font=("Arial", 12, "bold"),
        )
        save_btn.pack(side=tk.LEFT, padx=5)

        regenerate_all_btn = tk.Button(
            button_frame,
            text="Regenerate All",
            command=self.regenerate_all_images,
            font=("Arial", 12, "bold"),
        )
        regenerate_all_btn.pack(side=tk.LEFT, padx=5)

        exit_btn = tk.Button(
            button_frame,
            text="Exit",
            command=self.cleanup_and_exit,
            font=("Arial", 12, "bold"),
        )
        exit_btn.pack(side=tk.LEFT, padx=5)

    def generate_single_image(self, scene_name: str) -> str:
        manifest_path = self.find_manifest_path(scene_name)

        try:
            instance_ids = generate_captcha(manifest_path, self.temp_dir)
            image_file = os.path.join(self.temp_dir, f"images/{instance_ids[0]}.png")
            if os.path.exists(image_file):
                return image_file

        except Exception as e:
            print(f"Error generating image: {e}")

        return None

    def load_and_resize_image(self, image_path: str) -> ImageTk.PhotoImage:
        try:
            image = Image.open(image_path)
            image.thumbnail((self.max_size, self.max_size), Image.Resampling.LANCZOS)
            return ImageTk.PhotoImage(image)
        except Exception as e:
            print(f"Error loading image {image_path}: {e}")
            return None

    def generate_initial_images(self):
        total_images = self.n_rows * self.n_cols

        for i in range(total_images):
            self.root.update()

            scene_name = self.scenes[0]
            image_path = self.generate_single_image(scene_name)
            if image_path:
                self.image_paths.append(image_path)
                self.image_scenes.append(scene_name)
                photo = self.load_and_resize_image(image_path)
                if photo:
                    self.image_labels[i].config(image=photo, text="")
                    self.image_labels[i].image = photo
            else:
                self.image_paths.append(None)
                self.image_scenes.append(scene_name)
                self.image_labels[i].config(text="Failed to load")

    def regenerate_image(self, index: int, cycle_scene: bool = False):
        self.image_labels[index].config(text="Regenerating...")
        self.root.update()

        current_scene = self.image_scenes[index]

        if cycle_scene:
            current_scene_index = self.scenes.index(current_scene)
            next_scene_index = (current_scene_index + 1) % len(self.scenes)
            scene_to_use = self.scenes[next_scene_index]
        else:
            scene_to_use = current_scene

        image_path = self.generate_single_image(scene_to_use)
        if image_path:
            self.image_paths[index] = image_path
            self.image_scenes[index] = scene_to_use
            photo = self.load_and_resize_image(image_path)
            if photo:
                self.image_labels[index].config(image=photo, text="")
                self.image_labels[index].image = photo
        else:
            self.image_scenes[index] = scene_to_use
            self.image_labels[index].config(text="Failed to load")

    def regenerate_all_images(self):
        for i in range(len(self.image_labels)):
            self.regenerate_image(i)

    def save_collage(self):
        valid_images = [
            path for path in self.image_paths if path and os.path.exists(path)
        ]

        if not valid_images:
            messagebox.showerror("Error", "No valid images to save")
            return

        try:
            images = []
            for path in self.image_paths:
                if path and os.path.exists(path):
                    img = Image.open(path)
                    images.append(img)
                else:
                    img = Image.new("RGB", (512, 512), color="white")
                    images.append(img)

            max_width = max(img.width for img in images)
            max_height = max(img.height for img in images)
            cell_width = max_width
            cell_height = max_height

            collage_width = (
                self.n_cols * cell_width + (self.n_cols + 1) * self.padding_x
            )
            collage_height = (
                self.n_rows * cell_height + (self.n_rows - 1) * self.padding_y
            )

            collage = Image.new("RGB", (collage_width, collage_height), color="white")

            for i, img in enumerate(images):
                row = i // self.n_cols
                col = i % self.n_cols

                img_copy = img.copy()
                img_copy.thumbnail((cell_width, cell_height), Image.Resampling.LANCZOS)

                cell_x = col * (cell_width + self.padding_x) + self.padding_x
                cell_y = row * (cell_height + self.padding_y)

                paste_x = cell_x + (cell_width - img_copy.width) // 2
                paste_y = cell_y + (cell_height - img_copy.height) // 2

                collage.paste(img_copy, (paste_x, paste_y))

            collage.thumbnail((1920, 1920), Image.Resampling.LANCZOS)
            collage.save(self.output_path)

            messagebox.showinfo("Success", f"Collage saved to {self.output_path}")

        except Exception as e:
            messagebox.showerror("Error", f"Failed to save collage: {e}")

    def cleanup_and_exit(self):
        try:
            shutil.rmtree(self.temp_dir)
        except:
            pass
        self.root.destroy()

    def run(self):
        self.root.protocol("WM_DELETE_WINDOW", self.cleanup_and_exit)
        self.root.after(100, self.generate_initial_images)
        self.root.mainloop()


def main():
    parser = argparse.ArgumentParser(
        description="Interactive captcha collage generator"
    )
    parser.add_argument(
        "--scenes",
        required=True,
        help="Comma-separated list of scene names to cycle through",
    )
    parser.add_argument("--n-rows", type=int, default=2, help="Number of rows in grid")
    parser.add_argument(
        "--n-cols", type=int, default=2, help="Number of columns in grid"
    )
    parser.add_argument(
        "--output-path", default="collage.png", help="Output path for saved collage"
    )
    parser.add_argument(
        "--padding-x",
        type=int,
        default=10,
        help="Horizontal padding between images in final collage",
    )
    parser.add_argument(
        "--padding-y",
        type=int,
        default=10,
        help="Vertical padding between images in final collage",
    )

    args = parser.parse_args()

    scenes = [scene.strip() for scene in args.scenes.split(",")]

    app = CollageApp(
        scenes,
        args.n_rows,
        args.n_cols,
        args.output_path,
        args.padding_x,
        args.padding_y,
    )
    app.run()


if __name__ == "__main__":
    main()
