#%% Fastshap
import sys 
sys.path.append("..")
from metric.metric import Metric
import numpy as np
import torch
import torch.nn as nn
import os
from tqdm.auto import tqdm
from fastshap import ImageSurrogate
import pickle as pkl
os.environ['PYTHONHASHSEED'] = str(2)
np.random.seed(2)
torch.manual_seed(2)
torch.cuda.manual_seed(2)

# mkdir
if not os.path.exists('results_deletion_insertion'):
    os.makedirs('results_deletion_insertion')
#%% Load Model and Surrogate
device = torch.device('cpu')
model = torch.load('cifar resnet.pt').to(device)
model.eval()
surr = torch.load('cifar surrogate.pt').to(device)
surr.eval()
surrogate = ImageSurrogate(surr, width=32, height=32, superpixel_size=2)


#%% 
import torchvision.datasets as dsets
import torchvision.transforms as transforms
# Transformations
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

# Load train set
train_set = dsets.CIFAR10('./', train=True, download=True, transform=transform_train)
val_set = dsets.CIFAR10('./', train=False, download=True, transform=transform_test)
val_set, test_set = torch.utils.data.random_split(val_set, [5000, 5000])
num_classes = 10
each_size = 500 # all 5000 images
batch_size = 100
targets = np.array([test_set[i][1] for i in range(len(test_set))])
inds_lists = [np.where(targets == cat)[0] for cat in range(num_classes)]
x = torch.zeros(10*each_size, 3, 32, 32).to(device)
y = torch.zeros(10*each_size, dtype=torch.long).to(device)

inds = np.array([np.random.choice(cat_inds, size=each_size) for cat_inds in inds_lists]).flatten()
xx, yy = zip(*[val_set[ind] for ind in inds])
yy = torch.tensor(yy).to(device)
xx = torch.stack(xx).to(device)
# form a dataset and a dataloader
metric_dataset = torch.utils.data.TensorDataset(xx, yy)
metric_loader = torch.utils.data.DataLoader(metric_dataset, batch_size=batch_size, shuffle=False)
metric = Metric(step_per=5, klen=11, ksig=5, device=torch.device("cpu"), use_softmax=False, superpixel=2)

#%% Fastshap Evaluation
from unet import UNet
from fastshap.fastshap import FastSHAP

explainer = torch.load('cifar explainer.pt').to(device)
explainer.eval()
fastshap = FastSHAP(explainer, surrogate, link=nn.Identity())
insertion_score_fast = []
deletion_score_fast = []
insertion_auc_fast = []
deletion_auc_fast = []
for x, y in tqdm(metric_loader):
    output_dict = metric(model, x, y, layer_name=None, explainer_type='Fastshap', explainer=fastshap)
    insertion_score_fast.append(torch.mean(output_dict['insertion']['score'], dim=0))
    deletion_score_fast.append(torch.mean(output_dict['deletion']['score'], dim=0))
    insertion_auc_fast.append(np.mean(output_dict['insertion']['auc'], axis=0))
    deletion_auc_fast.append(np.mean(output_dict['deletion']['auc'], axis=0))

fastshap_metric = {}
fastshap_metric['insertion'] = {}
fastshap_metric['deletion'] = {}
fastshap_metric['insertion']['score'] = torch.stack(insertion_score_fast).cpu().numpy()
fastshap_metric['deletion']['score'] = torch.stack(deletion_score_fast).cpu().numpy()
fastshap_metric['insertion']['auc'] = np.array(insertion_auc_fast)
fastshap_metric['deletion']['auc'] = np.array(deletion_auc_fast)

with open ('results_deletion_insertion/fastshap_metric.pkl', 'wb') as f:
    pkl.dump(fastshap_metric, f)

#%% Simshap Evaluation
from simshap.simshap_sampling import SimSHAPSampling
explainer_sim = torch.load('cifar simshap.pt').to(device)
explainer_sim.eval()
simshap = SimSHAPSampling(explainer_sim, surrogate, device=device)
insertion_score_sim = []
deletion_score_sim = []
insertion_auc_sim = []
deletion_auc_sim = []
for x, y in tqdm(metric_loader):
    output_dict = metric(model, x, y, layer_name=None, explainer_type='Simshap', explainer=simshap)
    insertion_score_sim.append(torch.mean(output_dict['insertion']['score'], dim=0))
    deletion_score_sim.append(torch.mean(output_dict['deletion']['score'], dim=0))
    insertion_auc_sim.append(np.mean(output_dict['insertion']['auc'], axis=0))
    deletion_auc_sim.append(np.mean(output_dict['deletion']['auc'], axis=0))

