import json
import os.path
import random
import numpy as np
from collections import Counter
from sklearn.metrics.pairwise import rbf_kernel, cosine_similarity
import torch


def get_dirname(file_num):
    model_dirname = ''
    if file_num <= 41:
        model_dirname = "access_control"
    elif file_num <= 1060 and file_num >= 299:
        model_dirname = "reentrancy"
    elif file_num <= 2918 and file_num >= 1061:
        model_dirname = "wild-clean"
    elif file_num <= 3336 and file_num >= 2919:
        model_dirname = "external_call"
    elif file_num <= 298 and file_num >= 42:
        model_dirname = "delegatecall"
    return model_dirname


#
# def read_data(file_path):
#     with open(file_path, 'r') as f:
#         data = json.load(f)
#     return data
#
#
# def plot_token_frequency(data):
#     all_tokens = []
#     for item in data:
#         all_tokens.extend(item['token'])
#
#     processed_tokens = []
#     for token in all_tokens:
#         if token == 'math':
#             processed_tokens.append('safe math')
#         elif token == 'operations':
#             processed_tokens.append('math operations')
#         elif token != 'title' and token != 'safe':
#             processed_tokens.append(token)
#
#     token_freq = Counter(processed_tokens)
#     return token_freq
#
#
# def calculate_attention_weights(data, token_freq):
#     token_freq['safe'] = token_freq['safe math']
#     token_freq['title'] = 1
#     token_freq['math'] = token_freq['safe math']
#     token_freq['operations'] = token_freq['math operations']
#
#     token_to_weight = {token: freq / sum(token_freq.values()) for token, freq in token_freq.items()}
#
#     attention_weights = []
#     for item in data:
#         token_weights = [token_to_weight[token] for token in item['token'] if token in token_to_weight]
#         attention_weights.append(token_weights)
#
#     return attention_weights
#
#
# def calculate_weighted_embedding(data, attention_weights):
#     weighted_embeddings = []
#     for item, weights in zip(data, attention_weights):
#         token_embeddings = np.array(item['comments_feature'])
#         if len(weights) == 0:
#             weighted_embedding = np.zeros((512,), dtype=np.float32)
#         else:
#             weighted_embedding = np.dot(weights, token_embeddings)
#         weighted_embeddings.append(weighted_embedding)
#
#     return weighted_embeddings
#
#
# def preprocess_item(item, attention_weights):
#     weighted_embedding = np.array(attention_weights, dtype=np.float32).reshape(1, -1)
#     ast_features = np.array(item['ast_features'], dtype=np.float32).reshape(1, -1)
#     cfg_features = np.array(item['cfg_features'], dtype=np.float32).reshape(1, -1)
#     return np.vstack([weighted_embedding, ast_features, cfg_features])
#
#
# def compute_kernel_similarity(matrix1, matrix2, gamma=0.1):
#     # Flatten the matrices
#     matrix1_flat = matrix1.flatten().reshape(1, -1)
#     matrix2_flat = matrix2.flatten().reshape(1, -1)
#
#     # Compute RBF kernel (Gaussian kernel)
#     similarity = rbf_kernel(matrix1_flat, matrix2_flat, gamma=gamma)
#     return similarity[0][0]
#
#
# def main():
#     # 读取数据
#     dataset_path = './dataset_clone/all_features.json'
#     evaluate_path = './dataset_clone/evaluate_clone.json'
#     dataset = read_data(dataset_path)
#     evaluate_data = read_data(evaluate_path)
#
#     # 计算 token 频率和注意力权重
#     token_freq = plot_token_frequency(dataset)
#     attention_weights = calculate_attention_weights(dataset, token_freq)
#     weighted_embeddings = calculate_weighted_embedding(dataset, attention_weights)
#
#     # 预处理数据集中的所有项目
#     dataset_features = []
#     contract_names = []
#     for item, embedding in zip(dataset, weighted_embeddings):
#         features = preprocess_item(item, embedding)
#         dataset_features.append(features)
#         contract_names.append(item['contract_name'])
#
#     # 转换为numpy数组
#     dataset_features = np.array(dataset_features)
#
#     # 随机选择evaluate数据中的一个item并预处理
#     # itemb = random.choice(evaluate_data)
#     itemb = next(item for item in evaluate_data if item['contract_name'] == '932.sol')
#     itemb_attention_weights = calculate_attention_weights([itemb], token_freq)
#     itemb_weighted_embedding = calculate_weighted_embedding([itemb], itemb_attention_weights)
#     itemb_features = preprocess_item(itemb, itemb_weighted_embedding[0])
#     itemb_name = itemb['contract_name']
#
#     # 计算itemb与dataset中所有项的相似度
#     similarities = [compute_kernel_similarity(itemb_features, dataset_feature) for dataset_feature in dataset_features]
#
#     # # 找到前4个相似度最大的索引
#     # closest_indices = np.argsort(similarities)[-8:][::-1]
#     #
#     # # 输出对应的 contract_name 和相似度
#     # closest_contract_names = [contract_names[i] for i in closest_indices]
#     # closest_similarities = [similarities[i] for i in closest_indices]
#     # print("The input item is:", itemb_name)
#     # for name, similarity in zip(closest_contract_names, closest_similarities):
#     #     print(f"Contract Name: {name}, Similarity: {similarity}")
#     #
#     # # 文件路径
#     # file_path = os.path.join("./source_code_no_comments", itemb_name)
#     #
#     # # 统计文件行数
#     # with open(file_path, 'r', encoding='utf-8') as f:
#     #     num_lines = sum(1 for line in f)
#     #
#     # print(f"The number of lines in '{file_path}' is: {num_lines}")
#     # 找到前4个相似度最大的索引
#     closest_indices = np.argsort(similarities)[-15:][::-1]
#
#     # 输出对应的 contract_name 和相似度
#     closest_contract_names = [contract_names[i] for i in closest_indices]
#     closest_similarities = [similarities[i] for i in closest_indices]
#     max_similarity = max(similarities)
#     print("The input item is:", itemb_name)
#     for name, similarity in zip(closest_contract_names, closest_similarities):
#         if similarity / max_similarity > 0.85:
#             name_num = int(name[:-4])
#             type_name = get_dirname(name_num)
#
#             print(f"Contract Name: {name}, Similarity: {similarity / max_similarity}, Vulnarible type: {type_name}")
#
#     # 文件路径
#     file_path = os.path.join("./source_code_no_comments", itemb_name)
#
#     # 统计文件行数
#     with open(file_path, 'r', encoding='utf-8') as f:
#         num_lines = sum(1 for line in f)
#
#     print(f"The number of lines in '{file_path}' is: {num_lines}")
#
#
# if __name__ == "__main__":
#     main()


