import torch
import torch.nn as nn
from einops import rearrange
import math

class GaussianKernel(nn.Module):
    """
    Gaussian Radial Basis Function (RBF) Kernel implementation
    
    The Gaussian kernel computes similarity between points based on
    their Euclidean distance using the formula:
        K(x,y) = exp(-||x-y||^2 / (2 * bandwidth^2))

    
    Args:
        dim (int): Dimensionality of input feature space
        bandwidth (float, optional): Kernel bandwidth parameter (sigma). Defaults to 1.0.
        learnable_bandwidth (bool, optional): Whether to learn the bandwidth parameter. Defaults to False.
    """
    def __init__(self, dim, bandwidth=1.0, learnable_bandwidth=False, normalization=False):
        super(GaussianKernel, self).__init__()
        self.dim = dim
        self.normalization = normalization
        self.learnable_bandwidth = learnable_bandwidth
        
        # Initialize bandwidth parameter (sigma)
        if learnable_bandwidth:
            self.bandwidth = nn.Parameter(torch.ones(1) * bandwidth)
        else:
            self.register_buffer('bandwidth', torch.tensor([bandwidth], dtype=torch.float))
    
    def forward(self, X):
        """
        Compute the Gaussian kernel matrix for input data.
        
        Args:
            X: Input tensor with shape (B, T, N, D) where:
               - B is batch size
               - T is time steps
               - N is number of nodes/points
               - D is feature dimension
        
        Returns:
            K: Kernel matrix with shape (B*T, N, N) representing pairwise similarities
        """
        # Reshape input tensor to (B*T, N, D)
        X_flat = rearrange(X, 'b t n d -> (b t) n d')
        
        # Compute pairwise Euclidean distances
        dist = torch.cdist(X_flat, X_flat, p=2)
        
        # Apply numerical stability adjustment to avoid overflow/underflow
        scaled_dist = dist / (2 * self.bandwidth.pow(2))
        
        # Compute kernel values
        K = torch.exp(-scaled_dist)
        
        # Apply normalization if requested (creates doubly stochastic matrix)
        if self.normalization:
            row_sums = K.sum(dim=2, keepdim=True)
            K = K / (row_sums + 1e-8)  # Add small epsilon for numerical stability
        
        return K
    
    def compute_median_bandwidth(self, X, sample_size=1000):
        """
        Compute the median heuristic for bandwidth selection.
        This sets bandwidth to median distance between points.
        
        Args:
            X: Input tensor with shape (B, T, N, D)
            sample_size: Maximum number of pairs to sample for efficiency
            
        Returns:
            Updated kernel with median-based bandwidth
        """
        with torch.no_grad():
            X_flat = rearrange(X, 'b t n d -> (b t n) d')
            
            # Sample points if too many
            if X_flat.size(0) > sample_size:
                indices = torch.randperm(X_flat.size(0))[:sample_size]
                X_flat = X_flat[indices]
                
            # Compute pairwise distances
            dist = torch.cdist(X_flat, X_flat, p=2)
            
            # Get median of non-zero distances
            mask = dist > 0
            if mask.sum() > 0:
                median_dist = torch.median(dist[mask])
                
                # Update bandwidth parameter (either Parameter or buffer)
                if self.learnable_bandwidth:
                    self.bandwidth.data = median_dist.view_as(self.bandwidth)
                else:
                    self.bandwidth = median_dist.view_as(self.bandwidth)
                    
        return self
        
        
        
    