import numpy as np
import tensorflow as tf
from spektral.layers import ops


def quadratic_loss(x, x_pool, L, L_pool):
    if x.ndim == 1:
        x = x[:, None]
        x_pool = x_pool[:, None]
    loss = np.abs(np.dot(x.T, L.dot(x)) - np.dot(x_pool.T, L_pool.dot(x_pool)))
    return np.mean(np.diag(loss))


def quadratic_loss_tf(x, x_pool, L, L_pool):
    if len(x.shape) == 1:
        x = x[:, None]
        x_pool = x_pool[:, None]
    loss = tf.abs(
        tf.matmul(tf.transpose(x), ops.dot(L, x))
        - tf.matmul(tf.transpose(x_pool), ops.dot(L_pool, x_pool))
    )
    return tf.reduce_mean(tf.linalg.diag_part(loss))


def diffusion_distance_tf(L, t=1, **kwargs):
    """Calculate diffusion distance between vertices of a graph using TensorFlow.

    Parameters
    ----------
    L : Laplacian matrix (Tensor or SparseTensor)
    t : int
        Time parameter for diffusion.

    Returns
    -------
    tf.Tensor
        Matrix of distance values.
    """
    if isinstance(L, tf.SparseTensor):
        L_tf = tf.sparse.to_dense(L)
    elif isinstance(L, tf.Tensor):
        L_tf = L
    else:
        L_tf = tf.convert_to_tensor(L, dtype=tf.float32)

    # Ensure symmetry
    L_tf = (L_tf + tf.transpose(L_tf)) / 2

    # Check for NaNs or infinities
    tf.debugging.check_numerics(L_tf, "Laplacian matrix contains invalid values")

    # Compute eigenvalues and eigenvectors using TensorFlow
    try:
        eigenvalues, eigenvectors = tf.linalg.eigh(L_tf)
    except tf.errors.InvalidArgumentError as e:
        print("Error in eigen decomposition:", e)
        raise

    # Raise eigenvalues to the power of t
    eigenvalues_t = tf.pow(eigenvalues, t)

    # Compute Psi matrix
    psi = tf.multiply(eigenvalues_t, eigenvectors)

    # Ensure Psi is real
    psi = tf.math.real(psi)

    # Compute pairwise distances using TensorFlow
    distances = tf.norm(
        tf.expand_dims(psi, axis=1) - tf.expand_dims(psi, axis=0), axis=-1
    )

    return distances


def mag_loss_tf(x, x_pool, L, L_pool):
    D = diffusion_distance_tf(L)
    D_pool = diffusion_distance_tf(L_pool)

    # Compute the mag difference between the diffusion distances
    Z = tf.exp(-D)
    Z_pool = tf.exp(-D_pool)

    mag = tf.reduce_sum(1 / tf.reduce_sum(Z, axis=0))
    mag_pool = tf.reduce_sum(1 / tf.reduce_sum(Z_pool, axis=0))
    mag_diff = tf.reduce_sum(tf.abs(mag - mag_pool))
    #loss = tf.reduce_mean(tf.square(x_diffused - x_pool_diffused))
    return mag_diff