simshap_metric = {}
simshap_metric['insertion'] = {}
simshap_metric['deletion'] = {}
simshap_metric['insertion']['score'] = torch.stack(insertion_score_sim).cpu().numpy()
simshap_metric['deletion']['score'] = torch.stack(deletion_score_sim).cpu().numpy()
simshap_metric['insertion']['auc'] = np.array(insertion_auc_sim)
simshap_metric['deletion']['auc'] = np.array(deletion_auc_sim)
print('insertion auc mean:', np.mean(simshap_metric['insertion']['auc']))
print('insertion auc std:', np.std(simshap_metric['insertion']['auc']))
print('deletion auc mean:', np.mean(simshap_metric['deletion']['auc']))
print('deletion auc std:', np.std(simshap_metric['deletion']['auc']))

with open ('results_deletion_insertion/simshap_metric.pkl', 'wb') as f:
    pkl.dump(simshap_metric, f)

#%% Grad-cam Evaluation
from methods.gradcam import GradCAM

insertion_score_grad = []
deletion_score_grad = []
insertion_auc_grad = []
deletion_auc_grad = []

Gradcam = GradCAM()

layer_name = 'layers[3][1].bn2'
for x,y in tqdm(metric_loader):
    output_dict = metric(model, x, y, layer_name=layer_name, explainer_type='gradcam', explainer=Gradcam)
    insertion_score_grad.append(torch.mean(output_dict['insertion']['score'], dim=0))
    deletion_score_grad.append(torch.mean(output_dict['deletion']['score'], dim=0))
    insertion_auc_grad.append(np.mean(output_dict['insertion']['auc'], axis=0))
    deletion_auc_grad.append(np.mean(output_dict['deletion']['auc'], axis=0))

gradcam_metric = {}
gradcam_metric['insertion'] = {}
gradcam_metric['deletion'] = {}
gradcam_metric['insertion']['score'] = torch.stack(insertion_score_grad).cpu().numpy()
gradcam_metric['deletion']['score'] = torch.stack(deletion_score_grad).cpu().numpy()
gradcam_metric['insertion']['auc'] = np.array(insertion_auc_grad)
gradcam_metric['deletion']['auc'] = np.array(deletion_auc_grad)


# save
with open('results_deletion_insertion/gradcam_metric.pkl', 'wb') as f:
    pkl.dump(gradcam_metric, f)


#%% IG Evaluation
from captum.attr import IntegratedGradients, NoiseTunnel

explainer_ig = IntegratedGradients(model)
insertion_score_ig = []
deletion_score_ig = []
insertion_auc_ig = []
deletion_auc_ig = []
for x, y in tqdm(metric_loader):
    output_dict = metric(model, x, y, layer_name=None, explainer_type='IG', explainer=explainer_ig)
    insertion_score_ig.append(torch.mean(output_dict['insertion']['score'], dim=0))
    deletion_score_ig.append(torch.mean(output_dict['deletion']['score'], dim=0))
    insertion_auc_ig.append(np.mean(output_dict['insertion']['auc'], axis=0))
    deletion_auc_ig.append(np.mean(output_dict['deletion']['auc'], axis=0))

ig_metric = {}
ig_metric['insertion'] = {}
ig_metric['deletion'] = {}
ig_metric['insertion']['score'] = torch.stack(insertion_score_ig).cpu().numpy()
ig_metric['deletion']['score'] = torch.stack(deletion_score_ig).cpu().numpy()
ig_metric['insertion']['auc'] = np.array(insertion_auc_ig)
ig_metric['deletion']['auc'] = np.array(deletion_auc_ig)

# save
with open('results_deletion_insertion/ig_metric.pkl', 'wb') as f:
    pkl.dump(ig_metric, f)

#%% Deepshap Evaluation
import shap

