import math

import torch
from jaxtyping import Float
from torch import nn
from torch_geometric.data import Batch


class TimeEmbedding(nn.Module):
    """Time Embedding Feature.

    Encodes the tim with sin and cos functions.
    """

    def __init__(self, d) -> None:
        super().__init__()

        self.d = d

    def forward(self, data: Batch) -> Float[torch.Tensor, "n_nodes d"]:
        """Compute the Time Embedding feature.

        :param data: PyG batch object.
        :return: Time Embedding feature.
        """
        device = data.t.device
        half_dim = self.d // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = data.t[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings
