DEFAULT_BS = 64
batch_size = 64

from torch.utils.data import Dataset, DataLoader, Subset
import torch
import torch.nn as nn
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### Dataset
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

train_data_full = datasets.CIFAR10(root='./pytorch/data/', train=True,
                                        download=True, transform=transform)

test_data_full = datasets.CIFAR10(root='./pytorch/data/', train=False,
                                       download=True, transform=transform)

classes = ('airplane', 'automobile', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

targets = torch.tensor(train_data_full.targets)
subset_indices = ((targets == 0)+(targets == 1))
subset_indices = subset_indices[0:10000].nonzero().view(-1)
train_data = Subset(train_data_full,subset_indices)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=False)
n1 = len(train_data)
xs, ys = [], []
for x, y in train_loader:
    xs.append(x)
    ys.append(y)

train_target = torch.cat(ys).reshape(-1,1).numpy()#F.one_hot(torch.cat(ys),num_classes=10).reshape(-1,1).numpy()

path = './vit_models_NTKs/'
seed = 200
lr_list = [0.005,0.01,0.02]
final_step_list = [740, 460, 400]

# path = './vgg11_models_NTKs/'
# seed = 0
# lr_list = [0.005,0.01,0.015,0.02]
# final_step_list = [220, 160, 120, 100]

num_lr = len(lr_list)
train_target = 2*train_target - 1
num_index = 20
coeff_target_array = np.zeros((num_index,num_lr))
normalized_eigval_array = np.zeros((num_index,num_lr))
for i in range(num_lr):
    ntk_lr = lr_list[i]
    print(f'learning rate:{ntk_lr}')
    path_lr = path + 'lr' + str(ntk_lr)+ '/'
    step = final_step_list[i]

    train_kernel_final_df = pd.read_csv(path_lr + 'NTK_seed' + str(seed) +'_step' + str(step) + '.csv', index_col=0)
    train_kernel_final = train_kernel_final_df.to_numpy()
    U, S_final, Vh_final = np.linalg.svd(train_kernel_final)

    coeff_target_full = U.T @ train_target/np.linalg.norm(train_target)
    coeff_target_array[0:num_index, i:i + 1] = coeff_target_full[0:num_index]
    normalized_eigval_full = S_final / np.linalg.norm(S_final)
    normalized_eigval_array[0:num_index, i:i + 1] = normalized_eigval_full[0:num_index].reshape(-1,1)

fig = plt.figure()
plt.subplot(311)
plt.title('Normalized eigenvalues')
plt.xlabel("Eigenvalue index")
for i in range(len(lr_list)):
    plt.plot(normalized_eigval_array[:,i], label='lr='+str(lr_list[i]))
plt.legend()
plt.subplot(312)
plt.title('Individual Eigenvector Target Alignment')
plt.xlabel("Eigenvector Index")
for i in range(len(lr_list)):
    plt.plot(coeff_target_array[:,i]**2, label='lr='+str(lr_list[i]))
plt.legend()
plt.subplot(313)
plt.title('Cumulative Summation of Individual Eigenvector Target Alignment')
plt.xlabel("Eigenvector Index")
for i in range(len(lr_list)):
    plt.plot(np.cumsum(coeff_target_array[:,i]**2), label='lr='+str(lr_list[i]))
plt.legend(loc='upper right')
plt.tight_layout()
plt.show()
