import inspect
import json
from collections.abc import Mapping
from os.path import join as pjoin
from pathlib import Path

from loguru import logger

from app.config import config
from app.agents import agent_proxy
from app.debugger import agent_search_in_debugger
from app.data_structures import BugLocation, MessageThread, SearchResult
from app.log import print_acr, print_banner
from app.search.search_backend import SearchBackend
from app.search.search_utils import *
from app.task.task import Task
from app.task.raw_tasks import RawLocalTask
from app.utils import parse_function_invocation


class Deubgger_SearchManager:
    def __init__(self, project_path: str, output_dir: str, task: Task, prefix_thread: MessageThread):
        # output dir for writing search-related things
        self.output_dir = pjoin(output_dir, "search")
        Path(self.output_dir).mkdir(parents=True, exist_ok=True)

        # record the search APIs being used, in each layer
        self.tool_call_layers: list[list[Mapping]] = []
        self.task = task
        self.backend: SearchBackend = SearchBackend(project_path, self.task, config.test_in_refine)
        self.prefix_thread = prefix_thread

 
    def search_iterative(
        self
    ) -> MessageThread:
        """
        Main entry point of the search manager in debugging stage.
        Returns:
            - Class context code as string, or None if there is no context
            - The message thread that contains the search conversation.
        """
        search_api_generator = agent_search_in_debugger.generator(
            self.prefix_thread
        )
        # input to generator, should be (search_result_msg, re_search)
        # the first item is the results of search sent from backend
        # the second item is whether the agent should select APIs again, or proceed to analysis
        generator_input = None
        round_no = 0
        search_msg_thread: MessageThread | None = None  # for typing
        
        for round_no in range(config.conv_round_limit_in_debugger):
            # if round_no == 1:
            #     print('HUYIRAN: exit when round_no == 1')
            #     exit()
            self.start_new_tool_call_layer()
            print_banner(f"DEBUGGER RETRIEVAL ROUND {round_no}")
            # invoke agent search to choose search APIs
            agent_search_response, search_msg_thread = search_api_generator.send(
                generator_input
            )
            # : agent_search_response is the response of retrueval agent, including chain of thoughts and related api calls, code gen locations
            conversation_file = Path(self.output_dir, f"Debugger_search_round_{round_no}.json")
            # save current state before starting a new round
            search_msg_thread.save_to_file(conversation_file)

            # : extract json API calls from the raw response, using agents, no location updating for code gen
            selected_apis, proxy_threads = agent_proxy.run_with_retries(
                agent_search_response
            )
            
            logger.debug("Agent proxy return the following json: {}", selected_apis)
            # ane example of selected_apis: str, {"API_calls": ["search_method_relevance('provided docstring')"]}

            proxy_msg_log = Path(self.output_dir, f"agent_proxy_{round_no}.json")
            proxy_messages = [thread.to_msg() for thread in proxy_threads]
            proxy_msg_log.write_text(json.dumps(proxy_messages, indent=4))
            
            if selected_apis is None:
                # agent search response could not be propagated to backend;
                # ask it to retry
                logger.debug(
                    "Could not extract API calls from agent search response, asking search agent to re-generate response."
                )
                search_result_msg = "The search API calls seem not valid. Please check the arguments you give carefully and try again."
                generator_input = (search_result_msg, True)
                continue
               
            # there are valid search APIs - parse them
            selected_apis_json: dict = json.loads(selected_apis)
            # : Only get api calls for now, can be an empty list
            json_api_calls = selected_apis_json.get("API_calls", [])
            
            formatted = []
            
            if json_api_calls is None: 
                # it means the finish of the retrieval part, next for patch generation.
                return search_msg_thread
            
            if json_api_calls:
                formatted.append("API calls:")
                for call in json_api_calls:
                    formatted.extend([f"\n- `{call}`"])

            print_acr("\n".join(formatted), "Agent-selected API calls")
            
            if len(json_api_calls) > 0:
                # still need context search - send backend result and go to next round
                # pass anyway to test the entire pipeline, for now
                collated_search_res_str = ""
                for api_call in json_api_calls:
                    func_name, func_args = parse_function_invocation(api_call)
                    # TODO: there are currently duplicated code here and in agent_proxy.
                    func_unwrapped = getattr(self.backend, func_name)
                    while "__wrapped__" in func_unwrapped.__dict__:
                        func_unwrapped = func_unwrapped.__wrapped__
                    arg_spec = inspect.getfullargspec(func_unwrapped)
                    arg_names = arg_spec.args[1:]  # first parameter is self

                    assert len(func_args) == len(
                        arg_names
                    ), f"Number of argument is wrong in API call: {api_call}"

                    kwargs = dict(zip(arg_names, func_args))

                    function = getattr(self.backend, func_name)
                    result_str, _, call_ok = function(**kwargs) # we ignored the search_result here.
                    collated_search_res_str += f"Result of {api_call}:\n\n"
                    collated_search_res_str += result_str + "\n\n"
                    # record the api calls made and the call status
                    self.add_tool_call_to_curr_layer(func_name, kwargs, call_ok)

                print_acr(collated_search_res_str, f"context retrieval round {round_no}")
                # send the results back to the search agent
                logger.debug(
                    "Obtained search results from API invocation. Going into next retrieval round."
                )
                search_result_msg = collated_search_res_str
                generator_input = (search_result_msg, False)

        # used up all the rounds, but could not return the buggy locations
        logger.info("Too many rounds. Try writing patch anyway.")
        assert search_msg_thread is not None
        return search_msg_thread

        
    
    def start_new_tool_call_layer(self):
        self.tool_call_layers.append([])

    def add_tool_call_to_curr_layer(
        self, func_name: str, args: dict[str, str], result: bool
    ):
        self.tool_call_layers[-1].append(
            {
                "func_name": func_name,
                "arguments": args,
                "call_ok": result,
            }
        )

    def dump_tool_call_layers_to_file(self):
        """Dump the layers of tool calls to a file."""
        tool_call_file = Path(self.output_dir, "tool_call_layers.json")
        tool_call_file.write_text(json.dumps(self.tool_call_layers, indent=4))


# if __name__ == "__main__":
#     manager = SearchManager("/tmp", "/tmp/one")
#     func_name = "search_code"
#     func_args = {"code_str": "_separable"}

#     # func_name = "search_class"
#     # func_args = {"class_name": "ABC"}

#     function = getattr(manager.backend, func_name)

#     while "__wrapped__" in function.__dict__:
#         function = function.__wrapped__
#     arg_spec = inspect.getfullargspec(function)

#     print(arg_spec)
#     arg_names = arg_spec.args[1:]  # first parameter is self
#     kwargs = func_args

#     orig_func = getattr(manager.backend, func_name)
#     search_result, _, call_ok = orig_func(**kwargs)
