from jaxtyping import Float
from torch import Tensor, nn
from torch_geometric.data import Batch


class Time(nn.Module):
    """Time Feature."""

    def __init__(self) -> None:
        """Initialize Time Feature."""
        super().__init__()

    @property
    def d(self) -> int:
        """Return dimension of Time Feature.

        :return: Dimension of Time Feature.
        """
        return 1

    def forward(
        self, data: Batch, fully_connected_index: Float[Tensor, "2 n_fully_connected_edges"]
    ) -> Float[Tensor, "n_edges 2"]:
        """Compute the Time Feature.

        :param data: PyG batch object.
        :param fully_connected_index: Index of fully connected graph.
        :return: Time Feature.
        """
        t_edge = data.t[fully_connected_index[0]]
        return t_edge.view(-1, 1)
