from typing import Any, SupportsFloat, TypeVar

import gymnasium as gym
import numpy as np

ObsType = TypeVar("ObsType")
ActType = TypeVar("ActType")
WrapperObsType = TypeVar("WrapperObsType")
WrapperActType = TypeVar("WrapperActType")


def remove_false(obs: dict[Any, Any]) -> dict[Any, Any]:
    return {a: v for a, v in obs.items() if v is not False and v is not np.bool_(False)}


class RemoveFalseWrapper(
    gym.Wrapper[gym.spaces.Dict, gym.spaces.Dict, gym.spaces.Dict, gym.spaces.Dict]
):
    """
    Adds actions to the previous observation.
    """

    def __init__(self, env: gym.Env[gym.spaces.Dict, gym.spaces.Dict]) -> None:
        super().__init__(env)
        self.env = env

    def step(
        self,
        action: gym.spaces.Dict,
    ) -> tuple[
        dict[str, bool | None],
        SupportsFloat,
        bool,
        bool,
        dict[str, Any],
    ]:
        obs, reward, terminated, truncated, info = self.env.step(action)
        filtered_obs = remove_false(obs)
        return filtered_obs, reward, terminated, truncated, info

    def reset(
        self, seed: int | None = None, options: dict | None = None
    ) -> tuple[
        dict[str, bool | None],
        dict[str, Any],
    ]:
        obs, info = self.env.reset(seed=seed)
        filtered_obs = remove_false(obs)
        return filtered_obs, info
