"""

Notation:
tr = training
te = test

x, y: sample of data (input, output)
xs, ys: iterable of data (inputs, outputs)
xb, yb: iterable of a batch of data
xv, yv: iterable of a batch of test data
yh, yhs: y-hat / iterable of y-hat

_buf: suffix representing buffer (for temporary or intermediate storage)
"""

import json
from abc import ABC, abstractmethod
from collections import defaultdict
from itertools import cycle
from pathlib import Path
from typing import ClassVar, Literal, NamedTuple, TypedDict

import mlflow
import torch
import torch.nn.functional as f
from huggingface_hub import snapshot_download
from torch import Tensor, nn
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader
from transformers.modeling_utils import PreTrainedModel
from transformers.models.auto.configuration_auto import AutoConfig
from transformers.models.auto.modeling_auto import AutoModel, AutoModelForCausalLM, AutoModelForSequenceClassification
from transformers.models.auto.tokenization_auto import AutoTokenizer

from clcp import HF_TOKEN
from clcp.data import CLF_DSS
from clcp.lr_scheduler import get_scheduler
from clcp.metrics import Metrics
from ml_utils import log
from ml_utils.timer import timer


class TestDL(NamedTuple):
    name: str
    dl: DataLoader
    metrics: Metrics


class LRDict(TypedDict):
    backbone: float
    head: float


class LossBuf:
    def __init__(self, device: torch.device):
        self.device = device
        self._clear()

    def _clear(self) -> None:
        self.total = torch.zeros((), device=self.device)
        self.count = 0

    def update(self, loss: Tensor, bsz: int) -> None:
        self.total += loss.detach() * bsz
        self.count += bsz

    def compute(self) -> float:
        mean_loss = (self.total / self.count).item()
        self._clear()
        return mean_loss


class EarlyStopping:
    def __init__(self, mdl: "Model", save_path: str = "mdl.pt", patience: int = 10, delta: float = 1e-4) -> None:
        self.mdl = mdl
        self.save_path = Path(save_path)
        self.patience = patience
        self.delta = delta
        self.best_loss = float("inf")
        self.best_step = 0
        self.counter = 0

    def __call__(self, loss: float, step: int) -> bool:
        if loss < self.best_loss - self.delta:
            self.best_loss = loss
            self.best_step = step
            self.counter = 0
            torch.save(self.mdl.state_dict(), f=self.save_path)
        else:
            self.counter += 1
        return self.counter >= self.patience

    def restore_best_weights(self) -> None:
        state_dict = torch.load(f=self.save_path, map_location=self.mdl.get_device())
        self.mdl.load_state_dict(state_dict=state_dict)


def log_loss(name: str, hist: dict[str, list[float]], buf: LossBuf, step: int) -> None:
    if buf.count == 0:
        return
    loss = buf.compute()
    hist[name].append(loss)
    mlflow.log_metric(key=name, value=loss, step=step)


