import pickle
import numpy as np
import argparse

from tensorflow.python.keras.layers import Dense, Input, Flatten, Add, Multiply, Lambda
from tensorflow.python.keras.layers.normalization import BatchNormalization
from tensorflow.python.keras import regularizers
from tensorflow.python.keras.models import Model, Sequential
from tensorflow.python.keras import optimizers
from tensorflow.python.keras.callbacks import ModelCheckpoint
import warnings
warnings.filterwarnings("ignore")
from tqdm import tqdm

from utils.explanations import calculate_robust_astute_sampled

np.random.seed(0)


def rise_explainer(datatype, ball_r, epsilon, prop_points, exponentiate, classifier):
    data_dict = pickle.load(open('data/' + datatype + '.pk', 'rb'))
    x_train, y_train, x_val, y_val, datatype_val, input_shape = data_dict['x_train'], data_dict['y_train'], \
                                                                data_dict['x_val'], data_dict['y_val'], \
                                                                data_dict['datatype_val'], data_dict['input_shape']
    if classifier == '2layer':
        activation = 'relu' if datatype in ['orange_skin', 'XOR'] else 'selu'
        model_input = Input(shape=(input_shape,), dtype='float32')
        net = Dense(200, activation=activation, name='dense1',
                    kernel_regularizer=regularizers.l2(1e-3))(model_input)
        net = BatchNormalization()(net)  # Add batchnorm for stability.
        net = Dense(200, activation=activation, name='dense2',
                    kernel_regularizer=regularizers.l2(1e-3))(net)
        net = BatchNormalization()(net)
        preds = Dense(2, activation='softmax', name='dense4',
                      kernel_regularizer=regularizers.l2(1e-3))(net)
        bbox_model = Model(model_input, preds)
        bbox_model.load_weights('models/' + datatype + '_blackbox.hdf5',
                           by_name=True)
        pred_model = Model(model_input, preds)

    elif classifier == '4layer':
        activation = 'relu' if datatype in ['orange_skin', 'XOR'] else 'selu'

        model_input = Input(shape=(input_shape,), dtype='float32')
        net = Dense(200, activation=activation, name='dense1',
                    kernel_regularizer=regularizers.l2(1e-3))(model_input)
        net = BatchNormalization()(net)  # Add batchnorm for stability.
        net = Dense(200, activation=activation, name='dense2',
                    kernel_regularizer=regularizers.l2(1e-3))(net)
        net = BatchNormalization()(net)
        net = Dense(200, activation=activation, name='dense3',
                    kernel_regularizer=regularizers.l2(1e-3))(net)
        net = BatchNormalization()(net)
        net = Dense(200, activation=activation, name='dense4',
                    kernel_regularizer=regularizers.l2(1e-3))(net)
        net = BatchNormalization()(net)
        preds = Dense(2, activation='softmax', name='dense5',
                      kernel_regularizer=regularizers.l2(1e-3))(net)
        bbox_model = Model(model_input, preds)
        bbox_model.load_weights('models/' + datatype + '_blackbox_extra.hdf5',
                           by_name=True)
        pred_model = Model(model_input, preds)


    elif classifier == 'linear':
        activation = None

        model_input = Input(shape=(input_shape,), dtype='float32')

        net = Dense(200, activation=activation, name='dense1',
                    kernel_regularizer=regularizers.l2(1e-3))(model_input)
        net = BatchNormalization()(net)  # Add batchnorm for stability.
        net = Dense(200, activation=activation, name='dense2',
                    kernel_regularizer=regularizers.l2(1e-3))(net)
        net = BatchNormalization()(net)
        preds = Dense(2, activation='softmax', name='dense4',
                      kernel_regularizer=regularizers.l2(1e-3))(net)
        bbox_model = Model(model_input, preds)
        bbox_model.load_weights('models/' + datatype + '_blackbox_linear.hdf5',
                           by_name=True)
        pred_model = Model(model_input, preds)


    elif classifier == 'svm':
        pred_model = pickle.load(open('models/' + datatype + '_svm.pk', 'rb'))

    # masks = rise.generate_masks(1000, d=x_train.shape[1])
    # explanations = rise.explain(model=pred_model, input=x_val,
    #                             masks=masks, batch_size=1024,)
    # background = x_train[np.random.choice(len(x_train), 100, replace=False)]

    if classifier == 'svm':
        explanation = calculate_robust_astute_sampled(data=x_val,
                                                      explainer=pred_model,
                                                      explainer_type='rise',
                                                      explanation_type='attribution',
                                                      ball_r=ball_r,
                                                      epsilon=epsilon,
                                                      num_points=int(prop_points * len(x_val)),
                                                      exponentiate=exponentiate,
                                                      calculate_astuteness=False,
                                                      NN=False)
    else:
        explanation = calculate_robust_astute_sampled(data=x_val,
                                                      explainer=pred_model,
                                                      explainer_type='rise',
                                                      explanation_type='attribution',
                                                      ball_r=ball_r,
                                                      epsilon=epsilon,
                                                      num_points=int(prop_points * len(x_val)),
                                                      exponentiate=exponentiate,
                                                      calculate_astuteness=False,
                                                      NN=True)

    return np.abs(explanation)


if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument('--datatype', type=str,
                        choices=['orange_skin', 'XOR', 'nonlinear_additive', 'switch'], default='XOR')
    parser.add_argument('--ball_radius', type=float, default=2)
    parser.add_argument('--epsilon', type=float, default=0.05)
    parser.add_argument('--prop_points', type=float, default=.01)
    parser.add_argument('--run_times', type=int, default=1)
    parser.add_argument('--exponentiate', type=int, default=0)
    args = parser.parse_args()
    classifiers=['2layer', 'linear', '4layer', 'svm']
    for c in range(len(classifiers)):
        for i in tqdm(range(args.run_times)):
            fname = 'explained_weights/rise/' + 'rise_' + args.datatype + '_'  + classifiers[c] + '_' + str(i) + '.gz'
            explanation = rise_explainer(datatype=args.datatype,
                                         ball_r=args.ball_radius,
                                         epsilon=args.epsilon,
                                         prop_points=args.prop_points,
                                         exponentiate=args.exponentiate,
                                         classifier=classifiers[c])

            np.savetxt(X=explanation, fname=fname, delimiter=',')

