from kappamodules.functional.pos_embed import get_sincos_1d_from_seqlen
from torch import nn

from src.modules.kappa import MLP
from src.modules.positional_embeddings import ContinuousSincosEmbed


class TimestepConditioner(nn.Module):
    def __init__(
        self,
        dim,
        timestep_mlp_hidden,
        condition_dim,
        init_weights="torch",
        max_range=1000,
    ):
        super().__init__()
        self.max_range = max_range
        self.dim = dim
        self.condition_dim = condition_dim
        # buffer/modules
        self.posenc = ContinuousSincosEmbed(dim=timestep_mlp_hidden, ndim=dim)
        self.timestep_mlp = MLP(
            input_dim=timestep_mlp_hidden,
            output_dim=condition_dim,
            hidden_dims=timestep_mlp_hidden,
            init_weights=init_weights,
        )

    def forward(self, timestep):
        # posenc
        enc = self.posenc(timestep * self.max_range)
        # embed
        timestep_embed = self.timestep_mlp(enc)
        return timestep_embed
