# import json
# import os
# from smartembed import SmartEmbed
#
#
# def read_contract(file_path):
#     with open(file_path, 'r', encoding='utf-8') as file:
#         return file.read()
#
#
# def get_file_vector(se, file_path):
#     contract = read_contract(file_path)
#     return se.get_vector(contract)
#
#
# def calculate_similarities(target_file, all_files_dir, se):
#     """
#     计算目标文件与所有文件的相似度，并返回相似度大于0.95的文件列表
#     """
#     target_vector = get_file_vector(se, target_file)
#     similar_files = []
#
#     for file_name in os.listdir(all_files_dir):
#         file_path = os.path.join(all_files_dir, file_name)
#         if os.path.isfile(file_path) and file_path.endswith('.sol'):
#             print(file_path)
#             vector = get_file_vector(se, file_path)
#             similarity = se.get_similarity(target_vector, vector)
#             print(
#                 f"similarity score between {os.path.basename(target_file)} and {os.path.basename(file_path)}: {similarity}")
#             if similarity > 0.95:
#                 similar_files.append(file_name)
#                 print(f"File {file_name} is similar with a similarity score of {similarity}")
#
#     return similar_files
#
#
# def main():
#     se = SmartEmbed()
#
#     clone_target_dir = r'E:\2024\experiment_code_clone\total4\all_features\dataset_clone\evaluate_dataset'
#     clone_all_file_dir = r'E:\2024\experiment_code_clone\total4\all_features\source_code_no_comments'
#
#     # 随机选择目标文件夹中的一个文件
#     # target_file_name = random.choice(os.listdir(clone_target_dir))
#     for target_file_name in os.listdir(clone_target_dir):
#         if target_file_name.endswith('.sol'):
#             target_file_path = os.path.join(clone_target_dir, target_file_name)
#
#             print(f"Selected target file: {target_file_name}")
#             output_dir = os.path.join(r'E:\2024\experiment_code_clone\total4\all_features\clone_experiment\output',
#                                       target_file_name)
#             if not os.path.exists(output_dir):
#                 os.makedirs(output_dir)
#             if os.path.exists(os.path.join(output_dir, 'smartembed.json')):
#                 continue
#
#             # 计算相似度并记录相似度大于0.95的文件名称
#             similar_files = calculate_similarities(target_file_path, clone_all_file_dir, se)
#
#             print("Files with similarity > 0.95:", similar_files)
#             with open(os.path.join(output_dir, 'smartembed.json'), 'w') as file:
#                 json.dump(similar_files, file, ensure_ascii=False)
#
#
# if __name__ == "__main__":
#     main()

import json
import os
from smartembed import SmartEmbed
import time

# 全局字典来缓存文件路径和对应的向量
vector_cache = {}


def read_contract(file_path):
    with open(file_path, 'r', encoding='utf-8') as file:
        return file.read()


# def get_file_vector(se, file_path):
#     if file_path in vector_cache:
#         return vector_cache[file_path]
#     contract = read_contract(file_path)
#     vector = se.get_vector(contract)
#     vector_cache[file_path] = vector
#     return vector
def get_file_vector(se, file_path):
    if file_path in vector_cache:
        return vector_cache[file_path]
    contract = read_contract(file_path)
    try:
        vector = se.get_vector(contract)
        vector_cache[file_path] = vector
        return vector
    except Exception as e:
        print(f"Error processing file {file_path}: {e}")
        return None


def calculate_similarities(target_file, all_files_dir, se):
    """
    计算目标文件与所有文件的相似度，并返回相似度大于0.95的文件列表
    """
    target_vector = get_file_vector(se, target_file)
    similar_files = []

    for file_name in os.listdir(all_files_dir):
        file_path = os.path.join(all_files_dir, file_name)
        if os.path.isfile(file_path) and file_path.endswith('.sol'):
            print(file_path)
            vector = get_file_vector(se, file_path)
            similarity = se.get_similarity(target_vector, vector)
            print(
                f"similarity score between {os.path.basename(target_file)} and {os.path.basename(file_path)}: {similarity}")
            if similarity >= 1:
                similar_files.append(file_name)
                print(f"File {file_name} is similar with a similarity score of {similarity}")

    return similar_files


def main():
    se = SmartEmbed()

    clone_target_dir = r'E:\2024\experiment_code_clone\total4\all_features\dataset_clone\evaluate_dataset'
    clone_all_file_dir = r'E:\2024\experiment_code_clone\total4\all_features\source_code_no_comments'

    # 预先加载所有文件的向量
    for file_name in os.listdir(clone_all_file_dir):
        file_path = os.path.join(clone_all_file_dir, file_name)
        if os.path.isfile(file_path) and file_path.endswith('.sol'):
            get_file_vector(se, file_path)

    # 遍历目标文件夹中的所有文件
    start_time = time.time()
    for target_file_name in os.listdir(clone_target_dir):
        if target_file_name.endswith('.sol'):
            target_file_path = os.path.join(clone_target_dir, target_file_name)

            print(f"Selected target file: {target_file_name}")
            output_dir = os.path.join(r'E:\2024\experiment_code_clone\total4\all_features\clone_experiment\output_0.95',
                                      target_file_name)
            if not os.path.exists(output_dir):
                os.makedirs(output_dir)
            # if os.path.exists(os.path.join(output_dir, 'smartembed.json')):
            #     continue

            # 计算相似度并记录相似度大于0.95的文件名称
            similar_files = calculate_similarities(target_file_path, clone_all_file_dir, se)
            similar_files = [x for x in similar_files if x != target_file_name]

            print("Files with similarity >= 1:", similar_files)
            with open(os.path.join(output_dir, 'smartembed.json'), 'w') as file:
                json.dump(similar_files, file, ensure_ascii=False)

    end_time = time.time()
    print(f"Total time: {end_time - start_time}")


if __name__ == "__main__":
    main()
