import json
import random
random.seed(42)

def create_dataset(total):
    # 假设数据存储在名为 "data.json" 的 JSON 文件中
    with open("data/all-500k.json", "r") as file:
        data = json.load(file)

    category1_data = []
    category2_data = []
    category3_data = []

    # 遍历数据，将每条数据根据 "dataset" 属性添加到相应的种类数据列表中
    for entry in data:
        dataset = entry["dataset"]
        if dataset == "slimorca":
            category1_data.append(entry)
        elif dataset == "code-evol":
            category2_data.append(entry)
        elif dataset == "meta-math":
            category3_data.append(entry)

    # 根据比例计算各个种类的数量
    total_category1 = int(total * 0)
    total_category2 = int(total * 1)
    total_category3 = int(total * 0)

    # 随机抽取各个种类的数据
    sampled_category1 = random.sample(category1_data, total_category1)
    sampled_category2 = random.sample(category2_data, total_category2)
    sampled_category3 = random.sample(category3_data, total_category3)

    # 组合抽样的数据
    combined_data = sampled_category1 + sampled_category2 + sampled_category3

    return combined_data
    # return sampled_category1, sampled_category2, sampled_category3

# 示例使用
total_samples = 500  # 设置总样本数
new_dataset = create_dataset(total_samples)
with open("data/router.json", "w") as f:
    json.dump(new_dataset, f, indent=4, ensure_ascii=False)