import re
from typing import Any

from agent.agent import Agent
from agent.prompt_constructor import LMParsingError, PromptConstructor
from browser_env.actions import (
    Action,
    ActionParsingError,
    ActionTypes,
    create_id_based_action,
    create_none_action,
    create_playwright_action,
)
from browser_env.utils import map_url_to_local, map_url_to_real
from llms.prompt_utils import get_message, get_messages
from llms.tokenizers import Tokenizer
from llms.types import Message
from utils.logger_utils import logger
from utils.string_utils import clean_spaces, safe_format
from utils.trajectory_view import TrajectoryView
from utils.types import ImageInput
from vwa_utils.score_logger import ScoreLogger


class ExecutorPromptConstructor(PromptConstructor):
    def __init__(self, lm_config: dict[str, Any], agent_config: dict[str, Any]):
        super().__init__(lm_config, agent_config)
        self.action_splitter: str = self.instruction["meta_data"]["action_splitter"]
        self.template_cur_observation: str = self.instruction["template_cur_observation"]
        self.ins_metadata: dict[str, Any] = self.instruction["meta_data"]
        self.parsing_error_msg_key: str = self.instruction["meta_data"]["parsing_error_msg_key"]
        self.parsing_error_msg_template: str = self.instruction["parsing_error_msg_template"]
        self.use_web_k = agent_config.get("use_web_k", False)

    def construct(
        self,
        trajectory: TrajectoryView,
        objective: str,
        objective_imgs: list[ImageInput],
        meta_data: dict[str, Any],
        idxs_trajectory: list[int] = [],
        **kwargs: Any,
    ) -> list[Message]:
        messages: list[Message] = []

        # System prompt
        if sys_prompt := self.instruction.get("system_prompt", ""):
            messages.append(get_message(inputs=sys_prompt, role="system"))

        # --- examples ---
        # If Gemini, add a hint that examples are coming (cannot hint with name like 'example_user')
        if "gemini" in self.lm_config["model"]:
            messages.append(
                get_message(
                    inputs="The following are examples of observations and the corresponding responses:\n",
                    role="user",
                    name=self.lm_config.get("name_user", ""),
                )
            )

        examples: list[list[dict[str, Any]]] = self.select_examples()
        # examples = [example1, example2, ...];
        # example = [{"text": str, "image": str, "utterance": str}, {"text": str, "image": str, "utterance": str}, ...]
        for _, example in enumerate(examples):
            intro_img_obs_user = "IMAGES: (1) current page screenshot"  # TODO: refactor this away
            for state in example:
                user_msg = get_message(
                    inputs=[state["text"], intro_img_obs_user, state["image"]],
                    role="user",
                    name="example_user",
                    img_detail=self.lm_config.get("img_detail", "auto"),
                )
                assistant_msg = get_message(inputs=state["utterance"], role="assistant", name="example_assistant")
                messages.append(user_msg)
                messages.append(assistant_msg)

        if self.use_web_k and meta_data.get("web_knowledge"):
            messages.append(
                get_message(
                    inputs=f"""## General web knowledge:\n{meta_data["web_knowledge"]}""",
                    role="user",
                    name=self.lm_config.get("name_user", ""),
                )
            )

        # --- Execution history ---
        prompt_prev_utterances = ""
        if self.ins_metadata["history_type"] == "interaction_history":
            messages.extend(self.build_interaction_history(trajectory, meta_data, idxs_trajectory))

        elif self.ins_metadata["history_type"] == "rationale_action":
            prompt_prev_utterances = self.build_rationale_action_history(trajectory, meta_data, idxs_trajectory)
        else:
            raise ValueError(f"Unknown history type: {self.ins_metadata['history_type']}")

        # -- Build current text input --
        text_input = self.build_current_text_input(trajectory, meta_data, objective, prompt_prev_utterances)
        # If the last action is invalid, add the error message to the text input
        if len(trajectory.actions) > 0 and "invalid" in trajectory.actions[-1]:
            text_input += self.instruction["env_parsing_error_msg_template"].format(
                error_msg=meta_data["action_str_history"][-1]
            )

        if meta_data.get("critique_feedback"):
            text_input += self.instruction["feedback_template"].format(
                previous_response=meta_data["prev_utterance_for_feedback"],
                feedback=meta_data["critique_feedback"],
            )
            # Clear feedback fields; TODO: change this logic
            meta_data["critique_feedback"] = ""
            meta_data["prev_utterance_for_feedback"] = ""

        messages.extend(
            self.build_intent_message(
                trajectory,
                text_input,
                objective_imgs,
                meta_data,
                objective_image_captions=self.get_image_captions(objective_imgs, meta_data),
                state_img_intros=["current webpage screenshot at state `t`"],
                add_states_idxs=[-1],
                add_state_text=False,
                add_state_img=True,
                role="user",
                name=self.lm_config.get("name_user", ""),
            )
        )
        return messages

    # TODO: dynamic examples
    def select_examples(self) -> list[list[dict[str, Any]]]:
        return self.instruction["examples"]  # type: ignore

    def build_current_text_input(
        self, trajectory: TrajectoryView, meta_data: dict[str, Any], objective: str, prompt_prev_utterances: str = ""
    ) -> str:
        state_info = trajectory.states[-1]
        text_obs = state_info["observation"]["text"]
        url = state_info["info"]["page"].url

        if not self.use_text_observation:
            # parse Tab info
            first_line_pattern = r"^Tab.*"
            first_line = re.search(first_line_pattern, text_obs, re.MULTILINE)
            if first_line:
                text_obs = clean_spaces(first_line.group(0))
            else:
                text_obs = ""

        if self.ins_metadata["history_type"] == "rationale_action":
            text_input = safe_format(
                string_template=self.template_cur_observation,
                rationale_action_history=prompt_prev_utterances,
                text_obs=text_obs,
                url=map_url_to_real(url),
                objective=objective,
            )
        else:
            text_input = safe_format(
                string_template=self.template_cur_observation,
                fill_with="",
                text_obs=text_obs,
                url=map_url_to_real(url),
                objective=objective,
                previous_action=clean_spaces(meta_data["action_str_history"][-1]),
            )
        return text_input

    def extract_action(self, response: str) -> str:
        # find the first occurence of action
        action_splitter = self.action_splitter
        pattern = rf"{action_splitter}((.|\n)*?){action_splitter}"
        match = re.search(pattern, response, re.IGNORECASE)
        if match:
            return match.group(1).strip()
        else:
            raise LMParsingError(f'Cannot find the action identifier "{action_splitter}" in "{response}"')


