import argparse

from spektral.data import DisjointLoader

from magni.src.graph_classification.training import results_to_file, run_experiments
from magni.src.layers import TopKPool
from magni.src.models.classifiers import MainModel, MainModelGIN

parser = argparse.ArgumentParser()
parser.add_argument("--dataset", type=str, default="PROTEINS")
parser.add_argument("--lr", type=float, default=5e-4)
parser.add_argument("--batch-size", type=int, default=32)
parser.add_argument("--patience", type=int, default=50)
parser.add_argument("--runs", type=int, default=3)
parser.add_argument("--model", type=str, default="GNN")
parser.add_argument("--ratio", type=float, default=0.5)
args = parser.parse_args()


def create_model(n_out, **kwargs):
    pool = TopKPool(kwargs.get("ratio"))
    if args.model == "GNN": 
        model = MainModel(n_out, pool)
    elif args.model == "GIN":
        model = MainModelGIN(n_out, pool)
    else:
        raise ValueError(f"Unknown model: {args.model}")

    return model


results = run_experiments(
    runs=args.runs,
    create_model=create_model,
    loader_class=DisjointLoader,
    dataset_name=args.dataset,
    learning_rate=args.lr,
    batch_size=args.batch_size,
    patience=args.patience,
    method="TopK",
    model_name = args.model,
    ratio=args.ratio,
)
results_to_file(args.dataset, "TopK", *results)
