import torch


import math


def triu_to_adj(triu):
    """From upper-triangular to symmetric adjacency tensor."""
    # triu [b, n(n-1)/2, d]
    n = int(math.sqrt(2 * triu.shape[-2])) + 1
    d = triu.shape[-1]
    adj = torch.zeros((triu.shape[0], n, n, d), dtype=triu.dtype).to(triu.device)
    row, col = torch.triu_indices(n, n, offset=1).to(triu.device)
    adj[:, row, col] = triu
    adj[:, col, row] = triu
    # [b, n, n, d]
    return adj


def adj_to_triu(adj):
    """To upper-triangular from symmetric adjacency tensor."""
    # [b, n, n, d]
    n = adj.shape[-2]
    row, col = torch.triu_indices(n, n, offset=1).to(adj.device)
    # [b, n(n-1)/2, d]
    triu = adj[:, row, col]
    return triu


def triu_to_sparse(triu):
    """From upper-triangular to symmetric adjacency tensor."""
    # triu [b, n(n-1)/2]
    n = int(math.sqrt(2 * triu.shape[1])) + 1
    adj = torch.zeros((triu.shape[0], n, n), dtype=torch.long).to(triu.device)
    row, col = torch.triu_indices(n, n, offset=1).to(triu.device)
    adj[:, row, col] = triu
    adj[:, col, row] = triu
    indices = adj.nonzero()
    Is = indices[:, 0]
    Js = indices[:, 1]
    Ks = indices[:, 2]
    return Is, Js, Ks, adj[Is, Js, Ks], adj


def get_moment_encoding(adj, n_moments: int, return_adj: bool = False):
    """Computes moments \mathbb{E}[z^k] for order k.
    https://openreview.net/pdf?id=qaJxPhkYtD

    NOTE: we use log(a_ij + 1) for numerical stability.

    adj: one-hot [nrows, n, n]
    Returns y: [nrows, n, k]
    """
    A_0 = adj + torch.eye(adj.size(1)).to(adj.device)
    inv_deg = torch.diag_embed(1 / A_0.sum(-1))
    A_0 = inv_deg @ A_0
    powers = [A_0]
    A_k = A_0.clone()
    for _ in range(n_moments - 1):
        A_k = torch.einsum("nik,nkj->nij", A_k, A_0)
        powers.append(A_k.clone())

    x = torch.zeros((adj.size(0), adj.size(1), n_moments)).to(adj.device)
    for k in range(0, n_moments):
        x[:, :, k] = torch.diagonal(powers[k], dim1=1, dim2=2)

    if return_adj:
        # [nrows, n, n, k]
        e = torch.cat([p.unsqueeze(-1) for p in powers], dim=-1)
        return x, e
    else:
        # [nrows, n, k]
        return x


def get_lap_pe3(
    adj, n_eigs: int = 12, norm: bool = False, return_n_connected: bool = True
):
    """Computes Laplacian positional encodings and number of connected components."""
    assert n_eigs <= adj.size(
        -1
    ), "n_eigs must be less than or equal to the size of adj."

    # Add self-loops to the adjacency matrix
    adj = adj + torch.eye(adj.size(-1), device=adj.device)

    if norm:
        deg = torch.sum(adj, dim=-1)  # Degree matrix
        inv_sqrt_deg = 1.0 / torch.sqrt(deg)
        inv_sqrt_deg = torch.diag_embed(inv_sqrt_deg)
        c = (
            torch.eye(adj.size(-1), device=adj.device)
            - inv_sqrt_deg @ adj @ inv_sqrt_deg
        )
    else:
        deg = torch.diag_embed(torch.sum(adj, dim=-1))
        c = deg - adj  # Unnormalized Laplacian

    # Compute eigenvalues and eigenvectors in batch
    eigvals, eigvecs = torch.linalg.eigh(c)

    if eigvecs.isnan().any():
        raise ValueError("Diagonalization returned NaNs.")

    # Select eigenvectors excluding the first one
    eigs = eigvecs[:, :, 1 : n_eigs + 1]

    # Compute number of connected components
    n_connected_components = (eigvals < 1e-5).sum(dim=-1, keepdim=True) - (
        eigvals == 0
    ).sum(dim=-1, keepdim=True)

    if return_n_connected:
        return eigs, n_connected_components
    else:
        return eigs


def get_moment_encoding(adj, n_moments: int, return_adj: bool = False):
    """Computes moments \mathbb{E}[z^k] for order k.
    https://openreview.net/pdf?id=qaJxPhkYtD

    NOTE: we use log(a_ij + 1) for numerical stability.

    adj: one-hot [nrows, n, n]
    Returns y: [nrows, n, k]
    """
    A_0 = adj + torch.eye(adj.size(1)).to(adj.device)
    inv_deg = torch.diag_embed(1 / A_0.sum(-1))
    A_0 = inv_deg @ A_0
    powers = [A_0]
    A_k = A_0.clone()
    for _ in range(n_moments - 1):
        A_k = torch.einsum("nik,nkj->nij", A_k, A_0)
        powers.append(A_k.clone())

    x = torch.zeros((adj.size(0), adj.size(1), n_moments)).to(adj.device)
    for k in range(0, n_moments):
        x[:, :, k] = torch.diagonal(powers[k], dim1=1, dim2=2)

    if return_adj:
        # [nrows, n, n, k]
        e = torch.cat([p.unsqueeze(-1) for p in powers], dim=-1)
        return x, e
    else:
        # [nrows, n, k]
        return x
