import torchmetrics
import torch
from tqdm import tqdm
import ipdb
def GPFEva(loader, gnn, module_list, num_class, device):
    gnn.eval()
    module_list.eval()
    accuracy = torchmetrics.classification.Accuracy(task="multiclass", num_classes=num_class).to(device)
    macro_f1 = torchmetrics.classification.F1Score(task="multiclass", num_classes=num_class, average="macro").to(device)
    auroc = torchmetrics.classification.AUROC(task="multiclass", num_classes=num_class).to(device)
    auprc = torchmetrics.classification.AveragePrecision(task="multiclass", num_classes=num_class).to(device)
    criterion = torch.nn.CrossEntropyLoss()
    accuracy.reset()
    macro_f1.reset()
    auroc.reset()
    auprc.reset()
    outs = []
    labels = []
    features = []
    acc = 0.0
    loss = 0.0
    pre_labels = []
    with torch.no_grad(): 
        for batch_id, batch in enumerate(loader): 
            labels.append(batch.y)
            batch = batch.to(device) 
            batch.x = module_list[0](batch.x)
            features.append(batch.x) # batch.x: the node features of the target nodes and their corresponding neighboring nodes.
            out = gnn(batch.x, batch.edge_index, batch.batch)
            out = module_list[1](out)  
            pred = out.argmax(dim=1)  
            pre_labels.append(pred)
            loss += criterion(out, batch.y)
            acc += accuracy(pred, batch.y).item()
            outs.append(out)
       
    return acc / len(loader), torch.cat(pre_labels, dim=0), torch.cat(outs, dim=0), loss.item() / len(loader), torch.cat(labels, dim=0)