#%%
import torch 
import torch.nn as nn
import torch.nn.functional as F

import numpy as np


import torchvision
import torchvision.transforms as transforms
from sklearn.model_selection import KFold
import pickle
import random 
from models.LinearModel  import KronLinear

from utils.loss import regularizer

random_seed = 42
torch.manual_seed(random_seed)
random.seed(random_seed)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
train_mnist = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transforms.ToTensor())
test_mnist = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transforms.ToTensor())

batch_size = 3000

patterns = [
    [2, 2],
    [4, 2],
    [8, 2],
    [16, 2],
    
]




class SelectionModel(nn.Module):
    def __init__(self, patterns):
        super(SelectionModel, self).__init__()
        self.patterns = patterns
        self.models = nn.ModuleList([KronLinear(784, 10, pattern, structured_sparse=True, rank=2) for pattern in patterns])
        
    def forward(self, x):
        return [model(x) for model in self.models]
    
    def get_group_lasso(self, lr, pattern='dim'):
        group_loss = 0
        
        for model in self.models:
            # print(model.s.flatten().shape[0])
            group_loss += torch.sqrt(torch.norm(model.s,p=2)/np.sqrt(model.s.flatten().shape[0])) * lr
        return group_loss


    
lambda1 = 0.1
lambda2 = 0.1

def test_model(model, test_loader):
    model.eval()

    total = 0
    
    each_correct = [0] * len(model.patterns)
    with torch.no_grad():
        for x, y in test_loader:
            x = x.to(device)
            x = x.view(-1, 784)
            y = y.to(device)
            outputs = model(x)
            total += y.size(0)
            for output in outputs:
                _, predicted = torch.max(output, 1)
                each_correct[outputs.index(output)] += (predicted == y).sum().item()
        print(f"Accuracy: {each_correct[0]/total}, {each_correct[1]/total}, {each_correct[2]/total}, {each_correct[3]/total}")
                
epochs = 50
sparse = np.zeros((len(patterns), epochs))
def train_model(model, train_loader, test_loader, epochs=epochs):
    model = model.to(device)
    optimizer = torch.optim.SGD(model.parameters(), lr=0.05, momentum=0.9)
    for i in range(len(model.patterns)):
        print(torch.sum(torch.abs(model.models[i].s))/ model.models[i].s.flatten().shape[0])
    accuracy = []
    for epoch in range(epochs):
        for i, (x, y) in enumerate(train_loader):
            x = x.to(device)
            x = x.view(-1, 784)
            y = y.to(device)
            total_loss = 0
            output = model(x)
            group_loss = model.get_group_lasso(epoch)
            regular_loss = regularizer(model,p=1)[0] / regularizer(model,p=1)[1]
            total_loss  += lambda1 * group_loss * (epochs+epoch*2)/epochs
            total_loss += lambda2 * regular_loss * (epochs+epoch*2)/epochs
            for out in output:
                loss = F.cross_entropy(out, y)
                total_loss += loss
            total_loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            
            if i % 10 == 0:
                print(f"Epoch {epoch} Iteration {i} Loss {loss.item()}")
        with torch.no_grad():
            correct_each = np.array([0, 0, 0, 0, 0, 0])
            total = 0
            for x, y in test_loader:
                x = x.to(device)
                x = x.view(-1, 784)
                y = y.to(device)
                output = model(x)
                for i, out in enumerate(output):
                    _, predicted = torch.max(out, 1)
                    correct_each[i] += (predicted == y).sum().item()
                total += y.size(0)
            print(correct_each/total)
            
        for i in range(len(model.patterns)):
            sparse[i, epoch] = model.models[i].s.sum()/ model.models[i].s.flatten().shape[0]
            
                


    return accuracy

train_loader = torch.utils.data.DataLoader(train_mnist, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_mnist, batch_size=batch_size, shuffle=True)
model = SelectionModel(patterns)
train_model(model, train_loader, test_loader, epochs=epochs)
import matplotlib.pyplot as plt
plt.plot(sparse.T)
abs_sparse = np.abs(sparse)
plt.plot(abs_sparse.T)
#%%
abs_sparse = np.abs(sparse)
# only show 60 epochs
show_out = abs_sparse.T[:50]
plt.plot(show_out, linewidth=2.5)
plt.legend(["k=1", "k=2", "k=3", "k=4"])
plt.xlabel("Epoch", fontsize=15)
plt.ylabel(r'$||S^{(k)}||_1$',fontsize=15)
plt.grid()
# plt.title("Sparsity of the patterns")
plt.show()

#%%



#%%
correct_each = np.array([0, 0, 0, 0, 0, 0])
total = 0
for x, y in test_loader:
    x = x.to(device)
    x = x.view(-1, 784)
    y = y.to(device)
    output = model(x)
    for i, out in enumerate(output):
        _, predicted = torch.max(out, 1)
        correct_each[i] += (predicted == y).sum().item()
    total += y.size(0)
print(correct_each/total)
#%%
# find the best pattern 
# best_index = np.argmax(correct_each)
best_index = 0
best_pattern = patterns[best_index]

new_linear = KronLinear(784, 10, best_pattern, structured_sparse=True)
new_linear.a = model.models[best_index].a
new_linear.s = model.models[best_index].s
new_linear.b = model.models[best_index].b
new_linear.to(device)

#%%
index = 0
def finetune(module, train_loader, epoch=1):
    optimizer = torch.optim.SGD(module.parameters(), lr=0.01, momentum=0.9)
    for i in range(epoch):
        for x, y in train_loader:
            x = x.to(device)
            x = x.view(-1, 784)
            y = y.to(device)
            output = module(x)
            loss = F.cross_entropy(output, y)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            # print(f"Epoch {i} Loss {loss.item()}")
        total = 0
        correct = 0
        for x, y in test_loader:
            x = x.to(device)
            x = x.view(-1, 784)
            y = y.to(device)
            output = new_linear(x)
            _, predicted = torch.max(output, 1)
            correct += (predicted == y).sum().item()
            total += y.size(0)
        print(correct/total)  
#%%
# finetune the new linear model for 5 epochs
finetune(new_linear, train_loader, 100)
#%%
total = 0
correct = 0

for x, y in test_loader:
    x = x.to(device)
    x = x.view(-1, 784)
    y = y.to(device)
    output = new_linear(x)
    _, predicted = torch.max(output, 1)
    correct += (predicted == y).sum().item()
    total += y.size(0)
print(correct/total)    
#%%