def read_data(file_path):
    with open(file_path, 'r') as f:
        data = json.load(f)
    return data


def plot_token_frequency(data):
    all_tokens = []
    for item in data:
        all_tokens.extend(item['token'])

    processed_tokens = []
    for token in all_tokens:
        if token == 'math':
            processed_tokens.append('safe math')
        elif token == 'operations':
            processed_tokens.append('math operations')
        elif token != 'title' and token != 'safe':
            processed_tokens.append(token)

    token_freq = Counter(processed_tokens)
    return token_freq


def calculate_attention_weights(data, token_freq):
    token_freq['safe'] = token_freq['safe math']
    token_freq['title'] = 1
    token_freq['math'] = token_freq['safe math']
    token_freq['operations'] = token_freq['math operations']

    token_to_weight = {token: freq / sum(token_freq.values()) for token, freq in token_freq.items()}

    attention_weights = []
    for item in data:
        token_weights = [token_to_weight[token] for token in item['token'] if token in token_to_weight]
        attention_weights.append(token_weights)

    return attention_weights


def calculate_weighted_embedding(data, attention_weights):
    weighted_embeddings = []
    for item, weights in zip(data, attention_weights):
        token_embeddings = np.array(item['comments_feature'])
        if len(weights) == 0:
            weighted_embedding = np.zeros((512,), dtype=np.float32)
        else:
            weighted_embedding = np.dot(weights, token_embeddings)
        weighted_embeddings.append(weighted_embedding)

    return weighted_embeddings


