import torch
from jaxtyping import Float
from torch import Tensor, nn
from torch_geometric.data import Batch

from src.utils.sparse_utils import sparse_mode


class Mode(nn.Module):
    """Mode features:
    1. For each node count of how many nodes it is the mode (parent with highest probability).
    2. For each node the respective count of its mode.
    """

    def __init__(self) -> None:
        """Initialize the Mode feature."""
        super().__init__()

    @property
    def d(self) -> int:
        """Return the dimension of mode feature.

        :return: Dimension of mode feature.
        """
        return 2

    def forward(self, data: Batch) -> Float[Tensor, "n_nodes 2"]:
        """Compute the mode features.

        :param data: PyG batch object.
        :return: Mode features.
        """
        mode = sparse_mode(data.edge_index, data.edge_attr)
        mode_counts = torch.bincount(mode, minlength=data.num_nodes)
        # 0 is mode of all root nodes, therefore set count of 0 to 0
        mode_counts[0] = 0
        return torch.stack([mode_counts, mode_counts[mode]]).T
