import torch
from jaxtyping import Float
from torch import Tensor, nn
from torch_geometric.data import Batch

from src.utils.sparse_utils import (
    make_normalized_sparse_matrix,
    pad_values_of_sparse_matrix,
    make_sparse_matrix,
)


class AncestorDotProd(nn.Module):
    """Ancestor Dot Product Feature."""

    def __init__(self, max_n: int) -> None:
        """Initialize Ancestor Dot Product Feature."""
        super().__init__()
        self.max_n = max_n

    @property
    def d(self) -> int:
        """Return dimension of Ancestor Dot Product Feature.

        :return: Dimension of feature.
        """
        return 2

    def forward(
        self, data: Batch, fully_connected_index: Float[Tensor, "2 n_fully_connected_edges"]
    ) -> Float[Tensor, "n_edges 2"]:
        """Compute the Adjacency Matrix Features.

        :param data: PyG batch object.
        :return: Adjacency Matrix Features.
        """
        p_anc = make_sparse_matrix(data.edge_index, data.p_anc)
        val = pad_values_of_sparse_matrix(p_anc, fully_connected_index)
        p_anc = make_sparse_matrix(fully_connected_index, val)
        p_anc_normalized = make_normalized_sparse_matrix(fully_connected_index, val)

        p_anc_dot_prod = torch.sparse.mm(p_anc, p_anc.T)
        p_anc_dot_prod_feature = p_anc_dot_prod.coalesce().values() / self.max_n
        p_anc_normalized_dot_prod = torch.sparse.mm(p_anc_normalized, p_anc_normalized.T)
        p_anc_normalized_dot_prod_feature = p_anc_normalized_dot_prod.coalesce().values()
        return torch.stack([p_anc_dot_prod_feature,
                            p_anc_normalized_dot_prod_feature], dim=1)
