DEFAULT_BS = 64
batch_size = 64

from torch.utils.data import Dataset, DataLoader, TensorDataset, Subset
import torch
import torch.nn as nn
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import math

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def trainable_parameters(network: nn.Module):
    for param in network.parameters():
        if param.requires_grad:
            yield param

def nparams(network: nn.Module):
    return len(torch.nn.utils.parameters_to_vector(trainable_parameters(network)))

def vectorize(tensor: torch.Tensor):
    return tensor.reshape(tensor.shape[0], -1)


def grad_batch2vec(network: nn.Module):
    vec = []
    for param in trainable_parameters(network):
        vec.append(vectorize(param.grad_batch).detach())
    result = torch.cat(vec, dim=1)
    del vec
    return result

def grad_2vec(network: nn.Module):
    vec = []
    for param in trainable_parameters(network):
        vec.append(vectorize(param.grad.unsqueeze(0)).detach())
    result = torch.cat(vec, dim=1)
    del vec
    return result

def vec2param(vec: torch.Tensor, network: nn.Module):
    loc = 0
    for param in network.parameters():
        num_el = param.numel()
        subvec = vec[loc:loc+num_el]
        loc += num_el
        param.data = subvec.reshape(param.shape)

def iterate_dataset_idx(dataset: Dataset, batch_size: int):
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    for i, (batch_X, batch_y) in enumerate(loader):
        yield batch_X.to(device), batch_y.to(device), \
              i * batch_size, i * batch_size + len(batch_X)


def compute_output(network: nn.Module, k: int, dataset: Dataset, batch_size):
    n = len(dataset)
    output_list = torch.zeros(n, k)

    for (X, y, start, end) in iterate_dataset_idx(dataset, batch_size):
        output = network(X)
        output_list[start:end,:] = output
    return output_list.detach().reshape(-1,1)

def compute_network_gradients(network: nn.Module, optimizer, dataset: Dataset, batch_size):
    p = nparams(network)
    n = len(dataset)
    k = network(dataset[0][0].unsqueeze(0).to(device)).shape[1]
    gradients = torch.zeros(n, k, p)

    for c in range(k):
        for (X, y, start, end) in iterate_dataset_idx(dataset, batch_size):
            for i in range(len(y)):
                optimizer.zero_grad()
                batch_score = network(X)[:, c][i]
                batch_score.backward()
                gradients[start+i, c, :] = grad_2vec(network)

    return gradients


def compute_1D_output_error(network: nn.Module, dataset: Dataset, batch_size):
    n = len(dataset)
    error = torch.zeros(n, 1)
    target = torch.zeros(n, 1)

    for (X, y, start, end) in iterate_dataset_idx(dataset, batch_size):
        output = network(X)
        y = y.float().reshape(y.size(0), 1)
        error[start:end] = output - y
        target[start:end] = y
    return error.reshape(-1), target.reshape(-1)


# chunk_size = maximum number of examples to process at once
# real_chunk_size [aka "real chunk size"] = maximum number of example*class to process at once
def compute_ntk_matrix(network: nn.Module, optimizer, dataset: Dataset, real_chunk_size: int = None, batch_size: int = DEFAULT_BS, verbose=True):
    n = len(dataset)
    k = network(dataset[0][0].unsqueeze(0).to(device)).shape[1]
    chunk_size = n if real_chunk_size is None else real_chunk_size // k
    ntk = torch.zeros(n, n, k, k)
    nchunks = math.ceil(float(n) / chunk_size)
    for i in range(nchunks):
        if verbose:
            print(f"examples {i*chunk_size} through {min((i+1)*chunk_size, n)}", flush=True)
        for j in range(i, nchunks):
            # print(j)
            chunk_i = TensorDataset(*[tensor[i*chunk_size:(i+1)*chunk_size] for tensor in dataset.tensors])
            chunk_j = TensorDataset(*[tensor[j*chunk_size:(j+1)*chunk_size] for tensor in dataset.tensors])
            Ji = compute_network_gradients(network, optimizer, chunk_i, batch_size).detach()
            Jj = compute_network_gradients(network, optimizer, chunk_j, batch_size).detach()

            # ntk_chunk is [example_chunk_size, example_chunk_size, k, k]
            ntk_chunk = torch.einsum("ikp,jlp->ijkl", Ji, Jj)
            ntk[i*chunk_size:(i+1)*chunk_size, j*chunk_size:(j+1)*chunk_size, :, :] = ntk_chunk
            ntk[j*chunk_size:(j+1)*chunk_size, i*chunk_size:(i+1) * chunk_size, :, :] = ntk_chunk.transpose(0, 1)

            del Ji
            del Jj
    return ntk

### Dataset
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

train_data_full = datasets.CIFAR10(root='./pytorch/data/', train=True,
                                        download=True, transform=transform)

test_data_full = datasets.CIFAR10(root='./pytorch/data/', train=False,
                                       download=True, transform=transform)

classes = ('airplane', 'automobile', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

targets = torch.tensor(train_data_full.targets)
subset_indices = ((targets == 0)+(targets == 1))
subset_indices = subset_indices[0:10000].nonzero().view(-1)
train_data = Subset(train_data_full,subset_indices)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=False)

xs, ys = [], []
for x, y in train_loader:
    xs.append(x)
    ys.append(y)
#
train_x = torch.cat(xs)
train_y = torch.cat(ys)
train_y = (2*train_y-1).reshape(-1,1).double()
target_y = train_y

# path = './vgg11_models_NTKs/'
# seed = 0
# lr_list = [0.005,0.01,0.015,0.02]
# start_step = [0, 0, 0, 0]
# final_step = [220, 160, 120, 100]
# interval = 20

interval = 20
path = './vit_models_NTKs/'
seed = 200
lr_list = [0.005,0.01,0.02]
start_step = [0,0,0]
final_step = [740,460,400]

alignment_values = []

for i in range(len(lr_list)):
    learning_rate = lr_list[i]
    path_lr = path + 'lr' + str(learning_rate) + '/'
    epoch_list = np.linspace(start_step[i], final_step[i], int((final_step[i] - start_step[i]) / interval) + 1, dtype=int)
    alignment_list = np.zeros_like(epoch_list, dtype=float)

    for j in range(len(epoch_list)):
        epoch = epoch_list[j]

        kernel_df = pd.read_csv(path_lr + 'NTK_seed' + str(seed) + '_step' + str(epoch) + '.csv', index_col=0)
        kernel = kernel_df.to_numpy()
        kernel_torch = torch.tensor(kernel)

        alignment = target_y.T @ kernel_torch @ target_y / torch.sum(target_y ** 2) / kernel_torch.norm()
        alignment_list[j] = alignment.item()
        print('lr:{},step:{}, alignment1:{}'.format(learning_rate, epoch, alignment_list[j]))
    alignment_values.append(alignment_list)

plt.figure()
plt.ylabel("KTA")
plt.title('Kernel Target Alignment')
plt.xlabel("Iteration")
for i in range(len(lr_list)):
    learning_rate = lr_list[i]
    epoch_list = np.linspace(start_step[i], final_step[i], int((final_step[i] - start_step[i]) / interval) + 1, dtype=int)
    plt.plot(epoch_list, alignment_values[i], label='lr' + str(learning_rate))
plt.legend()
plt.show()