"""
usage: manim route_visulization.py RouteVisualization
"""

from dataclasses import dataclass
import json
import os
import random

import manim as man
import numpy as np

man.config.frame_rate = 60
man.config.pixel_height = 1080
man.config.pixel_width = 1920
man.config.background_color = man.WHITE
man.config.background_opacity = 1

# Parameters to adjust the visualization position and size
G_SIZE = 0.6
START_OFFSET = np.array([-5.0 * G_SIZE, -7.0 * G_SIZE])
END_OFFSET = np.array([0, 11.0 * G_SIZE]) + START_OFFSET
PER_GPU_GAP = G_SIZE * 0.2
EPS = 2e-2 * G_SIZE
UNIT_CHUNK_LEN = 256

# Default value - can be overridden by environment variable
# This file is generated by knapformer.sequence_balancer.SequenceBalancer.get_routing_plan_summary()
BALANCER_PLAN_FILE_PATH = os.environ.get("BALANCER_PLAN_FILE", "routing_plan.json")

COLORS = [
    "#B56E6E", "#C08D6A", "#C5A86A", "#AEB46A",
    "#77B46E", "#6EB496", "#6EB0B5", "#6E90B5",
    "#6F6EB5", "#8C6EB5", "#AA6EB5", "#B56E9C",
    "#B56E82", "#8E8E8E", "#B5B5A0", "#A0B5B5",
]

# random shuffle colors
random.seed(0)
random.shuffle(COLORS)


@dataclass
class ChunkRoutingPlan:
    start: np.ndarray
    end: np.ndarray
    shape: np.ndarray
    from_gpu: int
    to_bag: int
    color: str
    center_start: np.ndarray | None = None
    center_end: np.ndarray | None = None

    def __post_init__(self):
        self.center_start = self.start + self.shape / 2
        self.center_end = self.end + self.shape / 2


def plan_per_chunk_routes(plan_file_path: str = None):
    if plan_file_path is None:
        plan_file_path = BALANCER_PLAN_FILE_PATH

    balancer_plan = json.load(open(plan_file_path))

    for k, v in balancer_plan.items():
        if isinstance(v, dict):
            # To int-key dict
            balancer_plan[k] = {int(k): v for k, v in v.items()}

    seq_id2chunk_ids = {}
    for chunk_id, seq_id in balancer_plan["chunk_id2seq_id"].items():
        if seq_id not in seq_id2chunk_ids:
            seq_id2chunk_ids[int(seq_id)] = []

        seq_id2chunk_ids[int(seq_id)].append(int(chunk_id))
    balancer_plan["seq_id2chunk_ids"] = seq_id2chunk_ids

    balance_gpu_id2_chunk_ids = {}
    for chunk_id, bag_id in balancer_plan["balance_chunk_id2gpu_id"].items():
        if bag_id not in balance_gpu_id2_chunk_ids:
            balance_gpu_id2_chunk_ids[int(bag_id)] = []

        balance_gpu_id2_chunk_ids[int(bag_id)].append(int(chunk_id))
    balancer_plan["balance_gpu_id2_chunk_ids"] = balance_gpu_id2_chunk_ids

    gpu_id2_seq_len_pre_sums = {}

    for gpu_id, seq_ids in balancer_plan["gpu_id2seq_ids"].items():
        pre_sums = []
        pre_sum = 0
        for seq_id in seq_ids:
            pre_sums.append(pre_sum)
            pre_sum += balancer_plan["seq_id2seq_len"][seq_id]

        gpu_id2_seq_len_pre_sums[int(gpu_id)] = pre_sums
    balancer_plan["gpu_id2_seq_len_pre_sums"] = gpu_id2_seq_len_pre_sums

    seq_id_chunk_len_pre_sums = {}
    for seq_id, chunk_ids in seq_id2chunk_ids.items():
        pre_sums = []
        pre_sum = 0
        for chunk_id in chunk_ids:
            pre_sums.append(pre_sum)
            pre_sum += balancer_plan["balance_chunk_id2chunk_len"][chunk_id]

        seq_id_chunk_len_pre_sums[int(seq_id)] = pre_sums
    balancer_plan["seq_id_chunk_len_pre_sums"] = seq_id_chunk_len_pre_sums

    balance_gpu_id2_chunk_len_pre_sums = {}
    for gpu_id, chunk_ids in balance_gpu_id2_chunk_ids.items():
        pre_sums = []
        pre_sum = 0
        for chunk_id in chunk_ids:
            pre_sums.append(pre_sum)
            pre_sum += balancer_plan["balance_chunk_id2chunk_len"][chunk_id]

        balance_gpu_id2_chunk_len_pre_sums[int(gpu_id)] = pre_sums
    balancer_plan["balance_gpu_id2_chunk_len_pre_sums"] = balance_gpu_id2_chunk_len_pre_sums

    seq_id2gpu_id = {}
    for gpu_id, seq_ids in balancer_plan["gpu_id2seq_ids"].items():
        for seq_id in seq_ids:
            seq_id2gpu_id[int(seq_id)] = int(gpu_id)
    balancer_plan["seq_id2gpu_id"] = seq_id2gpu_id

    balancer_plan["seq_id2chunk_ids"] = seq_id2chunk_ids

    per_chunk_routing_plans = []
    for i in range(len(balancer_plan["chunk_id2seq_id"])):
        chunk_len = balancer_plan["balance_chunk_id2chunk_len"][i]
        seq_id = balancer_plan["chunk_id2seq_id"][i]
        gpu_id = balancer_plan["seq_id2gpu_id"][seq_id]
        to_bag_id = balancer_plan["balance_chunk_id2gpu_id"][i]
        to_gpu_id = balancer_plan["balance_chunk_id2gpu_id"][i]
        local_chunk_at_seq_position = seq_id_chunk_len_pre_sums[seq_id][balancer_plan["seq_id2chunk_ids"][seq_id].index(i)]
        local_seq_at_gpu_position = gpu_id2_seq_len_pre_sums[gpu_id][balancer_plan["gpu_id2seq_ids"][gpu_id].index(seq_id)]

        position = (local_seq_at_gpu_position + local_chunk_at_seq_position) / UNIT_CHUNK_LEN
        to_position = (balance_gpu_id2_chunk_len_pre_sums[to_gpu_id][balancer_plan["balance_gpu_id2_chunk_ids"][to_gpu_id].index(i)]) / UNIT_CHUNK_LEN
        chunk_len = chunk_len / UNIT_CHUNK_LEN

        cur_routing_plan = ChunkRoutingPlan(
            start=np.array([gpu_id, position]),
            end=np.array([to_gpu_id, to_position]),
            shape=np.array([1, chunk_len]),
            from_gpu=gpu_id,
            to_bag=to_bag_id,
            color=COLORS[seq_id]
        )
        assert seq_id < len(COLORS)

        per_chunk_routing_plans.append(cur_routing_plan)

    return per_chunk_routing_plans, balancer_plan


