from jaxtyping import Float
from torch import Tensor, nn
from torch_geometric.data import Batch

from src.utils.sparse_utils import sparse_sum


class AncestorProbaMass(nn.Module):
    """Ancestor Probability Mass Feature.

    Defined as the sum of the ancestor probabilities of all potential descendent nodes.
    """

    def __init__(self) -> None:
        """Initialize the Probability Mass feature."""
        super().__init__()

    @property
    def d(self) -> int:
        """Return the dimension of the Ancestor Probability Mass feature.

        :return: Dimension of Probability Mass feature.
        """
        return 1

    def forward(self, data: Batch) -> Float[Tensor, "n_nodes 1"]:
        """Compute the Ancestor Probability Mass feature.

        :param data: PyG batch object.
        :return: Ancestor Probability Mass feature.
        """
        anc_proba_mass = sparse_sum(data.edge_index, data.p_anc, 0)
        return anc_proba_mass.view(-1, 1)
