import json
import sys
import os

source_dir = r'E:\2024\experiment_code_clone\GraphFeatureExtractor-main\GraphFeatureExtractor-main\data-cfg'

# # 测试代码
# graph = [
#     [28, 3, 27],
#     [29, 2, 31],
#     [27, 1, 31],
#     [34, 4, 32]
# ]

# print(fd.reindex_graph_data(graph))
for process_dirname in ['reentrancy', 'wild-clean', 'external_call', 'access_control', 'delegatecall']:
# process_dirname = 'reentrancy'
# 处理数据使得index和graph以及targets格式正确
    for dirpath, dirnames, filenames in os.walk(source_dir):
        for dirname in dirnames:
            if process_dirname in dirname:
                source_folder_dir = os.path.join(dirpath, dirname)
                json_file = os.path.join(source_folder_dir, 'all.json')
                train_file = os.path.join(source_folder_dir, 'train.json')
                valid_file = os.path.join(source_folder_dir, 'valid.json')
                import json

                input_file = json_file
                combined_data = []
                # ]]}][{
                dirname = os.path.basename(os.path.dirname(input_file))

                # if 'reentrancy' in dirname or 'reentrancy' in dirname or 'reentrancy' in dirname:
                with open(input_file, 'r') as f:
                    content = f.read()
                content = content.replace("]]}][{", "]]},{")
                content = content.replace("[][][][][][][][][][]", "")
                data = json.loads(content)
                #
                combined_data = data
                print(len(combined_data))

                # for item in combined_data:
                #     graph_data = item['graph']
                #     targets = item['targets']
                #     node_features = item['node_features']
                #     str_target = str(targets)
                #     item['targets'] = str_target
                #     if len(graph_data) != len(node_features):
                #         print(len(graph_data), len(node_features))
                #         print(graph_data)
                #     if len(graph_data) > 200:
                #         print(len(graph_data))

                # # 分成每4个元素一组
                # new_features = []
                # for node in node_features:
                #     #变为64位
                #     grouped_features = [node[i:i + 4] for i in range(0, len(node), 4)]
                #
                #     # 计算每组的平均值
                #     averaged_features = [round(sum(group) / len(group), 6) for group in grouped_features]
                #
                #     # # 找到最大值的位置 类似于one hot
                #     # max_index = averaged_features.index(max(averaged_features))
                #     # #
                #     # # # 创建一个全零的列表，并将最大值的位置设置为1
                #     # binary_features = [0] * 64
                #     # binary_features[max_index] = 1
                #
                #     # 一个阈值来决定是非为一
                #     # binary_features = [1 if x > 0.8 else 0 for x in averaged_features]
                #
                #     new_features.append(averaged_features)
                # item['node_features'] = new_features

                with open(input_file, 'w', encoding='utf-8') as f:
                    json.dump(combined_data, f)
    #
    # # 将数据写入到train和valid文件
    for dirpath, dirnames, filenames in os.walk(source_dir):
        for dirname in dirnames:
            if process_dirname in dirname:
                source_folder_dir = os.path.join(dirpath, dirname)
                json_file = os.path.join(source_folder_dir, 'all.json')
                train_file = os.path.join(source_folder_dir, 'train.json')
                valid_file = os.path.join(source_folder_dir, 'valid.json')
                import json

                input_file = json_file
                combined_data = []
                dirname = os.path.basename(os.path.dirname(input_file))
                with open(input_file, 'r', encoding='utf-8') as f:
                    # data = json.loads(f.read())
                    data = json.load(f)

                print(len(data))

                import random

                # 打乱数据
                random.shuffle(data)
                # 计算分割索引
                train_size = int(0.6 * len(data))
                print(train_size)
                # 分割数据
                train_data = data[:train_size]
                valid_data = data[train_size:]
                # 将训练集数据写入train.json文件
                with open(train_file, 'w', encoding='utf-8') as f:
                    json.dump(train_data, f)
                # 将验证集数据写入valid.json文件
                with open(valid_file, 'w', encoding='utf-8') as f:
                    json.dump(valid_data, f)
