# -*- coding: utf-8 -*-


import os
from PIL import Image

import numpy as np
import matplotlib.pyplot as plt
import copy

import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, utils
import torchattacks

div_factor = 100
class SubLoader(datasets.MNIST):
    def __init__(self, *args, **kwargs):
        super(SubLoader, self).__init__(*args, **kwargs)
        dataset_size = self.targets.shape[0]
        mask1 = np.where((self.targets == 0) | (self.targets ==1))
        print(self.data.shape)
        self.data = self.data[mask1]
        print(self.data.shape)
        self.targets = self.targets[mask1]
        dataset_size = self.targets.shape[0]
        mask2 = np.random.randint(low=0,high=dataset_size,size=dataset_size//div_factor)
        print(self.data.shape)
        self.data = self.data[mask2]
        print(self.data.shape)
        self.targets = self.targets[mask2]
        print(self.targets)
    def get_data(self):
      return self.data, self.targets




#train_dataset = datasets.MNIST('./data', train=True, download=True,
train_dataset = SubLoader('./data', train=True, download=True,
                               transform=transforms.Compose([
                               transforms.ToTensor(),
                               ]))

#test_dataset = datasets.MNIST('./data', train=False, download=True,
test_dataset = SubLoader('./data', train=False, download=True,
                              transform=transforms.Compose([
                              transforms.ToTensor(),
                              ]))

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=train_dataset.targets.shape[0], shuffle=True)
test_loader  = torch.utils.data.DataLoader(test_dataset,  batch_size=test_dataset.targets.shape[0], shuffle=True)

import sys
import matplotlib.pyplot as plt
fig1, ax1 = plt.subplots()
fig2, ax2 = plt.subplots()
fig3, ax3 = plt.subplots()
for batch in train_loader:

  ax1.imshow(batch[0][13,0,:,:])
  ax2.imshow(batch[0][10,0,:,:])
  ax3.imshow(batch[0][99,0,:,:])
  break

class NN(torch.nn.Module):

    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        self.fc1 = torch.nn.Linear(input_size, hidden_size)
        self.fc2 = torch.nn.Linear(hidden_size, output_size)
        self.relu = torch.nn.ReLU()
        #self.log_softmax = torch.nn.LogSoftmax(dim=1)
        self.softmax = torch.nn.Softmax(dim=1)


    def forward(self, x):
        x = x.view(-1, 28*28)
        x = self.relu(self.fc1(x))
        self.feature_map = x
        x = self.fc2(x)
        #x = self.log_softmax(x)# output (log) softmax probabilities of each class
        x = self.softmax(x)
        return x

fcn = NN(28*28, 512, 2)

loss_function = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(fcn.parameters(), lr=0.01)
num_epochs = 10


for epoch in range(num_epochs):
  for batch in train_loader:
    optimizer.zero_grad()
    #print(batch[1])
    #print(nn(batch[0]))
    loss = loss_function(fcn(batch[0]), batch[1])
    print(loss)
    loss.backward()
    optimizer.step()

correct = 0
for im in train_dataset:
  if fcn(im[0]).argmax().item() == im[1]:
    correct +=1
print("Training accuracy: " + str(correct/len(train_dataset)*100) + "%")

correct = 0
for im in test_dataset:
  if fcn(im[0]).argmax().item() == im[1]:
    correct +=1
print("Test accuracy: " + str(correct/len(test_dataset)*100) + "%")

def deepfool(image, net, num_classes=2, overshoot=0.02, max_iter=10):

    f_image = net.forward(image).data.numpy().flatten()
    I = (np.array(f_image)).flatten().argsort()[::-1]

    I = I[0:num_classes]
    label = I[0]

    input_shape = image.detach().numpy().shape
    pert_image = copy.deepcopy(image)
    w = np.zeros(input_shape)
    r_tot = np.zeros(input_shape)

    loop_i = 0

    x = torch.tensor(pert_image[None, :],requires_grad=True)

    fs = net.forward(x[0])
    fs_list = [fs[0,I[k]] for k in range(num_classes)]
    k_i = label

    while k_i == label and loop_i < max_iter:

        pert = np.inf
        fs[0, I[0]].backward(retain_graph=True)
        grad_orig = x.grad.data.numpy().copy()

        for k in range(1, num_classes):

            #x.zero_grad()

            fs[0, I[k]].backward(retain_graph=True)
            cur_grad = x.grad.data.numpy().copy()

            # set new w_k and new f_k
            w_k = cur_grad - grad_orig
            f_k = (fs[0, I[k]] - fs[0, I[0]]).data.numpy()

            pert_k = abs(f_k)/np.linalg.norm(w_k.flatten())

            # determine which w_k to use
            if pert_k < pert:
                pert = pert_k
                w = w_k

        # compute r_i and r_tot
        # Added 1e-4 for numerical stability
        r_i =  (pert+1e-4) * w / np.linalg.norm(w)
        r_tot = np.float32(r_tot + r_i)

        pert_image = image + (1+overshoot)*torch.from_numpy(r_tot)

        x = torch.tensor(pert_image, requires_grad=True)
        fs = net.forward(x[0])
        k_i = np.argmax(fs.data.numpy().flatten())

        loop_i += 1

    r_tot = (1+overshoot)*r_tot

    return r_tot, loop_i, label, k_i, pert_image

