import argparse

from magni.src.graph_classification.training_nt import (results_to_file,
                                                  run_experiments)

from magni.src.modules.pooling_utils import to_nx_graph,  edge_dropping_random
#from magni.src.compute_graph_magnitude import compute_magnitude_subgraphs
import networkx as nx
import numpy as np
import numpy as np
import scipy.sparse as sp

parser = argparse.ArgumentParser()
parser.add_argument("--dataset", type=str, default="PROTEINS")
parser.add_argument("--lr", type=float, default=5e-4)
parser.add_argument("--batch-size", type=int, default=32)
parser.add_argument("--patience", type=int, default=50)
parser.add_argument("--runs", type=int, default=3)
parser.add_argument("--metric", type=str, default="diffusion_distance")
parser.add_argument("--mag_method", type=str, default="cholesky")
parser.add_argument("--model", type=str, default="GNN")
parser.add_argument("--ratio", type=float, default=0.5)
args = parser.parse_args()

def mag_pool(X, A, ratio=0.5):

    n_steps = int(round((1-ratio) * A.shape[0]))

    g = to_nx_graph(X, A)

    g_sub, this_result, nodes_removed, S = edge_dropping_random(g=g, n_steps=n_steps)

    A_out = nx.to_numpy_array(g_sub)
    mask = S

    return A, X, A_out, mask.T


def pooling(X, A, ratio=0.5):
    A_out = []
    S_out = []
    for x, a in zip(X, A):
        print(a.shape)
        _, _, a_out, s_out = mag_pool(x, a, ratio=ratio)
        A_out.append(a_out)
        S_out.append(s_out)
        
    return X, A, A_out, S_out

results = run_experiments(
    runs=args.runs,
    pooling=pooling,
    dataset_name=args.dataset,
    learning_rate=args.lr,
    batch_size=args.batch_size,
    patience=args.patience,
    method="RAND_EDGE",
    model_name = args.model,
    ratio=args.ratio,
)
results_to_file(args.dataset, "RAND_EDGE", *results)