from typing import Any, Callable

from agent.agent import Agent
from agent.critic import CriticAgent
from agent.executor import ExecutorAgent
from agent.image_refiner import ImageRefiner
from agent.prompt_constructor import PromptConstructor
from agent.request_refiner import RequestRefiner
from agent.text_refiner import TextRefiner
from browser_env.actions import ActionTypes, action2str
from llms.tokenizers import Tokenizer
from utils.logger_utils import logger
from utils.string_utils import clean_spaces
from utils.timing_utils import timeit
from utils.trajectory_view import TrajectoryView
from utils.types import ImageInput


class ModularAgent(Agent):
    def __init__(
        self,
        image_refiner: ImageRefiner | None,
        text_refiner: TextRefiner | None,
        executor_agent: ExecutorAgent,
        request_refiner: RequestRefiner | None,
        critique_agent: CriticAgent | None,
    ):
        super().__init__()
        self.image_refiner = image_refiner
        self.text_refiner = text_refiner
        self.executor_agent = executor_agent
        self.request_refiner = request_refiner
        self.critique_agent = critique_agent
        self.all_modules = [
            self.image_refiner,
            self.text_refiner,
            self.executor_agent,
            self.request_refiner,
            self.critique_agent,
        ]
        self.web_knowledge = ""

    @timeit(custom_name="AGENT:next_action")
    def next_action(
        self,
        trajectory: TrajectoryView,
        intent: str,
        intent_images: list[ImageInput],
        meta_data: dict[str, Any],
    ) -> Any:
        # Refine user request
        if self.request_refiner is not None:
            # pass
            new_intent = self.request_refiner.next_action(trajectory, intent, intent_images, meta_data)

        # Create plan #TODO

        # Refine text obs
        if self.text_refiner is not None:
            self.text_refiner.next_action(trajectory, intent, intent_images, meta_data)

        # Refine image obs
        if self.image_refiner is not None:
            self.image_refiner.next_action(trajectory, intent, intent_images, meta_data)
            # self.image_refiner.next_action(trajectory, new_intent, intent_images, meta_data)

        # Execute low-level action on environment
        critique_executor_loop = 0
        utterance_history: dict[str, list[str]] = {"executor": [], "critique": []}

        # If executor using web knowledge, retrieve web knowledge at initiation
        if (
            self.critique_agent is not None
            and "two_pass" in self.critique_agent.mode
            and not self.web_knowledge
            and self.executor_agent.use_web_knowledge
        ):
            self.web_knowledge, _, _ = self.critique_agent.retrieve_web_knowledge(
                trajectory, intent, intent_images, meta_data
            )

        while True:
            meta_data["web_knowledge"] = self.web_knowledge
            logger.info(f"Executor Agent action")
            low_level_action = self.executor_agent.next_action(
                trajectory, intent, intent_images, meta_data, log_score=(critique_executor_loop == 0)
            )
            if "early_stop" in low_level_action:
                break

            utterance_history["executor"].append(low_level_action["raw_prediction"])

            # If critique agent provided, start executor-critique loop
            if low_level_action["action_type"] == ActionTypes.STOP and self.critique_agent is not None:
                critique_executor_loop += 1

                # Stop the critique-revise loop if max_critique_executor_loop is reached
                if critique_executor_loop > self.critique_agent.max_critique_executor_loop:
                    break

                # Store ground truth score
                temp_trajectory = TrajectoryView(trajectory.trajectory + [low_level_action])

                # Create a temporary trajectory with the low-level action
                action_str = action2str(low_level_action, self.executor_agent.action_set_tag)
                temp_meta_data = meta_data.copy()
                temp_meta_data["action_str_history"] = meta_data["action_str_history"] + [action_str]

                # Call Critic
                critique_response, raw_critique_response = self.critique_agent.next_action(
                    temp_trajectory, intent, intent_images, temp_meta_data
                )
                utterance_history["critique"].append(raw_critique_response)

                low_level_action.update({"critique_executor_loop_utterances": utterance_history})  # type: ignore

                # Log scores
                if self.executor_agent.score_logger:
                    self.executor_agent.score_logger.log_scores_per_round(temp_trajectory, intent, meta_data)

                if self.critique_agent.is_success(critique_response):
                    break
                else:
                    meta_data["critique_feedback"] = clean_spaces(self.critique_agent.get_feedback(critique_response))
                    meta_data["prev_utterance_for_feedback"] = clean_spaces(low_level_action["raw_prediction"])
            else:
                break

        if utterance_history["executor"] and utterance_history["critique"]:
            low_level_action["critique_executor_loop_utterances"] = utterance_history  # type: ignore

        return low_level_action

    def get_prompt_constructor(self) -> PromptConstructor:
        if self.executor_agent:
            return self.executor_agent.prompt_constructor
        else:
            raise ValueError("No executor agent provided")

    def get_model(self, agent_name: str = "executor") -> str:
        # TODO: finish this function
        if agent_name == "executor" and self.executor_agent:
            return self.executor_agent.lm_config["model"]  # type: ignore
        else:
            raise NotImplementedError(f"Unknown agent {agent_name}")

    def get_provider(self, agent_name: str = "executor") -> str:
        # TODO: finish this function
        if agent_name == "executor" and self.executor_agent:
            return self.executor_agent.tokenizer.provider
        else:
            raise NotImplementedError(f"Unknown agent {agent_name}")

    def get_action_splitter(self, agent_name: str = "executor") -> str:
        if agent_name == "executor" and self.executor_agent:
            return self.executor_agent.get_action_splitter()
        else:
            raise ValueError(f"Unknown agent {agent_name}")

    def get_tokenizer(self, agent_name: str = "executor") -> Tokenizer:
        if agent_name == "executor" and self.executor_agent:
            return self.executor_agent.get_tokenizer()
        else:
            raise ValueError(f"Unknown agent {agent_name}")

    def reset(self, test_config_file: str) -> None:
        self.web_knowledge = ""
        for module in self.all_modules:
            if module is not None:
                module.reset(test_config_file)


def construct_modular_agent(
    agents_configs: dict[str, Any],
    caption_image_fn: Callable,  # type: ignore
) -> ModularAgent:
    # Load configurations

    if "executor_agent" not in agents_configs:
        raise ValueError("Please provide a low-level executor module.")

    agent_config = agents_configs["executor_agent"]
    executor_agent = ExecutorAgent(
        lm_config=agent_config["lm_config"],
        agent_config=agent_config,
    )

    image_refiner = None
    if "image_refiner" in agents_configs:
        agent_config = agents_configs["image_refiner"]
        image_refiner = ImageRefiner(
            lm_config=agent_config["lm_config"],
            agent_config=agent_config,
        )

    text_refiner = None
    if "text_refiner" in agents_configs:
        agent_config = agents_configs["text_refiner"]
        text_refiner = TextRefiner(
            lm_config=agent_config["lm_config"],
            agent_config=agent_config,
        )

    request_refiner = None
    if "request_refiner" in agents_configs:
        request_refiner = RequestRefiner(
            agent_config=agents_configs["request_refiner"],
            captioning_fn=caption_image_fn,
        )

    critique_agent = None
    if "critique_agent" in agents_configs:
        agent_config = agents_configs["critique_agent"]
        critique_agent = CriticAgent(
            lm_config=agent_config["lm_config"],
            agent_config=agent_config,
        )

    return ModularAgent(
        image_refiner=image_refiner,
        text_refiner=text_refiner,
        executor_agent=executor_agent,
        request_refiner=request_refiner,
        critique_agent=critique_agent,
    )
