"""Tests for the pixel observation wrapper."""
from typing import Optional

import numpy as np
import pytest

import gymnasium as gym
from gymnasium import spaces
from gymnasium.wrappers.pixel_observation import STATE_KEY, PixelObservationWrapper


class FakeEnvironment(gym.Env):
    def __init__(self, render_mode="single_rgb_array"):
        self.action_space = spaces.Box(shape=(1,), low=-1, high=1, dtype=np.float32)
        self.render_mode = render_mode

    def render(self, mode="human", width=32, height=32):
        image_shape = (height, width, 3)
        return np.zeros(image_shape, dtype=np.uint8)

    def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
        super().reset(seed=seed)
        observation = self.observation_space.sample()
        return observation, {}

    def step(self, action):
        del action
        observation = self.observation_space.sample()
        reward, terminal, info = 0.0, False, {}
        return observation, reward, terminal, info


class FakeArrayObservationEnvironment(FakeEnvironment):
    def __init__(self, *args, **kwargs):
        self.observation_space = spaces.Box(
            shape=(2,), low=-1, high=1, dtype=np.float32
        )
        super().__init__(*args, **kwargs)


class FakeDictObservationEnvironment(FakeEnvironment):
    def __init__(self, *args, **kwargs):
        self.observation_space = spaces.Dict(
            {
                "state": spaces.Box(shape=(2,), low=-1, high=1, dtype=np.float32),
            }
        )
        super().__init__(*args, **kwargs)


@pytest.mark.parametrize("pixels_only", (True, False))
def test_dict_observation(pixels_only):
    pixel_key = "rgb"

    env = FakeDictObservationEnvironment()

    # Make sure we are testing the right environment for the test.
    observation_space = env.observation_space
    assert isinstance(observation_space, spaces.Dict)

    width, height = (320, 240)

    # The wrapper should only add one observation.
    wrapped_env = PixelObservationWrapper(
        env,
        pixel_keys=(pixel_key,),
        pixels_only=pixels_only,
        render_kwargs={pixel_key: {"width": width, "height": height}},
    )

    assert isinstance(wrapped_env.observation_space, spaces.Dict)

    if pixels_only:
        assert len(wrapped_env.observation_space.spaces) == 1
        assert list(wrapped_env.observation_space.spaces.keys()) == [pixel_key]
    else:
        assert (
            len(wrapped_env.observation_space.spaces)
            == len(observation_space.spaces) + 1
        )
        expected_keys = list(observation_space.spaces.keys()) + [pixel_key]
        assert list(wrapped_env.observation_space.spaces.keys()) == expected_keys

    # Check that the added space item is consistent with the added observation.
    observation, info = wrapped_env.reset()
    rgb_observation = observation[pixel_key]

    assert isinstance(info, dict)
    assert rgb_observation.shape == (height, width, 3)
    assert rgb_observation.dtype == np.uint8


@pytest.mark.parametrize("pixels_only", (True, False))
def test_single_array_observation(pixels_only):
    pixel_key = "depth"

    env = FakeArrayObservationEnvironment()
    observation_space = env.observation_space
    assert isinstance(observation_space, spaces.Box)

    wrapped_env = PixelObservationWrapper(
        env, pixel_keys=(pixel_key,), pixels_only=pixels_only
    )
    wrapped_env.observation_space = wrapped_env.observation_space
    assert isinstance(wrapped_env.observation_space, spaces.Dict)

    if pixels_only:
        assert len(wrapped_env.observation_space.spaces) == 1
        assert list(wrapped_env.observation_space.spaces.keys()) == [pixel_key]
    else:
        assert len(wrapped_env.observation_space.spaces) == 2
        assert list(wrapped_env.observation_space.spaces.keys()) == [
            STATE_KEY,
            pixel_key,
        ]

    observation, info = wrapped_env.reset()
    depth_observation = observation[pixel_key]

    assert isinstance(info, dict)
    assert depth_observation.shape == (32, 32, 3)
    assert depth_observation.dtype == np.uint8

    if not pixels_only:
        assert isinstance(observation[STATE_KEY], np.ndarray)
