#!/usr/bin/env python3
import shutil
from pathlib import Path
from typing import Annotated

import cv2
import torch
import typer

from src.utils.sc2_data import app as sc2_util
from src.visualization.gym_env import app as gym_app
from src.visualization.sc2_battle import app as sc2_viz
from src.flops_eval import app as flops_app
from src.summary_table import app as summarize

app = typer.Typer(pretty_exceptions_enable=False, pretty_exceptions_show_locals=False)
app.add_typer(gym_app, name="gym")
app.add_typer(sc2_viz, name="sc2-visual")
app.add_typer(sc2_util, name="sc2-util")
app.add_typer(flops_app, name="flops")
app.add_typer(summarize, name="summarize")


@app.command()
def video_to_sequence(
    path: Annotated[Path, typer.Option()],
    out: Annotated[Path, typer.Option()] = Path("sequence"),
    stride: Annotated[int, typer.Option()] = 1,
):
    """
    Convert video to a sequence of frames...better than taking a screenshot of a video.
    `out` folder should not already exist, will error if it does.
    """
    video = cv2.VideoCapture(str(path))

    out.mkdir()  # Error if output folder already exists, should be clean folder
    frame_idx = 0

    ok = video.isOpened()
    while ok:
        ok, frame = video.read()
        if ok and frame_idx % stride == 0:
            cv2.imwrite(str(out / f"{frame_idx}.png"), frame)
        frame_idx += 1

    video.release()


@app.command()
def fix_goal_transformer_keys(path: Annotated[Path, typer.Option()]):
    """Fix the goal-perceiver transformer encoder name from an old checkpoint
    from encoder.enemy_enc to encoder.target_enc"""

    checkpoint = torch.load(path, weights_only=True)
    model_ckpt: dict[str, torch.Tensor] = checkpoint["model"]

    for key in list(model_ckpt.keys()):
        if key.startswith("encoder.enemy_enc"):
            newkey = key.replace("enemy_enc", "target_enc", 1)
        else:
            newkey = key
        model_ckpt[newkey] = model_ckpt.pop(key)

    backup = path.with_suffix(path.suffix + ".bak")
    shutil.copyfile(path, backup)
    torch.save(checkpoint, path)


if __name__ == "__main__":
    torch.set_float32_matmul_precision("high")
    app()
