# -*- coding: utf-8 -*-


!pip install torchattacks

#parts of the code have been copied directly from the blog post/tutorial on transfer learning https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html
import torch
import torchvision
import torchvision.transforms as transforms
model = torch.hub.load("chenyaofo/pytorch-cifar-models", "cifar10_vgg16_bn", pretrained=True)
#model = torchvision.models.vgg16(pretrained=True)
#using the VGG network pre-trained from https://github.com/chenyaofo/pytorch-cifar-models or from torchvision
import torchattacks
import numpy as np
import matplotlib.pyplot as plt

cifar10_div_factor = 200

class CIFAR10SubLoader(torchvision.datasets.CIFAR10):
    def __init__(self, *args, **kwargs):
        super(CIFAR10SubLoader, self).__init__(*args, **kwargs)
        dataset_size = len(self.targets)
        mask1 = np.where((np.array(self.targets) == 0) | (np.array(self.targets) ==1))
        print(self.data.shape)
        self.data = self.data[mask1]
        print(self.data.shape)
        self.targets = np.array(self.targets)[mask1]
        dataset_size = self.targets.shape[0]
        mask2 = np.random.randint(low=0,high=dataset_size,size=dataset_size//cifar10_div_factor)
        print(self.data.shape)
        self.data = self.data[mask2]
        print(self.data.shape)
        self.targets = self.targets[mask2]
        print(self.targets)
    def get_data(self):
      return self.data, self.targets
cifar_train_dataset = CIFAR10SubLoader('./data', train=True, download=True,
                               transform=transforms.Compose([
                               transforms.ToTensor(),
                               ]))

cifar_test_dataset = CIFAR10SubLoader('./data', train=False, download=True,
                              transform=transforms.Compose([
                              transforms.ToTensor(),
                              ]))

cifar_train_loader = torch.utils.data.DataLoader(cifar_train_dataset, batch_size=cifar_train_dataset.targets.shape[0], shuffle=True)
cifar_test_loader  = torch.utils.data.DataLoader(cifar_test_dataset,  batch_size=cifar_test_dataset.targets.shape[0], shuffle=True)

dataloaders ={'train':cifar_train_loader, 'val':cifar_test_loader}

dataset_sizes = {'train': len(cifar_train_dataset), 'val': len(cifar_test_dataset)}

import matplotlib.pyplot as plt
fig4, ax4 = plt.subplots()
fig5, ax5 = plt.subplots()
fig6, ax6 = plt.subplots()
for batch in cifar_train_loader:

  ax4.imshow(batch[0][13,0,:,:])
  ax5.imshow(batch[0][10,0,:,:])
  ax6.imshow(batch[0][49,0,:,:])
  break

import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler


l = [param for param in model.parameters()]
for param in l[0:-2]:
    param.requires_grad = False


#model.classifier[6] = nn.Linear(model.classifier[6].in_features,2)
#model.classifier.append(nn.Softmax(dim=1))

print(model)
criterion = nn.CrossEntropyLoss()

# only parameters of final layer are being optimized
optimizer_conv = optim.SGD(model.classifier[6].parameters(), lr=0.008, momentum=0.2)

# Decay LR by a factor of 0.1 every 7 epochs
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_conv, step_size=7, gamma=0.01)

import time
import copy
import numpy as np

def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs - 1}')
        print('-' * 10)

        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
            else:
                model.eval()

            running_loss = 0.0
            running_corrects = 0

            for inputs, labels in dataloaders[phase]:

                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
            if phase == 'train':
                scheduler.step()

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]

            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())

        print()

    time_elapsed = time.time() - since
    print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
    print(f'Best val Acc: {best_acc:4f}')

    model.load_state_dict(best_model_wts)
    return model

model = train_model(model, criterion, optimizer_conv,
                         exp_lr_scheduler, num_epochs=20)

number_of_examples = 0
correct = 0
for i, (inputs, labels) in enumerate(cifar_test_loader):
  for j in range(inputs.shape[0]):
    outputs = model(inputs[j].reshape((1,) + inputs[j].shape))
    _, preds = torch.max(outputs,1)
    number_of_examples +=1
    if labels[j].item() == preds[0].item():
      correct +=1
print(correct/number_of_examples)

def compute_feature_map(data):
  feature_map = model.features(data)
  print(feature_map.shape)
  feature_map = feature_map.reshape((feature_map.shape[0], feature_map.shape[1]))
  print(feature_map.shape)
  for i in range(4):
    feature_map = model.classifier[i](feature_map)
  return feature_map

