"""
This agent selects the search APIs to use, and returns the selected APIs in its response in
non-json format.
"""

import re
from collections.abc import Generator

from loguru import logger

from app.config import config
from app.data_structures import MessageThread
from app.log import print_acr, print_retrieval
from app.model import common, ollama

SYSTEM_PROMPT = """You are an intelligent software developer assigned to a code generation task.

The task description is provided between the tags <issue> and </issue>.  
Your goal is to generate an accurate and well-structured implementation for the target function.

To do this, you should first iteratively invoke search APIs to retrieve relevant code context from the codebase.  
Analyze the retrieved context carefully to understand the target functionality, dependencies, and any useful patterns or examples that can inform your implementation.

Focus solely on generating a high-quality patch.  
You do not need to handle testing or validation — your responsibility is limited to crafting the function implementation based on the available context.
"""

SELECT_PROMPT = (
    "Based on the the task, you can use the following search APIs to get more context:"
    "\n- search_test_cases(): Search for test cases of the target function. Analyzing test cases can help you to refine your solution. These test cases are filtered using dynamic analysis based on pytest. The API will return the test in pytest nodeid format. Based on the pytest nodeid, you can further use other API calls to retrieve the source code of the test cases. You don't need to provide any arguments for `search_test_cases()` API."
    
    "\n- search_import_in_file(file_name: str): Search for top-level import statements in given file `file_name`."
    "\n- search_target_usage_example(example_num: int): Search for a given number (`example_num`) of methods that call the target function directly. This will help you to understand how the target function is acutally used or tested in the codebase. If `example_num` is greater than the total number of usage examples, the API will return all of them."
    "\n- search_test_cases(): Search for test cases of the target function. Analyzing test cases can help you to refine your solution. These test cases are filtered using dynamic analysis based on pytest. The API will return the test in pytest nodeid format. Based on the pytest nodeid, you can further use other API calls to retrieve the source code of the test cases. You don't need to provide any arguments for `search_test_cases()` API."
    "\n- search_relevant_method(top_num: int): Search for the method that is most relevant to the target function's docstring by default. We will return the `top_num` methods with the highest BM25 score. This may give you hints about the implementation of your target function from similar ones."
    "\n- run_pdb_cmd(cmd: str): Execute a specified debugging command (e.g., pdb) within a container terminal. you can carry out line-by-line execution such as variable inspection and stack frame traversal for fine-grained debugging, for example:\n```\nl → list source around the current line\nn → step to the next line (skip into functions)\ns → step into a function\nc → continue execution until the next breakpoint\nb 23 → set a breakpoint at line 23\np var → print value of variable var\nq → quit debugger\n```"
    
    "\n- search_class(class_name: str): search for a class in the codebase. The class signature includes class name, base classes, and signatures for all of its methods/properties."
    "\n- search_class_in_file(class_name:str, file_name: str): Search for class with name `class_name` in given file `file_name`."
    "\n- search_method(method_name: str): Search for a method in the entire codebase."
    "\n- search_method_in_file(method_name: str, file_path: str): Search for method with name `method_name` in file `file_path`."
    "\n- search_method_in_class(method_name: str, class_name: str): Search for method with name `method_name` in class with name `class_name`."
    "\n- search_code(code_str: str): Search for a code snippet in the entire codebase. Only `code_str` is needed."
    "\n- search_code_in_file(code_str: str, file_path: str): Search for code snippets conatining `code_str` in given `file_path`."
    "\n- get_code_around_line(file_path: str, line_number: int, window_size: int): Gets the code around the specified line_number in the file `file_path`. `window_size` is the number of lines before and after `line_number`. Please make sure to provide all 3 parameters."
    "\n\nRemember:"
    "\n\n\tYou MUST provide correct number of arguments when invoking APIs! Do not leave any necessary arguments blank."
    "\n\n\tYou can use multiple APIs in one round."
    "\n\n\tDo not call the same API with the same parameters repeatedly."
    "\n\n\tYou SHOULD NOT generate hallucination code as the API return. We will provide you the searched context next round after you providing the needed APIs."
    "\n\nNow analyze the task and select necessary APIs to get more context. It's better to provide the APIs you need to call and their arguments in your response."
)
    

ANALYZE_PROMPT = (
    "Let's analyze collected context first.\n"
    "If an API call could not find any code, you should think about why and do we need other API calls to collect more context.\n"
    "If an API call returns some result, you should analyze the result and think about these questions:\n"
    "1. Does this information help us understand the expected behavior, input/output, dependencies, or usage patterns of the target function?\n"
    "   - For example: Does it include code that calls the target function? Implements a similar function? Shares utility functions or constants?\n"
    "   - Does it show how inputs are prepared or how outputs are used?\n"
    "   - Does it help clarify edge cases, logic structure, or external interactions (e.g., with files, databases, APIs)?\n"
    "2. Is this context sufficient? If not, what kind of information is missing?\nWhat are the necessary API calls you want to use?\n"
    "   - For example, searching for the source code and input of the test cases;\n\n"
    "Always focus on high-signal information. Even if you've collected a lot of context, prioritize what directly helps reason about or construct the target function."
)


