#%% 
import shap  # https://github.com/slundberg/shap
import shapreg  # https://github.com/iancovert/shapley-regression
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import seaborn as sns

# mkdir
import os
if not os.path.exists('results_losscurve'):
    os.makedirs('results_losscurve')
#%% 
df = pd.read_csv(
    "OnlineNewsPopularity/OnlineNewsPopularity.csv", 
    skipinitialspace=True
)
sr_Y = df['shares']
df_X = df.drop(
    ['url', 'timedelta', 'shares'], # non-predictive & taget features
    axis=1
)# 58 features remaining


# preprocessing
'''
Input features:

After dropping non-predictive input features, there are 58 input features remaining.

Data types: All the input features have numerical values, but by counting the unique values, we can tell some of them are actually binary. And some of them are related in content. So we can merge them into multinomial features encoded by integers to reduce dimension.

Feature scales: The scales(range) of these features are quite different. Standardization are needed before fitting models on them.

Outliers: If we define values greater than 5 std from mean as outliers, some features have considerable amount of outliers thus outlier removal is needed as well in preprocessing.
'''
def outlierCounts(series):
    centered = np.abs(series - series.mean())
    mask     = centered >= (5 * series.std())
    return len(series[mask])

def uniqueValueCount(series):
    return len(series.unique())

input_feats          = df_X.dtypes.reset_index()
input_feats.columns  = ['name', 'dtype']
input_feats['mean']  = df_X.mean().reset_index(drop=True)
input_feats['std']   = df_X.std().reset_index(drop=True)
input_feats['range'] = (df_X.max() - df_X.min())\
    .reset_index(drop=True)

input_feats['unique_values_count'] = df_X\
    .apply(uniqueValueCount, axis=0)\
    .reset_index(drop=True)
    
input_feats['outliers_count'] = df_X\
    .apply(outlierCounts, axis=0)\
    .reset_index(drop=True)


# Merge Binary Features
'''
Among those binary features, there are 6 describing the content categories of news and 7 describing the publish weekday. We can merge them and reduce the number of features to 47.
'''
def mergeFeatures(df, old_feats, new_feat):
    """ merge binary features in dataframe with int encoding.
    in: dataframe, binaryFeatureNames and multinomialFeatureName
    out: newDataframe
    """
    counter = 0
    df[new_feat] = counter
    
    for old_feat in old_feats:
        counter += 1
        df.loc[df[old_feat] == 1.0, new_feat] = counter
        del df[old_feat]
    
    return df

data_channels = [
    'data_channel_is_lifestyle',
    'data_channel_is_entertainment',
    'data_channel_is_bus',
    'data_channel_is_socmed',
    'data_channel_is_tech',
    'data_channel_is_world'
]
weekdays = [
    'weekday_is_monday',
    'weekday_is_tuesday',
    'weekday_is_wednesday',
    'weekday_is_thursday',
    'weekday_is_friday',
    'weekday_is_saturday',
    'weekday_is_sunday'
]
df = mergeFeatures(df_X, data_channels, 'data_channel')
df = mergeFeatures(df_X, weekdays, 'pub_weekday')


# Remove Outliers and Normalize features
import sklearn.preprocessing as prep

# remove outliers
for col in df_X.columns:
    centered = np.abs(df_X[col]-df_X[col].mean())
    mask     = centered <= (5 * df_X[col].std())
    df_X     = df_X[mask]

sr_Y = sr_Y[df_X.index]

def standarize(arr_X):
    arr_X = prep.MinMaxScaler().fit_transform(arr_X)
    return arr_X - arr_X.mean(axis=1).reshape(-1, 1)

arr_X = df_X.values
arr_X = standarize(arr_X)

# Binarize Y
'''
As we mentioned before, we use the median value 1400 as the threshold to binarize target feature and divide the data points into 2 classes. Because of the outlier removal, the sizes of 2 classes are not the same anymore. But they are still more or less balanced.
'''
arr_Y = prep.binarize(
    sr_Y.values.reshape(-1, 1), 
    threshold=1400 # using original median as threshold
) 
sr_Y  = pd.Series(arr_Y.ravel())

