import unittest
import numpy as np
from gym import envs
from gym.envs.tests.spec_list import skip_mujoco, SKIP_MUJOCO_WARNING_MESSAGE


def verify_environments_match(old_environment_id,
                              new_environment_id,
                              seed=1,
                              num_actions=1000):
    old_environment = envs.make(old_environment_id)
    new_environment = envs.make(new_environment_id)

    old_environment.seed(seed)
    new_environment.seed(seed)

    old_reset_observation = old_environment.reset()
    new_reset_observation = new_environment.reset()

    np.testing.assert_allclose(old_reset_observation, new_reset_observation)

    for i in range(num_actions):
        action = old_environment.action_space.sample()
        old_observation, old_reward, old_done, old_info = old_environment.step(
            action)
        new_observation, new_reward, new_done, new_info = new_environment.step(
            action)

        eps = 1e-6
        np.testing.assert_allclose(old_observation, new_observation, atol=eps)
        np.testing.assert_allclose(old_reward, new_reward, atol=eps)
        np.testing.assert_allclose(old_done, new_done, atol=eps)

        for key in old_info:
            np.testing.assert_allclose(old_info[key], new_info[key], atol=eps)


@unittest.skipIf(skip_mujoco, SKIP_MUJOCO_WARNING_MESSAGE)
class Mujocov2Tov3ConversionTest(unittest.TestCase):
    def test_environments_match(self):
        test_cases = (
            {
                'old_id': 'Swimmer-v2',
                'new_id': 'Swimmer-v3'
             },
            {
                'old_id': 'Hopper-v2',
                'new_id': 'Hopper-v3'
             },
            {
                'old_id': 'Walker2d-v2',
                'new_id': 'Walker2d-v3'
             },
            {
                'old_id': 'HalfCheetah-v2',
                'new_id': 'HalfCheetah-v3'
             },
            {
                'old_id': 'Ant-v2',
                'new_id': 'Ant-v3'
             },
            {
                'old_id': 'Humanoid-v2',
                'new_id': 'Humanoid-v3'
             },
        )

        for test_case in test_cases:
            verify_environments_match(test_case['old_id'], test_case['new_id'])

        # Raises KeyError because the new envs have extra info
        with self.assertRaises(KeyError):
            verify_environments_match('Swimmer-v3', 'Swimmer-v2')

        # Raises KeyError because the new envs have extra info
        with self.assertRaises(KeyError):
            verify_environments_match('Humanoid-v3', 'Humanoid-v2')

        # Raises KeyError because the new envs have extra info
        with self.assertRaises(KeyError):
            verify_environments_match('Swimmer-v3', 'Swimmer-v2')


if __name__ == '__main__':
    unittest.main()
