"""
Convergence Analysis of Kolmogorov-Arnold Networks (KAN) on MNIST
----------------------------------------------------------------

This script analyzes the convergence behavior of Kolmogorov-Arnold Networks (KAN) 
on the MNIST dataset with varying hidden layer widths.
"""

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.optim as optim
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from fastkan import FastKANLayer
from torch.utils.data import DataLoader

# Hyperparameters
batch_size = 100
input_size = 28 * 28  # MNIST images are 28x28
num_classes = 10

# Load MNIST dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,)),
    transforms.Lambda(lambda x: x.view(-1))  # Flatten the image
])

train_dataset = torchvision.datasets.MNIST(
    root='./data', 
    train=True,
    download=True,
    transform=transform
)

train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=batch_size,
    shuffle=True
)

# Define device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Define the KAN model
class KAN_MNIST(nn.Module):
    def __init__(self, hidden_size, input_dim=784, num_classes=10):
        super(KAN_MNIST, self).__init__()
        self.hidden_size = hidden_size
        self.kan1 = FastKANLayer(input_dim, hidden_size, use_layernorm=True, spline_weight_init_scale=1.0)
        self.kan2 = FastKANLayer(hidden_size, num_classes, use_layernorm=True, spline_weight_init_scale=1.0)
        
    def forward(self, x):
        x = self.kan1(x)
        x = 1/np.sqrt(self.hidden_size)*self.kan2(x)
        return x

def calculate_inf_norm(params, params_0):
    with torch.no_grad():
        max_dist = 0
        for p, p0 in zip(params, params_0):
            if p.requires_grad:
                dist = torch.max(torch.abs(p - p0)).item()
                max_dist = max(max_dist, dist)
        return max_dist

def train_model(model, train_loader, num_epochs=10, learning_rate=0.01):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    
    # Store initial parameters
    params_0 = [p.clone().detach() for p in model.parameters() if p.requires_grad]
    
    losses = []
    distances = []
    
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        
        for i, (images, labels) in enumerate(train_loader):
            images, labels = images.to(device), labels.to(device)
            
            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            # Backward and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
        
        # Calculate average loss for the epoch
        avg_loss = running_loss / len(train_loader)
        losses.append(avg_loss)
        
        # Calculate maximum parameter distance from initialization
        params = [p for p in model.parameters() if p.requires_grad]
        max_dist = calculate_inf_norm(params, params_0)
        distances.append(max_dist)
        
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}, Max Distance: {max_dist:.6f}')
    
    return losses, distances

# Main experiment
widths = [32, 128, 512, 2048]  # Different hidden layer sizes to try
num_epochs = 20
results = {}

# Train models and collect results
for width in widths:
    print(f"\nTraining model with hidden size: {width}")
    model = KAN_MNIST(hidden_size=width).to(device)
    
    # Count trainable parameters
    total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Number of trainable parameters: {total_params:,}")
    
    # Train the model
    losses, distances = train_model(model, train_loader, num_epochs)
    results[width] = {'losses': losses, 'distances': distances}

# Create separate figures
plt.figure(figsize=(10, 5))
for width in widths:
    plt.plot(results[width]['losses'], label=f'm={width}')

plt.title('Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()

# Second figure for parameter distances
plt.figure(figsize=(10, 5))
for width in widths:
    plt.plot(results[width]['distances'], label=f'm={width}')

plt.title('Parameter Distance from Initialization')
plt.xlabel('Epoch')
plt.ylabel('Max Distance')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()