for batch in cifar_test_loader:
  original_feature_map = compute_feature_map(batch[0])
print(original_feature_map.shape)

N = original_feature_map.shape[1]

K_original = (1/N)*torch.matmul(original_feature_map, original_feature_map.T)
print(K_original.shape)

eig_values_K_original, eig_vectors_K_original = torch.linalg.eig(K_original)
eig_values_K_original = eig_values_K_original.real.detach().numpy()
eig_values_K_original[(eig_values_K_original < 0)] = 0.
eig_values_K_original = eig_values_K_original[(eig_values_K_original >= 0) & (eig_values_K_original <0.2)]
density_original = np.histogram(eig_values_K_original, bins = 100, density = True)
print(density_original)
eig_values_K_original_list = list(eig_values_K_original)
print(eig_values_K_original_list)

plt.figure()
plt.ylabel('Probability density')
plt.xlabel('Eigenvalues of original kernel matrix')
left, right = (0, 0.2)
bottom, top = (0,1)
plt.xlim((left,right))
plt.ylim((bottom,top))
plt.hist(eig_values_K_original_list, bins =100, density =True)
plt.show()

min_eig_values_original = min(eig_values_K_original_list)
integral_near_zero_original= 1/100*np.sum(density_original[0]*(1/(density_original[1][1:])**2))
print(integral_near_zero_original)
print(min_eig_values_original)

images = batch[0]
labels = batch[1]
fab_attack = torchattacks.FAB(model, norm='Linf', steps=10, eps=8/255, n_restarts=1, alpha_max=0.1, eta=1.05, beta=0.9, verbose=False, seed=0, n_classes=2)
adv_images = fab_attack(images, labels)

print(adv_images.shape)

min_eig_values_modified_list = []
integral_near_zero_modified_list = []
for i in range(adv_images.shape[0]):
  pert_image = adv_images[i,:,:,:].reshape((1,) + adv_images[i,:,:,:].shape)
  modified_data_distribution = torch.concat([batch[0],pert_image], dim =0)
  modified_feature_map = compute_feature_map(modified_data_distribution)
  print(modified_feature_map.shape)

  K_modified = (1/N)*torch.matmul(modified_feature_map, modified_feature_map.T)
  print(K_modified.shape)

  try:
    eig_values_K_modified, eig_vectors_K_modified = torch.linalg.eig(K_modified)
  except RuntimeError:
    continue
  eig_values_K_modified = eig_values_K_modified.real.detach().numpy()
  eig_values_K_modified[(eig_values_K_modified < 0)] = 0.
  eig_values_K_modified = eig_values_K_modified[(eig_values_K_modified >= 0) & (eig_values_K_modified <0.2)]
  density_modified = np.histogram(eig_values_K_modified, bins = 100, density = True)
  print(density_modified)
  eig_values_K_modified_list = list(eig_values_K_modified)
  print(eig_values_K_modified_list)

  plt.figure()
  plt.ylabel('Probability density')
  plt.xlabel('Eigenvalues of modified kernel random matrix')
  left, right = (0, 0.2)
  bottom, top = (0,1)
  plt.xlim((left,right))
  plt.ylim((bottom,top))
  plt.hist(eig_values_K_modified_list, bins =100, density =True)
  plt.show()

  min_eig_values_modified_list.append(min(eig_values_K_modified_list))
  print(min(eig_values_K_modified_list))
  integral_near_zero_modified= 1/100*np.sum(density_modified[0]*(1/(density_modified[1][1:])**2))
  integral_near_zero_modified_list.append(integral_near_zero_modified)

  print(integral_near_zero_modified)
  fig1, ax1 = plt.subplots()
  ax1.imshow(pert_image[0,0,:,:])

plt.figure()
left, right = (0, 0.2)
bottom, top = (0,1)
plt.xlim((left,right))
plt.ylim((bottom,top))
plt.hist(eig_values_K_original_list, bins =100, density =True, color = 'orange')
plt.hist(eig_values_K_modified_list, bins =100, density =True)

plt.figure()
plt.hist(integral_near_zero_modified_list)
plt.axvline(x=integral_near_zero_original, color ='orange')

plt.figure()
#plt.xlim((-0.01,right))
plt.hist(min_eig_values_modified_list)
plt.axvline(x=min_eig_values_original, color ='orange')
