# Load Data
#%%
import torchvision.datasets as dsets
from torch.utils.data import Subset
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import os
import torch
os.environ['PYTHONHASHSEED'] = str(420)
import random
random.seed(420)
np.random.seed(420)
torch.random.manual_seed(420)

#%%
# Transformations
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

# Load train set
train_set = dsets.CIFAR10('./', train=True, download=True, transform=transform_train)
# Load test set (using as validation)
val_set = dsets.CIFAR10('./', train=False, download=True, transform=transform_test)
val_set, test_set = torch.utils.data.random_split(val_set, [5000, 5000])
device = torch.device('cpu')
dset = test_set
targets = np.array([dset[i][1] for i in range(len(test_set))])
num_classes = targets.max() + 1
inds_lists = [np.where(targets == cat)[0] for cat in range(num_classes)]
each_size = 10
inds = np.array([np.random.choice(cat_inds, size=each_size) for cat_inds in inds_lists]).flatten()
x, y = zip(*[dset[ind] for ind in inds])
x = torch.stack(x).to(device)
y = torch.tensor(y).to(device)
all_size = 10*each_size
#%% mkdir generation_results
import os
if not os.path.exists('generation_results'):
    os.mkdir('generation_results')
#%% Model and Surrogate Loader
from fastshap import ImageSurrogate
import sys
sys.path.append('..')
from resnet import ResNet18
model = torch.load('cifar resnet.pt').to(device)
surr = torch.load('cifar surrogate.pt').to(device)
surrogate = ImageSurrogate(surr, width=32, height=32, superpixel_size=2)

#%% Generate raw image
import matplotlib.pyplot as plt
transform_vis = transforms.Compose([
    transforms.ToTensor()])
test_set_untransformed = dsets.CIFAR10('./', train=False, download=True, transform=transform_vis)

x_vis, y_vis = zip(*[test_set_untransformed[ind] for ind in inds])
x_vis = torch.stack(x_vis).to(device)
y_vis = torch.tensor(y_vis).to(device)
for i in range(all_size):
    plt.imshow(x_vis[i].permute(1, 2, 0).cpu().numpy())
    plt.axis('off')
    plt.savefig('generation_results/raw{}.png'.format(i))
    plt.show()

# %% Fastshap Generation
from fastshap import FastSHAP
import sys
sys.path.append('..')
from unet import UNet
import torch.nn as nn
explainer = torch.load('cifar fastshap.pt').to(device)
fastshap = FastSHAP(explainer, surrogate, link=nn.Identity())

pred = surrogate(
    x.to(device),
    torch.ones(all_size, surrogate.num_players, device=device)
).softmax(dim=1).cpu().data.numpy()

fastshap_values = fastshap.shap_values(x.to(device))
for i in range(fastshap_values.shape[0]):
    m = np.abs(fastshap_values[i]).max()
    plt.imshow(fastshap_values[i][y[i]], cmap='bwr', vmin=-m, vmax=m)
    plt.axis('off')
    plt.tight_layout()
    plt.savefig('generation_results/fastshap{}.png'.format(i))
    plt.show()

#%% SimSHAP Generation
import torch.nn as nn
from simshap.simshap_sampling import SimSHAPSampling
explainer = torch.load('cifar simshap.pt').to(device)
simshap = SimSHAPSampling(explainer=explainer, imputer=surrogate, device=device)

simshap_values = simshap.shap_values(x.to(device))
for i in range(simshap_values.shape[0]):
    m = np.abs(simshap_values[i]).max()
    plt.imshow(simshap_values[i][y[i]], cmap='bwr', vmin=-m, vmax=m)
    plt.axis('off')
    plt.tight_layout()
    plt.savefig('generation_results/simshap{}.png'.format(i))
    plt.show()

#%% Kernelshap Generation with shap package
import shap
from tqdm.auto import tqdm
import math

## modified from https://github.com/iclr1814/fastshap/blob/master/experiments/images/imagenette/ks_explain.py
def model_wrapper(x):
    x = torch.tensor(x, dtype=torch.float32, device=device)
    pred = model(x)
    return pred.cpu().data.numpy()