unique_items, counts = np.unique(arr_Y, return_counts=True)
#%% 
# Load and split data
X_train, X_test, Y_train, Y_test = train_test_split(
    arr_X, arr_Y, test_size=0.2, random_state=7)
X_train, X_val, Y_train, Y_val = train_test_split(
    X_train, Y_train, test_size=0.2, random_state=0)



# Data scaling
num_features = X_train.shape[1]
ss = StandardScaler()
ss.fit(X_train)
X_train = ss.transform(X_train)
X_val = ss.transform(X_val)
X_test = ss.transform(X_test)

#%% load model
import pickle
from fastshap import Surrogate, FastSHAP
from simshap.simshap_sampling import SimSHAPSampling
import torch
import torch.nn as nn
import sys
sys.path.append('..')
from models import SimSHAPTabular
device = torch.device('cuda')
with open('news model.pkl', 'rb') as f:
        model = pickle.load(f)
surr = torch.load('news surrogate.pt').to(device)
surrogate = Surrogate(surr, num_features)
explainer_fastshap = torch.load('news fastshap.pt').to(device)
explainer_fastshap.eval()
fastshap = FastSHAP(explainer_fastshap, surrogate, normalization='additive',
                        link=nn.Identity())
explainer_simshap = torch.load('news simshap.pt').to(device)
explainer_simshap.eval()
simshap = SimSHAPSampling(explainer=explainer_simshap, imputer=surrogate, device=device)

#%% Get SHAP values of fastshap and simshap

def imputer(x, S):
    x = torch.tensor(x, dtype=torch.float32, device=device)
    S = torch.tensor(S, dtype=torch.float32, device=device)
    pred = surrogate(x, S)
    return pred.cpu().data.numpy()
num_eval = np.arange(0, 600, 4)
loss_fastshap_lst = []
loss_simshap_lst = []
np.random.seed(200)
num_samples = 256
samples = 4096
thresh = 0.001
ind = np.random.choice(len(X_test), size=num_samples)
# ind = np.arange(num_samples)

loss_kernelshap = []
loss_kernelshap_pair = []
loss_permutation = []
loss_antithesis = []

#%% GT shap_Values
shap_values = []

for i in range(num_samples):
    # Get instance
    x = X_test[ind[i]]

    # Set up game
    game = shapreg.games.PredictionGame(imputer, x)
    
    # Calculate ground truth SHAP values
    explanation = shapreg.shapley.ShapleyRegression(game, thresh=thresh, bar=False)
    shap_values.append(explanation.values.T)
    print('Done with sample = {}'.format(i))

with open('results_losscurve/census_shap.pkl', 'wb') as f:
    pickle.dump(shap_values, f)
#%% kernelshap

kernelshap_curves = []
for i in range(num_samples):
    x = X_test[ind[i]]
    game = shapreg.games.PredictionGame(imputer, x)
    results = shapreg.shapley.ShapleyRegression(game, batch_size=64, n_samples=samples, detect_convergence=False,
                                        bar=False, paired_sampling=False, return_all=True)
         
    curve = np.array([explanation.T for explanation in results[1]['values']])
    kernelshap_curves.append(curve)
    print('Done with sample = {}'.format(i))

kernelshap_iters = results[1]['iters']

#%% kernelshap_pair
paired_curves = []
for i in range(num_samples):
    x = X_test[ind[i]]
    game = shapreg.games.PredictionGame(imputer, x)
    results = shapreg.shapley.ShapleyRegression(game, batch_size=64, n_samples=(samples / 2), detect_convergence=False,
                                        bar=False, paired_sampling=True, return_all=True)
         
    curve = np.array([explanation.T for explanation in results[1]['values']])
    paired_curves.append(curve)
    print('Done with sample = {}'.format(i))

paired_iters = results[1]['iters']

#%% permutation
sampling_curves = []

