"""
Convergence Rate Analysis of Kolmogorov-Arnold Networks (KAN)
------------------------------------------------------------

This script analyzes the convergence behavior of a Kolmogorov-Arnold Network (KAN) model
by comparing the theoretical convergence rate with empirical training results. The code:
1. Implements Radial Basis Function (RBF) Tangent kernel for KAN
2. Computes the kernel matrix and its eigenvalues
3. Defines and trains a simple KAN model
4. Compares theoretical and empirical convergence rates

Key components:
- RBF tangent kernel implementation
- Training loop with PyTorch
- Convergence rate visualization
"""

# Standard library imports
import numpy as np
import matplotlib.pyplot as plt

# Deep learning imports
import torch
import torch.optim as optim
import torch.nn as nn
from fastkan import FastKANLayer  # Custom KAN implementation

# ===========================================
# Tangent Kernel Parameters and Basis Functions
# ===========================================

# Define the grid for RBF centers
grid_max = 2.0      # Maximum grid value
grid_min = -2.0     # Minimum grid value
g = 8               # Number of grid points (centers)
grid = np.linspace(grid_min, grid_max, g)  # Grid points for RBF centers

# Calculate kernel width parameter
denominator = (grid_max - grid_min) / (g - 1)
sigma = denominator / np.sqrt(2)  # Standard deviation for RBF

def phi(x, i):
    """
    Radial Basis Function (RBF) basis function.
    
    Args:
        x: Input value
        i: Index of the basis function center
        
    Returns:
        float: Value of the i-th basis function at x
    """
    return np.exp(-(x-grid[i])**2/(2*sigma**2))

def kantk(x, y):
    """
    Compute the 1D kernel function for Kolmogorov-Arnold Networks (KAN).
    This implements tangent kernel based on RBF basis functions.
    
    Args:
        x: First input (scalar or array)
        y: Second input (scalar or array)
        
    Returns:
        float: Kernel value k(x,y)
    """
    # Handle input types (convert array to scalar if needed)
    if isinstance(x, np.ndarray):
        x = x[0]
    if isinstance(y, np.ndarray):
        y = y[0]

    # Precompute frequently used values
    t = -1 / (2 * sigma**2)  # Kernel width parameter
    
    # Compute basis function values for x and y
    phi_x = np.array([phi(x, i) for i in range(g)])
    phi_y = np.array([phi(y, i) for i in range(g)])
    
    # Compute intermediate matrices for kernel calculation
    b = -2 * (phi_x + phi_y).reshape(-1, 1)  # Bias term
    A = np.outer(phi_x, phi_x) + np.outer(phi_y, phi_y)  # Interaction matrix
    
    # Compute matrix G and its determinant
    G = np.linalg.inv(np.eye(g) - 2 * t * A)  # Kernel matrix
    det = np.sqrt(np.linalg.det(G))  # Determinant of G
    
    # Compute exponential terms
    exponent = 0.5 * (b.T @ G @ b)[0, 0]
    T = np.exp(exponent * (t * grid)**2)
    Z = T * det
    
    # Compute intermediate variables
    c = G @ b  # Coefficient vector
    
    # Initialize tensors for higher-order terms
    Y = np.zeros((g, g))  # Second-order terms
    X = np.zeros((g, g, g))  # Third-order terms
    
    # Compute Y and X tensors using explicit loops for numerical stability
    for i in range(g):
        for j in range(g):
            Y[i, j] = (t * grid[j]**2 * det * c[i] * T[j])[0]
            for k in range(g):
                X[i, j, k] = (T[k] * (det * G[j, i] + (t * grid[k])**2 * det**2 * c[i] * c[j]))[0]
    
    # Final kernel value computation
    output = 0.0
    for l in range(g):
        # Weight for each basis function
        alpha = (phi_x[l] * phi_y[l] * np.exp(-grid[l]**2 / sigma**2)) / sigma**4
        
        # Sum over intermediate terms
        sum_term = 0.0
        for s in range(g):
            for p in range(g):
                # Include Kronecker delta term (1 when s == p, 0 otherwise)
                sum_term += (phi_x[s] * phi_y[p] * X[p, s, l] + 
                            (grid[l]**2 * Z[l] + b[s] * Y[s, l]) * (s == p))
        output += sum_term * alpha
    
    return output[0]

# ===========================================
# Tangent Kernel Matrix Computation
# ===========================================

# Number of sample points for kernel evaluation
n = 10

# Create input points in [-1, 1]
x_arr = np.linspace(-1, 1, n)
x_copy = x_arr.copy()

# Initialize and compute the kernel matrix H
H = np.zeros((n, n))
for i, x_1 in enumerate(x_arr):
    for j, x_2 in enumerate(x_copy):
        H[i, j] = kantk(x_1, x_2)

