"""
Projection Analysis of Kernel Eigenvectors
------------------------------------------

This script analyzes the projection of target functions onto the eigenvectors of a 
Radial Basis Function (RBF) kernel. The code:
1. Implements an RBF kernel and computes its Gram matrix
2. Computes the eigenvalues and eigenvectors of the kernel matrix
3. Projects true and random functions onto the eigenvectors
4. Visualizes the projection norms and eigenvalues

Key components:
- RBF kernel implementation using basis functions
- Eigenvalue decomposition of kernel matrix
- Projection of target functions onto eigenvectors
- Comparative visualization of projections
"""

import numpy as np
import matplotlib.pyplot as plt



# ===========================================
# Kernel Parameters and Basis Functions
# ===========================================

# Define the grid for RBF centers
grid_max = 2.0      # Maximum grid value
grid_min = -2.0     # Minimum grid value
g = 8               # Number of grid points (RBF centers)
grid = np.linspace(grid_min, grid_max, g)  # Grid points for RBF centers

# Calculate kernel width parameter
denominator = (grid_max - grid_min) / (g - 1)
sigma = denominator / np.sqrt(2)  # Standard deviation for RBF

def phi(x, i):
    """
    Radial Basis Function (RBF) basis function.
    
    Args:
        x: Input value
        i: Index of the basis function center
        
    Returns:
        float: Value of the i-th basis function at x
    """
    return np.exp(-(x-grid[i])**2/(2*sigma**2))

def kantk(x, y):
    """
    Compute the 1D kernel function for Kolmogorov-Arnold Networks (KAN).
    This implements tangent kernel based on RBF basis functions.
    
    Args:
        x: First input (scalar or array)
        y: Second input (scalar or array)
        
    Returns:
        float: Kernel value k(x,y)
    """
    # Handle input types (convert array to scalar if needed)
    if isinstance(x, np.ndarray):
        x = x[0]
    if isinstance(y, np.ndarray):
        y = y[0]

    # Precompute frequently used values
    t = -1 / (2 * sigma**2)  # Kernel width parameter
    
    # Compute basis function values for x and y
    phi_x = np.array([phi(x, i) for i in range(g)])
    phi_y = np.array([phi(y, i) for i in range(g)])
    
    # Compute intermediate matrices for kernel calculation
    b = -2 * (phi_x + phi_y).reshape(-1, 1)  # Bias term
    A = np.outer(phi_x, phi_x) + np.outer(phi_y, phi_y)  # Interaction matrix
    
    # Compute matrix G and its determinant
    G = np.linalg.inv(np.eye(g) - 2 * t * A)  # Kernel matrix
    det = np.sqrt(np.linalg.det(G))  # Determinant of G
    
    # Compute exponential terms
    exponent = 0.5 * (b.T @ G @ b)[0, 0]
    T = np.exp(exponent * (t * grid)**2)
    Z = T * det
    
    # Compute intermediate variables
    c = G @ b  # Coefficient vector
    
    # Initialize tensors for higher-order terms
    Y = np.zeros((g, g))  # Second-order terms
    X = np.zeros((g, g, g))  # Third-order terms
    
    # Compute Y and X tensors using explicit loops for numerical stability
    for i in range(g):
        for j in range(g):
            Y[i, j] = (t * grid[j]**2 * det * c[i] * T[j])[0]
            for k in range(g):
                X[i, j, k] = (T[k] * (det * G[j, i] + (t * grid[k])**2 * det**2 * c[i] * c[j]))[0]
    
    # Final kernel value computation
    output = 0.0
    for l in range(g):
        # Weight for each basis function
        alpha = (phi_x[l] * phi_y[l] * np.exp(-grid[l]**2 / sigma**2)) / sigma**4
        
        # Sum over intermediate terms
        sum_term = 0.0
        for s in range(g):
            for p in range(g):
                # Include Kronecker delta term (1 when s == p, 0 otherwise)
                sum_term += (phi_x[s] * phi_y[p] * X[p, s, l] + 
                            (grid[l]**2 * Z[l] + b[s] * Y[s, l]) * (s == p))
        output += sum_term * alpha
    
    return output[0]

# ===========================================
# Kernel Matrix Computation
# ===========================================

# Define evaluation points
n = 50  # Number of points in each dimension
x_arr = np.linspace(-1, 1, n)  # Input points
x_copy = x_arr.copy()  # Copy for second dimension

# Compute the kernel matrix H (Gram matrix)
H = np.zeros((n, n))
for i, x in enumerate(x_arr):
    for j, y in enumerate(x_copy):
        H[i, j] = kantk(x, y)

# ===========================================
# Eigenanalysis and Projections
# ===========================================

# Compute eigenvalues and eigenvectors of the kernel matrix
eigenvalues, eigenvectors = np.linalg.eig(H)

# Project the true function onto the eigenbasis
y = np.log(np.abs(x_arr)) + x_arr**2 + 1  # True target function
proj = np.zeros((n,))  # Projection coefficients
for i in range(n):
  proj[i] = np.abs(eigenvectors[:, i].conj().T @ y)  # Take absolute value of projections

# Project random noise onto the eigenbasis
y_rand = np.random.randn(n)  # Random target function
proj_rand = np.zeros((n,))  # Projection coefficients for random function
for i in range(n):
  proj_rand[i] = np.transpose(eigenvectors[:, i])@y_rand
proj_rand = np.abs(proj_rand)


# Plotting results
# Create a figure and primary y-axis (ax1)
fig, ax1 = plt.subplots()

# Plot true label projections on primary y-axis (left)
ax1.plot(np.linspace(1, n, n).astype('int16'), proj, color='red', label='True Label')
# Plot random label projections on primary y-axis (left)
ax1.plot(np.linspace(1, n, n).astype('int16'), proj_rand, color='green', label='Random Label')

# Configure primary y-axis (left) for projections
ax1.set_ylabel("Norm of Projections")  # Label for primary y-axis
ax1.tick_params(axis='y', labelcolor='red')  # Set primary y-axis label color to red

# Create secondary y-axis (right) sharing the same x-axis
ax2 = ax1.twinx() 

# Plot eigenvalues on secondary y-axis (right)
ax2.plot(np.linspace(1, n, n).astype('int16'), eigenvalues, color='blue')

# Configure secondary y-axis (right) for eigenvalues
ax2.set_ylabel("Eigenvalue")  # Label for secondary y-axis
ax2.tick_params(axis='y', labelcolor='blue')  # Set secondary y-axis label color to blue

# Add legend for the primary y-axis plots
ax1.legend()

# Save the figure to a file
plt.savefig("projections-2.png")

# Display the plot
plt.show()