import os
import shutil
import subprocess
import unittest


class TestAtari(unittest.TestCase):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.base_args = (
            "--seed 1 --disable_wandb --features 2 3 1 15 --replay_buffer_capacity 100 --batch_size 3 --update_horizon 1 "
            + "--horizon 10 --n_epochs 3 --n_training_steps_per_epoch 5 --update_to_data 3 "
            + "--target_update_frequency 3 --n_initial_samples 3 --epsilon_end 0.01 --epsilon_duration 4"
        )

    def run_core_test(self, algo_name, algo_args):
        save_path = os.path.join(
            os.path.dirname(os.path.abspath(__file__)), f"../experiments/atari/exp_output/_test_{algo_name}_Pong"
        )
        if os.path.exists(save_path):
            shutil.rmtree(save_path)

        returncode = subprocess.run(
            (
                f"python3 experiments/atari/{algo_name}.py --experiment_name _test_{algo_name}_Pong {self.base_args} {algo_args}"
            ).split(" ")
        ).returncode
        assert returncode == 0, "The command should not have raised an error."

        shutil.rmtree(save_path)

    def test_dqn(self):
        self.run_core_test("dqn", "--gamma 0.9")

    def test_metadqn(self):
        self.run_core_test("metadqn", "--gamma_init 0.99 --gamma_validation 0.995 --meta_learning_rate 0.001")

    def test_adadqn(self):
        self.run_core_test(
            "adadqn",
            "--gamma_range 0.1 1 --gamma_validation 0.995 --n_networks 5 --exploitation_type elitism --hp_update_frequency 3",
        )
