from collections import defaultdict, namedtuple
from collections.abc import MutableMapping
from functools import cache
from pathlib import Path
import random
import timeout_decorator
from loguru import logger
from app.config import config
from app.data_structures import BugLocation, SearchResult
from app.search import search_utils
from app.utils import catch_all_and_log
from app.task.task import Task
from app.search.search_utils import *
LineRange = namedtuple("LineRange", ["start", "end"])

ClassIndexType = MutableMapping[str, list[tuple[str, LineRange]]]
ClassFuncIndexType = MutableMapping[
    str, MutableMapping[str, list[tuple[str, LineRange]]]
]
FuncIndexType = MutableMapping[str, list[tuple[str, LineRange]]]
ClassRelationIndexType = MutableMapping[str, list[str]]

RESULT_SHOW_LIMIT = 3


class SearchBackend:
    def __init__(self, project_path: str, task: Task, include_test_files: bool):
        print(f"include_test_files in SearchBackend: {include_test_files}")
        self.project_path = project_path
        # list of all files ending with .py, which are likely not test files
        # These are all ABSOLUTE paths.
        self.parsed_files: list[str] = []
        self.task = task
        # for file name in the indexes, assume they are absolute path
        # class name -> [(file_name, line_range)]
        self.class_index: ClassIndexType = {}

        # {class_name -> {func_name -> [(file_name, line_range)]}}
        # inner dict is a list, since we can have (1) overloading func names,
        # and (2) multiple classes with the same name, having the same method
        self.class_func_index: ClassFuncIndexType = {}

        # a partially complete map of all the subclass relations
        # {class_name -> [class_name]}
        self.class_relation_index: ClassRelationIndexType = defaultdict(list)

        # {function name -> [(file_name, line_range)]}
        self.function_index: FuncIndexType = {}
        self._build_index(include_test_files) # when test_in_retrieval is True, we will include test cases in retrieval

    def _build_index(self, include_test_files: bool):
        """
        With all source code of the project, build two indexes:
            1. From class name to (source file, start line, end line)
            2. From function name to (source file, start line, end line)
        Since there can be two classes/functions with the same name, the mapping
        value is a list of tuples.
        This is for fast lookup whenever we receive a query.
        
        when include_test_files is True, we will incude test cases in retrieval.
        
        """
        self._update_indices(*self._build_python_index(self.project_path, include_test_files))

    def _update_indices(
        self,
        class_index: ClassIndexType,
        class_func_index: ClassFuncIndexType,
        function_index: FuncIndexType,
        class_relation_index: ClassRelationIndexType,
        parsed_files: list[str],
    ) -> None:
        self.class_index.update(class_index)
        self.class_func_index.update(class_func_index)
        self.function_index.update(function_index)
        self.class_relation_index.update(class_relation_index)
        self.parsed_files.extend(parsed_files)

    @classmethod
    @cache
    def _build_python_index(cls, project_path: str, include_test_files: bool = True) -> tuple[
        ClassIndexType,
        ClassFuncIndexType,
        FuncIndexType,
        ClassRelationIndexType,
        list[str],
    ]:
        class_index: ClassIndexType = defaultdict(list)
        class_func_index: ClassFuncIndexType = defaultdict(lambda: defaultdict(list))
        function_index: FuncIndexType = defaultdict(list)
        class_relation_index: ClassRelationIndexType = defaultdict(list)

        # py_files = search_utils.find_python_files(project_path, pass_test_files = True) # exclude test cases
        py_files = search_utils.find_python_files(project_path, include_test_files = include_test_files) 
        # when include_test_files is True, we include test cases in the retrieval stage
        # holds the parsable subset of all py files
        parsed_py_files = []
        for py_file in py_files:
            file_info = search_utils.parse_python_file(py_file)
            if file_info is None:
                # parsing of this file failed
                continue
            parsed_py_files.append(py_file)
            # extract from file info, and form search index
            classes, class_to_funcs, top_level_funcs, class_relation_map = file_info

            # (1) build class index
            for c, start, end in classes:
                class_index[c].append((py_file, LineRange(start, end)))

            # (2) build class-function index
            for c, class_funcs in class_to_funcs.items():
                for f, start, end in class_funcs:
                    class_func_index[c][f].append((py_file, LineRange(start, end)))

            # (3) build (top-level) function index
            for f, start, end in top_level_funcs:
                function_index[f].append((py_file, LineRange(start, end)))

            # (4) build class-superclass index
            for (c, start, end), super_classes in class_relation_map.items():
                class_relation_index[c] = super_classes

        return (
            class_index,
            class_func_index,
            function_index,
            class_relation_index,
            parsed_py_files,
        )

    def _file_line_to_class_and_func(
        self, file_path: str, line_no: int
    ) -> tuple[str | None, str | None]:
        """
        Given a file path and a line number, return the class and function name.
        If the line is not inside a class or function, return None.
        """
        # check whether this line is inside a class
        for class_name in self.class_func_index:
            func_dict = self.class_func_index[class_name]
            for func_name, func_info in func_dict.items():
                for file_name, (start, end) in func_info:
                    if file_name == file_path and start <= line_no <= end:
                        return class_name, func_name

        # not in any class; check whether this line is inside a top-level function
        for func_name in self.function_index:
            for file_name, (start, end) in self.function_index[func_name]:
                if file_name == file_path and start <= line_no <= end:
                    return None, func_name

        # this file-line is not recorded in any of the indexes
        return None, None

    def _search_func_in_class(
        self, function_name: str, class_name: str
    ) -> list[SearchResult]:
        """
        Search for the function name in the class.
        Args:
            function_name (str): Name of the function.
            class_name (str): Name of the class.
        Returns:
            The list of code snippets searched.
        """
        result: list[SearchResult] = []
        if class_name not in self.class_func_index:
            return result
        if function_name not in self.class_func_index[class_name]:
            return result
        for fname, (start, end) in self.class_func_index[class_name][function_name]:
            func_code = search_utils.get_code_snippets(fname, start, end)
            res = SearchResult(fname, start, end, class_name, function_name, func_code)
            result.append(res)
        return result

    def _search_func_in_all_classes(self, function_name: str) -> list[SearchResult]:
        """
        Search for the function name in all classes.
        Args:
            function_name (str): Name of the function.
        Returns:
            The list of code snippets searched.
        """
        result: list[SearchResult] = []
        for class_name in self.class_index:
            res = self._search_func_in_class(function_name, class_name)
            result.extend(res)
        return result

    def _search_top_level_func(self, function_name: str) -> list[SearchResult]:
        """
        Search for top-level function name in the entire project.
        Args:
            function_name (str): Name of the function.
        Returns:
            The list of code snippets searched.
        """
        result: list[SearchResult] = []
        if function_name not in self.function_index:
            return result

        for fname, (start, end) in self.function_index[function_name]:
            func_code = search_utils.get_code_snippets(fname, start, end)
            res = SearchResult(fname, start, end, None, function_name, func_code)
            result.append(res)
        return result

    def _search_func_in_code_base(self, function_name: str) -> list[SearchResult]:
        """
        Search for this function, from both top-level and all class definitions.
        """
        # BUG: which might cause duplicate search.
        result: list[SearchResult] = []  # list of (file_name, func_code)
        # (1) search in top level
        top_level_res = self._search_top_level_func(function_name)
        class_res = self._search_func_in_all_classes(function_name)
        result.extend(class_res)
        
        # the following code is to solve the bug of duplicate search
        for ele in top_level_res:
            ele_file_path = ele.file_path
            ele_start = ele.start
            ele_end = ele.end
            flag = 0
            for res in result:
                if(ele_file_path == res.file_path and ele_start == res.start and ele_end == res.end):
                    flag = 1
                    break
            if (not flag):
                result.append(ele)
        # result.extend(top_level_res)
        return result

    def _get_candidate_matched_py_files(self, target_file_name: str):
        """
        Search for files in the project that may match target_file_name.

        Returns:
            - all matched files, in abs path.
        """
        parsed_files_lower = [f.lower() for f in self.parsed_files]
        parsed_files = zip(self.parsed_files, parsed_files_lower)
        target_lower = target_file_name.lower()

        candidates = []
        for orig_file, lower_file in parsed_files:
            if lower_file.endswith(target_lower):
                candidates.append(orig_file)
        return candidates

    ###############################
    ### Interfaces ################
    ###############################
    ## NOTE: SearchResult objects returned by search APIs are not used when
    ## communicating with model - they are mainly for our own use cases.
    ## Only the first `tool_result` returned value is what sent to the model.

    # not search API - for writing patch
    # if we are searching for only a class when writing patch, likely we do not have enough info
    # the result can be too long, so we just show the first two
    # TODO: what to do with this method? It's not a method exposed to the agent, but maybe we also
    # want to catch exceptions from it?
    @catch_all_and_log
    def get_class_full_snippet(
        self, class_name: str
    ) -> tuple[str, list[SearchResult], bool]:
        search_res: list[SearchResult] = []
        tool_result = f"Could not find class {class_name} in the codebase."

        if class_name not in self.class_index:
            return tool_result, search_res, False

        for fname, (start, end) in self.class_index[class_name]:
            code = search_utils.get_code_snippets(fname, start, end)
            res = SearchResult(fname, start, end, class_name, None, code)
            search_res.append(res)

        if not search_res:
            return tool_result, search_res, False

        # the good path
        # for all the searched result, append them and form the final result
        tool_result = f"Found {len(search_res)} classes with name {class_name} in the codebase:\n\n"

        if len(search_res) > 2:
            tool_result += "Too many results, showing full code for 2 of them:\n"

        final_search_res = search_res[:2]
        for idx, res in enumerate(final_search_res):
            res_str = res.to_tagged_str(self.project_path)
            tool_result += f"- Search result {idx + 1}:\n```\n{res_str}\n```"
        return tool_result, final_search_res, True

    # agent API
    @catch_all_and_log
    def search_class(self, class_name: str) -> tuple[str, list[SearchResult], bool]:
        """Search for a class in the codebase.

        Only the signature of the class is returned. The class signature
        includes class name, base classes, and signatures for all of its methods/properties.

        Args:
            class_name (string): Name of the class to search for.
        """
        # initialize them to error case
        search_res: list[SearchResult] = []
        tool_result = f"Could not find class {class_name} in the codebase."

        if class_name not in self.class_index:
            return tool_result, search_res, False

        for fname, (start, end) in self.class_index[class_name]:
            # there are some classes; we return their signatures
            code = search_utils.get_class_signature(fname, class_name)
            res = SearchResult(fname, start, end, class_name, None, code)
            search_res.append(res)

        if not search_res:
            # this should not happen, but just in case
            return tool_result, search_res, False

        # the good path
        # for all the searched result, append them and form the final result
        tool_result = f"Found {len(search_res)} classes with name {class_name} in the codebase:\n\n"
        if len(search_res) > RESULT_SHOW_LIMIT:
            tool_result += "They appeared in the following files:\n"
            tool_result += SearchResult.collapse_to_file_level(
                search_res, self.project_path
            )
        else:
            for idx, res in enumerate(search_res):
                res_str = res.to_tagged_str(self.project_path)
                tool_result += f"- Search result {idx + 1}:\n```\n{res_str}\n```\n"
        final_search_res = search_res[:RESULT_SHOW_LIMIT]
        return tool_result, final_search_res, True

    # agent API
    @catch_all_and_log
    def search_class_in_file(
        self, class_name, file_name: str
    ) -> tuple[str, list[SearchResult], bool]:
        """Search for a class in a given file.

        Returns the actual code of the entire class definition.

        Args:
            class_name (string): Name of the class to search for.
            file_name (string): The file to search in. Must be a valid python file name.
        """
        search_res: list[SearchResult] = []

        # (1) check whether we can get the file
        candidate_py_abs_paths = self._get_candidate_matched_py_files(file_name)
        if not candidate_py_abs_paths:
            tool_output = f"Could not find file {file_name} in the codebase."
            return tool_output, search_res, False

        # (2) search for this class in the entire code base (we do filtering later)
        if class_name not in self.class_index:
            tool_output = f"Could not find class {class_name} in the codebase."
            return tool_output, search_res, False

        # (3) class is there, check whether it exists in the file specified.
        for fname, (start, end) in self.class_index[class_name]:
            if fname in candidate_py_abs_paths:
                class_code = search_utils.get_code_snippets(fname, start, end)
                res = SearchResult(fname, start, end, class_name, None, class_code)
                search_res.append(res)

        if not search_res:
            tool_output = f"Could not find class {class_name} in file {file_name}."
            return tool_output, search_res, False

        # good path; we have result, now just form a response
        tool_output = f"Found {len(search_res)} classes with name {class_name} in file {file_name}:\n\n"
        for idx, res in enumerate(search_res):
            res_str = res.to_tagged_str(self.project_path)
            tool_output += f"- Search result {idx + 1}:\n```\n{res_str}\n```\n"
        return tool_output, search_res, True

    # agent API
    @catch_all_and_log
    def search_method_in_file(
        self, method_name: str, file_name: str
    ) -> tuple[str, list[SearchResult], bool]:
        """Search for a method in a given file.

        Returns the actual code of the method.

        Args:
            method_name (string): Name of the method to search for.
            file_name (string): The file to search in. Must be a valid python file name.
        """
        # (1) check whether we can get the file
        # supports both when file_name is relative to project root, and when
        # it is just a short name
        candidate_py_abs_paths = self._get_candidate_matched_py_files(file_name)
        # print(candidate_py_files)
        if not candidate_py_abs_paths:
            tool_output = f"Could not find file {file_name} in the codebase."
            return tool_output, [], False

        # (2) search for this method in the entire code base (we do filtering later)
        search_res: list[SearchResult] = self._search_func_in_code_base(method_name)
        if not search_res:
            tool_output = f"The method {method_name} does not appear in the codebase."
            return tool_output, [], False

        # (3) filter the search result => they need to be in one of the files!
        filtered_res: list[SearchResult] = [
            res for res in search_res if res.file_path in candidate_py_abs_paths
        ]

        # (4) done with search, now prepare result
        if not filtered_res:
            tool_output = (
                f"There is no method with name `{method_name}` in file {file_name}."
            )
            return tool_output, [], False

        tool_output = f"Found {len(filtered_res)} methods with name `{method_name}` in file {file_name}:\n\n"

        # when searching for a method in one file, it's rare that there are
        # many candidates, so we do not trim the result
        for idx, res in enumerate(filtered_res):
            res_str = res.to_tagged_str(self.project_path)
            tool_output += f"- Search result {idx + 1}:\n```\n{res_str}\n```\n"
        return tool_output, filtered_res, True

    # agent API
    @catch_all_and_log
    def search_method_in_class(
        self, method_name: str, class_name: str
    ) -> tuple[str, list[SearchResult], bool]:
        """Search for a method in a given class.

        Returns the actual code of the method.

        Args:
            method_name (string): Name of the method to search for.
            class_name (string): Consider only methods in this class.
        """
        if class_name not in self.class_index:
            tool_output = f"Could not find class {class_name} in the codebase."
            return tool_output, [], False

        # has this class, check its methods
        search_res: list[SearchResult] = self._search_func_in_class(
            method_name, class_name
        )
        if not search_res:
            tool_output = f"Could not find method {method_name} in class {class_name}`."
            return tool_output, [], False

        # found some methods, prepare the result
        tool_output = f"Found {len(search_res)} methods with name {method_name} in class {class_name}:\n\n"

        # There can be multiple classes defined in multiple files, which contain the same method
        # still trim the result, just in case
        if len(search_res) > RESULT_SHOW_LIMIT:
            tool_output += f"Too many results, showing full code for {RESULT_SHOW_LIMIT} of them, and the rest just file names:\n"
        first_five = search_res[:RESULT_SHOW_LIMIT]
        for idx, res in enumerate(first_five):
            res_str = res.to_tagged_str(self.project_path)
            tool_output += f"- Search result {idx + 1}:\n```\n{res_str}\n```\n"
        # for the rest, collect the file names into a set
        if rest := search_res[RESULT_SHOW_LIMIT:]:
            tool_output += "Other results are in these files:\n"
            tool_output += SearchResult.collapse_to_file_level(rest, self.project_path)

        return tool_output, first_five, True

    # agent API
    @catch_all_and_log
    def search_method(self, method_name: str) -> tuple[str, list[SearchResult], bool]:
        """Search for a method in the entire codebase.

        Returns the actual code of the method.

        Args:
            method_name (string): Name of the method to search for.
        """
        search_res: list[SearchResult] = self._search_func_in_code_base(method_name)
        if not search_res:
            tool_output = f"Could not find method {method_name} in the codebase."
            return tool_output, [], False

        tool_output = f"Found {len(search_res)} methods with name {method_name} in the codebase:\n\n"

        if len(search_res) > RESULT_SHOW_LIMIT:
            tool_output += "They appeared in the following files:\n"
            tool_output += SearchResult.collapse_to_file_level(
                search_res, self.project_path
            )
        else:
            for idx, res in enumerate(search_res):
                res_str = res.to_tagged_str(self.project_path)
                tool_output += f"- Search result {idx + 1}:\n```\n{res_str}\n```\n"

        final_search_res = search_res[:RESULT_SHOW_LIMIT]
        # final_search_res = search_res
        return tool_output, final_search_res, True

    # agent API
    @catch_all_and_log
    @timeout_decorator.timeout(120)
    def search_code(self, code_str: str) -> tuple[str, list[SearchResult], bool]:
        """Search for a code snippet in the entire codebase.

        Returns the method that contains the code snippet, if it is found inside a method.
        Otherwise, returns the region of code surrounding it.

        Args:
            code_str (string): The code snippet to search for.
        """
        # attempt to search for this code string in all py files
        search_res: list[SearchResult] = []
        for file_path in self.parsed_files:
            searched_line_and_code: list[tuple[int, str]] = (
                search_utils.get_code_region_containing_code(file_path, code_str)
            )
            if not searched_line_and_code:
                continue
            for searched in searched_line_and_code:
                line_no, code_region = searched
                # from line_no, check which function and class we are in
                class_name, func_name = self._file_line_to_class_and_func(
                    file_path, line_no
                )
                res = SearchResult(
                    file_path, line_no, line_no, class_name, func_name, code_region
                )
                search_res.append(res)

        if not search_res:
            tool_output = f"Could not find code {code_str} in the codebase."
            return tool_output, [], False

        # good path
        tool_output = f"Found {len(search_res)} snippets containing `{code_str}` in the codebase:\n\n"

        if len(search_res) > RESULT_SHOW_LIMIT:
            tool_output += "They appeared in the following files:\n"
            tool_output += SearchResult.collapse_to_file_level(
                search_res, self.project_path
            )
        else:
            for idx, res in enumerate(search_res):
                res_str = res.to_tagged_str(self.project_path)
                tool_output += f"- Search result {idx + 1}:\n```\n{res_str}\n```\n"

        final_search_res = search_res[:RESULT_SHOW_LIMIT]
        return tool_output, final_search_res, True

    
    # agent API
    @catch_all_and_log
    @timeout_decorator.timeout(120)
    def search_relevant_method(self, top_num) -> tuple[str, list[SearchResult], bool]:
        """by providing the doc_str, calculate the bm25 score, and return thr top 10 functions.
        Args:
            doc_str (str): func docstr, intended to be the target function docstr
        Returns:
            tuple[str, list[SearchResult], bool]: similar to other apis
        """
        top_num = int(top_num)
        result: list[SearchResult] = []
        candidate: list[SearchResult] = []
        # for all functions, get their docstring, calculate the bm25 score, add to results, only provide the top 10 funcs
        for function_name in self.function_index: # {function name -> [(file_name, line_range)]}
            # collect all funcs from top func level
            for fname, (start, end) in self.function_index[function_name]:
                func_code = search_utils.get_code_snippets(fname, start, end)
                res = SearchResult(fname, start, end, None, function_name, func_code)
                candidate.append(res)
                
        for class_name in self.class_index:
            # collect funcs from class level
            if class_name not in self.class_func_index:
                continue
            for function_name in self.class_func_index[class_name]:
                for fname, (start, end) in self.class_func_index[class_name][function_name]:
                    func_code = search_utils.get_code_snippets(fname, start, end)
                    res = SearchResult(fname, start, end, None, function_name, func_code)
                    # the following code is to solve the bug of duplicate search
                    for ele in candidate:
                        ele_file_path = ele.file_path
                        ele_start = ele.start
                        ele_end = ele.end
                        flag = 0
                        if(ele_file_path == res.file_path and ele_start == res.start and ele_end == res.end):
                            flag = 1
                            break
                    if (not flag):
                        candidate.append(res)
        # after collecting all the funcs in the repo, collect target function sig and docstring from task
        doc_str = extract_target_func_from_Task(self.task)
        # do the bm25 computation, based on docstrings
        bm25_result_list = search_utils.compute_bm25(doc_str, candidate) # this is a ranked list with size 20, [(SearchResult, scores)]
        # handle the print statement
        tool_output = f"To help you generate an accurate and high-quality function, we have selected the top {top_num} most relevant functions from the repository based on their BM25 similarity to the target function's signature and docstring, from low to high relevance:\n\n"
        
        count = 0
        final_search_res = bm25_result_list[:top_num]
        for search_result, bm25_score in final_search_res[::-1]: # reverse, from low to high
            if count < top_num:
                count += 1
                res_str = search_result.to_tagged_str(self.project_path)
                tool_output += f"- Search result {count}: \n```\n{res_str}\n```\nbm25 score: {bm25_score}\n\n"
            result.append(search_result)
        
        return tool_output, final_search_res, True

    # agent API
    @catch_all_and_log
    def search_code_in_file(
        self, code_str: str, file_name: str
    ) -> tuple[str, list[SearchResult], bool]:
        """Search for a code snippet in a given file file.

        Returns the entire method that contains the code snippet.

        Args:
            code_str (string): The code snippet to search for.
            file_name (string): The file to search in. Must be a valid python file name in the project.
        """
        code_str = code_str.removesuffix(")")

        candidate_py_files = [f for f in self.parsed_files if f.endswith(file_name)]
        if not candidate_py_files:
            tool_output = f"Could not find file {file_name} in the codebase."
            return tool_output, [], False

        # start searching for code in the filtered files
        search_res: list[SearchResult] = []
        for file_path in candidate_py_files:
            searched_line_and_code: list[tuple[int, str]] = (
                search_utils.get_code_region_containing_code(file_path, code_str)
            )
            if not searched_line_and_code:
                continue
            for searched in searched_line_and_code:
                line_no, code_region = searched
                # from line_no, check which function and class we are in
                class_name, func_name = self._file_line_to_class_and_func(
                    file_path, line_no
                )
                res = SearchResult(
                    file_path, line_no, line_no, class_name, func_name, code_region
                )
                search_res.append(res)

        if not search_res:
            tool_output = f"Could not find code {code_str} in file {file_name}."
            return tool_output, [], False

        # good path
        # There can be a lot of results, from multiple files.
        tool_output = f"Found {len(search_res)} snippets with code {code_str} in file {file_name}:\n\n"
        if len(search_res) > RESULT_SHOW_LIMIT:
            tool_output += "They appeared in the following methods:\n"
            tool_output += SearchResult.collapse_to_method_level(
                search_res, self.project_path
            )
        else:
            for idx, res in enumerate(search_res):
                res_str = res.to_tagged_str(self.project_path)
                tool_output += f"- Search result {idx + 1}:\n```\n{res_str}\n```\n"

        final_search_res = search_res[:RESULT_SHOW_LIMIT]
        return tool_output, final_search_res, True

    # agent API
    @catch_all_and_log
    def get_code_around_line(
        self, file_name: str, line_no_str: str, window_size_str: str
    ) -> tuple[str, list[SearchResult], bool]:
        """
        Get the region of code around line `line_no` in the file `file_name`.

        Args:
            file_name (str): The file name.
            line_no_str (str): The line number. (1-based)
            window_size_str (str): The number of lines before and after the line number.
        """
        # we get argument as string
        line_no = int(line_no_str)
        window_size = int(window_size_str)

        # (1) check whether we can get the file
        candidate_py_abs_paths = self._get_candidate_matched_py_files(file_name)
        if not candidate_py_abs_paths:
            tool_output = f"Could not find file {file_name} in the codebase."
            return tool_output, [], False

        # (2) make a SearchResult for each file
        # region search result is what we will turn into the response to the model
        region_search_results: list[SearchResult] = []
        # func_search_results is what we keep for record
        func_search_results: list[SearchResult] = []

        for file_path in candidate_py_abs_paths:
            snippet = search_utils.get_code_region_around_line(
                file_path, line_no, window_size
            )
            if snippet is None:
                continue
            class_name, func_name = self._file_line_to_class_and_func(
                file_path, line_no
            )
            # get the surrounding functions, since our instrumentation is on function level
            if func_name is not None and class_name is not None:
                _, curr_func_results, _ = self.search_method_in_class(
                    func_name, class_name
                )
            elif func_name is not None:
                _, curr_func_results, _ = self.search_method(func_name)
            else:
                curr_func_results = []
            func_search_results.extend(curr_func_results)

            start_lineno = line_no - window_size
            end_lineno = line_no + window_size
            res = SearchResult(
                file_path, start_lineno, end_lineno, class_name, func_name, snippet
            )
            region_search_results.append(res)

        if not region_search_results:
            tool_output = f"{line_no} is invalid in file {file_name}."
            return tool_output, [], False

        # good path
        tool_output = f"Found {len(region_search_results)} code snippets around line {line_no}:\n\n"
        for idx, res in enumerate(region_search_results):
            res_str = res.to_tagged_str(self.project_path)
            tool_output += f"- Search result {idx + 1}:\n```\n{res_str}\n```\n"

        # NOTE: returning functions in search results, since they will be instrumented later
        return tool_output, func_search_results, True

    # agent API
    @catch_all_and_log
    def search_import_in_file(self, file_name: str) -> tuple[str, list[SearchResult], bool]:
        """search for import statements from a given file
            return the source code of import
        Args:
            file_name (str): must be an valid file path

        Returns:
            tuple[str, list[SearchResult], bool]
        """
        def extract_top_level_imports(file_content: str):
            lines = file_content.splitlines()
            tree = ast.parse(file_content)
            included_line_nums = set()
            for node in tree.body:
                if isinstance(node, (ast.Import, ast.ImportFrom)):
                    # simple top-level import
                    for lineno in range(node.lineno, getattr(node, 'end_lineno', node.lineno) + 1):
                        included_line_nums.add(lineno)
                elif isinstance(node, ast.If):
                    # top-level if/else block
                    for subnode in ast.walk(node):
                        if isinstance(subnode, (ast.Import, ast.ImportFrom)):
                            for lineno in range(subnode.lineno, getattr(subnode, 'end_lineno', subnode.lineno) + 1):
                                included_line_nums.add(lineno)
                    # also include the condition lines themselves
                    for lineno in range(node.lineno, getattr(node, 'end_lineno', node.lineno) + 1):
                        included_line_nums.add(lineno)
            # Extract lines in order, preserving format
            included_line_nums = sorted(included_line_nums)
            extracted_lines = [lines[i - 1] for i in included_line_nums]
            start_line = included_line_nums[0]
            end_line = included_line_nums[-1]
            return extracted_lines, start_line, end_line
        
        search_res: list[SearchResult] = []
        # (1) check whether we can get the file
        candidate_py_abs_paths = self._get_candidate_matched_py_files(file_name)
        if not candidate_py_abs_paths:
            tool_output = f"Could not find file {file_name} in the codebase."
            return tool_output, search_res, False
        
        # (2) find imports within the file
        for file_path in candidate_py_abs_paths:
            with open(file_path, 'r') as f:
                target_file_content = f.read()
            top_import_list, start_line, end_line = extract_top_level_imports(target_file_content)
            import_searchresult = SearchResult(file_path = file_path, 
                                               start = int(start_line), 
                                               end = int(end_line), 
                                               class_name = None, 
                                               func_name = None, 
                                               code = '\n'.join(top_import_list)
                                               )
            search_res.append(import_searchresult)
        
        if not search_res:
            tool_output = f"Could not find import statements in file {file_name}."
            return tool_output, [], False
        
        else:
            # good path; we have result, now just form a response
            tool_output = f"Found following top level import statements in file {file_name}:\n\n"
            for idx, res in enumerate(search_res):
                res_str = res.to_tagged_str(self.project_path)
                tool_output += f"- Search result {idx + 1}:\n```\n{res_str}\n```\n"
            return tool_output, search_res, True


    @catch_all_and_log
    def get_file_content(self, file_name: str) -> tuple[str, list[SearchResult], bool]:
        """Get actual content of the entire file.
        Mainly used for retrieving actual code snippets at selected bug locations.

        Args:
            - file_name: relevant path to the file.
        """
        # check whether we can get the file
        candidate_py_files = [f for f in self.parsed_files if f.endswith(file_name)]
        if not candidate_py_files:
            tool_output = f"Could not find file {file_name} in the codebase."
            return tool_output, [], False

        # NOTE: sometimes there can be multiple files.
        # To make the execution safe, we just take the first one

        file_path = candidate_py_files[0]
        file_content = Path(file_path).read_text()

        file_length = len(file_content.splitlines())

        search_res = [SearchResult(file_path, 1, file_length, None, None, file_content)]

        tool_output = (
            f"Found file {file_name} in the codebase:\n\n```\n{file_content}\n```\n"
        )
        tool_output = f"<file>{file_name}</file> <code>{file_content}</code>"
        return tool_output, search_res, True

    def retrieve_class_context(
        self, class_and_files: set[tuple[str, str]]
    ) -> str | None:
        """
        Args:
            - set of classes to retrieve as additional context.
            Each element is a tuple of (class_name, file_name).
        Returns:
            - A string containing definitions of all classes.
        """
        result_prefix = (
            "As additional context, here are the complete definitions of the classes "
            "around the more specific methods.\n"
        )
        result = ""

        for class_name, file_name in class_and_files:
            kwargs = {"class_name": class_name, "file_name": file_name}
            code, _, search_ok = self.search_class_in_file(**kwargs)
            if search_ok:
                result += f"\n\n{code}\n\n"

        if result:
            # some class definitions could be retrieved
            return result_prefix + result
        else:
            return None

    def _get_inherited_methods(self, class_name: str, method_name: str):
        """
        Given a method in a class, find its inherited classes in the parent class.
        Should eventually return whatever "search_method_in_class" returns.
        """
        class_queue: list[tuple[str, int]] = list(
            map(lambda n: (n, 1), self.class_relation_index[class_name])
        )
        super_calls: list[dict[str, str]] = []
        found_at_depth = -1
        while class_queue:
            (ancestor_name, depth) = class_queue.pop(0)
            if found_at_depth != -1 and depth > found_at_depth:
                break
            functions = self.class_func_index.get(ancestor_name, dict())
            if method_name in functions:
                found_at_depth = depth
                super_calls.append(
                    {"class_name": ancestor_name, "method_name": method_name}
                )
            else:
                for great_ancestor in self.class_relation_index.get(
                    ancestor_name, list()
                ):
                    class_queue.append((great_ancestor, depth + 1))

        final_output = ""
        final_search_res: list[SearchResult] = []

        if super_calls:
            for super_call in super_calls:
                logger.debug(
                    f"Found override of {super_call['method_name']} in {super_call['class_name']}"
                )

                # output, search_res, call_ok = self.search_method_in_class(super_call)
                output, search_res, call_ok = self.search_method_in_class(super_call['method_name'], super_call['class_name'])

                if not call_ok:
                    continue

                final_output += f"As additional context, this is an overriden instance of the method {method_name} inside class {super_call['class_name']}\n\n{output}\n\n"

                final_search_res.extend(search_res)

        return final_output, final_search_res, bool(final_output)

    def get_bug_loc_snippets_new(self, bug_location_dict: dict[str, str]):
        """
        Since this function is probably buggy, rewrite it.
        """
        # these are just what the model has returned us, so they may be wrong
        tmp_file_name = bug_location_dict.get("file", "")
        tmp_method_name = bug_location_dict.get("method", "")
        tmp_class_name = bug_location_dict.get("class", "")
        # (1) sometimes model can write class_name and method_name together in the
        # format Class.method
        if not tmp_class_name and tmp_method_name and "." in tmp_method_name:
            fragments = tmp_method_name.split(".")
            if len(fragments) == 2:
                tmp_class_name, tmp_method_name = fragments
                logger.warning(
                    "Successfully split {} and {}", tmp_class_name, tmp_method_name
                )
            else:
                logger.warning(
                    "Too many fragments. Examine the method name: {}", tmp_method_name
                )
        # we require at least the file_name to be given
        assert (
            tmp_method_name or tmp_class_name or tmp_file_name
        ), f"Invalid bug location returned from model: {bug_location_dict}"

        # (2) start searching for this location in the codebase
        call_ok = False
        search_res: list[SearchResult] = []
        class_context_search_res: list[SearchResult] = []
        # (2.1) search for the method in the class
        # NOTE: make sure all search_res below contains a valid unit of code,
        # such as method/class/file. Also, the search_res should contain correct
        # line numbers for this code unit.
        # Due to legacy reasons, search_res returned by some functions do not
        # satisfy the requirement above, so DO NOT use those functions here.

        if tmp_method_name and tmp_class_name:
            output, curr_search_res, call_ok = self.search_method_in_class(
                tmp_method_name, tmp_class_name
            )

            search_res.extend(curr_search_res)
            # when the location is decided to be method in a class, the original acr also
            # obtain (1) the entire class as location, and (2) the inherited
            # parent class methods as location. But we don't need it in code gen tasks. 
            # We just need the exact location of the target function, so we ignore this part 
            # of code and delete.
        if (not call_ok) and tmp_method_name and tmp_file_name:
            output, search_res, call_ok = self.search_method_in_file(
                tmp_method_name, tmp_file_name
            )

        if (not call_ok) and tmp_class_name and tmp_file_name:
            output, search_res, call_ok = self.search_class_in_file(
                tmp_class_name, tmp_file_name
            )

        if (not call_ok) and tmp_class_name:
            output, search_res, call_ok = self.get_class_full_snippet(tmp_class_name)

        if (not call_ok) and tmp_method_name:
            output, search_res, call_ok = self.search_method(tmp_method_name)

        if (not call_ok) and tmp_file_name:
            output, search_res, call_ok = self.get_file_content(tmp_file_name)

        if not call_ok:
            # cannot find any location!!
            return []

        # we have some SearchResults => turn these into BugLocations
        res: SearchResult

        final_bug_locs: list[BugLocation] = []
        for res in search_res:
            if res.start is None or res.end is None:
                continue
            new_bug_loc = BugLocation(res, self.project_path)
            final_bug_locs.append(new_bug_loc)

        return final_bug_locs

    # agent API
    @catch_all_and_log
    @timeout_decorator.timeout(120)
    def search_target_usage_example(self, example_num) -> tuple[str, list[SearchResult], bool]:
        """
        Search for usage examples of a target function in the codebase by finding
        function/methods that directly call it.
        Args:
            target_func_name (int or str): The function to search for usages.
        Returns:
            Tuple of (formatted string output, list of SearchResult, success flag)
        """       
        example_num = int(example_num)
        # collect all functions. same as search_relevant_method
        # collect all funcs from top func level
        candidate: list[SearchResult] = []
        for function_name in self.function_index: # {function name -> [(file_name, line_range)]}
            for fname, (start, end) in self.function_index[function_name]:
                func_code = search_utils.get_code_snippets(fname, start, end)
                res = SearchResult(fname, start, end, None, function_name, func_code)
                candidate.append(res)
        # collect funcs from class level
        for class_name in self.class_index:
            if class_name not in self.class_func_index:
                continue
            for function_name in self.class_func_index[class_name]:
                for fname, (start, end) in self.class_func_index[class_name][function_name]:
                    func_code = search_utils.get_code_snippets(fname, start, end)
                    res = SearchResult(fname, start, end, None, function_name, func_code)
                    # the following code is to solve the bug of duplicate search
                    for ele in candidate:
                        ele_file_path = ele.file_path
                        ele_start = ele.start
                        ele_end = ele.end
                        flag = 0
                        if(ele_file_path == res.file_path and ele_start == res.start and ele_end == res.end):
                            flag = 1
                            break
                    if (not flag):
                        candidate.append(res)
        # after collecting all the function candidates, check each of them whether call the target functions through static analysis
        results: list[SearchResult] = []
        target_function_name = self.task.function_name
        target_class_name = self.task.class_name
        for search_result in candidate:
            call_target_flag = is_func_calls_target(search_result, target_function_name, target_class_name)
            if call_target_flag: #called the target function
                results.append(search_result)
        
        # if None:
        if len(results) == 0:
            if target_class_name:
                tool_output = f"Could not find any function that calls the target function `{target_function_name}` from class `{target_class_name}`.\n"
            else:
                tool_output = f"Could not find any function that calls the target function `{target_function_name}`.\n"
            return tool_output, [], False
        
        else:
            if example_num > len(results):
                example_num = len(results)
            final_res = results[:example_num]
            if target_class_name:
                tool_output = f"We have selected {example_num} functions directly call the target function `{target_function_name}` in class `{target_class_name}`. You can treat these as usage examples of the target:\n"
            else:
                tool_output = f"We have selected {example_num} functions directly call the target function `{target_function_name}`. You can treat these as usage examples of the target:\n"
            for idx, res in enumerate(final_res):
                res_str = res.to_tagged_str(self.project_path)
                tool_output += f"- Search result {idx+1}:\n```\n{res_str}\n```\n\n"
            return tool_output, results, True
    
    
    @catch_all_and_log     
    def search_test_cases(
        self,
        ) -> tuple[str, list, bool]:
        
        if config.test_in_retrieval == False:
            tool_output = f"Under this task setting, you have no access to test cases of the target function.\nHowever, you can use other APIs to collect informative context.\n\n"
            self.task.selected_test_list = []
            return tool_output, self.task.selected_test_list, True
        
        test_num = config.test_num
        test_case_selection_mode = config.test_case_selection_mode
        
        def format_test_cases(test_cases: list, relavent_test_path, max_show: int = 20,) -> str:
            output = ""
            if config.DATASET_NAME == "RepoCod":
                for idx, res in enumerate(test_cases[:max_show]):
                    if len(res["nodeid_list"]) > 1:
                        output += f"- Test {idx+1}:\npytest node id: `{res['base_nodeid']}`, around line: {res['lineno']}, with {len(res['nodeid_list'])} different test inputs:\n"
                        for count, test_with_input in enumerate(res["nodeid_list"]):
                            output += f"    - {count+1}. `{test_with_input}`;\n"

                    else:
                        output += f"- Test {idx+1}:\npytest node id: `{res['base_nodeid']}`, around line: {res['lineno']};\n"
                    
                    if res["direct_call_info"]["line"] != -1:
                        output += f'The target function is called in file {res["direct_call_info"]["file"]} around line {res["direct_call_info"]["line"]};\n'
                    output += F"    - Test Source Code:\n```\n{res["src_code"]}\n```\n\n"
                        
                    
            if config.DATASET_NAME == "RepoEval":
                file2test_match_dict = {}
                relavent_test_path_list = relavent_test_path.copy() # here, its a list
                for test_dict in relavent_test_path_list:
                    rel_test_file_path = test_dict["abs_test_path"].replace(config.REPOEVAL_LOCAL_DIR, '')
                    rel_test_file_path = '/'.join(rel_test_file_path.split('/')[2:])
                    if rel_test_file_path not in file2test_match_dict.keys():
                        file2test_match_dict[rel_test_file_path] = {
                            "pytest_test_list": [],
                            "src_code": test_dict["src_code"]
                        }
                        
                for idx, res in enumerate(test_cases[:max_show]):
                    rel_path = res['test_path_docker'].replace(config.REPOEVAL_WORK_DIR, '')
                    if rel_path.split('::')[0] in file2test_match_dict.keys():
                        file2test_match_dict[rel_path.split('::')[0]]["pytest_test_list"].append(rel_path)
                for key, val in file2test_match_dict.items():
                    output += f"- Test {key}```\n{val['src_code']}\n```\n\n"
                    
            return output
        
        test_data: dict = self.task.test_data
        combined_tests = []
        for call_dis, test_list in test_data.items():
            combined_tests.extend(test_list)
        
        target_function_name = self.task.function_name
        target_class_name = self.task.class_name
        relavent_test_path = self.task.relavent_test_path
        if len(combined_tests) == 0:
            if target_class_name:
                tool_output = f"Can't find related test cases for the target function `{target_function_name}` from class `{target_class_name}`. "
            else:
                tool_output = f"Can't find related test cases for the target function `{target_function_name}`."
            tool_output += 'You may use other APIs to search for test cases.\n'
            return tool_output, [], False
        
        test_num = test_num if test_num < len(combined_tests) else len(combined_tests) # num of test we execute
        # show_num = 20 if test_num > 20 else test_num # the num of test to show in prompt
        show_num = test_num
        tool_output = f"Found {len(combined_tests)} test cases for target function `{target_function_name}`" + (f' from class `{target_class_name}` in total.\n' if target_class_name else f'in total.\n')

        if test_case_selection_mode == 'all':
            self.task.selected_test_list = combined_tests
            show_num = 20 if len(combined_tests) > 20 else len(combined_tests)
            tool_output += f"We will provide you all the {show_num} test cases:\n"
            
        elif test_case_selection_mode == "random":
            self.task.selected_test_list = random.sample(combined_tests, test_num)
            tool_output += f"We will provide you {test_num} randomly selected test cases:\n"
            
        elif test_case_selection_mode == 'call_distance':
            self.task.selected_test_list = combined_tests[:test_num]
            tool_output += f"We will provide you top {test_num} test cases that have the shortest call distance from the test to the target function, which means prioritize tests that directly call the target.\n\nHere are {show_num} tests in sequence (begins with shortest call distance):\n"

        elif test_case_selection_mode == "fail_signal":
            all_fail_signal_tests = []
            for test in combined_tests:
                if test['fail_signal']:
                    all_fail_signal_tests.append(test)
            if test_num >= len(all_fail_signal_tests):
                self.task.selected_test_list = all_fail_signal_tests
            else:
                self.task.selected_test_list = random.sample(all_fail_signal_tests, test_num)
            tool_output += f"We will provide you {test_num} test cases that contain explicit assertions or exception checks:\n"
        
        elif test_case_selection_mode == "diff_call":
            seen_call_infos = set()
            self.task.selected_test_list = []
            remaining_tests = []
            for test in combined_tests:
                direct_call = test.get("direct_call_info", {})
                call_key = (direct_call.get("file"), direct_call.get("line"))
                if call_key not in seen_call_infos:
                    seen_call_infos.add(call_key)
                    self.task.selected_test_list.append(test)
                else:
                    remaining_tests.append(test)
                if len(self.task.selected_test_list) >= test_num:
                    break
            if len(self.task.selected_test_list) < test_num:
                for test in remaining_tests:
                    self.task.selected_test_list.append(test)
                    if len(self.task.selected_test_list) >= test_num:
                        break
            tool_output += f"We will provide you the top {test_num} test cases that have the shortest call distance to the target function, meaning we prioritize tests that are most likely to directly call the target.\nTo potentially increase test coverage, we only include test cases that invoke the target from distinct locations.\nHere are {show_num} selected test cases:"
                    

        tool_output += format_test_cases(self.task.selected_test_list, relavent_test_path, show_num)
        # tool_output += "Base on above test locations and source code, you can call other APIs to collect more info" 
        return tool_output, self.task.selected_test_list, True
            
            
        
        
        
            
            


