import os
import random
import tempfile
import numpy as np
from polyomino.board import Rectangle, Shape
from polyomino.solution import Solution
from polyomino.constant import ALL_TETROMINOS, ALL_PENTOMINOS
from polyomino.tileset import any_number_of, exactly
from polyomino.transform import rotate
import matplotlib.pyplot as plt


MIN_SIZE = 6
MAX_SIZE = 10
DIRECTIONS = [
    (0, 1),
    (1, 0),
    (0, -1),
    (-1, 0),
]
DIRECTIONS_8 = [
    *DIRECTIONS,
    (1, 1),
    (1, -1),
    (-1, 1),
    (-1, -1),
]


def random_rotate(tile):
    return random.choice(list(rotate(tile)))


def random_rotate_tileset(tileset):
    return [random_rotate(tile) for tile in tileset]


def untransform_tile(tile):
    min_x = min(point[0] for point in tile)
    min_y = min(point[1] for point in tile)

    points = [(point[0] - min_x, point[1] - min_y) for point in tile]

    return points


def move_tile(tile, dx, dy):
    return [(x + dx, y + dy) for x, y in tile]


def tile_size(tile):
    min_x = min(point[0] for point in tile)
    max_x = max(point[0] for point in tile)
    min_y = min(point[1] for point in tile)
    max_y = max(point[1] for point in tile)

    w = max_x - min_x + 1
    h = max_y - min_y + 1

    return w, h


def pack_size(occupied):
    min_x = min(point[0] for point in occupied)
    max_x = max(point[0] for point in occupied)
    min_y = min(point[1] for point in occupied)
    max_y = max(point[1] for point in occupied)

    w = max_x - min_x + 1
    h = max_y - min_y + 1

    return w, h


def is_overlapping(tile, occupied):
    for point in tile:
        if point in occupied:
            return True

    return False


def generate_board(
    min_size=6,
    max_size=12,
    max_attempts=10,
):
    solution = None

    for _ in range(max_attempts):
        w = np.random.randint(min_size, max_size)
        h = np.random.randint(min_size, max_size)
        board = Rectangle(w, h)

        try:
            solution = board.tile_with_set(
                any_number_of(ALL_PENTOMINOS + ALL_TETROMINOS)
            ).solve()
        except Exception:
            continue

        if not solution:
            continue

        break

    return solution


def mark_tile(tile, coord_tile_map, add_padding=False):
    for point in tile:
        if add_padding:
            for dx, dy in DIRECTIONS_8:
                x = point[0] + dx
                y = point[1] + dy
                coord_tile_map[(x, y)] = tile
        coord_tile_map[point] = tile


def is_connected(tile, coord_tile_map):
    for point in tile:
        for dx, dy in DIRECTIONS:
            x = point[0] + dx
            y = point[1] + dy

            if (x, y) in coord_tile_map:
                return True

    return False


def pick_tiles(
    board: Solution,
    num_tiles=5,
    max_attempts=10,
):
    coord_tile_map = {}
    for tile in board.tiling:
        mark_tile(tile, coord_tile_map)

    board_width = board.board.x
    board_height = board.board.y

    picked_tiles = []
    picked_coords = {}

    for _ in range(num_tiles):
        for _ in range(max_attempts):
            x = np.random.randint(0, board_width)
            y = np.random.randint(0, board_height)

            if len(picked_tiles) == 0:
                picked_tiles.append(coord_tile_map[(x, y)])
                mark_tile(picked_tiles[-1], picked_coords)
                break

            if (x, y) in picked_coords:
                continue

            tile_candidate = coord_tile_map[(x, y)]

            if is_connected(tile_candidate, picked_coords):
                picked_tiles.append(tile_candidate)
                mark_tile(tile_candidate, picked_coords)
                break
        pass

    return picked_tiles


def generate_fake_answer(
    board: Shape,
    num_tiles=3,
    max_attempts=10,
):
    for _ in range(max_attempts):
        fake_answer = random.choices(ALL_PENTOMINOS + ALL_TETROMINOS, k=num_tiles)
        try:
            solution = board.tile_with_set(exactly(fake_answer)).solve()
            if solution:
                continue

            return fake_answer
        except Exception:
            return fake_answer

    return None


def generate_task_data(
    min_size=MIN_SIZE,
    max_size=MAX_SIZE,
    max_attempts=10,
    num_tiles=3,
):
    board_solution = generate_board(
        min_size=min_size, max_size=max_size, max_attempts=max_attempts
    )
    picked_tiles = pick_tiles(
        board_solution,
        num_tiles=num_tiles,
        max_attempts=max_attempts,
    )

    board = board_solution.board
    for tile in board_solution.tiling:
        if tile in picked_tiles:
            continue

        board = board.remove_all(tile)

    correct_answer = [untransform_tile(tile) for tile in picked_tiles]
    fake_answers = []
    for _ in range(3):
        fake_answer = generate_fake_answer(board, num_tiles, max_attempts)
        if fake_answer:
            fake_answers.append(random_rotate_tileset(fake_answer))

    return board, random_rotate_tileset(correct_answer), fake_answers


