"""
Kernel Analysis for Kolmogorov-Arnold Networks (KAN)
------------------------------------------------

This script implements and visualizes the kernel function for Kolmogorov-Arnold Networks (KAN)
using Radial Basis Functions (RBF). It demonstrates the kernel's performance on both linear
and polynomial datasets using Kernel Ridge Regression.

Key components:
- RBF basis function implementation
- KAN kernel computation
- Visualization of kernel regression results
"""

# Standard library imports
import numpy as np
import matplotlib.pyplot as plt

# Machine learning imports
from sklearn.kernel_ridge import KernelRidge

# ===========================================
# Kernel Configuration
# ===========================================

# Define the grid for RBF basis functions
grid_max = 2.0    # Maximum value of the grid
grid_min = -2.0   # Minimum value of the grid
g = 8              # Number of grid points (basis functions)

# Create an evenly spaced grid and compute basis function parameters
grid = np.linspace(grid_min, grid_max, g)  # Grid points for basis functions
grid_spacing = (grid_max - grid_min) / (g - 1)  # Distance between grid points
sigma = grid_spacing / np.sqrt(2)  # Standard deviation of RBFs

def phi(x, i):
    """
    Radial Basis Function (RBF) basis function.
    
    Args:
        x: Input value
        i: Index of the basis function center
        
    Returns:
        float: Value of the basis function at x for center i
    """
    return np.exp(-(x-grid[i])**2/(2*sigma**2))

def kantk(x, y):
    """
    Compute the 1D kernel function for Kolmogorov-Arnold Networks (KAN).
    This implements the tangent kernel based on RBF basis functions.
    
    Args:
        x: First input (scalar or array)
        y: Second input (scalar or array)
        
    Returns:
        float: Kernel value k(x,y)
    """
    # Ensure inputs are scalar values
    if isinstance(x, np.ndarray):
        x = x[0]
    if isinstance(y, np.ndarray):
        y = y[0]

    # Precompute frequently used values
    t = -1 / (2 * sigma**2)
    
    # Compute basis function values for x and y
    phi_x = np.array([phi(x, i) for i in range(g)])
    phi_y = np.array([phi(y, i) for i in range(g)])
    
    # Compute intermediate matrices for kernel calculation
    b = -2 * (phi_x + phi_y).reshape(-1, 1)
    A = np.outer(phi_x, phi_x) + np.outer(phi_y, phi_y)
    
    # Compute G matrix and its determinant
    G = np.linalg.inv(np.eye(g) - 2 * t * A)
    det = np.sqrt(np.linalg.det(G))
    
    # Compute T and Z terms for the kernel
    exponent = 0.5 * (b.T @ G @ b)[0, 0]
    T = np.exp(exponent * (t * grid)**2)
    Z = T * det
    
    # Compute intermediate terms for the kernel
    c = G @ b
    
    # Initialize intermediate tensors for kernel computation
    Y = np.zeros((g, g))
    X = np.zeros((g, g, g))
    
    # Compute Y and X tensors using explicit loops
    for i in range(g):
        for j in range(g):
            Y[i, j] = (t * grid[j]**2 * det * c[i] * T[j])[0]
            for k in range(g):
                X[i, j, k] = (T[k] * (det * G[j, i] + 
                                   (t * grid[k])**2 * det**2 * c[i] * c[j]))[0]
    
    # Compute the final kernel value by integrating over the basis functions
    output = 0.0
    for l in range(g):
        alpha = (phi_x[l] * phi_y[l] * np.exp(-grid[l]**2 / sigma**2)) / sigma**4
        sum_term = 0.0
        for s in range(g):
            for p in range(g):
                # Kronecker delta: 1 if s == p, else 0
                delta = 1 if s == p else 0
                sum_term += (phi_x[s] * phi_y[p] * X[p, s, l] + 
                           (grid[l]**2 * Z[l] + b[s] * Y[s, l]) * delta)
        output += sum_term * alpha
    
    return output[0]

# ===========================================
# Linear Dataset Example
# ===========================================

# Generate synthetic linear data with noise
n = 50  # Number of data points
x = np.linspace(0, 1, n)  # Input features
# Generate target values with linear relationship and Gaussian noise
y = 5.0 * x + 0.1 * np.random.randn(n)

# Initialize and train Kernel Ridge Regression with KAN kernel
krr = KernelRidge(alpha=1e-1, kernel=kantk)  # L2 regularization parameter
krr.fit(x.reshape((-1, 1)), y.reshape((-1, 1)))  # Reshape for sklearn compatibility

# Generate predictions for visualization
x_plot = np.linspace(0, 1, 100)  # Denser grid for smooth plotting
y_pred = np.array([krr.predict(np.array([[xi]]))[0, 0] for xi in x_plot])

# Plot results
plt.figure(dpi=120, figsize=(8, 5))
plt.scatter(x, y, color='blue', ec='white', label='Training Data', alpha=0.7)
plt.plot(x_plot, y_pred, 'r-', linewidth=2, label='KAN Kernel Regression')
plt.title('KAN Kernel Regression on Linear Data', fontsize=12)
plt.xlabel('Input Feature (x)', fontsize=10)
plt.ylabel('Target Value', fontsize=10)
plt.grid(True, alpha=0.3)
plt.legend()
plt.tight_layout()
plt.savefig("Linear.png", bbox_inches='tight')
plt.show()

# ===========================================
# Polynomial Dataset Example
# ===========================================

# Generate synthetic polynomial data with noise
n = 50
x = np.linspace(-1, 1, n)
# Generate target values with 4th degree polynomial and Gaussian noise
y = (0.5 * x**4 - 0.86 * x**3 + 1.32 * x**2 + 2.0 * x + 
     0.1 * np.random.randn(n))

# Train KRR with the KAN kernel
krr = KernelRidge(alpha=0.1, kernel=kantk)  # Slightly higher regularization
krr.fit(x.reshape((-1, 1)), y.reshape((-1, 1)))

# Generate predictions for visualization
x_plot = np.linspace(-1, 1, 200)
y_pred = np.array([krr.predict(np.array([[xi]]))[0, 0] for xi in x_plot])

# Plot results
plt.figure(dpi=120, figsize=(8, 5))
plt.scatter(x, y, color='green', ec='white', label='Training Data', alpha=0.7)
plt.plot(x_plot, y_pred, 'm-', linewidth=2, label='KAN Kernel Regression')
plt.title('KAN Kernel Regression on Polynomial Data', fontsize=12)
plt.xlabel('Input Feature (x)', fontsize=10)
plt.ylabel('Target Value', fontsize=10)
plt.grid(True, alpha=0.3)
plt.legend()
plt.tight_layout()
plt.savefig("Polynomial.png", bbox_inches='tight')
plt.show()