#%%
import shap  # https://github.com/slundberg/shap
import shapreg  # https://github.com/iancovert/shapley-regression
import numpy as np
import torch
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
device = torch.device("cpu")
# Load and split data
X_train, X_test, Y_train, Y_test = train_test_split(
    *shap.datasets.adult(), test_size=0.2, random_state=7)
X_train, X_val, Y_train, Y_val = train_test_split(
    X_train, Y_train, test_size=0.2, random_state=0)

# Data scaling
num_features = X_train.shape[1]
feature_names = X_train.columns.tolist()
ss = StandardScaler()
ss.fit(X_train)
X_train = ss.transform(X_train)
X_val = ss.transform(X_val)
X_test = ss.transform(X_test)

#%% load model
import pickle
from fastshap import Surrogate, FastSHAP
from simshap.simshap_sampling import SimSHAPSampling
import torch.nn as nn
import sys
sys.path.append('..')
from models import SimSHAPTabular
with open('census model.pkl', 'rb') as f:
    model = pickle.load(f)
surr = torch.load('census surrogate.pt').to(device)
surrogate = Surrogate(surr, num_features)
explainer_fastshap = torch.load('census fastshap.pt').to(device)
fastshap = FastSHAP(explainer_fastshap, surrogate, normalization='additive',
                        link=nn.Identity())
explainer_simshap = torch.load('census simshap.pt').to(device)
simshap = SimSHAPSampling(explainer=explainer_simshap, imputer=surrogate, device=device)

#%% Get SHAP values of fastshap and simshap

def imputer(x, S):
    x = torch.tensor(x, dtype=torch.float32, device=device)
    S = torch.tensor(S, dtype=torch.float32, device=device)
    pred = surrogate(x, S)
    return pred.cpu().data.numpy()
num_eval = np.arange(0, 600, 4)
loss_fastshap_lst = []
loss_simshap_lst = []
np.random.seed(200)
num_samples = 256
samples = 4096
thresh = 0.001
ind = np.random.choice(len(X_test), size=num_samples)

loss_kernelshap = []
loss_kernelshap_pair = []
loss_permutation = []
loss_antithesis = []

# mkdir
import os
if not os.path.exists('results_losscurve'):
    os.mkdir('results_losscurve')
#%% GT shap_Values
shap_values = []

for i in range(num_samples):
    # Get instance
    x = X_test[ind[i]]

    # Set up game
    game = shapreg.games.PredictionGame(imputer, x)
    
    # Calculate ground truth SHAP values
    explanation = shapreg.shapley.ShapleyRegression(game, thresh=thresh, bar=False)
    shap_values.append(explanation.values.T)
    print('Done with sample = {}'.format(i))

with open('results_losscurve/census_shap.pkl', 'wb') as f:
    pickle.dump(shap_values, f)
#%% kernelshap

kernelshap_curves = []
for i in range(num_samples):
    x = X_test[ind[i]]
    game = shapreg.games.PredictionGame(imputer, x)
    results = shapreg.shapley.ShapleyRegression(game, batch_size=32, n_samples=samples, detect_convergence=False,
                                        bar=False, paired_sampling=False, return_all=True)
         
    curve = np.array([explanation.T for explanation in results[1]['values']])
    kernelshap_curves.append(curve)
    print('Done with sample = {}'.format(i))

kernelshap_iters = results[1]['iters']

#%% kernelshap_pair
paired_curves = []
for i in range(num_samples):
    x = X_test[ind[i]]
    game = shapreg.games.PredictionGame(imputer, x)
    results = shapreg.shapley.ShapleyRegression(game, batch_size=32, n_samples=(samples / 2), detect_convergence=False,
                                        bar=False, paired_sampling=True, return_all=True)
         
    curve = np.array([explanation.T for explanation in results[1]['values']])
    paired_curves.append(curve)
    print('Done with sample = {}'.format(i))

paired_iters = results[1]['iters']

#%% permutation
sampling_curves = []

