from __future__ import annotations

from typing import Literal

import torch
from datasets import load_dataset


def get_wikitext2(tokenizer, *, split: Literal["train", "test"] = "test") -> torch.Tensor:
    """Return a 1D tensor of token ids for WikiText-2."""
    import os
    # Disable HuggingFace file locking to avoid contention in parallel processes
    os.environ["HF_DATASETS_DISABLE_FILE_LOCKING"] = "1"

    print(f"[data] loading wikitext2 split={split}...", flush=True)
    ds = load_dataset("Salesforce/wikitext", "wikitext-2-raw-v1", split=split)
    print(f"[data] loaded wikitext2, tokenizing...", flush=True)
    text = "\n\n".join(ds["text"])
    enc = tokenizer.encode(text, bos=True, eos=True)
    print(f"[data] tokenized {len(enc)} tokens", flush=True)
    return torch.tensor(enc, dtype=torch.long)


def split_dataset(token_ids: torch.Tensor, seqlen: int) -> torch.Tensor:
    """Split 1D token tensor into (nseq, seqlen)."""
    if token_ids.ndim != 1:
        raise ValueError("split_dataset expects a 1D tensor")
    seqlen = int(seqlen)
    nseq = token_ids.shape[0] // seqlen
    token_ids = token_ids[: nseq * seqlen]
    return token_ids.reshape(nseq, seqlen)


def take_nseq(token_ids_2d: torch.Tensor, nsamples: int | None) -> torch.Tensor:
    """Take first nsamples sequences. If nsamples is None, return all sequences."""
    if nsamples is None:
        return token_ids_2d
    nsamples = int(nsamples)
    return token_ids_2d[:nsamples]