def preprocess_item(item, attention_weights):
    weighted_embedding = np.array(attention_weights, dtype=np.float32).reshape(1, -1)
    ast_features = np.array(item['ast_features'], dtype=np.float32).reshape(1, -1)
    cfg_features = np.array(item['cfg_features'], dtype=np.float32).reshape(1, -1)
    return np.vstack([weighted_embedding, ast_features, cfg_features])


def compute_cosine_similarity(matrix1, matrix2):
    # Flatten the matrices
    matrix1_flat = matrix1.flatten().reshape(1, -1)
    matrix2_flat = matrix2.flatten().reshape(1, -1)

    # Compute cosine similarity
    similarity = cosine_similarity(matrix1_flat, matrix2_flat)
    return similarity[0][0]


def main():
    # 读取数据
    dataset_path = r'E:\2024\experiment_code_clone\total4\all_features\dataset\all_features.json'
    evaluate_path = r'E:\2024\experiment_code_clone\total4\all_features\dataset_clone\evaluate_clone.json'
    dataset = read_data(dataset_path)
    evaluate_data = read_data(evaluate_path)

    # 计算 token 频率和注意力权重
    token_freq = plot_token_frequency(dataset)
    attention_weights = calculate_attention_weights(dataset, token_freq)
    weighted_embeddings = calculate_weighted_embedding(dataset, attention_weights)

    # 预处理数据集中的所有项目
    dataset_features = []
    contract_names = []
    for item, embedding in zip(dataset, weighted_embeddings):
        features = preprocess_item(item, embedding)
        dataset_features.append(features)
        contract_names.append(item['contract_name'])

    # 转换为numpy数组
    dataset_features = np.array(dataset_features)

    # 随机选择evaluate数据中的一个item并预处理
    itemb = random.choice(evaluate_data)
    # itemb = next(item for item in evaluate_data if item['contract_name'] == '932.sol')
    itemb_attention_weights = calculate_attention_weights([itemb], token_freq)
    itemb_weighted_embedding = calculate_weighted_embedding([itemb], itemb_attention_weights)
    itemb_features = preprocess_item(itemb, itemb_weighted_embedding[0])
    itemb_name = itemb['contract_name']

    # 计算itemb与dataset中所有项的相似度
    similarities = [compute_cosine_similarity(itemb_features, dataset_feature) for dataset_feature in
                    dataset_features]

    # 找到前4个相似度最大的索引
    closest_indices = np.argsort(similarities)[-10:][::-1]

    # 输出对应的 contract_name 和相似度
    closest_contract_names = [contract_names[i] for i in closest_indices]
    closest_similarities = [similarities[i] for i in closest_indices]
    max_similarity = max(similarities)
    print("The input item is:", itemb_name)
    for name, similarity in zip(closest_contract_names, closest_similarities):
        if similarity / max_similarity > 0.95:
            name_num = int(name[:-4])
            type_name = get_dirname(name_num)

            print(f"Contract Name: {name}, Similarity: {similarity / max_similarity}, Vulnarible type: {type_name}")

    # 文件路径
    file_path = os.path.join("./source_code_no_comments", itemb_name)

    # 统计文件行数
    with open(file_path, 'r', encoding='utf-8') as f:
        num_lines = sum(1 for line in f)

    print(f"The number of lines in '{file_path}' is: {num_lines}")


if __name__ == "__main__":
    main()
