import json
from datasets import load_dataset

with open('/llm_unlearning/wmdp/data/fictional_knowledge.json', 'r', encoding='utf-8') as f:
    data = json.load(f)

selected_data = []
for sample in data:
    selected_data.append({
        "train_context": sample["train_context"]
    })

with open('/llm_unlearning/wmdp/data/processed_fictional_knowledge.json', 'w', encoding='utf-8') as f:
    json.dump(selected_data, f, ensure_ascii=False, indent=4)

dataset = load_dataset(
    'json',
    data_files='/llm_unlearning/wmdp/data/processed_fictional_knowledge.json',
    split='train'
)