import numpy as np
from numpy.typing import NDArray

import wandb


def get_test_data_artifact(
    artifact_name: str, skip_cache: bool = True
) -> tuple[NDArray[np.floating], NDArray[np.integer]]:
    api = wandb.Api()
    artifact = api.artifact(artifact_name)
    artifact.download(skip_cache=True)

    # Load test predictions
    test_predictions_table = artifact.get("test_predictions").get_dataframe()  # type: ignore
    test_preds = np.array(test_predictions_table["predictions"])
    test_labels = np.array(test_predictions_table["labels"])
    return test_preds, test_labels


def get_validation_and_test_data_artifacts(
    artifact_names: list[str], skip_cache: bool = True
) -> tuple[
    list[NDArray[np.floating]],
    list[NDArray[np.floating]],
    list[NDArray[np.integer]],
    list[NDArray[np.floating]],
    list[NDArray[np.integer]],
]:
    api = wandb.Api()
    artifacts = [api.artifact(name) for name in artifact_names]
    for artifact in artifacts:
        artifact.download(skip_cache=True)

    # Load validation predictions
    unlabeled_preds_list = []
    unlabeled_labels_list = []
    positive_preds_list = []
    test_preds_list = []
    test_labels_list = []
    for artifact in artifacts:
        # Load validation unlabeled predictions
        unlabeled_predictions_table = artifact.get("validation_unlabeled_predictions").get_dataframe()  # type: ignore
        unlabeled_preds = np.array(unlabeled_predictions_table["predictions"])
        unlabeled_labels = np.array(unlabeled_predictions_table["labels"])
        unlabeled_preds_list.append(unlabeled_preds)
        unlabeled_labels_list.append(unlabeled_labels)

        # Load validation positive predictions
        positive_predictions_table = artifact.get("validation_positive_predictions").get_dataframe()  # type: ignore
        positive_preds = np.array(positive_predictions_table["predictions"])
        positive_preds_list.append(positive_preds)

        # Load test predictions
        test_predictions_table = artifact.get("test_predictions").get_dataframe()  # type: ignore
        test_preds = np.array(test_predictions_table["predictions"])
        test_labels = np.array(test_predictions_table["labels"])
        test_preds_list.append(test_preds)
        test_labels_list.append(test_labels)

    return (positive_preds_list, unlabeled_preds_list, unlabeled_labels_list, test_preds_list, test_labels_list)