def mask_image(masks, image, background=None):
    # Reshape/size Mask 
    mask_shape = int(masks.shape[1]**.5)
    masks = np.reshape(masks, (masks.shape[0], 1, mask_shape, mask_shape))
    resize_aspect = image.shape[-1]/mask_shape
    masks = np.repeat(masks, resize_aspect, axis =2)
    masks = np.repeat(masks, resize_aspect, axis =3)
    
    # Mask Image 
    if background is not None:
        if len(background.shape) == 3:
            masked_images = np.vstack([np.expand_dims(
                (mask * image) + ((1-mask)*background[0]), 0
            ) for mask in masks])
        else:
            # Fill with Background
            masked_images = []
            for mask in masks:
                bg = [im * (1-mask) for im in background]
                masked_images.append(np.vstack([np.expand_dims((mask*image) + fill, 0) for fill in bg]))     
    else:     
        masked_images = np.vstack([mask * image for mask in masks])
        
    return masked_images #masks, image

def resize_mask(masks, image):
    mask_shape = int(masks.shape[1]**.5)
    masks = np.reshape(masks, (masks.shape[0], mask_shape, mask_shape, 1))
    resize_aspect = image.shape[0]/mask_shape
    masks = np.repeat(masks, resize_aspect, axis =1)
    masks = np.repeat(masks, resize_aspect, axis =2)
    
    return masks

for i in range(all_size):
    img = x[i:i+1].cpu().numpy()
    background = None

    def f_mask(z):
        if background is None or len(background.shape)==3:
            y_p = []
            if z.shape[0] == 1:
                masked_images = mask_image(z, img, background)
                return(model_wrapper(masked_images))
            else:
                for i in tqdm(range(int(math.ceil(z.shape[0]/100)))):
                    m = z[i*100:(i+1)*100]
                    masked_images = mask_image(m, img, background)
                    y_p.append(model_wrapper(masked_images))
                print (np.vstack(y_p).shape)
                return np.vstack(y_p)
        else:
            y_p = []
            if z.shape[0] == 1:
                masked_images = mask_image(z, img, background)
                for masked_image in masked_images:
                    y_p.append(np.mean(model_wrapper(masked_image), 0))
            else:
                for i in tqdm(range(int(math.ceil(z.shape[0]/100)))):
                    m = z[i*100:(i+1)*100]
                    masked_images = mask_image(m, img, background)
                    for masked_image in masked_images:
                        y_p.append(np.mean(model_wrapper(masked_image), 0))
            return np.vstack(y_p)

    explainer = shap.KernelExplainer(f_mask, np.zeros((1,16*16)), link='identity')
    shap_values = explainer.shap_values(np.ones((1,16*16)), nsamples='auto', l1_reg=False)
    m = np.abs(np.stack(shap_values)).max()
    plt.imshow(shap_values[y[i]].reshape(16, 16), cmap='bwr', vmin=-m, vmax=m)
    plt.axis('off')
    plt.tight_layout()
    plt.savefig('generation_results/kernelshap{}.png'.format(i))
    plt.show()


#%% Kernelshap-S Generation
import shap
def imputer(x, S):
    x = torch.tensor(x, dtype=torch.float32, device=device)
    S = torch.tensor(S, dtype=torch.float32, device=device)
    pred = surrogate(x,S)
    return pred.cpu().data.numpy()
def get_mask(masks, image):
    # Reshape/size Mask 
    mask_shape = int(masks.shape[1]**.5)
    masks = np.reshape(masks, (masks.shape[0], 1, mask_shape, mask_shape))
    resize_aspect = image.shape[-1]/mask_shape
    masks = np.repeat(masks, resize_aspect, axis =2)
    masks = np.repeat(masks, resize_aspect, axis =3)
    return masks[0]

for i in range(all_size):
    img = x[i:i+1].cpu().numpy()

    def f_mask(z):
        y_p = []
        if z.shape[0] == 1:
            return imputer(img, z)
        else:
            for i in tqdm(range(int(math.ceil(z.shape[0]/100)))):
                m = z[i*100:(i+1)*100]
                y_p.append(imputer(img,m))
            print (np.vstack(y_p).shape)
            return np.vstack(y_p)

    explainer = shap.KernelExplainer(f_mask, np.zeros((1,16*16)), link='identity')
    shap_values = explainer.shap_values(np.ones((1,16*16)), nsamples='auto', l1_reg=False)
    # shap_values = [resize_mask(sv, img)  for sv in shap_values]
    m = np.abs(np.stack(shap_values)).max()
    plt.imshow(shap_values[y[i]].reshape(16, 16), cmap='bwr', vmin=-m, vmax=m)
    # plt.title('Kernelshap-S')
    plt.axis('off')
    plt.tight_layout()
    plt.savefig('generation_results/kernelshap-s{}.png'.format(i))
    plt.show()


