from concurrent.futures import ThreadPoolExecutor
import os
from utils.utils import CONSTANTS, dump_jsonl, json_to_graph, CodexTokenizer, load_jsonl, make_needed_dir
import copy
import networkx as nx
import queue
import Levenshtein
import argparse
import time
import ast
import hashlib
import joblib
import re
from tree_sitter import Parser, Language
import numpy as np
from utils.metrics import hit
from functools import partial
from build_query_graph import build_query_subgraph
from collections import Counter
from sklearn.metrics.pairwise import cosine_similarity
from cachetools import LRUCache
import torch
import javalang
from javalang import parse
from javalang.tree import MethodDeclaration, ClassDeclaration, ForStatement, WhileStatement,DoStatement,IfStatement, SwitchStatement, TryStatement,SynchronizedStatement
from scipy.spatial.distance import cosine
from scipy.interpolate import interp1d
from sklearn.feature_extraction.text import TfidfVectorizer


class SimilarityScore:
    @staticmethod
    def text_edit_similarity(str1: str, str2: str):
        return 1 - Levenshtein.distance(str1, str2) / max(len(str1), len(str2))

    @staticmethod
    def counter_union(counter1, counter2):
        """计算两个Counter的并集（保留重复值）"""
        return counter1 | counter2

    @staticmethod
    def counter_intersection(counter1, counter2):
        """计算两个Counter的交集（保留重复值）"""
        return counter1 & counter2

    @staticmethod
    def text_jaccard_similarity(list1, list2):
        set1 = set(list1)
        set2 = set(list2)
        intersection = len(set1.intersection(set2))
        union = len(set1.union(set2))
        return float(intersection) / union

    @staticmethod
    def text_jaccard_similarity_pro(list1, list2):
        counter1 = Counter(list1)
        counter2 = Counter(list2)

        union_1 = SimilarityScore.counter_union(counter1, counter2)
        intersection_2 = SimilarityScore.counter_intersection(counter1, counter2)
        intersection =len(intersection_2)
        union = len(union_1)
        return float(intersection) / union


    @staticmethod
    def levenshtein_distance(list1, list2):
        # 获取两个列表的长度
        len1 = len(list1)
        len2 = len(list2)

        # 创建一个二维数组来存储子问题的解
        dp = [[0] * (len2 + 1) for _ in range(len1 + 1)]

        # 初始化第一行和第一列
        for i in range(len1 + 1):
            dp[i][0] = i
        for j in range(len2 + 1):
            dp[0][j] = j

        # 填充dp数组
        for i in range(1, len1 + 1):
            for j in range(1, len2 + 1):
                # 如果当前元素相同，则不需要操作
                if list1[i - 1] == list2[j - 1]:
                    dp[i][j] = dp[i - 1][j - 1]
                else:
                    # 取插入、删除、替换操作中的最小值
                    dp[i][j] = 1 + min(
                        dp[i - 1][j],  # 删除
                        dp[i][j - 1],  # 插入
                        dp[i - 1][j - 1]  # 替换
                    )

        return dp[len1][len2]

    @staticmethod
    def levenshtein_similarity(list1, list2):
        # 计算Levenshtein距离
        distance = SimilarityScore.levenshtein_distance(list1, list2)

        # 计算最大可能的距离（即较长列表的长度）
        max_len = max(len(list1), len(list2))
        if max_len == 0:
            return 1.0  # 两个空列表视为完全相同

        # 归一化距离到[0, 1]范围
        normalized_distance = distance / max_len

        # 将归一化的距离映射到[-1, 1]范围，距离越大值越小
        similarity = 1.0 - normalized_distance * 2

        return similarity

    @staticmethod
    def subgraph_edit_similarity(query_graph: nx.MultiDiGraph, graph: nx.MultiDiGraph, gamma=0.1):
        # To ensure the consistency of sorting scores implementation in the next step, the SED can be straightforwardly transformed into subgraph edit similarity.

        #找到query_graph和graph的最大节点作为根节点
        query_root = max(query_graph.nodes)
        root = max(graph.nodes)

        #使用分词器，对根节点的源代码进行分词
        tokenizer = CodexTokenizer()

        query_graph_node_embedding = tokenizer.tokenize("".join(query_graph.nodes[query_root]['sourceLines']))
        graph_node_embedding = tokenizer.tokenize("".join(graph.nodes[root]['sourceLines']))
        #计算两个节点的jaccard相似度

        node_sim = SimilarityScore.text_jaccard_similarity(query_graph_node_embedding, graph_node_embedding)


        node_match = dict()
        match_queue = queue.Queue()
        match_queue.put((query_root, root, 0))
        node_match[query_root] = (root, 0)

        #已访问节点————这个位置好像是DFS
        query_graph_visited = {query_root}
        graph_visited = {root}

        graph_nodes = set(graph.nodes)

        #广度优先搜索匹配节点
        while not match_queue.empty():
            v, u, hop = match_queue.get()
            v_neighbors = (set(query_graph.neighbors(v)) | set(query_graph.predecessors(v))) - set(query_graph_visited)
            u_neighbors = graph_nodes - set(graph_visited)

            sim_score = []
            #计算邻居节点的相似度
            for vn in v_neighbors:
                for un in u_neighbors:
                    query_graph_node_embedding = tokenizer.tokenize("".join(query_graph.nodes[vn]['sourceLines']))
                    graph_node_embedding = tokenizer.tokenize("".join(graph.nodes[un]['sourceLines']))
                    sim = SimilarityScore.text_jaccard_similarity(query_graph_node_embedding, graph_node_embedding)
                    sim_score.append((sim, vn, un))

            sim_score.sort(key=lambda x: -x[0])

            for sim, vn, un in sim_score:
            # 从左到右依次是相似度、查询图中的一个节点、目标图中的一个节点。

                if vn not in query_graph_visited and un not in graph_visited:
                #检查两个节点是否被访问过。
                    #hop由计算邻居节点的时候获得。
                    match_queue.put((vn, un, hop + 1)) #未被访问过添加到匹配队列 hop+1表示访问深度
                    node_match[vn] = (un, hop + 1) #添加匹配关系
                    query_graph_visited.add(vn) #标记已访问
                    graph_visited.add(un) #标记已访问
                    v_neighbors.remove(vn)
                    u_neighbors.remove(un)
                    node_sim += (gamma ** (hop + 1)) * sim #积累相似性分数，随着跳数越大权重降低

                if len(v_neighbors) == 0 or len(u_neighbors) == 0:
                    break

            #处理未被匹配到的节点，设置为none.
            if len(v_neighbors) != 0:
                for vn in v_neighbors:
                    node_match[vn] = None
                    query_graph_visited.add(vn)

        #计算边相似度
        edge_sim = 0
        #遍历查询图中的所有节点，如果在之前的节点匹配环节中，就没有匹配上，设置为none
        for v in query_graph.nodes:
            if v not in node_match.keys():
                node_match[v] = None
        #从左到右依次是，边的两个节点，t为边的类型。
        for v_query, u_query, t in query_graph.edges:
            #查找边的两个节点是不是都有匹配
            if node_match[v_query] is not None and node_match[u_query] is not None:
                v, hop_v = node_match[v_query]
                u, hop_u = node_match[u_query]
                #检查目标图中是否有这条边（v,u,t）
                if graph.has_edge(v, u, t):
                    #积累相似度分数
                    edge_sim += (gamma ** hop_v)

        graph_sim = node_sim + edge_sim

        return graph_sim


