# Evaluate model flops based on a few samples of real data
import tempfile
from copy import deepcopy
from functools import partial
from pathlib import Path
from typing import Annotated, Optional

import torch
import typer
from konductor.data import Split, get_dataset_config
from konductor.init import ExperimentInitConfig
from konductor.metadata import DataManager
from konductor.metadata.database.sqlite import DEFAULT_FILENAME
from konductor.models import get_model
from konductor.trainer.pytorch import (
    PyTorchTrainer,
    PyTorchTrainerConfig,
    PyTorchTrainerModules,
)
from konductor.utilities.pbar import LivePbar, PbarType, pbar_wrapper
from konductor.utilities.tools import parameter_count
from sqlalchemy import create_engine, update
from sqlalchemy.orm import (
    DeclarativeBase,
    Mapped,
    MappedAsDataclass,
    mapped_column,
    sessionmaker,
)
from torch import nn

from .trainer import DaliPipeParams, GymTrainer, SC2Trainer, setup_init_config


class Base(MappedAsDataclass, DeclarativeBase):
    """subclasses will be converted to dataclasses"""


class TrainCost(Base):
    __tablename__ = "training_cost"

    hash: Mapped[str] = mapped_column(primary_key=True)

    train_samples: Mapped[int]
    train_batch: Mapped[int]
    train_step_s: Mapped[float]
    train_mem_avg_gb: Mapped[float]
    train_mem_max_gb: Mapped[float]

    val_samples: Mapped[int]
    val_batch: Mapped[int]
    val_step_s: Mapped[float]
    val_mem_avg_gb: Mapped[float]
    val_mem_max_gb: Mapped[float]

    param_count: Mapped[int] = mapped_column(default=-1)


try:
    from torchtnt.utils.flops import FlopTensorDispatchMode
except ImportError:
    pass  # Don't throw if this isn't used anyway


@torch.no_grad
def evaluate_flops(model: nn.Module, dataloader, n_samples: int):
    """"""
    flops_data = []
    with LivePbar(n_samples, "Counting FLOPS") as pbar:
        with FlopTensorDispatchMode(model) as ftdm:
            for sample in dataloader:
                _ = model(**sample)
                flops_data.append(deepcopy(ftdm.flop_counts))
                ftdm.reset()
                pbar.update(dataloader.batch_size)
    print(flops_data)


app = typer.Typer()


@app.command()
def calculate_flops(
    run_path: Annotated[Path, typer.Option()],
    n_samples: Annotated[int, typer.Option()] = 128,
):
    """Calculate the FLOPS used in a model's forward pass"""
    exp_cfg = ExperimentInitConfig.from_run(run_path)
    dataset_cfg = get_dataset_config(exp_cfg)
    model = get_model(exp_cfg)

    # Use validation loader for consistency
    dataloader = dataset_cfg.get_dataloader(Split.VAL)

    evaluate_flops(model, dataloader, n_samples)


def run_stat_collection(trainer: GymTrainer | SC2Trainer):
    """Run collection of training/validation performance statistics

    Return:
        tuple[TrainStats, TrainStats]: collected train and validation statistics
    """

    try:
        trainer._train()
    except StopIteration:
        pass

    train_stats = deepcopy(trainer.collect_stats)

    trainer.collect_stats.reset()
    torch.cuda.empty_cache()

    try:
        trainer._validate()
    except StopIteration:
        pass

    val_stats = deepcopy(trainer.collect_stats)

    trainer.collect_stats.reset()
    torch.cuda.empty_cache()

    return train_stats, val_stats


def setup_trainer(
    run_path: Path,
    n_samples: int,
    temp_dir: str,
    train_batch: int | None,
    val_batch: int | None,
    overrides: list[str] | None,
):
    dali_params = DaliPipeParams(py_workers=4, source_prefetch=2, pipe_prefetch=2)
    exp_cfg = setup_init_config(
        run_path.parent,
        config_file=None,
        run_hash=run_path.stem,
        workers=8,
        dali_params=dali_params,
    )

    if overrides is not None:
        for override in overrides:
            target, key, value = override.split(":")
            if target == "dataset":
                args = exp_cfg.data[0].dataset.args
            else:
                raise NotImplementedError
            args[key] = type(args[key])(value)

    # Don't actually record anything to real experiment
    exp_cfg.exp_path = Path(temp_dir)

    if train_batch is not None:
        exp_cfg.data[0].train_loader.args["batch_size"] = train_batch
    if val_batch is not None:
        exp_cfg.data[0].val_loader.args["batch_size"] = val_batch

    train_modules = PyTorchTrainerModules.from_config(exp_cfg)
    data_manager = DataManager.default_build(exp_cfg, {}, {})
    trainer_config = PyTorchTrainerConfig(**exp_cfg.trainer)
    trainer_config.pbar = partial(pbar_wrapper, pbar_type=PbarType.LIVE)

    trainer: PyTorchTrainer = {"gym-predict": GymTrainer, "sc2-battle": SC2Trainer}.get(
        exp_cfg.data[0].dataset.type, PyTorchTrainer
    )(trainer_config, train_modules, data_manager, collect_stats=n_samples)

    return trainer


