import os

import pytest

from stable_baselines3 import A2C, PPO, SAC, TD3

MODEL_DICT = {
    "a2c": (A2C, "CartPole-v1"),
    "ppo": (PPO, "CartPole-v1"),
    "sac": (SAC, "Pendulum-v1"),
    "td3": (TD3, "Pendulum-v1"),
}

N_STEPS = 100


@pytest.mark.parametrize("model_name", MODEL_DICT.keys())
def test_tensorboard(tmp_path, model_name):
    # Skip if no tensorboard installed
    pytest.importorskip("tensorboard")

    logname = model_name.upper()
    algo, env_id = MODEL_DICT[model_name]
    model = algo("MlpPolicy", env_id, verbose=1, tensorboard_log=tmp_path)
    model.learn(N_STEPS)
    model.learn(N_STEPS, reset_num_timesteps=False)

    assert os.path.isdir(tmp_path / str(logname + "_1"))
    assert not os.path.isdir(tmp_path / str(logname + "_2"))

    logname = "tb_multiple_runs_" + model_name
    model.learn(N_STEPS, tb_log_name=logname)
    model.learn(N_STEPS, tb_log_name=logname)

    assert os.path.isdir(tmp_path / str(logname + "_1"))
    # Check that the log dir name increments correctly
    assert os.path.isdir(tmp_path / str(logname + "_2"))
