from typing import Any, Callable

import yaml

from agent.prompt_constructor import LMParsingError, PromptConstructor
from llms.llm_utils import call_llm
from llms.tokenizers import Tokenizer
from utils.logger_utils import logger
from utils.string_utils import clean_spaces
from utils.timing_utils import time_block
from utils.trajectory_view import TrajectoryView
from utils.types import ImageInput
from vwa_utils.score_logger import ScoreLogger


class Agent:
    """Base class for the agent"""

    lm_config: dict[str, Any]  # Configuration for the language model
    agent_config: dict[str, Any]  # Agent-specific configuration settings
    prompt_constructor: PromptConstructor  # construct prompts during navigation
    score_logger: ScoreLogger | None  # computes and log GT scores per executor-critic
    dump_html: bool  # Whether to dump interaction to HTML for current state
    conversation_dir: str  # Directory path where conversation logs are stored
    usage_dir: str  # Directory path used to store API-specific metrics

    def __init__(self, *args: Any) -> None:
        self.score_logger = None
        self.dump_html = True

    def next_action(
        self,
        trajectory: TrajectoryView,
        intent: str,
        intent_images: list[ImageInput],
        meta_data: dict[str, Any],
    ) -> Any:
        """Predict the next action given the observation"""
        raise NotImplementedError

    def select_states(self, trajectory: TrajectoryView, num_states: int) -> list[Any]:
        """Basic state selection of the last `num_states` states.
        Agents can override this method with more sophisticated state selection logic.

        Args:
            trajectory (TrajectoryView): Trajectory of states
            num_states (int): Number of states to select

        Returns:
            list: List of states to select
        """
        num_states = max(1, num_states)
        if num_states == 1:
            return [trajectory.states[-1]]
        else:
            return trajectory.states[-num_states:]

    def act_parse_retry(
        self,
        trajectory: TrajectoryView,
        intent: str,
        intent_images: list[ImageInput],
        meta_data: dict[str, Any],
        idxs_trajectory: list[int],
        parser_fn: Callable[[str], Any],
        max_tries: int,
        error_msg_template: str,
        error_msg_key: str,
        **kwargs: Any,
    ) -> tuple[Any, str, int]:
        """Build prompt, call model, parse response, and retry on failure.
        Try for `max_tries` times, appending error message to prompt in case of parsing failure.

        Args:
            trajectory (TrajectoryView): Trajectory of states
            intent (str): User objective
            intent_images (list[Image.Image]): Images of the objective
            meta_data (dict[str, Any]): Metadata
            idxs_trajectory (list[int]): Indices of the trajectory to use
            parser_fn (callable): Function to parse the response
            max_tries (int): Maximum number of tries
            error_msg_template (str): Template for the error message
            error_msg_key (str): Key for the error message in meta_data
            **kwargs: Additional arguments for the prompt constructor

        Raises:
            ValueError: If the response cannot be parsed after `max_tries` attempts

        Returns:
            tuple[Any, str, int]: Parsed response, raw response, number of tries
        """

        num_tries = 0
        while num_tries < max_tries:
            try:
                # Build API messages with feedback from previous attempts
                with time_block("AGENT:construct_prompt"):
                    prompt = self.prompt_constructor.construct(
                        trajectory=trajectory,
                        objective=intent,
                        objective_imgs=intent_images,
                        meta_data=meta_data,
                        idxs_trajectory=idxs_trajectory,
                        **kwargs,
                    )

                # Clean feedback from previous attempts
                meta_data[error_msg_key] = "" if error_msg_key in meta_data else None

                # Call model and parse response
                with time_block("AGENT:call_llm"):
                    _, model_generations = call_llm(
                        gen_kwargs=self.lm_config,
                        prompt=prompt,
                        meta_data=meta_data,
                        call_id=meta_data["task_id"],
                        conversation_dir=self.conversation_dir,
                        usage_dir=self.usage_dir,
                        dump_html=self.dump_html,
                    )

                # TODO: multiple generations handling
                raw_response = model_generations[0].text()
                parsed_response = parser_fn(raw_response)
                return parsed_response, raw_response, num_tries

            # If parsing fails, append error message to prompt and try again
            except LMParsingError:
                error_message = error_msg_template.format(response=clean_spaces(raw_response))
                meta_data[error_msg_key] = error_message
                num_tries += 1

            # If any other error occurs, raise it
            except Exception as e:
                meta_data[error_msg_key] = "" if error_msg_key in meta_data else None
                raise e

        # If parsing still fails after `max_tries` attempts, raise error
        meta_data[error_msg_key] = "" if error_msg_key in meta_data else None
        raise LMParsingError(
            f"Failed to parse {self.__str__()} response after {max_tries} attempts.", raw_response=raw_response
        )

    def log_scores_per_round(
        self,
        trajectory: TrajectoryView,
        intent: str,
        meta_data: dict[str, Any],
    ) -> None:
        """Not part of the Agentic pipeline. Helper to log gold scores at each executor-critic loop.

        Args:
            trajectory (TrajectoryView): Trajectory of states
            intent (str): User objective
            meta_data (dict[str, Any]): Metadata
        """
        if self.score_logger:
            self.score_logger.log_scores_per_round(trajectory, intent, meta_data)

    def reset(self, test_config_file: str) -> None:
        """Reset the agent. Agents can override this method with agent-specific reset logic.

        Args:
            test_config_file (str): Path to the test config file
        """
        self.dump_html = True

    def __str__(self) -> str:
        """String method to print agent name"""
        return f"{self.__class__.__name__}"

    def get_prompt_constructor(self) -> PromptConstructor:
        return self.prompt_constructor

    def get_model(self, agent_name: str = "") -> str:
        return self.lm_config["model_path"]  # type: ignore

    def get_provider(self, agent_name: str = "") -> str:
        return self.prompt_constructor.tokenizer.provider

    def get_action_splitter(self, agent_name: str = "") -> str:
        return self.prompt_constructor.instruction["meta_data"]["action_splitter"]  # type: ignore

    def get_tokenizer(self, agent_name: str = "") -> Tokenizer:
        return self.prompt_constructor.tokenizer


def load_agent_config(agent_config_path: str) -> dict[str, Any]:
    with open(agent_config_path, "r") as f:
        agent_config = yaml.safe_load(f)  # type: ignore
    return agent_config  # type: ignore


class ComputerUseAgent(Agent):
    # TODO
    pass
