from datasets import load_dataset
import random
import json

ds = load_dataset("rajpurkar/squad")

questions = []
for split in ['train', 'validation']:
    for d in ds[split]:
        questions.append(d['question'])


random.seed(4234)
random.shuffle(questions)

clean_data = questions[:50000]

aug_data = questions[50000:]

with open('data/clean_data.json', 'w') as f:
    json.dump(clean_data, f)

with open('data/aug_data.json', 'w') as f:
    json.dump(aug_data, f)