from jaxtyping import Float
from torch import Tensor, nn
from torch_geometric.data import Batch


class GinkgoFeatures(nn.Module):
    """Ginkgo Features."""

    def __init__(self) -> None:
        """Initialize Ginkgo Features."""
        super().__init__()

    @property
    def d(self) -> int:
        """Return dimension of Ginkgo Features.

        :return: Dimension of Ginkgo Features.
        """
        return 4

    def forward(self, data: Batch) -> Float[Tensor, "n_nodes 4"]:
        """Compute the Ginkgo Features for internal nodes. Makes use of the fact that features of
        parents is sum of features of children.

        :param data: PyG batch object.
        :return: Ginkgo Features.
        """
        assert "x" in data, "Node features are required to compute Ginkgo Features."
        assert "p_anc" in data, (
            "Parent ancestor probabilities are required to compute Ginkgo Features."
        )
        data.x[data.parent_mask] = 0
        p_ancestor = data.p_anc
        index = data.edge_index[1, :]
        source = data.x[data.edge_index[0, :]] * p_ancestor.unsqueeze(1)
        data.x.index_add_(0, index, source)
        return data.x
