import numpy as np
import argparse
import pickle
from tqdm import tqdm
import matplotlib.pyplot as plt

from tensorflow.python.keras.layers import Dense, Input, Flatten, Add, Multiply, Lambda
from tensorflow.python.keras.layers.normalization import BatchNormalization
from tensorflow.python.keras import regularizers
from tensorflow.python.keras.models import Model, Sequential
from tensorflow.python.keras import optimizers
from tensorflow.python.keras.callbacks import ModelCheckpoint
from scipy.spatial.distance import pdist

from utils.explanations import calculate_stability, calculate_robust_astute_sampled

if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument('--datatype', type=str,
                        choices=['orange_skin', 'XOR', 'nonlinear_additive', 'switch'], default='switch')
    parser.add_argument('--run_times', type=int, default=10)

    parser.add_argument('--prop_points', type=float, default=1)
    parser.add_argument('--calculate', dest='calculate', action='store_true')
    parser.add_argument('--no-calculate', dest='calculate', action='store_false')
    parser.set_defaults(calculate=True)

    args = parser.parse_args()

    if args.datatype == 'nonlinear_additive':
        epsilon_range = np.arange(0.01, 1.1, 0.1)
    elif args.datatype == 'orange_skin':
        epsilon_range = np.arange(0.01, 1.1, 0.1)
    elif args.datatype == 'switch':
        epsilon_range = np.arange(0.02, 0.22, 0.02)
    elif args.datatype == 'XOR':
        epsilon_range = np.arange(0.02, 0.22, 0.02)

    ks = {'orange_skin': 4, 'XOR': 2, 'nonlinear_additive': 4, 'switch': 5}

    data_dict = pickle.load(open('data/' + args.datatype + '.pk', 'rb'))

    x_train, _, x_val, _, _, input_shape = data_dict['x_train'], data_dict['y_train'], \
                                           data_dict['x_val'], data_dict['y_val'], \
                                           data_dict['datatype_val'], data_dict['input_shape']

    median_rad = 0.5 * np.median(pdist(x_val))
    save_astuteness_file = 'plots/rise_' + args.datatype + '_astuteness_classifiers.pk'
    classifiers = ['2layer', '4layer', 'linear', 'svm']
    if args.calculate:
        total_astuteness = np.zeros(shape=(args.run_times, len(classifiers), len(epsilon_range)))
        for i in range(args.run_times):
            print('Completing Run ' + str(i + 1) + ' of ' + str(args.run_times))
            for j in range(len(classifiers)):
                if classifiers[j] == '2layer':
                    activation = 'relu' if args.datatype in ['orange_skin', 'XOR'] else 'selu'
                    model_input = Input(shape=(input_shape,), dtype='float32')
                    net = Dense(200, activation=activation, name='dense1',
                                kernel_regularizer=regularizers.l2(1e-3))(model_input)
                    net = BatchNormalization()(net)  # Add batchnorm for stability.
                    net = Dense(200, activation=activation, name='dense2',
                                kernel_regularizer=regularizers.l2(1e-3))(net)
                    net = BatchNormalization()(net)
                    preds = Dense(2, activation='softmax', name='dense4',
                                  kernel_regularizer=regularizers.l2(1e-3))(net)
                    bbox_model = Model(model_input, preds)
                    bbox_model.load_weights('models/' + args.datatype + '_blackbox.hdf5',
                                            by_name=True)
                    pred_model = Model(model_input, preds)

                elif classifiers[j] == '4layer':
                    activation = 'relu' if args.datatype in ['orange_skin', 'XOR'] else 'selu'

                    model_input = Input(shape=(input_shape,), dtype='float32')
                    net = Dense(50, activation=activation, name='dense1',
                                kernel_regularizer=regularizers.l2(1e-3))(model_input)
                    net = BatchNormalization()(net)  # Add batchnorm for stability.
                    net = Dense(50, activation=activation, name='dense2',
                                kernel_regularizer=regularizers.l2(1e-3))(net)
                    net = BatchNormalization()(net)
                    net = Dense(50, activation=activation, name='dense3',
                                kernel_regularizer=regularizers.l2(1e-3))(net)
                    net = BatchNormalization()(net)
                    net = Dense(50, activation=activation, name='dense4',
                                kernel_regularizer=regularizers.l2(1e-3))(net)
                    net = BatchNormalization()(net)
                    preds = Dense(2, activation='softmax', name='dense5',
                                  kernel_regularizer=regularizers.l2(1e-3))(net)
                    bbox_model = Model(model_input, preds)
                    bbox_model.load_weights('models/' + args.datatype + '_blackbox_extra.hdf5',
                                            by_name=True)
                    pred_model = Model(model_input, preds)


                elif classifiers[j] == 'linear':
                    activation = None

                    model_input = Input(shape=(input_shape,), dtype='float32')

                    net = Dense(200, activation=activation, name='dense1',
                                kernel_regularizer=regularizers.l2(1e-3))(model_input)
                    net = BatchNormalization()(net)  # Add batchnorm for stability.

                    preds = Dense(2, activation='softmax', name='dense4',
                                  kernel_regularizer=regularizers.l2(1e-3))(net)
                    bbox_model = Model(model_input, preds)
                    bbox_model.load_weights('models/' + args.datatype + '_blackbox_linear.hdf5',
                                            by_name=True)
                    pred_model = Model(model_input, preds)
                elif classifiers[j] == 'svm':
                    pred_model = pickle.load(open('models/' + args.datatype + '_svm.pk', 'rb'))
                fname = 'explained_weights/rise/' + 'rise_' + args.datatype + '_' + classifiers[j] + '_' + str(
                    i) + '.gz'
                explanations = np.loadtxt(fname, delimiter=',')
                if classifiers[j] == 'svm':
                    for k in tqdm(range(len(epsilon_range))):
                        _, total_astuteness[i, j, k], _ = calculate_robust_astute_sampled(data=x_val,
                                                                                          explainer=pred_model,
                                                                                          explainer_type='rise',
                                                                                          explanation_type='attribution',
                                                                                          ball_r=median_rad,
                                                                                          epsilon=epsilon_range[k],
                                                                                          num_points=int(
                                                                                              args.prop_points * len(
                                                                                                  x_val)),
                                                                                          NN=False,
                                                                                          data_explanation=explanations)
                else:
                    for k in tqdm(range(len(epsilon_range))):
                        _, total_astuteness[i, j, k], _ = calculate_robust_astute_sampled(data=x_val,
                                                                                          explainer=pred_model,
                                                                                          explainer_type='rise',
                                                                                          explanation_type='attribution',
                                                                                          ball_r=median_rad,
                                                                                          epsilon=epsilon_range[k],
                                                                                          num_points=int(
                                                                                              args.prop_points * len(
                                                                                                  x_val)),
                                                                                          NN=True,
                                                                                          data_explanation=explanations)
        pickle.dump(total_astuteness, open(save_astuteness_file, 'wb'))
    else:
        total_astuteness = pickle.load(open(save_astuteness_file, 'rb'))
    astuteness_mean = total_astuteness.mean(axis=0)
    astuteness_std = total_astuteness.std(axis=0)
    image_name = 'plots/rise_' + args.datatype + '_astuteness_classifiers.PNG'
    fig, ax = plt.subplots()
    for i in range(len(classifiers)):
        ax.errorbar(x=epsilon_range, y=astuteness_mean[i, :], yerr=astuteness_std[i, :],
                    label=classifiers[i])
    plt.legend()
    plt.savefig(image_name)
    r = 3
