import math
import torch
from torch import nn
from torch import Tensor

from models.config import MiMoEConfig
from models.registry import register_position_embedding


@register_position_embedding("sinusoidal")
class SinusoidalPositionEmbedding(nn.Module):
    def __init__(self, config: MiMoEConfig):
        super().__init__()
        self.embedding_dim = config.hidden_dim
        self.max_len = config.max_seq_len + 1  # Add 1 for [CLS]
        
        pe = torch.zeros(self.max_len, self.embedding_dim)
        position = torch.arange(0, self.max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, self.embedding_dim, 2) * (-math.log(10000.0) / self.embedding_dim))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)
        
    def forward(self, x: Tensor) -> Tensor:
        B, T, D = x.shape
        pos_emb = self.pe[:T].unsqueeze(0).expand(B, -1, -1)
        return x + pos_emb



@register_position_embedding("learned")
class LearnedPositionEmbedding(nn.Module):
    def __init__(self, config: MiMoEConfig):
        super().__init__()
        self.embedding_dim = config.hidden_dim
        self.max_len = config.max_seq_len + 1 # Add 1 for [CLS]
        self.position_embeddings = nn.Embedding(self.max_len, self.embedding_dim)
        
    def forward(self, x: Tensor) -> Tensor:
        B, T, D = x.shape
        positions = torch.arange(0, T, device=x.device).unsqueeze(0).expand(B, -1)
        pos_emb = self.position_embeddings(positions)
        return x + pos_emb
    


@register_position_embedding("rope")
class DummyPositionEmbedding(nn.Module):
    """
        This is a dummy position embedding that does nothing.
        RoPEAttention will handle the position embedding.
    """
    def __init__(self, config: MiMoEConfig):
        super().__init__()
        
    def forward(self, x: Tensor) -> Tensor:
        return x