import pathlib

import click
import torch
import wandb

from ngab.graph_alignment import (
    compute_metrics,
    get_kwargs,
    setup_data,
)
from ngab.models import LaplacianEmbeddings


@click.command()
@click.option(
    "-d",
    "--dataset",
    type=click.Path(exists=True, file_okay=False, path_type=pathlib.Path),
    required=True,
    help="Path to A GADataset",
)
@click.option("--experiment", type=str, required=True, help="Experiment name")
@click.option("--run-name", type=str, required=True, help="Run name")
@click.option(
    "--features", type=int, required=True, help="Number of features per layer"
)
@click.option("--batch-size", type=int, required=True, help="Batch size")
@click.option("--cuda/--cpu", required=True, help="Training backend")
def compute_laplacian_performances(
    dataset: pathlib.Path,
    experiment: str,
    run_name: str,
    features: int,
    batch_size: int,
    cuda: bool,
):
    device = torch.device("cuda") if cuda else torch.device("cpu")

    with wandb.init(
        project=experiment,
        name=run_name,
        config=get_kwargs(),
    ) as run:

        # Load the training and validation datasets and build suitable loaders to batch the graphs together.
        (train_dataset, val_dataset, train_loader, val_loader) = setup_data(
            dataset_path=dataset,
            batch_size=batch_size,
        )

        # visualization_batch_train = build_visualization_batch(train_dataset, 1)
        # visualization_batch_val = build_visualization_batch(val_dataset, 1)

        gnn_model: torch.nn.Module = LaplacianEmbeddings(k=features)

        train_metrics = {
            f"{k}/train": v
            for (k, v) in compute_metrics(gnn_model, train_loader, device).items()
        }

        val_metrics = {
            f"{k}/val": v
            for (k, v) in compute_metrics(gnn_model, val_loader, device).items()
        }

        run.log(train_metrics)
        run.log(val_metrics)




def main():
    compute_laplacian_performances()


if __name__ == "__main__":
    main()