def imshow(img):
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

# Switch to evaluation mode
fcn.eval()
pert_image_list = []
#print(pert_images.shape)
for im in test_dataset:
  #imshow(im[0])
  img = torch.tensor(im[0],requires_grad =True)
  r, loop_i, label_orig, label_pert, pert_image = deepfool(im[0], fcn,max_iter=50)
  if len(pert_image.shape) < 4:
    continue
  pert_image_list.append(pert_image)
pert_images = torch.concat(pert_image_list, dim =0)
#print(pert_image_list)
print(pert_images.shape)
imshow(pert_images[0,:,:,:])
print(fcn(im[0]), label_orig)
fcn(batch[0])
print(fcn.feature_map.shape)
print(fcn(pert_image), label_pert)
fcn(batch[0])
print(fcn.feature_map.shape)
print(batch[0].shape)
print(pert_images.shape)
print(torch.concat([batch[0],pert_images], dim =0).shape)

N = fcn.feature_map.shape[1]
print(N)

original_data_distribution = batch[0]
fcn(original_data_distribution)
original_feature_map = fcn.feature_map
print(original_feature_map.shape)

K_original = (1/N)*torch.matmul(original_feature_map, original_feature_map.T)
print(K_original.shape)

eig_values_K_original, eig_vectors_K_original = torch.linalg.eig(K_original)
eig_values_K_original = eig_values_K_original.real.detach().numpy()
eig_values_K_original[(eig_values_K_original < 0)] = 0.
eig_values_K_original = eig_values_K_original[(eig_values_K_original >= 0) & (eig_values_K_original <0.2)]
density_original = np.histogram(eig_values_K_original, bins = 100, density = True)
print(density_original)
eig_values_K_original_list = list(eig_values_K_original)
print(eig_values_K_original_list)

plt.figure()
plt.ylabel('Probability density')
plt.xlabel('Eigenvalues of original kernel matrix')
left, right = (0, 0.2)
bottom, top = (0,1)
plt.xlim((left,right))
plt.ylim((bottom,top))
plt.hist(eig_values_K_original_list, bins =100, density =True)
plt.show()

min_eig_values_original = min(eig_values_K_original_list)
integral_near_zero_original= 1/100*np.sum(density_original[0]*(1/(density_original[1][1:])**2))
print(integral_near_zero_original)
print(min_eig_values_original)

min_eig_values_modified_list = []
integral_near_zero_modified_list = []
for pert_image in pert_image_list:
  modified_data_distribution = torch.concat([batch[0],pert_image], dim =0)
  fcn(modified_data_distribution)
  modified_feature_map = fcn.feature_map
  print(modified_feature_map.shape)

  K_modified = (1/N)*torch.matmul(modified_feature_map, modified_feature_map.T)
  print(K_modified.shape)

  try:
    eig_values_K_modified, eig_vectors_K_modified = torch.linalg.eig(K_modified)
  except RuntimeError:
    continue
  eig_values_K_modified = eig_values_K_modified.real.detach().numpy()
  eig_values_K_modified[(eig_values_K_modified < 0)] = 0.
  eig_values_K_modified = eig_values_K_modified[(eig_values_K_modified >= 0) & (eig_values_K_modified <0.2)]
  density_modified = np.histogram(eig_values_K_modified, bins = 100, density = True)
  print(density_modified)
  eig_values_K_modified_list = list(eig_values_K_modified)
  print(eig_values_K_modified_list)

  plt.figure()
  plt.ylabel('Probability density')
  plt.xlabel('Eigenvalues of modified kernel random matrix')
  left, right = (0, 0.2)
  bottom, top = (0,1)
  plt.xlim((left,right))
  plt.ylim((bottom,top))
  plt.hist(eig_values_K_modified_list, bins =100, density =True)
  plt.show()

  min_eig_values_modified_list.append(min(eig_values_K_modified_list))
  print(min(eig_values_K_modified_list))
  integral_near_zero_modified= 1/100*np.sum(density_modified[0]*(1/(density_modified[1][1:])**2))
  integral_near_zero_modified_list.append(integral_near_zero_modified)

  print(integral_near_zero_modified)
  fig1, ax1 = plt.subplots()
  ax1.imshow(pert_image[0,0,:,:])

plt.figure()
left, right = (0, 0.2)
bottom, top = (0,1)
plt.xlim((left,right))
plt.ylim((bottom,top))
plt.hist(eig_values_K_original_list, bins =100, density =True, color = 'orange')
plt.hist(eig_values_K_modified_list, bins =100, density =True)

plt.figure()
plt.hist(integral_near_zero_modified_list)
plt.axvline(x=integral_near_zero_original, color ='orange')

plt.figure()
#plt.xlim((-0.01,right))
plt.hist(min_eig_values_modified_list)
plt.axvline(x=min_eig_values_original, color ='orange')
