#%% 
import shap  # https://github.com/slundberg/shap
import shapreg  # https://github.com/iancovert/shapley-regression
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--test_training_speed', type=bool, default=False)
args = parser.parse_args()
#%% 
# Load and split data
X_train, X_test, Y_train, Y_test = train_test_split(
    *shap.datasets.adult(), test_size=0.2, random_state=7)
X_train, X_val, Y_train, Y_val = train_test_split(
    X_train, Y_train, test_size=0.2, random_state=0)

# Data scaling
num_features = X_train.shape[1]
feature_names = X_train.columns.tolist()
ss = StandardScaler()
ss.fit(X_train)
X_train = ss.transform(X_train)
X_val = ss.transform(X_val)
X_test = ss.transform(X_test)

#%% Train Model
import pickle
import os.path
import lightgbm as lgb
from lightgbm import log_evaluation, early_stopping

#%% 
if os.path.isfile('census model.pkl'):
    print('Loading saved model')
    with open('census model.pkl', 'rb') as f:
        model = pickle.load(f)

else:
    # Setup
    params = {
        "max_bin": 512,
        "learning_rate": 0.05,
        "boosting_type": "gbdt",
        "objective": "binary",
        "metric": "binary_logloss",
        "num_leaves": 10,
        "verbose": -1,
        "min_data": 100,
        "boost_from_average": True
    }

    # More setup
    d_train = lgb.Dataset(X_train, label=Y_train)
    d_val = lgb.Dataset(X_val, label=Y_val)
    callbacks = [log_evaluation(period=1000), early_stopping(stopping_rounds=50)]
    # Train model
    model = lgb.train(params, d_train, 10000, valid_sets=[d_val],
                      callbacks=callbacks)
    
    # Save model
    with open('census model.pkl', 'wb') as f:
        pickle.dump(model, f)

#%% Train surrogate
import torch
import torch.nn as nn
from fastshap.utils import MaskLayer1d
from fastshap import Surrogate, KLDivLoss

# Select device
device = torch.device('cuda')

#%% 
# Check for model
if os.path.isfile('census surrogate.pt'):
    print('Loading saved surrogate model')
    surr = torch.load('census surrogate.pt').to(device)
    surrogate = Surrogate(surr, num_features)

else:
    # Create surrogate model
    surr = nn.Sequential(
        MaskLayer1d(value=0, append=True),
        nn.Linear(2 * num_features, 128),
        nn.ELU(inplace=True),
        nn.Linear(128, 128),
        nn.ELU(inplace=True),
        nn.Linear(128, 2)).to(device)

    # Set up surrogate object
    surrogate = Surrogate(surr, num_features)

    # Set up original model
    def original_model(x):
        pred = model.predict(x.cpu().numpy())
        pred = np.stack([1 - pred, pred]).T
        return torch.tensor(pred, dtype=torch.float32, device=x.device)

    # Train
    surrogate.train_original_model(
        X_train,
        X_val,
        original_model,
        batch_size=64,
        max_epochs=100,
        loss_fn=KLDivLoss(),
        validation_samples=10,
        validation_batch_size=10000,
        verbose=True)

    # Save surrogate
    surr.cpu()
    torch.save(surr, 'census surrogate.pt')
    surr.to(device)

#%% Train SimSHAP
import time
from simshap.simshap_sampling import SimSHAPSampling
import sys
sys.path.append('..')
from models import SimSHAPTabular
import time
# Check for model
if os.path.isfile('census simshap.pt'):
    print('Loading saved explainer model')
    explainer = torch.load('census simshap.pt').to(device)
    simshap = SimSHAPSampling(explainer=explainer, imputer=surrogate, device=device)

