"""
Kernel functions for aggregating context set information onto latent grid.
Implements various kernels κ_ρ as described in Section 4.2.
"""

import torch
import torch.nn as nn
import numpy as np
from typing import Optional, Tuple


class RBFKernel(nn.Module):
    """
    Radial Basis Function (RBF) kernel: κ_ρ(r) = exp(-r²/(2ρ²))
    
    This is the default kernel used in the paper.
    The lengthscale ρ is learnable.
    """
    
    def __init__(
        self,
        initial_lengthscale: float = 0.1,
        learnable: bool = True,
        min_lengthscale: float = 1e-3
    ):
        """
        Args:
            initial_lengthscale: Initial value of ρ
            learnable: Whether ρ is learnable
            min_lengthscale: Minimum allowed lengthscale (for stability)
        """
        super().__init__()
        
        self.min_lengthscale = min_lengthscale
        
        # Store log-lengthscale for unconstrained optimization
        if learnable:
            self.log_lengthscale = nn.Parameter(
                torch.tensor(np.log(initial_lengthscale))
            )
        else:
            self.register_buffer(
                'log_lengthscale',
                torch.tensor(np.log(initial_lengthscale))
            )
    
    @property
    def lengthscale(self) -> torch.Tensor:
        """Get the actual lengthscale ρ."""
        return torch.exp(self.log_lengthscale).clamp(min=self.min_lengthscale)
    
    def forward(self, distances: torch.Tensor) -> torch.Tensor:
        """
        Compute kernel values from distances.
        
        Args:
            distances: Euclidean distances, shape (...)
        
        Returns:
            Kernel values κ_ρ(r), same shape as input
        """
        rho = self.lengthscale
        return torch.exp(-0.5 * (distances / rho) ** 2)
    
    def __repr__(self):
        return f"RBFKernel(lengthscale={self.lengthscale.item():.4f})"


class MaternKernel(nn.Module):
    """
    Matérn kernel with ν = 3/2:
    κ_ρ(r) = (1 + √3 * r/ρ) * exp(-√3 * r/ρ)
    
    Less smooth than RBF but often more robust.
    """
    
    def __init__(
        self,
        initial_lengthscale: float = 0.1,
        learnable: bool = True,
        min_lengthscale: float = 1e-3
    ):
        super().__init__()
        
        self.min_lengthscale = min_lengthscale
        
        if learnable:
            self.log_lengthscale = nn.Parameter(
                torch.tensor(np.log(initial_lengthscale))
            )
        else:
            self.register_buffer(
                'log_lengthscale',
                torch.tensor(np.log(initial_lengthscale))
            )
    
    @property
    def lengthscale(self) -> torch.Tensor:
        return torch.exp(self.log_lengthscale).clamp(min=self.min_lengthscale)
    
    def forward(self, distances: torch.Tensor) -> torch.Tensor:
        """Matern 3/2 kernel."""
        rho = self.lengthscale
        sqrt3_r = np.sqrt(3.0) * distances / rho
        return (1.0 + sqrt3_r) * torch.exp(-sqrt3_r)
    
    def __repr__(self):
        return f"MaternKernel(lengthscale={self.lengthscale.item():.4f})"


class WhiteNoiseKernel(nn.Module):
    """
    White noise kernel: κ(r) = σ² * δ(r)
    
    Used in Burgers equation benchmark to model observation noise.
    Only contributes at zero distance (i.e., on the diagonal).
    """
    
    def __init__(
        self,
        initial_log_scale: float = -5.0,
        learnable: bool = True
    ):
        """
        Args:
            initial_log_scale: Initial log(σ²)
            learnable: Whether σ² is learnable
        """
        super().__init__()
        
        if learnable:
            self.log_scale = nn.Parameter(torch.tensor(initial_log_scale))
        else:
            self.register_buffer('log_scale', torch.tensor(initial_log_scale))
    
    @property
    def scale(self) -> torch.Tensor:
        """Get σ²."""
        return torch.exp(self.log_scale)
    
    def forward(self, distances: torch.Tensor, epsilon: float = 1e-6) -> torch.Tensor:
        """
        White noise kernel (delta function approximation).
        
        Args:
            distances: Distances, shape (...)
            epsilon: Threshold for considering distance as zero
        
        Returns:
            Kernel values, shape (...)
        """
        scale = self.scale
        # Only non-zero when distance is very small (approximate delta function)
        return scale * (distances < epsilon).float()
    
    def __repr__(self):
        return f"WhiteNoiseKernel(scale={self.scale.item():.6f})"


