import torch
import torch.nn as nn
import math

class SinePositionalEncoding(nn.Module):
    """Sine and Cosine Positional Encoding"""
    def __init__(self, d_model: int, max_seq_len: int = 5000, dropout: float = 0.1):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        
        # Create positional encoding matrix (max_seq_len, d_model)
        pe = torch.zeros(max_seq_len, d_model)
        position = torch.arange(0, max_seq_len).unsqueeze(1)  # (max_seq_len, 1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model)
        )
        
        # Apply sin to even indices and cos to odd indices
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)  # (1, max_seq_len, d_model)
        
        # Register as buffer (not a parameter but saved as part of the model)
        self.register_buffer('pe', pe)
        
    def forward(self, x: torch.Tensor, position_offset=0) -> torch.Tensor:
        """
        Args:
            x: Tensor, shape [batch_size, seq_len, embedding_dim]
        """
        seq_len = x.size(1)
        pos_encoding = self.pe[:, position_offset:position_offset+seq_len]
        return self.dropout(x + pos_encoding)
    
    
class LearnedPositionalEncoding(nn.Module):
    def __init__(self, d_model: int, max_seq_len: int = 5000, dropout: float = 0.1):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        self.embedding = nn.Embedding(max_seq_len, d_model)
        
    def forward(self, x, position_offset=0):
        seq_len = x.size(1)
        position_ids = torch.arange(seq_len, device=x.device) + position_offset
        position_embeds = self.embedding(position_ids)
        
        return self.dropout(x + position_embeds.unsqueeze(0))

    