else:
# Create explainer model
    explainer = SimSHAPTabular(in_dim=num_features, hidden_dim=64, out_dim=2).to(device)

    # Set up FastSHAP object
    simshap = SimSHAPSampling(explainer=explainer, imputer=surrogate, device=device)
    # Train
    if args.test_training_speed:
        start = time.time()
    simshap.train(
        X_train,
        X_val[:100],
        batch_size=2048,
        num_samples=64,
        max_epochs=1000,
        lr=1.5e-2,  
        bar=False,
        validation_samples=1024,
        verbose=True, 
        lookback=20,
        lr_factor=0.5)
    if args.test_training_speed:
        print('simshap training time: ', time.time() - start)
    # Save explainer
    explainer.cpu()
    torch.save(explainer, 'census simshap.pt')
    explainer.to(device)

#%% fastshap
from simshap.fastshap_plus import FastSHAP
import time
# Check for model
if os.path.isfile('census fastshap.pt'):
    print('Loading saved explainer model')
    explainer_fastshap = torch.load('census fastshap.pt').to(device)
    fastshap = FastSHAP(explainer_fastshap, surrogate, normalization='additive',
                        link=nn.Identity())

else:
# Create explainer model
    explainer_fastshap = nn.Sequential(
        nn.Linear(num_features, 128),
        nn.ReLU(inplace=True),
        nn.Linear(128, 128),
        nn.ReLU(inplace=True),
        nn.Linear(128, 2 * num_features)).to(device)

    # Set up FastSHAP object
    fastshap = FastSHAP(explainer_fastshap, surrogate, 
                        link=nn.Identity(), normalization='additive')
    if args.test_training_speed:
        start = time.time()
    # Train
    fastshap.train(
        X_train,
        X_val[:100],
        batch_size=32,
        num_samples=32,
        max_epochs=200,
        validation_samples=128,
        verbose=True)
    if args.test_training_speed:
        print('fastshap training time: ', time.time() - start)
    # Save explainer
    explainer_fastshap.cpu()
    torch.save(explainer_fastshap, 'census fastshap.pt')
    explainer_fastshap.to(device)

#%% Compare with KernelSHAP
import matplotlib.pyplot as plt
# Setup for KernelSHAP
def imputer(x, S):
    x = torch.tensor(x, dtype=torch.float32, device=device)
    S = torch.tensor(S, dtype=torch.float32, device=device)
    pred = surrogate(x, S)
    return pred.cpu().data.numpy()

# Select example
np.random.seed(20)
# ind = np.random.choice(len(X_test))
ind = 0
x = X_test[ind:ind+1]
y = int(Y_test[ind])

# Run evoshap
simshap_values = simshap.shap_values(x)[0].transpose(1,0)
fastshap_values = fastshap.shap_values(x)[0]
# Run KernelSHAP to convergence
game = shapreg.games.PredictionGame(imputer, x)
shap_values, all_results = shapreg.shapley.ShapleyRegression(
    game, batch_size=32, paired_sampling=False, detect_convergence=True,
    bar=True, return_all=True)

# Create figure
plt.figure(figsize=(9, 5.5))

# Bar chart
width = 0.75
kernelshap_iters = 128
plt.bar(np.arange(num_features) - width / 2, shap_values.values[:, y],
        width / 4, label='True SHAP values', color='tab:gray')
plt.bar(np.arange(num_features) - width / 4, simshap_values[:, y],
        width / 4, label='SimSHAP', color='tab:green')
plt.bar(np.arange(num_features),
        fastshap_values[:, y],
        width / 4, label='fastSHAP', color='tab:blue')
plt.bar(np.arange(num_features) + width / 4,
        all_results['values'][list(all_results['iters']).index(kernelshap_iters)][:, y],
        width / 4, label='KernelSHAP @ {}'.format(kernelshap_iters), color='tab:red')

# Annotations
plt.legend(fontsize=16)
plt.tick_params(labelsize=14)
plt.ylabel('SHAP Values', fontsize=16)
plt.title('Census Explanation Example', fontsize=18)
plt.xticks(np.arange(num_features), feature_names,
           rotation=35, rotation_mode='anchor', ha='right')
print('simshap:', np.sqrt(np.sum((shap_values.values - simshap_values)**2)))
print('fastshap', np.sqrt(np.sum((shap_values.values - fastshap_values)**2)))
plt.tight_layout()
plt.savefig('census simshap.png')
plt.show()