import torch
import torch.nn as nn


class SequencePoolingHead(nn.Module):
    """A module that handles sequence pooling and prediction.

    This module takes sequence embeddings and:
    1. Pools them (currently using mean pooling)
    2. Passes them through a prediction head
    """

    def __init__(self, prediction_head: nn.Module, pooling_method: str = "mean"):
        super().__init__()
        self.prediction_head = prediction_head
        self.pooling_method = pooling_method

    def forward(self, embeddings: torch.Tensor) -> torch.Tensor:
        """Forward pass that pools embeddings and applies prediction head.

        Args:
            embeddings: Tensor of shape [batch_size, seq_len, hidden_dim]

        Returns:
            Tensor of shape [batch_size, output_dim]
        """
        if self.pooling_method == "mean":
            pooled = embeddings.mean(dim=1)  # [batch_size, hidden_dim]
        else:
            raise ValueError(f"Unsupported pooling method: {self.pooling_method}")

        return self.prediction_head(pooled)
