DEFAULT_BS = 64
batch_size = 64

from torch.utils.data import Dataset, DataLoader, TensorDataset, Subset
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from backpack import backpack, extend
from backpack.extensions import BatchGrad, BatchL2Grad
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import os
import math
from vit_pytorch import SimpleViT

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def trainable_parameters(network: nn.Module):
    for param in network.parameters():
        if param.requires_grad:
            yield param


def nparams(network: nn.Module):
    return len(torch.nn.utils.parameters_to_vector(trainable_parameters(network)))

def vectorize(tensor: torch.Tensor):
    return tensor.reshape(tensor.shape[0], -1)


def grad_batch2vec(network: nn.Module):
    vec = []
    for param in trainable_parameters(network):
        vec.append(vectorize(param.grad_batch).detach())
    result = torch.cat(vec, dim=1)
    del vec
    return result

def grad_2vec(network: nn.Module):
    vec = []
    for param in trainable_parameters(network):
        vec.append(vectorize(param.grad.unsqueeze(0)).detach())
    result = torch.cat(vec, dim=1)
    del vec
    return result

def vec2param(vec: torch.Tensor, network: nn.Module):
    loc = 0
    for param in network.parameters():
        num_el = param.numel()
        subvec = vec[loc:loc+num_el]
        loc += num_el
        param.data = subvec.reshape(param.shape)

def iterate_dataset_idx(dataset: Dataset, batch_size: int):
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    for i, (batch_X, batch_y) in enumerate(loader):
        yield batch_X.to(device), batch_y.to(device), \
              i * batch_size, i * batch_size + len(batch_X)

def compute_network_gradients(network: nn.Module, optimizer, dataset: Dataset, batch_size):
    p = nparams(network)
    n = len(dataset)
    k = network(dataset[0][0].unsqueeze(0).to(device)).shape[1]
    gradients = torch.zeros(n, k, p)

    for c in range(k):
        for (X, y, start, end) in iterate_dataset_idx(dataset, batch_size):
            for i in range(len(y)):
                optimizer.zero_grad()
                batch_score = network(X)[:, c][i]
                batch_score.backward()
                gradients[start+i, c, :] = grad_2vec(network)

    return gradients

def compute_loss_gradients(network: nn.Module, lossfunc, dataset: Dataset, batch_size):
    p = nparams(network)
    gradients = torch.zeros(p,1)

    network2 = extend(network)
    lossfunc2 = extend(lossfunc)
    for (X, y, start, end) in iterate_dataset_idx(dataset, batch_size):
        y = y.float().reshape(y.size(0), 1)  # 64x1

        loss = lossfunc2(network2(X), y)
        with backpack(BatchGrad()):
            loss.backward()
        grad_batch = grad_batch2vec(network2).sum(dim=0,keepdim=True)
        gradients += grad_batch.T*len(y)
    gradients = gradients/len(dataset)
    return gradients

def compute_1D_output_error(network: nn.Module, dataset: Dataset, batch_size):
    n = len(dataset)
    error = torch.zeros(n, 1)
    target = torch.zeros(n, 1)

    for (X, y, start, end) in iterate_dataset_idx(dataset, batch_size):
        output = network(X)
        y = y.float().reshape(y.size(0), 1)
        error[start:end] = output - y
        target[start:end] = y
    return error.reshape(-1), target.reshape(-1)

def compute_kernel_train_test(network: nn.Module, optimizer, train_dataset: Dataset, test_dataset: Dataset, real_chunk_size: int = None, batch_size: int = DEFAULT_BS, verbose=True):
    n1 = len(test_dataset)
    n2 = len(train_dataset)
    k = network(train_dataset[0][0].unsqueeze(0).to(device)).shape[1]
    chunk_size = min(n1,n2) if real_chunk_size is None else real_chunk_size // k
    kernel = torch.zeros(n1, n2, k, k)
    nchunks1 = math.ceil(float(n1) / chunk_size)
    nchunks2 = math.ceil(float(n2) / chunk_size)
    for i in range(nchunks1):
        if verbose:
            print(f"examples {i*chunk_size} through {min((i+1)*chunk_size, n1)}", flush=True)
        for j in range(nchunks2):
            # print(j)
            chunk_i = TensorDataset(*[tensor[i*chunk_size:(i+1)*chunk_size] for tensor in test_dataset.tensors])
            chunk_j = TensorDataset(*[tensor[j*chunk_size:(j+1)*chunk_size] for tensor in train_dataset.tensors])
            Ji = compute_network_gradients(network, chunk_i, batch_size).detach()
            Jj = compute_network_gradients(network, chunk_j, batch_size).detach()

            # ntk_chunk is [example_chunk_size, example_chunk_size, k, k]
            kernel_chunk = torch.einsum("ikp,jlp->ijkl", Ji, Jj)
            kernel[i*chunk_size:(i+1)*chunk_size, j*chunk_size:(j+1)*chunk_size, :, :] = kernel_chunk

            del Ji
            del Jj
    return kernel

