"""Test suite for DelayObservationV0."""
import re

import pytest

import gymnasium as gym
from gymnasium.experimental.wrappers import DelayObservationV0
from gymnasium.experimental.wrappers.utils import create_zero_array
from gymnasium.utils.env_checker import data_equivalence
from tests.experimental.wrappers.utils import (
    SEED,
    TESTING_OBS_ENVS,
    TESTING_OBS_ENVS_IDS,
)


@pytest.mark.parametrize("env", TESTING_OBS_ENVS, ids=TESTING_OBS_ENVS_IDS)
def test_env_obs(env, delay: int = 3, extra_steps: int = 4):
    """Tests the delay observation wrapper."""
    env.action_space.seed(SEED)
    obs, _ = env.reset(seed=SEED)

    undelayed_obs = [obs]
    for _ in range(delay + extra_steps):
        obs, _, _, _, _ = env.step(env.action_space.sample())
        undelayed_obs.append(obs)

    env = DelayObservationV0(env, delay=delay)
    example_zero_obs = create_zero_array(env.observation_space)
    env.action_space.seed(SEED)
    obs, _ = env.reset(seed=SEED)
    assert data_equivalence(obs, example_zero_obs)
    assert obs in env.observation_space

    delayed_obs = [obs]
    for i in range(delay + extra_steps):
        obs, _, _, _, _ = env.step(env.action_space.sample())
        delayed_obs.append(obs)
        assert obs in env.observation_space

        if i < delay - 1:
            assert data_equivalence(obs, example_zero_obs)

    assert data_equivalence(delayed_obs[delay:], undelayed_obs[:-delay])


@pytest.mark.parametrize("delay", [1, 2, 3, 4])
def test_delay_values(delay):
    """Test the possible delay values for the DelayObservation wrapper."""
    env = gym.make("CartPole-v1")
    first_obs, _ = env.reset(seed=123)

    env = DelayObservationV0(gym.make("CartPole-v1"), delay=delay)
    zero_obs = create_zero_array(env.observation_space)
    obs, _ = env.reset(seed=123)
    assert data_equivalence(obs, zero_obs)
    for _ in range(delay - 1):
        obs, _, _, _, _ = env.step(env.action_space.sample())
        assert data_equivalence(obs, zero_obs)

    obs, _, _, _, _ = env.step(env.action_space.sample())
    assert data_equivalence(first_obs, obs)


def test_delay_failures():
    """Test errors raised by DelayObservation wrapper."""
    env = gym.make("CartPole-v1")

    with pytest.raises(
        TypeError,
        match=re.escape(
            "The delay is expected to be an integer, actual type: <class 'float'>"
        ),
    ):
        DelayObservationV0(env, delay=1.0)

    with pytest.raises(
        ValueError,
        match=re.escape("The delay needs to be greater than zero, actual value: -1"),
    ):
        DelayObservationV0(env, delay=-1)
