
from __future__ import annotations
from dataclasses import dataclass
from typing import List, Dict, Optional, Iterable, Tuple
import json
import torch
from torch.utils.data import Dataset, DataLoader, RandomSampler

EOS_TOKEN = "[EOS]"


def load_vocab(vocab_path: Optional[str], scan_paths: Optional[List[str]] = None) -> Dict[str, int]:
    """
    Load a vocab {"vocab": [...]} from JSON. If not provided, scan JSONL files
    to build a vocab (stable order), ensuring EOS is included.
    """
    tok2id: Dict[str, int] = {}
    if vocab_path:
        with open(vocab_path, "r", encoding="utf-8") as f:
            data = json.load(f)
        vocab = data["vocab"]
    else:
        # Build vocab by scanning data files (rarely needed if you saved vocab from datagen)
        assert scan_paths, "Either vocab_path or scan_paths must be provided."
        seen = []
        for p in scan_paths:
            with open(p, "r", encoding="utf-8") as f:
                for line in f:
                    rec = json.loads(line)
                    for t in rec["tokens"]:
                        if t not in seen:
                            seen.append(t)
        vocab = seen
        if EOS_TOKEN not in vocab:
            vocab.append(EOS_TOKEN)

    for i, t in enumerate(vocab):
        tok2id[t] = i
    return tok2id


class JsonlSequenceDataset(Dataset):
    """
    Loads JSONL with {"tokens": ["1","0","[EOS]", ...], ...} per line.
    Keeps each example as a list[int] (token ids).
    Sequences are NOT padded here. Padding happens in collate_fn.
    """
    def __init__(self, path: str, tok2id: Dict[str, int], max_seq_len: int):
        self.path = path
        self.tok2id = tok2id
        self.max_seq_len = max_seq_len
        self.examples: List[List[int]] = []
        self._load()

    def _load(self) -> None:
        with open(self.path, "r", encoding="utf-8") as f:
            for line in f:
                rec = json.loads(line)
                toks: List[str] = rec["tokens"]
                ids = [self.tok2id[t] for t in toks]
                # Safety clip in case upstream produced a too-long example
                if len(ids) >= self.max_seq_len:
                    ids = ids[: self.max_seq_len - 1]  # keep room for at least one token to predict
                self.examples.append(ids)

    def __len__(self) -> int:
        return len(self.examples)

    def __getitem__(self, idx: int) -> List[int]:
        return self.examples[idx]


def _pad_batch(
    sequences: List[List[int]],
    pad_id: int,
    max_len: Optional[int] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Pads a list of variable-length int sequences to a single tensor [B, T],
    returns (input_ids, attention_mask). Uses pad_id for padding.
    """
    B = len(sequences)
    T = max_len if max_len is not None else max(len(s) for s in sequences)
    input_ids = torch.full((B, T), pad_id, dtype=torch.long)
    attention_mask = torch.zeros((B, T), dtype=torch.long)
    for i, seq in enumerate(sequences):
        L = min(len(seq), T)
        input_ids[i, :L] = torch.tensor(seq[:L], dtype=torch.long)
        attention_mask[i, :L] = 1
    return input_ids, attention_mask


def collate_causal_lm(batch: List[List[int]], pad_id: int, max_len: Optional[int] = None):
    """
    Collate for GPT-style causal LM.
    Returns dict with input_ids, attention_mask, labels.
    We set labels = input_ids with padded positions masked to -100.
    GPT2LMHeadModel handles the internal shift.
    """
    input_ids, attention_mask = _pad_batch(batch, pad_id=pad_id, max_len=max_len)
    labels = input_ids.clone()
    labels[attention_mask == 0] = -100  # ignore pad
    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": labels,
    }


def build_dataloader(
    dataset: JsonlSequenceDataset,
    batch_size: int,
    max_steps: Optional[int] = None,
    replacement: bool = True,
    num_workers: int = 0,
    pad_id: Optional[int] = None,
    max_len: Optional[int] = None,
) -> DataLoader:
    """
    DataLoader with RandomSampler(replacement=True) so each batch is random.
    If max_steps is provided, we set num_samples = max_steps * batch_size.
    """
    assert pad_id is not None, "pad_id must be provided (use eos as pad)."
    sampler = RandomSampler(
        dataset,
        replacement=replacement,
        num_samples=(max_steps * batch_size) if (replacement and max_steps) else None,
    )
    return DataLoader(
        dataset,
        batch_size=batch_size,
        sampler=sampler,
        num_workers=num_workers,
        pin_memory=True,
        drop_last=True,  # stable shapes
        collate_fn=lambda batch: collate_causal_lm(batch, pad_id=pad_id, max_len=max_len),
    )
