import multiprocessing as mp
import os
from itertools import cycle
from pathlib import Path
from typing import Annotated

import cv2
import numpy as np
import torch
import typer

try:
    from ffmpegcv import VideoWriterNV
except RuntimeError as err:
    print(f"Unable to import ffmpegcv with error: {err}")
from konductor.utilities.pbar import LivePbar
from matplotlib import pyplot as plt
from sc2_serializer import ReplayDataAll, set_replay_database_logger_level, spdlog_lvl
from torch import Tensor

from ..dataset import sc2_dataset as sc2
from ..dataset.sc2_preproc import (
    ReplayDataset,
    SequenceData,
    calculate_smooth_damage_dealt,
    find_battle_events,
)
from ..utils.eval_common import initialize
from ..utils.position_transforms import PositionTransform

app = typer.Typer()
set_replay_database_logger_level(spdlog_lvl.err)


@app.command()
def show_battle_phases(
    dataset: Annotated[Path, typer.Option()] = Path(
        os.environ.get("DATAPATH", "/data")
    ),
    index: Annotated[int, typer.Option()] = 0,
    window_size: Annotated[int, typer.Option()] = sc2.BATTLE_WINDOW_SIZE_DEFAULT,
    threshold: Annotated[float, typer.Option()] = sc2.BATTLE_THRESHOLD_DEFAULT,
):
    """Plot instantainous damage dealt by each player and show results of battle phase
    calculations"""
    replay = ReplayDataset.sample(dataset, index)

    damage_dealt = calculate_smooth_damage_dealt(replay.data, window_size)
    plt.figure("damage dealt")
    plt.plot(replay.data.gameStep, damage_dealt)
    events = find_battle_events(replay.data, window_size, threshold)
    for event, c in zip(events, cycle(["b", "r"])):
        plt.axvline(replay.data.gameStep[event.start], color=c)
        plt.axvline(replay.data.gameStep[event.end], color=c)
    plt.show()


def validate_tgtId(parser: sc2.ReplayDataAllParser, event: sc2.BattleEvent):
    """Check that unit.tgtId == unit.order0.tgtId if unit.order0.ability_id == 3674 (Attack)"""
    counter = 0
    for time_idx in range(event.start, event.end):
        for unit in parser.data.units[time_idx]:
            # If attack order on a target unit, assert tgtId is the same
            if unit.order0.ability_id == 3674 and unit.order0.tgtId != 0:
                assert unit.tgtId == unit.order0.tgtId
                counter += 1
    print(f"validated {counter} targets")


def load_replay_sample(
    replay: ReplayDataAll,
    window_size: int,
    threshold: float,
    duration: int,
    roi_shape: np.ndarray | None,
):
    """Load an example of battle sequence data from a replay"""
    events = sc2.find_battle_events(replay.data, window_size, threshold)
    print(f"Found {len(events)} battle events in the replay")
    try:
        event = next(e for e in events if e.duration > duration)
    except StopIteration:
        event = max(events, key=lambda e: e.duration)
    # Set event end to start + duration
    event.end = event.start + duration

    parser = sc2.ReplayDataAllParser(sc2.GAME_INFO_FILE)
    parser.parse_replay(replay)
    unit_feat_idx = [sc2.UnitOH.x, sc2.UnitOH.y, sc2.UnitOH.t, sc2.UnitOH.alliance_self]

    validate_tgtId(parser, event)

    sequence_data = sc2.extract_battle_sequence(parser, event, unit_feat_idx)
    sequence_data.minimap = sc2.normalize_minimap_coordinates(
        replay.header.heightMap.data, replay.header.mapHeight, replay.header.mapWidth
    )

    if roi_shape is not None:
        map_size = np.array([replay.header.mapHeight, replay.header.mapWidth])
        roi_center = sc2.find_main_roi(
            sequence_data.units, sequence_data.unit_targets, roi_shape, map_size
        )
        sc2.apply_roi_transforms(
            sequence_data, roi_shape, roi_center, minimap_res=np.full(2, 256, dtype=int)
        )
        if sequence_data.minimap is not None:
            sequence_data.minimap = (sequence_data.minimap.astype(float) - 127) / 128

    sequence_data.separate_self_units(unit_feat_idx.index(sc2.UnitOH.alliance_self))

    return sequence_data