# Compute eigenvalues of the kernel matrix
eigenvalues, eigenvectors = np.linalg.eig(H)
lambda_0 = np.min(eigenvalues) # smallest eigenvalue

# ===========================================
# Dataset and Training Setup
# ===========================================

# Create target function: combination of Gaussian and quadratic
y_arr = np.exp(-x_arr**2) + x_arr**2

# Training hyperparameters
lr = 1e-3       # Learning rate
epochs = 1000     # Number of training epochs
m = 5000         # Hidden layer width
criterion = nn.MSELoss()  # Mean Squared Error loss

# Pre-compute theoretical loss trajectory based on theory
# This gives us the expected convergence rate
pred_loss = np.zeros((epochs,))
for epoch in range(epochs):
    # Theoretical loss decay based on learning rate and smallest eigenvalue
    pred_loss[epoch] = epoch * np.log10(1 - lr * lambda_0 / 2)

# calculating real training loss

class Net(nn.Module):
    """
    A simple Kolmogorov-Arnold Network (KAN) with one hidden layer.
    
    Args:
        m: Number of hidden units
        d: Input dimension
    """
    def __init__(self, m, d):
        super(Net, self).__init__()
        self.m = m
        # First KAN layer: input to hidden
        self.kan1 = FastKANLayer(d, m, 
                               use_layernorm=False, 
                               use_base_update=False, 
                               spline_weight_init_scale=1.0)
        # Second KAN layer: hidden to output
        self.kan2 = FastKANLayer(m, 1, 
                               use_layernorm=False, 
                               use_base_update=False, 
                               spline_weight_init_scale=1.0)

    def forward(self, x):
        """Forward pass through the network."""
        x = self.kan1(x)  # First KAN layer
        x = (1/np.sqrt(self.m)) * self.kan2(x)  # Scale output by 1/sqrt(m)
        return x

# ===========================================
# Model Initialization
# ===========================================

# Set device (GPU if available, else CPU)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Convert data to PyTorch tensors and move to device
X = torch.Tensor(x_arr).to(device)
y = torch.Tensor(y_arr).to(device)

# Initialize model and move to device
model = Net(m, 1).to(device)

# Freeze the second layer's parameters
for param in model.kan2.parameters():
    param.requires_grad = False

# Initialize loss history and optimizer
loss_hist = np.zeros((epochs,))
optimizer = optim.SGD(model.parameters(), lr=lr)
init_loss = 0  # Will store the initial loss

# ===========================================
# Training Loop
# ===========================================

print("Starting training...")
for epoch in range(epochs):
    # Set model to training mode
    model.train()
    
    # Zero the parameter gradients
    optimizer.zero_grad()
    
    # Forward pass
    output = model(X.reshape(n, 1))
    
    # Compute loss
    loss = criterion(y.reshape(n,), output.reshape(n,))
    
    # Backward pass and optimize
    loss.backward()
    optimizer.step()
    
    # Store initial loss for normalization
    if epoch == 0:
        init_loss = loss.item()
    
    # Log training progress
    loss_hist[epoch] = np.log10(loss.item()/init_loss)
    
    # Print progress
    if (epoch + 1) % 50 == 0 or epoch == 0:
        print(f"Epoch [{epoch+1}/{epochs}], Loss: {loss_hist[epoch]:.6f}")

# ===========================================
# Plotting Results
# ===========================================

plt.figure(dpi=120, figsize=(10, 6))

# Plot theoretical and empirical convergence
epoch_range = np.linspace(1, epochs, epochs)
plt.plot(epoch_range, pred_loss, 
         label='Theoretical Convergence', 
         color='red', 
         linestyle='--',
         linewidth=2)
         
plt.plot(epoch_range, loss_hist, 
         label='Empirical Loss', 
         color='blue',
         linewidth=2)

# Customize plot
plt.xlabel('Epochs', fontsize=12)
plt.ylabel('Log10(Normalized Loss)', fontsize=12)
plt.title('Theoretical vs. Empirical Convergence Rates\nLearning Rate = ' + str(lr), fontsize=12)
plt.grid(True, linestyle='--', alpha=0.7)
plt.legend(fontsize=10)
plt.tight_layout()

# Save the figure
plt.savefig(f'convergence_plot_{lr}.png', dpi=300, bbox_inches='tight')
plt.show()

print("Training complete. Convergence plot saved as 'convergence_plot.png'")

# Print final statistics
final_loss = loss_hist[-1]
theoretical_final = pred_loss[-1]
print(f"\nFinal Results:")
print(f"- Final empirical loss: {final_loss:.6f}")
print(f"- Final theoretical loss: {theoretical_final:.6f}")
print(f"- Initial loss: {init_loss:.6f}")
print(f"- Learning rate: {lr}")
print(f"- Hidden units: {m}")
print(f"- Number of samples: {n}")