if __name__ == "__main__":
    pass
    # # test search_relevant_method: pass
    # from app.task.raw_tasks import RawLocalTask
    # task_exp = RawLocalTask(
    #         "seaborn",
    #         "seaborn_example",
    #         'REPOCOD/repocod_example/seaborn',
    #         'REPOCOD/repocod_example/seaborn_problem.md'
    #     )
    # backend = SearchBackend('REPOCOD/repocod_example/seaborn', task_exp)
    # tool_output, search_res, bool_val = backend.search_relevant_method(5)
    # print(tool_output)
    


    ## Test parsing of bug locations
    # backend = SearchBackend("/media/media0/yuntong/SWE-bench/testbed/django__django/setup_django__django__3.0")
    # bug_locations = [
    #     {
    #         "file": "django/conf/global_settings.py",
    #         "class": "",
    #         "method": ""
    #     },
    #     {
    #         "file": "django/core/files/storage.py",
    #         "class": "FileSystemStorage",
    #         "method": "__init__"
    #     },
    #     {
    #         "file": "django/core/files/storage.py",
    #         "class": "FileSystemStorage",
    #         "method": "_save"
    #     },
    #     {
    #         "file": "tests/file_storage/tests.py",
    #         "class": "",
    #         "method": ""
    #     }
    # ]
    # for bug_location in bug_locations:
    #     print(backend.get_bug_loc_snippets(bug_location))

    ## Test class inheritance index
    # backend = SearchBackend(
    #     "/media/media0/yuntong/SWE-bench/testbed/django__django/setup_django__django__4.0"
    # )
    # loc = {
    #     "file": "django/db/models/fields/__init__.py",
    #     "class": "AutoFieldMeta",
    #     "method": "__subclasscheck__",
    # }
    # code = backend.get_bug_loc_snippets(loc)
    # print(code)

    # backend = SearchBackend("/media/media0/yuntong/SWE-bench/testbed/django__django/setup_django__django__3.0")

    # locs = [
    #     {
    #         "file": "django/utils/autoreload.py",
    #         "class": "StatReloader",
    #         "method": "snapshot_files",
    #         "intended_behavior": "The snapshot_files method should take a snapshot of the watched files and their modification times without encountering errors. Specifically, it should handle any file paths that contain unexpected null bytes gracefully, possibly by skipping such paths or logging a warning."
    #     },
    #     {
    #         "file": "django/utils/autoreload.py",
    #         "class": "StatReloader",
    #         "method": "watched_files",
    #         "intended_behavior": "The watched_files method should yield all files that need to be watched for changes without encountering errors. It should ensure that any file paths containing unexpected null bytes are handled gracefully, possibly by skipping such paths or logging a warning."
    #     },
    #     {
    #         "file": "django/utils/autoreload.py",
    #         "class": "StatReloader",
    #         "method": "run_loop",
    #         "intended_behavior": "The run_loop method should run the reloader loop, checking for file changes at regular intervals without encountering errors. It should ensure that any file paths containing unexpected null bytes are handled gracefully, possibly by skipping such paths or logging a warning."
    #     }
    # ]

    # for loc in locs:
    #     print(backend.get_bug_loc_snippets(loc))

    # backend = SearchBackend(
    #     "/media/media0/yuntong/SWE-bench/testbed/astropy__astropy/setup_astropy__astropy__1.3"
    # )
    # locs = [
    #     {
    #         "file": "astropy/wcs/wcs.py",
    #         "class": "WCS",
    #         "method": "_array_converter",
    #         "intended_behavior": "The _array_converter method should handle empty input arrays gracefully and return empty arrays without raising an error. This ensures that when methods like wcs_pix2world are called with empty lists, they return empty lists/arrays instead of raising an InconsistentAxisTypesError.",
    #     },
    #     {
    #         "file": "astropy/wcs/wcs.py",
    #         "class": "WCS",
    #         "method": "wcs_pix2world",
    #         "intended_behavior": "The wcs_pix2world method should utilize the modified _array_converter method to ensure that when it is called with empty lists, it returns empty lists/arrays instead of raising an error. This preserves the existing functionality while handling edge cases of empty inputs.",
    #     },
    # ]

    # for loc in locs:
    #     print(backend.get_bug_loc_snippets(loc))