import torch
import torch.nn as nn
import math


class TimestepEmbedding(nn.Module):
    def __init__(self, num_freqs, hidden_dim, output_dim, pos_dim=1, act=nn.GELU()):
        super().__init__()
        self.num_freqs = num_freqs
        self.output_dim = output_dim
        self.hidden_dim = hidden_dim
        self.pos_dim = pos_dim

        self.register_buffer("freqs", torch.arange(1, num_freqs + 1) * math.pi)

        self.mlp = nn.Sequential(
            nn.Linear(2*num_freqs, hidden_dim),
            act,
            nn.Linear(hidden_dim, output_dim),
        )

    def forward(self, t):
        temb = self.freqs * t[..., None]
        temb = torch.cat((temb.cos(), temb.sin()), dim=-1)
        temb = self.mlp(temb)
        return temb
    