import os
import math
import tempfile
import numpy as np
from vedo import (
    Plotter,
    Box,
    Cylinder,
    Plane,
    Arrow,
    Text3D,
    Text2D,
    settings,
    color_map,
)

OBJECT_MIN_POS = -5
OBJECT_MAX_POS = 5

CAMERA_MIN_POS = -5
CAMERA_MAX_POS = 5

MIN_CAMERA_DISTANCE = 2.0
MIN_OBJECT_DISTANCE = 2.0

VIEWPORT_PADDING = 0.01
COLOR_BRIGHTNESS = 1.1


def get_color(color_idx, cmap_name):
    color = color_map(color_idx, name=cmap_name, vmin=0, vmax=1) * COLOR_BRIGHTNESS

    return np.clip(color, 0, 1)


def create_random_box(cmap_name):
    size = np.random.uniform(1, 2, 3)
    box = Box(size=size)

    angle = np.random.uniform(0, 360)
    box.rotate_y(angle)

    position = np.array(
        [
            np.random.uniform(OBJECT_MIN_POS, OBJECT_MAX_POS),
            0,
            np.random.uniform(OBJECT_MIN_POS, OBJECT_MAX_POS),
        ]
    )
    position[1] = size[1] / 2
    box.pos(position)

    color_idx = np.random.random()
    color = get_color(color_idx, cmap_name)
    box.color(color)
    box.alpha(1.0)
    box.lighting("off")

    return box


def create_random_cylinder(cmap_name):
    radius = np.random.uniform(0.8, 1.2)
    height = np.random.uniform(1.0, 2)

    position = np.array(
        [
            np.random.uniform(OBJECT_MIN_POS, OBJECT_MAX_POS),
            0,
            np.random.uniform(OBJECT_MIN_POS, OBJECT_MAX_POS),
        ]
    )
    position[1] = height / 2

    cy1 = Cylinder(pos=position, r=radius, height=height, axis=(0, 1, 0))
    color_idx = np.random.random()
    color = get_color(color_idx, cmap_name)
    cy1.color(color)
    cy1.alpha(1.0)
    cy1.lighting("off")

    return cy1


def create_floor():
    floor = Plane(pos=(0, 0, 0), normal=(0, 1, 0), s=(12, 12))
    floor.color([0.9, 0.9, 0.9])
    return floor


def check_collision(new_obj, existing_objects):
    for obj in existing_objects:
        new_bounds = new_obj.bounds()
        obj_bounds = obj.bounds()

        if (
            new_bounds[0] < obj_bounds[1]
            and new_bounds[1] > obj_bounds[0]
            and new_bounds[2] < obj_bounds[3]
            and new_bounds[3] > obj_bounds[2]
            and new_bounds[4] < obj_bounds[5]
            and new_bounds[5] > obj_bounds[4]
        ):
            return True

    return False


def place_object_without_collision(
    create_func, existing_objects, cmap_name, max_tries=20
):
    for _ in range(max_tries):
        new_obj = create_func(cmap_name)
        if not check_collision(new_obj, existing_objects):
            return new_obj

    return new_obj


def generate_scene(box_count, cylinder_count, cmap_name):
    objects = []

    for _ in range(box_count):
        box = place_object_without_collision(
            create_random_box, objects, cmap_name, max_tries=20
        )
        objects.append(box)

    for _ in range(cylinder_count):
        cylinder = place_object_without_collision(
            create_random_cylinder, objects, cmap_name, max_tries=20
        )
        objects.append(cylinder)

    return objects


def create_camera_point(existing_objects, existing_cameras=[]):
    for _ in range(50):
        pos = np.array(
            [
                np.random.uniform(CAMERA_MIN_POS, CAMERA_MAX_POS),
                1.0,
                np.random.uniform(CAMERA_MIN_POS, CAMERA_MAX_POS),
            ]
        )

        too_close_to_camera = False
        for cam_pos, _ in existing_cameras:
            distance = np.linalg.norm(pos - cam_pos)
            if distance < MIN_CAMERA_DISTANCE:
                too_close_to_camera = True
                break

        if too_close_to_camera:
            continue

        direction_to_center = -pos.copy()
        direction_to_center[1] = 0

        random_angle = np.random.uniform(-math.pi / 4, math.pi / 4)
        cos_angle = math.cos(random_angle)
        sin_angle = math.sin(random_angle)

        x = direction_to_center[0]
        z = direction_to_center[2]
        direction_to_center[0] = x * cos_angle - z * sin_angle
        direction_to_center[2] = x * sin_angle + z * cos_angle

        norm = np.linalg.norm(direction_to_center)
        direction = np.where(norm > 0, direction_to_center / norm, [1, 0, 0])

        valid = True
        for obj in existing_objects:
            obj_center = obj.pos()
            distance = np.linalg.norm(pos - obj_center)
            if distance < MIN_OBJECT_DISTANCE:
                valid = False
                break

        if valid:
            return pos, direction

    return np.array([0, 1.0, 0]), np.array([1, 0, 0])


