"""Data loading utilities for language and vision datasets.

Provides a single entry point `build_dataloaders(config)` that inspects the
model backend and selected calibration dataset to construct appropriate
PyTorch DataLoader objects for calibration and evaluation.

Language (HF causal LM backend):
  - Uses HuggingFace datasets (e.g., WikiText) with simple tokenization.
Vision (timm ViT backend):
  - Uses torchvision datasets (e.g., ImageNet) with standard transforms.

Note: This module keeps dependencies optional. If a required library is
missing, it raises a clear RuntimeError suggesting installation.
"""
from __future__ import annotations

import random
from dataclasses import dataclass
from typing import Any, Iterable, List, Tuple, Optional

from transformers import AutoTokenizer
import constants
from constants import (
    TOKENIZER_INPUT_IDS, TOKENIZER_ATTENTION_MASK,
    TOKENIZER_RETURN_TENSORS, TOKENIZER_ADD_SPECIAL_TOKENS
)
from run_config import RunConfig

try:
    import torch
    from torch.utils.data import DataLoader, Dataset
except Exception as e:  # pragma: no cover
    raise RuntimeError("PyTorch is required for data loading.") from e

# Optional imports guarded for language datasets
try:  # pragma: no cover - optional dependency
    from datasets import load_dataset, load_from_disk  # type: ignore
except Exception:  # pragma: no cover
    load_dataset = None  # type: ignore

try:  # pragma: no cover - optional dependency
    from transformers import PreTrainedTokenizerBase  # type: ignore
except Exception:  # pragma: no cover
    PreTrainedTokenizerBase = Any  # type: ignore

# Vision imports
try:  # pragma: no cover - optional dependency
    import torchvision.transforms as T  # type: ignore
    import torchvision.datasets as TVDatasets  # type: ignore
except Exception:
    T = None  # type: ignore
    TVDatasets = None  # type: ignore


@dataclass(slots=True)
class SimpleTextDataset(Dataset):
    """A minimal text dataset wrapper around a list of token ID tensors."""
    samples: List[torch.Tensor]

    def __len__(self) -> int:  # pragma: no cover - trivial
        return len(self.samples)

    def __getitem__(self, idx: int) -> torch.Tensor:  # pragma: no cover - trivial
        return self.samples[idx]


