# -- coding:utf-8 --
import re
import json
import os

previous_out_file = None
cached_datas = None


def add_node(reindex_graph):
    existing_nodes = [ele[0] for ele in reindex_graph]
    i = 0
    while i < len(reindex_graph):
        source, num, target = reindex_graph[i]
        if target not in existing_nodes:
            reindex_graph.insert(i + 1, [target, 0, 0])
            existing_nodes.append(target)
            i += 1
        i += 1
    return reindex_graph


def insert_graph(graph):
    # 创建一个集合来存储所有的 [target, _, source] 对，用于快速查找
    existing_pairs = {(elem[0], elem[2]) for elem in graph}

    # 遍历 graph 列表，检查并插入缺失的 [target, 0, source] 对
    i = 0
    while i < len(graph):
        source, num, target = graph[i]
        # 构造需要检查的反向对
        reverse_pair = (target, source)
        # 如果反向对不存在，则在当前位置后插入 [target, 0, source]
        if reverse_pair not in existing_pairs:
            graph.insert(i + 1, [target, num, source])
            # graph.insert(i + 1, [target, 0, source])
            # 插入后需要更新 existing_pairs 集合
            existing_pairs.add((target, source))
            # 跳过新插入的元素，以避免重复检查
            i += 1
        # 移动到下一个元素
        i += 1

    return graph


def reindex_graph_data(graph):
    # 获取所有节点并按升序排序
    nodes = sorted({node for edge in graph for node in [edge[0], edge[2]]})
    node_map = {node: idx for idx, node in enumerate(nodes)}

    # 使用映射表重新索引图中的节点
    new_graph = []
    for edge in graph:
        source_node = node_map[edge[0]]
        edge_type = edge[1]
        target_node = node_map[edge[2]]
        new_graph.append([source_node, edge_type, target_node])

    return new_graph


def parse_cfg_blocks(cfg_file):
    weight_color = {'blue': 1, 'green': 2, 'red': 3, 'cyan': 4, 'yellow': 5}
    with open(cfg_file, "r", encoding="utf-8") as f:
        datas = f.readlines()
        block_dict = {}
        connection_node = []
        index_node = 0
        for each_line in datas:
            # 只针对block_name 在前，节点关系在后
            each_line = each_line.strip()
            block = re.search(r"block_(\w+)\s*\[.+\\l\"]", each_line)
            connection_name = re.search(r"block_(\w+)\s*->\s*block_(\w+)", each_line)
            if block:
                block_name = block.group(1)
                block_dict[block_name] = index_node
                index_node += 1
            if connection_name:
                start_node = connection_name.group(1)
                end_node = connection_name.group(2)
                color_match = re.search(r"color=([a-zA-Z]+)", each_line).group(1)
                connection_node.append([block_dict[start_node], weight_color.get(color_match, 9), block_dict[end_node]])
        return connection_node


import numpy as np


def parse_output(out_file, graphs):
    global previous_out_file, cached_datas

    # 检查是否使用缓存的datas
    previous_parent = None
    previous_file_name = None
    datas = None
    node_features = []
    if previous_out_file is not None:
        previous_parent = os.path.basename(os.path.dirname(previous_out_file))
        previous_file_name = os.path.basename(previous_out_file)
    parent_folder_path = os.path.dirname(out_file)
    parent_folder = os.path.basename(parent_folder_path)
    if previous_parent is not None:
        if previous_parent == parent_folder and previous_file_name == 'all.json' and cached_datas is not None:
            datas = cached_datas
            print("666 we already has datas")
    else:
        if (parent_folder == 'reentrancy' and out_file.endswith('all.json')) or (
                parent_folder == 'wild-clean' and out_file.endswith('all.json')):

            with open(out_file, 'r', encoding='utf-8') as file:
                content = file.read()

            content1 = content.replace("]}]}]}\n{", "]}]}]},\n{")
            content2 = "[" + content1 + "]"
            datas = json.loads(content2)
            print("2 datas len is ", len(datas))

        else:
            with open(out_file, "r", encoding="utf-8") as f:
                datas = json.load(f)
            print("2.1 datas len is ", len(datas))

        if previous_out_file != out_file and datas is None:
            previous_out_file = out_file
            cached_datas = datas
            print("777 we already reserve datas")

    for src_node, weight, tgt_node in graphs:
        one_node_features = []
        if len(datas[src_node]['features']) > 3:
            for item in datas[src_node]['features'][1:-1]:
                values = item['layers'][0]['values']
                one_node_features.append(values)
        else:
            values = datas[src_node]['features'][1]['layers'][0]['values']
            one_node_features.append(values)

        one_node_features = np.array(one_node_features)
        mean = np.mean(one_node_features, axis=0)
        formatted_vector = [np.round(value, 6) for value in mean]
        node_features.append(formatted_vector)

    return node_features


