"""
Convergence Analysis of Kolmogorov-Arnold Networks (KAN) with Different Initializations
----------------------------------------------------------------------------------------

This script analyzes the convergence behavior of a Kolmogorov-Arnold Network (KAN) model
with different label initializations (best, random, worst cases) based on the kernel
eigenvalue spectrum. The code demonstrates how different initializations affect the
convergence rate during training.

Key components:
- RBF-based tangent kernel for KAN
- Eigenvalue analysis of the kernel matrix
- Training with different label initializations
- Convergence rate visualization
"""

# Standard library imports
import numpy as np
import matplotlib.pyplot as plt

# Deep learning imports
import torch
import torch.optim as optim
import torch.nn as nn
from fastkan import FastKANLayer  # Custom KAN implementation

# ===========================================
# Tangent Kernel Parameters and Basis Functions
# ===========================================

# Define the grid for RBF centers
grid_max = 2.0      # Maximum grid value
grid_min = -2.0     # Minimum grid value
g = 8               # Number of grid points (centers)
grid = np.linspace(grid_min, grid_max, g)  # Grid points for RBF centers

# Calculate kernel width parameter
denominator = (grid_max - grid_min) / (g - 1)
sigma = denominator / np.sqrt(2)  # Standard deviation for RBF

def phi(x, i):
    """
    Radial Basis Function (RBF) basis function.
    
    Args:
        x: Input value
        i: Index of the basis function center
        
    Returns:
        float: Value of the basis function at x for center i
    """
    return np.exp(-(x-grid[i])**2/(2*sigma**2))

def kantk(x, y):
    """
    Compute the 1D kernel function for Kolmogorov-Arnold Networks (KAN).
    This implements the tangent kernel based on RBF basis functions.
    
    Args:
        x: First input (scalar or array)
        y: Second input (scalar or array)
        
    Returns:
        float: Kernel value k(x,y)
    """
    # Ensure inputs are scalar values
    if isinstance(x, np.ndarray):
        x = x[0]
    if isinstance(y, np.ndarray):
        y = y[0]

    # Precompute frequently used values
    t = -1 / (2 * sigma**2)
    
    # Compute basis function values for x and y
    phi_x = np.array([phi(x, i) for i in range(g)])
    phi_y = np.array([phi(y, i) for i in range(g)])
    
    # Compute intermediate matrices for kernel calculation
    b = -2 * (phi_x + phi_y).reshape(-1, 1)
    A = np.outer(phi_x, phi_x) + np.outer(phi_y, phi_y)
    
    # Compute G matrix and its determinant
    G = np.linalg.inv(np.eye(g) - 2 * t * A)
    det = np.sqrt(np.linalg.det(G))
    
    # Compute T and Z terms for the kernel
    exponent = 0.5 * (b.T @ G @ b)[0, 0]
    T = np.exp(exponent * (t * grid)**2)
    Z = T * det
    
    # Compute intermediate terms for the kernel
    c = G @ b
    
    # Initialize intermediate tensors for kernel computation
    Y = np.zeros((g, g))
    X = np.zeros((g, g, g))
    
    # Compute Y and X tensors using explicit loops
    for i in range(g):
        for j in range(g):
            Y[i, j] = (t * grid[j]**2 * det * c[i] * T[j])[0]
            for k in range(g):
                X[i, j, k] = (T[k] * (det * G[j, i] + 
                                   (t * grid[k])**2 * det**2 * c[i] * c[j]))[0]
    
    # Compute the final kernel value by integrating over the basis functions
    output = 0.0
    for l in range(g):
        alpha = (phi_x[l] * phi_y[l] * np.exp(-grid[l]**2 / sigma**2)) / sigma**4
        sum_term = 0.0
        for s in range(g):
            for p in range(g):
                # Kronecker delta: 1 if s == p, else 0
                delta = 1 if s == p else 0
                sum_term += (phi_x[s] * phi_y[p] * X[p, s, l] + 
                           (grid[l]**2 * Z[l] + b[s] * Y[s, l]) * delta)
        output += sum_term * alpha
    
    return output[0]

# ===========================================
# Kernel Matrix and Data Preparation
# ===========================================

# Number of sample points for kernel evaluation
n = 30  # Number of training points

# Create input points in [-1, 1] for training
x_arr = np.linspace(-1, 1, n)
x_copy = x_arr.copy()

# Initialize and compute the kernel matrix H (n x n)
print("Computing kernel matrix...")
H = np.zeros((n, n))
for i, x_1 in enumerate(x_arr):
    for j, x_2 in enumerate(x_copy):
        H[i, j] = kantk(x_1, x_2)