@app.command()
def show_unit_assignments(
    dataset: Annotated[Path, typer.Option()] = Path(
        os.environ.get("DATAPATH", "/data")
    ),
    index: Annotated[int, typer.Option()] = 0,
    duration: Annotated[int, typer.Option()] = 50,
    roi_size: Annotated[int, typer.Option()] = 20,
    window_size: Annotated[int, typer.Option()] = sc2.BATTLE_WINDOW_SIZE_DEFAULT,
    threshold: Annotated[float, typer.Option()] = sc2.BATTLE_THRESHOLD_DEFAULT,
):
    """Visualize video of unit motion and targeting"""

    roi_shape = np.array([roi_size, roi_size], dtype=int)
    dataloader = ReplayDataset(dataset)

    # from scalene import scalene_profiler

    # scalene_profiler.start()
    # for i in range(16):
    replay = dataloader[index]
    sequence_data = load_replay_sample(
        replay, window_size, threshold, duration, roi_shape
    )
    # data = sequence_data.as_contiguous_arrays()
    # print(i)

    writer = VideoWriterNV("test.mp4", "h264", 2)

    for idx in range(len(sequence_data)):
        image = rasterize_frame(sequence_data, idx)
        writer.write(image)
    # cv2.imshow("", image)
    # cv2.waitKey(-1)
    # cv2.imwrite(f"{tidx:04}.png", image)
    # plt.scatter(units[:, 0], units[:, 1])
    # plt.xlim(-1, 1)
    # plt.ylim(-1, 1)
    # plt.show()
    # writer.close()


def tensor_to_sequence_data(data: sc2.TorchSC2Data) -> SequenceData:
    """Convert torch training loaded tensor data to sequence data format"""

    seq_data = SequenceData(
        units=[
            unit[mask].cpu().numpy() for unit, mask in zip(data.units, data.units_mask)
        ],
        unit_targets=[
            tgt[mask].cpu().numpy()
            for tgt, mask in zip(data.unit_targets, data.units_mask)
        ],
        minimap=data.minimap.cpu().numpy() if data.minimap is not None else None,
    )

    if data.enemy_units is not None:
        assert data.enemy_mask is not None
        seq_data.enemy_units = [
            enemy[mask].cpu().numpy()
            for enemy, mask in zip(data.enemy_units, data.enemy_mask)
        ]

    if data.positions is not None:
        seq_data.positions = [
            pos[mask].cpu().numpy()
            for pos, mask in zip(data.positions, data.units_mask)
        ]

    if data.positions_unique is not None:
        assert data.positions_mask is not None
        assert data.position_targets is not None
        seq_data.positions_unique = [
            pos[mask].cpu().numpy()
            for pos, mask in zip(data.positions_unique, data.positions_mask)
        ]
        seq_data.position_targets = [
            pos[mask].cpu().numpy()
            for pos, mask in zip(data.position_targets, data.units_mask)
        ]

    return seq_data


def get_relative_line_tip_length(
    src: tuple[int, int], dst: tuple[int, int], tip_length: int = 8
):
    """Get the 'relative' value for an absolute value line tip length"""
    line_length = np.linalg.norm(np.array(src) - np.array(dst))
    return tip_length / line_length


def overlay_unit_prediction(
    canvas: np.ndarray,
    pred: Tensor,
    sources: Tensor,
    target: Tensor,
    truth: Tensor,
):
    """Add weiged rendering from prediction source to target"""

    def to_image_coords(x: np.ndarray):
        return tuple(((x + 1) * canvas.shape[0] / 2).astype(int))

    weights, assigns = torch.topk(pred, k=min(pred.shape[-1], 3), dim=-1)
    assigns -= 1  # -1 for null-assignment
    for assign, weight, src, gt in zip(assigns, weights, sources, truth):
        src_img = to_image_coords(src[:2].cpu().numpy())
        for idx, w in zip(assign, weight):
            if w < 0.5:
                continue
            color = (int(255 * w), 0, 0)
            if idx == -1:
                canvas = cv2.circle(canvas, src_img, 10, color, thickness=2)
            else:
                tgt_img = to_image_coords(target[idx, :2].cpu().numpy())
                canvas = cv2.arrowedLine(
                    canvas,
                    src_img,
                    tgt_img,
                    color,
                    thickness=2,
                    tipLength=get_relative_line_tip_length(src_img, tgt_img),
                )

        # Green for TP, Red for FN
        color = (0, 255, 0) if torch.isin(gt, assign).item() else (0, 0, 255)
        if gt.item() == -1:
            canvas = cv2.circle(canvas, src_img, 10, color, thickness=2)
        else:
            tgt_img = to_image_coords(target[gt, :2].cpu().numpy())
            canvas = cv2.arrowedLine(
                canvas,
                src_img,
                tgt_img,
                color,
                thickness=2,
                tipLength=get_relative_line_tip_length(src_img, tgt_img),
            )

    return canvas


