import pytest

import gym
from gym.wrappers import ResizeObservation

pytest.importorskip("gym.envs.atari")


@pytest.mark.parametrize(
    "env_id", ["PongNoFrameskip-v0", "SpaceInvadersNoFrameskip-v0"]
)
@pytest.mark.parametrize("shape", [16, 32, (8, 5), [10, 7]])
def test_resize_observation(env_id, shape):
    env = gym.make(env_id)
    env = ResizeObservation(env, shape)

    assert env.observation_space.shape[-1] == 3
    obs = env.reset()
    if isinstance(shape, int):
        assert env.observation_space.shape[:2] == (shape, shape)
        assert obs.shape == (shape, shape, 3)
    else:
        assert env.observation_space.shape[:2] == tuple(shape)
        assert obs.shape == tuple(shape) + (3,)
