"""
An agent, which is only responsible for the write_patch tool call.
"""

from collections import defaultdict
from collections.abc import Generator
from copy import deepcopy
from os.path import join as pjoin
from pathlib import Path
from tempfile import NamedTemporaryFile
from typing import TypeAlias
from app.config import config
from loguru import logger

from app.agents import agent_common
from app.agents.agent_common import InvalidLLMResponse
from app.data_structures import BugLocation, MessageThread
from app.log import print_acr, print_code_generation
from app.model import common
from app.post_process import (
    ExtractStatus,
    convert_response_to_diff,
    extract_diff_one_instance,
    record_extract_status,
)
from app.search.search_manage import SearchManager
from app.task.task import Task
from app.agents.patch_utils import fix_function_body_indentation

SYSTEM_PROMPT = """You are a software developer working on a code generation task.
The task contains a description marked between <issue> and </issue>.
Other developers have already located the target function, collected and analyzed the code context related to the task for you.
Your task is to generate the function body of the target function based on their analysis.
REMEMBER:
- Make sure the generated code ONLY contains the code (function body) within the target function, no extra.
- Your code should preserve the program functionality as much as possible.
- Make sure your generated code has THE RIGHT INDENTION!!!
"""


USER_PROMPT_INIT = """Now, generate the **function body** of the target function based on the provided analysis and any available relections from the test feedback.

First, explain your reasoning. Then, generate the actual code.

Please follow these instructions carefully:
-  Only generate the **function body** — do NOT include the function signature, docstring, or any extra content. The output will be directly inserted into the target function during testing. you can generate comments withthin the code, if it helps you to generate better solutions
-  Do NOT include line numbers at the beginning of each line.
-  Ensure your code is properly indented to match the indentation level of the target.
-  Pay close attention to prior analysis or test feedback that could guide your implementation.

Format your response like this:

<reason>
...your explanation of how you derived the implementation...
</reason>
<code>
...your generated function body (with correct indentation)...
</code>

Make sure:
- The explanation in `<reason>` clearly connects context to implementation choices.
- The code in `<code>` is **ready for direct insertion** and properly formatted.
"""


PatchHandle: TypeAlias = str


