import numpy as np
import argparse
import shap
import pandas as pd
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from tqdm import tqdm

from tensorflow.python.keras.layers import Dense, Input, Flatten, Add, Multiply, Lambda
from tensorflow.python.keras.layers.normalization import BatchNormalization
from tensorflow.python.keras import regularizers
from tensorflow.python.keras.models import Model, Sequential
from tensorflow.python.keras import optimizers
from tensorflow.python.keras.callbacks import ModelCheckpoint

from utils.explanations import calculate_robust_astute_sampled


np.random.seed(0)

def shap_explainer(datatype, ball_r, epsilon, prop_points, exponentiate):
    blackbox_path = 'models/' + datatype + '_blackbox.hdf5'
    data_dict = pd.read_csv('data/magic04.data').values
    data = data_dict[:, :-1]
    labels = data_dict[:, -1]
    labels[labels == 'h'] = 0
    labels[labels == 'g'] = 1
    x_train, x_val, y_train, y_val = train_test_split(data, labels, test_size=0.05, stratify=labels, random_state=42)
    x_train = StandardScaler().fit_transform(x_train)
    x_val = StandardScaler().fit_transform(x_val)
    input_shape = x_train.shape[-1]
    activation = 'relu'

    model_input = Input(shape=(input_shape,), dtype='float32')

    net = Dense(32, activation=activation, name='dense1',
                kernel_regularizer=regularizers.l2(1e-3))(model_input)

    bbox_preds = Dense(1, activation='sigmoid', name='dense3',
                  kernel_regularizer=regularizers.l2(1e-3))(net)
    bbox_model = Model(model_input, bbox_preds)
    bbox_model.load_weights(blackbox_path,
                       by_name=True)

    background = x_train[np.random.choice(len(x_train), 100, replace=False)]
    explainer = shap.GradientExplainer(bbox_model, background)


    explanation = calculate_robust_astute_sampled(data=x_val,
                                                                   explainer=explainer,
                                                                   explainer_type='shap',
                                                                   explanation_type='attribution',
                                                                   data_explanation=None,
                                                                   ball_r=ball_r,
                                                                   epsilon=epsilon,
                                                                   num_points=int(prop_points * len(x_val)),
                                                                   exponentiate=exponentiate)

    return np.abs(explanation)

if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument('--datatype', type=str,
                        choices=['telescope'], default='telescope')
    parser.add_argument('--ball_radius', type=float, default=2)
    parser.add_argument('--epsilon', type=float, default=0.05)
    parser.add_argument('--prop_points', type=float, default=.01)
    parser.add_argument('--run_times', type=int, default=1)
    parser.add_argument('--exponentiate', type=int, default=0)
    args = parser.parse_args()

    for i in tqdm(range(args.run_times)):
        fname = 'explained_weights/shap/' + 'shap_' + args.datatype + '_' + str(i) + '.gz'
        explanation = shap_explainer(datatype=args.datatype,
                                                       ball_r=args.ball_radius,
                                                       epsilon=args.epsilon,
                                                       prop_points=args.prop_points,
                                                       exponentiate=args.exponentiate)
        np.savetxt(X=explanation[0], fname=fname, delimiter=',')
