import os
import random
import tempfile
import numpy as np
import matplotlib.pyplot as plt
from app.generator.renderer.voxel import VoxelRenderer


def place_random_cuboid(grid, max_box_size=(2, 2, 2)):
    z_max, y_max, x_max = grid.shape

    dz = random.randint(1, max_box_size[0])
    dy = random.randint(1, max_box_size[1])
    dx = random.randint(1, max_box_size[2])

    z0 = random.randint(0, z_max - dz)
    y0 = random.randint(0, y_max - dy)
    x0 = random.randint(0, x_max - dx)

    grid[z0 : z0 + dz, y0 : y0 + dy, x0 : x0 + dx] = 1


def generate_shape_by_boxes(
    grid_size=5, num_boxes=4, max_box_size=(2, 2, 2), seed=None
):
    if seed is not None:
        random.seed(seed)
        np.random.seed(seed)

    grid = np.zeros((grid_size, grid_size, grid_size), dtype=np.uint8)

    for _ in range(num_boxes):
        place_random_cuboid(grid, max_box_size)

    return grid


def generate_incorrect_models(correct_model, num_models=3):
    incorrect_models = []
    grid_size = correct_model.shape[0]

    for _ in range(num_models):
        model = correct_model.copy()

        num_changes = random.randint(2, 5)
        for _ in range(num_changes):
            if random.random() > 0.5:
                x, y, z = (
                    random.randint(0, grid_size - 1),
                    random.randint(0, grid_size - 1),
                    random.randint(0, grid_size - 1),
                )
                if model[x, y, z] == 0:
                    model[x, y, z] = 1
            else:
                filled = np.argwhere(model == 1)
                if len(filled) > 0:
                    idx = random.randint(0, len(filled) - 1)
                    x, y, z = filled[idx]
                    model[x, y, z] = 0

        incorrect_models.append(model)

    return incorrect_models


def generate(variables):
    grid_size = variables.get("GRID_SIZE", 5)
    color_map = variables.get("COLOR_MAP", "plasma")

    correct_model = generate_shape_by_boxes(
        grid_size, num_boxes=6, max_box_size=(2, 2, 2)
    )

    incorrect_models = generate_incorrect_models(correct_model, num_models=3)

    indices = list(range(4))
    random.shuffle(indices)
    correct_index = indices[0]
    correct_label = ["A", "B", "C", "D"][correct_index]

    all_models = [None] * 4
    all_models[indices[0]] = correct_model
    for i, model in enumerate(incorrect_models):
        all_models[indices[i + 1]] = model

    renderer = VoxelRenderer()

    fig = plt.figure(figsize=(18, 10))
    gs = fig.add_gridspec(5, 12)

    ax_top = fig.add_subplot(gs[0:2, 0:4])
    renderer._render_2d(correct_model, ax_top, "top", color_map)
    ax_top.set_title("Top View", fontsize=16)

    ax_front = fig.add_subplot(gs[0:2, 4:8])
    renderer._render_2d(correct_model, ax_front, "front", color_map)
    ax_front.set_title("Front View", fontsize=16)

    ax_right = fig.add_subplot(gs[0:2, 8:12])
    renderer._render_2d(correct_model, ax_right, "right", color_map)
    ax_right.set_title("Right View", fontsize=16)

    option_positions = [
        (gs[2:5, 0:3], "A"),
        (gs[2:5, 3:6], "B"),
        (gs[2:5, 6:9], "C"),
        (gs[2:5, 9:12], "D"),
    ]

    for i, ((grid_pos, label), model) in enumerate(zip(option_positions, all_models)):
        ax = fig.add_subplot(grid_pos, projection="3d")
        renderer._render_3d(model, ax, color_map)

        ax.set_xlabel("")
        ax.set_ylabel("")

        ax.set_title(label, fontsize=16)

    temp_dir = tempfile.mkdtemp()
    output_pattern = os.path.join(temp_dir, "render.png")

    plt.tight_layout()
    plt.savefig(output_pattern)
    plt.close()

    return {"IMAGE": output_pattern, "CORRECT": correct_label}
