"""
Convergence Analysis of Kolmogorov-Arnold Networks (KAN) with Different Widths
----------------------------------------------------------------------------

This script analyzes the convergence behavior of Kolmogorov-Arnold Networks (KAN) 
with varying hidden layer widths. The code:
1. Trains multiple KAN models with different hidden layer sizes
2. Tracks training loss and parameter distances from initialization
3. Visualizes the relationship between network width and convergence rate

Key components:
- KAN model implementation with FastKANLayer
- Training loop with PyTorch
- Convergence and distance visualization
"""

# Standard library imports
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.optim as optim
import torch.nn as nn
import time
from fastkan import FastKANLayer

d = 100 # number of dimensions

# Sweep over different numbers of samples
samples_list = [500, 750, 1000]
# Fixed network width
m = 1000

# defining device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

# defining model
class Net(nn.Module):
    def __init__(self, m, d):
        super(Net, self).__init__()
        self.m = m
        self.kan1 = FastKANLayer(d, m, use_layernorm=False, use_base_update=False, spline_weight_init_scale=1.0)
        self.kan2 = FastKANLayer(m, 1, use_layernorm=False, use_base_update=False, spline_weight_init_scale=1.0)

    def forward(self, x):
        x = self.kan1(x)
        x = 1/np.sqrt(self.m)*self.kan2(x)
        return x


class ReLUNet(nn.Module):
    def __init__(self, m, d):
        super(ReLUNet, self).__init__()
        self.m = m
        self.fc1 = nn.Linear(d, m)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(m, 1)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = 1/np.sqrt(self.m) * self.fc2(x)
        return x
    
def calculate_inf_norm(params, params_0):
  with torch.no_grad():
    vals = []
    vals0 = []
    for p in params:
      vals.append(p.clone().detach().cpu().numpy().reshape(-1,))
    for p0 in params_0:
      vals0.append(p0.clone().detach().cpu().numpy().reshape(-1,))
    if len(vals) == 0 or len(vals0) == 0:
      return 0.0
    w = np.concatenate(vals)
    w0 = np.concatenate(vals0)
    return np.max(np.abs(w - w0))

models = []
epochs = 5000 # number of training epochs
s_count = len(samples_list)
max_distance_kan = np.zeros((s_count, epochs))
max_distance_relu = np.zeros((s_count, epochs))

# store last trained models for potential inspection
models_kan = []
models_relu = []

# Plot: training loss curves for different sample sizes
plt.figure(dpi=120)
plt.xlabel('Epochs')
plt.ylabel('log(Training Errors)')
plt.title('Convergence Rates (varying n, m=%d)' % m)
plt.grid()
for ind, n_samples in enumerate(samples_list):
  # generate dataset for this n
  X_np = np.zeros((n_samples, d))
  y_np = np.random.randn(n_samples,)
  for i in range(n_samples):
    x = np.random.randn(d,)
    X_np[i, :] = x / np.linalg.norm(x)

  X = torch.Tensor(X_np).to(device)
  y = torch.Tensor(y_np).to(device)

  # --- Train KAN model (with frozen second layer) ---
  model_kan = Net(m, d).to(device)
  params_0_kan = [p.clone().detach() for p in model_kan.parameters()]
  for p in model_kan.kan2.parameters():
    p.requires_grad = False

  criterion = nn.MSELoss()
  loss_hist_kan = np.zeros((epochs,))
  trainable_params_kan = [p for p in model_kan.parameters() if p.requires_grad]
  optimizer_kan = optim.SGD(trainable_params_kan, lr=0.01)

  print("Number of samples is", n_samples)
  print(f"Start training KAN with n={n_samples}, m={m}")
  pytorch_total_params = sum(p.numel() for p in model_kan.parameters())
  print(f"Number of parameters (KAN) is ", pytorch_total_params)

  start_time = time.perf_counter()
  for epoch in range(epochs):
    model_kan.train()
    optimizer_kan.zero_grad()

    output = model_kan(X)
    loss = criterion(output.reshape(n_samples,), y.reshape(n_samples,))

    loss.backward()
    optimizer_kan.step()

    loss_hist_kan[epoch] = loss.item()

    params_now = [p for p in model_kan.parameters()]
    max_distance_kan[ind, epoch] = calculate_inf_norm(params_now, params_0_kan)

    if epoch % 100 == 0:
      print(f"KAN loss at epoch {epoch}:", loss.item())

  plt.plot(np.linspace(1, epochs, epochs), np.log10(loss_hist_kan), label=f'KAN n={n_samples}', linestyle='-')
  end_time = time.perf_counter()
  elapsed = end_time - start_time
  print(f"Training time for KAN n={n_samples}: {elapsed:.2f} seconds")
  models_kan.append(model_kan)

  # --- Train ReLU 2-layer network for comparison ---
  model_relu = ReLUNet(m, d).to(device)
  params_0_relu = [p.clone().detach() for p in model_relu.parameters()]

  loss_hist_relu = np.zeros((epochs,))
  optimizer_relu = optim.SGD(model_relu.parameters(), lr=0.01)

  print(f"Start training ReLU net with n={n_samples}, m={m}")
  pytorch_total_params_relu = sum(p.numel() for p in model_relu.parameters())
  print(f"Number of parameters (ReLU) is ", pytorch_total_params_relu)

  start_time = time.perf_counter()
  for epoch in range(epochs):
    model_relu.train()
    optimizer_relu.zero_grad()

    output = model_relu(X)
    loss = criterion(output.reshape(n_samples,), y.reshape(n_samples,))

    loss.backward()
    optimizer_relu.step()

    loss_hist_relu[epoch] = loss.item()

    params_now = [p for p in model_relu.parameters()]
    max_distance_relu[ind, epoch] = calculate_inf_norm(params_now, params_0_relu)

    if epoch % 100 == 0:
      print(f"ReLU loss at epoch {epoch}:", loss.item())

  plt.plot(np.linspace(1, epochs, epochs), np.log10(loss_hist_relu), label=f'ReLU n={n_samples}', linestyle='--')
  end_time = time.perf_counter()
  elapsed = end_time - start_time
  print(f"Training time for ReLU n={n_samples}: {elapsed:.2f} seconds")
  models_relu.append(model_relu)

# Add legend to the convergence plot
plt.legend()

# Create a new figure for the distance plot
plt.figure(dpi=120)

# Plot the distance curves for each sample size (KAN and ReLU)
for i in range(s_count):
    plt.plot(
        np.linspace(1, epochs, epochs),
        max_distance_kan[i, :],
        label=f"KAN n={samples_list[i]}",
        linestyle='-'
    )
    plt.plot(
        np.linspace(1, epochs, epochs),
        max_distance_relu[i, :],
        label=f"ReLU n={samples_list[i]}",
        linestyle='--'
    )

# Configure the distance plot
plt.grid()
plt.title("Maximum Distances From Initialization (varying n)")
plt.xlabel("Epochs")
plt.ylabel("Maximum Distances")
plt.legend()

# Display the plots
plt.show()
