from typing import Any

from agent.agent import Agent
from llms.tokenizers import Tokenizer
from utils.trajectory_view import TrajectoryView
from utils.types import ImageInput


class TextRefiner(Agent):
    def __init__(self, lm_config: dict[str, Any], agent_config: dict[str, Any]) -> None:
        self.tokenizer = Tokenizer(model_name=lm_config["model"])
        self.lm_config = lm_config
        self.max_obs_length = agent_config["max_obs_length"]
        self.agent_config = agent_config

    def select_states(self, trajectory: TrajectoryView, num_states: int = 1) -> list[dict[str, Any]]:
        return super().select_states(trajectory, num_states=num_states)

    def next_action(
        self, trajectory: TrajectoryView, intent: str, intent_images: list[ImageInput], meta_data: dict[str, Any]
    ) -> None:
        env_state = self.select_states(trajectory)[0]
        text_obs = env_state["observation"]["text"]
        env_state["observation"]["text"] = self.refine_text_obs(text_obs)

    def refine_text_obs(self, text_observation: str) -> str:
        # Trim text observation to `max_obs_length`
        if self.max_obs_length > 0:
            if self.tokenizer.provider == "google":
                # If Gemini, trim per character
                text_observation = text_observation[: self.max_obs_length * self.agent_config["google_chars_per_token"]]
            else:
                # If other, tokenize and trim
                tok_obs = self.tokenizer.encode(text_observation, add_special_tokens=False)
                text_observation = self.tokenizer.decode(tok_obs[: self.max_obs_length], skip_special_tokens=False)
        return text_observation
