import numpy as np
from scipy.spatial.distance import pdist, squareform


class Kernel:
    """
        Base class for kernels used in SVGD
    """
    def __init__(self):
        pass


    def evaluate(self, x, y=None):
        """
            Evaluate kernel function k(x, y)
            
            Args:
                x (np.ndarray): First set of points with shape (n, d)
                y (np.ndarray, optional): Second set of points with shape (m, d)
                    If None, uses x for both sets
                    
            Returns:
                np.ndarray: Kernel matrix with shape (n, m) or (n, n) if y is None
        """
        raise NotImplementedError("Subclasses must implement evaluate")


    def gradient(self, x, y=None):
        """
            Evaluate gradient of kernel function ∇_y k(x, y)
            
            Args:
                x (np.ndarray): First set of points with shape (n, d)
                y (np.ndarray, optional): Second set of points with shape (m, d)
                    If None, uses x for both sets
                    
            Returns:
                np.ndarray: Gradient of kernel with shape (n, m, d) or (n, n, d) if y is None
        """
        raise NotImplementedError("Subclasses must implement gradient")

        
class RBFKernel(Kernel):
    """
        Radial Basis Function (RBF) kernel: k(x, y) = exp(-||x-y||²/h)
        
        Includes adaptive bandwidth selection using median heuristic.
    """
    
    def __init__(self, bandwidth=None, adaptive=True):
        """
            Initialize RBF kernel
            
            Args:
                bandwidth (float, optional): Kernel bandwidth. If None and adaptive is True,
                    bandwidth will be computed using median heuristic
                adaptive (bool): Whether to use adaptive bandwidth selection
        """
        super().__init__()
        self.bandwidth = bandwidth
        self.adaptive = adaptive


    def _compute_bandwidth(self, x):
        """
            Compute bandwidth using median heuristic
            
            Args:
                x (np.ndarray): Points with shape (n, d)
                
            Returns:
                float: Computed bandwidth
        """
        # & Compute pairwise squared distances
        pairwise_dists = pdist(x, metric='sqeuclidean')
        
        # & Use median of squared distances as bandwidth
        h = np.median(pairwise_dists)
        
        # & Ensure non-zero bandwidth (avoid division by zero)
        if h < 1e-8:
            h = 1.0
            
        # & Scale by dimensionality to address the curse of dimensionality
        h = h / np.log(x.shape[0] + 1)
        
        # & Ensure bandwidth is not too small
        h = max(h, 1e-8)
        
        return h


    def evaluate(self, x, y=None):
        """
            Evaluate RBF kernel k(x, y) = exp(-||x-y||²/h)
            
            Args:
                x (np.ndarray): First set of points with shape (n, d)
                y (np.ndarray, optional): Second set of points with shape (m, d)
                    If None, uses x for both sets
                    
            Returns:
                np.ndarray: Kernel matrix with shape (n, m) or (n, n) if y is None
        """
        if y is None:
            y = x
            
        # & Compute adaptive bandwidth if needed
        if self.bandwidth is None or self.adaptive:
            self.bandwidth = self._compute_bandwidth(x)
            
        # & Get dimensions
        n, d = x.shape
        m = y.shape[0]
        
        # & Handle potential NaN or Inf values in input
        x_clean = np.nan_to_num(x, nan=0.0, posinf=1e10, neginf=-1e10)
        y_clean = np.nan_to_num(y, nan=0.0, posinf=1e10, neginf=-1e10)
        
        # & Compute squared Euclidean distance matrix efficiently
        x_norm = np.sum(x_clean**2, axis=1).reshape(n, 1)
        y_norm = np.sum(y_clean**2, axis=1).reshape(1, m)
        dist_mat = x_norm + y_norm - 2 * np.dot(x_clean, y_clean.T)
        
        # & Ensure no negative distances (from numerical errors)
        dist_mat = np.maximum(dist_mat, 0.0)
        
        # & Ensure bandwidth is positive
        bandwidth = max(self.bandwidth, 1e-8)
        
        # & Apply RBF kernel with safe exponentiation
        # & Clip values to avoid overflow in exp (exp(-700) to exp(700) is a safe range)
        K = np.exp(np.clip(-dist_mat / bandwidth, -700, 700))
        
        # & Replace any NaN or Inf values that might have occurred
        K = np.nan_to_num(K, nan=0.0, posinf=1.0, neginf=0.0)
        
        return K


    def gradient(self, x, y=None):
        """
            Evaluate gradient of RBF kernel ∇_y k(x, y) = -2/h * (y-x) * k(x,y)
            
            Args:
                x (np.ndarray): First set of points with shape (n, d)
                y (np.ndarray, optional): Second set of points with shape (m, d)
                    If None, uses x for both sets
                    
            Returns:
                np.ndarray: Gradient of kernel with shape (n, m, d) or (n, n, d) if y is None
        """
        if y is None:
            y = x
            
        # & Compute adaptive bandwidth if needed
        if self.bandwidth is None or self.adaptive:
            self.bandwidth = self._compute_bandwidth(x)
            
        # & Get dimensions
        n, d = x.shape
        m = y.shape[0]
        
        # & Handle potential NaN or Inf values in input
        x_clean = np.nan_to_num(x, nan=0.0, posinf=1e10, neginf=-1e10)
        y_clean = np.nan_to_num(y, nan=0.0, posinf=1e10, neginf=-1e10)
        
        # & Compute kernel values first using the safe evaluate method
        K = self.evaluate(x_clean, y_clean)  # Shape: (n, m)
        
        # & Ensure bandwidth is positive
        bandwidth = max(self.bandwidth, 1e-8)
        
        # & Initialize gradient array
        grad_K = np.zeros((n, m, d))
        
        # & Compute gradient for each dimension
        for i in range(n):
            # & Broadcast (y - x) with shape (m, d)
            diff = y_clean - x_clean[i:i+1]  # Shape: (m, d)
            
            # & Get kernel values for this point
            k_values = K[i:i+1].T.copy()  # Shape: (m, 1)
            
            # & Ensure no zero or NaN values in k_values to avoid invalid multiplication
            k_values = np.nan_to_num(k_values, nan=1e-10, posinf=1.0, neginf=1e-10)
            k_values = np.maximum(k_values, 1e-10)
            
            # & Apply RBF gradient formula: -2/h * (y-x) * k(x,y)
            gradient = -2.0 / bandwidth * diff * k_values
            
            # & Clean any potential NaN/Inf values
            gradient = np.nan_to_num(gradient, nan=0.0, posinf=0.0, neginf=0.0)
            
            grad_K[i] = gradient
            
        return grad_K

