from kappamodules.functional.pos_embed import get_sincos_1d_from_seqlen
from torch import nn

from src.modules.kappa import MLP
from src.models.kappa_overrides.prenorm_block import PrenormBlock
from src.modules.rope_frequency import RopeFrequency

class AttentionConditioner(nn.Module):
    '''First MLP, then SelfAttention, then aggregate'''
    def __init__(
        self,
        in_dim,
        x_dim,
        hidden_dim,
        condition_dim,
        n_heads,
        transformer_depth,
        init_weights="truncnormal",
    ):
        super().__init__()
        self.condition_dim = condition_dim
        
        self.mlp = MLP(
            input_dim=in_dim,
            output_dim=condition_dim,
            hidden_dims=hidden_dim,
            init_weights=init_weights,
        )
        
        self.blocks = nn.ModuleList(
            [
                PrenormBlock(
                    dim=condition_dim,
                    num_heads=n_heads,
                    n_anchors=-1,  # explicitly turn off
                    init_weights=init_weights,
                )
                for _ in range(transformer_depth)
            ]
        )
        self.rope = RopeFrequency(dim=condition_dim // n_heads, ndim=x_dim)

    def forward(self, x, pos):
        cond_embed = self.mlp(x)
        
        # apply blocks
        for blk in self.blocks:
            x = blk(x, rope_freqs=self.rope(pos))
        
        return cond_embed
