"""
Convergence Analysis of Kolmogorov-Arnold Networks (KAN) with Different Widths
----------------------------------------------------------------------------

This script analyzes the convergence behavior of Kolmogorov-Arnold Networks (KAN) 
with varying hidden layer widths. The code:
1. Trains multiple KAN models with different hidden layer sizes
2. Tracks training loss and parameter distances from initialization
3. Visualizes the relationship between network width and convergence rate

Key components:
- KAN model implementation with FastKANLayer
- Training loop with PyTorch
- Convergence and distance visualization
"""

# Standard library imports
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.optim as optim
import torch.nn as nn
from fastkan import FastKANLayer

n = 1000 # number of samples
d = 1000 # number of dimensions
X = np.zeros((n, d))
y = np.random.randn(n,)
for i in range(n):
  x = np.random.randn(d,)
  X[i, :] = x / np.linalg.norm(x)

# defining device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

# defining model
class Net(nn.Module):
    def __init__(self, m, d):
        super(Net, self).__init__()
        self.m = m
        self.kan1 = FastKANLayer(d, m, use_layernorm=False, use_base_update=False, spline_weight_init_scale=1.0)
        self.kan2 = FastKANLayer(m, 1, use_layernorm=False, use_base_update=False, spline_weight_init_scale=1.0)

    def forward(self, x):
        x = self.kan1(x)
        x = 1/np.sqrt(self.m)*self.kan2(x)
        return x
    
def calculate_inf_norm(params, params_0):
   w, w0 = None, None
   with torch.no_grad():
    for p in params:
        if p.requires_grad:
          w = p.clone().detach().cpu().numpy().reshape(-1,)
    for p0 in params_0:
        if p0.requires_grad:
          w0 = p0.clone().detach().cpu().numpy().reshape(-1,)
    return np.max(np.abs(w-w0))
    
X = torch.Tensor(X).to(device)
y = torch.Tensor(y).to(device)
models = []
widths = [500, 1000, 2000, 4000, 8000] # widths
epochs = 500 # number of training epochs
max_distance = np.zeros((len(widths), epochs))

plt.figure(dpi=120)
plt.xlabel('Epochs')
plt.ylabel('log(Training Errors)')
plt.title('Convergence Rates')
plt.grid()
for ind, m in enumerate(widths):
  # defining model and parameters
  model = Net(m, d).to(device)
  # get list of parameters at initialization
  for p in model.parameters():
     params_0 = p.clone()
  # set second layer weights constant
  layer = 1
  for p in params_0:
      if p.requires_grad and layer==1:
        layer = 2
      elif p.requires_grad and layer==2:
        p.requires_grad = False
  
  # Define the optimization objects
  criterion = nn.MSELoss()
  loss_hist = np.zeros((epochs,))
  optimizer = optim.SGD(model.parameters())

  # Print dataset size and model width
  print("Number of samples is", n)  
  print(f"Start training with m={m}")
  
  # Calculate and print the total number of parameters
  pytorch_total_params = sum(p.numel() for p in model.parameters())
  print(f"Number of parameters is ", pytorch_total_params)
  
  # Training loop
  for epoch in range(epochs):
      # Set the model to training mode
      model.train()  
      # Clear previous gradients
      optimizer.zero_grad()  
      
      # Forward pass
      output = model(X)  
      # Calculate loss between predictions and targets
      loss = criterion(y.reshape(n,), output.reshape(n,))
      
      # Backward pass and optimization
      loss.backward()  
      optimizer.step()  
      
      # Store loss for this epoch
      loss_hist[epoch] = loss.item()
      
      # Calculate and store the maximum parameter distance from initialization
      params = [p for p in model.parameters()]
      max_distance[ind, epoch] = calculate_inf_norm(params, params_0)
      
      # Print progress every 100 epochs
      if epoch % 100 == 0:
          print(f"loss at epoch {epoch}:", loss.item())
  
  # Plot the training curve for this model
  plt.plot(np.linspace(1, epochs, epochs), np.log10(loss_hist), label=f'm={m}')
  print("\n")  
  models.append(model)  

# Add legend to the convergence plot
plt.legend()

# Create a new figure for the distance plot
plt.figure(dpi=120)  

# Plot the distance curves for each model width
for i in range(len(widths)):
    # Plot the maximum parameter distance over epochs
    plt.plot(
        np.linspace(1, epochs, epochs),  
        max_distance[i,:],                
        label=f"m={widths[i]}"           
    )

# Configure the distance plot
plt.grid()  
plt.title("Maximum Distances From Initialization")  
plt.xlabel("Epochs")  
plt.ylabel("Maximum Distances")  
plt.legend()  

# Display the plots
plt.show()