def _build_packed_wikitext_calibration(
    texts,
    tokenizer,
    seqlen: int,
    num_samples: int,
):
    """
    Returns List[Tensor] where each tensor is [1, seqlen]
    """
    # 1) concatenate
    full_text = "\n\n".join(texts)

    # 2) tokenize once
    tok = tokenizer(full_text, **{TOKENIZER_RETURN_TENSORS: "pt"})
    input_ids = tok[TOKENIZER_INPUT_IDS][0]  # [N]

    # 3) chunk
    blocks = []
    n_blocks = min(num_samples, input_ids.numel() // seqlen)

    for i in range(n_blocks):
        block = input_ids[i * seqlen : (i + 1) * seqlen]
        blocks.append(block.unsqueeze(0))  # [1, seqlen]

    return blocks

def _tokenize_texts(texts: Iterable[str], tokenizer: PreTrainedTokenizerBase, max_length: int = 1024) -> List[torch.Tensor]:
    """Tokenize raw text lines into tensors (no special batching)."""
    out: List[torch.Tensor] = []
    for line in texts:
        if not line:
            continue
        encoded = tokenizer(
            line,
            truncation=True,
            max_length=max_length,
            **{TOKENIZER_RETURN_TENSORS: "pt"},
        )
        out.append(encoded[TOKENIZER_INPUT_IDS].squeeze(0))
    return out


def get_fineweb_edu(
        tokenizer: AutoTokenizer,
        use_dataset_path: bool,
        dataset_path: Optional[str],
        max_sequence_length: int,
        num_calibration_samples: Optional[int] = None,
        seed: int = 42
) -> List[torch.Tensor]:
    if use_dataset_path and dataset_path:
        # Load from local path
        train_dataset_raw = load_from_disk(dataset_path)
    else:
        train_dataset_raw = load_dataset("HuggingFaceFW/fineweb-edu", "sample-10BT", split="train", streaming=True)
    train_dataset_raw = train_dataset_raw.shuffle(seed=seed, buffer_size=1_000)
    trainloader = []
    for j, sample in enumerate(train_dataset_raw):
        trainenc = tokenizer(
            sample['text'],
            **{TOKENIZER_RETURN_TENSORS: "pt"}
        )
        if trainenc.input_ids.shape[1] < max_sequence_length:
            continue
        i = random.randint(0, trainenc.input_ids.shape[1] - max_sequence_length)
        tokenized_sample = trainenc.input_ids[:, i:i + max_sequence_length]
        trainloader.append(tokenized_sample)
        if len(trainloader) >= num_calibration_samples:
            break
    return trainloader


def get_wikitext_for_calibration(
    tokenizer,
    nsamples: int,
    seqlen: int,
    use_dataset_path,
    dataset_path,
    dataset_name,
    seed: int = 0,
    return_dict: bool = False,
    calib_seqlen: int | None = None,
    align_to_bos: bool = True,):

    if calib_seqlen is None:
        calib_seqlen = seqlen
    if calib_seqlen <= 0:
        raise ValueError(f"calib_seqlen must be > 0, got {calib_seqlen}")

    if use_dataset_path and dataset_path:
        # Load from local path
        ds = load_from_disk(dataset_path)
    else:
        # Map common aliases to canonical HF dataset names
        if dataset_name.startswith(constants.WIKITEXT):
            hf_name = "Salesforce/wikitext"
            subset = "wikitext-2-raw-v1"
        else:
            raise ValueError(f"Unsupported dataset for calibration: {dataset_name}")

        ds = load_dataset(hf_name, subset) if subset else load_dataset(hf_name)

    # Build one long stream (FP-Quant style)
    train_split = ds.get("train") or ds["train"]
    texts = [ex.get("text", "") for ex in train_split]
    full_text = "\n\n".join(texts)

    enc = tokenizer(full_text, **{TOKENIZER_RETURN_TENSORS: "pt", TOKENIZER_ADD_SPECIAL_TOKENS: True})
    input_ids_1d = enc[TOKENIZER_INPUT_IDS][0]  # [N]

    # Determine BOS id if available
    bos_id = getattr(tokenizer, "bos_token_id", None)
    if bos_id is None:
        align_to_bos = False

    min_needed = calib_seqlen + 1
    if input_ids_1d.numel() < min_needed:
        raise ValueError(
            f"WikiText token stream too short for calib_seqlen={calib_seqlen}. "
            f"Got {input_ids_1d.numel()} tokens."
        )

    g = torch.Generator()
    g.manual_seed(seed)

    max_start = input_ids_1d.numel() - seqlen - 1  # keep room
    samples = []

    for _ in range(nsamples):
        start = torch.randint(0, max_start + 1, (1,), generator=g).item()

        if align_to_bos and bos_id is not None:
            bos_start = _find_next_bos_start(input_ids_1d, start, bos_id)
            # Ensure we can still take a full block; otherwise fall back to original start.
            if bos_start <= max_start:
                start = bos_start

        block = input_ids_1d[start: start + calib_seqlen].unsqueeze(0)  # [1, calib_seqlen]

        if return_dict:
            samples.append(
                {
                    TOKENIZER_INPUT_IDS: block,
                    TOKENIZER_ATTENTION_MASK: torch.ones_like(block),
                }
            )
        else:
            samples.append(block)
    return samples


# Only for evaluation
def get_wikitext2(tokenizer,  sequence_length: int, use_dataset_path, dataset_path, dataset_name):
    if use_dataset_path and dataset_path:
        # Load from local path
        test_dataset_raw = load_from_disk(dataset_path)
    else:
        # Map common aliases to canonical HF dataset names
        if dataset_name.startswith(constants.WIKITEXT):
            hf_name = "Salesforce/wikitext"
            subset = "wikitext-2-raw-v1"
        else:
            hf_name = dataset_name
            subset = None

        test_dataset_raw = load_dataset(hf_name, subset) if subset else load_dataset(hf_name)
    test_dataset_raw = test_dataset_raw.get("test") or test_dataset_raw["test"] if "test" in test_dataset_raw else None
    test_dataset_tok = tokenizer("\n\n".join(test_dataset_raw["text"]), return_tensors="pt").input_ids
    num_test_sequences = test_dataset_tok.numel() // sequence_length
    test_loader = []
    for i in range(num_test_sequences):
        test_loader.append(test_dataset_tok[:, i * sequence_length : (i + 1) * sequence_length])
    return test_loader


def _find_next_bos_start(input_ids_1d: torch.Tensor, start: int, bos_id: int) -> int:
    """
    Return an index >= start such that input_ids_1d[idx] == bos_id.
    If none exists, return start (caller can decide what to do).
    """
    if start <= 0:
        start = 0
    if start >= input_ids_1d.numel():
        return input_ids_1d.numel() - 1
    hits = (input_ids_1d[start:] == bos_id).nonzero(as_tuple=False)
    if hits.numel() == 0:
        return start
    return start + int(hits[0].item())


def build_dataloaders(config: RunConfig, model_config, tokenizer: PreTrainedTokenizerBase | None = None) -> Tuple[Any, Any]:
    """
    Public entry point for constructing calibration & eval language DataLoaders.
    """

    dataset_name = config.calibration_dataset
    dataset_path = getattr(config, "dataset_path", None)
    seqlen = min(2048, model_config.max_position_embeddings)
    if constants.FINEWEB in dataset_name:
        calib = get_fineweb_edu(
                tokenizer=tokenizer,
                dataset_path=config.dataset_path,
                use_dataset_path=config.use_dataset_path,
                max_sequence_length=seqlen,
                num_calibration_samples=config.n_samples,
                seed=config.seed
        )
    else:

        if tokenizer is None:
            raise ValueError("Tokenizer required for language backend.")


    calib_seqlen = getattr(config, "calib_sequence_length", None)
    if calib_seqlen is not None:
        calib_seqlen = min(int(calib_seqlen), seqlen)

    calib = get_wikitext_for_calibration(tokenizer, dataset_path=dataset_path, dataset_name=dataset_name,
                                         use_dataset_path=config.use_dataset_path,
                                         nsamples=config.n_samples, seqlen=seqlen, calib_seqlen=calib_seqlen, align_to_bos=True, seed=config.seed)
    eval = get_wikitext2(tokenizer, seqlen, config.use_dataset_path, dataset_path, constants.WIKITEXT)
    return calib, eval


__all__ = ["build_dataloaders"]