from jaxtyping import Float
from torch import Tensor, nn
from torch_geometric.data import Batch


class GraphSize(nn.Module):
    """Graph Size Feature."""

    def __init__(self) -> None:
        """Initialize the Graph Size Feature."""
        super().__init__()

    @property
    def d(self) -> int:
        """Return the dimension of the Graph Size Feature.

        :return: Dimension of Graph Size Feature.
        """
        return 1

    def forward(self, data: Batch) -> Float[Tensor, "n_nodes 1"]:
        """Compute the Graph Size Feature using the ptr attribute of the PyG batch object.

        :param data: PyG batch object.
        :return: Graph Size Feature.
        """
        graph_size = (data.ptr[1:] - data.ptr[:-1])[data.batch]
        return graph_size.view(-1, 1)
