from itertools import count
import os
import random
import tempfile
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Polygon
from mpl_toolkits.mplot3d.art3d import Poly3DCollection


def generate_pyramid(base_ngon):
    angles = np.linspace(0, 2 * np.pi, base_ngon, endpoint=False)
    base_points = [(np.cos(angle), np.sin(angle), 0) for angle in angles]
    apex = (0, 0, 1)
    faces = []
    faces.append(base_points)
    for i in range(base_ngon):
        faces.append([apex, base_points[i], base_points[(i + 1) % base_ngon]])
    return faces


def render_pyramid_3d(ax, faces, edge_colors):
    base = faces[0]
    base_edges = []

    for i in range(len(base)):
        p1 = base[i]
        p2 = base[(i + 1) % len(base)]
        base_edges.append([p1, p2])

    apex = faces[1][0]
    apex_edges = []
    for point in base:
        apex_edges.append([apex, point])

    ax.view_init(elev=25, azim=180 / len(base))

    for i, edge in enumerate(base_edges):
        xs, ys, zs = zip(*edge)
        ax.plot(xs, ys, zs, color=edge_colors[i % len(edge_colors)], linewidth=3.5)

    base_len = len(base)
    for i, edge in enumerate(apex_edges):
        xs, ys, zs = zip(*edge)
        ax.plot(
            xs,
            ys,
            zs,
            color=edge_colors[(i + base_len) % len(edge_colors)],
            linewidth=3.5,
        )

    for face in faces:
        collection = Poly3DCollection([face], linewidths=0)
        collection.set_facecolor((1, 1, 1, 0))
        ax.add_collection3d(collection)

    ax.set_box_aspect([1, 1, 1])
    ax.set_axis_off()
    ax.set_xlim(-1, 1)
    ax.set_ylim(-1, 1)
    ax.set_zlim(0, 1)


def render_pyramid_projection(ax, faces, edge_colors):
    ax.set_aspect("equal")
    ax.axis("off")

    base = faces[0]
    apex = faces[1][0]

    rotation_angle = -90 - 180 / len(base)
    rotation_rad = np.radians(rotation_angle)

    rotated_base = []
    for x, y, z in base:
        new_x = x * np.cos(rotation_rad) - y * np.sin(rotation_rad)
        new_y = x * np.sin(rotation_rad) + y * np.cos(rotation_rad)
        rotated_base.append((new_x, new_y))

    apex_x = apex[0] * np.cos(rotation_rad) - apex[1] * np.sin(rotation_rad)
    apex_y = apex[0] * np.sin(rotation_rad) + apex[1] * np.cos(rotation_rad)
    apex_2d = (apex_x, apex_y)

    base_2d = rotated_base

    for i in range(len(base_2d)):
        p1 = base_2d[i]
        p2 = base_2d[(i + 1) % len(base_2d)]
        ax.plot(
            [p1[0], p2[0]],
            [p1[1], p2[1]],
            color=edge_colors[i % len(edge_colors)],
            linewidth=3.5,
        )

    for i, point in enumerate(base_2d):
        ax.plot(
            [apex_2d[0], point[0]],
            [apex_2d[1], point[1]],
            color=edge_colors[(i + len(base_2d)) % len(edge_colors)],
            linewidth=3.5,
        )

    ax.fill(*(zip(*base_2d)), color="white", zorder=0)

    ax.set_xlim(-1, 1)
    ax.set_ylim(-1, 1)


def generate_fake_colors(edge_colors):
    while True:
        fake_colors = edge_colors.copy()
        random.shuffle(fake_colors)
        if fake_colors != edge_colors:
            return fake_colors


def generate(variables):
    base_ngon = variables.get("BASE_NGON", 4)
    EDGE_COLORS = ["black", "blue", "red", "green", "magenta", "orange", "cyan"]

    needed_colors = base_ngon * 2
    edge_colors = random.sample(
        EDGE_COLORS, counts=[3] * len(EDGE_COLORS), k=needed_colors
    )

    faces = generate_pyramid(base_ngon)

    fig = plt.figure(figsize=(18, 12))
    gs = fig.add_gridspec(4, 3)

    ax_3d = fig.add_subplot(gs[0:2, :], projection="3d")
    render_pyramid_3d(ax_3d, faces, edge_colors)

    labels = ["A", "B", "C", "D", "E", "F"]
    correct_idx = random.randint(0, 5)
    correct_label = labels[correct_idx]

    all_colors = [edge_colors]

    for _ in range(5):
        fake_colors = generate_fake_colors(edge_colors)
        all_colors.append(fake_colors)

    correct_colors = all_colors[0]
    all_colors[0] = all_colors[correct_idx]
    all_colors[correct_idx] = correct_colors

    positions = [(2, 0), (2, 1), (2, 2), (3, 0), (3, 1), (3, 2)]
    for i, (row, col) in enumerate(positions):
        ax = fig.add_subplot(gs[row, col])
        render_pyramid_projection(ax, faces, all_colors[i])
        ax.set_title(labels[i], fontsize=20)

    temp_dir = tempfile.mkdtemp()
    output_pattern = os.path.join(temp_dir, "render.png")

    plt.tight_layout()
    plt.savefig(output_pattern)
    plt.close()

    return {"IMAGE": output_pattern, "CORRECT": correct_label}


if __name__ == "__main__":
    variables = {"BASE_NGON": 6}
    result = generate(variables)
    print(f"Generated image: {result['IMAGE']}")
    print(f"Correct answer: {result['CORRECT']}")
