from jaxtyping import Float
import torch
from torch import Tensor, nn
from torch_geometric.data import Batch

from src.utils.sparse_utils import make_sparse_matrix, pad_values_of_sparse_matrix

class Identity(nn.Module):
    """Identity Feature."""

    def __init__(self) -> None:
        """Initialize Time Feature."""
        super().__init__()

    @property
    def d(self) -> int:
        """Return dimension of Identity Feature.

        :return: Dimension of Identity Feature.
        """
        return 1

    def forward(
        self, data: Batch, fully_connected_index: Float[Tensor, "2 n_fully_connected_edges"]
    ) -> Float[Tensor, "n_edges 1"]:
        """Compute the Identity Feature.

        :param data: PyG batch object.
        :param fully_connected_index: Index of fully connected graph.
        :return: Identity Feature.
        """
        node = torch.arange(data.num_nodes, device=data.t.device)
        index = torch.stack([node, node], dim=0)
        values = torch.ones_like(node, dtype=torch.float)
        identity_matrix = make_sparse_matrix(index, values)
        identity_feature = pad_values_of_sparse_matrix(identity_matrix, fully_connected_index)
        return identity_feature.view(-1, 1)