class Model(nn.Module, ABC):
    requires_paired_inp: bool
    registry: ClassVar[dict[str, type["Model"]]] = {}
    encoder: PreTrainedModel

    def __init_subclass__(cls, *, arch: str = "NA", **kwargs) -> None:
        super().__init_subclass__(**kwargs)
        cls.arch = property(lambda self: arch)  # noqa: ARG005
        Model.registry[arch] = cls

    @classmethod
    def load(cls, name: str) -> "Model":
        try:
            return load_model(name=name)
        except Exception:  # noqa: BLE001
            log.info("loading external model")
            arch = _determine_architecture(mdl_name=name)
            subclass = Model.registry[arch]
            return subclass(name=name)

    @classmethod
    def build(cls, *, arch: str, backbone: str, **kwargs):
        return cls.registry[arch](backbone=backbone, **kwargs)

    @abstractmethod
    def forward(self, input_ids: Tensor, attention_mask: Tensor) -> Tensor:
        pass

    def setup_device(self) -> None:
        self.to(device=self.get_device())
        self.use_bfloat16 = torch.cuda.is_bf16_supported(including_emulation=False)
        self.amp_dtype = torch.bfloat16 if self.use_bfloat16 else torch.float16
        log.info(f"PyTorch - Using {self.get_device()}")

    @staticmethod
    def get_device() -> torch.device:
        device_type = "cuda" if torch.cuda.is_available() else "cpu"
        return torch.device(device_type)

    def loss_fn(self, yhs: Tensor, ys: Tensor) -> Tensor:  # noqa: PLR6301
        return f.binary_cross_entropy_with_logits(yhs, ys)

    def forward_pass(self, xb: dict[str, Tensor], yb: Tensor, *, use_amp: bool = True) -> tuple[Tensor, Tensor]:
        xb = {k: v.to(self.get_device(), non_blocking=True) for k, v in xb.items()}
        yb = yb.to(self.get_device(), non_blocking=True)
        with torch.autocast(device_type=self.get_device().type, dtype=self.amp_dtype, enabled=use_amp):
            yhs = self(**xb)  # (B, S) -> (B, 1)
            return self.loss_fn(yhs, yb), yhs  # (1,) (B, 1)

    @staticmethod
    def backward_pass(
        loss: Tensor,
        opt: torch.optim.Optimizer,
        lr_scheduler: LambdaLR,
        scaler: torch.GradScaler,
    ) -> None:
        scaler.scale(loss).backward()  # 1. Backprop (scaled) loss
        scaler.step(opt)  # 2. (Unscale gradients and) call optimizer.step()
        scaler.update()  # 3. (Update the scale for next iteration)
        lr_scheduler.step()  # 4. Update learning rate *after* optimizer step
        opt.zero_grad()  # 5. Zero gradients at the end

    @torch.inference_mode()
    def _eval_dl(self, te_dl: TestDL) -> tuple[Tensor, Tensor, LossBuf]:
        """Return ys, yhs, LossBuf for a single TestDL."""
        device = self.get_device()
        buf, te_buf = defaultdict(list), LossBuf(device=self.get_device())

        for xv, yv_cpu in te_dl.dl:
            yv = yv_cpu.to(device, non_blocking=True)
            loss, yhv = self.forward_pass(xv, yv, use_amp=False)

            for k, v in (("ys", yv), ("yhs", yhv)):
                buf[k].append(v.detach())
            te_buf.update(loss=loss, bsz=len(yv))

        ys, yhs = torch.cat(buf["ys"]), torch.cat(buf["yhs"])
        return ys, yhs, te_buf

    def eval_step(self, te_dls: list[TestDL], hist: dict[str, list[float]], step: int, prefix_metric: str = ""):
        te_dl = te_dls[0]

        ys, yhs, loss = self._eval_dl(te_dl=te_dl)
        name = f"{prefix_metric}_{te_dl.name}" if prefix_metric else te_dl.name
        log_loss(f"{name}_te_loss", hist=hist, buf=loss, step=step)

        for metric, score in te_dl.metrics(yhs, ys).items():
            key = f"{name}_{metric}"
            hist[key].append(score)
            mlflow.log_metric(key=key, value=score, step=step)

    def full_eval_step(self, te_dls: list[TestDL], step: int, prefix_metric: str = "", *, only_avg: bool = False):
        # Maximize GPU work across dl's without logging gpu-cpu sync's
        buf = {}
        for te_dl in te_dls:
            ys, yhs, _ = self._eval_dl(te_dl=te_dl)
            ys, yhs = ys.detach().to("cpu", non_blocking=True), yhs.detach().to("cpu", non_blocking=True)
            buf[te_dl.name] = (ys, yhs)
        torch.cuda.current_stream().synchronize()

        # CPU work
        avg = {"nli": defaultdict(list), "clf": defaultdict(list)}
        metric_buf, fig_buf, dict_buf = {}, [], []
        for te_dl in te_dls:
            bucket = "clf" if te_dl.name in CLF_DSS else "nli"
            name = f"{prefix_metric}_{te_dl.name}" if prefix_metric else te_dl.name
            ys, yhs = buf[te_dl.name]

            for metric, score in te_dl.metrics(yhs, ys).items():
                avg[bucket][metric].append(score)
                metric_buf[f"{name}_{metric}"] = score

            if only_avg:
                continue

            fig_buf.append((te_dl.metrics.plot_cm(yhs, ys), f"figs/{name}_cm.svg"))
            dict_buf.append((
                {
                    "ys": ys.tolist(),
                    "yhs": yhs.tolist(),
                    "misclassified_indices": (ys != yhs.sigmoid().round()).nonzero(as_tuple=True)[0].int().tolist(),
                    "yhs - min/max": [ys.min().item(), ys.max().item()],
                    "yhs - pos/neg (%)": [
                        (yhs > 0).float().mean().item(),
                        (yhs < 0).float().mean().item(),
                    ],
                    "σ(yhs) - mean/median": [yhs.sigmoid().mean().item(), yhs.sigmoid().median().item()],  # noqa: RUF001
                },
                f"logs/{name}_logs_on_training_end.json",
            ))

        if not only_avg:  # log metrics in one shot
            mlflow.log_metrics(metrics=metric_buf, step=step)
            for fig, path in fig_buf:
                mlflow.log_figure(fig, artifact_file=path, save_kwargs={"format": "svg"})
            for payload, path in dict_buf:
                mlflow.log_dict(payload, path)

        # aggregated averages
        agg_metrics = {
            (f"{prefix_metric}_{b}_{m}" if prefix_metric else f"{b}_{m}"): sum(vals) / len(vals)
            for b, mets in avg.items()
            for m, vals in mets.items()
        }
        mlflow.log_metrics(agg_metrics, step=step)

    def fit(
        self,
        tr_dl: DataLoader,
        te_dls: list[TestDL],
        epochs: int,
        lr: LRDict,
        lr_schedule: Literal["cte", "linear"] = "cte",
        freeze_backbone_until: int = 0,
        patience: int = 10,
        *,
        use_amp: bool = True,
    ) -> dict[str, list[float]]:
        use_amp = use_amp and torch.cuda.is_available()
        scaler = torch.GradScaler(enabled=(use_amp and not self.use_bfloat16))  # bfloat16 doesn't need scaler
        iters = epochs * len(tr_dl)

        self.encoder.eval()
        head_params = [p for n, p in self.named_parameters() if not n.startswith("encoder")]
        opt = torch.optim.AdamW([
            {"params": self.encoder.parameters(), "lr": 0.0},
            {"params": head_params, "lr": lr["head"]},
        ])

        lr_lambda = get_scheduler(lr_schedule, warmup_steps=int(0.10 * iters), total_steps=iters)
        lr_scheduler = LambdaLR(optimizer=opt, lr_lambda=lr_lambda)
        early_stop = EarlyStopping(mdl=self, patience=patience)

        hist, tr_buf = defaultdict(list), LossBuf(device=self.get_device())
        eval_every = max(1, iters // 100)
        ckpt_every = eval_every * 10

        for i, (xb, yb) in zip(range(iters + 1), cycle(tr_dl)):
            if i == freeze_backbone_until:
                self.encoder.train()
                opt.param_groups[0]["lr"] = lr["backbone"]
                lr_scheduler.base_lrs[0] = lr["backbone"]
                log.info(f"backbone unfrozen at step {i}")

            self.train()
            if i < freeze_backbone_until:
                self.encoder.eval()
            if i > 0:
                loss, _ = self.forward_pass(xb, yb, use_amp=use_amp)
                self.backward_pass(loss, opt, lr_scheduler, scaler)
                tr_buf.update(loss=loss, bsz=len(yb))

            if i % eval_every == 0 or i == iters:
                self.eval()
                log_loss("tr_loss", hist=hist, buf=tr_buf, step=i)
                self.eval_step(te_dls=te_dls, hist=hist, step=i)
                log.info(f"Step: {i:>4}/{iters} - " + ", ".join(f"{k}={hist[k][-1]:.2f}" for k in hist) + " ✓")

                if i % ckpt_every == 0:
                    self.full_eval_step(te_dls=te_dls, step=i, only_avg=True)

                # At least one epoch before activating
                loss = hist[f"{te_dls[0].name}_te_loss"][-1]
                early_stop_activated = i >= 2 * len(tr_dl) and early_stop(loss=loss, step=i)

                if early_stop_activated:
                    early_stop.restore_best_weights()
                    self.full_eval_step(te_dls=te_dls, step=early_stop.best_step)
                    msg = f"Early stopping activated: step {early_stop.best_step} loss: {early_stop.best_loss:.2f}"
                    log.info(msg)
                    break

                if i == iters:
                    self.full_eval_step(te_dls=te_dls, step=i)

        return hist

    def get_model(self, name: str) -> PreTrainedModel:
        kwargs = {
            "token": HF_TOKEN,
            "trust_remote_code": True,
            # "torch_dtype": torch.bfloat16 if torch.cuda.is_bf16_supported(including_emulation=False) else torch.float16,
        }
        if "qwen3-reranker" in name.lower():
            mdl = AutoModelForCausalLM.from_pretrained(name, **kwargs)
        elif self.arch == "cross-encoder":
            mdl = AutoModelForSequenceClassification.from_pretrained(name, **kwargs)
        else:
            mdl = AutoModel.from_pretrained(name, **kwargs)
        return mdl


def _determine_architecture(mdl_name: str) -> Literal["dual-encoder", "cross-encoder"]:
    cfg = AutoConfig.from_pretrained(mdl_name, token=HF_TOKEN, trust_remote_code=True)
    cfg_dict = cfg.to_dict()
    is_reranker = (
        "reranker" in cfg_dict.get("_name_or_path", "").lower()
        or "ms-marco" in cfg_dict.get("_name_or_path", "").lower()
    )
    is_nli = cfg_dict.get("label2id") and "entailment" in cfg_dict["label2id"]
    if is_nli or is_reranker:
        return "cross-encoder"
    return "dual-encoder"


def save_mdl(model: "Model", tokenizer: str, mdl_name: str) -> None:
    tmpdir = Path(f"outputs/models/{mdl_name.split('/')[1]}")

    # 1. encoder (HF knows how to reload this)
    model.encoder.save_pretrained(tmpdir / "encoder")

    # 2. tokenizer
    AutoTokenizer.from_pretrained(tokenizer).save_pretrained(tmpdir)

    # 3. classification head
    head_only = {k: v for k, v in model.state_dict().items() if not k.startswith("encoder.")}
    torch.save(head_only, tmpdir / "head_state.pt")

    # 4. meta data
    (tmpdir / "meta.json").write_text(json.dumps({"arch": model.arch, "class_name": model.__class__.__name__}))


def load_model(name: str):
    local_dir = Path(snapshot_download(name, repo_type="model", local_dir_use_symlinks=False))

    meta = json.loads((local_dir / "meta.json").read_text())
    sub_cls = Model.registry[meta["arch"]]

    model = sub_cls(backbone=str(local_dir / "encoder"))
    model.load_state_dict(torch.load(local_dir / "head_state.pt", map_location="cpu"), strict=False)
    log.info(f"loading internal model ({meta['arch']})")
    return model
