from typing import Optional

import gym
import numpy as np
from numpy.testing import assert_almost_equal
import pytest

from gym.wrappers.normalize import NormalizeObservation, NormalizeReward


class DummyRewardEnv(gym.Env):
    metadata = {}

    def __init__(self, return_reward_idx=0):
        self.action_space = gym.spaces.Discrete(2)
        self.observation_space = gym.spaces.Box(
            low=np.array([-1.0]), high=np.array([1.0])
        )
        self.returned_rewards = [0, 1, 2, 3, 4]
        self.return_reward_idx = return_reward_idx
        self.t = self.return_reward_idx

    def step(self, action):
        self.t += 1
        return np.array([self.t]), self.t, self.t == len(self.returned_rewards), {}

    def reset(
        self,
        *,
        seed: Optional[int] = None,
        return_info: Optional[bool] = False,
        options: Optional[dict] = None
    ):
        super().reset(seed=seed)
        self.t = self.return_reward_idx
        if not return_info:
            return np.array([self.t])
        else:
            return np.array([self.t]), {}


def make_env(return_reward_idx):
    def thunk():
        env = DummyRewardEnv(return_reward_idx)
        return env

    return thunk


@pytest.mark.parametrize("return_info", [False, True])
def test_normalize_observation(return_info: bool):
    env = DummyRewardEnv(return_reward_idx=0)
    env = NormalizeObservation(env)
    env.reset(return_info=return_info)
    env.step(env.action_space.sample())
    assert_almost_equal(env.obs_rms.mean, 0.5, decimal=4)
    env.step(env.action_space.sample())
    assert_almost_equal(env.obs_rms.mean, 1.0, decimal=4)


def test_normalize_reset_info():
    env = DummyRewardEnv(return_reward_idx=0)
    env = NormalizeObservation(env)
    obs = env.reset()
    assert isinstance(obs, np.ndarray)
    del obs
    obs = env.reset(return_info=False)
    assert isinstance(obs, np.ndarray)
    del obs
    obs, info = env.reset(return_info=True)
    assert isinstance(obs, np.ndarray)
    assert isinstance(info, dict)


def test_normalize_return():
    env = DummyRewardEnv(return_reward_idx=0)
    env = NormalizeReward(env)
    env.reset()
    env.step(env.action_space.sample())
    assert_almost_equal(
        env.return_rms.mean,
        np.mean([1]),  # [first return]
        decimal=4,
    )
    env.step(env.action_space.sample())
    assert_almost_equal(
        env.return_rms.mean,
        np.mean([2 + env.gamma * 1, 1]),  # [second return, first return]
        decimal=4,
    )


def test_normalize_observation_vector_env():
    env_fns = [make_env(0), make_env(1)]
    envs = gym.vector.SyncVectorEnv(env_fns)
    envs.reset()
    obs, reward, _, _ = envs.step(envs.action_space.sample())
    np.testing.assert_almost_equal(obs, np.array([[1], [2]]), decimal=4)
    np.testing.assert_almost_equal(reward, np.array([1, 2]), decimal=4)

    env_fns = [make_env(0), make_env(1)]
    envs = gym.vector.SyncVectorEnv(env_fns)
    envs = NormalizeObservation(envs)
    envs.reset()
    assert_almost_equal(
        envs.obs_rms.mean,
        np.mean([0.5]),  # the mean of first observations [[0, 1]]
        decimal=4,
    )
    obs, reward, _, _ = envs.step(envs.action_space.sample())
    assert_almost_equal(
        envs.obs_rms.mean,
        np.mean([1.0]),  # the mean of first and second observations [[0, 1], [1, 2]]
        decimal=4,
    )


def test_normalize_return_vector_env():
    env_fns = [make_env(0), make_env(1)]
    envs = gym.vector.SyncVectorEnv(env_fns)
    envs = NormalizeReward(envs)
    obs = envs.reset()
    obs, reward, _, _ = envs.step(envs.action_space.sample())
    assert_almost_equal(
        envs.return_rms.mean,
        np.mean([1.5]),  # the mean of first returns [[1, 2]]
        decimal=4,
    )
    obs, reward, _, _ = envs.step(envs.action_space.sample())
    assert_almost_equal(
        envs.return_rms.mean,
        np.mean(
            [[1, 2], [2 + envs.gamma * 1, 3 + envs.gamma * 2]]
        ),  # the mean of first and second returns [[1, 2], [2 + envs.gamma * 1, 3 + envs.gamma * 2]]
        decimal=4,
    )
