import numpy as np
import pytest
from stable_baselines3.common.noise import NormalActionNoise

from gsde_custom import SAC

normal_action_noise = NormalActionNoise(np.zeros(1), 0.1 * np.ones(1))


@pytest.mark.parametrize("ent_coef", ["auto", 0.01, "auto_0.01"])
def test_sac(ent_coef):
    model = SAC(
        "MlpPolicy",
        "Pendulum-v0",
        policy_kwargs=dict(net_arch=[64, 64]),
        learning_starts=100,
        verbose=1,
        create_eval_env=True,
        ent_coef=ent_coef,
    )
    model.learn(total_timesteps=300, eval_freq=250)


@pytest.mark.parametrize("n_critics", [1, 3])
def test_n_critics(n_critics):
    # Test SAC with different number of critics
    model = SAC(
        "MlpPolicy",
        "Pendulum-v0",
        policy_kwargs=dict(net_arch=[64], n_critics=n_critics),
        learning_starts=100,
        verbose=1,
    )
    model.learn(total_timesteps=300)


@pytest.mark.parametrize("deterministic_exploration", [False, True])
def test_sac_param_noise(deterministic_exploration):
    model = SAC(
        "MlpPolicy",
        "Pendulum-v0",
        policy_kwargs=dict(net_arch=[64, 64], layer_norm=True),
        learning_starts=100,
        verbose=1,
        create_eval_env=True,
        use_param_noise=True,
        deterministic_exploration=deterministic_exploration,
    )
    model.learn(total_timesteps=300, eval_freq=250)
