import os
import pandas as pd
import json
from datasets import load_dataset
from tqdm import tqdm

template = {
    "Definition": [
        ""
    ],
    "Positive Examples": [],
    "Negative Examples": [],
    "Instances": []
}

train_and_eval_data = load_dataset("PKU-Alignment/BeaverTails")["330k_train"]
test_data = load_dataset("PKU-Alignment/BeaverTails")["330k_test"]

# 人工筛选出的类别
# categorys_for_CL = ["self_harm", "drug_abuse,weapons,banned_substance", "financial_crime,property_crime,theft", "animal_abuse", "violence,aiding_and_abetting,incitement"]
categorys_for_CL = ["terrorism,organized_crime", "misinformation_regarding_ethics,laws_and_safety"]


category_count = {name: 0 for name in categorys_for_CL}
raw_instances = {}

# train and eval datasets
# 每个类别选1250条
for input in tqdm(train_and_eval_data):
    # 选取input["category"]这个dict中value为True的key
    category = [name for name, value in input["category"].items() if value]
    if len(category) == 1 and category[0] in categorys_for_CL and category_count[category[0]] < 1250:
        category_count[category[0]] += 1
        if category[0] not in raw_instances:
            raw_instances[category[0]] = []
        else:
            raw_instances[category[0]].append({"input": input["prompt"], "output": input["response"]})

print("train and eval sample: ", category_count)

train_instances = {name: [] for name in categorys_for_CL}
eval_instances = {name: [] for name in categorys_for_CL}

for category in raw_instances:
    n_train = int(len(raw_instances[category]) * 0.8)
    train_instances[category].extend(raw_instances[category][:n_train])
    eval_instances[category].extend(raw_instances[category][n_train:])

    train_template = template.copy()
    train_template["Instances"] = train_instances[category]
    eval_template = template.copy()
    eval_template["Instances"] = eval_instances[category]
    
    if not os.path.exists(f"train/CL_1k/{category}"):
        os.makedirs(f"train/CL_1k/{category}")
    
    with open(f"train/CL_1k/{category}/train.json", "w") as f:
        json.dump(train_template, f)
    with open(f"train/CL_1k/{category}/eval.json", "w") as f:
        json.dump(eval_template, f)

# 清零
category_count = {name: 0 for name in categorys_for_CL}
raw_instances = {}

# test datasets
for input in tqdm(test_data):
    category = [name for name, value in input["category"].items() if value]
    if len(category) == 1 and category[0] in categorys_for_CL and category_count[category[0]] < 100:
        category_count[category[0]] += 1
        if category[0] not in raw_instances:
            raw_instances[category[0]] = []
        else:
            raw_instances[category[0]].append({"input": input["prompt"], "output": ''})

print("test sample: ", category_count)

for category in raw_instances:
    test_template = template.copy()
    test_template["Instances"] = raw_instances[category]
    
    if not os.path.exists(f"train/CL_1k/{category}"):
        os.makedirs(f"train/CL_1k/{category}")
        
    with open(f"train/CL_1k/{category}/test.json", "w") as f:
        json.dump(test_template, f)
