"""
Comparative Analysis of First-Layer vs Second-Layer Training in Kolmogorov-Arnold Networks (KAN)
----------------------------------------------------------------------------------------------

This script compares two layer-specific training strategies for Kolmogorov-Arnold Networks (KAN):
1. First-layer only training (only the first layer's parameters are updated)
2. Second-layer only training (only the second layer's parameters are updated)

The implementation includes:
1. Training KAN models with both layer-specific strategies
2. Tracking and comparing convergence metrics between the approaches
3. Analyzing the impact of training different layers on model performance
4. Visualizing the differences in learning dynamics between the layers

Key components:
- KAN implementation with FastKANLayer
- Selective parameter freezing for layer-specific training
- Comparative performance analysis between layers
- Training visualization utilities
"""

# Standard library imports
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.optim as optim
import torch.nn as nn
from fastkan import FastKANLayer

n = 100 # number of samples
d = 100 # number of dimensions
X = np.zeros((n, d))
y = np.random.randn(n,)
for i in range(n):
  x = np.random.randn(d,)
  X[i, :] = x / np.linalg.norm(x)

# defining device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

# defining model
class Net(nn.Module):
    def __init__(self, m, d):
        super(Net, self).__init__()
        self.m = m
        self.kan1 = FastKANLayer(d, m, use_layernorm=False, use_base_update=False, spline_weight_init_scale=1.0)
        self.kan2 = FastKANLayer(m, 1, use_layernorm=False, use_base_update=False, spline_weight_init_scale=1.0)

    def forward(self, x):
        x = self.kan1(x)
        x = 1/np.sqrt(self.m)*self.kan2(x)
        return x
    
def calculate_inf_norm(params, params_0):
   w, w0 = None, None
   with torch.no_grad():
    for p in params:
        if p.requires_grad:
          w = p.clone().detach().cpu().numpy().reshape(-1,)
    for p0 in params_0:
        if p0.requires_grad:
          w0 = p0.clone().detach().cpu().numpy().reshape(-1,)
    return np.max(np.abs(w-w0))
    
X = torch.Tensor(X).to(device)
y = torch.Tensor(y).to(device)
widths = [500, 1000, 2000, 4000, 8000]  # widths
epochs = 2000  # number of training epochs
# Store results for first and second layer training
first_layer_results = {}
second_layer_results = {}

def train_model(m, train_first_layer):
    """Train a model with either first or second layer trainable."""
    model = Net(m, d).to(device)
    
    # Set which layer to train
    for name, param in model.named_parameters():
        if 'kan1' in name:  # First layer parameters
            param.requires_grad = train_first_layer
        elif 'kan2' in name:  # Second layer parameters
            param.requires_grad = not train_first_layer
    
    # Get initial parameters for distance calculation
    params_0 = [p.clone() for p in model.parameters() if p.requires_grad]
    
    criterion = nn.MSELoss()
    optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()))
    
    layer_type = 'First' if train_first_layer else 'Second'
    print(f"\nTraining with m={m}, {layer_type} layer trainable")
    print(f"Number of trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
    
    loss_hist = np.zeros(epochs)
    
    for epoch in range(epochs):
        model.train()
        optimizer.zero_grad()
        output = model(X)
        loss = criterion(output.squeeze(), y)
        loss.backward()
        optimizer.step()
        
        loss_hist[epoch] = loss.item()
        
        if epoch % 500 == 0 or epoch == epochs - 1:
            print(f"Epoch {epoch}: Loss = {loss.item():.4f}")
    
    return loss_hist

# Train models with different configurations
for m in widths:
    # Train with first layer only
    first_layer_loss = train_model(m, train_first_layer=True)
    first_layer_results[m] = first_layer_loss
    
    # Train with second layer only
    second_layer_loss = train_model(m, train_first_layer=False)
    second_layer_results[m] = second_layer_loss

# First plot: Training Loss
plt.figure(dpi=120)
for m in widths:
    plt.plot(np.log10(first_layer_results[m]), '--', label=f'First Layer (m={m})' if m == widths[0] else f'm={m}')
    plt.plot(np.log10(second_layer_results[m]), '-', label=f'Second Layer (m={m})' if m == widths[0] else f'm={m}')

plt.title('Training Loss: First Layer vs Second Layer Training')
plt.xlabel('Epochs')
plt.ylabel('log10(Loss)')
plt.grid(True)
plt.legend(ncol=2)
plt.tight_layout()
plt.show()

# Second plot: Final Loss vs Width
plt.figure(dpi=120)
widths_array = np.array(widths)
first_final = [np.log10(first_layer_results[m][-1]) for m in widths]
second_final = [np.log10(second_layer_results[m][-1]) for m in widths]

plt.plot(widths_array, first_final, 'o--', label='First Layer Training')
plt.plot(widths_array, second_final, 's-', label='Second Layer Training')
plt.xscale('log')
plt.xlabel('Width (m)')
plt.ylabel('Final log10(Loss)')
plt.title('Final Loss vs Network Width')
plt.grid(True, which='both', linestyle='--')
plt.legend()
plt.tight_layout()
plt.show()
