# -*- coding: utf-8 -*-
"""explainable_nn_mnist_iclr.ipynb

Automatically generated by Colaboratory.

Original file is located at
    https://colab.research.google.com/drive/1wEDN9joTDMi98IsXZZW7b4PNylEpLj1x
"""

import os
from PIL import Image

import numpy as np
import matplotlib.pyplot as plt

import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, utils

!pip install captum

from google.colab import drive
drive.mount('/content/gdrive')

train_dataset = datasets.MNIST('./data', train=True, download=True,
                               transform=transforms.Compose([
                               transforms.ToTensor(),
                               ]))

test_dataset = datasets.MNIST('./data', train=False, download=True,
                              transform=transforms.Compose([
                              transforms.ToTensor(),
                              ]))

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader  = torch.utils.data.DataLoader(test_dataset,  batch_size=128, shuffle=True)

import sys
import matplotlib.pyplot as plt
fig1, ax1 = plt.subplots()
fig2, ax2 = plt.subplots()
fig3, ax3 = plt.subplots()
for batch in train_loader:
  
  ax1.imshow(batch[0][13,0,:,:])
  ax2.imshow(batch[0][10,0,:,:])
  ax3.imshow(batch[0][100,0,:,:])
  break

class NN(torch.nn.Module):
    
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        self.fc1 = torch.nn.Linear(input_size, hidden_size)
        self.fc2 = torch.nn.Linear(hidden_size, output_size)
        self.relu = torch.nn.ReLU()
        self.log_softmax = torch.nn.LogSoftmax(dim=1)
        #self.softmax = torch.nn.Softmax(dim=1)
        
    def forward(self, x):
        x = x.view(-1, 28*28)
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        x = self.log_softmax(x)
        #x = self.softmax(x)
        return x

nn = NN(28*28, 512, 10)

loss_function = torch.nn.NLLLoss()# negative log likelihood loss
optimizer = torch.optim.Adam(nn.parameters(), lr=0.01)
num_epochs = 10


for epoch in range(num_epochs):
  for batch in train_loader:
    optimizer.zero_grad()
    loss = loss_function(nn(batch[0]), batch[1])
    print(loss)
    loss.backward()
    optimizer.step()

correct = 0
for im in train_dataset:
  if nn(im[0]).argmax().item() == im[1]:
    correct +=1
print("Training accuracy: " + str(correct/len(train_dataset)*100) + "%")

correct = 0
for im in test_dataset:
  if nn(im[0]).argmax().item() == im[1]:
    correct +=1
print("Test accuracy: " + str(correct/len(test_dataset)*100) + "%")

torch.save(nn,'/content/gdrive/My Drive/fcn_mnist.pth')

nn = torch.load('/content/gdrive/My Drive/fcn_mnist.pth')

import captum.attr as attr

ig = attr.IntegratedGradients(nn, multiply_by_inputs = True)

def imshow(img, path):
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.savefig(path)
    plt.show()

explanation_size = []
running_time = []

import copy
import time
baseline_value = 0
baseline = torch.tensor(baseline_value*np.ones((1,28,28))).float()
for batch in train_loader:
  im = (batch[0][0,:,:],batch[1][0].item())
  if nn(im[0]).argmax().item() != im[1]:
    continue
  ig_attributions = ig.attribute(inputs = im[0], baselines = 0.0, target = im[1], n_steps =100)
  contributions ={}
  minimal_explanation = []  
  modified_image = copy.deepcopy(im[0])
  t = time.time()
  for i in range(im[0].shape[1]):
    for j in range(im[0].shape[1]):
      modified_image[0,i,j] = baseline[0,i,j]
      contributions.update({(i,j):nn(modified_image)[0,im[1]].item()})
      modified_image[0,i,j] = im[0][0,i,j]
  print(contributions)

  sorted_contributions = {indices: contribution for indices, contribution in sorted(contributions.items(), key=lambda item: item[1])}
  print(sorted_contributions)
  
  for i, indices in enumerate(sorted_contributions.keys()):
    minimal_explanation.append(indices)
    modified_image[0,indices[0],indices[1]] = baseline[0,indices[0],indices[1]]
    if nn(modified_image).argmax().item() != im[1]:
      elapsed_time = time.time()-t
      print(elapsed_time)
      running_time.append(60*elapsed_time)
      print(f'prediction: {nn(modified_image).argmax().item()}')
      print(len(minimal_explanation))
      explanation_size.append(len(minimal_explanation))
      fig, ax = plt.subplots()
      images = torch.stack([im[0][0,:,:].repeat(3,1,1),modified_image[0,:,:].repeat(3,1,1),10*ig_attributions[0,:,:].repeat(3,1,1)])
      imshow(utils.make_grid(images),'/content/gdrive/My Drive/images_mnist.png')
      break
  #break

from statistics import *
plt.figure()
print(mean(explanation_size))
print(mean(running_time))
plt.xlabel('Size of minimal feature removal')
plt.ylabel('Number of examples')
plt.title('Histogram of sizes of minimal feature removals')
plt.hist(explanation_size, bins = 30)
plt.savefig('/content/gdrive/My Drive/histogram_size_minimal_feature_removal.pdf')
plt.figure()
plt.xlabel('Time to compute minimal feature removal (in seconds)')
plt.ylabel('Number of examples')
plt.title('Histogram of computing times for minimal feature removals')
plt.hist(running_time, bins = 30)
plt.savefig('/content/gdrive/My Drive/histogram_time_minimal_feature_removal.pdf')