import re
from typing import Any

from agent.agent import Agent
from agent.prompt_constructor import PromptConstructor
from browser_env.actions import create_refine_action
from llms.llm_utils import call_llm
from llms.prompt_utils import get_message
from utils.trajectory_view import TrajectoryView

# TODO: integrate joonhyuk changes


class ImageRefinerPromptConstructor(PromptConstructor):
    def __init__(self, lm_config: dict, agent_config: dict):
        super().__init__(lm_config, agent_config)
        self.max_refine: int = agent_config["max_refine"]

    def construct(
        self,
        trajectory: TrajectoryView,
        objective: str,
        objective_imgs: list,
        meta_data: dict[str, Any],
    ) -> list[list[dict[str, Any]]]:
        # 2. There is no limit to the number of elements you can return. The only requirement is they are worth investigating further in order to progress towards the objective.
        # 2. Aim to return at least 5 elements and up to 10 elements.

        sys_prompt = """You are an autonomous intelligent agent tasked with selecting webpage elements to complete web-based tasks. You will work with a team of other agents to accomplish the task.
        You will think step by step and carefully analyze webpage observations to identify webpage elements that are worth further investigation to help progress towards the objective.
        \n
        ## Here's the information you'll have:
        ### The objective: This is the task your team is trying to accomplish.
        ### The webpage screenshots: These are screenshots of the webpage, with each interactable element assigned a unique numerical ID. Each bounding box and its respective ID shares the same color.
        \n
        ## To be successful, it is crucial to follow the following rules:
        1. Analyze all webpage elements thoroughly, treating each element as potentially significant until proven otherwise.
        2. You should return at least 5 elements and up to 10 elements.
        3. Do not decide a task is impossible or complete. If in doubt, return the most promising candidate elements so the team can further reason over them.
        4. Do not decide on next steps. Your task in ONLY to select ALL potentially useful elements in the current page.
        5. VERY IMPORTANT: Your abilities to understand webpage observations are limited, so strive to select as many elements as possible to avoid leaving out promising ones.
        6. Don't generate anything after the list of candidate elements.
        \n
        ## Provide your answer as following:
        ### REASONING: Reason step by step to come up with a comprehensive list of candidate elements that are worth further investigating to progress towards the objective. Connect the dots between the objective, the current text observations, the webpage screenshots and come up with all potentially useful elements.
        ### CANDIDATE ELEMENTS: A list of IDs with concise descriptions of how each can help advance progress towards the objective.
        [id1]; description
        ...
        [idn]; description
        """
        messages = []

        messages.extend(
            get_message(
                inputs=sys_prompt,
                provider=self.lm_config["provider"],
                role="system",
            )
        )

        # -- examples ---

        # -- interaction history ---

        # -- current image input ---
        imgs = [trajectory.states[-1]["observation"]["image"]]
        prefix_imgs = ["IMAGES: (1) current webpage screenshot at state `t`"]

        # Images defining objective, if any
        if objective_imgs is not None and len(objective_imgs) > 0:
            imgs.extend(objective_imgs)
            for i, img in enumerate(objective_imgs):
                img_caption = meta_data["intent_images_captions"].get(hash(img.tobytes()), "")
                if img_caption:
                    img_caption = f" description: {img_caption}"
                prefix_imgs.append(f"({i + 2}) objective image {i + 1}{img_caption}")

        # -- current text input ---
        template = """{text_obs}Objective: {objective}"""
        if self.use_text_observation:
            text_obs = trajectory.states[-1]["observation"]["text"]
            text_obs = f"\nTEXT OBSERVATION: {text_obs}"
            text_input = template.format(objective=objective, text_obs=text_obs)
        else:
            text_input = template.format(objective=objective, text_obs="")

        # -- current input ---
        msg = get_interleaved_img_txt_msg(
            images=imgs,
            img_captions=meta_data["intent_images_captions"],
            role="user",
            name="",
            text_first=True,
            img_detail=self.img_detail,
        )
        messages.extend(msgs)
        messages.append(get_message(inputs=text_input, role="user", name=""))
        return messages


class ImageRefiner(Agent):
    def __init__(self, lm_config, agent_config) -> None:
        super().__init__()
        self.lm_config = lm_config
        self.prompt_constructor = ImageRefinerPromptConstructor(lm_config=lm_config, agent_config=agent_config)
        self.max_refine = agent_config["max_refine"]
        self.prev_traj_len = 0
        self.refined_imgs = {}

    def select_states(self, trajectory: TrajectoryView, num_states: int = 1) -> list:
        return super().select_states(trajectory, num_states)

    def next_action(self, trajectory: TrajectoryView, intent: str, intent_images, meta_data: dict[str, Any]) -> None:
        # If low-level action failed, use cached refined image if any
        if len(trajectory.actions) > 0 and "invalid" in trajectory.actions[-1]:
            url = trajectory.states[-1]["info"]["page"].url
            if url in self.refined_imgs:
                trajectory.states[-1]["observation"]["image"] = self.refined_imgs[url]  # type: ignore
                return

        # Else, refine image
        refine_count = 0
        # show_observation(trajectory)  # @debugging
        while refine_count < self.max_refine:
            # Select element Ids to keep
            model_response = self.get_model_response(trajectory, intent, intent_images, meta_data)
            element_ids, element_utterances = self.parse_model_response(model_response)

            # Refine image
            refine_action = create_refine_action(ids_keep=element_ids, ids_remove=[])
            obs, _, _, _, info = meta_data["env"].step(refine_action)

            # TODO: Modify a copy of obs instead of the original obs in the trajectory
            trajectory.states[-1]["observation"].update(obs)  # type: ignore
            trajectory.states[-1]["info"].update(info)  # type: ignore
            refine_count += 1

        # TODO: cache refined images; study if saving in textified Tree Format or like this is good enough
        url = trajectory.states[-1]["info"]["page"].url
        self.refined_imgs[url] = trajectory.states[-1]["observation"]["image"]

    def get_model_response(self, trajectory, intent, intent_images, meta_data: dict[str, Any]) -> str:
        prompt = self.prompt_constructor.construct(
            trajectory=trajectory,
            objective=intent,
            objective_imgs=intent_images,
            meta_data=meta_data,
        )
        response = call_llm(
            gen_kwargs=self.lm_config,
            prompt=prompt,
            call_id=meta_data["task_id"],
            safe_genargs_convert=False,
        )

        return response

    def parse_model_response(self, response: str) -> tuple[list[str], list[str]]:
        # Find the CANDIDATE ELEMENTS section
        if "CANDIDATE ELEMENTS:" not in response:
            return [], []
        candidates_section = response.split("CANDIDATE ELEMENTS:")[-1].strip()
        # Parse each line that matches the [id]; description format
        element_ids = []
        element_utterances = []
        for line in candidates_section.split("\n"):
            line = line.strip()
            if not line:
                continue

            # Match [id]; description pattern
            match = re.search(r".*?\[(\d+)\];\s*(.+)", line)

            if match:
                element_id = match.group(1)
                utterance = match.group(2).strip()
                element_ids.append(element_id)
                element_utterances.append(utterance)

        return element_ids, element_utterances

    def reset(self, test_config_file: str) -> None:
        self.refined_imgs = {}
        self.refined_imgs = {}