model = torch.load('cifar resnet deeplift.pt').to(device)
model.eval()
explainer_deep = shap.DeepExplainer(model,torch.zeros(20, 3, 32, 32)) # 相对于全0向量的影响
insertion_score_deep = []
deletion_score_deep = []
insertion_auc_deep = []
deletion_auc_deep = []
for x, y in tqdm(metric_loader):
    output_dict = metric(model, x, y, layer_name=None, explainer_type='Deepshap', explainer=explainer_deep)
    insertion_score_deep.append(torch.mean(output_dict['insertion']['score'], dim=0))
    deletion_score_deep.append(torch.mean(output_dict['deletion']['score'], dim=0))
    insertion_auc_deep.append(np.mean(output_dict['insertion']['auc'], axis=0))
    deletion_auc_deep.append(np.mean(output_dict['deletion']['auc'], axis=0))
deep_metric = {}
deep_metric['insertion'] = {}
deep_metric['deletion'] = {}
deep_metric['insertion']['score'] = torch.stack(insertion_score_deep).cpu().numpy()
deep_metric['deletion']['score'] = torch.stack(deletion_score_deep).cpu().numpy()
deep_metric['insertion']['auc'] = np.array(insertion_auc_deep)
# deep_metric['deletion']['auc'] = np.array(deletion_auc_deep)


# save
with open('results_deletion_insertion/deep_metric.pkl', 'wb') as f:
    pkl.dump(deep_metric, f)

#%% SmoothGrad Evaluation
from captum.attr import IntegratedGradients, NoiseTunnel
explainer_ig = IntegratedGradients(model)
nt = NoiseTunnel(explainer_ig)
insertion_score_smoothgrad = []
deletion_score_smoothgrad = []
insertion_auc_smoothgrad = []
deletion_auc_smoothgrad = []

for x, y in tqdm(metric_loader):
    output_dict = metric(model, x, y, layer_name=None, explainer_type='SmoothGrad', explainer=nt)
    insertion_score_smoothgrad.append(torch.mean(output_dict['insertion']['score'], dim=0))
    deletion_score_smoothgrad.append(torch.mean(output_dict['deletion']['score'], dim=0))
    insertion_auc_smoothgrad.append(np.mean(output_dict['insertion']['auc'], axis=0))
    deletion_auc_smoothgrad.append(np.mean(output_dict['deletion']['auc'], axis=0))
smoothgrad_metric = {}
smoothgrad_metric['insertion'] = {}
smoothgrad_metric['deletion'] = {}
smoothgrad_metric['insertion']['score'] = torch.stack(insertion_score_smoothgrad).cpu().numpy()
smoothgrad_metric['deletion']['score'] = torch.stack(deletion_score_smoothgrad).cpu().numpy()
smoothgrad_metric['insertion']['auc'] = np.array(insertion_auc_smoothgrad)
smoothgrad_metric['deletion']['auc'] = np.array(deletion_auc_smoothgrad)

# save
with open('results_deletion_insertion/smoothgrad_metric.pkl', 'wb') as f:
    pkl.dump(smoothgrad_metric, f)
#%% Kernelshap Evaluation(Original Model)
import shap
def model_wrapper(x):
    x = torch.tensor(x, dtype=torch.float32, device=device)
    pred = model(x)
    return pred.cpu().data.numpy()

insertion_score_kernelshap = []
deletion_score_kernelshap = []
insertion_auc_kernelshap =[]
deletion_auc_kernelshap = []
for x, y in tqdm(metric_loader):
    output_dict = metric(model, x, y, layer_name=None, explainer_type='Kernelshap', explainer=model_wrapper)
    insertion_score_kernelshap.append(torch.mean(output_dict['insertion']['score'], dim=0))
    deletion_score_kernelshap.append(torch.mean(output_dict['deletion']['score'], dim=0))
    insertion_auc_kernelshap.append(np.mean(output_dict['insertion']['auc'], axis=0))
    deletion_auc_kernelshap.append(np.mean(output_dict['deletion']['auc'], axis=0))

