import os
import tempfile
import numpy as np
from vedo import Plotter, Box, Cylinder, Cone, Text2D, Arrow, Plane

FLOOR_SIZE = 30
FLOOR_COLOR = [0.8, 1.0, 0.8]

VIEWPORT_PADDING = 0.01
ARROW_VIEWPORT = 0.2

CAMERA_POSITION = np.array([7, 7, 4])
HEIGHT = 4
DISTANCE = 5
MIN_SEPARATION = 3


def generate_box():
    size = np.random.uniform(1, 2, 3)
    box = Box(size=size)

    angle = np.random.uniform(0, 360)
    box.rotate_z(angle)

    center_point = [0, 0, size[2] / 2]
    box.pos(center_point)
    box.color("white")
    box.lighting("off")

    return box, center_point


def generate_pyramid():
    height = np.random.uniform(1.5, 2.5)
    base_size = np.random.uniform(1, 2)
    ngon = np.random.randint(3, 6)

    pyramid = Cone(pos=(0, 0, 0), r=base_size, height=height, axis=(0, 0, 1), res=ngon)

    angle = np.random.uniform(0, 360)
    pyramid.rotate_z(angle)

    center_point = [0, 0, height / 2]
    pyramid.pos(center_point)
    pyramid.color("white")
    pyramid.lighting("off")

    return pyramid, center_point


def generate_cone():
    height = np.random.uniform(1.5, 2.5)
    radius = np.random.uniform(0.8, 1.5)

    cone = Cone(pos=(0, 0, 0), r=radius, height=height, axis=(0, 0, 1), res=32)

    angle = np.random.uniform(0, 360)
    cone.rotate_z(angle)

    center_point = [0, 0, height / 2]
    cone.pos(center_point)
    cone.color("white")
    cone.lighting("off")

    return cone, center_point


def generate_cylinder():
    height = np.random.uniform(1.5, 2.5)
    radius = np.random.uniform(0.8, 1.5)

    cylinder = Cylinder(pos=(0, 0, 0), r=radius, height=height, axis=(0, 0, 1))

    angle = np.random.uniform(0, 360)
    cylinder.rotate_z(angle)

    center_point = [0, 0, height / 2]
    cylinder.pos(center_point)
    cylinder.color("white")
    cylinder.lighting("off")

    return cylinder, center_point


def generate_sun_positions():
    positions = []

    for i in range(4):
        while True:
            theta = np.random.uniform(0, 2 * np.pi)
            x = DISTANCE * np.cos(theta)
            y = DISTANCE * np.sin(theta)
            pos = np.array([x, y, HEIGHT])

            if not positions or all(
                np.linalg.norm(pos - np.array(p)) > MIN_SEPARATION for p in positions
            ):
                positions.append(pos)
                break

    return positions


def generate(variables):
    shape_type = variables.get("SHAPE", "box")

    plt = Plotter(
        N=6, size=(1600, 1600), bg="lightblue", sharecam=False, offscreen=True
    )

    project_plane = Plane(s=(FLOOR_SIZE, FLOOR_SIZE), normal=(0, 0, 1))
    floor = Plane(
        pos=(0, 0, -0.05), s=(FLOOR_SIZE, FLOOR_SIZE), normal=(0, 0, 1)
    ).color(FLOOR_COLOR)

    if shape_type == "box":
        shape, center_point = generate_box()
    elif shape_type == "pyramid":
        shape, center_point = generate_pyramid()
    elif shape_type == "cone":
        shape, center_point = generate_cone()
    elif shape_type == "cylinder":
        shape, center_point = generate_cylinder()
    else:
        raise ValueError(f"Unknown shape type: {shape_type}")

    positions = generate_sun_positions()

    plt.renderers[0].SetViewport(
        VIEWPORT_PADDING,
        0.5 + VIEWPORT_PADDING / 2,
        1 - VIEWPORT_PADDING,
        1 - VIEWPORT_PADDING,
    )

    target_idx = np.random.randint(0, 4)

    plt.renderers[5].SetViewport(
        VIEWPORT_PADDING,
        0.5 + VIEWPORT_PADDING / 2,
        ARROW_VIEWPORT + VIEWPORT_PADDING,
        0.5 + VIEWPORT_PADDING / 2 + ARROW_VIEWPORT,
    )
    plt.at(5).background(FLOOR_COLOR)
    arrow_start = positions[target_idx] * 0.7
    arrow_start[2] = 0

    arrow_end = np.array([0, 0, 0]) - arrow_start / 2
    arrow_start /= 2

    arrow = Arrow(arrow_start, arrow_end, c="red")
    cam = plt.at(5).camera
    cam.SetPosition(CAMERA_POSITION)
    cam.SetFocalPoint(center_point)
    cam.SetViewUp(0, 0, 1)
    plt.at(5).add(arrow)

    plt.at(0).add(floor, shape)

    silhouette = shape.silhouette().linewidth(5)
    plt.at(0).add(silhouette)

    cam = plt.at(0).camera
    cam.SetPosition(CAMERA_POSITION)
    cam.SetFocalPoint(center_point)
    cam.SetViewUp(0, 0, 1)

    viewports = [
        (
            VIEWPORT_PADDING,
            0.25 + VIEWPORT_PADDING / 2,
            0.5 - VIEWPORT_PADDING / 2,
            0.5 - VIEWPORT_PADDING / 2,
        ),
        (
            0.5 + VIEWPORT_PADDING / 2,
            0.25 + VIEWPORT_PADDING / 2,
            1 - VIEWPORT_PADDING,
            0.5 - VIEWPORT_PADDING / 2,
        ),
        (
            VIEWPORT_PADDING,
            VIEWPORT_PADDING,
            0.5 - VIEWPORT_PADDING / 2,
            0.25 - VIEWPORT_PADDING / 2,
        ),
        (
            0.5 + VIEWPORT_PADDING / 2,
            VIEWPORT_PADDING,
            1 - VIEWPORT_PADDING,
            0.25 - VIEWPORT_PADDING / 2,
        ),
    ]

    variant_labels = ["A", "B", "C", "D"]

    for i, variant_label in enumerate(variant_labels):
        idx = i + 1

        plt.renderers[idx].SetViewport(*viewports[i])

        plt.at(idx).camera.SetPosition(CAMERA_POSITION)
        plt.at(idx).camera.SetFocalPoint(center_point)
        plt.at(idx).camera.SetViewUp(0, 0, 1)

        plt.at(idx).add(floor, shape)

        silhouette = shape.silhouette().linewidth(5)
        plt.at(idx).add(silhouette)

        shadow = shape.clone().project_on_plane(project_plane, point=positions[i])
        shadow.color("black")
        plt.at(idx).add(shadow)

        title = Text2D(
            variant_label,
            pos="top-left",
            c="black",
            bg="white",
            font="ComicMono",
            alpha=1.0,
            s=3.5,
        )
        plt.at(idx).add(title)

    correct_answer = variant_labels[target_idx]

    temp_dir = tempfile.mkdtemp()
    output_path = os.path.join(temp_dir, "render.png")

    plt.render(
        resetcam=False,
    )
    plt.screenshot(output_path)

    return {"IMAGE": output_path, "CORRECT": correct_answer}


if __name__ == "__main__":
    result = generate({"SHAPE": "pyramid"})
    print(f"Correct answer: {result['CORRECT']}")
    print(f"Image saved to: {result['IMAGE']}")
