from magni.src.modules.pooling_utils import to_nx_graph
from magni.src.modules.make_splits import make_splits
from magni.src.modules.compare_graphs import choose_graph_metric
from magni.src.modules.loaders import MFreeLoader
import time
from magni.src.modules.compute_graph_magnitude import median_heuristic, compute_magnitude_graph
from magni.src.modules.structure_difference import wasserstein_distance_eigenvalues, wasserstein_distance_eigenvalues_power
from spektral.utils import laplacian
import numpy as np

def get_mag_results(inputs, dist_fn):
    
    spread_diffs = []
    mag_diffs = []
    n_nodes_all = []
    n_nodes_all_sub = []
    mags=[]
    results_wasserstein = []
    results_wasserstein_normalized = []
    results_spektral = []
    results_spektral_normalized = []
    spreads = []
    mag_reduced = []
    spread_reduced = []

    for x, a_original, a_pooled in inputs:
        #### Compare Graph Magnitude
        g = to_nx_graph(x, a_original)
        g_sub = to_nx_graph(x, a_pooled)
        mag, ts = compute_magnitude_graph(g, ts=[1], dist_fn=dist_fn, get_weights=False, n_ts=1, method="cholesky")
        mag_sub, _ = compute_magnitude_graph(g_sub, ts=ts, dist_fn=dist_fn, get_weights=False, n_ts=1, method="cholesky")   
        mag_diff = abs(mag[0] - mag_sub[0])

        n_nodes = g.number_of_nodes()
        n_nodes_sub = g_sub.number_of_nodes()

        spread, ts = compute_magnitude_graph(g, ts=[1],  dist_fn=dist_fn, get_weights=False, n_ts=1, method="spread")
        spread_sub, _ = compute_magnitude_graph(g_sub, ts=ts, dist_fn=dist_fn, get_weights=False, n_ts=1, method="spread")   
        spread_diff = abs(spread[0] - spread_sub[0])

        spread_diffs.append(spread_diff)
        mag_diffs.append(mag_diff)
        n_nodes_all.append(n_nodes)
        n_nodes_all_sub.append(n_nodes_sub)
        mags.append(mag[0])
        spreads.append(spread[0])
        mag_reduced.append(mag_sub[0])
        spread_reduced.append(spread_sub[0])

        #### Compare Spektral Difference
        L_pool = laplacian(a_pooled)
        L = laplacian(a_original)
        wd_normalized, sp_normalized = wasserstein_distance_eigenvalues_power(L, L_pool, a_original, a_pooled, power=-0.5)
        wd, sp = wasserstein_distance_eigenvalues(L, L_pool)
        results_wasserstein.append(float(wd))
        results_wasserstein_normalized.append(float(wd_normalized))
        results_spektral.append(float(sp))
        results_spektral_normalized.append(float(sp_normalized))
    
    mag_results = {
        "magnitude": mags,
        "spread": spreads,
        "magnitude_pooled": mag_reduced,
        "spread_pooled": spread_reduced,
        "mag_diffs": mag_diffs,
        "spread_diffs": spread_diffs,
        "n_nodes": n_nodes_all,
        "n_nodes_sub": n_nodes_all_sub,
        "wasserstein_distance": results_wasserstein,
        "wasserstein_distance_normalized": results_wasserstein_normalized,
        "spectral_distance": results_spektral,
        "spectral_distance_normalized": results_spektral_normalized,
        #"target_values": target_values,
    }
    return mag_results

import tensorflow as tf

def graph_structure_eval(dataset):
    dist_fn = choose_graph_metric("diffusion_distance", mode="structure")

    inputs = []
    for graph in dataset:
        x = graph.x
        a_original = graph.a
        a_pooled = graph.a_1
        inputs.append((x, a_original, a_pooled))

    mag_results = get_mag_results(inputs, dist_fn)
    return mag_results

def graph_structure_eval_trainable(dataset, nt=False, path_to_save_graphs=None):
    dist_fn = choose_graph_metric("diffusion_distance", mode="structure")
    inputs = []
    all_targets = []
    for i, val in enumerate(dataset): 
        d = val[0]
        target = val[1]
        #if tf.sparse.is_sparse_tensor(d[0]):  # Updated check for sparse tensors
        if isinstance(d[1], tf.SparseTensor):
            # Convert sparse tensor to dense tensor
            dense_tensor = tf.sparse.to_dense(d[1])
            a_pooled = dense_tensor.numpy()
        elif nt:
            a_pooled = np.array(d[1])
        else:
            a_pooled = np.squeeze(d[1])
            a_pooled = np.array(a_pooled)
        if isinstance(target, tf.SparseTensor):
            # Convert sparse tensor to dense tensor
            dense_tensor = tf.sparse.to_dense(target)
            target = dense_tensor.numpy()
        elif nt:
            target = np.array(target)
        else:
            #target = np.squeeze(target)
            target = np.array(target)
        
        if isinstance(d[0], tf.SparseTensor):
            dense_tensor = tf.sparse.to_dense(d[0])
            a_original = dense_tensor.numpy()
        elif nt:
            a_original = np.array(d[0])
        else:
            a_original = np.squeeze(d[0])
            a_original = np.array(a_original)

        if isinstance(d[2], tf.SparseTensor):
            dense_tensor = tf.sparse.to_dense(d[2])
            x_original = dense_tensor.numpy()
        elif nt:
            x_original = np.array(d[2])
        else:
            x_original = np.squeeze(d[2])
            x_original = np.array(x_original)

        if isinstance(d[3], tf.SparseTensor):
            # Convert sparse tensor to dense tensor
            dense_tensor = tf.sparse.to_dense(d[3])
            x_pooled = dense_tensor.numpy()
        elif nt:
            x_pooled = np.array(d[3])
        else:
            x_pooled = np.squeeze(d[3])
            x_pooled = np.array(x_pooled)
        
        print(a_original.shape, a_pooled.shape)
        x = np.eye(a_pooled.shape[0])
        inputs.append((x, a_pooled, a_original))
        

        if path_to_save_graphs is not None:
            this_path = path_to_save_graphs + "graph_" + str(int(i)) + ".npz"
            # Save the graph as a .npz file
            np.savez(this_path, i=i, x_original=x_original, x_pooled = x_pooled, a_original=a_original, a_pooled=a_pooled, target=target)
    
    mag_results = get_mag_results(inputs, dist_fn)
    return mag_results