import os
import random
from typing import Any, Literal, TypeAlias, cast

import torch
from datasets import Dataset as HFDataset
from datasets import concatenate_datasets, load_dataset
from torch import Tensor
from torch.utils.data import DataLoader, Dataset
from transformers.data.data_collator import DataCollatorWithPadding
from transformers.models.auto.tokenization_auto import AutoTokenizer
from transformers.tokenization_utils_base import PreTrainedTokenizerBase

from clcp import HF_TOKEN
from ml_utils import log

Splits: TypeAlias = Literal["train", "test"]

K = 1000

CLF_DSS = {
    "agnews",
    "amazonpolarity",
    "appreviews",
    "banking77",
    "biasframes_intent",
    "biasframes_offensive",
    "biasframes_sex",
    "capsotu",
    "emocontext",
    "emotiondair",
    "empathetic",
    "financialphrasebank",
    "hateoffensive",
    "hatexplain",
    "imdb",
    "manifesto",
    "massive",
    "rottentomatoes",
    "spam",
    "trueteacher",
    "wellformedquery",
    "wikitoxic_identityhate",
    "wikitoxic_insult",
    "wikitoxic_obscene",
    "wikitoxic_threat",
    "wikitoxic_toxicaggregated",
    "yahootopics",
    "yelpreviews",
}

CLF_TEST_DSS = {
    "agnews",
    "capsotu",
    "yahootopics",
    "wikitoxic_insult",
    "financialphrasebank",
    "empathetic",
    "yelpreviews",
    "biasframes_intent",
    "appreviews",
}

CLF_TRAIN_DSS = {ds for ds in CLF_DSS if ds not in CLF_TEST_DSS and "wiki" not in ds and "bias" not in ds}

NLI_TEST_DSS = {
    "mnli_m",
    "mnli_mm",
    "anli_r1",
    "anli_r2",
    "anli_r3",
    "wanli",
    "fevernli",
    "lingnli",
}


CLF_BAD_DSS = {
    "hateoffensive",  # hierachical labelling problem
    "hatexplain",  # hierachical labelling problem
    "emocontext",  # texts are ambiguous
    "emotiondair",  # texts are ambiguous
    "wikitoxic_identityhate",  # labels are too obscure
    "wellformedquery",  # no semantics to capture a well formed query
    "spam",  # no semantics to capture spam vs. no smap
    # -- big ds
    "banking77",
    "manifesto",
    "massive",
}


def get_tokenizer(
    mdl_name: str = "answerdotai/ModernBERT-base",
) -> PreTrainedTokenizerBase:
    padding_side = "left" if any(x in mdl_name.lower() for x in ["mistral", "qwen3"]) else "right"
    log.info(f"{padding_side=}")
    return AutoTokenizer.from_pretrained(mdl_name, token=HF_TOKEN, padding_side=padding_side)


def get_n_classes(ds) -> int:
    return ds.to_pandas().groupby("text").size().iat[0]


def sample_ds(name: str, ds, k: int = K, min_ds_len: int = 5_000, *, is_test: bool = False) -> HFDataset:
    if name not in CLF_DSS:
        return ds.select(range(min(1_000, len(ds)))) if is_test else ds  # type: ignore
    if len(ds) < min_ds_len:
        return ds  # no sampling needed for small clf ds

    # Ensure first two samples are the first two from ds so we can derive n_classes as the number of elems before hitting 2nd 1 label
    n_classes = get_n_classes(ds)
    rng = random.Random(0)
    indices = rng.sample(range(0, len(ds), n_classes), k=k)
    indices[0], indices[1] = 0, n_classes
    return concatenate_datasets([ds.select(range(idx, idx + n_classes)) for idx in indices])