class ExecutorAgent(Agent):
    """prompt-based agent that emits action given the history"""

    def __init__(self, lm_config: dict[str, Any], agent_config: dict[str, Any]) -> None:
        super().__init__()
        self.lm_config = lm_config
        self.tokenizer = Tokenizer(model_name=lm_config["model"])
        self.num_previous_state_actions = agent_config["num_previous_state_actions"]
        self.action_set_tag = agent_config["action_set_tag"]
        self.max_model_call = agent_config["max_model_call"]
        self.out_utterance = agent_config.get("out_utterance", True)
        self.conversation_dir = agent_config.get("conversation_dir", None)
        self.usage_dir = agent_config.get("usage_dir", None)
        self.prompt_constructor: ExecutorPromptConstructor = ExecutorPromptConstructor(
            lm_config=lm_config, agent_config=agent_config
        )
        self.agent_config = agent_config
        self.score_logger = ScoreLogger()
        self.use_web_knowledge = agent_config.get("use_web_k", False)

    def select_states(self, trajectory: TrajectoryView, num_states: int = -1) -> list[int]:
        # Return indices of all states except the last one
        num_states = min(self.num_previous_state_actions, len(trajectory.states) - 1)
        return list(range(len(trajectory.states) - num_states - 1, len(trajectory.states) - 1))

    def next_action(
        self,
        trajectory: TrajectoryView,
        intent: str,
        intent_images: list[ImageInput],
        meta_data: dict[str, Any],
        log_score: bool = True,
    ) -> Action:
        idxs_trajectory = self.select_states(trajectory)

        total_tries = self.max_model_call
        while True:
            # Call model and parse response. Obs.: This is just text parsing, done by the PromptConstructor.
            try:
                parsed_response, raw_response, num_tries = self.act_parse_retry(
                    trajectory=trajectory,
                    intent=intent,
                    intent_images=intent_images,
                    meta_data=meta_data,
                    idxs_trajectory=idxs_trajectory,
                    parser_fn=self.prompt_constructor.extract_action,
                    max_tries=total_tries,
                    error_msg_template=self.prompt_constructor.parsing_error_msg_template,
                    error_msg_key=self.prompt_constructor.parsing_error_msg_key,
                )
                self.dump_html = False  # Dump only one time to save execution time

            # Failed to parse the action with the splitters more than `max_model_call` times
            except LMParsingError as e:
                # If fails more than `max_model_call` times, create a NONE action for environment feedback in next iteration
                raw_response = e.raw_response
                logger.info(e.message)
                action = create_none_action()
                action.update({"raw_prediction": raw_response, "wait_for": 0, "early_stop": e.message})  # type: ignore
                break

            total_tries = total_tries - num_tries

            # Create action. Obs.: This is environment-specific; environment will parse the action.
            try:
                action = self.create_action(parsed_response)
                action.update({"raw_prediction": raw_response})
                if action["action_type"] == ActionTypes.STOP and log_score and self.score_logger:
                    self.score_logger.log_scores_per_round(
                        TrajectoryView(trajectory.trajectory + [action]), intent, meta_data
                    )
                break

            except ActionParsingError as _:
                # If fails more than `max_model_call` times, create a NONE action for environment feedback in next iteration
                if total_tries <= 0:
                    action = create_none_action()
                    action.update({"raw_prediction": raw_response, "wait_for": 0})  # type: ignore
                    break

        logger.info(f"\n[Executor Agent]: {raw_response}") if self.out_utterance else None
        return action

    def create_action(self, action_str: str) -> Action:
        if self.action_set_tag == "id_accessibility_tree":
            action = create_id_based_action(action_str)
        elif self.action_set_tag == "playwright":
            action = create_playwright_action(action_str)
        elif self.action_set_tag == "som":
            action = create_id_based_action(action_str)
        else:
            raise ValueError(f"Unknown action type {self.action_set_tag}")
        action["parsed_action"] = action_str  # type: ignore
        return action

    def get_action_splitter(self, agent_name: str = "executor") -> str:
        return self.prompt_constructor.instruction["meta_data"]["action_splitter"]  # type:ignore

    def get_tokenizer(self, agent_name: str = "executor") -> Tokenizer:
        return self.prompt_constructor.tokenizer

    def set_action_set_tag(self, tag: str) -> None:
        self.action_set_tag = tag

    def reset(self, test_config_file: str) -> None:
        super().reset(test_config_file)
        super().reset(test_config_file)
