from typing import Any, Dict, List, Optional, TypedDict
from skyrl_gym import Env
from typing import Tuple

MessageType = Dict[str, str]
ConversationType = List[MessageType]


class BaseTextEnvStepOutput(TypedDict):
    observations: ConversationType  # OpenAI API Messages Format
    reward: float
    done: bool
    metadata: Dict[str, Any]
    postprocessed_action: Optional[str] = None


class BaseTextEnv(Env[ConversationType, str]):
    """
    Base environment class for all text-in / text-out environments.
    Supports tool-calling and multi-turn trajectories.

    Exposes only `step`, `init` and `close`.

    Input Types:
        - ObsType: ConversationType (tool output, LLM input)
        - ActType: str (LLM output)
    """

    def __init__(self):
        super().__init__()

        # Metadata
        self.turns = 0
        self.max_turns = 1

        # Tool groups
        self.tool_groups = []
        self.tool_to_toolgroup = {}

    def init_tool_groups(self, tool_groups: List = []) -> None:
        """
        Initialize the tool groups for the environment.
        """
        # Find ToolGroup for a given tool
        self.tool_groups = tool_groups
        self.tool_to_toolgroup = {}
        for tool_group in self.tool_groups:
            self.tool_to_toolgroup.update(tool_group.get_tool_to_group_mapping())

    def _execute_tool(self, tool_group_name: str, tool_name: str, tool_input: Any) -> str:
        """
        Find the right ToolGroup and Tool and execute it.
        """
        for group in self.tool_groups:
            if group.name == tool_group_name:
                return group.execute_tool(tool_name, *tool_input)  # tool_input must be tuple or list

        raise ValueError(f"ToolGroup '{tool_group_name}' not found.")

    def step(self, action: str) -> BaseTextEnvStepOutput:
        """
        Runs one environment step.

        Return:
        - observations: [{"role": "user", "content": observation}]
        - reward: float
        - done: bool
        - postprocessed_action: Optional[str]
        - metadata: Dict[str, Any] any metadata
        """
        pass

    def init(self, prompt: ConversationType) -> Tuple[ConversationType, Dict[str, Any]]:
        """
        Return the first prompt to be given to the model and optional metadata.
        """
        return prompt, {}

    def close(self):
        """
        Closes the environment, override if needed by subclasses.
        """
        pass

    def get_metrics(self) -> Dict[str, Any]:
        """
        Return environment-specific metrics for the episode.
        Default is empty dict (no metrics).
        """
        return {}

    @staticmethod
    def aggregate_metrics(metrics: List[Dict[str, Any]]) -> Dict[str, Any]:
        """
        Static method to aggregate metrics across many episodes of this env class.
        Default behavior: average the numerics, drop the non-numerics.
        """
        from skyrl_gym.metrics import default_aggregate_metrics

        return default_aggregate_metrics(metrics)
