import pandas as pd
import json
from datasets import load_dataset

template = {
    "Definition": [
        ""
    ],
    "Positive Examples": [],
    "Negative Examples": [],
    "Instances": []
}

inputs = load_dataset("PKU-Alignment/BeaverTails")["330k_test"]
category_count = {name: 0 for name in list(inputs["category"][0].keys())}

# 每个类别选100条
for input in inputs:
    # 选取input["category"]这个dict中value为True的key
    category = [name for name, value in input["category"].items() if value]
    if not input["is_safe"] and len(category) == 1 and category_count[category[0]] < 100:
        category_count[category[0]] += 1
        template["Instances"].append({"input": input["prompt"], "output": ""})

with open("test/Safety/BeaverTails_100_14/test.json", "w") as f:
    json.dump(template, f)