"""
Comparative Analysis of Full vs First-Layer Training in Kolmogorov-Arnold Networks (KAN)
--------------------------------------------------------------------------------------

This script compares two training strategies for Kolmogorov-Arnold Networks (KAN):
1. Full network training (all layers updated)
2. First-layer only training (only the first layer's parameters are updated)

The implementation includes:
1. Training KAN models with both strategies
2. Tracking and comparing convergence metrics
3. Analyzing the impact of training strategy on model performance
4. Visualizing the differences in learning dynamics

Key components:
- KAN implementation with FastKANLayer
- Selective parameter freezing for first-layer training
- Comparative performance analysis
- Training visualization utilities
"""

# Standard library imports
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.optim as optim
import torch.nn as nn
from fastkan import FastKANLayer

n = 100 # number of samples
d = 100 # number of dimensions
X = np.zeros((n, d))
y = np.random.randn(n,)
for i in range(n):
  x = np.random.randn(d,)
  X[i, :] = x / np.linalg.norm(x)

# defining device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

# defining model
class Net(nn.Module):
    def __init__(self, m, d):
        super(Net, self).__init__()
        self.m = m
        self.kan1 = FastKANLayer(d, m, use_layernorm=False, use_base_update=False, spline_weight_init_scale=1.0)
        self.kan2 = FastKANLayer(m, 1, use_layernorm=False, use_base_update=False, spline_weight_init_scale=1.0)

    def forward(self, x):
        x = self.kan1(x)
        x = 1/np.sqrt(self.m)*self.kan2(x)
        return x
    
def calculate_inf_norm(params, params_0):
   w, w0 = None, None
   with torch.no_grad():
    for p in params:
        if p.requires_grad:
          w = p.clone().detach().cpu().numpy().reshape(-1,)
    for p0 in params_0:
        if p0.requires_grad:
          w0 = p0.clone().detach().cpu().numpy().reshape(-1,)
    return np.max(np.abs(w-w0))
    
X = torch.Tensor(X).to(device)
y = torch.Tensor(y).to(device)
widths = [500, 1000, 2000, 4000, 8000]  # widths
epochs = 2000  # number of training epochs
# Store results for different training modes
first_layer_only_results = {}
full_training_results = {}

def train_model(m, training_mode='first_layer_only'):
    """Train a model with different training modes.
    
    Args:
        m: Network width
        training_mode: 'first_layer_only' or 'full_training'
    """
    model = Net(m, d).to(device)
    
    # Set which layers to train based on training mode
    for name, param in model.named_parameters():
        if 'kan1' in name:  # First layer parameters
            param.requires_grad = True  # Always train first layer
        elif 'kan2' in name:  # Second layer parameters
            param.requires_grad = (training_mode == 'full_training')
    
    # Get initial parameters for distance calculation
    params_0 = [p.clone() for p in model.parameters() if p.requires_grad]
    
    criterion = nn.MSELoss()
    optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()))
    
    print(f"\nTraining with m={m}, {training_mode}")
    print(f"Number of trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
    print(f"Training first layer: {any('kan1' in name and p.requires_grad for name, p in model.named_parameters())}")
    print(f"Training second layer: {any('kan2' in name and p.requires_grad for name, p in model.named_parameters())}")
    
    loss_hist = np.zeros(epochs)
    
    for epoch in range(epochs):
        model.train()
        optimizer.zero_grad()
        output = model(X)
        loss = criterion(output.squeeze(), y)
        loss.backward()
        optimizer.step()
        
        loss_hist[epoch] = loss.item()
        
        if epoch % 500 == 0 or epoch == epochs - 1:
            print(f"Epoch {epoch}: Loss = {loss.item():.4f}")
    
    return loss_hist

# Train models with different configurations
for m in widths:
    # Train with first layer only
    first_layer_loss = train_model(m, training_mode='first_layer_only')
    first_layer_only_results[m] = first_layer_loss
    
    # Train with both layers
    full_training_loss = train_model(m, training_mode='full_training')
    full_training_results[m] = full_training_loss

# First plot: Training Loss Comparison
plt.figure(dpi=120)
for m in widths:
    plt.plot(np.log10(first_layer_only_results[m]), '--', label=f'First Layer Only (m={m})' if m == widths[0] else f'm={m}')
    plt.plot(np.log10(full_training_results[m]), '-', label=f'Full Training (m={m})' if m == widths[0] else f'm={m}')

plt.title('Training Loss: First Layer Only vs Full Training')
plt.xlabel('Epochs')
plt.ylabel('log10(Loss)')
plt.grid(True)
plt.legend()
plt.tight_layout()
plt.show()

# Second plot: Final Loss vs Width
plt.figure(dpi=120)
widths_array = np.array(widths)
first_only_final = [np.log10(first_layer_only_results[m][-1]) for m in widths]
full_train_final = [np.log10(full_training_results[m][-1]) for m in widths]

plt.plot(widths_array, first_only_final, 'o--', label='First Layer Only')
plt.plot(widths_array, full_train_final, 's-', label='Full Training')
plt.xscale('log')
plt.xlabel('Network Width (m)')
plt.ylabel('Final log10(Loss)')
plt.title('Final Loss vs Network Width')
plt.grid(True, which='both', linestyle='--')
plt.legend()
plt.tight_layout()

# Add text annotation with training details
plt.figtext(0.5, -0.1, 
            f'First Layer Only: Only first layer weights are updated\n'
            f'Full Training: Both first and second layer weights are updated',
            ha='center', fontsize=9, bbox=dict(facecolor='white', alpha=0.7))

plt.tight_layout(rect=[0, 0.05, 1, 1])  # Adjust layout to make room for the text
plt.show()
