"""
Helper functions for visualising target tracking
"""

from dataclasses import dataclass
from pathlib import Path

import cv2
import numpy as np
from torch import Tensor
from torch import multiprocessing as mp


@dataclass(slots=True)
class TextFormat:
    string: str
    color: tuple[int, int, int]


def add_text_info(bgr_frame: np.ndarray, tidx: int) -> np.ndarray:
    """Adds time index and legend to the image"""
    fmt_common = {
        "fontFace": cv2.FONT_HERSHEY_COMPLEX,
        "fontScale": 1,
        "thickness": 2,
        "lineType": cv2.LINE_AA,
    }
    printables = [
        TextFormat("Predict", (0, 0, 0)),
        TextFormat("True", (0, 255, 0)),
        TextFormat("Target", (0, 0, 255)),
        TextFormat("Chaser", (255, 0, 0)),
    ]

    x = 20
    y = 40
    for printable in printables:
        bgr_frame = cv2.putText(
            bgr_frame, printable.string, (x, y), color=printable.color, **fmt_common
        )
        y += 30

    ts = TextFormat(f"ts:{tidx:4}", (0, 165, 255))
    x = bgr_frame.shape[0] - 160
    y = 40
    bgr_frame = cv2.putText(bgr_frame, ts.string, (x, y), color=ts.color, **fmt_common)

    return bgr_frame


def sim2pixel(xy: np.ndarray, resolution: np.ndarray) -> np.ndarray:
    """Convert sim xy to pixelspace (and int)"""
    return ((xy + 1) * resolution / 2).astype(int)


def create_target_assignment_frame(
    agents: np.ndarray,
    targets: np.ndarray,
    agent_target_gt: np.ndarray,
    agent_target_pred: np.ndarray,
    resolution: np.ndarray,
):
    """Create BGR frame of target assignment gt and prediction"""
    base_frame = np.full((*resolution, 3), 255, dtype=np.uint8)

    # Render agents
    agents_xy = sim2pixel(agents[..., :2], resolution)
    for agent in agents_xy:
        base_frame = cv2.circle(base_frame, agent, 10, (255, 0, 0), -1)

    # Render targets
    targets_xy = sim2pixel(targets[..., :2], resolution)
    for target in targets_xy:
        base_frame = cv2.circle(base_frame, target, 10, (0, 0, 255), -1)

    # Render Predicted Assignments
    for agent, assignment in zip(agents_xy, agent_target_pred):
        for target, score in zip(targets_xy, assignment):
            if score < 0.2:
                continue
            grey = (int(255 * (1 - score)),) * 3
            base_frame = cv2.line(base_frame, agent, target, grey, 5)

    # Render Ground Truth Assignments
    for agent, assignment in zip(agents_xy, agent_target_gt):
        target = targets_xy[assignment]
        base_frame = cv2.line(base_frame, agent, target, (0, 255, 0), 2)

    return base_frame


def write_target_video(
    agents: np.ndarray,
    targets: np.ndarray,
    agent_target_gt: np.ndarray,
    agent_target_pred: np.ndarray,
    path: Path,
    timestamps: list[float],
):
    """Write agent-target assignment prediction video"""
    video_shape = (800, 800)
    v_writer = cv2.VideoWriter(
        str(path), cv2.VideoWriter_fourcc(*"VP90"), 10, video_shape
    )
    if not v_writer.isOpened():
        raise RuntimeError(f"Unable to open video writer at {path}")

    _res = np.array(video_shape, dtype=int)
    for tidx in range(agents.shape[1]):
        bgr_frame = create_target_assignment_frame(
            agents[:, tidx],
            targets[:, tidx],
            agent_target_gt[:, tidx],
            agent_target_pred[:, tidx],
            _res,
        )
        bgr_frame = add_text_info(bgr_frame, tidx)

        v_writer.write(bgr_frame)

    v_writer.release()


def write_target_batch(
    data: dict[str, Tensor],
    pred: dict[str, Tensor],
    timestamps: list[float],
    write_folder: Path,
    gid: int,
):
    """Write batch of target assignment videos"""
    if not write_folder.exists():
        write_folder.mkdir(parents=True)

    mpool = mp.get_context("forkserver").Pool(processes=mp.cpu_count() * 3 // 4)

    for bidx in range(data["agents"].shape[0]):
        # Agent/target validity should be consistent across time
        agent_valid = data["agents_valid"][bidx, :, 0].bool()
        targets_valid = data["targets_valid"][bidx, :, 0].bool()

        agents = data["agents"][bidx, agent_valid].cpu().numpy()
        targets = data["targets"][bidx, targets_valid].cpu().numpy()
        assignment = data["agent_target"][bidx, agent_valid].cpu().numpy()
        prediction = pred["agent_target"][bidx, agent_valid].cpu().numpy()
        mpool.apply_async(
            write_target_video,
            kwds=dict(
                agents=agents,
                targets=targets,
                agent_target_gt=assignment,
                agent_target_pred=prediction,
                timestamps=timestamps,
                path=write_folder / f"{gid + bidx}.webm",
            ),
        )

    mpool.close()
    mpool.join()
