# %%
import os
import pathlib
import shutil
from collections import defaultdict

import torch
from huggingface_hub import constants
from retry import retry

from clcp import Model
from clcp.data import CLF_BAD_DSS, CLF_DSS, build_dl
from clcp.metrics import BinaryMetrics, MultiClassMetrics
from clcp.models.base import TestDL
from ml_utils import init_mlflow, log

CACHE = pathlib.Path(constants.HF_HUB_CACHE)


# %%

base_mdls = [
    # "answerdotai/ModernBERT-large",
    # "aarabil/deberta-v3-large",
    # "aarabil/bert-large-uncased",
]


nli_mdls = [
    # "aarabil/bart-large-mnli",
    # "aarabil/nli-roberta-base",
    # "aarabil/bert-base-uncased-nli",
    # "aarabil/bert-large-uncased-nli",
    # "aarabil/bert-large-uncased-nli-triplet",
    # "aarabil/deberta-v3-base-nli",
    # "aarabil/deberta-v3-large-nli",
    # "aarabil/deberta-v3-large-nli-triplet",
    # "aarabil/modernbert-base-nli",
    # "aarabil/modernbert-large-nli",
    # "aarabil/modernbert-large-nli-triplet",
]


rerank_mdls = [
    # "aarabil/ms-marco-MiniLM-L6-v2",
    # "Alibaba-NLP/gte-reranker-modernbert-base",
    # "aarabil/bge-reranker-base",
    # "aarabil/bge-reranker-large",
    # "aarabil/Qwen3-Reranker-0.6B",
    # "aarabil/Qwen3-Reranker-8B",
]

emb_mdls = [
    # "aarabil/all-MiniLM-L6-v2",
    # "aarabil/Qwen3-Embedding-0.6B",
    # "aarabil/Qwen3-Embedding-8B",
    # "aarabil/bge-base-en-v1.5",
    # "aarabil/bge-large-en-v1.5",
    # "Alibaba-NLP/gte-base-en-v1.5",
    # "Alibaba-NLP/gte-large-en-v1.5",
    # "Alibaba-NLP/gte-modernbert-base",
    # "aarabil/e5-base-v2",
    # "aarabil/e5-large-v2",
    "intfloat/e5-mistral-7b-instruct",
]


models = base_mdls + nli_mdls + rerank_mdls + emb_mdls


test_dss = [
    "mnli_m",
    "mnli_mm",
    "anli_r1",
    "anli_r2",
    "anli_r3",
    "wanli",
    "fevernli",
    "lingnli",
    "amazonpolarity",
    "imdb",
    "appreviews",
    "yelpreviews",
    "rottentomatoes",
    "emotiondair",
    "empathetic",
    "financialphrasebank",
    "banking77",
    "massive",
    "wikitoxic_toxicaggregated",
    "wikitoxic_obscene",
    "wikitoxic_threat",
    "wikitoxic_insult",
    "biasframes_offensive",
    "biasframes_sex",
    "biasframes_intent",
    "agnews",
    "yahootopics",
    "trueteacher",
    "manifesto",
    "capsotu",
]


def build_test_dls(
    names: list[str],
    mdl_name: str,
    batch_size: int,
    *,
    paired_data: bool,
    is_test: bool,
    is_dummy: bool,
) -> list[TestDL]:
    te_dls = []
    for name in names:
        dl = build_dl(
            mdl_name=mdl_name,
            name=name,
            split="test",
            batch_size=batch_size,
            paired=paired_data,
            is_test=is_test,
            is_dummy=is_dummy,
        )
        metrics = MultiClassMetrics(metrics=["f1", "acc"]) if name in CLF_DSS else BinaryMetrics(metrics=["roc"])
        te_dls.append(TestDL(name=name, dl=dl, metrics=metrics))
    return te_dls


def get_test_dls(mdl_name, paired_data):
    return build_test_dls(
        names=test_dss,
        mdl_name=mdl_name,
        batch_size=64,
        paired_data=paired_data,
        is_test=False,
        is_dummy=False,
    )


def _dir_size(path) -> int:
    """Total bytes occupied by all files under *path* (follows .rglob)."""
    return sum(f.stat().st_size for f in path.rglob("*") if f.is_file())


def _human(n_bytes: int) -> str:
    """Human‑readable data size, base‑1024."""
    units = ("B", "KiB", "MiB", "GiB", "TiB")
    power = 0
    while n_bytes >= 1024 and power < len(units) - 1:
        n_bytes /= 1024
        power += 1
    return f"{n_bytes:.1f} {units[power]}"


def purge(model_id: str) -> None:
    """
    Delete all cache folders for *model_id* and print each path + size first.
    """
    pattern = f"models--{model_id.replace('/', '--')}*"
    targets = CACHE.glob(pattern)

    for path in targets:
        size = _dir_size(path)
        print(f"🗑️  {path}  ({_human(size)})")
        shutil.rmtree(path, ignore_errors=True)


# %%
@retry(tries=3, delay=60, backoff=2)
def eval_model(mdl_name: str) -> None:
    try:
        mdl = Model.load(name=mdl_name).eval()
        te_dls = get_test_dls(mdl_name=mdl_name, paired_data=mdl.requires_paired_inp)
        mdl.full_eval_step(te_dls=te_dls, step=0, prefix_metric=mdl_name.split("/")[1])

        del mdl
        torch.cuda.empty_cache()
        purge(mdl_name)

    except Exception as exc:
        log.exception(f"Exception while evaluating '{mdl_name}': {exc}")
        raise


init_mlflow()
for mdl_name in models:
    eval_model(mdl_name)
