#!/usr/bin/env python3
import logging
from functools import partial
from pathlib import Path
from typing import Annotated, Optional

import torch
import typer
import yaml
from konductor.init import ModuleInitConfig
from konductor.metadata import DataManager
from konductor.trainer.pytorch import (
    AsyncFiniteMonitor,
    PyTorchTrainer,
    PyTorchTrainerConfig,
    PyTorchTrainerModules,
)
from konductor.utilities import comm
from konductor.utilities.pbar import PbarType, pbar_wrapper

import src
import src.statistics  # Imports all components into framework

app = typer.Typer(
    pretty_exceptions_show_locals=False,
    pretty_exceptions_enable=False,
    pretty_exceptions_short=True,
)


@app.command()
def main(
    workspace: Annotated[Path, typer.Option()],
    epoch: Annotated[int, typer.Option()],
    remote: Annotated[Optional[Path], typer.Option()] = None,
    run_hash: Annotated[Optional[str], typer.Option()] = None,
    config_file: Annotated[Optional[Path], typer.Option()] = None,
    workers: Annotated[int, typer.Option()] = 4,
    dali_py_workers: Annotated[int, typer.Option()] = 4,
    dali_source_prefetch: Annotated[int, typer.Option()] = 2,
    dali_pipe_prefetch: Annotated[int, typer.Option()] = 2,
    pbar: Annotated[bool, typer.Option()] = False,
    brief: Annotated[Optional[str], typer.Option()] = None,
) -> None:
    """Main entrypoint for training model"""
    dali_params = src.trainer.DaliPipeParams(
        dali_py_workers, dali_source_prefetch, dali_pipe_prefetch
    )
    exp_cfg = src.trainer.setup_init_config(
        workspace, config_file, run_hash, workers, dali_params
    )

    if remote is not None:
        with open(remote, "r", encoding="utf-8") as remote_file:
            remote_cfg = yaml.safe_load(remote_file)
        exp_cfg.remote_sync = ModuleInitConfig(**remote_cfg)

    train_modules = PyTorchTrainerModules.from_config(exp_cfg)

    data_manager = DataManager.default_build(
        exp_cfg,
        train_modules.get_checkpointables(),
        src.statistics.get_statistics(exp_cfg),
    )
    if brief is not None:
        data_manager.metadata.brief = brief

    trainer_config = PyTorchTrainerConfig(**exp_cfg.trainer)
    if pbar and comm.get_local_rank() == 0:
        trainer_config.pbar = partial(pbar_wrapper, pbar_type=PbarType.LIVE)
    elif comm.get_local_rank() == 0:
        trainer_config.pbar = partial(
            pbar_wrapper, pbar_type=PbarType.INTERVAL, fraction=0.1
        )

    trainer: PyTorchTrainer = {
        "gym-predict": src.trainer.GymTrainer,
        "sc2-battle": src.trainer.SC2Trainer,
    }.get(exp_cfg.data[0].dataset.type, PyTorchTrainer)(
        trainer_config, train_modules, data_manager
    )

    # add_backward_monitor(trainer)
    trainer.train(epoch=epoch)
    if isinstance(trainer.loss_monitor, AsyncFiniteMonitor):
        trainer.loss_monitor.stop()

    # from konductor.trainer.profiler import profile_function

    # profile_function(trainer._train, Path.cwd())


if __name__ == "__main__":
    comm.initialize()
    torch.set_float32_matmul_precision("high")
    logging.basicConfig(
        format=f"%(asctime)s-RANK:{comm.get_local_rank()}-%(levelname)s-%(name)s: %(message)s",
        level=logging.INFO,
        force=True,
    )
    app()
