DEFAULT_BS = 64
batch_size = 256

from torch.utils.data import Dataset, DataLoader, TensorDataset, Subset
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import os
import math

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def trainable_parameters(network: nn.Module):
    for param in network.parameters():
        if param.requires_grad:
            yield param


def nparams(network: nn.Module):
    return len(torch.nn.utils.parameters_to_vector(trainable_parameters(network)))

def vectorize(tensor: torch.Tensor):
    return tensor.reshape(tensor.shape[0], -1)

def grad_batch2vec(network: nn.Module):
    vec = []
    for param in trainable_parameters(network):
        vec.append(vectorize(param.grad_batch).detach())
    result = torch.cat(vec, dim=1)
    del vec
    return result

def grad_2vec(network: nn.Module):
    vec = []
    for param in trainable_parameters(network):
        vec.append(vectorize(param.grad.unsqueeze(0)).detach())
    result = torch.cat(vec, dim=1)
    del vec
    return result

def vec2param(vec: torch.Tensor, network: nn.Module):
    loc = 0
    for param in network.parameters():
        num_el = param.numel()
        subvec = vec[loc:loc+num_el]
        loc += num_el
        param.data = subvec.reshape(param.shape)

def iterate_dataset_idx(dataset: Dataset, batch_size: int):
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    for i, (batch_X, batch_y) in enumerate(loader):
        yield batch_X.to(device), batch_y.to(device), \
              i * batch_size, i * batch_size + len(batch_X)

def compute_network_gradients(network: nn.Module, optimizer, dataset: Dataset, batch_size):
    p = nparams(network)
    n = len(dataset)
    k = network(dataset[0][0].unsqueeze(0).to(device)).shape[1]
    gradients = torch.zeros(n, k, p)

    for c in range(k):
        for (X, y, start, end) in iterate_dataset_idx(dataset, batch_size):
            for i in range(len(y)):
                optimizer.zero_grad()
                batch_score = network(X)[:, c][i]
                batch_score.backward()
                gradients[start+i, c, :] = grad_2vec(network)

    return gradients

def compute_1D_output_error(network: nn.Module, dataset: Dataset, batch_size):
    n = len(dataset)
    error = torch.zeros(n, 1)
    target = torch.zeros(n, 1)

    for (X, y, start, end) in iterate_dataset_idx(dataset, batch_size):
        output = network(X)
        y = y.float().reshape(y.size(0), 1)
        error[start:end] = output - y
        target[start:end] = y
    return error.reshape(-1), target.reshape(-1)


# chunk_size = maximum number of examples to process at once
# real_chunk_size [aka "real chunk size"] = maximum number of example*class to process at once
def compute_ntk_matrix(network: nn.Module, optimizer, dataset: Dataset, real_chunk_size: int = None, batch_size: int = DEFAULT_BS, verbose=True):
    n = len(dataset)
    k = network(dataset[0][0].unsqueeze(0).to(device)).shape[1]
    chunk_size = n if real_chunk_size is None else real_chunk_size // k
    ntk = torch.zeros(n, n, k, k).to(device)
    nchunks = math.ceil(float(n) / chunk_size)
    for i in range(nchunks):
        if verbose:
            print(f"examples {i*chunk_size} through {min((i+1)*chunk_size, n)}", flush=True)
        for j in range(i, nchunks):
            # print(j)
            chunk_i = TensorDataset(*[tensor[i*chunk_size:(i+1)*chunk_size] for tensor in dataset.tensors])
            chunk_j = TensorDataset(*[tensor[j*chunk_size:(j+1)*chunk_size] for tensor in dataset.tensors])
            Ji = compute_network_gradients(network, optimizer, chunk_i, batch_size).detach()
            Jj = compute_network_gradients(network, optimizer, chunk_j, batch_size).detach()

            # ntk_chunk is [example_chunk_size, example_chunk_size, k, k]
            ntk_chunk = torch.einsum("ikp,jlp->ijkl", Ji, Jj).to(device)
            ntk[i*chunk_size:(i+1)*chunk_size, j*chunk_size:(j+1)*chunk_size, :, :] = ntk_chunk
            ntk[j*chunk_size:(j+1)*chunk_size, i*chunk_size:(i+1) * chunk_size, :, :] = ntk_chunk.transpose(0, 1)

            del Ji
            del Jj
    return ntk

