import numpy as np

import numpy as np
import pickle
from datetime import datetime
import time
from scipy.spatial.distance import pdist

from utils.explanations import calculate_prob_lipschitz
import matplotlib.pyplot as plt

from tensorflow.python.keras.layers import Dense, Input, Flatten, Add, Multiply, Lambda
from tensorflow.python.keras.layers.normalization import BatchNormalization
from tensorflow.python.keras import regularizers
from tensorflow.python.keras.models import Model, Sequential
from tensorflow.python.keras import optimizers
from tensorflow.python.keras.callbacks import ModelCheckpoint

datatype = 'orange_skin'
train = False
BATCH_SIZE = 1000
epochs = 2
calculate = True
save_lipschitz = 'plots/blackbox_' + datatype + '_lipschitz.pk'
np.random.seed(0)

data_dict = pickle.load(open('data/' + datatype + '.pk', 'rb'))
classifiers = ['2layer','4layer','linear','svm']
L_range = np.arange(0, 1.1, 0.1)
total_lipschitz = np.zeros(shape=(len(classifiers), len(L_range)))

x_train, y_train, x_val, y_val, datatype_val, input_shape = data_dict['x_train'], data_dict['y_train'], \
                                                            data_dict['x_val'], data_dict['y_val'], \
                                                            data_dict['datatype_val'], data_dict['input_shape']

median_rad = 0.5 * np.median(pdist(x_val))

activation = 'relu' if datatype in ['orange_skin', 'XOR'] else 'selu'

model_input = Input(shape=(input_shape,), dtype='float32')

net = Dense(200, activation=activation, name='dense1',
            kernel_regularizer=regularizers.l2(1e-3))(model_input)
net = BatchNormalization()(net)  # Add batchnorm for stability.
net = Dense(200, activation=activation, name='dense2',
            kernel_regularizer=regularizers.l2(1e-3))(net)
net = BatchNormalization()(net)

preds = Dense(2, activation='softmax', name='dense4',
              kernel_regularizer=regularizers.l2(1e-3))(net)
model = Model(model_input, preds)

model.load_weights('models/' + datatype + '_blackbox.hdf5',
                       by_name=True)
pred_model = Model(model_input, preds)
pred_model.compile(loss=None,
                   optimizer='rmsprop',
                   metrics=None)

if calculate:
    total_lipschitz[0, :] = calculate_prob_lipschitz(x_val, pred_model,
                                               r=median_rad,
                                               L_range=L_range,
                                               num_points=len(x_val))


###

print("Training classifier with extra layer")

activation = 'relu'

model_input = Input(shape=(input_shape,), dtype='float32')
net = Dense(100, activation=activation, name='dense1',
            kernel_regularizer=regularizers.l2(1e-3))(model_input)
net = BatchNormalization()(net)  # Add batchnorm for stability.
net = Dense(100, activation=activation, name='dense2',
            kernel_regularizer=regularizers.l2(1e-3))(net)
net = BatchNormalization()(net)
net = Dense(100, activation=activation, name='dense3',
            kernel_regularizer=regularizers.l2(1e-3))(net)
net = BatchNormalization()(net)
net = Dense(100, activation=activation, name='dense4',
            kernel_regularizer=regularizers.l2(1e-3))(net)
net = BatchNormalization()(net)

preds = Dense(2, activation='softmax', name='dense5',
              kernel_regularizer=regularizers.l2(1e-3))(net)
model = Model(model_input, preds)
model.load_weights('models/' + datatype + '_blackbox_extra.hdf5',
                   by_name=True)
pred_model = Model(model_input, preds)
pred_model.compile(loss=None,
                   optimizer='rmsprop',
                   metrics=None)
if calculate:
    total_lipschitz[1, :] = calculate_prob_lipschitz(x_val, pred_model,
                                               r=median_rad,
                                               L_range=L_range,
                                               num_points=len(x_val))


print('Training Linear Classifier')

activation = None

model_input = Input(shape=(input_shape,), dtype='float32')

net = Dense(200, activation=activation, name='dense1',
            kernel_regularizer=regularizers.l2(1e-3))(model_input)
net = BatchNormalization()(net)  # Add batchnorm for stability.

preds = Dense(2, activation='softmax', name='dense4',
              kernel_regularizer=regularizers.l2(1e-3))(net)
model = Model(model_input, preds)

model.load_weights('models/' + datatype + '_blackbox_linear.hdf5',
                   by_name=True)
pred_model = Model(model_input, preds)
pred_model.compile(loss=None,
                   optimizer='rmsprop',
                   metrics=None)


if calculate:
    total_lipschitz[2, :] = calculate_prob_lipschitz(x_val, pred_model,
                                               r=median_rad,
                                               L_range=L_range,
                                               num_points=len(x_val))

###


print("SVM")
svm_classif = pickle.load(open('models/' + datatype + '_svm.pk', 'rb'))

if calculate:
    total_lipschitz[3, :] = calculate_prob_lipschitz(x_val, svm_classif,
                                               r=median_rad,
                                               L_range=L_range,
                                               num_points=len(x_val),
                                               NN=False)


if calculate:
    pickle.dump(total_lipschitz, open(save_lipschitz, 'wb'))
else:
    total_lipschitz = pickle.load(open(save_lipschitz, 'rb'))

image_name = 'plots/classifiers_' + datatype + '_lipschitz.PNG'
for i in range(len(classifiers)):
    plt.errorbar(x=L_range, y=total_lipschitz[i, :], yerr=0,
                 label=classifiers[i], marker='x')
plt.legend()
plt.savefig(image_name)
r = 3