from pytest import fixture
from cross_validate import BestRuns


class TestBestRuns:
    sweep_dir = "DATADIR/logs/tmp/shapes_resnet_linear_shapes_classifier/2022-02-11_13-50-03"

    @fixture(scope="class")
    def best_runs(self):
        return BestRuns(self.sweep_dir, "val_canonical_loss")

    def test_best_run_id(self, best_runs):
        assert best_runs.best_run_id == "1"

    def test_id_to_metrics(self, best_runs):
        id_to_metrics = best_runs.id_to_metrics
        assert isinstance(id_to_metrics, dict)
        metrics = {
            "val_canonical_loss",
            "train_canonical_top_1_accuracy",
            "test_diverse_2d_top_1_accuracy",
        }
        assert metrics.issubset(set(id_to_metrics["1"]))

    def test_id_to_parameters(self, best_runs):
        id_to_params = best_runs.id_to_parameters
        assert isinstance(id_to_params, dict)
        assert "learning_rate" in id_to_params["1"]

    def test_find_matching_runs(self, best_runs):
        assert len(best_runs.find_matching_runs("1")) == 2

    def test_best_runs_metric(self, best_runs):
        best_top_1 = best_runs.get_best_runs_metric("train_canonical_top_1_accuracy")
        assert len(best_top_1) == 2
        assert best_top_1[0] >= 0
