from datasets import load_dataset
import os

output_dir = "./data/pubmedqa/split/"
os.makedirs(output_dir, exist_ok=True)

dataset = load_dataset("./data/pubmedqa/") 

train_testvalid = dataset['train'].train_test_split(test_size=0.3, seed=42)


test_valid = train_testvalid['test'].train_test_split(test_size=0.5, seed=42)


final_dataset = {
    'train': train_testvalid['train'],
    'validation': test_valid['train'],
    'test': test_valid['test']
}


print(f"训练集大小: {len(final_dataset['train'])}")
print(f"验证集大小: {len(final_dataset['validation'])}")
print(f"测试集大小: {len(final_dataset['test'])}")


final_dataset['train'].to_json(f"{output_dir}/train.jsonl")
final_dataset['validation'].to_json(f"{output_dir}/validation.jsonl")
final_dataset['test'].to_json(f"{output_dir}/test.jsonl")
