import argparse

from magni.src.graph_classification.training_nt import (results_to_file,
                                                  run_experiments)
from magni.src.modules.nmf import NMF, preprocess
from magni.src.modules.repeat_pooling import repeat_pooling

parser = argparse.ArgumentParser()
parser.add_argument("--dataset", type=str, default="PROTEINS")
parser.add_argument("--lr", type=float, default=5e-4)
parser.add_argument("--batch-size", type=int, default=32)
parser.add_argument("--patience", type=int, default=50)
parser.add_argument("--runs", type=int, default=3)
parser.add_argument("--model", type=str, default="GNN")
parser.add_argument("--ratio", type=float, default=0.5)
args = parser.parse_args()


def pooling_once(X, A, ratio=0.5):
    _, A_in = zip(*[preprocess(x, a) for x, a in zip(X, A)])
    A_pool, S = NMF(A_in, 0.5)
    return X, A, A_pool, S

def pooling(X, A, **kwargs):
    ratio = kwargs.get("ratio")
    #k = kwargs.get("k")
    if ratio == 0.5:
        return pooling_once(X, A)
    elif ratio > 0.54:
        raise ValueError("Pooling ratios above 0.5 are not supported.")
    else:
        return repeat_pooling(X, A, ratio=ratio, pooling = pooling_once)


results = run_experiments(
    runs=args.runs,
    pooling=pooling,
    dataset_name=args.dataset,
    learning_rate=args.lr,
    batch_size=args.batch_size,
    patience=args.patience,
    method="NMF",
    model_name = args.model,
    ratio=args.ratio,
)
results_to_file(args.dataset, "NMF", *results)
