from typing import Any

from compression_autoencoder.envs.multireward_wrapper import MultiRewardWrapper
import gymnasium as gym


class MultiRewardMountainCarWrapper(MultiRewardWrapper):
    """Computes multiple rewards for MountainCarContinuous-v0."""

    def __init__(self, env: gym.Env, reward_type: str = "standard") -> None:
        self._reached_left = False
        super().__init__(env, reward_type)

    @property
    def reward_keys(self) -> set[str]:
        return {"speed", "left", "height"}

    def _calculate_rewards(
        self, obs, reward, terminated, truncated, info, action
    ) -> dict[str, float]:
        position, velocity = obs
        height = self.env.unwrapped._height(position)  # noqa

        rewards = {
            "speed": velocity**2,
            "height": height**2 if height >= 0.2 else 0.0,
        }

        if not self._reached_left and position <= -1.1:
            rewards["left"] = 100.0
            self._reached_left = True
        else:
            rewards["left"] = -(action[0] ** 2) * 0.1

        return rewards

    def reset(
        self, *, seed: int | None = None, options: dict[str, Any] | None = None
    ) -> tuple[Any, dict[str, Any]]:
        self._reached_left = False
        return super().reset(seed=seed, options=options)
