import random, json
from pathlib import Path
from datasets import load_dataset
from transformers import AutoTokenizer

MODEL = "allenai/OLMo-2-0425-1B"
DATASET = "allenai/olmo-mix-1124"
CONFIG = "wiki"
SPLIT = "train"

SEQ_LEN = 32
TRAIN_SAMPLES = 32000
EVAL_SAMPLES = 200
SEED = 42

OUT_DIR = Path("data/forget_samples") / MODEL / "from_wiki_olmo"
TRAIN_FILE = OUT_DIR / "sampled_texts.jsonl"
EVAL_FILE = OUT_DIR / "sampled_texts_for_evaluate_loss.jsonl"

random.seed(SEED)
OUT_DIR.mkdir(parents=True, exist_ok=True)

tokenizer = AutoTokenizer.from_pretrained(MODEL, use_fast=True, trust_remote_code=True)

ds = load_dataset(DATASET, CONFIG, split=SPLIT, streaming=True)

def reservoir_sample(stream, k, seed=42):
    random.seed(seed)
    reservoir = []
    for i, ex in enumerate(stream, start=1):
        if i <= k:
            reservoir.append(ex)
        else:
            j = random.randint(1, i)
            if j <= k:
                reservoir[j-1] = ex
    return reservoir

def sample_windows_from_article(text, seq_len=SEQ_LEN, max_per_article=5):
    ids = tokenizer.encode(text, add_special_tokens=False)
    if len(ids) < seq_len:
        return []
    starts = list(range(0, len(ids) - seq_len + 1))
    random.shuffle(starts)
    out = []
    for s in starts[:max_per_article]:
        win_ids = ids[s:s+seq_len]
        snippet = tokenizer.decode(win_ids, skip_special_tokens=True)
        out.append({"input_ids": win_ids, "text": snippet})
    return out

def make_samples(reservoir, n_samples, exclude_set=None):
    samples = []
    exclude_set = exclude_set or set()
    for ex in reservoir:
        text = ex.get("text")
        if not text or not text.strip():
            continue
        windows = sample_windows_from_article(text)
        for w in windows:
            if w["text"] in exclude_set:
                continue
            samples.append(w)
            if len(samples) >= n_samples:
                return samples
    return samples

reservoir_for_train = reservoir_sample(ds, 10 * TRAIN_SAMPLES, seed=SEED)
train_samples = make_samples(reservoir_for_train, TRAIN_SAMPLES)
with TRAIN_FILE.open("w", encoding="utf-8") as f:
    for rec in train_samples:
        f.write(json.dumps(rec, ensure_ascii=False) + "\n")

reservoir_for_eval = reservoir_sample(ds, 10 * EVAL_SAMPLES, seed=SEED + 1)
train_texts = set([rec["text"] for rec in train_samples])
eval_samples = make_samples(reservoir_for_eval, EVAL_SAMPLES, exclude_set=train_texts)
with EVAL_FILE.open("w", encoding="utf-8") as f:
    for rec in eval_samples:
        f.write(json.dumps(rec, ensure_ascii=False) + "\n")

print(f"train: {len(train_samples)} → {TRAIN_FILE}")
print(f"eval : {len(eval_samples)} → {EVAL_FILE}")