kernelshap_metric = {}
kernelshap_metric['insertion'] = {}
kernelshap_metric['deletion'] = {}
kernelshap_metric['insertion']['score'] = torch.stack(insertion_score_kernelshap).cpu().numpy()
kernelshap_metric['deletion']['score'] = torch.stack(deletion_score_kernelshap).cpu().numpy()
kernelshap_metric['insertion']['auc'] = np.array(insertion_auc_kernelshap)
kernelshap_metric['deletion']['auc'] = np.array(deletion_auc_kernelshap)

# save
with open('results_deletion_insertion/kernelshap_metric.pkl', 'wb') as f:
    pkl.dump(kernelshap_metric, f)
#%% Kernelshap-S Evaluation
import shap
def imputer(x, S):
    x = torch.tensor(x, dtype=torch.float32, device=device)
    S = torch.tensor(S, dtype=torch.float32, device=device)
    pred = surrogate(x, S)
    return pred.cpu().data.numpy()

insertion_score_kernelS = []
deletion_score_kernelS = []
insertion_auc_kernelS =[]
deletion_auc_kernelS = []
for x, y in tqdm(metric_loader):
    output_dict = metric(model, x, y, layer_name=None, explainer_type='Kernelshap-S', explainer=imputer)
    insertion_score_kernelS.append(torch.mean(output_dict['insertion']['score'], dim=0))
    deletion_score_kernelS.append(torch.mean(output_dict['deletion']['score'], dim=0))
    insertion_auc_kernelS.append(np.mean(output_dict['insertion']['auc'], axis=0))
    deletion_auc_kernelS.append(np.mean(output_dict['deletion']['auc'], axis=0))

kernelS_metric = {}
kernelS_metric['insertion'] = {}
kernelS_metric['deletion'] = {}
kernelS_metric['insertion']['score'] = torch.stack(insertion_score_kernelS).cpu().numpy()
kernelS_metric['deletion']['score'] = torch.stack(deletion_score_kernelS).cpu().numpy()
kernelS_metric['insertion']['auc'] = np.array(insertion_auc_kernelS)
kernelS_metric['deletion']['auc'] = np.array(deletion_auc_kernelS)

# save
with open('results_deletion_insertion/kernelS_metric.pkl', 'wb') as f:
    pkl.dump(kernelS_metric, f)

#%% Load
with open('results_deletion_insertion/fastshap_metric.pkl', 'rb') as f:
    fastshap_metric = pkl.load(f)
with open('results_deletion_insertion/simshap_metric.pkl', 'rb') as f:
    simshap_metric = pkl.load(f)
with open('results_deletion_insertion/gradcam_metric.pkl', 'rb') as f:
    gradcam_metric = pkl.load(f)
with open('results_deletion_insertion/ig_metric.pkl', 'rb') as f:
    ig_metric = pkl.load(f)
with open('results_deletion_insertion/deep_metric.pkl', 'rb') as f:
    deep_metric = pkl.load(f)
with open('results_deletion_insertion/smoothgrad_metric.pkl', 'rb') as f:
    smoothgrad_metric = pkl.load(f)
with open('results_deletion_insertion/kernelshap_metric.pkl', 'rb') as f:
    kernelshap_metric = pkl.load(f)
with open('results_deletion_insertion/kernelS_metric.pkl', 'rb') as f:
    kernelS_metric = pkl.load(f)
#%% Insertion
import seaborn as sns
import pickle as pkl
import numpy as np
import matplotlib.pyplot as plt
from scipy.ndimage import gaussian_filter1d
def normalize(data, mtype='del'):
    if mtype == "del":
        data_bottom = data[:, -1][:, np.newaxis]
        data = (data - data_bottom)
        data_peak = data[:, 0][:, np.newaxis]
        data = data / data_peak
    elif mtype == "ins":
        data_bottom = data[:, 0][:, np.newaxis]
        data = (data - data_bottom)
        data_peak = data[:, -1][:, np.newaxis]
        data = data / data_peak
    return data

sns.set_theme(style="white")
sns.despine()
ax=plt.gca() 
ax.spines['right'].set_color('none')
ax.spines['top'].set_color('none')



x_axis = np.arange(0, 100, 5 * 100 / 256) # inclusion percent
x_axis = np.append(x_axis, 100)
# fastshap
y_smoothed = gaussian_filter1d(np.mean(fastshap_metric['insertion']['score'], axis=0), sigma=1)
plt.plot(x_axis, y_smoothed, label='Fastshap', color='tab:green')

