from pathlib import Path

import cv2
import numpy as np
from torch import Tensor
from torch import multiprocessing as mp

_DEFAULT_TEXT_ARGS = [cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 165, 255), 2, cv2.LINE_AA]


def apply_ts_text(ts: float, frame: np.ndarray, extra: str = "") -> np.ndarray:
    """Apply timestamp text to image frame and optional "extra" text underneath"""
    if ts < 0:
        text = "past"
    elif ts == 0:
        text = "present"
    else:
        text = "future"

    frame = cv2.putText(frame, f"{text}: {ts:+}", (25, 50), *_DEFAULT_TEXT_ARGS)

    if extra:
        frame = cv2.putText(frame, extra, (25, 75), *_DEFAULT_TEXT_ARGS)

    return frame


def create_occupancy_frame(
    ground_truth: np.ndarray,
    prediction: np.ndarray | None,
    resolution: tuple[int, int],
    threshold: float | None = None,
) -> np.ndarray:
    """Create an rgb frame showing the ground truth and predicted frames"""
    bgr_frame = 255 * np.ones((*resolution, 3), dtype=np.uint8)

    ground_truth = cv2.resize(ground_truth, resolution, interpolation=cv2.INTER_NEAREST)

    # Red for false negatives
    bgr_frame[ground_truth == 1] = np.array((0, 0, 200), dtype=np.uint8)
    if prediction is None:
        return bgr_frame

    prediction = cv2.resize(prediction, resolution, interpolation=cv2.INTER_LINEAR)
    # Blue for false positives
    if threshold is not None:
        bgr_frame[prediction > threshold] = np.array((255, 0, 0), dtype=np.uint8)
    else:
        mask = ground_truth == 0
        rg = (255 * prediction).astype(np.uint8)[mask]
        b = np.zeros_like(rg)
        # subtract rg from prediction
        bgr_frame[mask] -= np.stack([b, rg, rg], axis=-1)

    # Green for true positives
    if threshold is not None:
        bgr_frame[(prediction > threshold) & (ground_truth == 1)] = np.array(
            (0, 255, 0), dtype=np.uint8
        )
    else:
        mask = (prediction > 0.5) & (ground_truth == 1)
        rb = ((1 - prediction) * 255).astype(np.uint8)[mask]
        g = 200 * np.ones_like(rb)
        bgr_frame[mask] = np.stack([rb, g, rb], axis=-1)

    return bgr_frame


def write_occupancy_video(
    agent_gt: np.ndarray,
    target_gt: np.ndarray,
    agent_pred: np.ndarray | None,
    target_pred: np.ndarray | None,
    timestamps: list[float],
    path: Path,
    thresh: float | None = None,
) -> None:
    """Write video of prediction over time"""

    video_sz = (1600, 800)
    v_writer = cv2.VideoWriter(str(path), cv2.VideoWriter_fourcc(*"VP90"), 10, video_sz)
    frame_sz = (800, 800)

    if not v_writer.isOpened():
        raise RuntimeError(f"Can't write video, writer not open: {path}")

    for idx, timestamp in enumerate(timestamps):
        agent_pred_t = agent_pred[idx] if agent_pred is not None else None
        bgr_agent = create_occupancy_frame(
            agent_gt[idx], agent_pred_t, frame_sz, thresh
        )
        bgr_agent = apply_ts_text(timestamp, bgr_agent, extra="Agents")

        target_pred_t = target_pred[idx] if target_pred is not None else None
        bgr_target = create_occupancy_frame(
            target_gt[idx], target_pred_t, frame_sz, thresh
        )
        bgr_target = cv2.putText(bgr_target, "Targets", (25, 50), *_DEFAULT_TEXT_ARGS)
        bgr_frame = cv2.hconcat([bgr_agent, bgr_target])
        v_writer.write(bgr_frame)

    v_writer.release()


def write_occupancy_batch(
    data: dict[str, Tensor],
    pred: dict[str, Tensor],
    timestamps: list[float],
    path: Path,
    gid: int,
    threshold: float | None = None,
) -> None:
    """Write batch of videos"""
    mpool = mp.get_context("forkserver").Pool(processes=mp.cpu_count() * 3 // 4)
    bz = data["agents"].shape[0]

    occ_path = path / "occupancy"
    occ_path.mkdir(parents=True, exist_ok=True)

    # Sigmoid prediction logits to get predicted probability
    if "agents_occ" in pred:
        pred_agents_occ = pred["agents_occ"].sigmoid().cpu().numpy()
    else:
        pred_agents_occ = [None] * bz

    if "targets_occ" in pred:
        pred_targets_occ = pred["targets_occ"].sigmoid().cpu().numpy()
    else:
        pred_targets_occ = [None] * bz

    for bidx in range(bz):
        mpool.apply_async(
            write_occupancy_video,
            kwds=dict(
                agent_gt=data["agents_occ"][bidx].cpu().numpy(),
                target_gt=data["targets_occ"][bidx].cpu().numpy(),
                agent_pred=pred_agents_occ[bidx],
                target_pred=pred_targets_occ[bidx],
                path=occ_path / f"{gid + bidx}.webm",
                thresh=threshold,
                timestamps=timestamps,
            ),
        )

    mpool.close()
    mpool.join()
