"""
Set aggregation module for aggregating context set onto latent grid.
Implements Equation 8 from the paper.
"""

import torch
import torch.nn as nn
from typing import Optional
from models.kernels import compute_pairwise_distances


class SetAggregator(nn.Module):
    """
    Aggregates context set representations onto a latent grid using kernel density estimation.
    
    Implements: ρ(x_g) = Σ_i κ_ρ(||x_g - x_c^i||) φ_z(x_c^i, y_c^i) / Z(x_g)
    
    where Z(x_g) = Σ_i κ_ρ(||x_g - x_c^i||) is the normalization.
    """
    
    def __init__(
        self,
        kernel: nn.Module,
        normalize: bool = True,
        epsilon: float = 1e-8
    ):
        """
        Args:
            kernel: Kernel function κ_ρ
            normalize: Whether to normalize by density (divide by Z)
            epsilon: Small constant for numerical stability
        """
        super().__init__()
        
        self.kernel = kernel
        self.normalize = normalize
        self.epsilon = epsilon
    
    def forward(
        self,
        x_context: torch.Tensor,
        latent_features: torch.Tensor,
        x_grid: torch.Tensor
    ) -> torch.Tensor:
        """
        Aggregate context set onto latent grid.
        
        Args:
            x_context: Context locations, shape (batch, n_context, spatial_dim)
            latent_features: Encoded context features φ_z(x_c, y_c), 
                           shape (batch, n_context, latent_dim)
            x_grid: Grid locations, shape (batch, n_grid, spatial_dim)
        
        Returns:
            Aggregated features on grid, shape (batch, n_grid, latent_dim)
        """
        # Compute pairwise distances between grid points and context points
        # Shape: (batch, n_grid, n_context)
        distances = compute_pairwise_distances(x_grid, x_context, metric='euclidean')
        
        # Apply kernel to get weights
        # Shape: (batch, n_grid, n_context)
        weights = self.kernel(distances)
        
        # Normalize weights if requested
        if self.normalize:
            # Compute normalization Z(x_g) = Σ_i κ(||x_g - x_c^i||)
            # Shape: (batch, n_grid, 1)
            normalization = weights.sum(dim=-1, keepdim=True) + self.epsilon
            weights = weights / normalization
        
        # Aggregate: ρ(x_g) = Σ_i w_i φ_z(x_c^i, y_c^i)
        # weights: (batch, n_grid, n_context)
        # latent_features: (batch, n_context, latent_dim)
        # result: (batch, n_grid, latent_dim)
        aggregated = torch.bmm(weights, latent_features)
        
        return aggregated


class DensityChannel(nn.Module):
    """
    Compute normalized density channel that represents concentration of context points.
    
    This is concatenated with the aggregated features to help the model
    understand regions with sparse/dense observations.
    """
    
    def __init__(
        self,
        kernel: nn.Module,
        epsilon: float = 1e-8
    ):
        """
        Args:
            kernel: Kernel function to compute density
            epsilon: Numerical stability constant
        """
        super().__init__()
        self.kernel = kernel
        self.epsilon = epsilon
    
    def forward(
        self,
        x_context: torch.Tensor,
        x_grid: torch.Tensor
    ) -> torch.Tensor:
        """
        Compute density at grid points.
        
        Args:
            x_context: Context locations, shape (batch, n_context, spatial_dim)
            x_grid: Grid locations, shape (batch, n_grid, spatial_dim)
        
        Returns:
            Density values, shape (batch, n_grid, 1)
        """
        # Compute distances
        distances = compute_pairwise_distances(x_grid, x_context, metric='euclidean')
        
        # Apply kernel and sum
        weights = self.kernel(distances)
        density = weights.sum(dim=-1, keepdim=True)  # (batch, n_grid, 1)
        
        # Normalize to [0, 1] range (optional)
        # density = density / (density.max(dim=1, keepdim=True)[0] + self.epsilon)
        
        return density


class MultiChannelAggregator(nn.Module):
    """
    Aggregator that includes both feature aggregation and density channel.
    
    Outputs concatenated [aggregated_features, density] to provide the model
    with information about observation coverage.
    """
    
    def __init__(
        self,
        kernel: nn.Module,
        normalize: bool = True,
        include_density: bool = True,
        epsilon: float = 1e-8
    ):
        """
        Args:
            kernel: Kernel function
            normalize: Whether to normalize aggregation weights
            include_density: Whether to include density channel
            epsilon: Numerical stability
        """
        super().__init__()
        
        self.aggregator = SetAggregator(kernel, normalize, epsilon)
        self.include_density = include_density
        
        if include_density:
            self.density_channel = DensityChannel(kernel, epsilon)
    
    def forward(
        self,
        x_context: torch.Tensor,
        latent_features: torch.Tensor,
        x_grid: torch.Tensor
    ) -> torch.Tensor:
        """
        Aggregate with optional density channel.
        
        Args:
            x_context: Context locations, shape (batch, n_context, spatial_dim)
            latent_features: Encoded features, shape (batch, n_context, latent_dim)
            x_grid: Grid locations, shape (batch, n_grid, spatial_dim)
        
        Returns:
            If include_density=True: shape (batch, n_grid, latent_dim + 1)
            If include_density=False: shape (batch, n_grid, latent_dim)
        """
        # Aggregate features
        aggregated = self.aggregator(x_context, latent_features, x_grid)
        
        if not self.include_density:
            return aggregated
        
        # Compute density
        density = self.density_channel(x_context, x_grid)
        
        # Concatenate
        return torch.cat([aggregated, density], dim=-1)
    
    def get_output_dim(self, latent_dim: int) -> int:
        """Get output dimension after aggregation."""
        return latent_dim + 1 if self.include_density else latent_dim