# chunk_size = maximum number of examples to process at once
# real_chunk_size [aka "real chunk size"] = maximum number of example*class to process at once
def compute_ntk_matrix(network: nn.Module, optimizer, dataset: Dataset, real_chunk_size: int = None, batch_size: int = DEFAULT_BS, verbose=True):
    n = len(dataset)
    k = network(dataset[0][0].unsqueeze(0).to(device)).shape[1]
    chunk_size = n if real_chunk_size is None else real_chunk_size // k
    ntk = torch.zeros(n, n, k, k)
    nchunks = math.ceil(float(n) / chunk_size)
    for i in range(nchunks):
        if verbose:
            print(f"examples {i*chunk_size} through {min((i+1)*chunk_size, n)}", flush=True)
        for j in range(i, nchunks):
            # print(j)
            chunk_i = TensorDataset(*[tensor[i*chunk_size:(i+1)*chunk_size] for tensor in dataset.tensors])
            chunk_j = TensorDataset(*[tensor[j*chunk_size:(j+1)*chunk_size] for tensor in dataset.tensors])
            Ji = compute_network_gradients(network, optimizer, chunk_i, batch_size).detach()
            Jj = compute_network_gradients(network, optimizer, chunk_j, batch_size).detach()

            # ntk_chunk is [example_chunk_size, example_chunk_size, k, k]
            ntk_chunk = torch.einsum("ikp,jlp->ijkl", Ji, Jj)
            ntk[i*chunk_size:(i+1)*chunk_size, j*chunk_size:(j+1)*chunk_size, :, :] = ntk_chunk
            ntk[j*chunk_size:(j+1)*chunk_size, i*chunk_size:(i+1) * chunk_size, :, :] = ntk_chunk.transpose(0, 1)

            del Ji
            del Jj
    return ntk


### Dataset
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

train_data_full = datasets.CIFAR10(root='./pytorch/data/', train=True,
                                        download=True, transform=transform)

test_data_full = datasets.CIFAR10(root='./pytorch/data/', train=False,
                                       download=True, transform=transform)

classes = ('airplane', 'automobile', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')


targets = torch.tensor(train_data_full.targets)
subset_indices = ((targets == 0) + (targets == 1))
subset_indices = subset_indices[0:10000].nonzero().view(-1)
train_data = Subset(train_data_full,subset_indices)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=False)

xs, ys = [], []
for x, y in train_loader:
    xs.append(x)
    ys.append(y)

path = './vit_models_NTKs/'
if not os.path.exists(path):
    os.mkdir(path)

train_data_tensor = TensorDataset(torch.cat(xs), torch.cat(ys))
targets = torch.tensor(test_data_full.targets)
subset_indices = ((targets == 0) + (targets == 1))
subset_indices = subset_indices.nonzero().view(-1)
test_data = Subset(test_data_full,subset_indices)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False)

xs, ys = [], []
for x, y in test_loader:
    xs.append(x)
    ys.append(y)

test_data_tensor = TensorDataset(torch.cat(xs), torch.cat(ys))

### Model
seed = 200
torch.manual_seed(seed)
model = SimpleViT(
                image_size=32,
                patch_size = 4,
                num_classes = 1,
                dim = 64,
                depth = 4,
                heads = 8,
                mlp_dim = 256
            ).to(device)

