#%%
import sys 
sys.path.append("..")

import numpy as np
import torch
import torch.nn as nn
from fastshap import ImageSurrogate
import time
#%% Load Model and Surrogate
# device = torch.device('')
device = torch.device('cuda')
model = torch.load('cifar resnet.pt').to(device)
surr = torch.load('cifar surrogate.pt').to(device)
surrogate = ImageSurrogate(surr, width=32, height=32, superpixel_size=2)


#%% 
import torchvision.datasets as dsets
import torchvision.transforms as transforms
# Transformations
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

# Load train set
train_set = dsets.CIFAR10('./', train=True, download=True, transform=transform_train)
# Load test set (using as validation)
val_set = dsets.CIFAR10('./', train=False, download=True, transform=transform_test)

np.random.seed(0)
num_classes = 10
targets = np.array(val_set.targets)
inds_lists = [np.where(targets == cat)[0] for cat in range(num_classes)]
inds = np.array([np.random.choice(cat_inds, size=100) for cat_inds in inds_lists]).flatten() # total 1000 images
x, y = zip(*[val_set[ind] for ind in inds])
y = torch.tensor(y).to(device)
x = torch.stack(x).to(device)

#%% fastshap speed

import sys
sys.path.append('..')
from fastshap.fastshap import FastSHAP
import time
explainer = torch.load('cifar fastshap.pt').to(device)
fastshap = FastSHAP(explainer, surrogate, link=nn.Identity())

start = time.time()
fastshap_values = fastshap.shap_values(x.to(device))
end = time.time()
print('fastshap time: ', end - start)


# %% simshap speed
from simshap.simshap_sampling import SimSHAPSampling
explainer = torch.load('cifar simshap.pt').to(device)
simshap = SimSHAPSampling(explainer, surrogate, device=device)
start = time.time()
simshap_values = simshap.shap_values(x)
end = time.time()
print('simshap time: ', end - start)

#%% Kernelshap
import shap
from tqdm.auto import tqdm
import math
def mask_image(masks, image, background=None): # for kernelshap
    # Reshape/size Mask 
    mask_shape = int(masks.shape[1]**.5)
    masks = np.reshape(masks, (masks.shape[0], 1, mask_shape, mask_shape))
    resize_aspect = image.shape[-1]/mask_shape
    masks = np.repeat(masks, resize_aspect, axis =2)
    masks = np.repeat(masks, resize_aspect, axis =3)
    
    # Mask Image 
    if background is not None:
        if len(background.shape) == 3:
            masked_images = np.vstack([np.expand_dims(
                (mask * image) + ((1-mask)*background[0]), 0
            ) for mask in masks])
        else:
            # Fill with Background
            masked_images = []
            for mask in masks:
                bg = [im * (1-mask) for im in background]
                masked_images.append(np.vstack([np.expand_dims((mask*image) + fill, 0) for fill in bg]))     
    else:     
        masked_images = np.vstack([mask * image for mask in masks])
        
    return masked_images #masks, image
def model_wrapper(x):
    x = torch.tensor(x, dtype=torch.float32, device=device)
    pred = model(x)
    return pred.cpu().data.numpy()
start = time.time()
for i in range(x.shape[0]):
    image = x[i:i+1].cpu().numpy()
    background = None
    def f_mask(z):
        if background is None or len(background.shape)==3:
            y_p = []
            if z.shape[0] == 1:
                masked_images = mask_image(z, image, background)
                return(model_wrapper(masked_images))
            else:
                for i in tqdm(range(int(math.ceil(z.shape[0]/100)))):
                    m = z[i*100:(i+1)*100]
                    masked_images = mask_image(m, image, background)
                    y_p.append(model_wrapper(masked_images))
                print (np.vstack(y_p).shape)
                return np.vstack(y_p)
        else:
            y_p = []
            if z.shape[0] == 1:
                masked_images = mask_image(z, image, background)
                for masked_image in masked_images:
                    y_p.append(np.mean(model_wrapper(masked_image), 0))
            else:
                for i in tqdm(range(int(math.ceil(z.shape[0]/100)))):
                    m = z[i*100:(i+1)*100]
                    masked_images = mask_image(m, image, background)
                    for masked_image in masked_images:
                        y_p.append(np.mean(model_wrapper(masked_image), 0))
            return np.vstack(y_p)
    explainer_kernelshaps = shap.KernelExplainer(f_mask, np.zeros((1, 256)))
    shap_values = explainer_kernelshaps.shap_values(np.ones((1, 256)), nsamples='auto')
end = time.time()
print('kernelshap time: ', end - start)
#%% Kernelshap-S 
import shap
from tqdm.auto import tqdm
import math

def imputer(x, S):
    x = torch.tensor(x, dtype=torch.float32, device=device)
    S = torch.tensor(S, dtype=torch.float32, device=device)
    pred = surrogate(x, S)
    return pred.cpu().data.numpy()
start = time.time()
for i in range(x.shape[0]):
    image = x[i:i+1].cpu().numpy()
    def f_mask(z):
        y_p = []
        if z.shape[0] == 1:
            return imputer(image, z)
        else:
            for i in tqdm(range(int(math.ceil(z.shape[0]/100)))):
                m = z[i*100:(i+1)*100]
                y_p.append(imputer(image,m))
            # print (np.vstack(y_p).shape)
            return np.vstack(y_p)
    explainer_kernelshaps = shap.KernelExplainer(f_mask, np.zeros((1, 256)))
    shap_values = explainer_kernelshaps.shap_values(np.ones((1, 256)), nsamples='auto')
end = time.time()
print('kernelshap-S time: ', end - start)
#%% GradCAM
import sys
sys.path.append('..')
import time
from methods.gradcam import GradCAM

gradcam = GradCAM()
layer_name = 'layers[3][1].bn2'
start = time.time()

gradcam_values = gradcam(model, x, layer_name)
end = time.time()
print('gradcam time: ', end - start)

#%% IG speed
from captum.attr import IntegratedGradients, NoiseTunnel
import time
from tqdm.auto import tqdm
explainer_ig = IntegratedGradients(model)
start = time.time()
for i in tqdm(range(1000)):
    ig_values = explainer_ig.attribute(x[i:i+1], target=y[i:i+1])
end = time.time()
print('ig time: ', end - start)

#%% smoothgrad
from captum.attr import IntegratedGradients, NoiseTunnel
import time
from tqdm.auto import tqdm
explainer_ig = IntegratedGradients(model)
nt = NoiseTunnel(explainer_ig)
start = time.time()
for i in tqdm(range(1000)):
    ig_values = nt.attribute(x[i:i+1], target=y[i:i+1], nt_type='smoothgrad', nt_samples=4)
end = time.time()
print('smoothgrad time: ', end - start)

#%% deepshap
import shap
model = torch.load('cifar resnet deeplift.pt').to(device)
model.eval()
explainer_deep = shap.DeepExplainer(model,torch.zeros(20, 3, 32, 32).to(device)) # 相对于全0向量的影响

start = time.time()
for i in tqdm(range(1000)):
    deep_values = explainer_deep.shap_values(x[i:i+1])
end = time.time()
print('deepshap time: ', end - start)