class MultiScaleKernel(nn.Module):
    """
    Mixture of RBF kernels with different lengthscales.
    Useful for capturing multi-scale phenomena.
    """
    
    def __init__(
        self,
        num_scales: int = 3,
        initial_lengthscales: Optional[Tuple[float, ...]] = None,
        learnable: bool = True
    ):
        """
        Args:
            num_scales: Number of different scales
            initial_lengthscales: Initial lengthscales for each scale
            learnable: Whether lengthscales are learnable
        """
        super().__init__()
        
        if initial_lengthscales is None:
            # Use exponentially spaced lengthscales
            initial_lengthscales = tuple(0.1 * (2.0 ** i) for i in range(num_scales))
        
        assert len(initial_lengthscales) == num_scales
        
        # Create individual RBF kernels
        self.kernels = nn.ModuleList([
            RBFKernel(lengthscale, learnable=learnable)
            for lengthscale in initial_lengthscales
        ])
        
        # Learnable weights for each scale
        if learnable:
            self.log_weights = nn.Parameter(torch.zeros(num_scales))
        else:
            self.register_buffer('log_weights', torch.zeros(num_scales))
    
    @property
    def weights(self) -> torch.Tensor:
        """Get normalized weights (sum to 1)."""
        return torch.softmax(self.log_weights, dim=0)
    
    def forward(self, distances: torch.Tensor) -> torch.Tensor:
        """
        Weighted sum of kernels.
        
        Args:
            distances: Distances, shape (...)
        
        Returns:
            Combined kernel values, shape (...)
        """
        weights = self.weights
        
        # Compute weighted sum
        result = 0.0
        for i, kernel in enumerate(self.kernels):
            result = result + weights[i] * kernel(distances)
        
        return result
    
    def __repr__(self):
        lengthscales = [k.lengthscale.item() for k in self.kernels]
        weights = self.weights.detach().cpu().numpy()
        return f"MultiScaleKernel(lengthscales={lengthscales}, weights={weights})"


def compute_pairwise_distances(x1: torch.Tensor, 
                              x2: torch.Tensor,
                              metric: str = 'euclidean') -> torch.Tensor:
    """
    Compute pairwise distances between two sets of points.
    
    Args:
        x1: First set of points, shape (batch, n1, dim) or (n1, dim)
        x2: Second set of points, shape (batch, n2, dim) or (n2, dim)
        metric: Distance metric ('euclidean' or 'squared_euclidean')
    
    Returns:
        Pairwise distances, shape (batch, n1, n2) or (n1, n2)
    """
    # Ensure both tensors are on the same device
    if x1.device != x2.device:
        x2 = x2.to(x1.device)
    
    # Handle unbatched inputs
    if x1.dim() == 2:
        x1 = x1.unsqueeze(0)
        x2 = x2.unsqueeze(0)
        squeeze_batch = True
    else:
        squeeze_batch = False
    
    # Compute squared Euclidean distances
    # ||x1 - x2||^2 = ||x1||^2 + ||x2||^2 - 2*x1^T*x2
    x1_sq = torch.sum(x1 ** 2, dim=-1, keepdim=True)  # (batch, n1, 1)
    x2_sq = torch.sum(x2 ** 2, dim=-1, keepdim=True)  # (batch, n2, 1)
    
    dot_product = torch.bmm(x1, x2.transpose(-2, -1))  # (batch, n1, n2)
    
    distances_sq = x1_sq + x2_sq.transpose(-2, -1) - 2 * dot_product
    
    # Ensure non-negative (numerical stability)
    distances_sq = torch.clamp(distances_sq, min=0.0)
    
    if metric == 'euclidean':
        distances = torch.sqrt(distances_sq + 1e-8)
    elif metric == 'squared_euclidean':
        distances = distances_sq
    else:
        raise ValueError(f"Unknown metric: {metric}")
    
    if squeeze_batch:
        distances = distances.squeeze(0)
    
    return distances

def build_kernel(
    kernel_type: str,
    initial_lengthscale: float = 0.1,
    learnable: bool = True,
    **kwargs
) -> nn.Module:
    """
    Factory function to build kernel by name.
    
    Args:
        kernel_type: Type of kernel ('rbf', 'matern', 'white_noise', 'multi_scale')
        initial_lengthscale: Initial lengthscale
        learnable: Whether parameters are learnable
        **kwargs: Additional kernel-specific arguments
    
    Returns:
        Kernel module
    """
    kernel_type = kernel_type.lower()
    
    if kernel_type == 'rbf':
        return RBFKernel(initial_lengthscale, learnable)
    elif kernel_type == 'matern':
        return MaternKernel(initial_lengthscale, learnable)
    elif kernel_type == 'white_noise':
        return WhiteNoiseKernel(
            initial_log_scale=kwargs.get('initial_log_scale', -5.0),
            learnable=learnable
        )
    elif kernel_type == 'multi_scale':
        return MultiScaleKernel(
            num_scales=kwargs.get('num_scales', 3),
            learnable=learnable
        )
    else:
        raise ValueError(f"Unknown kernel type: {kernel_type}")