import numpy as np
import torch
from sklearn.metrics import roc_auc_score, average_precision_score

def auc_score_train(predictions, labels, target_ids):
    AUCs = list()
    target_id_list = list()

    for target_idx in torch.unique(target_ids):
        rows = torch.where(target_ids == target_idx)
        preds = predictions[rows].detach()
        y = labels[rows]

        if torch.unique(y).shape[0] == 2:
            auc = roc_auc_score(y,preds)
            AUCs.append(auc)
            target_id_list.append(target_idx.item())
        else:
            AUCs.append(np.nan)
            target_id_list.append(target_idx.item())
    return np.nanmean(AUCs), AUCs, target_id_list

def deltaAUPRC_score_train(predictions, labels, target_ids):
    deltaAUPRCs = list()
    target_id_list = list()

    for target_idx in torch.unique(target_ids):
        rows = torch.where(target_ids == target_idx)
        preds = predictions[rows].detach()
        y = labels[rows].int()

        if torch.unique(y).shape[0] == 2:
            nbrActives = y[y == 1].shape[0]
            nbrInactives = y[y == 0].shape[0]
            nbrTotal = nbrActives + nbrInactives

            random_clf_auprc = nbrActives / nbrTotal
            auprc = average_precision_score(y.numpy().flatten(), preds.numpy().flatten())

            deltaAuprc = auprc - random_clf_auprc

            deltaAUPRCs.append(deltaAuprc)
            target_id_list.append(target_idx.item())
        else:
            deltaAUPRCs.append(np.nan)
            target_id_list.append(target_idx.item())


    return np.nanmean(deltaAUPRCs), deltaAUPRCs, target_id_list