### Train and Calculate NTK
learning_rate = 0.01
num_steps = 501
train_loss_list = np.zeros(num_steps)
test_loss_list = np.zeros(num_steps)
train_acc_list = np.zeros(num_steps)
test_acc_list = np.zeros(num_steps)
n = len(train_data)
n2 = len(test_data)
criterion = nn.MSELoss()

path_lr = path + 'lr'+str(learning_rate)+'/'
if not os.path.exists(path_lr):
    os.mkdir(path_lr)

optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
fo1 = open(path_lr + 'loss.txt', 'w')
for epoch in range(num_steps):
    optimizer.zero_grad()
    if epoch % 20 == 0:
        ntk = compute_ntk_matrix(model, optimizer, train_data_tensor, None, batch_size)
        ntk = ntk.reshape(n, n)
        ntk_2 = 2 * ntk / n
        ntk_np = ntk_2.cpu().numpy()
        ntk_df = pd.DataFrame(ntk_np)
        ntk_df.to_csv(path_lr + 'NTK_seed' + str(seed) + '_step' + str(epoch) + '.csv')

        model_path = path_lr + 'model_step' + str(epoch) + '.pth.tar'
        torch.save({'model_state_dict': model.state_dict()}, model_path)

        S = torch.linalg.eigvals(ntk_2)
        S_real = S.real
        fo2 = open(path_lr + 'eigval_seed' + str(seed) + '_step' + str(epoch) + '.txt', 'w')
        for k in range(len(S)):
            fo2.write(str(S_real[k].item()) + '\n')
        fo2.close()

    test_loss = 0
    correct = 0
    for batch, (images, labels) in enumerate(test_loader):
        images, labels = images.to(device), labels.to(device)
        labels = (2 * labels - 1).reshape(-1, 1)
        labels_fl = labels.float()  # 64x1

        # Forward pass
        outputs = model(images)
        pred = outputs.data.sign().int().view_as(labels)
        correct += pred.eq(labels).sum().item()
        te_loss = criterion(outputs, labels_fl) / len(test_loader)
        test_loss += te_loss.item()
    test_acc = 100. * correct / len(test_loader.dataset)

    train_loss = 0
    correct = 0

    optimizer.zero_grad()
    for j, (images, labels) in enumerate(train_loader):
        images, labels = images.to(device), labels.to(device)
        labels = (2 * labels - 1).reshape(-1, 1)
        labels_fl = labels.float()  # 64x1

        # Forward pass
        outputs = model(images)
        pred = outputs.data.sign().int().view_as(labels)
        correct += pred.eq(labels).sum().item()

        loss = criterion(outputs, labels_fl) / len(train_loader)
        train_loss += loss.item()

        # Backward pass + Optimize
        loss.backward()
    train_acc = 100. * correct / len(train_loader.dataset)
    train_loss_list[epoch] = train_loss
    test_loss_list[epoch] = test_loss
    train_acc_list[epoch] = train_acc
    test_acc_list[epoch] = test_acc

    print('Iteration: [{}/{}], Training Loss: {:.4f}, Training Acc: {:.2f}%, Test Loss: {:.4f}, Test Acc: {:.2f}%'
          .format(epoch, num_steps, train_loss, train_acc, test_loss, test_acc))
    fo1.write(str(train_loss) + '\t' + str(test_loss) + '\t' + str(train_acc) + '\t' + str(test_acc) + '\n')
    optimizer.step()
fo1.close()

x_axis = np.linspace(0, num_steps - 1, num_steps)
fig = plt.figure()
plt.subplot(2,1,1)
plt.title('Training and test losses, lr='+ str(learning_rate))
plt.ylabel("Loss")
plt.xlabel("Iteration")
plt.plot(x_axis, train_loss_list, label='Train')
plt.plot(x_axis, test_loss_list, label='Test')
plt.legend()

plt.subplot(2,1,2)
plt.title('Training and test accuracy, lr='+ str(learning_rate))
plt.ylabel("Accuracy(%)")
plt.xlabel("Iteration")
plt.plot(x_axis, train_acc_list, label='Train')
plt.plot(x_axis, test_acc_list, label='Test')
plt.legend()
fig.savefig(path_lr + 'Loss_Acc.png')
plt.show()