def compute_kernel_train_test(network: nn.Module, optimizer, train_dataset: Dataset, test_dataset: Dataset, real_chunk_size: int = None, batch_size: int = DEFAULT_BS, verbose=True):
    n1 = len(test_dataset)
    n2 = len(train_dataset)
    k = network(train_dataset[0][0].unsqueeze(0).to(device)).shape[1]
    chunk_size = min(n1,n2) if real_chunk_size is None else real_chunk_size // k
    kernel = torch.zeros(n1, n2, k, k).to(device)
    nchunks1 = math.ceil(float(n1) / chunk_size)
    nchunks2 = math.ceil(float(n2) / chunk_size)
    for i in range(nchunks1):
        if verbose:
            print(f"examples {i*chunk_size} through {min((i+1)*chunk_size, n1)}", flush=True)
        for j in range(nchunks2):
            # print(j)
            chunk_i = TensorDataset(*[tensor[i*chunk_size:(i+1)*chunk_size] for tensor in test_dataset.tensors])
            chunk_j = TensorDataset(*[tensor[j*chunk_size:(j+1)*chunk_size] for tensor in train_dataset.tensors])
            Ji = compute_network_gradients(network, optimizer, chunk_i, batch_size).detach()
            Jj = compute_network_gradients(network, optimizer, chunk_j, batch_size).detach()

            # ntk_chunk is [example_chunk_size, example_chunk_size, k, k]
            kernel_chunk = torch.einsum("ikp,jlp->ijkl", Ji, Jj).to(device)
            kernel[i*chunk_size:(i+1)*chunk_size, j*chunk_size:(j+1)*chunk_size, :, :] = kernel_chunk

            del Ji
            del Jj
    return kernel

### Dataset
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

train_data_full = datasets.CIFAR10(root='./pytorch/data/', train=True,
                                        download=True, transform=transform)

test_data_full = datasets.CIFAR10(root='./pytorch/data/', train=False,
                                       download=True, transform=transform)

classes = ('airplane', 'automobile', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')


targets = torch.tensor(train_data_full.targets)
subset_indices = ((targets == 0) + (targets == 1))
subset_indices = subset_indices[0:10000].nonzero().view(-1)
train_data = Subset(train_data_full,subset_indices)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=False)

xs, ys = [], []
for x, y in train_loader:
    xs.append(x)
    ys.append(y)

path = './vgg11_models_NTKs/'
if not os.path.exists(path):
    os.mkdir(path)

train_data_tensor = TensorDataset(torch.cat(xs), torch.cat(ys))
targets = torch.tensor(test_data_full.targets)
subset_indices = ((targets == 0) + (targets == 1))
subset_indices = subset_indices.nonzero().view(-1)
test_data = Subset(test_data_full,subset_indices)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False)

xs, ys = [], []
for x, y in test_loader:
    xs.append(x)
    ys.append(y)

test_data_tensor = TensorDataset(torch.cat(xs), torch.cat(ys))

### Model
class VGG(nn.Module):
    '''
    VGG model
    '''
    def __init__(self, features):
        super(VGG, self).__init__()
        self.features = features
        self.classifier = nn.Sequential(
            nn.Linear(512, 512),
            nn.ReLU(True),
            nn.Linear(512, 512),
            nn.ReLU(True),
            nn.Linear(512, 1),
        )
         # Initialize weights
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
                m.bias.data.zero_()


    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x


def make_layers(cfg, batch_norm=False):
    layers = []
    in_channels = 3
    for v in cfg:
        if v == 'M':
            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
        else:
            conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
            if batch_norm:
                layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
            else:
                layers += [conv2d, nn.ReLU(inplace=True)]
            in_channels = v
    return nn.Sequential(*layers)


cfg = {
    'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M',
          512, 512, 512, 512, 'M'],
}


def vgg11():
    """VGG 11-layer model (configuration "A")"""
    return VGG(make_layers(cfg['A']))


def vgg11_bn():
    """VGG 11-layer model (configuration "A") with batch normalization"""
    return VGG(make_layers(cfg['A'], batch_norm=True))