for i in range(num_samples):
    # Get instance
    x = X_test[ind[i]]

    # Set up game
    game = shapreg.games.PredictionGame(imputer, x)
    
    # Calculate ground truth SHAP values
    results = shapreg.shapley_sampling.ShapleySampling(game, batch_size=1, n_samples=int(np.ceil(samples / num_features)), detect_convergence=False,
                                               bar=False, return_all=True)
    curve = np.array([explanation.T for explanation in results[1]['values']])
    sampling_curves.append(curve)
    print('Done with sample = {}'.format(i))

sampling_iters = results[1]['iters']


#%% antithetical 
antithetical_curves = []

for i in range(num_samples):
    # Get instance
    x = X_test[ind[i]]

    # Set up game
    game = shapreg.games.PredictionGame(imputer, x)
    
    # Calculate ground truth SHAP values
    results = shapreg.shapley_sampling.ShapleySampling(game, batch_size=2, n_samples=int(np.ceil(samples / num_features)), detect_convergence=False,
                                               bar=False, antithetical=True, return_all=True)
    curve = np.array([explanation.T for explanation in results[1]['values']])
    antithetical_curves.append(curve)
    print('Done with sample = {}'.format(i))

antithetical_iters = results[1]['iters']

#%% save
with open('results_losscurve/census_curves.pkl', 'wb') as f:
    save_dict = {
        'kernelshap': kernelshap_curves,
        'kernelshap_iters': kernelshap_iters,

        'paired_sampling': paired_curves,
        'paired_sampling_iters': paired_iters,

        'sampling_curves': sampling_curves,
        'sampling_iters': sampling_iters,
        
        'antithetical_curves': antithetical_curves,
        'antithetical_iters': antithetical_iters,
    }
    pickle.dump(save_dict, f)

#%% Fastshap
fastshap_curves = []
for i in range(num_samples):
    x = X_test[ind[i]][None, :]
    curve = fastshap.shap_values(x)[0]
    fastshap_curves.append(curve.T)
    print('Done with sample = {}'.format(i))

with open('results_losscurve/census_fastshap.pkl', 'wb') as f:
    pickle.dump(fastshap_curves, f)

#%% simshap
simshap_curves = []
for i in range(num_samples):
    x = X_test[ind[i]][None, :]
    curve = simshap.shap_values(x)[0]
    simshap_curves.append(curve)
    print('Done with sample = {}'.format(i))

with open('results_losscurve/census_simshap.pkl', 'wb') as f:
    pickle.dump(simshap_curves, f)
#%% Load curves
import numpy as np
import pickle 
with open('results_losscurve/census_curves.pkl', 'rb') as f:
    save_dict = pickle.load(f)
    
kernelshap_curves = save_dict['kernelshap']
kernelshap_iters = save_dict['kernelshap_iters']

paired_curves = save_dict['paired_sampling']
paired_iters = save_dict['paired_sampling_iters']

sampling_curves = save_dict['sampling_curves']
sampling_iters = save_dict['sampling_iters']

antithetical_curves = save_dict['antithetical_curves']
antithetical_iters = save_dict['antithetical_iters']
with open('results_losscurve/census_shap.pkl', 'rb') as f:
    shap_values = np.array(pickle.load(f))

with open('results_losscurve/census_fastshap.pkl', 'rb') as f:
    fastshap_curves = pickle.load(f)

with open('results_losscurve/census_simshap.pkl', 'rb') as f:
    simshap_curves = pickle.load(f)
#%% Visualiztion
import seaborn as sns
import matplotlib.pyplot as plt
plt.figure(figsize=(9, 5.5))


sns.set_style('white')
plt.figure(figsize=(9, 5.5))
ax=plt.gca()
ax.spines['right'].set_color('none')
ax.spines['top'].set_color('none')
def euclidean_dist(values, target):
    return np.sqrt(np.sum((values - target) ** 2, axis=(-2, -1)))

def l1_dist(values, target):
    return np.sum(np.abs(values - target), axis=(-2, -1))

dist = euclidean_dist(kernelshap_curves, shap_values[:, np.newaxis])
plt.plot(kernelshap_iters, np.mean(dist, axis=0),
         label='KernelSHAP', color='tab:blue')
