import argparse

from magni.src.graph_classification.training_nt import (results_to_file,
                                                 run_experiments)

from magni.src.modules.pooling_utils import to_nx_graph
from magni.src.modules.compare_graphs import choose_graph_metric
import networkx as nx
import tensorflow as tf

from magni.src.edge_dropping_magnitude import edge_pooling_magnitude, edge_pooling_magnitude_repeated

parser = argparse.ArgumentParser()
parser.add_argument("--dataset", type=str, default="PROTEINS")
parser.add_argument("--lr", type=float, default=5e-4)
parser.add_argument("--batch-size", type=int, default=32)
parser.add_argument("--patience", type=int, default=50)
parser.add_argument("--runs", type=int, default=3)
parser.add_argument("--metric", type=str, default="diffusion_distance")
parser.add_argument("--mag_method", type=str, default="cholesky")
parser.add_argument("--model", type=str, default="GNN")
parser.add_argument("--ratio", type=float, default=0.5)
args = parser.parse_args()

tf.experimental.numpy.experimental_enable_numpy_behavior()

def mag_pool(X, A, **kwargs):
    ratio = kwargs.get("ratio", 0.5)

    n_steps = int(round((1-ratio) * A.shape[0]))

    g = to_nx_graph(X, A)
    print(g)
    dist_fn = choose_graph_metric(args.metric, mode="structure")

    ts = [1]

    if ratio >= 0.5:
        g_sub, this_result, nodes_removed, S, _ = edge_pooling_magnitude(g=g, ts=ts, dist_fn=dist_fn, n_steps=n_steps, method=args.mag_method)
    else:
        g_sub, this_result, nodes_removed, S, _ = edge_pooling_magnitude_repeated(g=g, ts=ts, dist_fn=dist_fn, n_steps=n_steps, method=args.mag_method)


    A_out = nx.to_numpy_array(g_sub)

    mask = S.T

    return A, X, A_out, mask

def pooling(X, A, ratio=0.5):
    A_out = []
    S_out = []
    for x, a in zip(X, A):
        print(a.shape)
        _, _, a_out, s_out = mag_pool(x, a, ratio=ratio)
        A_out.append(a_out)
        S_out.append(s_out)
        
    return X, A, A_out, S_out

if args.mag_method == "spread":
    pre = "SPREAD"
else:
    pre = "MAG"

results = run_experiments(
    runs=args.runs,
    pooling=pooling,
    dataset_name=args.dataset,
    learning_rate=args.lr,
    batch_size=args.batch_size,
    patience=args.patience,
    method=pre+"_EDGE_"+args.metric,
    model_name = args.model,
    ratio=args.ratio,
)
results_to_file(args.dataset, pre+"_EDGE_"+args.metric, *results)
