"""
A proxy agent. Process raw response into json format.
"""

import inspect
from typing import Any

from loguru import logger

from app.data_structures import MessageThread
from app.model import common
from app.post_process import ExtractStatus, is_valid_json
from app.search.search_backend import SearchBackend
from app.utils import parse_function_invocation

PROXY_PROMPT = '''
You will be given a text input that may consist of:
What additional context is needed for a repository-level code generation tasks, and what APIs to call to retrieve the needed context.

Your task is to extract all search-related API calls proposed in the text, and return them in a structured JSON format with the key "API_calls". The value must be:
- A list of API call strings (if search API calls are present or can be inferred)
- `null` (if text does not exist or you believe no search API calls is needed)

**Format Example 1 (calls exist):**
```json
{
  "API_calls": ["search_method(\"fit\")", "search_test_cases(self)"]
}
````

**Format Example 2 (no calls, which means the text input believes the existing context is sufficient):**

```json
{
  "API_calls": null
}
```

List of Valid API Calls:
You must only include calls from the following predefined list. Their arguments are just placeholders and should be stripped as explained below.

search_method_in_class(method_name: str, class_name: str)
search_method_in_file(method_name: str, file_path: str)
search_test_cases(self)
search_method(method_name: str)
search_target_usage_example(example_num: int)
search_relevant_method(top_num: int)
search_class_in_file(class_name: str, file_name: str)
search_class(class_name: str)
search_code_in_file(code_str: str, file_path: str)
search_code(code_str: str)
get_code_around_line(file_path: str, line_number: int, window_size: int)
search_import_in_file(file_name: str)

Argument Formatting Rules:
* Ignore parameter names, keep only positional arguments in valid Python call syntax.
* Replace all `.` in file paths with `/`.
* Example corrections:
  * `search_code(code_str="exp")` → `search_code("exp")`
  * `search_method_in_file("method_name", "path.to.file")` → `search_method_in_file("method_name", "path/to/file")`

Additional Notes:
* Do NOT extract any API calls mentioned outside the List of Valid API Calls.
* Ensure every API call in your output is a syntactically correct Python expression.
* Do not miss any API calls needed in the text input. 

'''


def run_with_retries(text: str, retries=5) -> tuple[str | None, list[MessageThread]]:
    msg_threads = []
    for idx in range(1, retries + 1):
        logger.debug(
            "Trying to convert API calls into json. Try {} of {}.",
            idx,
            retries,
        )
        res_text, new_thread = run(text)
        msg_threads.append(new_thread)
        extract_status, data = is_valid_json(res_text)
            
        if extract_status != ExtractStatus.IS_VALID_JSON:
            logger.debug("Invalid json. Will retry.")
            continue
        
        # To manually guarantee that the json agent won't extract api calls when the text is not asking for any api calls.
        if('search_method_in_class' not in text
           and 'search_method_in_file' not in text
           and 'search_method' not in text
           and 'search_class_in_file' not in text
           and 'search_class' not in text
           and 'search_code_in_file' not in text
           and 'search_code' not in text
           and 'get_code_around_line' not in text
           and 'search_relevant_method' not in text
           and 'search_import_in_file' not in text
           and 'search_test_cases' not in text
           and 'search_target_usage_example' not in text):
            data['API_calls'] = None
            
        valid, diagnosis = is_valid_response(data) # diagnosis is the problem description of is_valid_response func (secound return).
        # print(f'diagnosis: {diagnosis}')
        if not valid:
            logger.debug(f"{diagnosis}. Will retry.")
            continue

        logger.debug("Extracted a valid json.")
        return res_text, msg_threads
    return None, msg_threads


def run(text: str) -> tuple[str, MessageThread]:
    """
    Run the agent to extract issue to json format.
    """

    msg_thread = MessageThread()
    msg_thread.add_system(PROXY_PROMPT)
    msg_thread.add_user(text)
    res_text, *_ = common.SELECTED_MODEL.call(
        msg_thread.to_msg(), response_format="json_object"
    )

    msg_thread.add_model(res_text, [])  # no tools
    return res_text, msg_thread


def is_valid_response(data: Any) -> tuple[bool, str]:
    if not isinstance(data, dict):
        return False, "Json is not a dict"
    
    if data.get("API_calls"):
        for api_call in data["API_calls"]:
            if not isinstance(api_call, str):
                return False, "Every API call must be a string"
            try:
                # TODO: handle the possible incorrect usage of handling escaping for 
                # single or double quotes within strings, especially in search_method_relevance: Done for now
                func_name, func_args = parse_function_invocation(api_call)
            except Exception:
                return False, "Every API call must be of form api_call(arg1, ..., argn). Make sure every parameter is provided!\nFor example, get_code_around_line() requires you to provide 3 parameters: file_path: str, line_number: int, window_size: int. None of them should be missing. You can't leave `line_number` empty."

            function = getattr(SearchBackend, func_name, None)
            if function is None:
                return False, f"the API call '{api_call}' calls a non-existent function"

            # getfullargspec returns a wrapped function when the function defined
            # has a decorator. We unwrap it here.
            while "__wrapped__" in function.__dict__:
                function = function.__wrapped__

            arg_spec = inspect.getfullargspec(function)
            arg_names = arg_spec.args[1:]  # first parameter is self

            if len(func_args) != len(arg_names):
                return False, f"the API call '{api_call}' has wrong number of arguments"
    else:
        if(data == {'API_calls': None}):
            return True, "OK"
        else:
            return False, "Json should only contain one key-value pair. Key is 'API_calls' and the value should be a list"
    return True, "OK"
