import torch
from torchmetrics import AUROC
from torchmetrics.functional import average_precision

auroc = AUROC(pos_label=1)
def class_eval(feat_dict):
    precisions = []
    recalls = []
    fscores = []
    AUCs = []
    APs = []
    with torch.no_grad():
        b_iface_p2p = feat_dict['iface_p2p']
        b_pred_prob = torch.sigmoid(feat_dict['bsp'].squeeze())
        b_pred_label = b_pred_prob > 0.5
        batch_idx = feat_dict['num_verts']
        pred_prob_list = torch.split(b_pred_prob, batch_idx)
        pred_label_list = torch.split(b_pred_label, batch_idx)
        for bid in range(len(batch_idx)):
            pred_label = pred_label_list[bid]
            pred_prob = pred_prob_list[bid]
            label = torch.zeros(pred_label.size(0), dtype=torch.long).to(pred_label.device)
            label[b_iface_p2p[bid][:, 0]] = 1
            tp = torch.sum((pred_label == label) & (pred_label==1))
            ps = torch.sum(pred_label)
            rs = torch.sum(label)
            prec = tp/ps if ps > 0 else 0
            rec = tp/rs if rs > 0 else 0
            fsc = (2*prec*rec)/(prec+rec) if prec + rec > 0 else 0
            precisions.append(prec)
            recalls.append(rec)
            fscores.append(fsc)
            AUC = auroc(pred_prob, label)
            AUCs.append(AUC)
            AP = average_precision(pred_prob, label)
            APs.append(AP)
        
    return precisions, recalls, fscores, AUCs, APs


