from datasets import load_dataset
from collections import defaultdict

dataset = load_dataset("nlile/hendrycks-MATH-benchmark", split="train")

level_buckets = defaultdict(list)

for idx, example in enumerate(dataset):
    level = int(example["level"])
    if len(level_buckets[level]) < 50:
        level_buckets[level].append(idx)

selected_indices = []
for level in range(1, 6):
    selected = level_buckets[level]
    if len(selected) < 50:
        raise ValueError(f"Level {level} only has {len(selected)} samples.")
    selected_indices.extend(selected[:50])

subset = dataset.select(selected_indices)
subset.to_json("data/validation_dataset/test.jsonl", lines=True, force_ascii=False)
print("Saved subset with 250 samples.")