class IMQKernel(Kernel):
    """
        Inverse Multi-Quadric (IMQ) kernel: k(x, y) = (c² + ||x-y||²)^β
        
        More robust to high dimensions than RBF kernel.
    """
    def __init__(self, c=1.0, beta=-0.5, adaptive=True):
        """
            Initialize IMQ kernel
            
            Args:
                c (float): Scale parameter
                beta (float): Power parameter (negative for IMQ)
                adaptive (bool): Whether to use adaptive parameter selection
        """
        super().__init__()
        self.c = c
        self.beta = beta
        self.adaptive = adaptive


    def _compute_parameters(self, x):
        """
        Compute adaptive parameters based on data
        
        Args:
            x (np.ndarray): Points with shape (n, d)
            
        Returns:
            tuple: Computed (c, beta) parameters
        """
        # & Compute pairwise squared distances
        pairwise_dists = pdist(x, metric='sqeuclidean')
        
        # & Use median of squared distances for c
        c = np.sqrt(np.median(pairwise_dists))
        
        # & Ensure non-zero c parameter
        if c < 1e-5:
            c = 1.0
            
        # & Adjust beta based on dimensionality
        d = x.shape[1]
        beta = -0.5 if d <= 10 else -0.25
        
        return c, beta


    def evaluate(self, x, y=None):
        """
            Evaluate IMQ kernel k(x, y) = (c² + ||x-y||²)^β
            
            Args:
                x (np.ndarray): First set of points with shape (n, d)
                y (np.ndarray, optional): Second set of points with shape (m, d)
                    If None, uses x for both sets
                    
            Returns:
                np.ndarray: Kernel matrix with shape (n, m) or (n, n) if y is None
        """
        if y is None:
            y = x
            
        # & Compute adaptive parameters if needed
        if self.adaptive:
            self.c, self.beta = self._compute_parameters(x)
            
        # & Get dimensions
        n, d = x.shape
        m = y.shape[0]
        
        # & Compute squared Euclidean distance matrix efficiently
        x_norm = np.sum(x**2, axis=1).reshape(n, 1)
        y_norm = np.sum(y**2, axis=1).reshape(1, m)
        dist_mat = x_norm + y_norm - 2 * np.dot(x, y.T)
        
        # & Apply IMQ kernel
        K = (self.c**2 + dist_mat)**self.beta
        
        return K


    def gradient(self, x, y=None):
        """
        Evaluate gradient of IMQ kernel
        ∇_y k(x, y) = 2β * (y-x) * (c² + ||x-y||²)^(β-1)
        
        Args:
            x (np.ndarray): First set of points with shape (n, d)
            y (np.ndarray, optional): Second set of points with shape (m, d)
                If None, uses x for both sets
                
        Returns:
            np.ndarray: Gradient of kernel with shape (n, m, d) or (n, n, d) if y is None
        """
        if y is None:
            y = x
            
        # & Compute adaptive parameters if needed
        if self.adaptive:
            self.c, self.beta = self._compute_parameters(x)
            
        # & Get dimensions
        n, d = x.shape
        m = y.shape[0]
        
        # & Compute squared distances
        x_norm = np.sum(x**2, axis=1).reshape(n, 1)
        y_norm = np.sum(y**2, axis=1).reshape(1, m)
        dist_mat = x_norm + y_norm - 2 * np.dot(x, y.T)
        
        # & Compute kernel base values with power (β-1)
        K_base = (self.c**2 + dist_mat)**(self.beta - 1)
        
        # & Initialize gradient array
        grad_K = np.zeros((n, m, d))
        
        # & Compute gradient for each dimension
        for i in range(n):
            # & Broadcast (y - x) with shape (m, d)
            diff = y - x[i:i+1]  # Shape: (m, d)
            
            # & Apply IMQ gradient formula: 2β * (y-x) * (c² + ||x-y||²)^(β-1)
            grad_K[i] = 2 * self.beta * diff * K_base[i:i+1].T
            
        return grad_K
