import argparse
from magni.src.spectral_similarity.training_nt import (run_experiment)

from magni.src.modules.pooling_utils import to_nx_graph
from magni.src.modules.compare_graphs import choose_graph_metric
from magni.src.modules.compute_graph_magnitude import median_heuristic, compute_magnitude_graph
import networkx as nx

from magni.src.edge_dropping_magnitude import edge_pooling_magnitude

parser = argparse.ArgumentParser()
parser.add_argument("--name", type=str, default="Grid2d")
parser.add_argument("--runs", type=int, default=3)
parser.add_argument("--metric", type=str, default="diffusion_distance")
parser.add_argument("--mag_method", type=str, default="cholesky")
args = parser.parse_args()

def pooling(X, A, ratio=0.5):

    n_steps = int(round((1-ratio) * A.shape[0]))

    g = to_nx_graph(X, A)
    print(g)
    dist_fn = choose_graph_metric(args.metric, mode="structure")

    ts = [1]

    g_sub, this_result, nodes_removed, S, _ = edge_pooling_magnitude(g=g, ts=ts, dist_fn=dist_fn, n_steps=n_steps, method=args.mag_method)

    A_out = nx.to_numpy_array(g_sub)
    mask = S

    return A, X, A_out, mask.T

if args.mag_method == "spread":
    pre = "SPREAD"
else:
    pre = "MAG"

results = run_experiment(name=args.name, method=pre+"_EDGE_"+args.metric, pooling=pooling, runs=args.runs)