def get_block_dict_in_features_input(input_path):
    block_in_input_features = {}
    with open(input_path, "r", encoding="utf-8") as f:
        content = f.readlines()
    index = 0
    for line in content:
        cleaned_line = line.strip()
        block_in_input_features[cleaned_line] = index
        index += 1

    return block_in_input_features


def block_dict_in_cfg(file_path):
    cfg_block_dict = {}
    weight_color = {'blue': 0, 'green': 1, 'red': 2, 'cyan': 3, 'yellow': 4}
    connection_node = []
    with open(file_path, "r", encoding="utf-8") as f:
        content = f.readlines()
    for line in content:
        if line.strip().startswith("block_") and "->" not in line:
            block_name = re.search(r"block_(\w+)\s*\[", line).group(1)
            features = re.search(r'label="(.*)\\l', line).group(1)
            if features.endswith('\\l'):
                cleaned_content2 = features[:-2]
            else:
                cleaned_content2 = features
            cleaned_content3 = cleaned_content2.split(':', 1)[1].strip()
            formatted_content = re.sub(r'\d+:\\l', '', cleaned_content3)
            formatted_content = re.sub(r'\\l', ' ', formatted_content)
            formatted_content = re.sub(r'.: ', ' ', formatted_content)
            formatted_content = re.sub(r'(?<=\s)\d(?=\s)', '', formatted_content)
            formatted_content = re.sub(r'\s+', ' ', formatted_content)
            cfg_block_dict[block_name] = formatted_content
            # {'0': 'PUSH1 0x80 PUSH1 0x40 MSTORE PUSH1 0x4 CALLDATASIZE LT PUSH2 0x154 JUMPI', 'd': 'PUSH1 0x0 CALLDATALOAD 1 PUSH29 0x100000000000000000000000000000000000000000000000000000000 2 SWAP1 2 DIV 3 PUSH4 0xffffffff 3 AND 3 DUP1 3 PUSH4 0x25e7c27 3 EQ 3 PUSH2 0x1ae 4 JUMPI'}
        connection_name = re.search(r"block_(\w+)\s*->\s*block_(\w+)", line)
        if connection_name:
            start_node = connection_name.group(1)  # 34n
            end_node = connection_name.group(2)  # 353
            color_match = re.search(r"color=([a-zA-Z]+)", line).group(1)  # red
            connection_node.append(
                [start_node, weight_color.get(color_match, 9), end_node])  # [['576', 2, '57f'], ['1b6d', 1, '1ba9']]
    return cfg_block_dict, connection_node


def get_graph(block_dict, connected_nodes):
    graph = []
    for node in connected_nodes:
        block_name1, num, block_name2 = node
        if block_name1 in block_dict and block_name2 in block_dict:
            sentence_index1 = block_dict[block_name1]
            sentence_index2 = block_dict[block_name2]
            graph.append([sentence_index1, num, sentence_index2])
    reindex_graph = reindex_graph_data(graph)
    final_graph = add_node(reindex_graph)
    # final_graph = insert_graph(reindex_graph)
    return final_graph


def deal_graph(graph):
    # 创建一个集合来存储所有的 [target, num, source] 对，用于快速查找
    existing_pairs = {(elem[2], elem[1], elem[0]) for elem in graph}

    # 遍历 graph 列表，检查并插入缺失的 [target, 0, source] 对
    i = 0
    while i < len(graph):
        source, num, target = graph[i]
        # 构造需要检查的反向对
        reverse_pair = (target, num, source)
        # 如果反向对不存在，则在当前位置后插入 [target, 0, source]
        if reverse_pair not in existing_pairs:
            graph.insert(i + 1, [target, 0, source])
            # 插入后需要更新 existing_pairs 集合
            existing_pairs.add((target, 0, source))
            # 跳过新插入的元素，以避免重复检查
            i += 1
        # 移动到下一个元素
        i += 1


