import argparse

from spektral.data import BatchLoader

from magni.src.graph_classification.training import results_to_file, run_experiments
from magni.src.layers import DiffPool
from magni.src.models.classifiers import MainModel, MainModelGIN

parser = argparse.ArgumentParser()
parser.add_argument("--dataset", type=str, default="PROTEINS")
parser.add_argument("--lr", type=float, default=5e-4)
parser.add_argument("--batch-size", type=int, default=32)
parser.add_argument("--patience", type=int, default=50)
parser.add_argument("--runs", type=int, default=3)
parser.add_argument("--model", type=str, default="GNN")
parser.add_argument("--ratio", type=float, default=0.5)
args = parser.parse_args()


def create_model(n_out, **kwargs):
    pool = DiffPool(kwargs.get("k"))
    #model = MainModel(n_out, pool, mask=True)
    if args.model == "GNN": 
        model = MainModel(n_out, pool, mask=True)
    elif args.model == "GIN":
        model = MainModelGIN(n_out, pool, mask=True)
    else:
        raise ValueError(f"Unknown model: {args.model}")

    return model

#def BatchLoaderMask(dataset, batch_size):
#    return BatchLoader(dataset=dataset, batch_size=batch_size, mask=True, node_level=False)

results = run_experiments(
    runs=args.runs,
    create_model=create_model,
    loader_class=BatchLoader,
    dataset_name=args.dataset,
    learning_rate=args.lr,
    batch_size=args.batch_size,
    patience=args.patience,
    method="DiffPool",
    model_name = args.model,
    ratio=args.ratio,
)
results_to_file(args.dataset, "DiffPool", *results)
