#%%
import torch 
import torch.nn as nn
import torch.nn.functional as F

import numpy as np


import torchvision
import torchvision.transforms as transforms
from sklearn.model_selection import KFold
import pickle
import random 
from models.LinearModel  import KronLinear
from models.LeNet import KPDLeNet

from utils.loss import regularizer

random_seed = 30
torch.manual_seed(random_seed)
random.seed(random_seed)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
train_mnist = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transforms.ToTensor())
test_mnist = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transforms.ToTensor())

batch_size = 3000

patterns = [
    [[16, 12], [12,7], [7, 2]],
    [[16,10], [10, 12], [12, 2]],
    [[8, 2], [2, 4], [4, 2]],
    [[16,5], [5, 4], [4, 2]],
    [[4, 5], [5, 12], [12, 2]],
]


class PatternSelect(nn.Module):
    def __init__(self, patterns):
        super(PatternSelect, self).__init__()
        self.patterns = patterns
        self.models = nn.ModuleList()
        for pattern in patterns:
            self.models.append(KPDLeNet(pattern))
            
    def forward(self, x):
        outputs = []
        for model in self.models:
            outputs.append(model(x))
        return outputs
    
    def get_group_lasso(self, lr):
        group_lasso = 0
        for model in self.models:
            group_lasso += model.get_group_lasso() * lr
        return group_lasso


model = PatternSelect(patterns)


    
lambda1 = 0.01
lambda2 = 0.01

acc = []
def test_model(model, test_loader):
    model.eval()

    total = 0
    
    each_correct = [0] * len(model.patterns)
    with torch.no_grad():
        for x, y in test_loader:
            x = x.to(device)
            y = y.to(device)
            outputs = model(x)
            total += y.size(0)
            for output in outputs:
                _, predicted = torch.max(output, 1)
                each_correct[outputs.index(output)] += (predicted == y).sum().item()
        print(f"Accuracy: {each_correct[0]/total}, {each_correct[1]/total}, {each_correct[2]/total}, {each_correct[3]/total}")
epochs = 200
sparse = np.zeros((len(patterns), epochs))
def train_model(model, train_loader, test_loader, epochs=epochs):
    model = model.to(device)
    optimizer = torch.optim.SGD(model.parameters(), lr=0.05, momentum=0.9)

    accuracy = []
    for epoch in range(epochs):
        for i, (x, y) in enumerate(train_loader):
            x = x.to(device)
            y = y.to(device)
            total_loss = 0
            output = model(x)
            group_loss = model.get_group_lasso(epoch+1)
            from utils.loss import lenet_regularizer
            regular_loss = regularizer(model, p=1)[0] / regularizer(model, p=1)[1]
            # print(group_loss, regular_loss)
            import math 
            total_loss  += lambda1 * group_loss * (epochs+epoch*2)/epochs
            total_loss += lambda2 * regular_loss * (epochs+epoch*2)/epochs
            for out in output:
                # print(out)
                loss = F.cross_entropy(out, y)
                total_loss += loss
            total_loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            
            # if i % 10 == 0:
                # print(f"Epoch {epoch} Iteration {i} Loss {loss.item()}")
        if epoch % 2 == 0:
            correct_each = np.array([0, 0, 0, 0, 0, 0])
            total = 0
            for x, y in test_loader:
                x = x.to(device)
                y = y.to(device)
                output = model(x)
                for i, out in enumerate(output):
                    _, predicted = torch.max(out, 1)
                    correct_each[i] += (predicted == y).sum().item()
                total += y.size(0)
            # print(correct_each/total)    
            acc.append(correct_each/total)
        for i in range(len(patterns)):
            sparse[i, epoch] = model.models[i].get_s_norm()
            # print
        print(sparse[:, epoch])
            

                


    return accuracy

train_loader = torch.utils.data.DataLoader(train_mnist, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_mnist, batch_size=batch_size, shuffle=True)

train_model(model, train_loader, test_loader, epochs=epochs)

correct_each = np.array([0, 0, 0, 0, 0, 0])
total = 0
for x, y in test_loader:
    x = x.to(device)
    y = y.to(device)
    output = model(x)
    for i, out in enumerate(output):
        _, predicted = torch.max(out, 1)
        correct_each[i] += (predicted == y).sum().item()
    total += y.size(0)
print(correct_each/total)
#%%
# replace nan to 0 in sparse
n_sparse = np.nan_to_num(sparse)
#%%
n_sparse[0][90] = 0
print(n_sparse[0][90])
#%%
abs_sparse = np.abs(n_sparse)
# only show 60 epochs
show_out = abs_sparse.T
import matplotlib.pyplot as plt
plt.plot(show_out, linewidth=2.5)
plt.legend(["k=1", "k=2", "k=3", "k=4", 'k=5'])
plt.xlabel("Epoch", fontsize=15)
plt.ylabel(r'$\sum_l^L ||S^{[l],(k)}||_1$',fontsize=15)
plt.grid()
# plt.title("Sparsity of the patterns")
plt.show()
#%%

import matplotlib.pyplot as plt
acc = np.array(acc)
for i in range(5):
    plt.plot(acc[:,i])
plt.legend(['16-12-7-2', '16-10-12-2', '8-4-2', '16-5-4-2', '4-5-12-2'])
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.grid()

#%%
for i in model.models:
    s_ = torch.sum(torch.abs(s_))/torch.numel(s_) 
    print(s_)

#%%
for i in model.models:
    s_ = 0
    s_ += torch.sum(torch.abs(i.kron_fc1.s))/torch.numel(i.kron_fc1.s) 
    s_ += torch.sum(torch.abs(i.kron_fc2.s))/torch.numel(i.kron_fc2.s)
    s_ += torch.sum(torch.abs(i.kron_fc3.s))/torch.numel(i.kron_fc3.s)
    print(s_)

# for i in model.models:
#     threshold = 1e-3
#     sparsity = 0
#     total = 0
#     if i.s is not None:
#         sparsity += torch.sum(torch.abs(i.s * i.a) < threshold).item()
#         total += (i.s * i.a).numel()
#     print(sparsity/total)
#%%
# find the best pattern 
best_index = np.argmax(correct_each)
print(best_index)
best_pattern = patterns[best_index]

new_linear = KronLinear(784, 10, best_pattern, structured_sparse=True, rank=5)
new_linear.a = model.models[best_index].a
new_linear.s = model.models[best_index].s
new_linear.b = model.models[best_index].b
new_linear.bias = model.models[best_index].bias if hasattr(model.models[best_index], 'bias') else None
new_linear.to(device)

#%%
def finetune(module, train_loader, epoch=1):
    optimizer = torch.optim.SGD(module.parameters(), lr=0.1, momentum=0.9)
    for i in range(epoch):
        for x, y in train_loader:
            x = x.to(device)
            x = x.view(-1, 784)
            y = y.to(device)
            output = module(x)
            loss = F.cross_entropy(output, y)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            print(f"Epoch {i} Loss {loss.item()}")
#%%
# finetune the new linear model for 5 epochs
finetune(new_linear, train_loader, 5)
#%%
total = 0
correct = 0

for x, y in test_loader:
    x = x.to(device)
    x = x.view(-1, 784)
    y = y.to(device)
    output = new_linear(x)
    _, predicted = torch.max(output, 1)
    correct += (predicted == y).sum().item()
    total += y.size(0)
print(correct/total)    
#%%
