import torch
from sklearn.metrics import balanced_accuracy_score, roc_auc_score


def memory_stats():
    print("Memory stats for cuda: ")
    print(f"allocated memory: {torch.cuda.memory_allocated() / 1024**2}")
    print(f"reserved memory: {torch.cuda.memory_reserved() / 1024**2}")


def roc_auc_gpu_safe(y_true, y_scores, **roc_auc_args):
    """
    Like regular roc_auc_score, but checks if y_scores are
    on CPU first.

    :param y_true: integer labels
    :param y_scores: array of scores for the classes

    The rest of the arguments get passed to roc_auc_score.
    """

    if "cuda" in str(y_scores.device):
        y_scores = y_scores.cpu()
    return roc_auc_score(y_true, y_scores, **roc_auc_args)


def accuracy_gpu_safe(y_true, y_scores, thr=0.5, **accuracy_args):
    """
    Like regular accuracy, but checks if y_scores are
    on CPU first.

    :param y_true: integer labels
    :param y_scores: array of scores for the classes

    The rest of the arguments get passed to roc_auc_score.
    """

    if "cuda" in str(y_scores.device):
        y_scores = y_scores.cpu()

    _, n_shape = y_scores.shape

    if n_shape == 1:
        y_labels = (y_scores > thr).int()
    else:
        # multi-class case, no use of threshold
        y_labels = y_scores.argmax(dim=1)

    return balanced_accuracy_score(y_true, y_labels, **accuracy_args)
