from itertools import islice
from tqdm.auto import tqdm
import torch
import torch.nn as nn
import torchvision
from kan import *
import numpy as np
import time
import copy

seed = 0
np.random.seed(seed)
torch.manual_seed(seed)

grids = [3,5,10]
widths = [[784,10],[784,10,10],[784,100,10],[784,100,100,10],[784,100,100,100,10]]

def cycle(iterable):
    while True:
        for x in iterable:
            yield x

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

train = torchvision.datasets.MNIST(root="./tmp", train=True, transform=torchvision.transforms.ToTensor(), download=False)
test = torchvision.datasets.MNIST(root="./tmp", train=False, transform=torchvision.transforms.ToTensor(), download=False)

dataset = {}
dataset['train_input'] = train.data.reshape(train.data.shape[0],-1).float()/256
dataset['train_label'] = train.targets.long()
dataset['test_input'] = test.data.reshape(test.data.shape[0],-1).float()/256
dataset['test_label'] = test.targets.long()

def train_acc():
    batch = 1024
    N_batch = train.data.shape[0] // batch
    correct = 0
    for i_batch in range(N_batch+1):
        logits = model(dataset['train_input'][i_batch*batch:(i_batch+1)*batch])
        pred_class = torch.argmax(logits, dim=1)
        correct += torch.sum(pred_class == dataset['train_label'][i_batch*batch:(i_batch+1)*batch])
    return correct/train.data.shape[0]

def train_loss():
    batch = 1024
    N_batch = train.data.shape[0] // batch
    loss = 0
    for i_batch in range(N_batch+1):
        logits = model(dataset['train_input'][i_batch*batch:(i_batch+1)*batch])
        loss += torch.nn.CrossEntropyLoss(reduction='sum')(logits, dataset['train_label'][i_batch*batch:(i_batch+1)*batch])
    return loss/train.data.shape[0]

def test_acc():
    batch = 1024
    N_batch = test.data.shape[0] // batch
    correct = 0
    for i_batch in range(N_batch+1):
        logits = model(dataset['test_input'][i_batch*batch:(i_batch+1)*batch])
        pred_class = torch.argmax(logits, dim=1)
        correct += torch.sum(pred_class == dataset['test_label'][i_batch*batch:(i_batch+1)*batch])
    return correct/test.data.shape[0]

def test_loss():
    batch = 1024
    N_batch = test.data.shape[0] // batch
    loss = 0
    for i_batch in range(N_batch+1):
        logits = model(dataset['test_input'][i_batch*batch:(i_batch+1)*batch])
        loss += torch.nn.CrossEntropyLoss(reduction='sum')(logits, dataset['test_label'][i_batch*batch:(i_batch+1)*batch])
    return loss/test.data.shape[0]


for width in widths:
    for grid in grids:
        
        width_0 = copy.deepcopy(width)

        model = KAN(width=width_0, grid=grid)

        start_time = time.time()
        model.fit(dataset, steps=2000, batch=1024, stop_grid_update_step=500, loss_fn=torch.nn.CrossEntropyLoss(), opt="Adam", lr=1e-2);

        end_time = time.time()
        wall_time = end_time - start_time

        train_accn = train_acc().item()
        test_accn = test_acc().item()
        train_lossn = train_loss().item()
        test_lossn = test_loss().item()
        
        np.savetxt(f'./results/kan_width_{width}_grid_{grid}.txt', [train_lossn, test_lossn, train_accn, test_accn, wall_time])
