from sklearn import metrics
import torch
from torch.nn import Module
from torch_geometric.loader import DataLoader


@torch.no_grad()
def calculate_area_under_roc_curve(model: Module, val_loader: DataLoader, device: torch.device | str) -> float:
    """
    Runs the model on the given validation graphs, then calculates the area under the ROC curve based on the results.

    The graphs must have a ground truth cut attached.
    """
    model.eval()

    edge_predictions = torch.cat([model(graph.to(device)) for graph in val_loader]).cpu()
    ground_truth_cuts = torch.cat([graph.y for graph in val_loader])

    return metrics.roc_auc_score(ground_truth_cuts, edge_predictions)