for i in range(num_samples):
    # Get instance
    x = X_test[ind[i]]

    # Set up game
    game = shapreg.games.PredictionGame(imputer, x)
    
    # Calculate ground truth SHAP values
    results = shapreg.shapley_sampling.ShapleySampling(game, batch_size=1, n_samples=int(np.ceil(samples / num_features)), detect_convergence=False,
                                               bar=False, return_all=True)
    curve = np.array([explanation.T for explanation in results[1]['values']])
    sampling_curves.append(curve)
    print('Done with sample = {}'.format(i))

sampling_iters = results[1]['iters']


#%% antithetical 
antithetical_curves = []

for i in range(num_samples):
    # Get instance
    x = X_test[ind[i]]

    # Set up game
    game = shapreg.games.PredictionGame(imputer, x)
    
    # Calculate ground truth SHAP values
    results = shapreg.shapley_sampling.ShapleySampling(game, batch_size=2, n_samples=int(np.ceil(samples / num_features)), detect_convergence=False,
                                               bar=False, antithetical=True, return_all=True)
    curve = np.array([explanation.T for explanation in results[1]['values']])
    antithetical_curves.append(curve)
    print('Done with sample = {}'.format(i))

antithetical_iters = results[1]['iters']

#%% save
with open('results_losscurve/census_curves.pkl', 'wb') as f:
    save_dict = {
        'kernelshap': kernelshap_curves,
        'kernelshap_iters': kernelshap_iters,

        'paired_sampling': paired_curves,
        'paired_sampling_iters': paired_iters,

        'sampling_curves': sampling_curves,
        'sampling_iters': sampling_iters,
        
        'antithetical_curves': antithetical_curves,
        'antithetical_iters': antithetical_iters,
    }
    pickle.dump(save_dict, f)

#%% Fastshap
fastshap_curves = []
for i in range(num_samples):
    x = X_test[ind[i]][None, :]
    curve = fastshap.shap_values(x)[0]
    fastshap_curves.append(curve.T)
    print('Done with sample = {}'.format(i))

with open('results_losscurve/census_fastshap.pkl', 'wb') as f:
    pickle.dump(fastshap_curves, f)

#%% simshap
simshap_curves = []
for i in range(num_samples):
    x = X_test[ind[i]][None, :]
    curve = simshap.shap_values(x)[0]
    simshap_curves.append(curve)
    print('Done with sample = {}'.format(i))

with open('results_losscurve/census_simshap.pkl', 'wb') as f:
    pickle.dump(simshap_curves, f)
#%% Load curves
import numpy as np
import pickle 
with open('results_losscurve/census_curves.pkl', 'rb') as f:
    save_dict = pickle.load(f)
    
kernelshap_curves = save_dict['kernelshap']
kernelshap_iters = save_dict['kernelshap_iters']

paired_curves = save_dict['paired_sampling']
paired_iters = save_dict['paired_sampling_iters']

sampling_curves = save_dict['sampling_curves']
sampling_iters = save_dict['sampling_iters']

antithetical_curves = save_dict['antithetical_curves']
antithetical_iters = save_dict['antithetical_iters']
with open('results_losscurve/census_shap.pkl', 'rb') as f:
    shap_values = np.array(pickle.load(f))

with open('results_losscurve/census_fastshap.pkl', 'rb') as f:
    fastshap_curves = pickle.load(f)

with open('results_losscurve/census_simshap.pkl', 'rb') as f:
    simshap_curves = pickle.load(f)
#%% Visualiztion
import seaborn as sns
import matplotlib.pyplot as plt
plt.figure(figsize=(9, 5.5))


sns.set_style('white')
plt.figure(figsize=(9, 5.5))
ax=plt.gca()
ax.spines['right'].set_color('none')
ax.spines['top'].set_color('none')
def euclidean_dist(values, target):
    return np.sqrt(np.sum((values - target) ** 2, axis=(-2, -1)))

def l1_dist(values, target):
    return np.sum(np.abs(values - target), axis=(-2, -1))

dist = euclidean_dist(kernelshap_curves, shap_values[:, np.newaxis])
plt.plot(kernelshap_iters, np.mean(dist, axis=0),
         label='KernelSHAP', color='tab:blue')
