"""Download results of the experiment."""

from __future__ import annotations

import sys

import fire

sys.path.insert(0, "../../..")
from experiments import utils


def download_results(
    which: list[str] = ["training", "test"],
    created_after: str | None = None,
    group: str = "sample_size_experiment",
) -> None:
    """Download results of the experiment."""
    if "training" in which:
        utils.wandb.download_results(
            file="experiment_results.csv",
            group=group,
            created_after=created_after,
            keys=[
                "Test Accuracy/dataloader_idx_0",
                "Test NLL/dataloader_idx_0",
                "Test ECE/dataloader_idx_0",
                "Test AUROC/dataloader_idx_1",
            ],
        )

    if "test" in which:
        utils.wandb.download_results(
            file="experiment_results_training.csv",
            group=group,
            created_after=created_after,
            keys=[
                "trainer/global_step",
                "epoch",
                "Validation Accuracy",
                "Validation NLL",
                "Validation ECE",
            ],
        )


if __name__ == "__main__":
    fire.Fire(download_results)
