"""Tests for the flatten observation wrapper."""

from collections import OrderedDict

import numpy as np
import pytest

import gym
from gym.spaces import Box, Dict, unflatten, flatten
from gym.wrappers import FlattenObservation


class FakeEnvironment(gym.Env):
    def __init__(self, observation_space):
        self.observation_space = observation_space

    def reset(self):
        self.observation = self.observation_space.sample()
        return self.observation


OBSERVATION_SPACES = (
    (
        Dict(
            OrderedDict(
                [
                    ("key1", Box(shape=(2, 3), low=0, high=0, dtype=np.float32)),
                    ("key2", Box(shape=(), low=1, high=1, dtype=np.float32)),
                    ("key3", Box(shape=(2,), low=2, high=2, dtype=np.float32)),
                ]
            )
        ),
        True,
    ),
    (
        Dict(
            OrderedDict(
                [
                    ("key2", Box(shape=(), low=0, high=0, dtype=np.float32)),
                    ("key3", Box(shape=(2,), low=1, high=1, dtype=np.float32)),
                    ("key1", Box(shape=(2, 3), low=2, high=2, dtype=np.float32)),
                ]
            )
        ),
        True,
    ),
    (
        Dict(
            {
                "key1": Box(shape=(2, 3), low=-1, high=1, dtype=np.float32),
                "key2": Box(shape=(), low=-1, high=1, dtype=np.float32),
                "key3": Box(shape=(2,), low=-1, high=1, dtype=np.float32),
            }
        ),
        False,
    ),
)


class TestFlattenEnvironment(object):
    @pytest.mark.parametrize("observation_space, ordered_values", OBSERVATION_SPACES)
    def test_flattened_environment(self, observation_space, ordered_values):
        """
        make sure that flattened observations occur in the order expected
        """
        env = FakeEnvironment(observation_space=observation_space)
        wrapped_env = FlattenObservation(env)
        flattened = wrapped_env.reset()

        unflattened = unflatten(env.observation_space, flattened)
        original = env.observation

        self._check_observations(original, flattened, unflattened, ordered_values)

    @pytest.mark.parametrize("observation_space, ordered_values", OBSERVATION_SPACES)
    def test_flatten_unflatten(self, observation_space, ordered_values):
        """
        test flatten and unflatten functions directly
        """
        original = observation_space.sample()

        flattened = flatten(observation_space, original)
        unflattened = unflatten(observation_space, flattened)

        self._check_observations(original, flattened, unflattened, ordered_values)

    def _check_observations(self, original, flattened, unflattened, ordered_values):
        # make sure that unflatten(flatten(original)) == original
        assert set(unflattened.keys()) == set(original.keys())
        for k, v in original.items():
            np.testing.assert_allclose(unflattened[k], v)

        if ordered_values:
            # make sure that the values were flattened in the order they appeared in the
            # OrderedDict
            np.testing.assert_allclose(sorted(flattened), flattened)
