import torchmetrics
import torch
from tqdm import tqdm
import ipdb
def GPPTEva(data, idx_test, gnn, prompt, num_class, device):
    # gnn.eval()
    prompt.eval()
    accuracy = torchmetrics.classification.Accuracy(task="multiclass", num_classes=num_class).to(device)
    macro_f1 = torchmetrics.classification.F1Score(task="multiclass", num_classes=num_class, average="macro").to(device)
    auroc = torchmetrics.classification.AUROC(task="multiclass", num_classes=num_class).to(device)
    auprc = torchmetrics.classification.AveragePrecision(task="multiclass", num_classes=num_class).to(device)

    accuracy.reset()
    macro_f1.reset()
    auroc.reset()
    auprc.reset()
    criterion = torch.nn.CrossEntropyLoss()
    node_embedding = gnn(data.x, data.edge_index)
    out = prompt(node_embedding, data.edge_index)
    loss = criterion(out[idx_test], data.y[idx_test])
    pred = out.argmax(dim=1)
    acc = accuracy(pred[idx_test], data.y[idx_test])
    return acc.item(), pred[idx_test], out[idx_test], loss.item(), data.y[idx_test]