import torch
from jaxtyping import Float
from torch import Tensor, nn
from torch_geometric.data import Batch

from src.utils.backbone_utils import safe_log
from src.utils.sparse_utils import sparse_sum


class Entropy(nn.Module):
    """Entropy feature."""

    def __init__(self) -> None:
        """Initialize the Entropy feature."""
        super().__init__()

    @property
    def d(self) -> int:
        """Return the dimension of entropy feature.

        :return: Dimension of entropy feature.
        """
        return 1

    def forward(self, data: Batch) -> Float[Tensor, "n_nodes 1"]:
        """Compute the normalized entropy of the edge probabilities for each node.

        :param data: PyG batch object.
        :return: Normalized entropy of edge probabilities for each node.
        """
        K = data.n_parents
        p = data.edge_attr
        log_p = safe_log(p)
        entropy = -sparse_sum(data.edge_index, p * log_p, 1)
        mask = K >= 2
        entropy_norm = torch.zeros_like(data.t)
        entropy_norm[mask] = entropy[mask] / torch.log(K[mask])
        return entropy_norm.view(-1, 1)
