import random, json
from pathlib import Path
from datasets import load_dataset
from transformers import AutoTokenizer

MODEL = "EleutherAI/pythia-1b"
DATASET = "EleutherAI/the_pile_deduplicated"
NAME = "default"                               
SPLIT = "train"

SEQ_LEN = 32
TRAIN_SAMPLES = 32000
EVAL_SAMPLES = 200
SEED = 42

OUT_DIR = Path("data/forget_samples") / MODEL / "from_pile_all"
TRAIN_FILE = OUT_DIR / "sampled_texts.jsonl"
EVAL_FILE = OUT_DIR / "sampled_texts_for_evaluate_loss.jsonl"

random.seed(SEED)
OUT_DIR.mkdir(parents=True, exist_ok=True)

tokenizer = AutoTokenizer.from_pretrained(MODEL, use_fast=True)

ds = load_dataset(DATASET, NAME, split=SPLIT, streaming=True)

def reservoir_sample(stream, k, seed=42, log_every=100_000):
    random.seed(seed)
    reservoir = []
    total = 0
    for total, ex in enumerate(stream, start=1):
        if total <= k:
            reservoir.append(ex)
        else:
            j = random.randint(1, total)
            if j <= k:
                reservoir[j-1] = ex
        if total % log_every == 0:
            print(f"[reservoir] processed {total:,} examples")
    print(f"[reservoir] done. total seen = {total:,}")
    return reservoir

def sample_windows_from_article(text, seq_len=SEQ_LEN, max_per_article=5):
    ids = tokenizer.encode(text, add_special_tokens=False)
    if len(ids) < seq_len:
        return []
    starts = list(range(0, len(ids) - seq_len + 1))
    random.shuffle(starts)
    out = []
    for s in starts[:max_per_article]:
        win_ids = ids[s:s+seq_len]
        snippet = tokenizer.decode(win_ids, skip_special_tokens=True)
        out.append({"input_ids": win_ids, "text": snippet})
    return out

def make_samples(reservoir, n_samples, exclude_set=None, log_every=5_000):
    samples = []
    exclude_set = exclude_set or set()
    for idx, ex in enumerate(reservoir, start=1):
        text = ex.get("text")
        if not text or not text.strip():
            continue
        for w in sample_windows_from_article(text):
            if w["text"] in exclude_set:
                continue
            samples.append(w)
            if len(samples) >= n_samples:
                print(f"[make_samples] done. took {idx} reservoir articles.")
                return samples
        if idx % log_every == 0:
            print(f"[make_samples] scanned {idx:,} reservoir articles / collected {len(samples):,}")
    print(f"[make_samples] finished reservoir. collected {len(samples):,}")
    return samples

print("=== Reservoir sampling for train ===")
reservoir_for_train = reservoir_sample(ds, 10 * TRAIN_SAMPLES, seed=SEED, log_every=100_000)
train_samples = make_samples(reservoir_for_train, TRAIN_SAMPLES)
with TRAIN_FILE.open("w", encoding="utf-8") as f:
    for rec in train_samples:
        f.write(json.dumps(rec, ensure_ascii=False) + "\n")

print("=== Reservoir sampling for eval ===")
ds_eval = load_dataset(DATASET, NAME, split=SPLIT, streaming=True)
reservoir_for_eval = reservoir_sample(ds_eval, 10 * EVAL_SAMPLES, seed=SEED + 1, log_every=100_000)
train_texts = set(rec["text"] for rec in train_samples)
eval_samples = make_samples(reservoir_for_eval, EVAL_SAMPLES, exclude_set=train_texts)
with EVAL_FILE.open("w", encoding="utf-8") as f:
    for rec in eval_samples:
        f.write(json.dumps(rec, ensure_ascii=False) + "\n")

print(f"train: {len(train_samples)} → {TRAIN_FILE}")
print(f"eval : {len(eval_samples)} → {EVAL_FILE}")