import torchmetrics
import torch
from tqdm import tqdm
def AllInOneEva(loader, gnn, module_list, num_class, device):

    module_list.eval()

    accuracy = torchmetrics.classification.Accuracy(task="multiclass", num_classes=num_class).to(device)
    macro_f1 = torchmetrics.classification.F1Score(task="multiclass", num_classes=num_class, average="macro").to(device)
    auroc = torchmetrics.classification.AUROC(task="multiclass", num_classes=num_class).to(device)
    auprc = torchmetrics.classification.AveragePrecision(task="multiclass", num_classes=num_class).to(device)

    accuracy.reset()
    macro_f1.reset()
    auroc.reset()
    auprc.reset()
    pres = []
    criterion = torch.nn.CrossEntropyLoss()
    true_labels = []
    pre_labels = []
    with torch.no_grad():
        for batch_id, batch in enumerate(loader):
            true_labels.append(batch.y)
            batch = batch.to(device)
            prompted_graph = module_list[0](batch)
            graph_emb = gnn(prompted_graph.x, prompted_graph.edge_index, prompted_graph.batch)
            pre = module_list[1](graph_emb)
            loss = criterion(pre, batch.y)

            pred = pre.argmax(dim=1)
            pre_labels.append(pred)
            acc = accuracy(pred, batch.y)
            pres.append(pre)
    acc = accuracy.compute()
    pres_ = torch.cat(pres, dim=0)

    return acc.item(), torch.cat(pre_labels, dim=0), pres_, loss.item(), torch.cat(true_labels, dim=0)
