#%% 
import shap  # https://github.com/slundberg/shap
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import seaborn as sns
import time
#%% 
df = pd.read_csv(
    "OnlineNewsPopularity/OnlineNewsPopularity.csv", 
    skipinitialspace=True
)
sr_Y = df['shares']
df_X = df.drop(
    ['url', 'timedelta', 'shares'], # non-predictive & taget features
    axis=1
)# 58 features remaining


#%% preprocessing
'''
Input features:

After dropping non-predictive input features, there are 58 input features remaining.

Data types: All the input features have numerical values, but by counting the unique values, we can tell some of them are actually binary. And some of them are related in content. So we can merge them into multinomial features encoded by integers to reduce dimension.

Feature scales: The scales(range) of these features are quite different. Standardization are needed before fitting models on them.

Outliers: If we define values greater than 5 std from mean as outliers, some features have considerable amount of outliers thus outlier removal is needed as well in preprocessing.
'''

def outlierCounts(series):
    centered = np.abs(series - series.mean())
    mask     = centered >= (5 * series.std())
    return len(series[mask])

def uniqueValueCount(series):
    return len(series.unique())

input_feats          = df_X.dtypes.reset_index()
input_feats.columns  = ['name', 'dtype']
input_feats['mean']  = df_X.mean().reset_index(drop=True)
input_feats['std']   = df_X.std().reset_index(drop=True)
input_feats['range'] = (df_X.max() - df_X.min())\
    .reset_index(drop=True)

input_feats['unique_values_count'] = df_X\
    .apply(uniqueValueCount, axis=0)\
    .reset_index(drop=True)
    
input_feats['outliers_count'] = df_X\
    .apply(outlierCounts, axis=0)\
    .reset_index(drop=True)


#%% Merge Binary Features
'''
Among those binary features, there are 6 describing the content categories of news and 7 describing the publish weekday. We can merge them and reduce the number of features to 47.
'''
def mergeFeatures(df, old_feats, new_feat):
    """ merge binary features in dataframe with int encoding.
    in: dataframe, binaryFeatureNames and multinomialFeatureName
    out: newDataframe
    """
    counter = 0
    df[new_feat] = counter
    
    for old_feat in old_feats:
        counter += 1
        df.loc[df[old_feat] == 1.0, new_feat] = counter
        del df[old_feat]
    
    return df

data_channels = [
    'data_channel_is_lifestyle',
    'data_channel_is_entertainment',
    'data_channel_is_bus',
    'data_channel_is_socmed',
    'data_channel_is_tech',
    'data_channel_is_world'
]
weekdays = [
    'weekday_is_monday',
    'weekday_is_tuesday',
    'weekday_is_wednesday',
    'weekday_is_thursday',
    'weekday_is_friday',
    'weekday_is_saturday',
    'weekday_is_sunday'
]
df = mergeFeatures(df_X, data_channels, 'data_channel')
df = mergeFeatures(df_X, weekdays, 'pub_weekday')


#%% Remove Outliers and Normalize features
import sklearn.preprocessing as prep

# remove outliers
for col in df_X.columns:
    centered = np.abs(df_X[col]-df_X[col].mean())
    mask     = centered <= (5 * df_X[col].std())
    df_X     = df_X[mask]

sr_Y = sr_Y[df_X.index]

def standarize(arr_X):
    arr_X = prep.MinMaxScaler().fit_transform(arr_X)
    return arr_X - arr_X.mean(axis=1).reshape(-1, 1)

arr_X = df_X.values
arr_X = standarize(arr_X)

#%% Binarize Y
'''
As we mentioned before, we use the median value 1400 as the threshold to binarize target feature and divide the data points into 2 classes. Because of the outlier removal, the sizes of 2 classes are not the same anymore. But they are still more or less balanced.
'''
arr_Y = prep.binarize(
    sr_Y.values.reshape(-1, 1), 
    threshold=1400 # using original median as threshold
) 
sr_Y  = pd.Series(arr_Y.ravel())

unique_items, counts = np.unique(arr_Y, return_counts=True)
#%% 
# Load and split data
X_train, X_test, Y_train, Y_test = train_test_split(
    arr_X, arr_Y, test_size=0.2, random_state=7)
