import torch
from jaxtyping import Float
from torch import Tensor, nn
from torch_geometric.data import Batch

from src.utils.sparse_utils import (
    make_sparse_matrix,
    pad_values_of_sparse_matrix,
)


class AdjacencyDotProd(nn.Module):
    """Adjacency Dot Product Feature."""

    def __init__(self) -> None:
        """Initialize Adjacency Dot Product Feature."""
        super().__init__()

    @property
    def d(self) -> int:
        """Return dimension of Adjacency Dot Product Feature.

        :return: Dimension of feature.
        """
        return 1

    def forward(
        self, data: Batch, fully_connected_index: Float[Tensor, "2 n_fully_connected_edges"]
    ) -> Float[Tensor, "n_edges 2"]:
        """Compute the Adjacency Matrix Features.

        :param data: PyG batch object.
        :return: Adjacency Matrix Features.
        """
        adj = make_sparse_matrix(data.edge_index, data.edge_attr)
        val = pad_values_of_sparse_matrix(adj, fully_connected_index)
        adj = make_sparse_matrix(fully_connected_index, val)
        adj_dot_prod = torch.sparse.mm(adj, adj.T)
        adj_dot_prod_feature = adj_dot_prod.coalesce().values()
        return adj_dot_prod_feature.view(-1, 1)
