import numpy as np
import cv2 as cv
from matplotlib import pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import os
import sys
from kan import KAN
from tqdm import tqdm

class MyDataset(Dataset):
    def __init__(self, sequences, labels):
        self.sequences = sequences
        self.labels = labels

    def __getitem__(self, idx):
        sequence = self.sequences[idx]
        label = self.labels[idx]
        return sequence, label

    def __len__(self):
        return len(self.sequences)

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 8, kernel_size=5, stride=1, padding=2)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.conv2 = nn.Conv2d(8, 16, kernel_size=5, stride=1, padding=2)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        return x
        

class KANNet(nn.Module):
    def __init__(self, cnn_model, input_size=784, output_size=10, grid_size=2, k=1):
        super(KANNet, self).__init__()
        self.cnn_model = cnn_model
        self.kan_layer = KAN(width=[input_size, output_size], grid=grid_size, k=k)

    def forward(self, x):
        cnn_output = self.cnn_model(x)
        cnn_output = cnn_output.view(cnn_output.size(0), -1)  # Flatten the output
        output = self.kan_layer(cnn_output)
        return output

def print_total_parameters(model, model_name="Model"):
    total_params = sum(p.numel() for p in model.parameters())
    print(f"Total number of parameters in {model_name}: {total_params}")

def train_and_test(mode):
    batch = 512
    learning_rate = 0.001
    num_epochs = 100
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    def process_data(path):
        data = []
        label = []
        all_folders = os.listdir(path)
        for idx, folder in enumerate(all_folders):
            numbers = os.listdir(os.path.join(path, folder))
            for number in numbers:
                img = cv.imread(os.path.join(path, folder, number), 0)
                img = np.expand_dims(img, axis=0)
                data.append(img)
                label.append(idx)
        return data, label

    train_data, train_label = process_data('../MNIST/transformed/TRAIN/')
    test_data, test_label = process_data('../MNIST/transformed/TEST/')

    train_dataset = MyDataset(train_data, train_label)
    test_dataset = MyDataset(test_data, test_label)

    train_dataloader = DataLoader(train_dataset, batch_size=batch, shuffle=True)
    test_dataloader = DataLoader(test_dataset, batch_size=batch, shuffle=True)

    cnn_model = CNN().to(device)
    kan_model = KANNet(cnn_model).to(device)
    print_total_parameters(kan_model, "KANNet")
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(kan_model.parameters(), lr=learning_rate)
    accuracy = []

    if mode == 'train':
        for epoch in range(num_epochs):
            kan_model.train()
            total_loss = 0
            correct = 0
            total = 0
            for i, (images, labels) in enumerate(tqdm(train_dataloader, desc=f"Epoch {epoch + 1}")):
                labels = labels.to(device)
                images = images.to(device).float()

                outputs = kan_model(images)
                loss = criterion(outputs, labels)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                total_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                correct += (predicted == labels).sum().item()
                total += labels.size(0)

            epoch_loss = total_loss / len(train_dataloader)
            epoch_acc = correct / total
            print(f"Epoch [{epoch+1}] Complete. Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.4f}")

            kan_model.eval()
            accu_test = 0
            with torch.no_grad():
                for images, labels in test_dataloader:
                    images = images.to(device).float()
                    labels = labels.to(device)
                    outputs = kan_model(images)
                    accu_test += (outputs.argmax(1) == labels).float().sum()

            test_accuracy = accu_test / len(test_label)
            accuracy.append(test_accuracy.item())  # Append test accuracy to the list
            print(f"Test Accuracy after Epoch [{epoch+1}]: {test_accuracy:.4f}")

        plt.plot(range(1, num_epochs + 1), accuracy)
        plt.title('Accuracy Over Epochs')
        plt.xlabel('Epoch')
        plt.ylabel('Accuracy')
        plt.grid(True)
        plt.savefig('Accuracy Over Epochs.png')
        torch.save(kan_model.state_dict(), 'cnn+kan_model.pth')

    elif mode == 'test':
        kan_model.load_state_dict(torch.load('cnn+kan_model.pth',weights_only=True))
        kan_model.eval()
        accu_test = 0
        with torch.no_grad():
            for images, labels in test_dataloader:
                images = images.to(device).float()
                labels = labels.to(device)
                outputs = kan_model(images)
                accu_test += (outputs.argmax(1) == labels).float().sum()

        accuracy = accu_test / len(test_label)
        print(f"Test Accuracy: {accuracy:.4f}")




if __name__ == "__main__":
    if len(sys.argv) != 2 or sys.argv[1] not in ['train', 'test']:
        print("Usage: python script.py [train/test]")
    else:
        train_and_test(sys.argv[1])