class CodeSearchWorker:
    def __init__(self, query_cases, output_path, mode, gamma=None, max_top_k=CONSTANTS.max_search_top_k, remove_threshold=0):
        self.query_cases = query_cases
        self.output_path = output_path
        self.max_top_k = max_top_k
        self.remove_threshold = remove_threshold
        self.mode = mode
        self.gamma = gamma
        self._code_encoder = None
        self.embedding_cache = LRUCache(maxsize=1000)  # 可选缓存
        self._tokenizer = None
        self._code_encoder = None
        self.embedding_cache = {}
        self.tfidf_vectorizer = None
        self.cross_cases = None


    @staticmethod
    # 判断查询案例的代码片段是否出现在仓库案例的代码片段之后。
    def _is_context_after_hole(query_case, repo_case):
        hole_fpath_str = "/".join(query_case['metadata']['fpath_tuple'])
        repo_fpath_str = "/".join(repo_case['fpath_tuple'])

        if hole_fpath_str != repo_fpath_str:
            return False
        else:
            line_no=query_case['metadata']['line_no']
            query_case_line = max(query_case['metadata']['forward_context_line_list'])
            repo_case_last_line = repo_case['max_line_no']
            repo_case_forward_line=repo_case['min_line_no']
            # if repo_case_last_line >= query_case_line:
            #     return True
            if line_no <= repo_case_last_line and line_no >= repo_case_forward_line:
                return True
            else:
                return False

    # 计算查询案例和仓库案例之间的文件Jaccard相似度
    def _text_jaccard_similarity_wrapper(self, query_case, repo_case):
        if self._is_context_after_hole(query_case, repo_case):
            return repo_case, 0

        sim = SimilarityScore.text_jaccard_similarity(query_case['query_forward_encoding'],
                                                     repo_case['key_forward_encoding'])

        return repo_case, sim

    #计算查询案例和仓库案例之间的图节点相似度
    def _graph_node_prior_similarity_wrapper(self, query_case, repo_case):
        query_graph = json_to_graph(query_case['query_forward_graph'])
        repo_graph = json_to_graph(repo_case['key_forward_graph'])
        if len(repo_graph.nodes) == 0 or self._is_context_after_hole(query_case, repo_case):
            return repo_case, 0
        # if len(query_graph.edges) == 0:
        #     return repo_case, 0

        sim = SimilarityScore.subgraph_edit_similarity(query_graph, repo_graph, gamma=self.gamma)
        return repo_case, sim

    #执行第一阶段搜索，查找与查询案例最相似的代码片段
    def _find_top_k_context_one_phase(self, query_case):
        start_time = time.time()
        repo_name = query_case['metadata']['task_id'].split('/')[0]
        search_res = copy.deepcopy(query_case)
        repo_cases = load_jsonl(os.path.join(CONSTANTS.graph_database_save_dir, f"{repo_name}.jsonl"))
        top_k_context = []
        with ThreadPoolExecutor(max_workers=32) as executor:
            if self.mode == 'coarse':
                compute_sim = partial(self._text_jaccard_similarity_wrapper, query_case)
            else:
                compute_sim = partial(self._graph_node_prior_similarity_wrapper, query_case)
            futures = executor.map(compute_sim, repo_cases)
            top_k_context = list(futures)
        top_k_context_filtered = []
        for repo_case, sim in top_k_context:
            if sim >= self.remove_threshold:
                top_k_context_filtered.append((repo_case['val'], repo_case['statement'],
                                               repo_case['key_forward_context'], repo_case['fpath_tuple'], sim))
        top_k_context_filtered = sorted(top_k_context_filtered, key=lambda x: x[-1], reverse=False)
        search_res['top_k_context'] = top_k_context_filtered[-self.max_top_k:]

        case_id = query_case['metadata']['task_id']
        print(f'case {case_id} finished')
        end_time = time.time()

        if self.mode == 'coarse':
            search_res['text_runtime'] = end_time - start_time
            search_res['graph_runtime'] = 0
        else:
            search_res['text_runtime'] = 0
            search_res['graph_runtime'] = end_time - start_time
        return search_res

    def _semantic_filter(self, query_text, candidates, similarity_threshold=0.75):
        """Semantic filtering based on GraphcodeBERT"""
        from transformers import AutoTokenizer, AutoModel
        import torch
        from sklearn.metrics.pairwise import cosine_similarity
        # 确保模型初始化
        if self._code_encoder is None:
            model_name = "microsoft/graphcodebert-base"
            self._tokenizer = AutoTokenizer.from_pretrained(model_name)
            self._code_encoder = AutoModel.from_pretrained(model_name)
            device = "cuda" if torch.cuda.is_available() else "cpu"
            self._code_encoder = self._code_encoder.to(device)
            self._code_encoder.eval()

        def _encode(text_list):
            """封装编码逻辑"""
            inputs = self._tokenizer(
                text_list,
                padding=True,
                truncation=True,
                max_length=512,
                return_tensors="pt"
            ).to(self._code_encoder.device)

            with torch.no_grad():
                outputs = self._code_encoder(**inputs)
            # 使用CLS token作为句子表示
            embeddings = outputs.last_hidden_state[:, 0, :]
            # L2归一化保证余弦相似度计算正确
            embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
            return embeddings.cpu().numpy()

         # 执行编码
        query_embed = _encode([query_text])
        candidate_embeds = _encode(candidates)

        # 计算相似度
        similarities = cosine_similarity(query_embed, candidate_embeds)[0]
        if similarities.size != 0:
            # 计算中位数作为阈值
            sorted_scores = sorted(similarities,reverse=True)
            n = len(sorted_scores)
            percentile_index = int(n * 0.75)
            if percentile_index >= n:
                percentile_index = n - 1
            similarity_threshold = sorted_scores[percentile_index]

        # print(similarities)
        return [
            (cand,sim) for cand, sim in zip(candidates, similarities)
                if sim >= similarity_threshold
            ]

    def _deduplicate_codes(self, cases):
        unique_cases = []
        seen_hashes = set()
        #print(cases)
        for case in cases:
            code = case['val']
            content_hash = hashlib.md5(code.encode()).hexdigest()
            if content_hash not in seen_hashes:
                unique_cases.append(case)
                seen_hashes.add(content_hash)

        return unique_cases

    # def _get_code_structure_hash(self, code):
    #     """支持多语言的代码结构特征提取"""
    #     is_python=False
    #    # print(code)
    #         # 使用tree-sitter进行多语言解析
    #     if is_python==True:
    #         try:
    #             tree = ast.parse(code)
    #             struct_features = []
    #             for node in ast.walk(tree):
    #                 if isinstance(node, (ast.FunctionDef, ast.ClassDef, ast.For, ast.While)):
    #                     struct_features.append(type(node).__name__)
    #            # print(0)
    #             return hashlib.md5(''.join(struct_features).encode()).hexdigest()
    #         except Exception as e:
    #             return hashlib.md5(code.encode()).hexdigest()
    #     else:
    #         struct_features = []
    #         code = self.remove_comments(code)
    #         try:
    #             tree = parse.parse(code)
    #             for path, node in tree.filter((MethodDeclaration,
    #                                          ClassDeclaration,
    #                                          ForStatement,
    #                                          WhileStatement,
    #                                          DoStatement,
    #                                          IfStatement,
    #                                          SwitchStatement,
    #                                          TryStatement,
    #                                          SynchronizedStatement
    #                                            )):
    #                 struct_features.append(type(node).__name__)
    #             return hashlib.md5(''.join(struct_features).encode()).hexdigest()
    #
    #         except (javalang.parser.JavaSyntaxError, javalang.tokenizer.LexerError,Exception) as e:
    #                 return hashlib.md5(code.encode()).hexdigest()


    def remove_comments(self,code):
        # 移除单行注释
        code = re.sub(r'//.*', '', code)
        # 移除多行注释
        code = re.sub(r'/\*.*?\*/', '', code, flags=re.DOTALL)
        return code

    def _code_similarity(self, code1, code2):
        """基于token的快速相似度计算"""
        tokens1 = set(code1.split())
        tokens2 = set(code2.split())
        return len(tokens1 & tokens2) / len(tokens1 | tokens2)

    def _get_code_embedding(self, code: str)->np.ndarray:
        """使用GraphcodeBERT生成代码嵌入"""
        import hashlib
        import numpy as np
        import torch

        # 检查缓存（增加模型标识）
        cache_key = hashlib.md5(
            f"graphcodebert_{code}".encode()  # 添加模型标识避免冲突
        ).hexdigest()

        if cache_key in self.embedding_cache:
            return self.embedding_cache[cache_key]

        # 确保模型初始化（复用_semantic_filter中的初始化）
        if self._code_encoder is None:
            _ = self._semantic_filter("dummy", ["dummy"])  # 触发初始化

        # 生成嵌入
        inputs = self._tokenizer(
                code,
                padding=True,
                truncation=True,
                max_length=512,
                return_tensors="pt"
            ).to(self._code_encoder.device)

        with torch.no_grad():
            outputs = self._code_encoder(**inputs)
        embedding = outputs.last_hidden_state[:, 0, :]
        embedding = torch.nn.functional.normalize(embedding, p=2, dim=1)

        # 更新缓存
        cpu_embedding = embedding.detach().cpu().numpy().astype(np.float32)
        self.embedding_cache[cache_key] = cpu_embedding
        return cpu_embedding

    # def _get_code_embedding1(self, code_snippet):
    #     """获取代码片段嵌入（示例使用TF-IDF，可替换为其他模型）"""
    #     return self.tfidf_vectorizer.transform([code_snippet]).toarray()

    def _find_top_k_context_two_phase(self, query_case):
        repo_name = query_case['metadata']['task_id'].split('/')[0]
        repo_cases = load_jsonl(os.path.join(CONSTANTS.graph_database_save_dir, f"{repo_name}.jsonl"))
        
        #文本相似度计算
        text_runtime_start = time.time()
        with ThreadPoolExecutor(max_workers=32) as executor:
            #partial为部分应用函数，绑定query_case参数
            compute_sim = partial(self._text_jaccard_similarity_wrapper, query_case)
            #对每个repo_case计算查询案例和仓库案例之间的文本相似度
            futures = executor.map(compute_sim, repo_cases)
            #print(futures)
            top_k_context_phase1 = list(futures)
        
        #x[1]为相似度值，按照相似度值对结果进行排序，降序排序。 取前三倍self.max_top_k个结果
        #phase1_scores = [x[1] for x in top_k_context_phase1]
        dynamic_top_k = min(len(repo_cases), self.max_top_k*2)
        top_k_context_phase1 = sorted(top_k_context_phase1, key=lambda x: x[1], reverse=True)[:dynamic_top_k]

        # 语义过滤（使用SBERT）
        semantic_filtered = self._semantic_filter(
            query_text=query_case['query_forward_context'],
            candidates=[c[0]['val'] for c in top_k_context_phase1],
            similarity_threshold=0.75 #默认值为0.75
        )

        # 去重处理
        unique_candidates = self._deduplicate_codes(
            [c[0] for c in top_k_context_phase1 if c[0]['val'] in  [item[0] for item in semantic_filtered]]
        )

        text_runtime_end = time.time()

        #计算图结构相似度计算
        with ThreadPoolExecutor(max_workers=32) as executor:
            compute_sim = partial(self._graph_node_prior_similarity_wrapper, query_case)
            top_k_cases = []

            for case in unique_candidates:
                top_k_cases.append(case)
            futures = executor.map(compute_sim, top_k_cases)
            graph_scored = [(c, s) for c, s in futures if s >= self.remove_threshold]

        graph_scored1 = sorted(graph_scored, key=lambda x: x[1], reverse=True)[:self.max_top_k+2]
        combined_scores = []
        final_results = []

        for case, graph_score in graph_scored1:
            # 获取原始文本相似度分数
            text_score = next((ts for c, ts in top_k_context_phase1 if c['val'] == case['val']), 0)
            semantic_score = next((ss for c, ss in semantic_filtered if c == case['val']),0)


            combined = (
                    0.5 * graph_score +  # 图结构相似度
                    0.5 * text_score  # 文本相似度
                    #0.1 * semantic_score  # 语义相似度
            )
            combined_scores.append((case, combined))

        # 多样性重排序（MMR算法）
        selected_embeddings = []
        count=0
        for case, score in sorted(combined_scores, key=lambda x: -x[1]):
            curr_embed = self._get_code_embedding(case['val'])

            if count >= self.max_top_k:
                break #只取前k个 k为10
            if selected_embeddings:
                    # 计算与已选结果的最大相似度
                max_sim = max(
                    cosine_similarity(curr_embed.reshape(1, -1),  # 确保形状为(1, dim)
                                           e.reshape(1, -1)
                                          )[0][0]
                    for e in selected_embeddings
                )
                # MMR评分（λ=0.4控制多样性权重）
                mmr_score = 0.8 * score - 0.2 * max_sim
            else:
                mmr_score = score
            count = count + 1
            final_results.append((case, mmr_score))
            selected_embeddings.append(curr_embed)
             # 最终排序并格式化结果
        sorted_results = sorted(final_results, key=lambda x: -x[1])[:self.max_top_k]

        query_case['top_k_context'] = [(
            res[0]['val'], res[0]['statement'],
            res[0]['key_forward_context'], res[0]['fpath_tuple'],
            res[1]  # 最终评分
        ) for res in sorted_results]
        # ——————————添加——————————————————
        query_case['import_info'] = self._find_import(query_case) #add EAID
        # query_case['cross_file_line']=self._find_crossfile(query_case) # synergistic gain part

        graph_runtime_end = time.time()

        case_id = query_case['metadata']['task_id']
        print(f'case {case_id} finished')
        query_case['text_runtime'] = text_runtime_end - text_runtime_start
        query_case['graph_runtime'] = graph_runtime_end - text_runtime_end
        return copy.deepcopy(query_case)

   # Original version
    def _find_top_k_context_two_phase_0(self, query_case):
        repo_name = query_case['metadata']['task_id'].split('/')[0]
        repo_cases = load_jsonl(os.path.join(CONSTANTS.graph_database_save_dir, f"{repo_name}.jsonl"))
        text_runtime_start = time.time()
        with ThreadPoolExecutor(max_workers=32) as executor:
            compute_sim = partial(self._text_jaccard_similarity_wrapper, query_case)
            futures = executor.map(compute_sim, repo_cases)
            top_k_context_phase1 = list(futures)
        top_k_context_phase1 = sorted(top_k_context_phase1, key=lambda x: x[1], reverse=True)[:self.max_top_k]
        text_runtime_end = time.time()

        with ThreadPoolExecutor(max_workers=32) as executor:
            compute_sim = partial(self._graph_node_prior_similarity_wrapper, query_case)
            top_k_cases = []
            for case, _ in top_k_context_phase1:
                top_k_cases.append(case)
            futures = executor.map(compute_sim, top_k_cases)
            top_k_context_phase2 = list(futures)

        top_k_context_filtered = []
        for repo_case, sim in top_k_context_phase2:
            if sim >= self.remove_threshold:
                top_k_context_filtered.append((repo_case['val'], repo_case['statement'],
                                               repo_case['key_forward_context'], repo_case['fpath_tuple'], sim))
        top_k_context_filtered = sorted(top_k_context_filtered, key=lambda x: x[-1], reverse=True)

        query_case['top_k_context'] = top_k_context_filtered[:self.max_top_k]

        graph_runtime_end = time.time()

        case_id = query_case['metadata']['task_id']
        print(f'case {case_id} finished')
        query_case['text_runtime'] = text_runtime_end - text_runtime_start
        query_case['graph_runtime'] = graph_runtime_end - text_runtime_end
        return copy.deepcopy(query_case)

    def _find_import(self, query_case):
        query_name = query_case['metadata']['task_id'].split('/')[0]
        import_cases = load_jsonl(os.path.join(f"./output1", f"{query_name}.jsonl"))
        for import_case in import_cases:
            hole_fpath_str = "/".join(query_case['metadata']['fpath_tuple'])
            repo_fpath_str = "".join(import_case["file_path"])
            if hole_fpath_str == repo_fpath_str:
                return import_case['import_info']

    def _find_crossfile(self,query_case):
        query_name = query_case['metadata']['task_id']
        for cross_file in self.cross_cases:
            if query_name == cross_file['metadata']['task_id']:
                return cross_file['cross_file_prompt']

    def run(self):
        query_lines_with_retrieved_results = []
        # self.cross_cases=load_jsonl(f'other_cross_file_method_output.jsonl')

        if self.mode == 'coarse' or self.mode == 'fine':
            for query_case in self.query_cases:
                res = self._find_top_k_context_one_phase(query_case)
                query_lines_with_retrieved_results.append(copy.deepcopy(res))
        else:
            for query_case in self.query_cases:
                # self._find_import(query_case)
                # res = self._find_top_k_context_two_phase_0(query_case)
                res = self._find_top_k_context_two_phase(query_case)
                query_lines_with_retrieved_results.append(copy.deepcopy(res))

        dump_jsonl(query_lines_with_retrieved_results, self.output_path)


