from jaxtyping import Float
from torch import Tensor, nn
from torch_geometric.data import Batch

from src.utils.adjacency_utils import compute_P_anc, get_adj_row_indices


class ProbaAncestorPE(nn.Module):
    """Probabilistic Ancestor Feature."""

    def __init__(self, n: int) -> None:
        """Probabilistic Ancestor.

        :param n: Maximal number of nodes in any graph.
        """
        super().__init__()
        self.n = n

    @property
    def d(self) -> int:
        """Return the dimension of the Parent feature.

        :return: Dimension of Parent feature.
        """
        return (self.n - 1) // 2

    def forward(self, data: Batch) -> Float[Tensor, "n_edges"]:
        """Compute probabilistic ancestors.

        :param data: PyG batch object.
        :return: Probabilistic ancestor probabilities.
        """
        P_anc = compute_P_anc(data, self.n)
        batch, node = get_adj_row_indices(data)
        return P_anc[batch, node, :]