def render_target(
    canvas: np.ndarray,
    src: tuple[int, int],
    dst: tuple[int, int],
    marker: int,
    color: tuple[int, int, int],
):
    """_summary_

    Args:
        canvas (np.ndarray): _description_
        src (tuple[int,int]): _description_
        dst (tuple[int,int]): _description_
        marker (int): _description_
        color (tuple[int,int,int]): _description_
    """
    canvas = cv2.drawMarker(
        canvas,
        dst,
        color=color,
        thickness=2,
        markerType=marker,
    )
    canvas = cv2.arrowedLine(
        canvas,
        src,
        dst,
        color=color,
        thickness=2,
        tipLength=get_relative_line_tip_length(src, dst),
    )
    return canvas


def overlay_pos_prediction(
    canvas: np.ndarray,
    units: Tensor,
    positions: Tensor,
    position_probs: Tensor,
    targets: Tensor,
    truth: Tensor,
):
    """Overlay position prediction"""
    has_truth = torch.all(truth < 1, dim=-1)
    has_truth &= torch.all(truth > -1, dim=-1)

    half_dim = canvas.shape[0] / 2

    def to_image_coords(x: np.ndarray):
        return tuple(((x + 1) * half_dim).astype(int))

    for unit, position, prob, target, gt, has_gt in zip(
        units, positions, position_probs, targets, truth, has_truth
    ):
        unit_img = to_image_coords(unit[:2].cpu().numpy())
        if has_gt:
            canvas = render_target(
                canvas,
                unit_img,
                to_image_coords(gt.cpu().numpy()),
                cv2.MARKER_TRIANGLE_UP,
                (255, 255, 0),  # Cyan
            )

        if torch.argmax(target) != 0 or prob < 0.25:
            continue

        color = (255, 0, 255)  # Purple
        pos_img = to_image_coords(position[:2].cpu().numpy())
        canvas = render_target(
            canvas, unit_img, pos_img, cv2.MARKER_TRIANGLE_DOWN, color
        )
        if position.shape[0] == 3:
            rad = int(position[-1] * half_dim)
            canvas = cv2.circle(canvas, pos_img, rad, color, 2)
        elif position.shape[0] == 4:
            rad = tuple(int(i) for i in position[-2:].abs() * half_dim)
            canvas = cv2.ellipse(canvas, pos_img, rad, 0, 0, 360, color, thickness=2)

    return canvas


def overlay_prediction(
    canvas: np.ndarray,
    tidx: int,
    pred: dict[str, Tensor],
    data: sc2.TorchSC2Data,
    pos_decoder: PositionTransform,
):
    """Write prediction over frame"""
    units = data.units[tidx, data.units_mask[tidx]]
    if data.enemy_units is None:
        targets = data.units[tidx]
        targets_mask = data.units_mask[tidx]
    else:
        assert data.enemy_mask is not None
        targets = data.enemy_units[tidx]
        targets_mask = data.enemy_mask[tidx]
    targets = targets[targets_mask]

    pred_mask = torch.cat([targets_mask.new_ones(1), targets_mask], dim=0)
    pred_assign = pred["unit-target"][tidx, data.units_mask[tidx]]
    pred_assign = pred_assign[:, pred_mask]
    canvas = overlay_unit_prediction(
        canvas,
        pred_assign,
        units,
        targets,
        data.unit_targets[tidx, data.units_mask[tidx]],
    )
    if "position" in pred:
        assert "unit-target" in pred
        assert data.positions is not None, "Only position values implemented"
        positions = pred["position"][tidx, data.units_mask[tidx]]
        gt_positions = data.positions[tidx, data.units_mask[tidx]]
        positions = pos_decoder(positions, gt_positions)
        if "pos-logit" in pred:
            position_prob = torch.sigmoid(
                pred["pos-logit"][tidx, data.units_mask[tidx]]
            )
        else:
            position_prob = torch.ones(
                positions.shape[0], dtype=torch.float32, device=positions.device
            )
        canvas = overlay_pos_prediction(
            canvas, units, positions, position_prob, pred_assign, gt_positions
        )

    return canvas


