import argparse
from magni.src.modules.pooling_utils import to_nx_graph
from magni.src.modules.compare_graphs import choose_graph_metric
import networkx as nx
import numpy as np
from magni.src.autoencoder.training import results_to_file
from magni.src.autoencoder.training_nt import run_experiment
from magni.src.edge_dropping_magnitude import edge_pooling_magnitude

parser = argparse.ArgumentParser()
parser.add_argument("--name", type=str, default="Grid2d")
parser.add_argument("--runs", type=int, default=3)
parser.add_argument("--lr", type=float, default=5e-4)
parser.add_argument("--patience", type=int, default=1000)
parser.add_argument("--tol", type=float, default=1e-6)
parser.add_argument("--metric", type=str, default="diffusion_distance")
parser.add_argument("--mag_method", type=str, default="cholesky")
args = parser.parse_args()

def pooling(X, A, ratio=0.5):

    n_steps = int(round((1-ratio) * A.shape[0]))

    g = to_nx_graph(X, A)
    print(g)
    dist_fn = choose_graph_metric(args.metric, mode="structure")

    ts = [1]

    g_sub, this_result, nodes_removed, S,_ = edge_pooling_magnitude(g=g, ts=ts, dist_fn=dist_fn, n_steps=n_steps, method=args.mag_method)

    A_out = nx.to_numpy_array(g_sub)

    mask = S

    return A, X, A_out, mask.T

if args.mag_method == "spread":
    pre = "SPREAD"
else:
    pre = "MAG"

results = run_experiment(name=args.name, method=pre+"_EDGE_"+args.metric, pooling=pooling, 
                         runs=args.runs, learning_rate=args.lr,
    es_patience=args.patience,
    es_tol=args.tol)
results_to_file(args.name, pre + "_EDGE_"+args.metric, *results)
