from abc import ABC, abstractmethod
from typing import Any, SupportsFloat

import gymnasium as gym


class MultiRewardWrapper(gym.Wrapper[Any, Any, Any, Any], ABC):
    """
    An abstract base class for Gymnasium wrappers that compute multiple reward
    signals.

    The wrapper handles the boilerplate of accumulating rewards over an episode,
    selecting one as the primary reward signal, and logging all cumulative
    rewards in the `info` dictionary upon episode termination.

    Subclasses must implement:
    - `_reward_keys`: A property that returns the set of reward keys.
    - `_calculate_rewards`: A method to compute the rewards for the current step.
    """

    def __init__(self, env: gym.Env, reward_type: str = "standard") -> None:
        super().__init__(env)
        if reward_type not in self.reward_keys and reward_type != "standard":
            raise ValueError(
                f"Unknown reward_type '{reward_type}'. "
                f"Available types: {self.reward_keys}"
            )
        self.reward_type = reward_type
        self.cumulative_rewards: dict[str, float] = {}
        self._reset_rewards()

    @property
    @abstractmethod
    def reward_keys(self) -> set[str]:
        """Returns the set of available reward keys for the environment."""
        raise NotImplementedError

    @abstractmethod
    def _calculate_rewards(
        self,
        obs: Any,
        reward: SupportsFloat,
        terminated: bool,
        truncated: bool,
        info: dict[str, Any],
        action: Any,
    ) -> dict[str, float]:
        """
        Calculates the dictionary of rewards for the current step.

        This method contains the environment-specific reward logic.
        """
        raise NotImplementedError

    def _reset_rewards(self) -> None:
        """Resets all cumulative reward counters."""
        self.cumulative_rewards = {key: 0.0 for key in self.reward_keys}
        self.cumulative_rewards["standard"] = 0.0
        self.cumulative_rewards["length"] = 0.0

    def step(
        self, action: Any
    ) -> tuple[Any, SupportsFloat, bool, bool, dict[str, Any]]:
        obs, reward, terminated, truncated, info = self.env.step(action)

        # Calculate environment-specific rewards for the current step
        step_rewards = self._calculate_rewards(
            obs, reward, terminated, truncated, info, action
        )
        if "standard" not in step_rewards:
            step_rewards["standard"] = float(reward)

        # Accumulate rewards
        for key, value in step_rewards.items():
            if key in self.cumulative_rewards:
                self.cumulative_rewards[key] += value
        self.cumulative_rewards["length"] += 1

        # On episode end, add all cumulative rewards to the info dict
        if terminated or truncated:
            info.update(self.cumulative_rewards)
            # If episode info exists, a monitor is using it to log rewards
            if "episode" in info:
                info["episode"]["r"] = self.cumulative_rewards[self.reward_type]

        # Return the selected reward type as the main reward signal
        main_reward = step_rewards[self.reward_type]
        return obs, main_reward, terminated, truncated, info

    def reset(
        self, *, seed: int | None = None, options: dict[str, Any] | None = None
    ) -> tuple[Any, dict[str, Any]]:
        self._reset_rewards()
        return self.env.reset(seed=seed, options=options)