def merge_data(cfg_path, output_path, write_path, targets, all_block_dict):
    all_cfg_files = os.listdir(cfg_path)
    all_res = []
    print("333 the write path is", write_path)
    with open(write_path, 'a', encoding='utf-8') as f:
        for each_cfg_file in all_cfg_files:

            res = {}
            cfg_file = os.path.join(cfg_path, each_cfg_file)
            out_file = os.path.join(output_path, each_cfg_file[:-4] + '.json')
            if os.path.exists(out_file):
                print("3.2 so cfg_file is connted with graphs: ", cfg_file)

                # parse
                graphs = parse_cfg_blocks(cfg_file)
                out_file = os.path.join(output_path, each_cfg_file[:-4] + '.json')
                print("3.3 out_file is conntected with node_features", out_file)
            else:
                block_dict = {}
                print("3.5 file.json is not exist ,so cfg_file is connted with graphs: ", cfg_file)
                cfg_block_dict, conneted_nodes = block_dict_in_cfg(cfg_file)
                # conneted_nodes [['576', 2, '57f'], ['1b6d', 1, '1ba9']]
                reversed_cfg_block_dict = {value: key for key, value in cfg_block_dict.items()}
                for key in reversed_cfg_block_dict:
                    if key in all_block_dict:
                        block_dict[reversed_cfg_block_dict[key]] = all_block_dict[key]

                graphs = get_graph(block_dict, conneted_nodes)
                out_file = os.path.join(output_path, "all.json")
                print("3.6 out_file is conntected with node_features", out_file)
            node_features = parse_output(out_file, graphs)
            # format
            res['targets'] = targets  # 针对access datas
            res['graph'] = graphs
            res['contract_name'] = each_cfg_file[:-4] + '.sol'
            res['node_features'] = node_features
            if len(graphs) > 0 and len(node_features) > 0:
                all_res.append(res)
        if len(all_res) > 0:
            print(len(all_res), "this is the len of all the cfg files")
            json.dump(all_res, f, ensure_ascii=False)


if __name__ == '__main__':
    source_dir = '/root/ll_code/graphextractor6.30/binary_cfg_code/'
    output_dir = '/root/ll_code/graphextractor6.30/feature/output/'
    data_dir = '/root/ll_code/graphextractor6.30/GraphFeatureExtractor-main/data-cfg/'
    input_dir = '/root/ll_code/graphextractor6.30/feature/input/'
    for dirpath, dirnames, filenames in os.walk(source_dir):
        for dirname in dirnames:
            ourput_dirname = ''
            source_folder_dir = os.path.join(dirpath, dirname)
            print("0 dirname is", dirname)
            if 'wild' in dirname:
                if 'reentrancy' in dirname:
                    targets = "1"
                    ourput_dirname = 'reentrancy'
                else:
                    targets = "0"
                    if 'access_control' in dirname:
                        ourput_dirname = 'access_control'
                    elif 'delegatecall' in dirname:
                        ourput_dirname = 'delegatecall'
                    elif 'wild-clean' in dirname:
                        ourput_dirname = 'wild-clean'
                    elif 'external' in dirname:
                        ourput_dirname = 'external_call'
                cfg_path = source_folder_dir
                output_path = os.path.join(output_dir, ourput_dirname)

                write_path = os.path.join(data_dir, ourput_dirname)
                write_path = os.path.join(write_path, 'all.json')
                # print("1 write_path is", write_path)
                input_folder = os.path.join(input_dir, ourput_dirname)
                input_file = os.path.join(input_folder, 'all.txt')

                block_in_input_features = {}
                if os.path.exists(input_file):
                    print("1.5 input_file is exist, input_folder is exist", input_folder)
                    block_in_input_features = get_block_dict_in_features_input(input_file)

                merge_data(cfg_path, output_path, write_path, targets, block_in_input_features)