# %% Deepshap Generation
# from captum.attr import DeepLiftShap
# import shap
# import sys
# sys.path.append('..')
# from resnet_deeplift import ResNet18
# model = torch.load('cifar resnet deeplift.pt').to(device)
# model.eval()
# explainer_deep = shap.DeepExplainer(model,) # 相对于全0向量的影响
# for i in range(all_size):
#     result = explainer_deep.shap_values(x[i:i+1].to(device))[y[i]][0]
#     result = np.mean(result, axis=0)
#     m = np.abs(result).max()
#     plt.imshow(result, cmap='bwr', vmin=-m, vmax=m)
#     # plt.title('Deepshap')
#     plt.axis('off')
#     plt.savefig('generation_results/deepshap{}.png'.format(i))
#     plt.show()
from captum.attr import DeepLiftShap
import sys
sys.path.append('..')
from resnet_deeplift import ResNet18
model = torch.load('cifar resnet deeplift.pt').to(device)
model.eval()
explainer = DeepLiftShap(model)
single_tensor = torch.zeros([10, 3, 32, 32], dtype=torch.float32, device=device)
x.requires_grad = True
for i in range(all_size):
    result = explainer.attribute(x[i:i+1], target=y[i], baselines=single_tensor)
    result = result.mean(dim=1, keepdim=False)
    Deepshap_values = result.squeeze(0).detach().cpu().numpy() # 224 * 224
# Deepshap_values = nn.AvgPool2d(superpixel)(result).detach().cpu().numpy().squeeze(1)
    m = np.abs(Deepshap_values).max()
    plt.imshow(Deepshap_values, cmap='bwr', vmin=-m, vmax=m)
    # plt.title('Deepshap')
    plt.axis('off')
    plt.tight_layout()
    plt.savefig('generation_results/deepshap{}.png'.format(i))
    plt.show()

# %% IG Generation

from captum.attr import IntegratedGradients
explainer_ig = IntegratedGradients(model)
for i in range(all_size):

    result = explainer_ig.attribute(torch.tensor(x[i:i+1]), target=y[i:i+1])
    # for i in range(10):
    result = result.mean(dim=1, keepdim=True)
    Ig_values = result.squeeze(1).detach().cpu().numpy() # 224 * 224
    # Deepshap_values = nn.AvgPool2d(superpixel)(result).detach().cpu().numpy().squeeze(1)
    m = np.abs(Ig_values[0]).max()
    plt.imshow(Ig_values[0], cmap='bwr', vmin=-m, vmax=m)
    # plt.title('ig')
    plt.axis('off')
    plt.tight_layout()
    plt.savefig('generation_results/ig{}.png'.format(i))
    plt.show()

#%% SmoothGrad Generation
from captum.attr import IntegratedGradients, NoiseTunnel
explainer_ig = IntegratedGradients(model)
nt = NoiseTunnel(explainer_ig)
for i in range(all_size):
    result = nt.attribute(torch.tensor(x[i:i+1]), target=y[i:i+1], nt_type='smoothgrad', nt_samples=4)
    result = result.mean(dim=1, keepdim=True)
    Ig_values = result.squeeze(1).detach().cpu().numpy() # 224 * 224
    m = np.abs(Ig_values[0]).max()
    plt.imshow(Ig_values[0], cmap='bwr', vmin=-m, vmax=m)
    # plt.title('smoothgrad')
    plt.axis('off')
    plt.tight_layout()
    plt.savefig('generation_results/smoothgrad{}.png'.format(i))
    plt.show()

#%% GradCam Generation
import sys
sys.path.append('..')
from methods.gradcam import GradCAM
Gradcam = GradCAM()
layer_name = 'layers[3][1].bn2'
result = Gradcam(model, x, layer_name)
mask = result.detach().cpu().numpy()
for i in range(len(x)):
    m = np.abs(mask[i]).max()
    plt.imshow(mask[i], cmap='bwr', vmin=-m, vmax=m)
    # plt.title('Gradcam')
    plt.axis('off')
    plt.tight_layout()
    plt.savefig('generation_results/gradcam{}.png'.format(i))
    plt.show()