import random
import unittest

import gym
from src.rllib.env.wrappers.exception_wrapper import ResetOnExceptionWrapper, \
    TooManyResetAttemptsException


class TestResetOnExceptionWrapper(unittest.TestCase):
    def test_unstable_env(self):
        class UnstableEnv(gym.Env):
            observation_space = gym.spaces.Discrete(2)
            action_space = gym.spaces.Discrete(2)

            def step(self, action):
                if random.choice([True, False]):
                    raise ValueError("An error from a unstable environment.")
                return self.observation_space.sample(), 0.0, False, {}

            def reset(self):
                return self.observation_space.sample()

        env = UnstableEnv()
        env = ResetOnExceptionWrapper(env)

        try:
            self._run_for_100_steps(env)
        except Exception:
            self.fail()

    def test_very_unstable_env(self):
        class VeryUnstableEnv(gym.Env):
            observation_space = gym.spaces.Discrete(2)
            action_space = gym.spaces.Discrete(2)

            def step(self, action):
                return self.observation_space.sample(), 0.0, False, {}

            def reset(self):
                raise ValueError("An error from a very unstable environment.")

        env = VeryUnstableEnv()
        env = ResetOnExceptionWrapper(env)
        self.assertRaises(TooManyResetAttemptsException,
                          lambda: self._run_for_100_steps(env))

    @staticmethod
    def _run_for_100_steps(env):
        env.reset()
        for _ in range(100):
            env.step(env.action_space.sample())


if __name__ == "__main__":
    import sys
    import pytest
    sys.exit(pytest.main(["-v", __file__]))
