import os
from pathlib import Path
from typing import Annotated, Optional

import numpy as np
import typer
from konductor.data import Split
from konductor.data.dali import SampleInfo
from konductor.utilities.pbar import IntervalPbar
from sc2_serializer import GAME_INFO_FILE, ReplayDataAll

import src.dataset.sc2_preproc as sc2
from src.dataset.dataset_utils import (
    create_replay_subsequence,
    gather_unique_unit_types,
)
from src.dataset.sc2_common import find_main_roi
from src.dataset.sc2_dataset import (
    BATTLE_THRESHOLD_DEFAULT,
    BATTLE_WINDOW_SIZE_DEFAULT,
    DaliLoaderConfig,
    SC2BattleCfg,
)

app = typer.Typer()


def make_battle_events_regular_duration(
    events: list[sc2.BattleEvent], duration: int, replay_len: int
):
    """Adjust the battle events such that they are all the same fixed duration by extending or
    subdividing. When extending, prefer into the future unless that exceeds replay_len, then
    extend into the past.
    """
    new_events: list[sc2.BattleEvent] = []
    for event in events:
        for subsplit in range(event.duration // duration):
            new_start = event.start + subsplit * duration
            assert new_start + duration < replay_len
            new_events.append(
                sc2.BattleEvent(new_start, new_start + duration, event.score)
            )
        if event.duration < duration and event.duration > duration // 2:
            if event.start + duration < replay_len:
                new_events.append(
                    sc2.BattleEvent(event.start, event.start + duration, event.score)
                )
            elif event.end - duration > 0:
                new_events.append(
                    sc2.BattleEvent(event.end - duration, event.end, event.score)
                )
    return new_events


def add_battles_to_database(
    battle_db: sc2.ReplayDataAllDatabase,
    replay: ReplayDataAll,
    duration: int,
    window_size: int,
    threshold: float,
):
    """Find battle events in the replay and add each event sequence as a separate replay to the
    database."""
    events = sc2.find_battle_events(replay.data, window_size, threshold)
    events = make_battle_events_regular_duration(events, duration, len(replay))
    for event in events:
        subsequence = create_replay_subsequence(replay, event.start, event.duration)
        assert battle_db.addEntry(subsequence), "error adding new subsequence"


@app.command()
def create_battle_dataset(
    datapath: Annotated[Path, typer.Option()] = Path(
        os.environ.get("DATAPATH", "/data")
    ),
    outpath: Annotated[Path, typer.Option()] = Path().cwd()
    / "battle_dataset.SC2Replays",
    duration: Annotated[int, typer.Option()] = 30,
    window_size: Annotated[int, typer.Option()] = BATTLE_WINDOW_SIZE_DEFAULT,
    threshold: Annotated[float, typer.Option()] = BATTLE_THRESHOLD_DEFAULT,
):
    """Creates dataset of battle snippets for faster dataloading during training"""
    dataset = sc2.ReplayDataset(datapath)

    assert not outpath.exists(), "Output already exists and there is no deduplication"
    battle_db = sc2.ReplayDataAllDatabase(outpath)

    with IntervalPbar(len(dataset), fraction=0.01, desc="Creating Dataset") as pbar:
        for replay in dataset:
            try:
                add_battles_to_database(
                    battle_db, replay, duration, window_size, threshold
                )
            except Exception as err:
                print(f"Skipping Replay {replay.header.replayHash} with error: {err}")
            pbar.update(1)


@app.command()
def precalculate_roi(
    path: Annotated[Path, typer.Option()],
    roi_size: Annotated[int, typer.Option()],
    outdir: Annotated[Optional[Path], typer.Option()] = None,
    overwrite: Annotated[Optional[bool], typer.Option()] = False,
):
    """Pre-calculate battle ROI with kmeans clustering"""
    dataset = sc2.ReplayDataset(path)

    if outdir is None:
        outdir = path.parent

    output = outdir / f"sc2_roi_center_{roi_size}_{roi_size}.csv"
    if output.exists() and not overwrite:
        raise FileExistsError(output)

    roi_shape = np.full(2, fill_value=roi_size)
    parser = sc2.ReplayDataAllParser(GAME_INFO_FILE)
    unit_feats = [sc2.UnitOH.x, sc2.UnitOH.y, sc2.UnitOH.t]

    with open(output, "w", encoding="utf-8") as f:
        f.write("replay,cx,cy\n")

    with typer.progressbar(length=len(dataset), label="Writing Center Points:") as pbar:
        replay: ReplayDataAll
        for replay in dataset:
            map_size = np.array([replay.header.mapWidth, replay.header.mapHeight])
            parser.parse_replay(replay)
            sequence = sc2.extract_battle_sequence(
                parser, sc2.BattleEvent(0, len(replay), 0), unit_feats
            )
            roi_center = find_main_roi(
                sequence.units, sequence.unit_targets, roi_shape, map_size
            )
            with open(output, "a", encoding="utf-8") as f:
                f.write(f"{replay.header.replayHash},{roi_center[0]},{roi_center[1]}\n")
            pbar.update(1)


@app.command()
def get_unique_units(path: Annotated[Path, typer.Option()]):
    """Write a file with unique ids of unit types in the dataset"""
    dataset = sc2.ReplayDataset(path)

    all_unit_types: set[int] = set()
    with typer.progressbar(dataset, label="Writing Center Points:") as pbar:
        replay: ReplayDataAll
        for replay in pbar:
            all_unit_types.update(gather_unique_unit_types(replay.data))

    with open("unique_unit_list.txt", "w", encoding="utf-8") as f:
        f.write(",".join(str(e) for e in all_unit_types))
        f.write("\n")


def get_basic_dataloader(
    batch_size: int = 8,
    shuffle: bool = False,
    py_workers: int = 1,
    workers: int = 1,
    dataset: Path = Path(os.environ.get("DATAPATH", "/data")),
):
    loader_cfg = DaliLoaderConfig(
        batch_size=batch_size,
        workers=workers,
        py_num_workers=py_workers,
        shuffle=shuffle,
    )
    dataset_cfg = SC2BattleCfg(
        loader_cfg,
        loader_cfg,
        clip_length=30,
        enable_pos_values=True,
        minimap_size=np.full((2), fill_value=128),
        roi_size=np.full((2), fill_value=20),
        basepath=dataset,
    )
    return dataset_cfg


@app.command()
def profile_loader(
    n_iter: Annotated[int, typer.Option()] = 20,
    shuffle: Annotated[bool, typer.Option()] = False,
    split: Annotated[Split, typer.Option()] = Split.TRAIN,
):
    """Run with scalene off to begin with so that only loading is profiled:
    scalene --off --cpu other.py sc2 profile-loader --n-iter=500
    NOTE: Ensure zlib-ng is used with LD_PRELOAD
    """
    from scalene import scalene_profiler

    dataset_cfg = get_basic_dataloader(shuffle=shuffle)
    source = dataset_cfg.make_source(split)
    source._post_init()

    scalene_profiler.start()
    for i in range(n_iter):
        _ = source(SampleInfo(i, 0, i, 0))
    scalene_profiler.stop()


@app.command()
def sample_loader(
    index: Annotated[int, typer.Option()] = 0,
    dataset: Annotated[Path, typer.Option()] = Path(os.environ.get("DATAPATH", ".")),
    split: Annotated[Split, typer.Option()] = Split.TRAIN,
):
    """Sample specific index from dataset"""
    dataset_cfg = get_basic_dataloader(dataset=dataset)
    source = dataset_cfg.make_source(split)
    source._post_init()
    data = source.get_data(index)
    print(data)


def run_loader_with_progress(loader, desc: str):
    """Run over dataloader with progress bar"""
    with IntervalPbar(total=len(loader), fraction=0.05, desc=desc) as pbar:
        for _ in loader:
            pbar.update(1)


@app.command()
def validate_loader(
    dataset: Annotated[Path, typer.Option()] = Path(os.environ.get("DATAPATH", ".")),
    batch_size: Annotated[int, typer.Option()] = 64,
    py_workers: Annotated[int, typer.Option()] = 8,
    workers: Annotated[int, typer.Option()] = 4,
):
    """Run over entire dataset, throwing any errors that occur"""
    dataset_cfg = get_basic_dataloader(
        dataset=dataset, batch_size=batch_size, py_workers=py_workers, workers=workers
    )
    run_loader_with_progress(dataset_cfg.get_dataloader(Split.TRAIN), "Train Loader")
    run_loader_with_progress(dataset_cfg.get_dataloader(Split.VAL), "Validation Loader")
