import pickle
import numpy as np
import argparse
import shap

from tensorflow.python.keras.layers import Dense, Input, Flatten, Add, Multiply, Lambda
from tensorflow.python.keras.layers.normalization import BatchNormalization
from tensorflow.python.keras import regularizers
from tensorflow.python.keras.models import Model, Sequential


from utils.explanations import calculate_robust_astute_sampled


np.random.seed(0)

def shap_explainer(datatype, ball_r, epsilon, prop_points, exponentiate, classifier):
    data_dict = pickle.load(open('data/' + datatype + '.pk', 'rb'))
    x_train, y_train, x_val, y_val, datatype_val, input_shape = data_dict['x_train'], data_dict['y_train'], \
                                                                data_dict['x_val'], data_dict['y_val'], \
                                                                data_dict['datatype_val'], data_dict['input_shape']

    if classifier == '2layer':
        activation = 'relu' if datatype in ['orange_skin', 'XOR'] else 'selu'
        model_input = Input(shape=(input_shape,), dtype='float32')
        net = Dense(200, activation=activation, name='dense1',
                    kernel_regularizer=regularizers.l2(1e-3))(model_input)
        net = BatchNormalization()(net)  # Add batchnorm for stability.
        net = Dense(200, activation=activation, name='dense2',
                    kernel_regularizer=regularizers.l2(1e-3))(net)
        net = BatchNormalization()(net)
        preds = Dense(2, activation='softmax', name='dense4',
                      kernel_regularizer=regularizers.l2(1e-3))(net)
        bbox_model = Model(model_input, preds)
        bbox_model.load_weights('models/' + datatype + '_blackbox.hdf5',
                                by_name=True)
        pred_model = Model(model_input, preds)

    elif classifier == '4layer':
        activation = 'relu' if datatype in ['orange_skin', 'XOR'] else 'selu'

        model_input = Input(shape=(input_shape,), dtype='float32')
        net = Dense(50, activation=activation, name='dense1',
                    kernel_regularizer=regularizers.l2(1e-3))(model_input)
        net = BatchNormalization()(net)  # Add batchnorm for stability.
        net = Dense(50, activation=activation, name='dense2',
                    kernel_regularizer=regularizers.l2(1e-3))(net)
        net = BatchNormalization()(net)
        net = Dense(50, activation=activation, name='dense3',
                    kernel_regularizer=regularizers.l2(1e-3))(net)
        net = BatchNormalization()(net)
        net = Dense(50, activation=activation, name='dense4',
                    kernel_regularizer=regularizers.l2(1e-3))(net)
        net = BatchNormalization()(net)
        preds = Dense(2, activation='softmax', name='dense5',
                      kernel_regularizer=regularizers.l2(1e-3))(net)
        bbox_model = Model(model_input, preds)
        bbox_model.load_weights('models/' + datatype + '_blackbox_extra.hdf5',
                                by_name=True)
        pred_model = Model(model_input, preds)


    elif classifier == 'linear':
        activation = None

        model_input = Input(shape=(input_shape,), dtype='float32')

        net = Dense(200, activation=activation, name='dense1',
                    kernel_regularizer=regularizers.l2(1e-3))(model_input)
        net = BatchNormalization()(net)  # Add batchnorm for stability.
        preds = Dense(2, activation='softmax', name='dense4',
                      kernel_regularizer=regularizers.l2(1e-3))(net)
        bbox_model = Model(model_input, preds)
        bbox_model.load_weights('models/' + datatype + '_blackbox_linear.hdf5',
                                by_name=True)
        pred_model = Model(model_input, preds)


    elif classifier == 'svm':
        pred_model = pickle.load(open('models/' + datatype + '_svm.pk', 'rb'))

    if classifier == 'svm':
        training_indices = np.random.choice(len(x_train), int(0.001*len(x_train)), replace=False)
        explainer = shap.KernelExplainer(pred_model.predict_proba, shap.kmeans(x_train[training_indices], 100))


        explanation = calculate_robust_astute_sampled(data=x_val,
                                                   explainer=explainer,
                                                   explainer_type='shap',
                                                   explanation_type='attribution',
                                                   ball_r=ball_r,
                                                   epsilon=epsilon,
                                                   num_points=int(prop_points * len(x_val)),
                                                   exponentiate=exponentiate,
                                                   calculate_astuteness=False,
                                                   NN=False)
    else:
        background = x_train[np.random.choice(len(x_train), 100, replace=False)]
        explainer = shap.GradientExplainer(bbox_model, background)

        explanation = calculate_robust_astute_sampled(data=x_val,
                                                      explainer=explainer,
                                                      explainer_type='shap',
                                                      explanation_type='attribution',
                                                      ball_r=ball_r,
                                                      epsilon=epsilon,
                                                      num_points=int(prop_points * len(x_val)),
                                                      exponentiate=exponentiate,
                                                      calculate_astuteness=False)

    del pred_model
    return np.abs(explanation)

if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument('--datatype', type=str,
                        choices=['orange_skin', 'XOR', 'nonlinear_additive', 'switch'], default='XOR')
    parser.add_argument('--ball_radius', type=float, default=2)
    parser.add_argument('--epsilon', type=float, default=0.05)
    parser.add_argument('--prop_points', type=float, default=.01)
    parser.add_argument('--run_times', type=int, default=1)
    parser.add_argument('--exponentiate', type=int, default=0)
    args = parser.parse_args()
    classifiers = ['2layer']
    for c in range(len(classifiers)):
        for i in range(args.run_times):
            fname = 'explained_weights/shap/' + 'shap_' + args.datatype + '_' + classifiers[c] + '_' + str(i) + '.gz'
            explanation = shap_explainer(datatype=args.datatype,
                                           ball_r=args.ball_radius,
                                           epsilon=args.epsilon,
                                           prop_points=args.prop_points,
                                           exponentiate=args.exponentiate,
                                           classifier=classifiers[c])
            np.savetxt(X=explanation, fname=fname, delimiter=',')