def get_engine(workspace: Path, overrides: list[str] | None = None):
    """Get database engine, don't use the default filename if overrides is not None"""
    if overrides is None:
        db_filename = DEFAULT_FILENAME
    else:  # Make another database with overrides as a suffix
        pre, post = DEFAULT_FILENAME.split(".")
        db_filename = f"{pre}_{'_'.join(o.replace(':', '-') for o in overrides)}.{post}"

    engine = create_engine(f"sqlite:///{workspace / db_filename}")
    Base.metadata.create_all(engine)
    return engine


@app.command()
def training_cost(
    run_path: Annotated[Path, typer.Option()],
    n_samples: Annotated[int, typer.Option()],
    train_batch: Annotated[Optional[int], typer.Option()] = None,
    val_batch: Annotated[Optional[int], typer.Option()] = None,
    overrides: Annotated[Optional[list[str]], typer.Option()] = None,
):
    """Calculate training cost (step time and gpu mem) and add result to database"""
    with tempfile.TemporaryDirectory() as temp_dir:
        trainer = setup_trainer(
            run_path, n_samples, temp_dir, train_batch, val_batch, overrides
        )
        train_stats, val_stats = run_stat_collection(trainer)

    n_params = parameter_count(run_path / "train_config.yml")
    engine = get_engine(run_path.parent, overrides)

    # This is some clown world stuff, why isn't there a simple upsert method?
    with sessionmaker(bind=engine)() as session:
        entry = session.query(TrainCost).filter_by(hash=run_path.stem).first()
        if entry:
            entry.param_count = n_params
            entry.train_samples = train_stats.count
            entry.train_batch = trainer.modules.trainloader.batch_size
            entry.train_step_s = train_stats.step_time
            entry.train_mem_avg_gb = train_stats.gpu_mem_avg
            entry.train_mem_max_gb = train_stats.gpu_mem_max
            entry.val_samples = val_stats.count
            entry.val_batch = trainer.modules.valloader.batch_size
            entry.val_step_s = val_stats.step_time
            entry.val_mem_avg_gb = val_stats.gpu_mem_avg
            entry.val_mem_max_gb = val_stats.gpu_mem_max
        else:
            session.add(
                TrainCost(
                    hash=run_path.stem,
                    param_count=n_params,
                    train_samples=train_stats.count,
                    train_batch=trainer.modules.trainloader.batch_size,
                    train_step_s=train_stats.step_time,
                    train_mem_avg_gb=train_stats.gpu_mem_avg,
                    train_mem_max_gb=train_stats.gpu_mem_max,
                    val_samples=val_stats.count,
                    val_batch=trainer.modules.valloader.batch_size,
                    val_step_s=val_stats.step_time,
                    val_mem_avg_gb=val_stats.gpu_mem_avg,
                    val_mem_max_gb=val_stats.gpu_mem_max,
                )
            )
        session.commit()


@app.command()
def training_cost_list(
    workspace: Annotated[Path, typer.Option()],
    list_file: Annotated[Path, typer.Option()],
    n_samples: Annotated[int, typer.Option()],
    train_batch: Annotated[Optional[int], typer.Option()] = None,
    val_batch: Annotated[Optional[int], typer.Option()] = None,
):
    """Run over a list of experiments"""
    with open(list_file, "r", encoding="utf-8") as f:
        run_folders = [l.strip() for l in f.readlines()]

    for i, run_folder in enumerate(run_folders, start=1):
        training_cost(workspace / run_folder, n_samples, train_batch, val_batch)
        print(f"Completed {i} of {len(run_folders)}")


@app.command()
def gather_param_counts(workspace: Annotated[Path, typer.Option()]):
    """Run over all experiments in the directory and add their parameter
    count to the table if it has an entry in the table"""
    engine = get_engine(workspace)

    with sessionmaker(bind=engine)() as session:
        exp_list = [e.hash for e in session.query(TrainCost).all()]

    exp_params: list[dict] = [
        {
            "hash": exp,
            "param_count": parameter_count(workspace / exp / "train_config.yml"),
        }
        for exp in exp_list
    ]

    with sessionmaker(bind=engine)() as session:
        session.execute(update(TrainCost), exp_params)
        session.commit()
