#%% 
import shap  # https://github.com/slundberg/shap
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
#%% 
# Load and split data
file = 'Taiwan_data_ENG_95.csv'
data = pd.read_csv(file, encoding='utf-8')

#%% 
Y = np.array(data['Flag'])
X = np.array(data.drop(['Flag'], axis=1))
X_train, X_test, Y_train, Y_test = train_test_split(
    X, Y, test_size=0.2, random_state=70)
X_train, X_val, Y_train, Y_val = train_test_split(
    X_train, Y_train, test_size=0.2, random_state=42)

# Data scaling
num_features = X_train.shape[1]
# feature_names = data.drop(['Flag'], axis=1).columns.tolist()
ss = StandardScaler()
ss.fit(X_train)
X_train = ss.transform(X_train)
X_val = ss.transform(X_val)
X_test = ss.transform(X_test)

#%% Load Model
import sys 
sys.path.append("..")
import time
import numpy as np
import torch
import torch.nn as nn
import pandas as pd
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from tqdm.auto import tqdm
from fastshap import Surrogate
import pickle as pkl
import os
os.environ['PYTHONHASHSEED'] = str(2)
np.random.seed(2)
torch.manual_seed(2)
torch.cuda.manual_seed(2)

device = torch.device('cuda')
with open('bank model.pkl', 'rb') as f:
    model = pkl.load(f)
def original_model(x):
    pred = model.predict(x.cpu().detach().numpy())
    pred = np.stack([1 - pred, pred]).T
    return torch.tensor(pred, dtype=torch.float32, device=x.device)
surr = torch.load('bank surrogate.pt').to(device)
surr.eval()
surrogate = surrogate = Surrogate(surr, 94)
all = 1000
ind = np.random.choice(len(X_test), size=all)

x = X_test[ind]
y = Y_test[ind]

#%% Fastshap
from unet import UNet
from simshap.fastshap_plus import FastSHAP

explainer = torch.load('bank fastshap.pt').to(device)
explainer.eval()
fastshap = FastSHAP(explainer, surrogate, link=nn.Identity(), normalization='additive')

start = time.time()
fastshap_values = fastshap.shap_values(x)
end = time.time()
print('fastshap running time:', end - start)
#%% Simshap
from simshap.simshap_sampling import SimSHAPSampling
explainer_sim = torch.load('bank goodsimshap.pt').to(device)
explainer_sim.eval()
simshap = SimSHAPSampling(explainer_sim, surrogate, device=device)

start = time.time()
simshap_values = simshap.shap_values(x)
end = time.time()
print('simshap running time:', end - start)
#%% Kernelshap
import shap
def model_wrapper(x):
    x = torch.tensor(x, dtype=torch.float32, device=device)
    pred = original_model(x)
    return pred.cpu().data.numpy()
med = np.median(X_train, axis=0).reshape((1, -1))
kernelshap = shap.KernelExplainer(model_wrapper, med)
start = time.time()
kernelshap_values = kernelshap.shap_values(x, nsamples='auto')
end = time.time()
print('kernelshap running time:', end - start)
#%% Kernelshap-S
def imputer(x, S):
    x = torch.tensor(x, dtype=torch.float32, device=device)
    S = torch.tensor(S, dtype=torch.float32, device=device)
    pred = surrogate(x, S)
    return pred.cpu().data.numpy()


med = np.median(X_train, axis=0).reshape((1, -1))
start = time.time()
for i in range(all):
    def f_mask(z):
        return imputer(x[i:i+1], z)
    kernelshap_s = shap.KernelExplainer(f_mask, np.zeros((1, 94)))
    kernelshap_s.shap_values(x[i:i+1], nsamples='auto')
end = time.time()
print('kernelshap-s running time:', end - start)
#%% IG
from captum.attr import IntegratedGradients
ig = IntegratedGradients(original_model)
start = time.time()
ig_values = ig.attribute(torch.tensor(x, dtype=torch.float32, device=device), target=torch.tensor(y, dtype=torch.long, device=device))
end = time.time()
print('ig running time:', end - start)

#%% SmoothGrad
from captum.attr import IntegratedGradients, NoiseTunnel
ig = IntegratedGradients(original_model)
sg = NoiseTunnel(ig)
start = time.time()
sg_values = sg.attribute(torch.tensor(x, dtype=torch.float32, device=device), nt_type='smoothgrad',nt_samples=4, target=y)
end = time.time()
print('sg running time:', end - start)