"""Download results of the hyperparameter transfer experiments."""

from __future__ import annotations

import sys
import pandas as pd

import fire

sys.path.insert(0, "../../..")
from experiments import utils


def download_results(
    created_after: str | None = None,
    group: str = "hyperparameter_transfer_v2",
) -> None:
    """Download results of the hyperparameter transfer experiments."""
    

    utils.wandb.download_results(
        file='experiment_results_train.csv',
        group=group,
        created_after=created_after,
        keys = [
            "epoch",
            "Train Loss",
        ],
    )
    # filter to only last epoch
    df_train = pd.read_csv('experiment_results_train.csv')
    df_train = df_train[df_train['epoch'] == df_train['max_epochs']-1]
    df_train.to_csv('experiment_results_train.csv')

    utils.wandb.download_results(
        file='experiment_results_test.csv',
        group=group,
        created_after=created_after,
        keys = [
            "Test Accuracy/dataloader_idx_0",
            "Test NLL/dataloader_idx_0",
            "Test ECE/dataloader_idx_0",
            "Test AUROC/dataloader_idx_1",
            "Final Validation NLL/dataloader_idx_2",
            "Final Validation NLL of mean/dataloader_idx_2",
        ],
    )
    

    utils.wandb.download_results(
        file='experiment_results_val.csv',
        group=group,
        created_after=created_after,
        keys = [
            "epoch",
            "Validation Accuracy",
            "Validation NLL",
            "Validation ECE",
        ],
    )
    # filter to only last epoch
    df_val = pd.read_csv('experiment_results_val.csv')
    df_val = df_val[df_val['epoch'] == df_val['max_epochs']-1]
    df_val.to_csv('experiment_results_val.csv')



if __name__ == "__main__":
    fire.Fire(download_results)
