from typing import Any

from agent.agent import Agent
from agent.executor import ExecutorAgent
from browser_env.actions import ActionParsingError
from utils.logger_utils import logger
from utils.trajectory_view import TrajectoryView
from utils.types import ImageInput
from vwa_utils.extract_trajectory_html import rebuild_trajectory_vwa_format


class TeacherForcingAgent(Agent):
    def __init__(
        self,
        trajectory_html_path: str,
        agents_configs: dict[str, Any] | None = None,
        action_set_tag: str = "som",
        action_splitter: str = "```",
    ) -> None:
        if agents_configs is None:
            agents_configs = {
                "executor_agent": {
                    "action_set_tag": action_set_tag,
                    "action_splitter": action_splitter,
                    "lm_config": None,
                    "max_model_call": 3,
                    "prompt": "p_som_cot_id_actree_3s_prev_utterances",
                }
            }
            logger.warning("No agents config provided. Creating a default config.")
            print(agents_configs)

        self.action_set_tag = agents_configs["executor_agent"]["action_set_tag"]
        self.executor_agent = ExecutorAgent(
            lm_config=agents_configs["executor_agent"]["lm_config"],
            agent_config=agents_configs["executor_agent"],
        )
        self.trajectory_html_path = trajectory_html_path
        self.out_utterance = agents_configs["executor_agent"]["out_utterance"]
        self.action_list: list[dict[str, Any]] = []

        ans = input("Teacher Forcing Agent. Stop at critique during action parsing? (y/n)")
        self.stop_at_critique = True if ans == "y" else False

    def next_action(
        self,
        trajectory: TrajectoryView,
        intent: str,
        intent_images: list[ImageInput],
        meta_data: dict[str, Any],
    ) -> Any:
        if self.action_list:
            raw_response = self.action_list.pop(0)["raw_prediction"]
        else:
            raw_response = input("Enter the raw prediction. e.g.: ```Therefore, click [element_id]```")
            # raise ValueError("No more actions to execute.")

        parsed_response = self.executor_agent.prompt_constructor.extract_action(raw_response)
        action = self.executor_agent.create_action(parsed_response)

        if self.out_utterance:
            logger.info(f"\n[Executor Agent]: {raw_response}")

        try:
            action = self.executor_agent.create_action(parsed_response)
            action.update({"raw_prediction": raw_response})
            return action

        except ActionParsingError:
            raise ValueError(f"Cannot parse the action from the response: {raw_response}")

    def reset(self, test_config_file: str) -> None:
        task_id = test_config_file.split("/")[-1].split(".")[0]
        html_path = f"{self.trajectory_html_path}/render_{task_id}.html"
        executed_trajectory, meta_data, _ = rebuild_trajectory_vwa_format(
            html_path=html_path, stop_at_critique=self.stop_at_critique
        )
        self.action_list = executed_trajectory[1::2]
        self.action_list = executed_trajectory[1::2]
