import torch
import torch.nn as nn


class ParticleAggregator(nn.Module):
    """
    # & Network to aggregate particle ensemble into a single vector representation
    """
    def __init__(self, h_dim: int, z_dim: int, output_dim: int = None):
        super(ParticleAggregator, self).__init__()
        
        if output_dim is None:
            output_dim = h_dim
            
        self.input_dim = h_dim + z_dim + 1  # & h, z, and weight
        
        # & RNN for processing particles sequentially
        self.rnn = nn.GRU(input_size=self.input_dim, hidden_size=h_dim, batch_first=True)
        
        # & Final projection
        self.output_proj = nn.Linear(h_dim, output_dim)
    
    def forward(self, h_particles: torch.Tensor, z_particles: torch.Tensor, 
                weights: torch.Tensor) -> torch.Tensor:
        """
        # & Aggregate particles into a single representation
        
        Args:
            h_particles: Shape [batch_size, n_particles, h_dim]
            z_particles: Shape [batch_size, n_particles, z_dim]
            weights: Shape [batch_size, n_particles]
            
        Returns:
            Aggregated representation [batch_size, output_dim]
        """
        batch_size, n_particles = h_particles.shape[0], h_particles.shape[1]
        
        # & Prepare inputs: concatenate h, z, and weight for each particle
        weights_expanded = weights.unsqueeze(-1)
        particle_inputs = torch.cat([h_particles, z_particles, weights_expanded], dim=-1)
        
        # & Initialize hidden state
        h_0 = torch.zeros(1, batch_size, self.rnn.hidden_size, device=particle_inputs.device)
        
        # & Process particles sequentially through RNN
        _, h_n = self.rnn(particle_inputs, h_0)
        
        # & Final projection
        aggregated = self.output_proj(h_n.squeeze(0))
        
        return aggregated