def generate_camera_points(count, existing_objects):
    cameras = []
    for _ in range(count):
        pos, direction = create_camera_point(existing_objects, cameras)
        cameras.append((pos, direction))
    return cameras


def create_camera_view(i, pos, direction, camera_labels):
    text = Text3D(
        camera_labels[i],
        pos=(0, 0, 0),
        s=0.5,
        c="red",
        font="ComicMono",
        justify="center",
    )
    text.rotate_x(90)
    text.rotate_z(180)
    text_pos = pos.copy() - direction.copy() * 0.5
    text_pos[1] += 2
    text.pos(text_pos)

    arrow = Arrow(pos, pos + direction, s=0.01, c="red")

    return arrow, text


def generate(variables):
    box_count = variables.get("BOX_COUNT", 2)
    cylinder_count = variables.get("CYLINDER_COUNT", 2)
    target = variables.get("TARGET", 0)
    color_map = variables.get("COLOR_MAP", "Pastel1")
    camera_count = 4

    scene_objects = generate_scene(box_count, cylinder_count, color_map)
    floor = create_floor()

    camera_points = generate_camera_points(camera_count, scene_objects)

    settings.use_depth_peeling = True
    plt = Plotter(
        N=camera_count + 1,
        interactive=False,
        sharecam=False,
        size=(1800, 1800),
        offscreen=True,
    )

    all_objects = [floor] + scene_objects

    camera_labels = ["A", "B", "C", "D"]
    np.random.shuffle(camera_labels)

    target_idx = camera_labels.index(target)

    view_labels = ["1", "2", "3", "4"]

    plt.renderers[0].SetViewport(
        VIEWPORT_PADDING,
        0.5 + VIEWPORT_PADDING / 2,
        1 - VIEWPORT_PADDING,
        1 - VIEWPORT_PADDING,
    )
    plt.renderers[5].SetViewport(0, 0, 0, 0)

    plt.at(0).add(all_objects)
    cam = plt.at(0).camera
    cam.SetPosition(0, 20, 0)
    cam.SetFocalPoint(0, 0, 0)
    cam.SetViewUp(0, 0, 1)
    cam.SetParallelProjection(True)
    cam.SetParallelScale(7)

    for i, (pos, direction) in enumerate(camera_points):
        arrow, text = create_camera_view(i, pos, direction, camera_labels)
        plt.at(0).add(arrow, text)

    silhouettes = [obj.silhouette().linewidth(5) for obj in scene_objects]
    plt.at(0).add(silhouettes)

    viewports = [
        (
            VIEWPORT_PADDING,
            0.25 + VIEWPORT_PADDING / 2,
            0.5 - VIEWPORT_PADDING / 2,
            0.5 - VIEWPORT_PADDING / 2,
        ),
        (
            0.5 + VIEWPORT_PADDING / 2,
            0.25 + VIEWPORT_PADDING / 2,
            1 - VIEWPORT_PADDING,
            0.5 - VIEWPORT_PADDING / 2,
        ),
        (
            VIEWPORT_PADDING,
            VIEWPORT_PADDING,
            0.5 - VIEWPORT_PADDING / 2,
            0.25 - VIEWPORT_PADDING / 2,
        ),
        (
            0.5 + VIEWPORT_PADDING / 2,
            VIEWPORT_PADDING,
            1 - VIEWPORT_PADDING,
            0.25 - VIEWPORT_PADDING / 2,
        ),
    ]

    camera_labels = ["A", "B", "C", "D"]

    for i, (pos, direction) in enumerate(camera_points):
        idx = i + 1

        vp = viewports[i]
        plt.renderers[idx].SetViewport(*vp)

        plt.at(idx).add(all_objects)

        cam = plt.at(idx).camera
        cam.SetPosition(pos)
        cam.SetFocalPoint(pos + direction)
        cam.SetViewUp(0, 1, 0)

        silhouettes = [obj.silhouette().linewidth(5) for obj in scene_objects]
        plt.at(idx).add(silhouettes)

        title = Text2D(
            view_labels[i],
            pos="top-left",
            c="black",
            bg="white",
            font="ComicMono",
            alpha=1.0,
            s=3.5,
        )
        plt.at(idx).add(title)

    temp_dir = tempfile.mkdtemp()
    output_pattern = os.path.join(temp_dir, "render.png")

    plt.render(
        resetcam=False,
    )
    plt.screenshot(output_pattern)

    return {"IMAGE": output_pattern, "CORRECT": view_labels[target_idx]}


if __name__ == "__main__":
    variables = {"BOX_COUNT": 5, "CYLINDER_COUNT": 2, "TARGET": "A"}
    result = generate(variables)
    print(f"Generated image: {result['IMAGE']}")
    print(f"Correct answer: {result['CORRECT']}")
