import numpy as np
import torch as th
from gymnasium import spaces

from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import VecEnv, VecExtractDictObs, VecMonitor


class DictObsVecEnv(VecEnv):
    """Custom Environment that produces observation in a dictionary like the procgen env"""

    metadata = {"render_modes": ["human"]}

    def __init__(self):
        self.num_envs = 4
        self.action_space = spaces.Discrete(2)
        self.observation_space = spaces.Dict({"rgb": spaces.Box(low=0.0, high=255.0, shape=(86, 86), dtype=np.float32)})
        self.n_steps = 0
        self.max_steps = 5
        self.render_mode = None

    def step_async(self, actions):
        self.actions = actions

    def step_wait(self):
        self.n_steps += 1
        done = self.n_steps >= self.max_steps
        if done:
            infos = [
                {"terminal_observation": {"rgb": th.zeros((86, 86), dtype=th.float32)}, "TimeLimit.truncated": True}
                for _ in range(self.num_envs)
            ]
        else:
            infos = []
        return (
            {"rgb": th.zeros((self.num_envs, 86, 86), dtype=th.float32)},
            th.zeros((self.num_envs,), dtype=th.float32),
            th.ones((self.num_envs,), dtype=bool) * done,
            infos,
        )

    def reset(self):
        self.n_steps = 0
        return {"rgb": th.zeros((self.num_envs, 86, 86), dtype=th.float32)}

    def render(self, mode=""):
        pass

    def get_attr(self, attr_name, indices=None):
        indices = range(self.num_envs) if indices is None else indices
        return [getattr(self, attr_name) for _ in indices]

    def close(self):
        pass

    def env_is_wrapped(self, wrapper_class, indices=None):
        indices = range(self.num_envs) if indices is None else indices
        return [False for _ in indices]

    def env_method(self):
        raise NotImplementedError  # not used in the test

    def set_attr(self, attr_name, value, indices=None) -> None:
        raise NotImplementedError  # not used in the test


def test_extract_dict_obs():
    """Test VecExtractDictObs"""

    env = DictObsVecEnv()
    env = VecExtractDictObs(env, "rgb")
    assert env.reset().shape == (4, 86, 86)

    for _ in range(10):
        obs, _, dones, infos = env.step([env.action_space.sample() for _ in range(env.num_envs)])
        assert obs.shape == (4, 86, 86)
        for idx, info in enumerate(infos):
            if "terminal_observation" in info:
                assert dones[idx]
                assert info["terminal_observation"].shape == (86, 86)
            else:
                assert not dones[idx]


def test_vec_with_ppo():
    """
    Test the `VecExtractDictObs` with PPO
    """
    env = DictObsVecEnv()
    env = VecExtractDictObs(env, "rgb")
    monitor_env = VecMonitor(env)
    model = PPO("MlpPolicy", monitor_env, verbose=1, n_steps=64, device="cpu")
    model.learn(total_timesteps=250)
