import numpy as np
import argparse
import pickle
from tqdm import tqdm
import matplotlib.pyplot as plt

from keras.layers import Dense, Input, Flatten, Add, Multiply, Lambda
from keras.layers.normalization import BatchNormalization
from keras import regularizers
from keras.models import Model, Sequential
from keras.callbacks import ModelCheckpoint
from keras import optimizers

from utils.explanations import calculate_stability, calculate_robust_astute_sampled

if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument('--datatype', type=str,
                        choices=['orange_skin', 'XOR', 'nonlinear_additive', 'switch'], default='switch')
    parser.add_argument('--run_times', type=int, default=10)
    parser.add_argument('--radius_range', default=np.arange(0, 5, 1))
    parser.add_argument('--epsilon_range', default=np.arange(0, 1, 0.1))
    parser.add_argument('--prop_points', type=float, default=1)
    parser.add_argument('--calculate', dest='calculate', action='store_true')
    parser.add_argument('--no-calculate', dest='calculate', action='store_false')
    parser.set_defaults(calculate=True)

    args = parser.parse_args()

    ks = {'orange_skin': 4, 'XOR': 2, 'nonlinear_additive': 4, 'switch': 5}

    data_dict = pickle.load(open('data/' + args.datatype + '.pk', 'rb'))

    x_train, _, x_val, _, _, input_shape = data_dict['x_train'], data_dict['y_train'], \
                           data_dict['x_val'], data_dict['y_val'], \
                           data_dict['datatype_val'], data_dict['input_shape']
    save_astuteness_file = 'plots/rise_' + args.datatype + '_astuteness.pk'

    if args.calculate:
        total_astuteness = np.zeros(shape=(args.run_times, len(args.radius_range), len(args.epsilon_range)))
        activation = 'relu' if args.datatype in ['orange_skin', 'XOR'] else 'selu'

        model_input = Input(shape=(input_shape,), dtype='float32')

        net = Dense(200, activation=activation, name='dense1',
                    kernel_regularizer=regularizers.l2(1e-3))(model_input)
        net = BatchNormalization()(net)  # Add batchnorm for stability.
        net = Dense(200, activation=activation, name='dense2',
                    kernel_regularizer=regularizers.l2(1e-3))(net)
        net = BatchNormalization()(net)

        preds = Dense(2, activation='softmax', name='dense4',
                      kernel_regularizer=regularizers.l2(1e-3))(net)
        bbox_model = Model(model_input, preds)
        bbox_model.load_weights('models/' + args.datatype + '_blackbox.hdf5',
                                by_name=True)
        for i in range(args.run_times):
            fname = 'explained_weights/rise/' + 'rise_' + args.datatype + '_' + str(i) + '.gz'
            explanations = np.loadtxt(fname, delimiter=',')
            print('Completing Run ' + str(i + 1) + ' of ' + str(args.run_times))
            for j in range(len(args.radius_range)):
                for k in tqdm(range(len(args.epsilon_range))):
                    _, total_astuteness[i, j, k],_ = calculate_robust_astute_sampled(data=x_val,
                                                                           explainer=bbox_model,
                                                                           explainer_type='rise',
                                                                           explanation_type='attribution',
                                                                           ball_r=args.radius_range[j],
                                                                           epsilon=args.epsilon_range[k],
                                                                           num_points=int(
                                                                               args.prop_points * len(x_val)))
        pickle.dump(total_astuteness, open(save_astuteness_file, 'wb'))
    else:
        total_astuteness = pickle.load(open(save_astuteness_file, 'rb'))
    astuteness_mean = total_astuteness.mean(axis=0)
    astuteness_std = total_astuteness.std(axis=0)
    image_name = 'plots/rise_' + args.datatype + '_astuteness.PNG'
    fig, ax = plt.subplots()
    for i in range(len(args.radius_range)):
        ax.errorbar(x=args.epsilon_range, y=astuteness_mean[i, :], yerr=astuteness_std[i, :],
                    label='radius: ' + str(args.radius_range[i]))
    plt.legend()
    plt.savefig(image_name)
    r = 3
