import pickle
import numpy as np
import argparse
import cxplain

from tensorflow.python.keras.layers import Dense, Input, Flatten, Add, Multiply, Lambda
from tensorflow.python.keras.layers.normalization import BatchNormalization
from tensorflow.python.keras import regularizers
from tensorflow.python.keras.models import Model, Sequential

from cxplain import MLPModelBuilder, ZeroMasking, CXPlain
from tensorflow.python.keras.losses import categorical_crossentropy

from utils.explanations import calculate_robust_astute_sampled

np.random.seed(0)


def cxplain_explainer(datatype, ball_r, epsilon, prop_points, exponentiate, classifier):
    data_dict = pickle.load(open('data/' + datatype + '.pk', 'rb'))
    x_train, y_train, x_val, y_val, datatype_val, input_shape = data_dict['x_train'], data_dict['y_train'], \
                                                                data_dict['x_val'], data_dict['y_val'], \
                                                                data_dict['datatype_val'], data_dict['input_shape']

    masking_operation = ZeroMasking()
    loss = categorical_crossentropy
    if classifier == '2layer':
        activation = 'relu' if datatype in ['orange_skin', 'XOR'] else 'selu'
        model_input = Input(shape=(input_shape,), dtype='float32')
        net = Dense(200, activation=activation, name='dense1',
                    kernel_regularizer=regularizers.l2(1e-3))(model_input)
        net = BatchNormalization()(net)  # Add batchnorm for stability.
        net = Dense(200, activation=activation, name='dense2',
                    kernel_regularizer=regularizers.l2(1e-3))(net)
        net = BatchNormalization()(net)
        preds = Dense(2, activation='softmax', name='dense4',
                      kernel_regularizer=regularizers.l2(1e-3))(net)
        bbox_model = Model(model_input, preds)
        bbox_model.load_weights('models/' + datatype + '_blackbox.hdf5',
                                by_name=True)
        pred_model = Model(model_input, preds)

        model_builder = MLPModelBuilder(num_layers=2, num_units=200, activation=activation, verbose=1,
                                        batch_size=1000, learning_rate=0.001, num_epochs=5, early_stopping_patience=15,
                                        with_bn=True)

    elif classifier == '4layer':
        activation = 'relu' if datatype in ['orange_skin', 'XOR'] else 'selu'

        model_input = Input(shape=(input_shape,), dtype='float32')
        net = Dense(50, activation=activation, name='dense1',
                    kernel_regularizer=regularizers.l2(1e-3))(model_input)
        net = BatchNormalization()(net)  # Add batchnorm for stability.
        net = Dense(50, activation=activation, name='dense2',
                    kernel_regularizer=regularizers.l2(1e-3))(net)
        net = BatchNormalization()(net)
        net = Dense(50, activation=activation, name='dense3',
                    kernel_regularizer=regularizers.l2(1e-3))(net)
        net = BatchNormalization()(net)
        net = Dense(50, activation=activation, name='dense4',
                    kernel_regularizer=regularizers.l2(1e-3))(net)
        net = BatchNormalization()(net)
        preds = Dense(2, activation='softmax', name='dense5',
                      kernel_regularizer=regularizers.l2(1e-3))(net)
        bbox_model = Model(model_input, preds)
        bbox_model.load_weights('models/' + datatype + '_blackbox_extra.hdf5',
                                by_name=True)
        pred_model = Model(model_input, preds)
        model_builder = MLPModelBuilder(num_layers=4, num_units=50, activation=activation, verbose=1,
                                        batch_size=1000, learning_rate=0.001, num_epochs=5, early_stopping_patience=15,
                                        with_bn=True)

    elif classifier == 'linear':
        activation = None

        model_input = Input(shape=(input_shape,), dtype='float32')

        net = Dense(200, activation=activation, name='dense1',
                    kernel_regularizer=regularizers.l2(1e-3))(model_input)
        net = BatchNormalization()(net)  # Add batchnorm for stability.
        preds = Dense(2, activation='softmax', name='dense4',
                      kernel_regularizer=regularizers.l2(1e-3))(net)
        bbox_model = Model(model_input, preds)
        bbox_model.load_weights('models/' + datatype + '_blackbox_linear.hdf5',
                                by_name=True)
        pred_model = Model(model_input, preds)
        model_builder = MLPModelBuilder(num_layers=1, num_units=200, activation=activation, verbose=1,
                                        batch_size=1000, learning_rate=0.001, num_epochs=5, early_stopping_patience=15,
                                        with_bn=True)


    elif classifier == 'svm':
        activation = 'relu' if datatype in ['orange_skin', 'XOR'] else 'selu'
        pred_model = pickle.load(open('models/' + datatype + '_svm.pk', 'rb'))
        model_builder = MLPModelBuilder(num_layers=2, num_units=200, activation=activation, verbose=1,
                                        batch_size=1000, learning_rate=0.001, num_epochs=5, early_stopping_patience=15,
                                        with_bn=True)
    if classifier == 'svm':
        training_indices = np.random.choice(len(x_train), int(0.01 * len(x_train)), replace=False)
        explainer = CXPlain(pred_model, model_builder, masking_operation, loss, num_models=1)
        explainer.fit(x_train, y_train)
        explanation = calculate_robust_astute_sampled(data=x_val,
                                                      explainer=explainer,
                                                      explainer_type='cxplain',
                                                      explanation_type='attribution',
                                                      ball_r=ball_r,
                                                      epsilon=epsilon,
                                                      num_points=int(prop_points * len(x_val)),
                                                      exponentiate=exponentiate,
                                                      calculate_astuteness=False,
                                                      NN=False)
    else:
        training_indices = np.random.choice(len(x_train), int(0.01 * len(x_train)), replace=False)
        explainer = CXPlain(pred_model, model_builder, masking_operation, loss, num_models=1)
        explainer.fit(x_train, y_train)
        explanation = calculate_robust_astute_sampled(data=x_val,
                                                      explainer=explainer,
                                                      explainer_type='cxplain',
                                                      explanation_type='attribution',
                                                      ball_r=ball_r,
                                                      epsilon=epsilon,
                                                      num_points=int(prop_points * len(x_val)),
                                                      exponentiate=exponentiate,
                                                      calculate_astuteness=False)

    del pred_model
    return np.abs(explanation)


if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument('--datatype', type=str,
                        choices=['orange_skin', 'XOR', 'nonlinear_additive', 'switch'], default='XOR')
    parser.add_argument('--ball_radius', type=float, default=2)
    parser.add_argument('--epsilon', type=float, default=0.05)
    parser.add_argument('--prop_points', type=float, default=.01)
    parser.add_argument('--run_times', type=int, default=1)
    parser.add_argument('--exponentiate', type=int, default=0)
    args = parser.parse_args()
    classifiers = ['2layer', '4layer', 'linear', 'svm']
    for c in range(len(classifiers)):
        for i in range(args.run_times):
            fname = 'explained_weights/cxplain/' + 'cxplain_' + args.datatype + '_' + classifiers[c] + '_' + str(
                i) + '.gz'
            explanation = cxplain_explainer(datatype=args.datatype,
                                            ball_r=args.ball_radius,
                                            epsilon=args.epsilon,
                                            prop_points=args.prop_points,
                                            exponentiate=args.exponentiate,
                                            classifier=classifiers[c])
            np.savetxt(X=explanation, fname=fname, delimiter=',')