class Data(Dataset):
    def __init__(
        self,
        tok: PreTrainedTokenizerBase,
        name: str = "nli",
        split: Splits = "train",
        *,
        is_test: bool = False,
        is_dummy: bool = False,
    ) -> None:
        """(anchor, pos/neg) text pairs."""

        # Make tokenizer accessible through Data class
        self.tok = tok

        if is_dummy:
            self.data = self._get_dummy_data()
            return

        ds_name = f"aarabil/clcp_{name}"
        ds = load_dataset(ds_name, split=split)

        if name == "yahootopics" and split == "test":  # bad quality
            first_20 = ds.select(range(20))  # type: ignore
            ds = ds.filter(lambda sample: sample["label_text"] != "Business & Finance")
            ds = concatenate_datasets([first_20, ds])  # type: ignore
        if name == "nli" and split == "train":  # bad quality
            ds = ds.filter(lambda sample: sample["task_name"] != "mixtral_small_zeroshot")

        # for clf test sets we need to sample from them as they are too big
        ds = sample_ds(name, ds, k=int(os.getenv("K", K)), is_test=is_test) if split == "test" else ds

        # Sampling for small-scale quick tests
        ds = ds.select(range(min(1_000, len(ds)))) if split == "train" and is_test else ds  # type: ignore

        # Qwen3-Embedding instruct
        inst_nli = "Given a piece of text, retrieve the passage that entails the text the best"
        inst_clf = "Given a piece of text, retrieve relevant label descriptions that best match the text"
        inst = inst_nli if "nli" in name else inst_clf

        if any(x in tok.name_or_path for x in ["Qwen3-Embedding", "e5-mistral"]):
            ds = ds.map(lambda sample: {"text": f"Instruct: {inst}\nQuery: {sample['text']}"})

        # Qwen3-Reranker instruct
        if "Qwen3-Reranker" in tok.name_or_path:
            prefix = '<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be "yes" or "no".<|im_end|>\n<|im_start|>user\n'
            suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
            ds = ds.map(
                lambda sample: {
                    "text": f"{prefix}<Instruct>: {inst}\n<Query>: {sample['text']}",
                    "hypothesis": f"\n<Document>: {sample['hypothesis']}{suffix}",
                }
            )

        log.debug(f"Returning dataset {ds_name} | {split} | {len(ds)} samples")  # type: ignore
        self.data = cast(HFDataset, ds)

    @staticmethod
    def _get_dummy_data() -> HFDataset:
        size = 12
        text, hypotheses = "Anchoring text.", "{} text."
        docs_a = [text] * size
        label_text = ["Positive", "Negative"] * (size // 2)
        docs_b = [hypotheses.format(label) for label in label_text]
        labels = [1 if "Positive" in doc_b else 0 for doc_b in docs_b]
        task_name = ["test"] * size

        log.debug(f"Returning test data of size: {size} samples")
        return HFDataset.from_dict({
            "text": docs_a,
            "hypothesis": docs_b,
            "labels": labels,
            "task_name": task_name,
            "label_text": label_text,
        })

    def __len__(self) -> int:
        return len(self.data)

    def __getitem__(self, idx: int) -> dict[str, Any]:
        return self.data[idx]


def build_dl(
    mdl_name: str,
    name: str,
    split: Splits,
    batch_size: int,
    *,
    paired: bool,
    is_test: bool,
    is_dummy: bool,
) -> DataLoader:
    # To dynamically pad sequences to the longest seq in the batch (rounds up to the nearest multiple-8 for cuda op efficiency)
    tok = get_tokenizer(mdl_name=mdl_name)
    _collate_padding = DataCollatorWithPadding(tokenizer=tok, padding="longest", pad_to_multiple_of=8)
    max_len = 12 if is_dummy else 512

    def paired_collate_fn(batch: list[dict[str, Any]]) -> tuple[dict[str, Tensor], Tensor]:
        docs_a, docs_b, labels, *_ = zip(*(sample.values() for sample in batch), strict=True)
        # (I) Tokenize with truncation on the 1st seq -> jagged nested lists -> (II) padding to longest seq -> pt tensor
        xb: dict[str, Tensor] = _collate_padding(
            tok(
                docs_a,  # type: ignore
                docs_b,  # type: ignore
                max_length=max_len,
                truncation=True,
                return_token_type_ids=False,
            )
        )
        yb = torch.tensor(labels, dtype=torch.float)
        return xb, yb

    def unpaired_collate_fn(batch: list[dict[str, Any]]) -> tuple[dict[str, Tensor], Tensor]:
        docs_a, docs_b, labels, *_ = zip(*(sample.values() for sample in batch), strict=True)
        docs = docs_a + docs_b
        # (I) Tokenize with truncation -> jagged nested lists -> (II) padding to longest seq -> pt tensor
        xb: dict[str, Tensor] = _collate_padding(
            tok(
                docs,  # type: ignore
                max_length=max_len,
                truncation=True,
                return_token_type_ids=False,
            )
        )  # type: ignore
        yb = torch.tensor(labels, dtype=torch.float)
        return xb, yb

    return DataLoader(
        dataset=Data(tok=tok, name=name, split=split, is_test=is_test, is_dummy=is_dummy),
        batch_size=batch_size,
        shuffle=(split == "train"),
        collate_fn=paired_collate_fn if paired else unpaired_collate_fn,
        pin_memory=torch.cuda.is_available(),
        num_workers=4 if torch.cuda.is_available() else 0,
    )


def dl_factory(
    mdl_name: str,
    name: str,
    batch_size: int = 4,
    *,
    is_test: bool = False,
    is_dummy: bool = False,
):
    def factory(split: Splits, *, paired: bool) -> DataLoader:
        return build_dl(
            mdl_name=mdl_name,
            name=name,
            split=split,
            batch_size=batch_size,
            paired=paired,
            is_test=is_test,
            is_dummy=is_dummy,
        )

    return factory