### Train and Calculate NTK
seed = 0
torch.manual_seed(seed)
learning_rate = 0.01
num_epochs = 200
train_loss_list = np.zeros(num_epochs)
test_loss_list = np.zeros(num_epochs)
train_acc_list = np.zeros(num_epochs)
test_acc_list = np.zeros(num_epochs)

n = len(train_data)
criterion = nn.MSELoss()
model = vgg11_bn().to(device)
p = nparams(model)
path_lr = path + 'lr'+str(learning_rate)+'/'
if not os.path.exists(path_lr):
    os.mkdir(path_lr)

optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

fo1 = open(path_lr + 'loss.txt', 'w')
for epoch in range(num_epochs):
    optimizer.zero_grad()
    if epoch % 20 == 0:
        ntk = compute_ntk_matrix(model, optimizer, train_data_tensor, 1000, batch_size)
        ntk = ntk.reshape(n, n)
        ntk_2 = 2 * ntk / n
        ntk_np = ntk_2.cpu().numpy()
        ntk_df = pd.DataFrame(ntk_np)
        ntk_df.to_csv(path_lr+'NTK_seed'+str(seed)+'_step'+str(epoch)+'.csv')

        model_path = path_lr + 'model_step' + str(epoch) + '.pth.tar'
        torch.save({'model_state_dict': model.state_dict()}, model_path)

        S = torch.linalg.eigvals(ntk_2)
        S_real = S.real
        fo2 = open(path_lr + 'eigval_seed'+str(seed)+ '_step' + str(epoch) + '.txt', 'w')
        for k in range(len(S)):
            fo2.write(str(S_real[k].item()) + '\n')
        fo2.close()

    test_loss = 0
    correct = 0
    for batch, (images, labels) in enumerate(test_loader):
        images, labels = images.to(device), labels.to(device)
        labels = (2*labels-1).reshape(labels.size(0), 1)
        labels_fl = labels.float()

        # Forward pass
        outputs = model(images)
        pred = outputs.data.sign().int().view_as(labels)
        correct += pred.eq(labels).sum().item()
        te_loss = criterion(outputs, labels_fl) / len(test_loader)
        test_loss += te_loss.item()
    test_acc = 100. * correct / len(test_loader.dataset)

    train_loss = 0
    correct = 0

    optimizer.zero_grad()
    for j, (images, labels) in enumerate(train_loader):
        images, labels = images.to(device), labels.to(device)
        labels = (2*labels-1).reshape(labels.size(0), 1)
        labels_fl = labels.float()

        # Forward pass
        outputs = model(images)
        pred = outputs.data.sign().int().view_as(labels)
        correct += pred.eq(labels).sum().item()

        loss = criterion(outputs, labels_fl) / len(train_loader)
        train_loss += loss.item()

        # Backward pass + Optimize
        loss.backward()
    train_acc = 100. * correct / len(train_loader.dataset)
    train_loss_list[epoch] = train_loss
    test_loss_list[epoch] = test_loss
    train_acc_list[epoch] = train_acc
    test_acc_list[epoch] = test_acc
    print('Iteration: [{}/{}], Training Loss: {:.4f}, Training Acc: {:.2f}%, Test Loss: {:.4f}, Test Acc: {:.2f}%'
          .format(epoch, num_epochs, train_loss, train_acc, test_loss, test_acc),flush=True)

    optimizer.step()
    fo1.write(str(train_loss) + '\t' + str(test_loss) + '\t' + str(train_acc) + '\t' + str(test_acc) + '\n')
fo1.close()

x_axis = np.linspace(0, num_epochs - 1, num_epochs)
fig = plt.figure()
plt.subplot(2,1,1)
plt.title('Training and test losses, lr='+ str(learning_rate))
plt.ylabel("Loss")
plt.xlabel("Iteration")
plt.plot(x_axis, train_loss_list, label='Train')
plt.plot(x_axis, test_loss_list, label='Test')
plt.legend()

plt.subplot(2,1,2)
plt.title('Training and test accuracy, lr='+ str(learning_rate))
plt.ylabel("Accuracy(%)")
plt.xlabel("Iteration")
plt.plot(x_axis, train_acc_list, label='Train')
plt.plot(x_axis, test_acc_list, label='Test')
plt.legend()
fig.savefig(path_lr + 'Loss_Acc.png')
plt.show()


