import torch
from jaxtyping import Float
from torch import Tensor, nn
from torch_geometric.data import Batch

from src.utils.adjacency_utils import (
    compute_magnetic_laplacian,
    get_adj,
    get_adj_row_indices,
)


class MagneticLaplacianPE(nn.Module):
    """Magnetic Laplacian Eigenvector Positional Encoding."""

    def __init__(self, n: int, q: float = 0.25) -> None:
        """Initialize Magnetic Laplacian Eigenvector Positional Encoding.

        :param n: Maximal number of nodes in one graph.
        :param q: Hyperparameter.
        """
        super().__init__()

        self.n = n
        self.q = q

    @property
    def d(self) -> int:
        """Return the dimension of magnetic laplacian eigenvector positional encoding.

        :return: Dimension of magnetic laplacian eigenvector positional encoding.
        """
        return self.n

    def forward(self, data: Batch) -> Float[Tensor, "n_nodes n"]:
        """Compute magnetic laplacian eigenvector positional encoding.

        :param data: PyG batch object.
        :return: Magnetic laplacian eigenvector positional encoding.
        """
        adj = get_adj(data, self.n)
        magnetic_laplacian = compute_magnetic_laplacian(adj, q=self.q)
        _, eigenvectors = torch.linalg.eigh(magnetic_laplacian)

        batch_index, node_index = get_adj_row_indices(data)
        return eigenvectors[batch_index, node_index, :].real
