from kappamodules.functional.pos_embed import get_sincos_1d_from_seqlen
from kappamodules.layers import LinearProjection
from torch import nn

from src.modules.act import GEGLU


class UptTimestepConditioner(nn.Module):
    def __init__(
        self,
        dim,
        condition_dim=None,
        init_weights="truncnormal",
        num_total_timesteps=None,
        act: nn.Module = GEGLU,
    ):
        super().__init__()
        self.num_total_timesteps = num_total_timesteps
        self.dim = dim
        self.cond_dim = condition_dim or dim * 4
        # buffer/modules
        self.register_buffer(
            "timestep_embed",
            get_sincos_1d_from_seqlen(seqlen=self.num_total_timesteps, dim=dim),
        )
        self.timestep_mlp = nn.Sequential(
            LinearProjection(dim, dim * 4, init_weights=init_weights),
            nn.GELU(),
            LinearProjection(dim * 4, self.cond_dim, init_weights=init_weights),
            nn.GELU(),
        )

    def forward(self, timestep):
        # checks + preprocess
        assert timestep.numel() == len(timestep)
        timestep = timestep.flatten()
        # embed
        timestep_embed = self.timestep_mlp(self.timestep_embed[timestep])
        return timestep_embed
