"""Tests for the filter observation wrapper."""
from typing import Optional

import numpy as np
import pytest

import gymnasium as gym
from gymnasium.spaces import Box, Dict, Tuple
from gymnasium.wrappers import FilterObservation, FlattenObservation


class FakeEnvironment(gym.Env):
    def __init__(self, observation_space, render_mode=None):
        self.observation_space = observation_space
        self.obs_keys = self.observation_space.spaces.keys()
        self.action_space = Box(shape=(1,), low=-1, high=1, dtype=np.float32)
        self.render_mode = render_mode

    def render(self, mode="human"):
        image_shape = (32, 32, 3)
        return np.zeros(image_shape, dtype=np.uint8)

    def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
        super().reset(seed=seed)
        observation = self.observation_space.sample()
        return observation, {}

    def step(self, action):
        del action
        observation = self.observation_space.sample()
        reward, terminal, info = 0.0, False, {}
        return observation, reward, terminal, info


NESTED_DICT_TEST_CASES = (
    (
        Dict(
            {
                "key1": Box(shape=(2,), low=-1, high=1, dtype=np.float32),
                "key2": Dict(
                    {
                        "subkey1": Box(shape=(2,), low=-1, high=1, dtype=np.float32),
                        "subkey2": Box(shape=(2,), low=-1, high=1, dtype=np.float32),
                    }
                ),
            }
        ),
        (6,),
    ),
    (
        Dict(
            {
                "key1": Box(shape=(2, 3), low=-1, high=1, dtype=np.float32),
                "key2": Box(shape=(), low=-1, high=1, dtype=np.float32),
                "key3": Box(shape=(2,), low=-1, high=1, dtype=np.float32),
            }
        ),
        (9,),
    ),
    (
        Dict(
            {
                "key1": Tuple(
                    (
                        Box(shape=(2,), low=-1, high=1, dtype=np.float32),
                        Box(shape=(2,), low=-1, high=1, dtype=np.float32),
                    )
                ),
                "key2": Box(shape=(), low=-1, high=1, dtype=np.float32),
                "key3": Box(shape=(2,), low=-1, high=1, dtype=np.float32),
            }
        ),
        (7,),
    ),
    (
        Dict(
            {
                "key1": Tuple((Box(shape=(2,), low=-1, high=1, dtype=np.float32),)),
                "key2": Box(shape=(), low=-1, high=1, dtype=np.float32),
                "key3": Box(shape=(2,), low=-1, high=1, dtype=np.float32),
            }
        ),
        (5,),
    ),
    (
        Dict(
            {
                "key1": Tuple(
                    (Dict({"key9": Box(shape=(2,), low=-1, high=1, dtype=np.float32)}),)
                ),
                "key2": Box(shape=(), low=-1, high=1, dtype=np.float32),
                "key3": Box(shape=(2,), low=-1, high=1, dtype=np.float32),
            }
        ),
        (5,),
    ),
)


class TestNestedDictWrapper:
    @pytest.mark.parametrize("observation_space, flat_shape", NESTED_DICT_TEST_CASES)
    def test_nested_dicts_size(self, observation_space, flat_shape):
        env = FakeEnvironment(observation_space=observation_space)

        # Make sure we are testing the right environment for the test.
        observation_space = env.observation_space
        assert isinstance(observation_space, Dict)

        wrapped_env = FlattenObservation(FilterObservation(env, list(env.obs_keys)))
        assert wrapped_env.observation_space.shape == flat_shape

        assert wrapped_env.observation_space.dtype == np.float32

    @pytest.mark.parametrize("observation_space, flat_shape", NESTED_DICT_TEST_CASES)
    def test_nested_dicts_ravel(self, observation_space, flat_shape):
        env = FakeEnvironment(observation_space=observation_space)
        wrapped_env = FlattenObservation(FilterObservation(env, list(env.obs_keys)))
        obs, info = wrapped_env.reset()
        assert obs.shape == wrapped_env.observation_space.shape
        assert isinstance(info, dict)
