import unittest

import ray
import agents.impala as impala
import agents.pg as pg
from utils.error import EnvError
from utils.test_utils import framework_iterator


class TestErrors(unittest.TestCase):
    """Tests various failure-modes, making sure we produce meaningful errmsgs."""

    @classmethod
    def setUpClass(cls) -> None:
        ray.init()

    @classmethod
    def tearDownClass(cls) -> None:
        ray.shutdown()

    def test_no_gpus_error(self):
        """Tests errors related to no-GPU/too-few GPUs/etc.

        This test will only work ok on a CPU-only machine.
        """

        config = impala.DEFAULT_CONFIG.copy()
        env = "CartPole-v0"

        for _ in framework_iterator(config):
            self.assertRaisesRegex(
                RuntimeError,
                # (?s): "dot matches all" (also newlines).
                "(?s)Found 0 GPUs on your machine.+To change the config",
                lambda: impala.ImpalaTrainer(config=config, env=env),
            )

    def test_bad_envs(self):
        """Tests different "bad env" errors."""
        config = pg.DEFAULT_CONFIG.copy()
        config["num_workers"] = 0

        # Non existing/non-registered gym env string.
        env = "Alien-Attack-v42"
        for _ in framework_iterator(config):
            self.assertRaisesRegex(
                EnvError,
                f"The env string you provided \\('{env}'\\) is",
                lambda: pg.PGTrainer(config=config, env=env),
            )

        # Malformed gym env string (must have v\d at end).
        env = "Alien-Attack-part-42"
        for _ in framework_iterator(config):
            self.assertRaisesRegex(
                EnvError,
                f"The env string you provided \\('{env}'\\) is",
                lambda: pg.PGTrainer(config=config, env=env),
            )

        # Non-existing class in a full-class-path.
        env = "examples.env.random_env.RandomEnvThatDoesntExist"
        for _ in framework_iterator(config):
            self.assertRaisesRegex(
                EnvError,
                f"The env string you provided \\('{env}'\\) is",
                lambda: pg.PGTrainer(config=config, env=env),
            )

        # Non-existing module inside a full-class-path.
        env = "examples.env.module_that_doesnt_exist.SomeEnv"
        for _ in framework_iterator(config):
            self.assertRaisesRegex(
                EnvError,
                f"The env string you provided \\('{env}'\\) is",
                lambda: pg.PGTrainer(config=config, env=env),
            )


if __name__ == "__main__":
    import pytest
    import sys

    sys.exit(pytest.main(["-v", __file__]))
