from hydra import initialize, compose
import train_self_supervised_classifier
import pytest


def test_config() -> None:
    with initialize(config_path="../config"):
        # config is relative to a module
        cfg = compose(
            config_name="self_supervised_defaults",
            overrides=["+experiment=mmnist_vivi_linear_classifier", "mode=local_test"],
        )
        assert cfg["project"] == "tmp"


def test_config_reload_dataloaders_flag() -> None:
    with initialize(config_path="../config"):
        # config is relative to a module
        cfg = compose(
            config_name="classifier_defaults",
            overrides=["+experiment=mmnist_resnet_2d_classifier", "mode=local_test"],
        )
        assert cfg["trainer"]["reload_dataloaders_every_n_epochs"] == 1
        assert cfg["trainer"]["replace_sampler_ddp"] is False


@pytest.mark.skip(reason="Skipping MMNIST test")
def test_launch_vivi_mmnist() -> None:
    with initialize(config_path="../config"):
        # config is relative to a module
        cfg = compose(
            config_name="self_supervised_defaults",
            overrides=["+experiment=mmnist_vivi_linear_classifier", "mode=local_test"],
        )
    train_self_supervised_classifier.main(cfg)