def convert_to_manim_coords(coords: np.ndarray) -> np.ndarray:
    coords = coords.copy() * G_SIZE
    coords += START_OFFSET
    return np.array([coords[1], -coords[0], 0.0])


def convert_to_start_manim_coords(coords: np.ndarray, from_gpu: int) -> np.ndarray:
    coords = coords.copy() * G_SIZE
    coords[0] += from_gpu * PER_GPU_GAP
    coords += START_OFFSET
    return np.array([coords[1], -coords[0], 0.0])


def convert_to_end_manim_coords(coords: np.ndarray, to_bag: int) -> np.ndarray:
    coords = coords.copy() * G_SIZE
    coords[0] += to_bag * PER_GPU_GAP
    coords += END_OFFSET
    return np.array([coords[1], -coords[0], 0.0])


class RouteVisualization(man.Scene):
    def construct(self):
        self.per_chunk_routing_plans, self.balancer_plan = plan_per_chunk_routes()
        self.add_gpu_blocks()
        self.add_routing_plans()

        # Keep the final state visible for a moment
        self.wait(1)

    def add_routing_plans(self):
        transforms = []
        for _, routing_plan in enumerate(self.per_chunk_routing_plans):
            print(routing_plan)

            data_block = man.Rectangle(
                height=routing_plan.shape[0] * G_SIZE + EPS,
                width=routing_plan.shape[1] * G_SIZE + EPS,
                fill_color=man.ManimColor(routing_plan.color),
                fill_opacity=1.0,
                stroke_width=0.0,
                stroke_opacity=0.0,
            )
            data_block.move_to(convert_to_start_manim_coords(routing_plan.center_start, routing_plan.from_gpu))
            data_block_fixed = data_block.copy()
            self.add(data_block_fixed)

            transforms.append(
                man.Transform(
                    data_block,
                    data_block.copy().move_to(convert_to_end_manim_coords(routing_plan.center_end, routing_plan.to_bag))
                ),
            )

        self.play(*transforms, run_time=2)

    def add_gpu_blocks(self):
        # Add GPU labels
        n_gpus = len(self.balancer_plan["gpu_id2seq_ids"])
        for idx in range(n_gpus):
            gpu_block = man.Text(f"GPU {idx}", font_size=25 * G_SIZE, color=man.BLACK, font="Calibri")
            gpu_block.move_to(convert_to_start_manim_coords(np.array([idx + 0.5, -1.5]), idx))
            self.add(gpu_block)

        # Add title texts
        self._add_title_text("Imbalanced", np.array([10.0, 5.0]))
        self._add_title_text("Balanced", np.array([10.0, 14.0]))

    def _add_title_text(self, text: str, position: np.ndarray):
        """Helper method to add title text with consistent styling."""
        title = man.Text(text, font_size=300 * G_SIZE, color=man.BLACK, font="Calibri")
        title.scale(0.1)
        title.move_to(convert_to_manim_coords(position))
        self.add(title)
