import os
import math
import tempfile
import numpy as np
from vedo import LightKit, Plotter, Sphere, Box, color_map

BALL_SIZE = 1

PLANK_WIDTH = 1
PLANK_LENGTH = 4
PLANK_HEIGHT = 0.1

LAYER_HEIGHT = BALL_SIZE + PLANK_HEIGHT

INIT_LAYER = {
    "balls": [
        (0, 0),
    ],
    "planks": [
        (0, 0, 0, 1),
    ],
}

COLOR_BRIGHTNESS = 1.3


def get_color(color_idx, cmap_name):
    color = color_map(color_idx, name=cmap_name, vmin=0, vmax=1) * COLOR_BRIGHTNESS

    return np.clip(color, 0, 1)


def create_ball(grid_pos, layer_index, color_idx, color_map_name):
    z = layer_index * (BALL_SIZE + PLANK_HEIGHT) + BALL_SIZE / 2
    return Sphere(pos=(grid_pos[0], grid_pos[1], z), r=BALL_SIZE / 2).color(
        get_color(color_idx, color_map_name)
    )


def create_plank(grid_pos_rot, layer_index):
    *grid_pos, x_size, y_size = grid_pos_rot
    z = layer_index * (BALL_SIZE + PLANK_HEIGHT) + BALL_SIZE + PLANK_HEIGHT / 2

    xs = PLANK_LENGTH * x_size + PLANK_WIDTH * y_size
    ys = PLANK_LENGTH * y_size + PLANK_WIDTH * x_size

    return Box(pos=(grid_pos[0], grid_pos[1], z), size=(xs, ys, PLANK_HEIGHT)).color(
        "white"
    )


def place_balls(plank):
    *grid_pos, a, b = plank
    random_val = np.random.random()

    coef = PLANK_LENGTH / 3

    if random_val < 0.2:
        return []
    elif random_val < 0.3:
        pos = np.random.randint(-1, 2)
        return [
            (
                grid_pos[0] + a * coef * pos,
                grid_pos[1] + b * coef * pos,
            )
        ]
    else:
        return [
            (
                grid_pos[0] - a * coef,
                grid_pos[1] - b * coef,
            ),
            (
                grid_pos[0] + a * coef,
                grid_pos[1] + b * coef,
            ),
        ]


def is_plank_overlapping(plank, occupied):
    x, y, orient_x, orient_y = plank

    xs = PLANK_LENGTH * orient_x + PLANK_WIDTH * orient_y
    ys = PLANK_LENGTH * orient_y + PLANK_WIDTH * orient_x

    half_xs = xs / 2
    half_ys = ys / 2

    min_x = math.ceil(x - half_xs)
    max_x = math.ceil(x + half_xs)
    min_y = math.ceil(y - half_ys)
    max_y = math.ceil(y + half_ys)

    for check_x in range(min_x, max_x):
        for check_y in range(min_y, max_y):
            if (check_x, check_y) in occupied:
                return True

    return False


def mark_plank_occupied(plank, occupied):
    x, y, orient_x, orient_y = plank

    xs = PLANK_LENGTH * orient_x + PLANK_WIDTH * orient_y
    ys = PLANK_LENGTH * orient_y + PLANK_WIDTH * orient_x

    half_xs = xs / 2
    half_ys = ys / 2

    min_x = math.ceil(x - half_xs)
    max_x = math.ceil(x + half_xs)
    min_y = math.ceil(y - half_ys)
    max_y = math.ceil(y + half_ys)

    for check_x in range(min_x, max_x):
        for check_y in range(min_y, max_y):
            occupied[(check_x, check_y)] = True


