import torch
from torch.nn import Module
from torch_geometric.loader import DataLoader

from data_generation import Dataset
from ._util import parse_checkpoint_path_argument, prepare_model


@torch.no_grad()
def evaluate_model(model: Module, val_loader: DataLoader, device: torch.device | str) -> float:
    """
    Runs the model on the given validation dataset and returns the edge classification accuracy.
    """
    model.eval()
    model.to(device)

    correct_edges = 0
    total_edges = 0

    for graph in val_loader:
        graph.to(device)
        edge_predictions = model(graph)
        predicted_labels = edge_predictions > 0

        correct_edges += (predicted_labels == graph.y).sum()
        total_edges += predicted_labels.size(0)

    val_accuracy = correct_edges / total_edges
    return val_accuracy.item()


def _main(checkpoint_path: str, dataset_name: str, batch_size: int = 64):
    model = prepare_model(checkpoint_path)
    dataset = Dataset.load(dataset_name)
    val_loader = DataLoader(dataset.val_graphs, batch_size)
    val_accuracy = evaluate_model(model, val_loader, device="cpu")
    print("Validation accuracy:", val_accuracy)


if __name__ == "__main__":
    checkpoint_path = parse_checkpoint_path_argument()
    dataset_name = "simple-graph-plus--n-100--k-2--max-cut-value-30"
    _main(checkpoint_path, dataset_name)
