import re
from typing import Any

from agent.agent import Agent
from agent.prompt_constructor import LMParsingError, PromptConstructor
from llms.prompt_utils import get_message, get_messages
from llms.types import Message
from utils.logger_utils import logger
from utils.string_utils import clean_spaces
from utils.trajectory_view import TrajectoryView
from utils.types import ImageInput
from vwa_utils.score_logger import hard_verify


class CriticPromptConstructor(PromptConstructor):
    mode: str

    def __init__(self, lm_config: dict[str, Any], agent_config: dict[str, Any]):
        super().__init__(lm_config, agent_config)
        self.mode = agent_config["mode"]
        self.parsing_error_msg_key: str = self.instruction["meta_data"]["parsing_error_msg_key"]
        self.eval_key: str = self.instruction["meta_data"]["eval_key"]
        self.eval_scores: list[str] = self.instruction["meta_data"]["eval_scores"]
        self.parsing_error_msg_template: str = self.instruction["parsing_error_msg_template"]
        self.expert: bool = agent_config.get("expert", False)

    def build_execution_history(
        self, trajectory: TrajectoryView, meta_data: dict[str, Any], idxs_trajectory: list[int]
    ) -> list[Message]:
        # TODO: add other types of execution history prompt
        # Without utterances, parsed actions instead of utterances, etc.
        return self.build_interaction_history(trajectory, meta_data, idxs_trajectory)

    def construct(
        self,
        trajectory: TrajectoryView,
        objective: str,
        objective_imgs: list[ImageInput],
        meta_data: dict[str, Any],
        idxs_trajectory: list[int] = [],
        first_pass_utterance: str = "",
        second_pass: bool = False,
    ) -> list[Message]:
        messages: list[Message | str | ImageInput] = []

        # -----------------------------------------------------------------------
        # System prompt, evaluation prompt
        # -----------------------------------------------------------------------
        sys_prompt: str = self.instruction[f"system_prompt_{self.mode}"]
        eval_prompt: str = self.instruction[f"eval_prompt_{self.mode}"]

        # -----------------------------------------------------------------------
        # Examples
        # -----------------------------------------------------------------------
        # TODO

        # -----------------------------------------------------------------------
        # Build intent input
        # -----------------------------------------------------------------------
        # text intro for objective
        text_input = self.instruction["objective_template"].format(objective=objective)

        # Add intent images and corresponding textual prefixes
        state_img_intros, add_state_idxs = [], []

        # If two pass, and is the first pass, initial page is necessary to define the intent
        if "two_pass" in self.mode and not second_pass:
            add_state_idxs.append(0)
            if self.use_img_observation:
                state_img_intros.append("description: Initial webpage screenshot")

            if self.use_text_observation:
                text_input += f"\nTEXT OBSERVATION: {trajectory.states[0]['observation']['text']}"

        messages.extend(
            self.build_intent_message(
                trajectory,
                text_input,
                objective_imgs,
                meta_data,
                objective_image_captions=self.get_image_captions(objective_imgs, meta_data),
                state_img_intros=state_img_intros,
                add_states_idxs=add_state_idxs,
                add_state_text=False,
                add_state_img=True,
                role="user",
                name=self.lm_config.get("name_user", ""),
            )
        )

        # -----------------------------------------------------------------------
        # Execution history
        # -----------------------------------------------------------------------
        if "one_pass" in self.mode or second_pass:
            execution_history_msgs = self.build_interaction_history(trajectory, meta_data, idxs_trajectory)
            messages.extend(execution_history_msgs)

        # -----------------------------------------------------------------------
        # Get Critique Request
        # -----------------------------------------------------------------------
        # One pass mode: return request
        if "one_pass" in self.mode:
            # Parsing error feedback
            parsing_error_msg = ""
            if meta_data.get(self.parsing_error_msg_key):
                parsing_error_msg = meta_data[self.parsing_error_msg_key]
                meta_data[self.parsing_error_msg_key] = ""

            messages.append(
                get_message(
                    inputs=eval_prompt + parsing_error_msg,
                    role="user",
                    name=self.lm_config.get("name_user", ""),
                )
            )

            return get_messages(
                inputs=messages,
                sys_prompt=sys_prompt,
                role="user",
                name=self.lm_config.get("name_user", ""),
            )

        # Two pass mode: build first_pass request or evaluation request
        if not second_pass:
            # Two pass mode - First pass. Prompt critique for knowledge
            if self.expert:
                k_retrieval_prompt = self.instruction["k_retrieval_expert"]
                sys_prompt = self.instruction["system_prompt_k_expert"]
            else:
                k_retrieval_prompt = self.instruction["knowledge_retrieval_prompt"]
            if meta_data.get(self.parsing_error_msg_key):
                k_retrieval_prompt = k_retrieval_prompt + meta_data[self.parsing_error_msg_key]
            messages.append(k_retrieval_prompt)
        else:
            # Two pass mode - Second pass. Prompt critic for evaluation
            # Knowledge injection prompt
            knowledge_injection_prompt = ""
            if first_pass_utterance:
                knowledge_injection_prompt = self.instruction["knowledge_injection_prompt"].format(
                    knowledge_retrieval_response=first_pass_utterance,
                )
            messages.append(knowledge_injection_prompt)

            # Parsing error feedback
            parsing_error_msg = ""
            if meta_data.get(self.parsing_error_msg_key):
                parsing_error_msg = meta_data[self.parsing_error_msg_key]
                meta_data[self.parsing_error_msg_key] = ""

            messages.append(eval_prompt + parsing_error_msg)

        return get_messages(
            inputs=messages,
            sys_prompt=sys_prompt,
            role="user",
            name=self.lm_config.get("name_user", ""),
        )

    def parse_first_pass(self, response: str) -> str:
        splitter = self.instruction["meta_data"]["splitters_first_pass"][0]
        if not splitter:
            return clean_spaces(response)
        pattern = rf"{re.escape(splitter)}(.*)"
        match = re.search(pattern, response, re.IGNORECASE)
        if match:
            return clean_spaces(match.group(1))
        else:
            raise LMParsingError(f"Cannot find {splitter} in {response}")

    def parse_second_pass(self, response: str) -> dict[str, Any]:
        splitters: list[str] = self.instruction["meta_data"]["splitters_second_pass"]
        parsed_data: dict[str, Any] = {}
        splitters_group: str = "|".join(map(re.escape, splitters))

        # Iterate over splitters and parse corresponding sections
        for splitter in splitters:
            # Use regex to extract the section content
            pattern = rf"{re.escape(splitter)}(.*?)(?=\n(?:{splitters_group})|$)"
            match = re.search(pattern, response, re.IGNORECASE | re.DOTALL)
            if match:
                content = match.group(1).strip()
                if self.eval_key in splitter:
                    # Map evaluation criteria like SUCCESS or FAILURE
                    parsed_data[self.eval_key] = self.parse_evaluation(content)
                else:
                    # General parsing for all sections, including COMPARISON
                    parsed_data[re.sub(r":$", "", splitter)] = content  # Remove trailing ":"

        if not all(req_splitter in parsed_data for req_splitter in self.instruction["meta_data"]["required_splitters"]):
            raise LMParsingError(f"Cannot find all required splitters in {response}")

        return parsed_data

    def parse_evaluation(self, content: str) -> str:
        """
        Extracts the evaluation score (e.g., SUCCESS, FAILURE) from the EVALUATION section.
        """
        eval_scores = "|".join(map(re.escape, self.eval_scores))  # Join criteria with OR operator
        status_pattern = rf"(?i)\s*({eval_scores})"

        match = re.search(status_pattern, content)
        if match:
            return match.group(1).upper()  # Normalize to uppercase

        else:
            raise LMParsingError(f"Cannot determine evaluation score in: {content}")


