import numpy as np
import pickle
from datetime import datetime
import time

# from keras.layers import Dense, Input, Flatten, Add, Multiply, Lambda
# from keras.layers.normalization import BatchNormalization
# from keras import regularizers
# from keras.models import Model, Sequential
# from keras.callbacks import ModelCheckpoint
# from keras import optimizers

from tensorflow.python.keras.layers import Dense, Input, Flatten, Add, Multiply, Lambda
from tensorflow.python.keras.layers.normalization import BatchNormalization
from tensorflow.python.keras import regularizers
from tensorflow.python.keras.models import Model, Sequential
from tensorflow.python.keras import optimizers
from tensorflow.python.keras.callbacks import ModelCheckpoint
from sklearn.svm import SVC


for datatype in ['orange_skin', 'XOR', 'nonlinear_additive', 'switch']:
    train = True
    BATCH_SIZE = 1000
    epochs = 2

    np.random.seed(0)

    data_dict = pickle.load(open('data/' + datatype + '.pk', 'rb'))
    x_train, y_train, x_val, y_val, datatype_val, input_shape = data_dict['x_train'], data_dict['y_train'], \
                                                                data_dict['x_val'], data_dict['y_val'], \
                                                                data_dict['datatype_val'], data_dict['input_shape']

    activation = 'relu' if datatype in ['orange_skin', 'XOR'] else 'selu'

    model_input = Input(shape=(input_shape,), dtype='float32')

    net = Dense(200, activation=activation, name='dense1',
                kernel_regularizer=regularizers.l2(1e-3))(model_input)
    net = BatchNormalization()(net)  # Add batchnorm for stability.
    net = Dense(200, activation=activation, name='dense2',
                kernel_regularizer=regularizers.l2(1e-3))(net)
    net = BatchNormalization()(net)

    preds = Dense(2, activation='softmax', name='dense4',
                  kernel_regularizer=regularizers.l2(1e-3))(net)
    model = Model(model_input, preds)

    if train:
        adam = optimizers.Adam(lr=1e-3)
        model.compile(loss='categorical_crossentropy',
                      optimizer=adam,
                      metrics=['acc'])
        filepath = 'models/' + datatype + '_blackbox.hdf5'
        checkpoint = ModelCheckpoint(filepath, monitor='val_acc',
                                     verbose=1, save_best_only=True, mode='max')
        callbacks_list = [checkpoint]
        model.fit(x_train, y_train, validation_data=(x_val, y_val), callbacks=callbacks_list, epochs=epochs,
                  batch_size=BATCH_SIZE)
        st2 = time.time()
    else:
        model.load_weights('models/' + datatype + '_blackbox.hdf5',
                           by_name=True)
    pred_model = Model(model_input, preds)
    pred_model.compile(loss=None,
                       optimizer='rmsprop',
                       metrics=None)
    pred_val = pred_model.predict(x_val, verbose=1, batch_size=BATCH_SIZE)
    del pred_model
    ######
    print('Training Linear Classifier')

    activation = None

    model_input = Input(shape=(input_shape,), dtype='float32')

    net = Dense(200, activation=activation, name='dense1',
                kernel_regularizer=regularizers.l2(1e-3))(model_input)
    net = BatchNormalization()(net)  # Add batchnorm for stability.

    preds = Dense(2, activation='softmax', name='dense4',
                  kernel_regularizer=regularizers.l2(1e-3))(net)
    model = Model(model_input, preds)

    if train:
        adam = optimizers.Adam(lr=1e-3)
        model.compile(loss='categorical_crossentropy',
                      optimizer=adam,
                      metrics=['acc'])
        filepath = 'models/' + datatype + '_blackbox_linear.hdf5'
        checkpoint = ModelCheckpoint(filepath, monitor='val_acc',
                                     verbose=1, save_best_only=True, mode='max')
        callbacks_list = [checkpoint]
        model.fit(x_train, y_train, validation_data=(x_val, y_val), callbacks=callbacks_list, epochs=epochs,
                  batch_size=BATCH_SIZE)
        st2 = time.time()
    else:
        model.load_weights('models/' + datatype + '_blackbox_linear.hdf5',
                           by_name=True)
    pred_model = Model(model_input, preds)
    pred_model.compile(loss=None,
                       optimizer='rmsprop',
                       metrics=None)
    pred_val = pred_model.predict(x_val, verbose=1, batch_size=BATCH_SIZE)

    ###
    del pred_model
    print("Training classifier with extra layer")

    activation = 'relu' if datatype in ['orange_skin', 'XOR'] else 'selu'


    model_input = Input(shape=(input_shape,), dtype='float32')
    net = Dense(50, activation=activation, name='dense1',
                kernel_regularizer=regularizers.l2(1e-3))(model_input)
    net = BatchNormalization()(net)  # Add batchnorm for stability.
    net = Dense(50, activation=activation, name='dense2',
                kernel_regularizer=regularizers.l2(1e-3))(net)
    net = BatchNormalization()(net)
    net = Dense(50, activation=activation, name='dense3',
                kernel_regularizer=regularizers.l2(1e-3))(net)
    net = BatchNormalization()(net)
    net = Dense(50, activation=activation, name='dense4',
                kernel_regularizer=regularizers.l2(1e-3))(net)
    net = BatchNormalization()(net)

    preds = Dense(2, activation='softmax', name='dense5',
                  kernel_regularizer=regularizers.l2(1e-3))(net)
    model = Model(model_input, preds)

    if train:
        adam = optimizers.Adam(lr=1e-3)
        model.compile(loss='categorical_crossentropy',
                      optimizer=adam,
                      metrics=['acc'])
        filepath = 'models/' + datatype + '_blackbox_extra.hdf5'
        checkpoint = ModelCheckpoint(filepath, monitor='val_acc',
                                     verbose=1, save_best_only=True, mode='max')
        callbacks_list = [checkpoint]
        model.fit(x_train, y_train, validation_data=(x_val, y_val), callbacks=callbacks_list, epochs=epochs,
                  batch_size=BATCH_SIZE)
        st2 = time.time()
    else:
        model.load_weights('models/' + datatype + '_blackbox_extra.hdf5',
                           by_name=True)
    pred_model = Model(model_input, preds)
    pred_model.compile(loss=None,
                       optimizer='rmsprop',
                       metrics=None)
    pred_val = pred_model.predict(x_val, verbose=1, batch_size=BATCH_SIZE)

    ### train SVM

    print("train svm")
    if train:
        training_indices = np.random.choice(len(x_train), int(0.001*len(x_train)), replace=False)
        svm_classif = SVC(probability=True).fit(x_train[training_indices], np.argmax(y_train[training_indices],
                                                                                     axis=1))
        pickle.dump(svm_classif,file=open('models/' + datatype + '_svm.pk', 'wb'))
    else:
        svm_classif = pickle.load(open('models/' + datatype + '_svm.pk', 'rb'))

    pred_val = svm_classif.predict_proba(x_val)

    r = 3