def write_prediction_video(
    filepath: str,
    pred: dict[str, Tensor],
    data: sc2.TorchSC2Data,
    pos_decoder: PositionTransform,
    resolution: int = 1024,
):
    """Write video prediction"""
    writer = cv2.VideoWriter(
        filepath, cv2.VideoWriter_fourcc(*"VP90"), 2.0, (resolution, resolution)
    )

    if data.minimap is not None:
        base_frame = rasterize_minimap(data.minimap.cpu().numpy(), resolution)
    else:
        base_frame = np.full([resolution, resolution, 3], 255, dtype=np.uint8)

    for tidx in range(data.sequence_len):
        frame = base_frame.copy()
        frame = rasterize_units(frame, data.units[tidx].cpu().numpy(), (255, 0, 0))
        if data.enemy_units is not None:
            frame = rasterize_units(
                frame, data.enemy_units[tidx].cpu().numpy(), (0, 0, 255)
            )
        frame = overlay_prediction(frame, tidx, pred, data, pos_decoder)
        writer.write(frame)
    writer.release()


def write_prediction_video_batch(
    batch_index: int,
    batch_pred: dict[str, Tensor],
    data: sc2.TorchSC2Data,
    pos_decoder: PositionTransform,
    write_folder: Path,
):
    """Write batch of predictions"""
    mpool = mp.get_context("forkserver").Pool(processes=mp.cpu_count() * 3 // 4)
    start_idx = batch_index * data.batch_size
    for idx in range(data.batch_size):
        mpool.apply_async(
            write_prediction_video,
            kwds={
                "filepath": str(write_folder / f"pred_{start_idx + idx}.webm"),
                "pred": {k: v[:, idx] for k, v in batch_pred.items()},
                "data": data.get_sample(idx),
                "pos_decoder": pos_decoder,
            },
        )
    mpool.close()
    mpool.join()


@app.command()
def predictions(
    run_path: Path,
    split: Annotated[sc2.Split, typer.Option()],
    n_samples: Annotated[int, typer.Option()] = 16,
    workers: Annotated[int, typer.Option()] = 1,
    py_workers: Annotated[int, typer.Option()] = 1,
    batch_size: Annotated[int, typer.Option()] = 16,
):
    """Visualise model predictions"""
    exp_cfg = sc2.ExperimentInitConfig.from_run(run_path)
    exp_cfg.set_workers(workers)
    exp_cfg.set_batch_size(batch_size, split)
    exp_cfg.data[0].train_loader.args["py_num_workers"] = py_workers
    exp_cfg.data[0].val_loader.args["py_num_workers"] = py_workers

    model, data_cfg = initialize(exp_cfg)
    pos_decoder = PositionTransform.from_config(exp_cfg)

    outfolder = exp_cfg.exp_path / "sc2-prediction" / split.name.lower()
    outfolder.mkdir(parents=True, exist_ok=True)

    dataloader = data_cfg.get_dataloader(split)
    with torch.inference_mode(), LivePbar(total=n_samples) as pbar:
        for sample in dataloader:
            data = sc2.TorchSC2Data.from_dali(sample)
            pred = model(data)
            write_prediction_video_batch(pbar.n, pred, data, pos_decoder, outfolder)
            pbar.update(data.batch_size)
            if pbar.n >= n_samples:
                break


def rasterize_units(canvas: np.ndarray, units: np.ndarray, color: tuple[int, int, int]):
    """Render units [N,[x,y,t,...]] on to canvas with BGR color"""

    def to_image_coords(x: np.ndarray):
        return tuple(((x + 1) * canvas.shape[0] / 2).astype(int))

    for unit_feat in units:
        canvas = cv2.circle(
            canvas, to_image_coords(unit_feat[:2]), 5, color, thickness=2
        )
        arrow = 0.01 * np.array(
            [np.cos(unit_feat[2]), np.sin(unit_feat[2])], dtype=float
        )
        canvas = cv2.line(
            canvas,
            to_image_coords(unit_feat[:2]),
            to_image_coords(unit_feat[:2] + arrow),
            color,
            2,
        )

    return canvas


def rasterize_positions(
    canvas: np.ndarray,
    positions: np.ndarray,
    color: tuple[int, int, int] = (255, 255, 0),
    marker: int = cv2.MARKER_DIAMOND,
):
    """Render positions with a marker onto the canvas"""

    def to_image_coords(x: np.ndarray):
        return tuple(((x + 1) * canvas.shape[0] / 2).astype(int))

    for pos in positions:
        canvas = cv2.drawMarker(
            canvas, to_image_coords(pos), color, thickness=2, markerType=marker
        )

    return canvas


def rasterize_assignment(
    canvas: np.ndarray,
    sources: np.ndarray,
    targets: np.ndarray,
    assignment: np.ndarray,
    color: tuple[int, int, int] = (0, 255, 0),
):
    """Rasterize assignment of sources to targets onto canvas as an arrow from source to target
    with BGR color."""

    def to_image_coords(x: np.ndarray):
        return tuple(((x + 1) * canvas.shape[0] / 2).astype(int))

    for src, tgt in enumerate(assignment):
        if tgt == -1:
            continue
        canvas = cv2.arrowedLine(
            canvas,
            to_image_coords(sources[src][:2]),
            to_image_coords(targets[tgt][:2]),
            color,
            thickness=2,
        )

    return canvas


def rasterize_position_values(
    canvas: np.ndarray,
    sources: np.ndarray,
    targets: np.ndarray,
    color: tuple[int, int, int] = (0, 255, 0),
):
    """Rasterize source and target position pairs onto canvas with BGR color"""
    mask = (targets < 1).all(axis=-1)
    mask &= (targets > -1).all(axis=-1)
    sources = sources[mask]
    targets = targets[mask]

    def to_image_coords(x: np.ndarray):
        return tuple(((x + 1) * canvas.shape[0] / 2).astype(int))

    for src, tgt in zip(sources, targets):
        canvas = cv2.arrowedLine(
            canvas,
            to_image_coords(src[:2]),
            to_image_coords(tgt),
            color,
            thickness=2,
        )

    return canvas


def rasterize_minimap(minimap: np.ndarray, resolution: int):
    """Rasterize the minimap as a uint8 rgb image"""
    if minimap.dtype != np.uint8:
        minimap = (minimap * 255).clip(0, 255).astype(np.uint8)
    if minimap.ndim == 3:
        assert minimap.shape[0] == 1, "Currently only handle 1ch at front"
        minimap = minimap[0]
    image = cv2.resize(
        minimap, (resolution, resolution), interpolation=cv2.INTER_NEAREST
    )
    image = image[..., None].repeat(3, axis=-1)
    return image


def rasterize_frame(data: SequenceData, t_idx: int, resolution: int = 1024):
    """Rasterize loaded SequenceData onto an image"""
    if data.minimap is not None:
        image = rasterize_minimap(data.minimap, resolution)
    else:
        image = np.full([resolution, resolution, 3], 255, dtype=np.uint8)

    image = rasterize_units(image, data.units[t_idx], (255, 0, 0))
    if data.enemy_units is not None:
        image = rasterize_units(image, data.enemy_units[t_idx], (0, 0, 255))

    image = rasterize_assignment(
        image,
        data.units[t_idx],
        data.enemy_units[t_idx] if data.enemy_units is not None else data.units[t_idx],
        data.unit_targets[t_idx],
    )
    if data.positions_unique is not None:
        assert data.position_targets is not None
        image = rasterize_positions(image, data.positions_unique[t_idx])
        image = rasterize_assignment(
            image,
            data.units[t_idx],
            data.positions_unique[t_idx],
            data.position_targets[t_idx],
        )

    if data.positions is not None:
        # repeat writing of the same position is okay
        image = rasterize_positions(image, data.positions[t_idx])
        image = rasterize_position_values(
            image, data.units[t_idx], data.positions[t_idx]
        )

    return image
