
#!/usr/bin/env python
# coding: utf-8

# In[1]:


#get_ipython().run_line_magic('matplotlib', 'inline')

import random
from torch.nn import functional as F
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import torch
from torchvision import datasets, models, transforms
from tqdm import tqdm
from torchray.attribution.grad_cam import grad_cam
from src import Pruner, Plot_tools
from src.attribution_methods import gradcam, vanilla_saliency, integrated_gradients, grad_times_image, prune_grad_mid, guided_backprop, RectGrad
import matplotlib.pyplot as plt
from scipy.ndimage import gaussian_filter
import math
from TorchRay.torchray.attribution.grad_cam_positive import grad_cam_positive
from cdrp import main as cdrp
import sys
sys.path.insert(0, './cdrp/')

#from src.attribution_methods.grad_cam_pytorch.grad_cam import GradCAM
#from src.attribution_methods.grad_cam import GradCam
# ### Setup Imagenet
# 
# ImageNet as of Oct2019 can no longer be downloaded using pytorch.  
# https://github.com/pytorch/vision/issues/1453  
# To download ImageNet, see http://image-net.org/.  

# In[3]:


imagenet_dir = '/home/ashkan/data/ILSVRC2012/'
transform = transforms.Compose([transforms.Resize(256),
                                transforms.CenterCrop(224),
                                transforms.ToTensor(),
                                transforms.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
imagenet = datasets.ImageNet(imagenet_dir, download=False, split='val', transform=transform)
classes = imagenet.classes
mpl.rcParams['figure.dpi']= 400


# ### Method to get attributions

# In[4]:


EPSILON_DOUBLE = torch.tensor(2.220446049250313e-16, dtype=torch.float64)
EPSILON_SINGLE = torch.tensor(1.19209290E-07, dtype=torch.float32)
SQRT_TWO_DOUBLE = torch.tensor(math.sqrt(2), dtype=torch.float32)
SQRT_TWO_SINGLE = SQRT_TWO_DOUBLE.to(torch.float32)
def imsmooth(tensor,
             sigma,
             stride=1,
             padding=0,
             padding_mode='constant',
             padding_value=0):
    r"""Apply a 2D Gaussian filter to a tensor.
    The 2D filter itself is implementing by separating the 2D convolution in
    two 1D convolutions, first along the vertical direction and then along
    the horizontal one. Each 1D Gaussian kernel is given by:
    .. math::
        f_i \propto \exp\left(-\frac{1}{2} \frac{i^2}{\sigma^2} \right),
            ~~~ i \in \{-W,\dots,W\},
            ~~~ W = \lceil 4\sigma \rceil.
    This kernel is normalized to sum to one exactly. Given the latter, the
    function calls `torch.nn.functional.conv2d`
    to perform the actual convolution. Various padding parameters and the
    stride are passed to the latter.
    Args:
        tensor (:class:`torch.Tensor`): :math:`N\times C\times H\times W`
            image tensor.
        sigma (float): standard deviation of the Gaussian kernel.
        stride (int, optional): subsampling factor. Default: ``1``.
        padding (int, optional): extra padding. Default: ``0``.
        padding_mode (str, optional): ``'constant'``, ``'reflect'`` or
            ``'replicate'``. Default: ``'constant'``.
        padding_value (float, optional): constant value for the `constant`
            padding mode. Default: ``0``.
    Returns:
        :class:`torch.Tensor`: :math:`N\times C\times H\times W` tensor with
        the smoothed images.
    """
    assert sigma >= 0
    width = math.ceil(4 * sigma)
    filt = (torch.arange(-width,
                         width + 1,
                         dtype=torch.float32,
                         device=device) /
            (SQRT_TWO_SINGLE * sigma + EPSILON_SINGLE))
    filt = torch.exp(-filt * filt)
    filt /= torch.sum(filt)
    num_channels = tensor.shape[1]
    width = width + padding
    if padding_mode == 'constant' and padding_value == 0:
        other_padding = width
        x = tensor
    else:
        # pad: (before, after) pairs starting from last dimension backward
        x = F.pad(tensor,
                  (width, width, width, width),
                  mode=padding_mode,
                  value=padding_value)
        other_padding = 0
        padding = 0
    x = F.conv2d(x,
                 filt.reshape((1, 1, -1, 1)).expand(num_channels, -1, -1, -1),
                 padding=(other_padding, padding),
                 stride=(stride, 1),
                 groups=num_channels)
    x = F.conv2d(x,
                 filt.reshape((1, 1, 1, -1)).expand(num_channels, -1, -1, -1),
                 padding=(padding, other_padding),
                 stride=(1, stride),
                 groups=num_channels)
    return x


def get_attribution(attribution_name, data, model, gradcam_layer, pruneGradMid_layer, model_sparsity_threshold, cumulative_layers=None):
    make_single_channel = True
    class_id = model(data)
    class_id = class_id.data.max(1)[1].item()
    print(class_id)
    model.eval()
    if attribution_name == "Gradients":
        vanilla_sal = vanilla_saliency.VanillaSaliency(model, device)
        saliency = vanilla_sal.generate_saliency(data, class_id, make_single_channel)
    elif attribution_name == "InputMCT":
        saliency = grad_times_image.generate_grad_times_image_saliency(model, data, class_id, device, make_single_channel)
    elif attribution_name == "InputIntGrad":
        integ_grad = integrated_gradients.IntegratedGradients(model, device)
        saliency = integ_grad.generate_integrated_gradients(data, class_id, 50, make_single_channel)
    elif attribution_name == "GBP":
        GB = guided_backprop.GuidedBackprop(model, device)
        saliency = GB.generate_gradients(data, class_id, make_single_channel)
    elif attribution_name == "RectGrad":
                saliency = RectGrad.generate_rectgrad(model, data, class_id, 90, device, make_single_channel)

    elif attribution_name == "GradCAM":
        saliency = grad_cam(model, data, class_id, saliency_layer=gradcam_layer)
        saliency = F.interpolate(saliency, 224, mode="bilinear")
        saliency = saliency.detach().cpu().numpy()

    elif attribution_name == "NeuronMCT":
        pruner = Pruner.Pruner(model, data, device)
        pruner.prune(model_sparsity_threshold, debug=False)
        saliency = pruner.generate_saliency(make_single_channel=make_single_channel)
        pruner.remove_handles()

    elif attribution_name == "RandomPruning":
        pruner = Pruner.Pruner(model, data, device)
        pruner.prune_random(model_sparsity_threshold, debug=False)
        saliency = pruner.generate_saliency(make_single_channel=make_single_channel)
        pruner.remove_handles()

    elif attribution_name == "NeuronIntGrad":
        pruner = Pruner.Pruner(model, data, device)
        pruner.prune_integrad(model_sparsity_threshold, debug=False)
        saliency = pruner.generate_saliency(make_single_channel=make_single_channel)
        pruner.remove_handles()

    elif attribution_name == "PruneGradSum":
#        assert sum_sparsities is not None and len(sum_sparsities) > 0
        maps = []
        sum_sparsities = list(range(80, 100))
        for l in sum_sparsities:
            pruner = Pruner.Pruner(model, data.clone(), device)
            pruner.prune(l, debug=False)
            saliency = pruner.generate_saliency(make_single_channel=make_single_channel)
            pruner.remove_handles()
            if make_single_channel:
                saliency = torch.from_numpy(np.asarray(saliency)).view([1, 224, 224])
            else:
                saliency = torch.from_numpy(np.asarray(saliency)).view([3, 224, 224])
            saliency /= np.max(np.asarray(abs(saliency)).flatten())
            saliency = np.asarray(saliency.squeeze(0))
#            saliency = abs(saliency)
            #saliency = gaussian_filter(saliency, 0.4, (2, 2))
            maps.append(saliency)
        map = maps[0]
        for m in range(1, len(maps)):
            map += maps[m]
        map /= np.max(np.asarray(abs(map)).flatten())
        return torch.from_numpy(map).unsqueeze(0)

    elif attribution_name == "GreedyPruning":
        pruner = Pruner.Pruner(model, data, device)
        pruner.prune_iterative(1, model_sparsity_threshold, debug=False)
        saliency = pruner.generate_saliency(make_single_channel=make_single_channel)
        pruner.remove_handles()

    elif attribution_name == "DGR":
        pruner = Pruner.Pruner(model, data, device)
        pruner.prune_cdrp(model_sparsity_threshold, debug=False)
        saliency = pruner.generate_saliency(make_single_channel=make_single_channel)
        pruner.remove_handles()

    elif attribution_name == "PrunePGD":
        pruner = Pruner.Pruner(model, data, device)
        pruner.prune(model_sparsity_threshold, debug=False)
        saliency = pruner.generate_saliency_pgd_l2(epsilon=500, alpha=25, num_iter=50, make_single_channel=make_single_channel, debug=False)
        pruner.remove_handles()

    elif attribution_name == "PruneGrad-Mid":
        pgmid = prune_grad_mid.PruneGradMid(model, data, class_id, pruneGradMid_layer, device)
        saliency, _, _ = pgmid.get_saliency(model_sparsity_threshold)

    elif attribution_name == "GradCAM Positive":
        saliency = grad_cam_positive(model, data, class_id, saliency_layer=gradcam_layer)
        # saliency = grad_cam_positive(model, data, class_id, saliency_layer='features.23')

        saliency = F.interpolate(saliency, 224, mode="bilinear") #(int(h), int(w)), mode="bilinear")
        saliency = saliency.detach().cpu().numpy()

    elif attribution_name == "Cumulative GradCAM Positive":
        saliency = None
        assert cumulative_layers is not None
        for l in cumulative_layers:
            saliency_of_specific_layer = grad_cam_positive(model, data, class_id, saliency_layer='features')
            saliency_of_specific_layer = F.interpolate(saliency_of_specific_layer, 224, mode="bilinear")
            saliency_of_specific_layer = saliency_of_specific_layer.detach().cpu().numpy()
            if saliency is None:
                saliency = saliency_of_specific_layer
            else:
                saliency += saliency_of_specific_layer

    if make_single_channel:
        saliency = torch.from_numpy(np.asarray(saliency)).view([1, 224, 224])
    else:
        saliency = torch.from_numpy(np.asarray(saliency)).view([3, 224, 224])
    saliency /= np.max(np.asarray(abs(saliency)).flatten())
#    print(saliency)
    return saliency


# In[5]:


def visualize(model, dataloader, images, classes, gradcam_layer, pruneGradMid_layer, sparsity_levels, cumulative_layers=None):
    global id
    model = model.to(device)
    model.eval()
    name_methods = ["NeuronIntGrad", "NeuronMCT"] #, "DGR", "GreedyPruning", "RandomPruning"] #"GradCAM", "GradCAM Positive", "Cumulative GradCAM Positive"] #["VanillaGradient", "Grad*Input", "IntegratedGradient", "GuidedBackprop", "RectGrad", "GradCAM", "PruneGrad", "PruneGrad-Mid"]
    num_methods = len(name_methods)
    dataiter = iter(dataloader)
    i = 0
    fig = plt.figure(figsize=(7, num_samples*2))
    acts = []
    for chosen in tqdm(range(num_samples)):
        data, _ = dataiter.next()
        data = data.to(device)
        output = model(data.clone())
        output = torch.nn.functional.softmax(output.detach(), dim=1)
        predicted_logit = output.data.max(1)[1].item()
        predicted_prob = output.data.max(1)[0].item()
        image = Plot_tools.reverse_preprocess_imagenet_image(data.clone())
        ax = fig.add_subplot(num_samples*num_methods, len(sparsity_levels)+1, 2*chosen*(len(sparsity_levels)+1) + 1)
        if chosen == 0:
            ax.set_title("Original Image", fontsize=6)
        ax.text(-0.13, 0.5, classes[predicted_logit][0]+"\n"+str("%.2f" % round(predicted_prob*100, 2)+"%"), fontsize=6, rotation=90, horizontalalignment='center', verticalalignment='center', transform=ax.transAxes)
        plt.axis('off')
        plt.imshow(image)
        for j in range(num_methods):
            for k in range(len(sparsity_levels)):
              ax = fig.add_subplot(num_samples*num_methods, len(sparsity_levels)+1, num_methods*chosen*(len(sparsity_levels)+1) + j*(len(sparsity_levels)+1) + 1 + k + 1)
              if sparsity_levels[k] == 0:
                  attribution = get_attribution('Gradients', data.clone(), model, gradcam_layer, pruneGradMid_layer, sparsity_levels[k], cumulative_layers)
              else:
                  attribution = get_attribution(name_methods[j], data.clone(), model, gradcam_layer, pruneGradMid_layer, sparsity_levels[k], cumulative_layers)
              attribution = np.asarray(attribution.squeeze(0))
              if chosen == 0 and j == 0:
                if sparsity_levels[k] != 0:
                  ax.set_title("Sparsity={}".format(sparsity_levels[k]), fontsize=6)
                else:
                  ax.set_title("Original Network", fontsize=6)
              plt.imshow(abs(attribution), cmap='jet', vmin=0, vmax=1)
              if k == len(sparsity_levels)-1:
                  ax.text(1.05, 0.5, name_methods[j], fontsize=5, rotation=-90, horizontalalignment='center', verticalalignment='center', transform=ax.transAxes, fontweight='bold')
              plt.axis('off')
        i += 1
    plt.tight_layout()
    plt.subplots_adjust(wspace=0.05, hspace=0.01)
    plt.savefig('./fig'+str(id)+'.png', dpi=300)
    id += 1


# # Evaluate on Resnet50

# In[8]:


num_samples = 4
id = 0
indices = random.sample(range(0, len(imagenet)), num_samples)
#indices = [7211, 21856, 19258, 8701]
#indices = [4622]
#indices = [11814]
#indices = [16305] #butterfly
#print(indices)
#indices = [34155, 2184] #[7066, 11836] #, 34155, 2184]
dataset = torch.utils.data.Subset(imagenet, indices)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True)

model = models.resnet50(pretrained=True)
device = 'cuda:1'
resnet_gradcam_layer = 'layer4'
pruneGradMid_resnet_layer = 'layer2'
cumulative_layers = ['layer4', 'layer3', 'layer2', 'layer1']
model_sparsity_threshold = 75  # Threshold computed for 15% output change for Pruner
#visualize(model, dataloader, indices, classes, resnet_gradcam_layer, pruneGradMid_resnet_layer, [60, 75, 90], cumulative_layers)


# # Evaluate on VGG16

# In[7]:

#num_samples = 1

#indices = random.sample(range(0, len(imagenet)), num_samples)
dataset = torch.utils.data.Subset(imagenet, indices)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True)

model = models.vgg16(pretrained=True)
#device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
#model = models.googlenet(pretrained=True)
vgg_gradcam_layer = 'features'
pruneGradMid_vgg_layer = 'features.20'
cumulative_layers = ['features', 'features.23', 'features.16', 'features.9', 'features.4']
model_sparsity_threshold = 89 #95  # Threshold computed for 15% output change for Pruner
visualize(model, dataloader, indices, classes, vgg_gradcam_layer, pruneGradMid_vgg_layer, [0, 70, 80, 85, 90, 95, 99], cumulative_layers)