# Compute eigenvalues and eigenvectors of the kernel matrix
eigenvalues, eigenvectors = np.linalg.eig(H)

# Sort eigenvalues and corresponding eigenvectors
idx = eigenvalues.argsort()[::-1]
eigenvalues = eigenvalues[idx]
eigenvectors = eigenvectors[:, idx]

# Define different label initializations based on kernel spectrum
# 1. Best case: Aligned with the first (largest) eigenvector
y_best = eigenvectors[:, 0].real
y_best = y_best / np.linalg.norm(y_best)

# 2. Random case: Random Gaussian vector
y_rand = np.random.randn(n,)
y_rand = y_rand / np.linalg.norm(y_rand)

# 3. Worst case: Aligned with the last (smallest) eigenvector
y_worst = eigenvectors[:, -1].real
y_worst = y_worst / np.linalg.norm(y_worst)

# ===========================================
# Model Definition and Training
# ===========================================

# Set device for training (GPU if available, else CPU)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

class Net(nn.Module):
    """
    Kolmogorov-Arnold Network (KAN) model with one hidden layer.
    
    Args:
        m: Number of hidden units (width of the network)
        d: Input dimension
    """
    def __init__(self, m, d):
        super(Net, self).__init__()
        self.m = m
        # First KAN layer: input to hidden
        self.kan1 = FastKANLayer(
            d, m, 
            use_layernorm=False, 
            use_base_update=False, 
            spline_weight_init_scale=1.0
        )
        # Second KAN layer: hidden to output
        self.kan2 = FastKANLayer(
            m, 1, 
            use_layernorm=False, 
            use_base_update=False, 
            spline_weight_init_scale=0.05
        )

    def forward(self, x):
        """Forward pass through the network"""
        x = self.kan1(x)
        # Scale the output by 1/sqrt(m) for normalization
        x = (1/np.sqrt(self.m)) * self.kan2(x)
        return x

# ===========================================
# Training and Visualization
# ===========================================

# Define the target label initializations and their display properties
labels = [y_best, y_rand, y_worst]  # Different label initializations
labels_str = ['best', 'random', 'worst']  # Labels for the legend
colors = ['red', 'green', 'cyan']  # Colors for each training curve

# Model and training hyperparameters
m = 5000         # Width of the hidden layer
epochs = 3000   # Total number of training epochs

# Prepare input data and move to device
X = torch.Tensor(x_arr).to(device)

# Set up the plot
plt.figure(dpi=120)
plt.xlabel('Epochs')
plt.ylabel('Log10(Normalized Training Loss)')
plt.title('Convergence Rates of KAN with Different Label Initializations')
plt.grid(True, alpha=0.3)

# Train model for each label initialization
for ind, label in enumerate(labels):
    # Prepare target values and move to device
    y = torch.Tensor(label).to(device)
    
    # Initialize model and move to device
    model = Net(m, 1).to(device)
    
    # Freeze the second layer parameters (optional)
    for param in model.kan2.parameters():
        param.requires_grad = False
    
    # Set up loss function and optimizer
    criterion = nn.MSELoss()  # Mean Squared Error loss
    loss_hist = np.zeros((epochs,))  # To track loss history
    optimizer = optim.SGD(model.parameters())  # Stochastic Gradient Descent
    
    # Print training information
    print("Number of samples is", n)
    print(f"Start training with {labels_str[ind]} initialization")
    
    # Training loop
    for epoch in range(epochs):
        # Set model to training mode
        model.train()
        
        # Clear previous gradients
        optimizer.zero_grad()
        
        # Forward pass: compute model output
        output = model(X.reshape(n, 1))
        
        # Compute loss between predicted and target values
        loss = criterion(y.reshape(n,), output.reshape(n,))
        
        # Backward pass: compute gradient and update weights
        loss.backward()
        optimizer.step()
        
        # Store current loss
        loss_hist[epoch] = loss.item()
        
        # Print progress at regular intervals
        if epoch % 200 == 0:
            print(f"loss at epoch {epoch}:", loss.item())
    
    # Plot the normalized training loss (log scale)
    normalized_loss = loss_hist / loss_hist[0]  # Normalize by initial loss
    plt.plot(
        np.linspace(1, epochs, epochs), 
        np.log10(normalized_loss + 1e-10),  # Add small constant for numerical stability
        label=labels_str[ind], 
        color=colors[ind]
    )

# Add legend and display the plot
plt.legend()
plt.tight_layout()
plt.savefig('convergence_true_random.png')
plt.show()
