"""
Aggregator components for GLEAM-AI.

This module contains aggregation functions used in the STNP framework
for combining information across different dimensions.
"""

import torch
import torch.nn as nn
from typing import Optional


class MeanAggregator(nn.Module):
    """
    Mean aggregator for combining representations across dimensions.
    
    This aggregator computes the mean of input representations,
    commonly used for pooling operations in neural networks.
    """
    
    def __init__(self, dim: int = 0):
        """
        Initialize the mean aggregator.
        
        Args:
            dim: Dimension along which to compute the mean (default: 0)
        """
        super().__init__()
        self.dim = dim
    
    def forward(self, r: torch.Tensor) -> torch.Tensor:
        """
        Forward pass of the mean aggregator.
        
        Args:
            r: Input tensor [batch_size, seq_len, feature_dim] or similar
            
        Returns:
            Aggregated tensor with reduced dimension
        """
        return r.mean(dim=self.dim)
    
    def get_aggregation_dim(self) -> int:
        """Get the dimension along which aggregation is performed."""
        return self.dim


class MaxAggregator(nn.Module):
    """
    Maximum aggregator for combining representations across dimensions.
    
    This aggregator computes the maximum of input representations,
    useful for capturing the most salient features.
    """
    
    def __init__(self, dim: int = 0):
        """
        Initialize the max aggregator.
        
        Args:
            dim: Dimension along which to compute the maximum (default: 0)
        """
        super().__init__()
        self.dim = dim
    
    def forward(self, r: torch.Tensor) -> torch.Tensor:
        """
        Forward pass of the max aggregator.
        
        Args:
            r: Input tensor [batch_size, seq_len, feature_dim] or similar
            
        Returns:
            Aggregated tensor with reduced dimension
        """
        return r.max(dim=self.dim)[0]
    
    def get_aggregation_dim(self) -> int:
        """Get the dimension along which aggregation is performed."""
        return self.dim


class AttentionAggregator(nn.Module):
    """
    Attention-based aggregator for combining representations.
    
    This aggregator uses attention mechanisms to weight different
    representations before combining them.
    """
    
    def __init__(self, input_dim: int, attention_dim: Optional[int] = None):
        """
        Initialize the attention aggregator.
        
        Args:
            input_dim: Input feature dimension
            attention_dim: Attention hidden dimension (default: input_dim // 2)
        """
        super().__init__()
        
        if attention_dim is None:
            attention_dim = max(1, input_dim // 2)
        
        self.input_dim = input_dim
        self.attention_dim = attention_dim
        
        # Attention layers
        self.attention = nn.Sequential(
            nn.Linear(input_dim, attention_dim),
            nn.Tanh(),
            nn.Linear(attention_dim, 1)
        )
        self.softmax = nn.Softmax(dim=1)
    
    def forward(self, r: torch.Tensor) -> torch.Tensor:
        """
        Forward pass of the attention aggregator.
        
        Args:
            r: Input tensor [batch_size, seq_len, input_dim]
            
        Returns:
            Attention-weighted aggregated tensor [batch_size, input_dim]
        """
        # Compute attention weights
        attention_weights = self.attention(r)  # [batch_size, seq_len, 1]
        attention_weights = self.softmax(attention_weights.squeeze(-1))  # [batch_size, seq_len]
        
        # Apply attention weights
        weighted_r = r * attention_weights.unsqueeze(-1)  # [batch_size, seq_len, input_dim]
        
        # Sum over sequence dimension
        aggregated = weighted_r.sum(dim=1)  # [batch_size, input_dim]
        
        return aggregated
    
    def get_attention_weights(self, r: torch.Tensor) -> torch.Tensor:
        """
        Get attention weights for the input.
        
        Args:
            r: Input tensor [batch_size, seq_len, input_dim]
            
        Returns:
            Attention weights [batch_size, seq_len]
        """
        attention_weights = self.attention(r)  # [batch_size, seq_len, 1]
        attention_weights = self.softmax(attention_weights.squeeze(-1))  # [batch_size, seq_len]
        return attention_weights