if __name__ == '__main__':
    args_parser = argparse.ArgumentParser()
    args_parser.add_argument('--query_cases', default="api_level", type=str)
    args_parser.add_argument('--mode', type=str, default='coarse2fine')
    args_parser.add_argument('--gamma', default=0.1, type=float)
    args = args_parser.parse_args()

    build_query_subgraph(f"{args.query_cases}.test.jsonl")

    query_cases = load_jsonl(os.path.join(CONSTANTS.query_graph_save_dir, f"{args.query_cases}.test.jsonl"))
    save_path = os.path.join(f"./search_results/{args.query_cases}.{args.mode}.{args.gamma*100}.search_res.jsonl")
    make_needed_dir(save_path)

    all_start_time = time.time()
    searcher = CodeSearchWorker(query_cases, save_path, args.mode, gamma=args.gamma)
    searcher.run()
    all_end_time = time.time()

    running_time = all_end_time - all_start_time

    search_cases = load_jsonl(save_path)
    hit1, hit5, hit10 = hit(search_cases, hits=[1, 5, 10])

    print('-' * 20 + "Parameters" + '-' * 20)
    print(f"query_cases: {args.query_cases}")
    print(f'mode: {args.mode}')
    print(f'gamma: {args.gamma}')
    print('-' * 20 + "Results" + '-' * 20)
    print(f'save_path: {save_path}')
    print('hit1 %.4f' % hit1)
    print('hit5 %.4f' % hit5)
    print('hit10 %.4f' % hit10)
    print('runtime %.4f' % running_time)
