# -- coding:utf-8 --
import re
import json
import os
import numpy as np
import format_data as fd

previous_out_file = None
cached_datas = None


def block_dict_in_features_input(cfg_file, feature_input):
    input_file_name = "all1.txt"
    input_file_path = os.path.join(feature_input, input_file_name)
    block_dict_in_all = fd.get_block_dict_in_features_input(input_file_path)
    return block_dict_in_all, input_file_name


def get_node_features(input_file_name, graphs, feature_output):
    global previous_out_file, cached_datas
    datas = None
    node_features = []
    output_file_name = input_file_name[:-4] + ".json"
    output_file_path = os.path.join(feature_output, output_file_name)
    if previous_out_file == output_file_path and cached_datas is not None:
        datas = cached_datas
        print("555 the datas is exist")
    else:
        with open(output_file_path, "r") as f:
            content = f.read()
        content1 = content.replace("]}]}]}\n{", "]}]}]},\n{")
        if not content1.startswith('['):
            content1 = '[' + content1
        if not content1.endswith(']'):
            content1 = content1 + ']'
        datas = json.loads(content1)

        if cached_datas is None:
            cached_datas = datas
            print("666 the datas is saved")
        elif previous_out_file != output_file_path:
            print("previous_out_file is ", previous_out_file)
            cached_datas = datas
            previous_out_file = output_file_path
            print("666 the datas is saved")

    for blockindex1, color_num, blockindex2 in graphs:
        one_node_features = []
        if len(datas[blockindex1]['features']) > 3:
            for item in datas[blockindex1]['features'][1:-1]:
                values = item['layers'][0]['values']
                one_node_features.append(values)
            one_node_features = np.array(one_node_features)
            mean = np.mean(one_node_features, axis=0)
            formatted_vector = [np.round(value, 6) for value in mean]
        else:
            values = datas[blockindex1]['features'][1]['layers'][0]['values']
            formatted_vector = values
            # one_node_features.append(values)

        node_features.append(formatted_vector)

    return node_features


def get_features(cfg_file, feature_input, feature_output):
    block_dict = {}
    res = {}
    cfg_block_dict, node_ralations = fd.block_dict_in_cfg(cfg_file)
    # node_ralations: # [['576', 2, '57f'], ['1b6d', 1, '1ba9']]
    reversed_cfg_block_dict = {value: key for key, value in cfg_block_dict.items()}
    # get from cfg.txt file :the ralations and the cfg_block_dict in cfg file

    all_block_dict, input_file_name = block_dict_in_features_input(cfg_file, feature_input)
    # get block dict in all.txt (input features)

    for key in reversed_cfg_block_dict:
        if key in all_block_dict:
            block_dict[reversed_cfg_block_dict[key]] = all_block_dict[key]

    graphs = fd.get_graph(block_dict, node_ralations)
    node_features = get_node_features(input_file_name, graphs, feature_output)

    res['targets'] = "0"
    res['graph'] = graphs
    res['contract_name'] = os.path.basename(cfg_file)[:-4] + '.sol'
    res['node_features'] = node_features
    return res

if __name__ == '__main__':
    cfg_txt_folder = r'E:\2024\experiment_code_clone\total4\binary_cfg_code'
    write_file = r'E:\2024\experiment_code_clone\GraphFeatureExtractor-main\GraphFeatureExtractor-main\data-cfg\access_control\all.json'
    feature_input = r'E:\2024\experiment_code_clone\total4\BertPretrainFinetune-main\feature\input\access_control'
    feature_output = r'E:\2024\experiment_code_clone\total4\BertPretrainFinetune-main\feature\output\access_control'
    for dirpath, dirnames, filenames in os.walk(cfg_txt_folder):
        all_features = []
        for dirname in dirnames:
            if 'access' in dirname:
                cfg_folder = os.path.join(dirpath, dirname)
                for filename in os.listdir(cfg_folder):
                    cfg_file = os.path.join(cfg_folder, filename)
                    features = get_features(cfg_file, feature_input, feature_output)
                    if len(features['graph']) > 0 and len(features['node_features']) > 0:
                        if len(features['graph']) == len(features['node_features']):
                            all_features.append(features)
                    elif len(features['node_features']) > 0:
                        num_node = len(features['node_features'])
                        features['graph'].extend([0, 0, 0] * num_node)
                    elif len(features['graph']) > 0:
                        num_node = len(features['graph'])
                        features['node_features'].extend([[0.0] * 256 for _ in range(num_node)])
                    else:
                        print('features contract name is', features['contract_name'])
                        continue
        with open(write_file, 'a') as f:
            json.dump(all_features, f, ensure_ascii=False)