import pandas as pd
import json
from datasets import load_dataset
from tqdm import tqdm

template = {
    "Definition": [
        ""
    ],
    "Positive Examples": [],
    "Negative Examples": [],
    "Instances": []
}

inputs = load_dataset("PKU-Alignment/BeaverTails")["330k_train"]
category_count = {name: 0 for name in list(inputs["category"][0].keys())}
raw_instances = {}

# 每个类别选120条
for input in tqdm(inputs):
    # 选取input["category"]这个dict中value为True的key
    category = [name for name, value in input["category"].items() if value]
    if len(category) == 1 and category_count[category[0]] < 125:
        category_count[category[0]] += 1
        if category[0] not in raw_instances:
            raw_instances[category[0]] = []
        else:
            raw_instances[category[0]].append({"input": input["prompt"], "output": input["response"]})

print(category_count)
train_instances = []
eval_instances = []
# 拆分数据集
for category in raw_instances:
    n_train = int(len(raw_instances[category]) * 0.8)
    train_instances.extend(raw_instances[category][:n_train])
    eval_instances.extend(raw_instances[category][n_train:])

train_template = template.copy()
train_template["Instances"] = train_instances
eval_template = template.copy()
eval_template["Instances"] = eval_instances

with open("train/BeaverTails_100_14/train.json", "w") as f:
    json.dump(train_template, f)
with open("train/BeaverTails_100_14/eval.json", "w") as f:
    json.dump(eval_template, f)

# load dataset from test/Safety/BeaverTails_100_14/test.json
with open("test/Safety/BeaverTails_100_14/test.json", "r") as f:
    test_template = json.load(f)

with open("train/BeaverTails_100_14/test.json", "w") as f:
    json.dump(test_template, f)