# simshap
y_smoothed = gaussian_filter1d(np.mean(simshap_metric['insertion']['score'], axis=0), sigma=1)
plt.plot(x_axis, y_smoothed, label='Simshap', color='tab:red')
# gradcam
y_smoothed = gaussian_filter1d(np.mean(gradcam_metric['insertion']['score'], axis=0), sigma=1)
plt.plot(x_axis, y_smoothed, label='Gradcam', color='tab:orange')


# # ig
y_smoothed = gaussian_filter1d(np.mean(ig_metric['insertion']['score'], axis=0), sigma=3)
plt.plot(x_axis, y_smoothed, label='IG', color='tab:olive')

# smoothgrad
y_smoothed = gaussian_filter1d(np.mean(smoothgrad_metric['insertion']['score'], axis=0), sigma=1)
plt.plot(x_axis, y_smoothed, label='SmoothGrad', color='tab:purple')


# deepshap
y_smoothed = gaussian_filter1d(np.mean(deep_metric['insertion']['score'], axis=0), sigma=1)
plt.plot(x_axis, y_smoothed, label='Deepshap', color='tab:green')


# kernelshap
y_smoothed = gaussian_filter1d(np.mean(kernelshap_metric['insertion']['score'], axis=0), sigma=2)
plt.plot(x_axis, y_smoothed, label='Kernelshap', color='tab:brown')
# normalized_score = normalize(kernelshap_metric['insertion']['score'], mtype='ins')
# y_smoothed = gaussian_filter1d(np.median(normalized_score, axis=0), sigma=1)
# plt.plot(x_axis, y_smoothed, label='Kernelshap', color='tab:brown')

# kernelshap-S
y_smoothed = gaussian_filter1d(np.mean(kernelS_metric['insertion']['score'], axis=0), sigma=1)
plt.plot(x_axis, y_smoothed, label='Kernelshap-S', color='tab:pink')
# normalized_score = normalize(kernelS_metric['insertion']['score'], mtype='ins')
# y_smoothed = gaussian_filter1d(np.median(normalized_score, axis=0), sigma=1)
# plt.plot(x_axis, y_smoothed, label='Kernelshap-S', color='tab:pink')

plt.title('Insertion Score', fontsize=15, fontweight='bold')
plt.xlabel('Fraction of Features')
font_legend = {'weight': 'normal',
         'size': 10,
         }
plt.legend(loc='best', frameon=False, prop=font_legend)

plt.tight_layout()
plt.savefig('cifar insertion.png', dpi=300)
plt.savefig('cifar insertion.pdf')
plt.show()
#%% Deletion
import seaborn as sns
import pickle as pkl
import numpy as np
import matplotlib.pyplot as plt
from scipy.ndimage import gaussian_filter1d
plt.clf()
sns.set_theme(style="white")
sns.despine()
ax=plt.gca() 
ax.spines['right'].set_color('none')
ax.spines['top'].set_color('none')



x_axis = np.arange(0, 100, 5 * 100 / 256) # inclusion percent
x_axis = np.append(x_axis, 100)
# fastshap
y_smoothed = gaussian_filter1d(np.mean(fastshap_metric['deletion']['score'], axis=0), sigma=1)
plt.plot(x_axis, y_smoothed, label='Fastshap', color='tab:green')

# simshap
y_smoothed = gaussian_filter1d(np.mean(simshap_metric['deletion']['score'], axis=0), sigma=1)
plt.plot(x_axis, y_smoothed, label='Simshap', color='tab:red')

# gradcam
y_smoothed = gaussian_filter1d(np.mean(gradcam_metric['deletion']['score'], axis=0), sigma=1)
plt.plot(x_axis, y_smoothed, label='Gradcam', color='tab:orange')

# ig
y_smoothed = gaussian_filter1d(np.mean(ig_metric['deletion']['score'], axis=0), sigma=2)
plt.plot(x_axis, y_smoothed, label='IG', color='tab:olive')

# smoothgrad
y_smoothed = gaussian_filter1d(np.mean(smoothgrad_metric['deletion']['score'], axis=0), sigma=1)
plt.plot(x_axis, y_smoothed, label='SmoothGrad', color='tab:purple')