class PatchAgent:
    EMPTY_PATCH_HANDLE = "EMPTY"

    def __init__(
        self,
        task: Task,
        search_manager: SearchManager,
        issue_stmt: str,
        context_thread: MessageThread,
        bug_locs: list[BugLocation],
        task_dir: str,
    ) -> None:
        self.task = task
        self.search_manager = search_manager
        self.issue_stmt = issue_stmt
        self.context_thread = context_thread  # the search conv historh thread
        self.bug_locs: list[BugLocation] = bug_locs
        self.task_dir = task_dir

        self._request_idx: int = -1
        self._responses: dict[PatchHandle, str] = {}
        self._diffs: dict[PatchHandle, str] = {}
        self._feedbacks: dict[PatchHandle, list[str]] = defaultdict(list)
        self._history: list[PatchHandle] = []

    def write_applicable_code_without_feedback( 
        self, entire_thread, retries: int = 3
    ) -> tuple[PatchHandle, str, str, MessageThread]: # patch_handle, diff_content, raw_response, entire_thread
        return self._write_applicable_code(entire_thread, max_feedbacks=0, retries=retries)

    def write_applicable_code_with_feedback(
        self, entire_thread, max_feedbacks: int = 1, retries: int = 3
    ) -> tuple[PatchHandle, str, str, MessageThread]: # patch_handle, diff_content, raw_response, entire_thread
        return self._write_applicable_code(
            entire_thread, max_feedbacks=max_feedbacks, retries=retries
        )
        # NOTE: with / without feedback, difference is the value of max_feedbacks

    def add_feedback(self, handle: PatchHandle, feedback: str) -> None:
        if handle not in self._diffs:
            raise ValueError("patch {} does not exist", handle)

        self._feedbacks[handle].append(feedback)

    def _write_applicable_code(
        self, entire_thread, max_feedbacks: int, retries: int
    ) -> tuple[PatchHandle, str, str, MessageThread]:
        max_feedbacks = max_feedbacks if max_feedbacks >= 0 else len(self._history)
        num_feedbacks = min(max_feedbacks, len(self._history))
        history_handles = self._history[-num_feedbacks:]
        for _ in range(retries):
            applicable, response, diff_content, thread = self._write_code(
                entire_thread, history_handles
            ) # diff_content is the patch content
            self._request_idx += 1
            print_code_generation(response)
            Path(self.task_dir, f"patch_raw_{self._request_idx}.md").write_text(
                response
            )
            thread.save_to_file(
                Path(self.task_dir, f"conv_patch_{self._request_idx}.json")
            )
            msg = "Patch is applicable" if applicable else "Patch is not applicable"
            print_acr(msg)
            if applicable:
                print_acr(f"```diff\n{diff_content}\n```", "Extracted patch")
                handle = self._register_applicable_patch(response, diff_content)
                
                return handle, diff_content, response, thread

        raise InvalidLLMResponse(
            f"Failed to write an applicable patch in {retries} attempts"
        )

    def _write_code(
        self,
        entire_thread: MessageThread,
        history_handles: list[PatchHandle] | None = None,
    ) -> tuple[bool, str, str, MessageThread]:
        history_handles = history_handles or []

        # self.context_thread = self._construct_init_thread()

        is_first_try = not any(handle in self._feedbacks for handle in history_handles)

        logger.debug(f"<agent write code> is_first_try: {is_first_try}")
        # for handle in history_handles:
        #     feedbacks = self._feedbacks.get(handle, [])
        #     if not feedbacks:
        #         logger.warning("patch {} does not have a feedback; skipping", handle)
        #         continue
        #     thread.add_model(self._responses[handle], [])
        #     for feedback in feedbacks:
        #         thread.add_user(feedback)
        
        entire_thread.add_user(USER_PROMPT_INIT)
        if not history_handles:
            print_acr(USER_PROMPT_INIT)
            
        print('========== thread before generation in PatchAgent._write_code ==========') 
        print(entire_thread) # all good in here
        print('\n====================\n')
        # exit()
        
        patch_resp, *_ = common.SELECTED_MODEL.call(entire_thread.to_msg()) # patch_resp is str from LLM. 
        # here, we add a module to guarantee the model don't make indentation error
        patch_resp = fix_function_body_indentation(self.task, patch_resp)
        entire_thread.add_model(patch_resp)
        # self.bug_locs: [<file></file><class></class><code></code>(with line num)]
        extract_status, msg, diff_content = convert_response_to_diff(
            patch_resp, self.task_dir, self.bug_locs
        ) 
        
        if extract_status != ExtractStatus.APPLICABLE_PATCH:
            print(f"HUYIRAN extract_status: {extract_status}")
            print(patch_resp)
            print(msg)
            exit()
            
        record_extract_status(self.task_dir, extract_status) 
        return (
            extract_status == ExtractStatus.APPLICABLE_PATCH,
            patch_resp,
            diff_content,
            entire_thread,
        )

    def _construct_init_thread(self) -> MessageThread:
        """
        Construct the initial patch gen conv thread, based on whether bug location is available.
        """
        if self.bug_locs: # bug location is available
            # edit this part to maintain the whole trajectory
            thread = self.context_thread.copy()
            thread = self._construct_code_context_prompt(thread)
        else:
            print('HUYIRAN: in _construct_init_thread, the code generation location is empty!')
            # bug location not there; we use the search conv history to at least get some context
            thread = self.context_thread.copy()
            thread = agent_common.replace_system_prompt(thread, SYSTEM_PROMPT)

        return thread

    def _construct_code_context_prompt(self, thread: MessageThread) -> MessageThread:
        # prompt = "Here are the target function locations collected by someone else.\n"
        # prompt += BugLocation.multiple_locs_to_str_for_model(self.bug_locs)        
        # thread.add_user(prompt)
        return thread
    

    def _register_applicable_patch(
        self, response: str, diff_content: str
    ) -> PatchHandle:
        handle = str(self._request_idx)

        assert handle not in self._responses
        assert handle not in self._feedbacks
        assert handle not in self._diffs
        assert handle not in self._history

        self._responses[handle] = response
        self._diffs[handle] = diff_content
        self._history.append(handle)

        return handle