def place_planks(balls):
    new_planks = []
    occupied = {}
    processed_balls = set()
    ball_pairs = {i: [] for i in range(len(balls))}

    for i, ball1 in enumerate(balls):
        for j, ball2 in enumerate(balls[i + 1 :], i + 1):
            x_dist = abs(ball1[0] - ball2[0])
            y_dist = abs(ball1[1] - ball2[1])

            if (x_dist <= 2 and y_dist == 0) or (y_dist <= 2 and x_dist == 0):
                ball_pairs[i].append(j)
                ball_pairs[j].append(i)

    for i, ball_pos in enumerate(balls):
        if i in processed_balls:
            continue

        if np.random.random() < 0.2:
            continue

        if ball_pairs[i] and np.random.random() < 0.3:
            pair_candidates = ball_pairs[i].copy()
            np.random.shuffle(pair_candidates)

            paired = False
            for j in pair_candidates:
                if j in processed_balls:
                    continue

                ball1 = balls[i]
                ball2 = balls[j]

                mid_x = (ball1[0] + ball2[0]) / 2
                mid_y = (ball1[1] + ball2[1]) / 2

                if ball1[0] == ball2[0]:
                    orient_x, orient_y = 0, 1
                else:
                    orient_x, orient_y = 1, 0

                plank = (mid_x, mid_y, orient_x, orient_y)

                if not is_plank_overlapping(plank, occupied):
                    new_planks.append(plank)
                    mark_plank_occupied(plank, occupied)
                    processed_balls.add(i)
                    processed_balls.add(j)
                    paired = True
                    break

            if not paired and i not in processed_balls:
                for _ in range(10):
                    orient = 1 if np.random.rand() < 0.5 else 0
                    x, y = orient, 1 - orient

                    plank = (ball_pos[0], ball_pos[1], x, y)

                    if not is_plank_overlapping(plank, occupied):
                        new_planks.append(plank)
                        mark_plank_occupied(plank, occupied)
                        processed_balls.add(i)
                        break
        else:
            for _ in range(10):
                orient = 1 if np.random.rand() < 0.5 else 0
                x, y = orient, 1 - orient

                plank = (ball_pos[0], ball_pos[1], x, y)

                if not is_plank_overlapping(plank, occupied):
                    new_planks.append(plank)
                    mark_plank_occupied(plank, occupied)
                    processed_balls.add(i)
                    break

    return new_planks


def create_layer(layer):
    max_attempts = 5
    attempt = 0

    while attempt < max_attempts:
        new_balls = []

        for plank in layer["planks"]:
            new_balls.extend(place_balls(plank))

        new_planks = place_planks(new_balls)

        if len(new_balls) > 0 and len(new_planks) > 0:
            break

        attempt += 1

    if len(new_balls) == 0:
        new_balls = [(0, 0)]

    if len(new_planks) == 0 and len(new_balls) > 0:
        ball_pos = new_balls[0]
        new_planks = [(ball_pos[0], ball_pos[1], 1, 0)]

    return {"balls": new_balls, "planks": new_planks}


def create_structure(init_layer, num_layers=5):
    layers = [init_layer]

    for i in range(num_layers - 1):
        layers.append(create_layer(layers[-1]))

    return layers


def generate(variables):
    color_map = variables.get("COLOR_MAP", "Pastel1")
    num_layers = variables.get("NUM_LAYERS", 4)

    plt = Plotter(size=(1200, 1200), bg="white", axes=0)
    cam = plt.at(0).camera
    cam.SetPosition((16, 16, 10))
    cam.SetFocalPoint((0, 0, num_layers * LAYER_HEIGHT / 2))
    cam.SetViewUp(0, 0, 1)

    layers = create_structure(INIT_LAYER, num_layers)

    for layer_idx, layer in enumerate(layers):
        for ball_pos in layer["balls"]:
            ball = create_ball(ball_pos, layer_idx, np.random.rand(), color_map)
            plt.add(
                ball,
            )

        if layer_idx == num_layers - 1:
            continue

        for plank_pos in layer["planks"]:
            plank = create_plank(plank_pos, layer_idx)
            plt.add(
                plank,
            )

    temp_dir = tempfile.mkdtemp()
    output_path = os.path.join(temp_dir, "render.png")

    plt.add(LightKit())

    plt.show(interactive=False, resetcam=False, screenshot=output_path)

    return {
        "IMAGE": output_path,
    }


if __name__ == "__main__":
    generate({})
