#%% 
import shap  # https://github.com/slundberg/shap
import shapreg  # https://github.com/iancovert/shapley-regression
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import time
import torch

# Load and split data
X_train, X_test, Y_train, Y_test = train_test_split(
    *shap.datasets.adult(), test_size=0.2, random_state=7)
X_train, X_val, Y_train, Y_val = train_test_split(
    X_train, Y_train, test_size=0.2, random_state=0)

# Data scaling
num_features = X_train.shape[1]
feature_names = X_train.columns.tolist()
ss = StandardScaler()
ss.fit(X_train)
X_train = ss.transform(X_train)
X_val = ss.transform(X_val)
X_test = ss.transform(X_test)
import os
os.environ['PYTHONHASHSEED'] = str(2)
np.random.seed(2)
torch.manual_seed(2)
torch.cuda.manual_seed(2)
all = 1000
ind = np.random.choice(len(X_test), size=all)

x = X_test[ind]
y = Y_test[ind]
#%% Load Model
#%% load model
import pickle
from fastshap import Surrogate, FastSHAP
from simshap.simshap_sampling import SimSHAPSampling
import torch.nn as nn
import os
import sys
sys.path.append('..')
from models import SimSHAPTabular
device = torch.device('cuda')
with open('census model.pkl', 'rb') as f:
    model = pickle.load(f)
def original_model(x):
    pred = model.predict(x.cpu().detach().numpy())
    pred = np.stack([1 - pred, pred]).T
    return torch.tensor(pred, dtype=torch.float32, device=x.device)
surr = torch.load('census surrogate.pt').to(device)
surrogate = Surrogate(surr, num_features)

explainer_simshap = torch.load('census simshap.pt').to(device)
explainer_simshap.eval()
simshap = SimSHAPSampling(explainer=explainer_simshap, imputer=surrogate, device=device)

#%% Fastshap
from simshap.fastshap_plus import FastSHAP
explainer_fastshap = torch.load('census fastshap.pt').to(device)
explainer_fastshap.eval()
fastshap = FastSHAP(explainer_fastshap, surrogate,normalization='additive',
                        link=nn.Identity())

start = time.time()
fastshap_values = fastshap.shap_values(torch.tensor(x, dtype=torch.float32, device=device))
end = time.time()
print('fastshap running time:', end - start)
#%% Simshap
from simshap.simshap_sampling import SimSHAPSampling
explainer_sim = torch.load('census bestsimshap.pt').to(device)
explainer_sim.eval()
simshap = SimSHAPSampling(explainer_sim, surrogate, device=device)

start = time.time()
simshap_values = simshap.shap_values(x)
end = time.time()
print('simshap running time:', end - start)
#%% Kernelshap
import shap
def model_wrapper(x):
    x = torch.tensor(x, dtype=torch.float32, device=device)
    pred = original_model(x)
    return pred.cpu().data.numpy()
med = np.median(X_train, axis=0).reshape((1, -1))
kernelshap = shap.KernelExplainer(model_wrapper, med)
start = time.time()
kernelshap_values = kernelshap.shap_values(x, nsamples='auto')
end = time.time()
print('kernelshap running time:', end - start)
#%% Kernelshap-S
def imputer(x, S):
    x = torch.tensor(x, dtype=torch.float32, device=device)
    S = torch.tensor(S, dtype=torch.float32, device=device)
    pred = surrogate(x, S)
    return pred.cpu().data.numpy()


med = np.median(X_train, axis=0).reshape((1, -1))
start = time.time()
for i in range(all):
    def f_mask(z):
        return imputer(x[i:i+1], z)
    kernelshap_s = shap.KernelExplainer(f_mask, np.zeros((1, 12)))
    kernelshap_s.shap_values(x[i:i+1], nsamples='auto')
end = time.time()
print('kernelshap-s running time:', end - start)

#%% IG
from captum.attr import IntegratedGradients
ig = IntegratedGradients(original_model)
start = time.time()
ig_values = ig.attribute(torch.tensor(x, dtype=torch.float32, device=device))
end = time.time()
print('ig running time:', end - start)
#%% SmoothGrad
from captum.attr import IntegratedGradients, NoiseTunnel
ig = IntegratedGradients(original_model)
sg = NoiseTunnel(ig)
start = time.time()
sg_values = sg.attribute(torch.tensor(x, dtype=torch.float32, device=device), nt_type='smoothgrad',nt_samples=4, target=y)
end = time.time()
print('sg running time:', end - start) 