import unittest
import numpy as np
from gym import envs
from gym.envs.tests.spec_list import skip_mujoco, SKIP_MUJOCO_WARNING_MESSAGE


def verify_environments_match(
    old_environment_id, new_environment_id, seed=1, num_actions=1000
):
    old_environment = envs.make(old_environment_id)
    new_environment = envs.make(new_environment_id)

    old_environment.seed(seed)
    new_environment.seed(seed)

    old_reset_observation = old_environment.reset()
    new_reset_observation = new_environment.reset()

    np.testing.assert_allclose(old_reset_observation, new_reset_observation)

    for i in range(num_actions):
        action = old_environment.action_space.sample()
        old_observation, old_reward, old_done, old_info = old_environment.step(action)
        new_observation, new_reward, new_done, new_info = new_environment.step(action)

        eps = 1e-6
        np.testing.assert_allclose(old_observation, new_observation, atol=eps)
        np.testing.assert_allclose(old_reward, new_reward, atol=eps)
        np.testing.assert_allclose(old_done, new_done, atol=eps)

        for key in old_info:
            np.testing.assert_allclose(old_info[key], new_info[key], atol=eps)


@unittest.skipIf(skip_mujoco, SKIP_MUJOCO_WARNING_MESSAGE)
class Mujocov2Tov3ConversionTest(unittest.TestCase):
    def test_environments_match(self):
        test_cases = (
            {"old_id": "Swimmer-v2", "new_id": "Swimmer-v3"},
            {"old_id": "Hopper-v2", "new_id": "Hopper-v3"},
            {"old_id": "Walker2d-v2", "new_id": "Walker2d-v3"},
            {"old_id": "HalfCheetah-v2", "new_id": "HalfCheetah-v3"},
            {"old_id": "Ant-v2", "new_id": "Ant-v3"},
            {"old_id": "Humanoid-v2", "new_id": "Humanoid-v3"},
        )

        for test_case in test_cases:
            verify_environments_match(test_case["old_id"], test_case["new_id"])

        # Raises KeyError because the new envs have extra info
        with self.assertRaises(KeyError):
            verify_environments_match("Swimmer-v3", "Swimmer-v2")

        # Raises KeyError because the new envs have extra info
        with self.assertRaises(KeyError):
            verify_environments_match("Humanoid-v3", "Humanoid-v2")

        # Raises KeyError because the new envs have extra info
        with self.assertRaises(KeyError):
            verify_environments_match("Swimmer-v3", "Swimmer-v2")


if __name__ == "__main__":
    unittest.main()
