# %%
import contextlib
import pathlib
import shutil
import time

import mlflow
import torch
from huggingface_hub import constants
from retry import retry
from torch import Tensor

from clcp import Model
from clcp.data import build_dl
from clcp.models.pretrained import DualEncoder
from ml_utils import init_mlflow, log

CACHE = pathlib.Path(constants.HF_HUB_CACHE)


# %% DualEncoderTimer: Subclass DualEncoder and adjust forward to only encode labels once
class DualEncoderTimer(DualEncoder, arch="dual-encoder"):
    """
    Inference-only dual-encoder with an on-the-fly label cache.
    The first time a label description is seen its CLS embedding
    is stored; subsequent passes just reuse the cached tensor.
    """

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        # maps   bytes(token_ids_without_pad)  →  embedding Tensor (E,)
        self._label_cache: dict[bytes, torch.Tensor] = {}
        self.pad_id = (
            self.encoder.config.pad_token_id
            if self.encoder.config.pad_token_id is not None
            else self.encoder.config.eos_token_id
        )

    @staticmethod
    def _make_key(ids_row: Tensor, mask_row: Tensor) -> bytes:
        """Return a hashable key that ignores right-side padding."""
        valid_len = int(mask_row.sum())
        return ids_row[:valid_len].cpu().numpy().tobytes()  # padding-free

    @torch.no_grad()
    def forward(self, input_ids: Tensor, attention_mask: Tensor) -> Tensor:
        B2 = input_ids.size(0)  # noqa: N806
        assert B2 % 2 == 0, "batch = docs + labels (concatenated)"  # noqa: S101
        B = B2 // 2  # noqa: N806

        # split docs / labels
        doc_ids, lbl_ids = input_ids.split(B, dim=0)
        doc_mask, lbl_mask = attention_mask.split(B, dim=0)

        # ----- docs: de-duplicate docs inside the batch -----
        flat_docs = doc_ids.view(doc_ids.size(0), -1)
        uniq_docs, inv = torch.unique(flat_docs, return_inverse=True, dim=0)
        uniq_doc_ids = uniq_docs.view(-1, doc_ids.size(1))
        uniq_doc_mask = uniq_doc_ids.ne(self.pad_id).long()
        doc_emb_unique = self.get_embedding(uniq_doc_ids, uniq_doc_mask)  # (U, E)
        doc_emb = doc_emb_unique[inv]  # (B, E)

        # ----- labels: cache with padding-invariant keys -----
        keys: list[bytes] = []
        miss_rows: list[int] = []  # indices that need encoding this pass

        for i in range(B):
            k = self._make_key(lbl_ids[i], lbl_mask[i])
            keys.append(k)
            if k not in self._label_cache:  # first time we see this exact sequence
                self._label_cache[k] = None  # placeholder to keep order deterministic
                miss_rows.append(i)

        if miss_rows:  # encode only the truly new ones
            log.info("caching labels")
            new_ids = lbl_ids[miss_rows]
            new_mask = lbl_mask[miss_rows]
            new_emb = self.get_embedding(new_ids, new_mask)  # (M, E)

            # store embeddings back into the cache
            for idx, k in enumerate(keys[i] for i in miss_rows):
                self._label_cache[k] = new_emb[idx]

        # gather cached embeddings in original label order
        lbl_emb = torch.stack([self._label_cache[k] for k in keys])  # (B, E)

        # dot-product similarity
        return (doc_emb * lbl_emb).sum(dim=1)  # (B,)


# %% Constants

base_mdls = [
    "answerdotai/ModernBERT-large",
    "aarabil/deberta-v3-large",
    "aarabil/bert-large-uncased",
]


nli_mdls = [
    "aarabil/bart-large-mnli",
    "aarabil/nli-roberta-base",
    "aarabil/bert-base-uncased-nli",
    "aarabil/bert-large-uncased-nli",
    "aarabil/bert-large-uncased-nli-triplet",
    "aarabil/deberta-v3-base-nli",
    "aarabil/deberta-v3-large-nli",
    "aarabil/deberta-v3-large-nli-triplet",
    "aarabil/modernbert-base-nli",
    "aarabil/modernbert-large-nli",
    "aarabil/modernbert-large-nli-triplet",
]


rerank_mdls = [
    "aarabil/ms-marco-MiniLM-L6-v2",
    "Alibaba-NLP/gte-reranker-modernbert-base",
    "aarabil/bge-reranker-base",
    "aarabil/bge-reranker-large",
    "aarabil/Qwen3-Reranker-0.6B",
    "aarabil/Qwen3-Reranker-8B",
]

emb_mdls = [
    "aarabil/all-MiniLM-L6-v2",
    "aarabil/Qwen3-Embedding-0.6B",
    "aarabil/Qwen3-Embedding-8B",
    "aarabil/bge-base-en-v1.5",
    "aarabil/bge-large-en-v1.5",
    "Alibaba-NLP/gte-base-en-v1.5",
    "Alibaba-NLP/gte-large-en-v1.5",
    "Alibaba-NLP/gte-modernbert-base",
    "aarabil/e5-base-v2",
    "aarabil/e5-large-v2",
    "intfloat/e5-mistral-7b-instruct",
]


models = base_mdls + nli_mdls + rerank_mdls + emb_mdls


# %% Data
def get_dl(mdl_name, paired_data):
    return build_dl(
        mdl_name=mdl_name,
        name="latency_eval",
        split="test",
        batch_size=20,
        paired=paired_data,
        is_test=False,
        is_dummy=False,
    )


def _dir_size(path) -> int:
    """Total bytes occupied by all files under *path* (follows .rglob)."""
    return sum(f.stat().st_size for f in path.rglob("*") if f.is_file())


def _human(n_bytes: int) -> str:
    """Human‑readable data size, base‑1024."""
    units = ("B", "KiB", "MiB", "GiB", "TiB")
    power = 0
    while n_bytes >= 1024 and power < len(units) - 1:
        n_bytes /= 1024
        power += 1
    return f"{n_bytes:.1f} {units[power]}"


def purge(model_id: str) -> None:
    """
    Delete all cache folders for *model_id* and print each path + size first.
    """
    pattern = f"models--{model_id.replace('/', '--')}*"
    targets = CACHE.glob(pattern)

    for path in targets:
        size = _dir_size(path)
        print(f"🗑️  {path}  ({_human(size)})")
        shutil.rmtree(path, ignore_errors=True)


# %% Run


@contextlib.contextmanager
def wall_timer(mdl_name: str):
    """Simple CPU + GPU aware timer."""
    torch.cuda.synchronize()  # make sure all preceding kernels finish
    start = time.perf_counter()
    yield
    torch.cuda.synchronize()
    elapsed = time.perf_counter() - start  # seconds
    # log properly
    log.info(f"{mdl_name}: {elapsed:.1f} s")
    mlflow.log_metric(key=mdl_name, value=elapsed)


@retry(tries=3, delay=60, backoff=2)
def eval_model(mdl_name: str) -> None:
    try:
        mdl = Model.load(name=mdl_name).eval()
        dl = get_dl(mdl_name=mdl_name, paired_data=mdl.requires_paired_inp)
        with wall_timer(mdl_name=mdl_name.split("/")[1]), torch.inference_mode():
            for xb, yb in dl:
                _ = mdl.forward_pass(xb, yb, use_amp=False)

        del mdl, dl
        torch.cuda.empty_cache()
        purge(mdl_name)

    except Exception as exc:
        log.exception(f"Exception while evaluating '{mdl_name}': {exc}")
        raise


init_mlflow()
for mdl_name in models:
    eval_model(mdl_name)
