# execute as
# python -m evaluation.visualise_model_performance <checkpoint_path>

import torch

from data_generation import SimpleGraphConfig
from visualise_graph import visualise_graph
from ._util import parse_checkpoint_path_argument, prepare_model


@torch.no_grad()
def visualise_model_performance(checkpoint_path: str):
    model = prepare_model(checkpoint_path)
    config = SimpleGraphConfig(num_nodes=20, num_clusters=3, min_edges_between_clusters=4, max_edges_between_clusters=4)
    graph = config.generate_graph()
    edge_predictions = torch.sigmoid(model(graph))
    visualise_graph(graph, edge_predictions)


if __name__ == "__main__":
    checkpoint_path = parse_checkpoint_path_argument()
    visualise_model_performance(checkpoint_path)
