import os
import random
import tempfile
import numpy as np
import matplotlib.pyplot as plt
from app.utils.voxel_utils import voxel_depth
from app.generator.renderer.voxel import VoxelRenderer


def generate_different_views_shape(size=5):
    voxel_array = np.zeros((size, size, size), dtype=int)

    right_pattern = np.zeros((size, size), dtype=int)
    for i in range(size):
        for j in range(size):
            if random.random() > 0.5:
                right_pattern[i, j] = 1

    front_pattern = np.zeros((size, size), dtype=int)
    for i in range(size):
        for j in range(size):
            if random.random() > 0.5:
                front_pattern[i, j] = 1

    for x in range(size):
        for y in range(size):
            for z in range(size):
                if right_pattern[x, z] == 1 and front_pattern[y, z] == 1:
                    voxel_array[x, y, z] = 1

    return voxel_array


def verify_different_views(voxel_array):
    top_view = voxel_depth(voxel_array, "top")
    right_view = voxel_depth(voxel_array, "right")
    front_view = voxel_depth(voxel_array, "front")

    return not (
        np.array_equal(top_view, right_view)
        or np.array_equal(top_view, front_view)
        or np.array_equal(right_view, front_view)
    )


def check_complexity(voxel_array):
    if np.sum(voxel_array) < 10:
        return False

    top_view = voxel_depth(voxel_array, "top")
    right_view = voxel_depth(voxel_array, "right")
    front_view = voxel_depth(voxel_array, "front")

    if np.all(top_view[top_view != 0] == top_view[top_view != 0][0]):
        return False
    if np.all(right_view[right_view != 0] == right_view[right_view != 0][0]):
        return False
    if np.all(front_view[front_view != 0] == front_view[front_view != 0][0]):
        return False

    return True


def generate_and_verify(size=5, max_attempts=30):
    for _ in range(max_attempts):
        voxel_array = generate_different_views_shape(size)
        if verify_different_views(voxel_array) and check_complexity(voxel_array):
            return voxel_array

    return voxel_array


def generate(variables):
    size = variables.get("NUM_VOXELS", 4)
    color_map = variables.get("COLOR_MAP")
    projections = variables.get("PROJECTIONS")
    target_projection = variables.get("TARGET_PROJECTION")

    target_idx = projections.index(target_projection)

    voxel_array = generate_and_verify(size)

    labels = ["A", "B", "C", "D", "E", "F"]
    target_projection = projections[target_idx]
    correct_label = labels[target_idx]

    renderer = VoxelRenderer()
    fig = plt.figure(figsize=(15, 10))

    cols = 3
    projection_rows = len(projections) // cols
    gs = fig.add_gridspec(2 + projection_rows, cols)

    ax_3d = fig.add_subplot(gs[0:2, 0:cols], projection="3d")
    renderer._render_3d(voxel_array, ax_3d, color_map)

    axs = []
    for row in range(projection_rows):
        for col in range(cols):
            axs.append(fig.add_subplot(gs[2 + row, col]))

    for ax, proj, label in zip(axs, projections, labels):
        renderer._render_2d(voxel_array, ax, proj, color_map)
        ax.set_title(label, fontsize=20)

    temp_dir = tempfile.mkdtemp()
    output_pattern = os.path.join(temp_dir, "render.png")

    plt.tight_layout()
    plt.savefig(output_pattern)

    plt.close()

    return {
        "IMAGE": output_pattern,
        "PROJECTION": target_projection,
        "CORRECT": correct_label,
    }
