import torch
import torch.nn as nn
import torch.nn.functional as F

class AttentionPooling(torch.nn.Module):
    def __init__(self, d_model, num_queries=1):
        """
            Attention Pooling Implementation. 
        """
        super().__init__()
        self.w_q = torch.nn.Linear(d_model, d_model, bias=False)
        self.w_k = torch.nn.Linear(d_model, d_model, bias=False)
        self.w_v = torch.nn.Linear(d_model, d_model, bias=False)
        self.q = torch.nn.Parameter(torch.zeros(num_queries, d_model))
        self.softmax = torch.nn.Softmax(dim=-1)
        self.scale = 1 / (d_model ** 0.5)
        
    def forward(self, x, attention_mask=None):
        """
        x: Tensor of shape (B, N, D)  [Batch, Sequence Length, Feature Dim]
        attention_mask: (optional) Tensor of shape (B, N) indicating valid tokens.
        Returns: Tensor of shape (B, D)
        """
        B, N, D = x.shape
        q_i = self.w_q(self.q.to(x.dtype)).unsqueeze(0).repeat(B, 1, 1)
        k = self.w_k(x)
        v = self.w_v(x)

        scores = torch.bmm(q_i, k.transpose(1, 2)) * self.scale
        if attention_mask is not None:
            scores = scores.masked_fill(~attention_mask.unsqueeze(1).to(bool), float('-inf'))
        scores = self.softmax(scores)
        x = torch.bmm(scores, v).squeeze(1)
        return x
    
class WeightedAveragePooling(nn.Module):
    """
    Weighted average pooling with softmax normalization over token positions.
    Each token has a learnable weight.
    """
    def __init__(self, n_tokens, dtype=torch.float32):
        super().__init__()
        weight_init = torch.ones(n_tokens, dtype=dtype) / n_tokens
        self.w = nn.Parameter(weight_init, requires_grad=True)

    def forward(self, x):
        """
        Args:
            x: Tensor of shape (B, N, D)
        Returns:
            Tensor of shape (B, D)
        """
        B, N, D = x.shape
        w_expanded = self.w.unsqueeze(0).expand(B, -1)
        w_normalized = torch.softmax(w_expanded, dim=1)
        output = (x * w_normalized.unsqueeze(-1)).sum(dim=1)
        return output
    

class MaxPooling(nn.Module):
    """
    Max pooling over token positions.
    """
    def __init__(self):
        super().__init__()
    def forward(self, x):
        pooled, _ = x.max(dim=1)
        return pooled
    









if __name__=="__main__":
    pass