"""
Data utilities: WikiText-2 loader.
"""

from __future__ import annotations

import collections

from typing import Dict, Iterable, List, Tuple

import torch
from torch.utils.data import Dataset





# --- WikiText-2 utilities (adapted from wikitext_data.py) ---


def _require_datasets():
    try:
        from datasets import load_dataset  # type: ignore
    except ImportError as e:  # pragma: no cover
        raise ImportError(
            "The 'datasets' package is required for WikiText-2. Install with `pip install datasets`."
        ) from e
    return load_dataset


def simple_tokenize(text: str) -> List[str]:
    return text.strip().split()


def build_vocab(texts: Iterable[str], max_vocab: int = 20000, min_freq: int = 1) -> Dict[str, int]:
    counter: collections.Counter = collections.Counter()
    for txt in texts:
        counter.update(simple_tokenize(txt))
    most_common = [(tok, freq) for tok, freq in counter.most_common() if freq >= min_freq]
    most_common = most_common[: max_vocab - 1]
    vocab = {"<unk>": 0}
    for i, (tok, _) in enumerate(most_common, start=1):
        vocab[tok] = i
    return vocab


def encode(text: str, vocab: Dict[str, int]) -> List[int]:
    unk = vocab["<unk>"]
    return [vocab.get(tok, unk) for tok in simple_tokenize(text)]


class WikiText2SeqDataset(Dataset):
    """
    Returns (input_ids, target_id) where target is the token right after the input window.
    """

    def __init__(self, ids: List[int], seq_len: int, max_samples: int | None = None):
        self.ids = ids
        self.seq_len = seq_len
        self.max_samples = max_samples

    def __len__(self) -> int:
        length = max(0, len(self.ids) - self.seq_len - 1)
        if self.max_samples is not None:
            length = min(length, self.max_samples)
        return length

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:
        x = self.ids[idx : idx + self.seq_len]
        y = self.ids[idx + self.seq_len]
        return torch.tensor(x, dtype=torch.long), int(y)


def load_wikitext2(
    seq_len: int = 64,
    max_vocab: int = 20000,
    train_max_samples: int | None = None,
    val_max_samples: int | None = None,
) -> Tuple[WikiText2SeqDataset, WikiText2SeqDataset, int]:
    load_dataset = _require_datasets()
    ds = load_dataset("wikitext", "wikitext-2-raw-v1")
    train_texts = ds["train"]["text"]
    val_texts = ds["validation"]["text"]

    vocab = build_vocab(train_texts, max_vocab=max_vocab)

    train_ids: List[int] = []
    for txt in train_texts:
        train_ids.extend(encode(txt, vocab))
    val_ids: List[int] = []
    for txt in val_texts:
        val_ids.extend(encode(txt, vocab))

    train_ds = WikiText2SeqDataset(train_ids, seq_len=seq_len, max_samples=train_max_samples)
    val_ds = WikiText2SeqDataset(val_ids, seq_len=seq_len, max_samples=val_max_samples)
    vocab_size = len(vocab)
    return train_ds, val_ds, vocab_size