def pack_tileset(tileset):
    occupied = {}
    mark_tile(tileset[0], occupied, add_padding=True)

    tile_sizes = [tile_size(tile) for tile in tileset]
    max_dx = sum(tile_sizes[0])
    max_dy = sum(tile_sizes[1])

    packed_tiles = [tileset[0]]

    for tile in tileset[1:]:
        placements = [
            (dx, dy)
            for dx in range(-max_dx, max_dx + 1)
            for dy in range(-max_dy, max_dy + 1)
        ]
        min_size = (99999, 99999)
        best_placement = None
        for dx, dy in placements:
            placed_tile = move_tile(tile, dx, dy)
            if not is_overlapping(placed_tile, occupied):
                new_occupied = occupied.copy()
                mark_tile(placed_tile, new_occupied, add_padding=True)
                size = pack_size(new_occupied)

                if size[0] < min_size[0] or size[1] < min_size[1]:
                    min_size = size
                    best_placement = (dx, dy)

        packed_tile = move_tile(tile, *best_placement)
        mark_tile(packed_tile, occupied, add_padding=True)
        packed_tiles.append(move_tile(tile, *best_placement))

    return packed_tiles


def plot_tile(ax, tile, color):
    for x, y in tile:
        ax.add_patch(
            plt.Rectangle(
                (x - 0.5, y - 0.5),
                1,
                1,
                facecolor=color,
                edgecolor="black",
                linewidth=3.5,
            )
        )


def plot_answer(ax, answer, colors):
    packed_tiles = pack_tileset(answer)

    xs = []
    ys = []

    for i, tile in enumerate(packed_tiles):
        xs.extend([x for x, y in tile])
        ys.extend([y for x, y in tile])

        color_idx = i % len(colors)
        plot_tile(ax, tile, colors[color_idx])

    min_x, max_x = min(xs), max(xs)
    min_y, max_y = min(ys), max(ys)
    padding = 1
    ax.set_xlim(min_x - padding, max_x + padding)
    ax.set_ylim(min_y - padding, max_y + padding)
    ax.set_aspect("equal")
    ax.axis("off")


def plot_board_and_answers(
    board, correct_answer, fake_answers, num_tiles=3, color_map="viridis"
):
    fig = plt.figure(figsize=(12, 8))
    gs = fig.add_gridspec(2, 4)

    cmap = plt.colormaps[color_map]
    colors = [cmap(i / num_tiles) for i in range(num_tiles)]

    ax_board = fig.add_subplot(gs[0:2, 0:2])
    ax_board.set_title("Board", fontsize=16)

    board_max_x = max(x for x, y in board.squares)
    board_max_y = max(y for x, y in board.squares)
    board_min_x = min(x for x, y in board.squares)
    board_min_y = min(y for x, y in board.squares)

    for rect in board.squares:
        x, y = rect
        ax_board.add_patch(
            plt.Rectangle(
                (x - 0.5, y - 0.5),
                1,
                1,
                facecolor="white",
                edgecolor="black",
                linewidth=2,
            )
        )

    ax_board.set_xlim(board_min_x - 1, board_max_x + 1)
    ax_board.set_ylim(board_min_y - 1, board_max_y + 1)
    ax_board.set_aspect("equal")
    ax_board.axis("off")

    labels = ["A", "B", "C", "D"]
    positions = [(0, 2), (0, 3), (1, 2), (1, 3)]

    all_answers = [correct_answer] + fake_answers

    correct_idx = np.random.randint(0, len(all_answers))
    all_answers[0] = all_answers[correct_idx]
    all_answers[correct_idx] = correct_answer

    for i, ((row, col), answer) in enumerate(zip(positions, all_answers)):
        ax = fig.add_subplot(gs[row, col])
        ax.set_title(labels[i], fontsize=16)
        plot_answer(ax, answer, colors)

    plt.tight_layout()
    return fig, labels[correct_idx]


def generate(variables):
    num_tiles = variables.get("NUM_TILES", 3)
    color_map = variables.get("COLOR_MAP", "viridis")

    board, correct_answer, fake_answers = generate_task_data(
        max_attempts=20, num_tiles=num_tiles
    )

    fig, correct_label = plot_board_and_answers(
        board, correct_answer, fake_answers, num_tiles, color_map
    )

    temp_dir = tempfile.mkdtemp()
    output_pattern = os.path.join(temp_dir, "render.png")
    plt.savefig(output_pattern)
    plt.close()

    return {"IMAGE": output_pattern, "CORRECT": correct_label}


if __name__ == "__main__":
    variables = {"NUM_TILES": 3, "COLOR_MAP": "viridis"}
    result = generate(variables)
    print(f"Generated image: {result['IMAGE']}")
    print(f"Correct answer: {result['CORRECT']}")
