import json
import re
from typing import Any, List

from agent.to_json import python_to_json_prompt_by_id
from browser_env.utils import map_url_to_local, map_url_to_real
from llms.prompt_utils import get_interleaved_img_txt_msg, get_message, get_messages
from llms.tokenizers import Tokenizer
from llms.types import Message
from utils.image_utils import any_to_pil
from utils.string_utils import clean_spaces, safe_format
from utils.trajectory_view import TrajectoryView
from utils.types import ImageInput

BASE_PROMPT_JSON_PATH = "agent/prompts/jsons"


class LMParsingError(Exception):
    def __init__(self, message: str, raw_response: str = "") -> None:
        self.message = message
        self.raw_response = raw_response
        super().__init__(self.message)


def load_instruction(prompt_id: str, path_raw_prompts: str = "./agent/prompts/raw") -> dict[str, Any]:
    python_to_json_prompt_by_id(path_raw_prompts=path_raw_prompts, prompt_id=prompt_id)
    instruction_path = f"{BASE_PROMPT_JSON_PATH}/{prompt_id}.json"
    try:
        with open(instruction_path, "r") as f:
            return json.load(f)  # type: ignore
    except FileNotFoundError:
        return {}


class PromptConstructor(object):
    def __init__(self, lm_config: dict[str, Any], agent_config: dict[str, Any]):
        self.lm_config: dict[str, Any] = lm_config
        self.text_first: bool = self.lm_config.get("text_first", True)
        self.instruction: dict[str, Any] = load_instruction(agent_config["prompt"])
        self.use_text_observation: bool = self.instruction["meta_data"]["use_text_observation"]
        self.use_img_observation: bool = self.instruction["meta_data"]["use_img_observation"]
        self.parsing_error_msg_key: str | None = self.instruction["meta_data"].get("parsing_error_msg_key", None)
        self.img_detail: str = self.lm_config.get("img_detail", "auto")

    def construct(
        self,
        trajectory: TrajectoryView,
        objective: str,
        objective_imgs: list[ImageInput],
        meta_data: dict[str, Any],
        idxs_trajectory: list[int] = [],
    ) -> list[Message]:
        raise NotImplementedError("Subclasses must implement the construct method")

    def build_interaction_history(
        self,
        trajectory: TrajectoryView,
        meta_data: dict[str, Any],
        idxs_history: list[int] = [],
    ) -> list[Message]:
        messages: list[Message] = []
        if not idxs_history:
            return messages

        # Pre-filter valid indices from the sorted history in one pass.
        valid_indices = [idx for idx in sorted(idxs_history) if "invalid" not in trajectory.actions[idx]]
        if not valid_indices:
            return messages

        # If there is an intro execution history, add it
        if self.instruction.get("intro_execution_history"):
            messages.append(
                get_message(
                    inputs=self.instruction["intro_execution_history"],
                    role="user",
                    name=self.lm_config.get("name_user", ""),
                )
            )

        t = len(valid_indices)
        for idx in valid_indices:
            action = trajectory.actions[idx]
            state = trajectory.states[idx]
            text_obs = state["observation"]["text"] if self.use_text_observation else ""
            img_obs = state["observation"]["image"] if self.use_img_observation else ""

            intro_txt_obs = ""
            if self.use_text_observation and "intro_txt_obs_history" in self.instruction:
                intro_txt_obs = self.instruction["intro_txt_obs_history"].format(t=t)

            intro_img_obs = ""
            if self.use_img_observation and "intro_img_obs_history" in self.instruction:
                intro_img_obs = self.instruction["intro_img_obs_history"].format(t=t)

            utterance = ""
            no_prediction_flag = False

            # Add assistant responses
            if self.instruction["meta_data"].get("use_assistant_utterance") or self.instruction["meta_data"].get(
                "last_u"
            ):
                if self.instruction["meta_data"].get("last_u") and idx == len(valid_indices) - 1:
                    utterance = clean_spaces(action["raw_prediction"])
                    no_prediction_flag = not utterance
                elif self.instruction["meta_data"].get("use_assistant_utterance"):
                    utterance = clean_spaces(action["raw_prediction"])
                    no_prediction_flag = not utterance

            if self.instruction["meta_data"].get("use_low_level_actions"):
                utterance += clean_spaces(action["extracted_action"])
                # if re.match(r"^stop\s*\[[Ee]arly stop:.*\]$", utterance, re.IGNORECASE):
                #     utterance = "stop[]"

            # Add action strings parsed by the environment. OR condition: if no prediction, this is added as fallback
            elif self.instruction["meta_data"].get("use_low_level_actions_env_parsed") or no_prediction_flag:
                if not (self.instruction["meta_data"].get("last_u") and idx == len(valid_indices) - 1):
                    # Use action_str_history. Obs: this is environment specific to VWA
                    u = clean_spaces(meta_data["action_str_history"][idx + 1])
                    u = re.sub(r"\n", "", u)
                    # If 'early stop' within [], strip it
                    # if re.match(r"^stop\s*\[[Ee]arly stop:.*\]$", u, re.IGNORECASE):
                    #     u = "stop[]"
                    utterance += f"\n{u}"

            # FIXME: remove this; some addition from get trace made necessary to strip from this "Executor:"
            utterance = re.sub(r"^Executor:\s*", "", utterance)
            user_msg = get_message(
                inputs=[intro_txt_obs, text_obs, intro_img_obs, img_obs],
                role="user",
                name=self.lm_config.get("name_user", ""),
                img_detail=self.lm_config.get("img_detail", "auto"),
            )
            assistant_msg = get_message(
                inputs=utterance,
                role="assistant",
                name=self.lm_config.get("name_assistant", ""),
                img_detail=self.lm_config.get("img_detail", "auto"),
            )

            messages.extend([user_msg, assistant_msg])
            t -= 1
        return messages

    def build_rationale_action_history(
        self,
        trajectory: TrajectoryView,
        meta_data: dict[str, Any],
        idxs_history: list[int] = [],
    ) -> str:
        # If no history yet, "None"
        if idxs_history is None or len(idxs_history) == 0:
            return "None"

        # Else, build prompt with previous actions and utterances
        prompt_parts: list[str] = []
        t = 0
        for idx in reversed(idxs_history):
            action = trajectory.actions[idx]
            # If invalid action, skip (action,state) pair
            if "invalid" in action:
                continue
            t += 1

            utterance = clean_spaces(action["raw_prediction"])
            parsed_action = clean_spaces(meta_data["action_str_history"][idx + 1])

            prompt = safe_format(
                self.instruction["intro_txt_obs_history"],
                t=t,
                utterance=utterance,
                parsed_action=parsed_action,
            )
            prompt_parts.append(prompt)
        if len(prompt_parts) == 0:
            return "None"

        return "\n".join(reversed(prompt_parts))

    def get_image_captions(
        self,
        images: list[ImageInput],
        meta_data: dict[str, Any],
        key: str = "intent_images_captions",
    ) -> list[str]:
        img_captions: list[str] = []
        for _, img in enumerate(images):
            img = any_to_pil(img)
            img_caption = meta_data[key].get(hash(img.tobytes()), "")
            img_captions.append(img_caption)
        return img_captions

    # TODO: make this more general; refine function; modularize away the textual intros to instruction
    def build_intent_message(
        self,
        trajectory: TrajectoryView,
        text_intent_input: str,
        objective_imgs: list[ImageInput],
        meta_data: dict[str, Any],
        objective_image_captions: list[str] = [],
        add_states_idxs: list[int] = [],
        add_state_text: bool = False,
        add_state_img: bool = False,
        state_img_intros: list[str] = [],
        state_text_intros: list[str] = [],
        role: str = "user",
        name: str = "",
    ) -> list[Message]:
        imgs: list[ImageInput] = []
        full_img_captions: list[str] = []
        img_idx: int = 0
        text_observations_input: str = ""

        # Add environment state as part of current intent
        if add_states_idxs:
            for i, idx in enumerate(add_states_idxs):
                # Select the state
                state = trajectory.states[idx]

                # Add screenshot of current state
                if add_state_img and self.use_img_observation:
                    imgs.append(state["observation"]["image"])
                    full_img_captions.append(f"Image {img_idx}: {state_img_intros[i]}")

                # Add text observation of current state
                if add_state_text and self.use_text_observation:
                    if state_text_intros:
                        text_observations_input += f"{state_text_intros[i]}"
                    text_observations_input += f"{state['observation']['text']}\n\n"
                img_idx += 1

        # Full textual input = text observations + text intent
        final_text_input = f"{text_observations_input}{text_intent_input}"
        if self.parsing_error_msg_key and meta_data.get(self.parsing_error_msg_key):
            final_text_input += meta_data[self.parsing_error_msg_key]

        # Get intent images and corresponding captions
        if objective_imgs is not None and len(objective_imgs) > 0:
            for i, img in enumerate(objective_imgs):
                imgs.append(img)
                # Index of image in the full list of images
                img_caption = f"Image {img_idx}: objective image {i + 1}"

                # If image was captioned by previous agents, add it to the full caption
                additional_caption = objective_image_captions[i]
                if additional_caption:
                    img_caption = f"{img_caption}; description: {additional_caption}."
                else:
                    img_caption = f"{img_caption}."

                full_img_captions.append(img_caption)
                img_idx += 1

        # Prepend image captions with "IMAGES:"
        if full_img_captions:
            full_img_captions[0] = f"IMAGES:\n{full_img_captions[0]}"

        # Build message
        msg = get_interleaved_img_txt_msg(
            images=imgs,
            img_captions=full_img_captions,
            role=role,
            name=name,
            img_detail=self.lm_config.get("img_detail", "auto"),
            text_prefix=final_text_input,
            text_first=self.lm_config.get("text_first", True),
        )
        return [msg]