plt.fill_between(kernelshap_iters,
                 np.mean(dist, axis=0) - 1.96 * np.std(dist, axis=0) / np.sqrt(len(shap_values)),
                 np.mean(dist, axis=0) + 1.96 * np.std(dist, axis=0) / np.sqrt(len(shap_values)),
                 color='tab:blue', alpha=0.1)

# KernelSHAP (paired sampling)
dist = euclidean_dist(paired_curves, shap_values[:, np.newaxis])
plt.plot(paired_iters, np.mean(dist, axis=0),
         label='KernelSHAP (Paired)', color='tab:orange')
plt.fill_between(paired_iters,
                 np.mean(dist, axis=0) - 1.96 * np.std(dist, axis=0) / np.sqrt(len(shap_values)),
                 np.mean(dist, axis=0) + 1.96 * np.std(dist, axis=0) / np.sqrt(len(shap_values)),
                 color='tab:orange', alpha=0.1)

# Permutation sampling
dist = euclidean_dist(sampling_curves, shap_values[:, np.newaxis])
plt.plot(sampling_iters, np.mean(dist, axis=0),
         label='Permutation Sampling', color='tab:purple')
plt.fill_between(sampling_iters,
                 np.mean(dist, axis=0) - 1.96 * np.std(dist, axis=0) / np.sqrt(len(shap_values)),
                 np.mean(dist, axis=0) + 1.96 * np.std(dist, axis=0) / np.sqrt(len(shap_values)),
                 color='tab:purple', alpha=0.1)

# Antithetical sampling
dist = euclidean_dist(antithetical_curves, shap_values[:, np.newaxis])
plt.plot(antithetical_iters, np.mean(dist, axis=0),
         label='Permutation Sampling (Antithetical)', color='tab:pink')
plt.fill_between(antithetical_iters,
                 np.mean(dist, axis=0) - 1.96 * np.std(dist, axis=0) / np.sqrt(len(shap_values)),
                 np.mean(dist, axis=0) + 1.96 * np.std(dist, axis=0) / np.sqrt(len(shap_values)),
                 color='tab:pink', alpha=0.1)

# Fastshap
dist = euclidean_dist(fastshap_curves, shap_values)
num_eval = np.arange(0, 1250, 10)
plt.plot(num_eval, np.mean(dist, axis=0).repeat(len(num_eval)),
            label='FastSHAP', color='tab:green')
plt.fill_between(num_eval,
                    np.mean(dist, axis=0) - 1.96 * np.std(dist, axis=0) / np.sqrt(len(shap_values)),
                    np.mean(dist, axis=0) + 1.96 * np.std(dist, axis=0) / np.sqrt(len(shap_values)),
                    color='tab:green', alpha=0.1)

# Simshap
dist = euclidean_dist(simshap_curves, shap_values)
distl1 = l1_dist(simshap_curves, shap_values)
num_eval = np.arange(0, 1250, 10)
plt.plot(num_eval, np.mean(dist, axis=0).repeat(len(num_eval)),
            label='SimSHAP', color='tab:red')
plt.fill_between(num_eval,
                    np.mean(dist, axis=0) - 1.96 * np.std(dist, axis=0) / np.sqrt(len(shap_values)),
                    np.mean(dist, axis=0) + 1.96 * np.std(dist, axis=0) / np.sqrt(len(shap_values)),
                    color='tab:red', alpha=0.1)

plt.plot()
# Formatting
plt.ylim(0, 0.7)
plt.xlim(0, 1250)
plt.legend(fontsize=16)
plt.tick_params(labelsize=14)
plt.ylabel(r'Mean $\ell_2$ distance', fontsize=16)
plt.xlabel('# Evals', fontsize=16)
plt.title('News SHAP Estimation', fontsize=18)

plt.tight_layout()
plt.savefig('results_figures/news_l2_curves.pdf')
plt.savefig('results_figures/news_l2_curves.png', dpi=300)
plt.show()
