from kappamodules.functional.pos_embed import get_sincos_1d_from_seqlen
from torch import nn

from src.modules.kappa import MLP


class VectorConditioner(nn.Module):
    def __init__(
        self,
        in_dim,
        hidden_dim,
        condition_dim,
        init_weights="truncnormal",
    ):
        super().__init__()
        self.condition_dim = condition_dim
        
        self.mlp = MLP(
            input_dim=in_dim,
            output_dim=condition_dim,
            hidden_dims=hidden_dim,
            init_weights=init_weights,
        )

    def forward(self, x):
        cond_embed = self.mlp(x)
        return cond_embed
