import argparse

import numpy as np

from magni.src.graph_classification.training_nt import (results_to_file,
                                                  run_experiments)
from magni.src.modules.graclus import GRACLUS
from magni.src.modules.repeat_pooling import repeat_pooling

import math

parser = argparse.ArgumentParser()
parser.add_argument("--dataset", type=str, default="PROTEINS")
parser.add_argument("--lr", type=float, default=5e-4)
parser.add_argument("--batch-size", type=int, default=32)
parser.add_argument("--patience", type=int, default=50)
parser.add_argument("--runs", type=int, default=3)
parser.add_argument("--model", type=str, default="GNN")
parser.add_argument("--ratio", type=float, default=0.5)
args = parser.parse_args()


def pooling_once(X, A):
    for i in range(len(A)):
        mask = np.array(A[i].sum(-1))[:, 0] != 0
        A[i] = A[i].tocsr()[mask, :][:, mask].tocoo()
        X[i] = X[i][mask]
    X, A, S = GRACLUS(X, A, [0, 1])
    A, A_pool = list(zip(*A))
    S = [s[0] for s in S]
    return X, A, A_pool, S

def pooling_more_than_once(X, A, levels):
    #print(X, A)
    for i in range(len(A)):
        mask = np.array(A[i].sum(-1))[:, 0] != 0
        A[i] = A[i].tocsr()[mask, :][:, mask].tocoo()
        X[i] = X[i][mask]
    X, A, S = GRACLUS(X, A, levels)
    A, A_pool = list(zip(*A))
    S = [s[0] for s in S]
    #print(X, A_pool)
    return X, A, A_pool, S

def pooling(X, A, **kwargs):
    ratio = kwargs.get("ratio")
    #k = kwargs.get("k")
    if ratio == 0.5:
        return pooling_once(X, A)
    elif ratio > 0.54:
        raise ValueError("Pooling ratios above 0.5 are not supported.")
    else:
        out = pooling_more_than_once(X, A, levels=[0, int(math.log(ratio, 0.5))])
        return out


results = run_experiments(
    runs=args.runs,
    pooling=pooling,
    dataset_name=args.dataset,
    learning_rate=args.lr,
    batch_size=args.batch_size,
    patience=args.patience,
    method="Graclus",
    model_name = args.model,
    ratio=args.ratio,
)
results_to_file(args.dataset, "Graclus", *results)
