import torch

import torch.nn as nn

import torch.nn.functional as F



class AttentionPooling(nn.Module):

    """
    Attention pooling layer that learns a weighted average of the input.
    """

    def __init__(self, d_model: int):

        super().__init__()

        self.attention_weights = nn.Linear(d_model, 1)



    def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:

        """
        Args:
            hidden_states (torch.Tensor): (batch_size, seq_len, d_model)
            attention_mask (torch.Tensor): (batch_size, seq_len) boolean mask, True for valid tokens.

        Returns:
            torch.Tensor: (batch_size, d_model) pooled output
        """



        scores = self.attention_weights(hidden_states)





        scores.masked_fill_(~attention_mask.unsqueeze(-1), -float('inf'))





        attn_weights = F.softmax(scores, dim=1)





        context_vector = torch.bmm(attn_weights.transpose(1, 2), hidden_states)



        return context_vector.squeeze(1)





class PoolingHead(nn.Module):

    """
    A head for downstream tasks that pools the sequence of hidden states
    and applies a classification layer.
    """

    def __init__(self, d_model: int, num_classes: int, pooling_strategy: str = "mean", dropout: float = 0.1):

        super().__init__()

        if pooling_strategy not in ["first", "mean", "attn"]:

            raise ValueError(f"Unknown pooling strategy: {pooling_strategy}")

        self.pooling_strategy = pooling_strategy



        if self.pooling_strategy == "attn":

            self.pooler = AttentionPooling(d_model)





        self.classifier = nn.Sequential(

            nn.Linear(d_model, num_classes)

        )



    def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:

        """
        Args:
            hidden_states (torch.Tensor): (batch_size, seq_len, d_model)
            attention_mask (torch.Tensor): (batch_size, seq_len) boolean mask, True for valid tokens.

        Returns:
            torch.Tensor: (batch_size, num_classes) logits
        """

        if self.pooling_strategy == "first":



            pooled_output = hidden_states[:, 0]

        elif self.pooling_strategy == "mean":



            masked_hidden_states = hidden_states * attention_mask.unsqueeze(-1)

            summed = masked_hidden_states.sum(dim=1)

            num_valid_tokens = attention_mask.sum(dim=1, keepdim=True)

            pooled_output = summed / num_valid_tokens.clamp(min=1e-9)

        elif self.pooling_strategy == "attn":

            pooled_output = self.pooler(hidden_states, attention_mask)



        return self.classifier(pooled_output)