ANALYZE_AND_SELECT_PROMPT = (
    "Based on your analysis, please answer the following:\n"
    "\n"
    "1. **Do we need more context?**\n"
    "   - If yes, list the necessary API calls that should be invoked to gather additional context. These APIs might have already been implied in the analysis above. review carefully.\n"
    "   - If no additional context is needed, LEAVE THIS EMPTY.\n"
    "\n"
    "2. **How should we generate the target function based on the current context?**\n"
    "   - Describe your strategy in detail, as if you are handing this off to another engineer.\n"
    "   - Cover: high-level logic, required variables or inputs, usage of existing functions, expected control flow, edge case handling, and any relevant assumptions.\n"
    "   - suggestion: you can provide proposed implementation of the target if you feel confident enough.\n"
    "   - If you still require more context to proceed, LEAVE THIS EMPTY.\n"
    "\n"
    "Remember: even if a lot of context has been collected, prioritize what's most relevant. Focus on signal over noise, and use your judgment to extract what truly helps implement the function."
)


# TODO: move this to some util class, since other agents may need it as well
def prepare_issue_prompt(problem_stmt: str) -> str:
    """
    Given the raw problem statement, sanitize it and prepare the issue prompt.
    Args:
        problem_stmt (str): The raw problem statement.
            Assumption: the problem statement is the content of a markdown file.
    Returns:
        str: The issue prompt.
    """
    # remove markdown comments
    problem_wo_comments = re.sub(r"<!--.*?-->", "", problem_stmt, flags=re.DOTALL)
    content_lines = problem_wo_comments.split("\n")
    # remove spaces and empty lines
    content_lines = [x.strip() for x in content_lines]
    content_lines = [x for x in content_lines if x != ""]
    problem_stripped = "\n".join(content_lines)
    # add tags
    result = "<issue>" + problem_stripped + "\n</issue>"
    return result


def generator(
    issue_stmt: str
) -> Generator[tuple[str, MessageThread], tuple[str, bool] | None, None]:
    """
    Args:
        - issue_stmt: problem statement
        - sbfl_result: result after running sbfl
    """

    msg_thread = MessageThread()
    msg_thread.add_system(SYSTEM_PROMPT)

    issue_prompt = prepare_issue_prompt(issue_stmt)
    msg_thread.add_user(issue_prompt)

    msg_thread.add_user(SELECT_PROMPT)
    print_acr(SELECT_PROMPT, "context retrieval initial prompt")

    # TODO: figure out what should be printed to console here
    # print_acr(prompt, f"context retrieval round {start_round_no}")

    while True:
        # first call is to select some APIs to call
        logger.debug("<Agent search> Selecting APIs to call.")
        res_text, *_ = common.SELECTED_MODEL.call(msg_thread.to_msg())
        msg_thread.add_model(res_text)
        print_retrieval(res_text, "Model response (API selection)")

        # the search result should be sent here by our backend AST search tool
        generator_input = yield res_text, msg_thread
        assert generator_input is not None
        search_result, re_search = generator_input

        if re_search:
            # the search APIs selected have some issue
            logger.debug(
                "<Agent search> Downstream could not consume our last response. Will retry."
            )
            msg_thread.add_user(search_result)
            continue

        # the search APIs selected are ok and the results are back
        # second call is to analyze the search results
        logger.debug("<Agent search> Analyzing search results.")
        msg_thread.add_user(search_result)
        msg_thread.add_user(ANALYZE_PROMPT)
        print_acr(ANALYZE_PROMPT, "context retrieval analyze prompt")

        res_text, *_ = common.SELECTED_MODEL.call(msg_thread.to_msg())
        msg_thread.add_model(res_text)
        print_retrieval(res_text, "Model response (context analysis)")

        analyze_and_select_prompt = ANALYZE_AND_SELECT_PROMPT
        if isinstance(common.SELECTED_MODEL, ollama.OllamaModel):
            # llama models tend to always output search APIs and buggy locations.
            analyze_and_select_prompt += "\n\nNOTE: If you already feel confident to generate the code, do not make any search API calls."

        msg_thread.add_user(analyze_and_select_prompt)
        print_acr(
            analyze_and_select_prompt, "context retrieval analyze and select prompt"
        )

    