class CriticAgent(Agent):
    def __init__(self, lm_config: dict[str, Any], agent_config: dict[str, Any]) -> None:
        super().__init__()
        self.lm_config = lm_config
        self.prompt_constructor: CriticPromptConstructor = CriticPromptConstructor(
            lm_config=lm_config, agent_config=agent_config
        )
        self.agent_config = agent_config
        self.num_previous_state_actions = agent_config["num_previous_state_actions"]
        self.mode = agent_config["mode"]
        self.out_utterance = agent_config["out_utterance"]
        self.success_keyword = self.prompt_constructor.instruction["meta_data"]["success_keyword"]
        self.max_model_call = agent_config["max_model_call"]
        self.max_critique_executor_loop = agent_config["max_critique_executor_loop"]
        self.retrieved_knowledge = ""
        self.conversation_dir = agent_config.get("conversation_dir", None)
        self.usage_dir = agent_config.get("usage_dir", None)
        self.score_per_round: dict[str, Any] = {}

    def select_states(self, trajectory: TrajectoryView, num_states: int = -1) -> list[int]:
        # Return indices of the last `num_previous_state_actions` states
        num_states = min(self.num_previous_state_actions, len(trajectory.states))
        return list(range(len(trajectory.states) - num_states, len(trajectory.states)))

    def retrieve_web_knowledge(
        self,
        trajectory: TrajectoryView,
        intent: str,
        intent_images: list[ImageInput],
        meta_data: dict[str, Any],
    ) -> tuple[str, str, int]:
        idxs_trajectory = self.select_states(trajectory)
        logger.info(f"Critique Agent, retrieving knowledge...")
        parsed_response_first_pass, raw_response, num_tries = self.act_parse_retry(
            trajectory=trajectory,
            intent=intent,
            intent_images=intent_images,
            meta_data=meta_data,
            idxs_trajectory=idxs_trajectory,
            parser_fn=self.prompt_constructor.parse_first_pass,
            max_tries=self.max_model_call,
            error_msg_template=self.prompt_constructor.instruction["parsing_error_msg_template"],
            error_msg_key=self.prompt_constructor.instruction["meta_data"]["parsing_error_msg_key"],
            first_pass_utterance="",
            second_pass=False,
        )
        logger.info(f"Retrieved Knowledge: {parsed_response_first_pass}")
        self.retrieved_knowledge = parsed_response_first_pass

        return parsed_response_first_pass, raw_response, num_tries

    def next_action(
        self,
        trajectory: TrajectoryView,
        intent: str,
        intent_images: list[ImageInput],
        meta_data: dict[str, Any],
    ) -> tuple[Any, str]:
        idxs_trajectory = self.select_states(trajectory)

        if "one_pass" in self.mode:
            logger.info(f"Critique Agent, evaluating...")
            parsed_response, raw_response, num_tries = self.act_parse_retry(
                trajectory=trajectory,
                intent=intent,
                intent_images=intent_images,
                meta_data=meta_data,
                idxs_trajectory=idxs_trajectory,
                parser_fn=self.prompt_constructor.parse_second_pass,
                max_tries=self.max_model_call,
                error_msg_template=self.prompt_constructor.instruction["parsing_error_msg_template"],
                error_msg_key=self.prompt_constructor.instruction["meta_data"]["parsing_error_msg_key"],
            )
            self.dump_html = False  # Dump only one time to save execution time
            if self.out_utterance:
                logger.info(f"\n[Critique Agent]: {raw_response}")
            return parsed_response, raw_response

        elif "two_pass" in self.mode:
            # NOTE: workaround to retrieve first pass response just once for experiments with critique at STOP action only.
            # If extending to critique other actions besides STOP, should allow multiple knowledge retrievals.
            num_tries = 0
            if not self.retrieved_knowledge:
                parsed_response_first_pass, raw_response, num_tries = self.retrieve_web_knowledge(
                    trajectory, intent, intent_images, meta_data
                )
                trajectory.states[-1]["retrieved_knowledge"] = parsed_response_first_pass

            logger.info(f"Critique Agent, evaluating...")
            parsed_response, raw_response, _ = self.act_parse_retry(
                trajectory=trajectory,
                intent=intent,
                intent_images=intent_images,
                meta_data=meta_data,
                idxs_trajectory=idxs_trajectory,
                parser_fn=self.prompt_constructor.parse_second_pass,
                max_tries=self.max_model_call - num_tries,
                error_msg_template=self.prompt_constructor.instruction["parsing_error_msg_template"],
                error_msg_key=self.prompt_constructor.instruction["meta_data"]["parsing_error_msg_key"],
                first_pass_utterance=self.retrieved_knowledge,
                second_pass=True,
            )
            # self.dump_html = False  # Dump only one time to save execution time

            if self.out_utterance:
                logger.info(f"\n[Critique Agent, second pass]: {raw_response}")

            return parsed_response, raw_response
        else:
            raise ValueError(f"Unknown critique mode {self.mode}")

    def get_feedback(self, parsed_response: dict[str, Any]) -> str:
        return parsed_response[self.prompt_constructor.instruction["meta_data"]["feedback_key"]]  # type: ignore

    def get_eval_score(self, parsed_response: dict[str, Any]) -> str:
        return parsed_response[self.prompt_constructor.instruction["meta_data"]["eval_key"]]  # type: ignore

    def is_success(self, parsed_response: dict[str, Any]) -> bool:
        return self.get_eval_score(parsed_response) == self.success_keyword  # type: ignore

    def hard_critique(self, trajectory: TrajectoryView, meta_data: dict[str, Any], intent: str) -> int:
        return hard_verify(trajectory, meta_data, intent)  # type:ignore

    def reset(self, test_config_file: str) -> None:
        super().reset(test_config_file)
        self.retrieved_knowledge = ""
        self.retrieved_knowledge = ""
