import json
import torch

# 加载结果文件
resultFile = json.load(open("./data/all-2.json"))

# 准备存储每个标签的数据
label_data = {0: [], 1: [], 2: [], 3: []}

# 将数据按标签分组
for key, values in resultFile.items():
    label = int(key.split("|||")[1])  # 从key中提取标签
    if label in label_data:
        label_data[label].append((key, values))

# 打乱数据以确保随机性
torch.manual_seed(0)  # 为了结果可重复


# 函数：分割数据集
def split_data(data_list, train_ratio=0.8, val_ratio=0.1):
    total_size = len(data_list)
    train_size = int(total_size * train_ratio)
    val_size = int(total_size * val_ratio)
    test_size = total_size - train_size - val_size  # 确保所有数据都被分配

    indices = torch.randperm(total_size).tolist()
    data_list = [data_list[i] for i in indices]

    train_items = data_list[:train_size]
    val_items = data_list[train_size:train_size + val_size]
    test_items = data_list[train_size + val_size:]

    return train_items, val_items, test_items


# 按标签分割数据
train_data, val_data, test_data = [], [], []

for label in label_data:
    train_items, val_items, test_items = split_data(label_data[label])
    train_data.extend(train_items)
    val_data.extend(val_items)
    test_data.extend(test_items)

# 创建字典并保存文件
trainFileDict = {item[0]: item[1] for item in train_data}
valFileDict = {item[0]: item[1] for item in val_data}
testFileDict = {item[0]: item[1] for item in test_data}

# 打印训练集的大小
print("数据集已分割完毕")
print("The size of the training file is: {}".format(len(trainFileDict)))
print("The size of the training file is: {}".format(len(valFileDict)))
print("The size of the training file is: {}".format(len(testFileDict)))
# 保存数据到 JSON 文件
json.dump(trainFileDict, open("./data/subsetTrainFiles.json", 'w'))
json.dump(valFileDict, open("./data/subsetValFiles.json", 'w'))
json.dump(testFileDict, open("./data/subsetTestFiles.json", 'w'))