X_train, X_val, Y_train, Y_val = train_test_split(
    X_train, Y_train, test_size=0.2, random_state=0)


# Data scaling
num_features = X_train.shape[1]
ss = StandardScaler()
ss.fit(X_train)
X_train = ss.transform(X_train)
X_val = ss.transform(X_val)
X_test = ss.transform(X_test)

#%% Load Model
import pickle
import os.path
import sys
sys.path.append('..')
from copy import deepcopy
import torch
import torch.nn as nn
from tqdm.auto import tqdm
from fastshap import Surrogate
from torch.utils.data import DataLoader
import torch.optim as optim
import lightgbm as lgb
from lightgbm.callback import log_evaluation, early_stopping
device = torch.device('cuda')

os.environ['PYTHONHASHSEED'] = str(2)
np.random.seed(2)
torch.manual_seed(2)
torch.cuda.manual_seed(2)
all = 1000
ind = np.random.choice(len(X_test), size=all)

x = X_test[ind]
y = Y_test[ind]
with open('news model.pkl', 'rb') as f:
        model = pickle.load(f)
def original_model(x):
    pred = model.predict(x.cpu().detach().numpy())
    pred = np.stack([1 - pred, pred]).T
    return torch.tensor(pred, dtype=torch.float32, device=x.device)
surr = torch.load('news surrogate.pt').to(device)
surrogate = Surrogate(surr, num_features)

#%% Fastshap
from simshap.fastshap_plus import FastSHAP

explainer_fastshap = torch.load('news fastshap.pt').to(device)
explainer_fastshap.eval()
fastshap = FastSHAP(explainer_fastshap, surrogate,normalization='additive',
                        link=nn.Identity())

start = time.time()
fastshap_values = fastshap.shap_values(torch.tensor(x, dtype=torch.float32, device=device))
end = time.time()
print('fastshap running time:', end - start)
#%% Simshap
from simshap.simshap_sampling import SimSHAPSampling
explainer_sim = torch.load('news simshap.pt').to(device)
explainer_sim.eval()
simshap = SimSHAPSampling(explainer_sim, surrogate, device=device)

start = time.time()
simshap_values = simshap.shap_values(x)
end = time.time()
print('simshap running time:', end - start)
#%% Kernelshap
import shap
def model_wrapper(x):
    x = torch.tensor(x, dtype=torch.float32, device=device)
    pred = original_model(x)
    return pred.cpu().data.numpy()
med = np.median(X_train, axis=0).reshape((1, -1))
kernelshap = shap.KernelExplainer(model_wrapper, med)
start = time.time()
kernelshap_values = kernelshap.shap_values(x, nsamples='auto')
end = time.time()
print('kernelshap running time:', end - start)
#%% Kernelshap-S
def imputer(x, S):
    x = torch.tensor(x, dtype=torch.float32, device=device)
    S = torch.tensor(S, dtype=torch.float32, device=device)
    pred = surrogate(x, S)
    return pred.cpu().data.numpy()


med = np.median(X_train, axis=0).reshape((1, -1))
start = time.time()
for i in range(all):
    def f_mask(z):
        return imputer(x[i:i+1], z)
    kernelshap_s = shap.KernelExplainer(f_mask, np.zeros((1, 47)))
    kernelshap_s.shap_values(x[i:i+1], nsamples='auto')
end = time.time()
print('kernelshap-s running time:', end - start)

# %%

#%% IG
from captum.attr import IntegratedGradients
ig = IntegratedGradients(original_model)
start = time.time()
ig_values = ig.attribute(torch.tensor(x, dtype=torch.float32, device=device))
end = time.time()
print('ig running time:', end - start)
#%% SmoothGrad
from captum.attr import IntegratedGradients, NoiseTunnel
ig = IntegratedGradients(original_model)
sg = NoiseTunnel(ig)
start = time.time()
sg_values = sg.attribute(torch.tensor(x, dtype=torch.float32, device=device), nt_type='smoothgrad',nt_samples=4, target=y)
end = time.time()
print('sg running time:', end - start)