from scipy.stats import wasserstein_distance
import numpy as np
import tensorflow as tf
from spektral.utils import degree_power, laplacian
import scipy.sparse as sp

def wasserstein_distance_eigenvalues(L, L_pool):

    if hasattr(L, 'toarray'):
        L = L.toarray()
    elif hasattr(L, 'numpy'):
        L = L.numpy()
        
    if hasattr(L_pool, 'toarray'):
        L_pool = L_pool.toarray()
    elif hasattr(L_pool, 'numpy'):
        L_pool = L_pool.numpy()

    if hasattr(L, 'to_dense'):
        L = L.to_dense().numpy()
    if hasattr(L_pool, 'to_dense'):
        L_pool = L_pool.to_dense().numpy()

    if isinstance(L, tf.SparseTensor):
        L = tf.sparse.to_dense(L).numpy()
    if isinstance(L_pool, tf.SparseTensor):
        L_pool = tf.sparse.to_dense(L_pool).numpy()

    if L.ndim == 1:
        raise ValueError("Laplacian matrix L should be at least 2D.")
    if L_pool.ndim == 1:
        raise ValueError("Pooled Laplacian matrix L_pool should be at least 2D.")
        
    #if L.ndim == 1:
    #    L = np.diag(L)
    #if L_pool.ndim == 1:
    #    L_pool = np.diag(L_pool)
    eigs1 = np.linalg.eigvalsh(L)
    eigs2 = np.linalg.eigvalsh(L_pool)

    n = max(len(eigs1), len(eigs2))
    eigs1 = np.pad(eigs1, (0, n - len(eigs1)))
    eigs2 = np.pad(eigs2, (0, n - len(eigs2)))

    distance = wasserstein_distance(eigs1, eigs2)

    eigs1 = np.sort(eigs1)
    eigs2 = np.sort(eigs2)

    spektral_distance = np.linalg.norm(eigs1 - eigs2)

    return distance, spektral_distance

def wasserstein_distance_eigenvalues_power(L, L_pool, A, A_pool, power=-0.5):

    if hasattr(L, 'toarray'):
        L = L.toarray()
    elif hasattr(L, 'numpy'):
        L = L.numpy()
        
    if hasattr(L_pool, 'toarray'):
        L_pool = L_pool.toarray()
    elif hasattr(L_pool, 'numpy'):
        L_pool = L_pool.numpy()

    if hasattr(A, 'toarray'):
        A = A.toarray()
    elif hasattr(A, 'numpy'):
        A = A.numpy()
        
    if hasattr(A_pool, 'toarray'):
        A_pool = A_pool.toarray()
    elif hasattr(A_pool, 'numpy'):
        A_pool = A_pool.numpy()

    if hasattr(L, 'to_dense'):
        L = L.to_dense().numpy()
    if hasattr(L_pool, 'to_dense'):
        L_pool = L_pool.to_dense().numpy()

    if isinstance(L, tf.SparseTensor):
        L = tf.sparse.to_dense(L).numpy()
    if isinstance(L_pool, tf.SparseTensor):
        L_pool = tf.sparse.to_dense(L_pool).numpy()

    if isinstance(A, tf.SparseTensor):
        A = tf.sparse.to_dense(A).numpy()
    if isinstance(A_pool, tf.SparseTensor):
        A_pool = tf.sparse.to_dense(A_pool).numpy()

    if L.ndim == 1:
        raise ValueError("Laplacian matrix L should be at least 2D.")
    if L_pool.ndim == 1:
        raise ValueError("Pooled Laplacian matrix L_pool should be at least 2D.")
    
    D = degree_power(A, power)
    L = D @ laplacian(A) @ D
    D_pool = degree_power(A_pool, power)
    L_pool = D_pool @ laplacian(A_pool) @ D_pool

    eigvals = sp.linalg.eigsh(L, return_eigenvectors=False, k=A.shape[0])
    eigvals_pool = sp.linalg.eigsh(
        L_pool, return_eigenvectors=False, k=A_pool.shape[0]
    )

    eigs1 = eigvals
    eigs2 = eigvals_pool

    n = max(len(eigs1), len(eigs2))
    eigs1 = np.pad(eigs1, (0, n - len(eigs1)))
    eigs2 = np.pad(eigs2, (0, n - len(eigs2)))

    distance = wasserstein_distance(eigs1, eigs2)

    eigs1 = np.sort(eigs1)
    eigs2 = np.sort(eigs2)

    spektral_distance = np.linalg.norm(eigs1 - eigs2)

    return distance, spektral_distance