# deepshap
y_smoothed = gaussian_filter1d(np.mean(deep_metric['deletion']['score'], axis=0), sigma=1)
plt.plot(x_axis, y_smoothed, label='Deepshap', color='tab:blue')

# kernelshap
y_smoothed = gaussian_filter1d(np.mean(kernelshap_metric['deletion']['score'], axis=0), sigma=1)
plt.plot(x_axis, y_smoothed, label='Kernelshap', color='tab:brown')


# kernelshap-S
y_smoothed = gaussian_filter1d(np.mean(kernelS_metric['deletion']['score'], axis=0), sigma=1)
plt.plot(x_axis, y_smoothed, label='Kernelshap-S', color='tab:pink')

plt.title('Deletion Score', fontsize=15, fontweight='bold')
plt.xlabel('Fraction of Features')
font_legend = {'weight': 'normal',
         'size': 10,
         }
plt.legend(loc='best', frameon=False, prop=font_legend)

plt.savefig('cifar deletion.png', dpi=300)
plt.savefig('cifar deletion.pdf')
plt.show()
#%% print aucs

print('Fastshap insertion auc mean:', np.mean(fastshap_metric['insertion']['auc']))
print('Fastshap insertion auc std:', np.std(fastshap_metric['insertion']['auc']))
print('Fastshap deletion auc mean:', np.mean(fastshap_metric['deletion']['auc']))
print('Fastshap deletion auc std:', np.std(fastshap_metric['deletion']['auc']))

print('Simshap insertion auc mean:', np.mean(simshap_metric['insertion']['auc']))
print('Simshap insertion auc std:', np.std(simshap_metric['insertion']['auc']))
print('Simshap deletion auc mean:', np.mean(simshap_metric['deletion']['auc']))
print('Simshap deletion auc std:', np.std(simshap_metric['deletion']['auc']))

print('Gradcam insertion auc mean:', np.mean(gradcam_metric['insertion']['auc']))
print('Gradcam insertion auc std:', np.std(gradcam_metric['insertion']['auc']))
print('Gradcam deletion auc mean:', np.mean(gradcam_metric['deletion']['auc']))
print('Gradcam deletion auc std:', np.std(gradcam_metric['deletion']['auc']))

print('IG insertion auc mean:', np.mean(ig_metric['insertion']['auc']))
print('IG insertion auc std:', np.std(ig_metric['insertion']['auc']))
print('IG deletion auc mean:', np.mean(ig_metric['deletion']['auc']))
print('IG deletion auc std:', np.std(ig_metric['deletion']['auc']))

print('SmoothGrad insertion auc mean:', np.mean(smoothgrad_metric['insertion']['auc']))
print('SmoothGrad insertion auc std:', np.std(smoothgrad_metric['insertion']['auc']))
print('SmoothGrad deletion auc mean:', np.mean(smoothgrad_metric['deletion']['auc']))
print('SmoothGrad deletion auc std:', np.std(smoothgrad_metric['deletion']['auc']))

print('Deepshap insertion auc mean:', np.mean(deep_metric['insertion']['auc']))
print('Deepshap insertion auc std:', np.std(deep_metric['insertion']['auc']))
print('Deepshap deletion auc mean:', np.mean(deep_metric['deletion']['auc']))
print('Deepshap deletion auc std:', np.std(deep_metric['deletion']['auc']))

print('Kernelshap insertion auc mean:', np.mean(kernelshap_metric['insertion']['auc']))
print('Kernelshap insertion auc std:', np.std(kernelshap_metric['insertion']['auc']))
print('Kernelshap deletion auc mean:', np.mean(kernelshap_metric['deletion']['auc']))
print('Kernelshap deletion auc std:', np.std(kernelshap_metric['deletion']['auc']))

print('Kernelshap-S insertion auc mean:', np.mean(kernelS_metric['insertion']['auc']))
print('Kernelshap-S insertion auc std:', np.std(kernelS_metric['insertion']['auc']))
print('Kernelshap-S deletion auc mean:', np.mean(kernelS_metric['deletion']['auc']))
print('Kernelshap-S deletion auc std:', np.std(kernelS_metric['deletion']['auc']))
