from datasets import load_dataset, Dataset
import pandas as pd

ds_bio = load_dataset("cais/wmdp", "wmdp-bio", split="test[:50]")

def process_samples(dataset, category):
    return [{
        "id": f"{category.lower()}_{i:03}",
        "category": category,
        "input": example["question"],
        "target": example["choices"][example['answer']],
        "choices" : example["choices"],
    } for i, example in enumerate(dataset)]

bio_samples = process_samples(ds_bio, "WMDP-Bio")
df = pd.DataFrame(bio_samples)


df.to_csv("llm_unlearning/wmdp/data/wmdp_50_samples.csv", index=False)
print("✅ Saved as wmdp_100_samples.csv")