import numpy as np
import argparse
import pickle
from tqdm import tqdm
import matplotlib.pyplot as plt

import shap
from scipy.spatial.distance import pdist
from tensorflow.python.keras.layers import Dense, Input, Flatten, Add, Multiply, Lambda
from tensorflow.python.keras.layers.normalization import BatchNormalization
from tensorflow.python.keras import regularizers
from tensorflow.python.keras.models import Model, Sequential
from tensorflow.python.keras import optimizers
from tensorflow.python.keras.callbacks import ModelCheckpoint

from utils.explanations import calculate_stability, calculate_robust_astute_sampled

if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument('--datatype', type=str,
                        choices=['orange_skin', 'XOR', 'nonlinear_additive', 'switch'], default='orange_skin')
    parser.add_argument('--run_times', type=int, default=1)
    parser.add_argument('--epsilon_range', default=np.arange(0.02, 0.22, 0.02))
    parser.add_argument('--prop_points', type=float, default=1)
    parser.add_argument('--calculate', dest='calculate', action='store_true')
    parser.add_argument('--no-calculate', dest='calculate', action='store_false')
    parser.set_defaults(calculate=True)

    args = parser.parse_args()

    ks = {'orange_skin': 4, 'XOR': 2, 'nonlinear_additive': 4, 'switch': 5}

    data_dict = pickle.load(open('data/' + args.datatype + '.pk', 'rb'))


    x_train, _, x_val, _, _, input_shape = data_dict['x_train'], data_dict['y_train'], \
                           data_dict['x_val'], data_dict['y_val'], \
                           data_dict['datatype_val'], data_dict['input_shape']

    median_rad = 0.5 * np.median(pdist(x_val))
    save_astuteness_file = 'plots/rise_' + args.datatype + '_astuteness_classifiers.pk'
    classifiers = ['2layer','4layer','linear','svm']

    if args.calculate:
        total_astuteness = np.zeros(shape=(args.run_times, len(classifiers), len(args.epsilon_range)))
        for i in range(args.run_times):
            print('Completing Run ' + str(i + 1) + ' of ' + str(args.run_times))
            for j in range(len(classifiers)):
                if classifiers[j] == '2layer':
                    activation = 'relu' if args.datatype in ['orange_skin', 'XOR'] else 'selu'
                    model_input = Input(shape=(input_shape,), dtype='float32')
                    net = Dense(200, activation=activation, name='dense1',
                                kernel_regularizer=regularizers.l2(1e-3))(model_input)
                    net = BatchNormalization()(net)  # Add batchnorm for stability.
                    net = Dense(200, activation=activation, name='dense2',
                                kernel_regularizer=regularizers.l2(1e-3))(net)
                    net = BatchNormalization()(net)
                    preds = Dense(2, activation='softmax', name='dense4',
                                  kernel_regularizer=regularizers.l2(1e-3))(net)
                    bbox_model = Model(model_input, preds)
                    bbox_model.load_weights('models/' + args.datatype + '_blackbox.hdf5',
                                            by_name=True)
                    pred_model = Model(model_input, preds)

                elif classifiers[j] == '4layer':
                    activation = 'relu' if args.datatype in ['orange_skin', 'XOR'] else 'selu'

                    model_input = Input(shape=(input_shape,), dtype='float32')
                    net = Dense(50, activation=activation, name='dense1',
                                kernel_regularizer=regularizers.l2(1e-3))(model_input)
                    net = BatchNormalization()(net)  # Add batchnorm for stability.
                    net = Dense(50, activation=activation, name='dense2',
                                kernel_regularizer=regularizers.l2(1e-3))(net)
                    net = BatchNormalization()(net)
                    net = Dense(50, activation=activation, name='dense3',
                                kernel_regularizer=regularizers.l2(1e-3))(net)
                    net = BatchNormalization()(net)
                    net = Dense(50, activation=activation, name='dense4',
                                kernel_regularizer=regularizers.l2(1e-3))(net)
                    net = BatchNormalization()(net)
                    preds = Dense(2, activation='softmax', name='dense5',
                                  kernel_regularizer=regularizers.l2(1e-3))(net)
                    bbox_model = Model(model_input, preds)
                    bbox_model.load_weights('models/' + args.datatype + '_blackbox_extra.hdf5',
                                            by_name=True)
                    pred_model = Model(model_input, preds)


                elif classifiers[j] == 'linear':
                    activation = None

                    model_input = Input(shape=(input_shape,), dtype='float32')

                    net = Dense(200, activation=activation, name='dense1',
                                kernel_regularizer=regularizers.l2(1e-3))(model_input)
                    net = BatchNormalization()(net)  # Add batchnorm for stability.

                    preds = Dense(2, activation='softmax', name='dense4',
                                  kernel_regularizer=regularizers.l2(1e-3))(net)
                    bbox_model = Model(model_input, preds)
                    bbox_model.load_weights('models/' + args.datatype + '_blackbox_linear.hdf5',
                                            by_name=True)
                    pred_model = Model(model_input, preds)
                elif classifiers[j] == 'svm':
                    pred_model = pickle.load(open('models/' + args.datatype + '_svm.pk', 'rb'))
                fname = 'explained_weights/shap/' + 'shap_' + args.datatype + '_' + classifiers[j] + '_' + str(i) + '.gz'
                explanations = np.loadtxt(fname, delimiter=',')
                if classifiers[j] == 'svm':
                    training_indices = np.random.choice(len(x_train), int(0.001 * len(x_train)), replace=False)
                    explainer = shap.KernelExplainer(pred_model.predict_proba,
                                                     shap.kmeans(x_train[training_indices], 100))
                    for k in tqdm(range(len(args.epsilon_range))):
                        _, total_astuteness[i, j, k],_ = calculate_robust_astute_sampled(data=x_val,
                                                                               explainer=explainer,
                                                                               explainer_type='shap',
                                                                               explanation_type='attribution',
                                                                               ball_r=median_rad,
                                                                               epsilon=args.epsilon_range[k],
                                                                               num_points=int(
                                                                                   args.prop_points * len(x_val)),
                                                                               NN=False,
                                                                               data_explanation=explanations)
                else:
                    background = x_train[np.random.choice(len(x_train), 100, replace=False)]
                    explainer = shap.GradientExplainer(pred_model, background)
                    for k in tqdm(range(len(args.epsilon_range))):
                        _, total_astuteness[i, j, k],_ = calculate_robust_astute_sampled(data=x_val,
                                                                               explainer=explainer,
                                                                               explainer_type='shap',
                                                                               explanation_type='attribution',
                                                                               ball_r=median_rad,
                                                                               epsilon=args.epsilon_range[k],
                                                                               num_points=int(
                                                                                   args.prop_points * len(x_val)),
                                                                               NN=True,
                                                                               data_explanation=explanations)
        pickle.dump(total_astuteness, open(save_astuteness_file, 'wb'))
    else:
        total_astuteness = pickle.load(open(save_astuteness_file, 'rb'))
    astuteness_mean = total_astuteness.mean(axis=0)
    astuteness_std = total_astuteness.std(axis=0)
    image_name = 'plots/rise_' + args.datatype + '_astuteness_classifiers.PNG'
    fig, ax = plt.subplots()
    for i in range(len(classifiers)):
        ax.errorbar(x=args.epsilon_range, y=astuteness_mean[i, :], yerr=astuteness_std[i, :],
                    label=classifiers[i])
    plt.legend()
    plt.savefig(image_name)