"""
Convergence Analysis of Kolmogorov-Arnold Networks (KAN) on CIFAR-10
-------------------------------------------------------------------

This script analyzes the convergence behavior of Kolmogorov-Arnold Networks (KAN) 
on the CIFAR-10 dataset with varying hidden layer widths.
"""

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.optim as optim
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from fastkan import FastKANLayer
from torch.utils.data import DataLoader

# Hyperparameters
batch_size = 100
input_channels = 3  # CIFAR-10 has 3 color channels
image_size = 32 * 32 * 3  # CIFAR-10 images are 32x32 with 3 channels
num_classes = 10

# Load CIFAR-10 dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    transforms.Lambda(lambda x: x.view(-1))  # Flatten the image
])

train_dataset = torchvision.datasets.CIFAR10(
    root='./data', 
    train=True,
    download=True,
    transform=transform
)

train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=batch_size,
    shuffle=True
)

# Define device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Define the KAN model for CIFAR-10
class KAN_CIFAR10(nn.Module):
    def __init__(self, hidden_size, input_dim=3072, num_classes=10):  # 32*32*3 = 3072
        super(KAN_CIFAR10, self).__init__()
        self.hidden_size = hidden_size
        # First KAN layer with more capacity
        self.kan1 = FastKANLayer(input_dim, hidden_size, use_layernorm=True, spline_weight_init_scale=1.0)
        # Second KAN layer
        self.kan2 = FastKANLayer(hidden_size, num_classes, use_layernorm=True, spline_weight_init_scale=1.0)
        
    def forward(self, x):
        x = self.kan1(x)
        x = 1/np.sqrt(self.hidden_size)*self.kan2(x)
        return x

def calculate_inf_norm(params, params_0):
    with torch.no_grad():
        max_dist = 0
        for p, p0 in zip(params, params_0):
            if p.requires_grad:
                dist = torch.max(torch.abs(p - p0)).item()
                max_dist = max(max_dist, dist)
        return max_dist

def train_model(model, train_loader, num_epochs=10, learning_rate=0.01):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    
    # Store initial parameters
    params_0 = [p.clone().detach() for p in model.parameters() if p.requires_grad]
    
    losses = []
    distances = []
    
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        
        for i, (images, labels) in enumerate(train_loader):
            images, labels = images.to(device), labels.to(device)
            
            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            # Backward and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
        
        # Calculate average loss for the epoch
        avg_loss = running_loss / len(train_loader)
        losses.append(avg_loss)
        
        # Calculate maximum parameter distance from initialization
        params = [p for p in model.parameters() if p.requires_grad]
        max_dist = calculate_inf_norm(params, params_0)
        distances.append(max_dist)
        
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}, Max Distance: {max_dist:.6f}')
    
    return losses, distances

# Main experiment
widths = [64, 256, 1024, 4096]  # Increased widths for CIFAR-10
num_epochs = 20
results = {}

# Train models and collect results
for width in widths:
    print(f"\nTraining model with hidden size: {width}")
    model = KAN_CIFAR10(hidden_size=width).to(device)
    
    # Count trainable parameters
    total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Number of trainable parameters: {total_params:,}")
    
    # Train the model
    losses, distances = train_model(model, train_loader, num_epochs)
    results[width] = {'losses': losses, 'distances': distances}

# Create separate figures
plt.figure(figsize=(10, 5))
for width in widths:
    plt.plot(results[width]['losses'], label=f'm={width}')

plt.title('Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()

# Second figure for parameter distances
plt.figure(figsize=(10, 5))
for width in widths:
    plt.plot(results[width]['distances'], label=f'm={width}')

plt.title('Parameter Distance from Initialization')
plt.xlabel('Epoch')
plt.ylabel('Max Distance')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()
