# autointerp_hf/data_utils.py

from __future__ import annotations

import json
from typing import Iterable, Optional, Tuple, Union

import torch
from datasets import load_dataset
from transformers import PreTrainedTokenizerBase
from tqdm import tqdm


def _iter_texts_from_hf(
    dataset_name: str,
    split: str = "train",
    text_key: str = "text",
) -> Iterable[str]:
    """Yield raw text rows from a HuggingFace dataset in streaming mode."""
    ds = load_dataset(dataset_name, split=split, streaming=True)
    for row in ds:
        if text_key not in row:
            continue
        val = row[text_key]
        if isinstance(val, str):
            yield val


def _iter_texts_from_jsonl(
    jsonl_path: str,
    text_key: str = "text",
) -> Iterable[str]:
    """Yield raw text rows from a local JSONL file.

    Each line must be valid JSON. The field `text_key` can be:
      - a string
      - a list of strings
    """
    with open(jsonl_path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            try:
                obj = json.loads(line)
            except json.JSONDecodeError:
                continue
            if text_key not in obj:
                continue
            val = obj[text_key]
            if isinstance(val, str):
                yield val
            elif isinstance(val, list):
                for item in val:
                    if isinstance(item, str):
                        yield item


def _skip_first(text_iter: Iterable[str], n: int) -> Iterable[str]:
    """Skip the first n yielded items."""
    it = iter(text_iter)
    for _ in range(max(n, 0)):
        try:
            next(it)
        except StopIteration:
            return
    for x in it:
        yield x


def load_and_tokenize_data(
    dataset_name: str,
    context_length: int,
    total_tokens: int,
    tokenizer: PreTrainedTokenizerBase,
    device: Union[str, torch.device],
    *,
    split: str = "train",
    hf_text_key: str = "text",
    heldout_jsonl: Optional[str] = None,
    heldout_text_key: str = "text",
    skip_first_n_examples: int = 0,
    add_special_tokens: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """Load text from HF streaming or a local JSONL file, tokenize, and chunk.

    Returns:
        input_ids: (N, context_length)
        attention_mask: (N, context_length)
    """
    if heldout_jsonl is not None:
        base_iter = _iter_texts_from_jsonl(heldout_jsonl, text_key=heldout_text_key)
    else:
        base_iter = _iter_texts_from_hf(dataset_name, split=split, text_key=hf_text_key)

    text_iter = _skip_first(base_iter, skip_first_n_examples)

    chunks = []
    tokens_collected = 0

    pbar = tqdm(total=total_tokens, desc="Tokenizing", unit="tok", dynamic_ncols=True)

    for text in text_iter:
        enc = tokenizer(
            text,
            add_special_tokens=add_special_tokens,
            return_tensors="pt",
            truncation=False,
        )
        ids = enc["input_ids"][0]

        idx = 0
        while idx < len(ids) and tokens_collected < total_tokens:
            seg = ids[idx: idx + context_length]
            idx += context_length

            if seg.numel() < context_length:
                break

            attn = torch.ones_like(seg)
            chunks.append((seg, attn))
            tokens_collected += int(seg.numel())
            pbar.update(int(seg.numel()))

        if tokens_collected >= total_tokens:
            break

    pbar.close()

    if not chunks:
        raise RuntimeError("No token chunks were created (empty dataset or too-large skip_first_n_examples).")

    input_ids = torch.stack([x[0] for x in chunks], dim=0).to(device)
    attention_mask = torch.stack([x[1] for x in chunks], dim=0).to(device)
    return input_ids, attention_mask


# Backward-compat alias (older code imported this name).
load_and_tokenize_dataset = load_and_tokenize_data