plt.fill_between(kernelshap_iters,
                 np.mean(dist, axis=0) - 1.96 * np.std(dist, axis=0) / np.sqrt(len(shap_values)),
                 np.mean(dist, axis=0) + 1.96 * np.std(dist, axis=0) / np.sqrt(len(shap_values)),
                 color='tab:blue', alpha=0.1)

# KernelSHAP (paired sampling)
dist = euclidean_dist(paired_curves, shap_values[:, np.newaxis])
plt.plot(paired_iters, np.mean(dist, axis=0),
         label='KernelSHAP (Paired)', color='tab:orange')
plt.fill_between(paired_iters,
                 np.mean(dist, axis=0) - 1.96 * np.std(dist, axis=0) / np.sqrt(len(shap_values)),
                 np.mean(dist, axis=0) + 1.96 * np.std(dist, axis=0) / np.sqrt(len(shap_values)),
                 color='tab:orange', alpha=0.1)

# Permutation sampling
dist = euclidean_dist(sampling_curves, shap_values[:, np.newaxis])
plt.plot(sampling_iters, np.mean(dist, axis=0),
         label='Permutation Sampling', color='tab:purple')
plt.fill_between(sampling_iters,
                 np.mean(dist, axis=0) - 1.96 * np.std(dist, axis=0) / np.sqrt(len(shap_values)),
                 np.mean(dist, axis=0) + 1.96 * np.std(dist, axis=0) / np.sqrt(len(shap_values)),
                 color='tab:purple', alpha=0.1)

# Antithetical sampling
dist = euclidean_dist(antithetical_curves, shap_values[:, np.newaxis])
plt.plot(antithetical_iters, np.mean(dist, axis=0),
         label='Permutation Sampling (Antithetical)', color='tab:pink')
plt.fill_between(antithetical_iters,
                 np.mean(dist, axis=0) - 1.96 * np.std(dist, axis=0) / np.sqrt(len(shap_values)),
                 np.mean(dist, axis=0) + 1.96 * np.std(dist, axis=0) / np.sqrt(len(shap_values)),
                 color='tab:pink', alpha=0.1)

# Fastshap
dist = euclidean_dist(fastshap_curves, shap_values)
num_eval = np.arange(0, 1250, 10)
plt.plot(num_eval, np.mean(dist, axis=0).repeat(len(num_eval)),
            label='FastSHAP', color='tab:green')
plt.fill_between(num_eval,
                    np.mean(dist, axis=0) - 1.96 * np.std(dist, axis=0) / np.sqrt(len(shap_values)),
                    np.mean(dist, axis=0) + 1.96 * np.std(dist, axis=0) / np.sqrt(len(shap_values)),
                    color='tab:green', alpha=0.1)

# Simshap
distl1 = l1_dist(simshap_curves, shap_values)
dist = euclidean_dist(simshap_curves, shap_values)
num_eval = np.arange(0, 1250, 10)
plt.plot(num_eval, np.mean(dist, axis=0).repeat(len(num_eval)),
            label='SimSHAP', color='tab:red')
plt.fill_between(num_eval,
                    np.mean(dist, axis=0) - 1.96 * np.std(dist, axis=0) / np.sqrt(len(shap_values)),
                    np.mean(dist, axis=0) + 1.96 * np.std(dist, axis=0) / np.sqrt(len(shap_values)),
                    color='tab:red', alpha=0.1)

plt.plot()
# Formatting
# plt.ylim(0, 0.2)
plt.xlim(0, 1250)
font=  {'family' : 'Times New Roman',
'weight' : 'normal',
'size'   : 16,
}
plt.legend(fontsize=16, frameon=False)
plt.tick_params(labelsize=14)
plt.ylabel(r'Mean $\ell_2$ distance', fontdict=font)
plt.xlabel('# Evals', fontsize=16)
plt.title('Census', fontsize=18)

plt.tight_layout()
plt.savefig('results_figures/census_l2_curves.pdf')
plt.savefig('results_figures/census_l2_curves.png', dpi=300)
plt.show()
# %%
