from tensorflow.keras.layers import Input, concatenate
from tensorflow.keras.models import Model
from tensorflow.keras.datasets import cifar10
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.callbacks import ModelCheckpoint, LearningRateScheduler
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import categorical_crossentropy as cce
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import numpy as np
import tensorflow as tf
from model import resnet_v1
import utils
tf.compat.v1.disable_eager_execution()


def gal_loss(y_true, y_pred, num_model=utils.num_ensemble):
    y_p = tf.split(y_pred, num_model, axis=-1)
    y_t = tf.split(y_true, num_model, axis=-1)
    loss_0 = cce(y_t[0], y_p[0])
    loss_1 = cce(y_t[1], y_p[1])
    loss_2 = cce(y_t[2], y_p[2])
    grads_0 = tf.gradients(loss_0, model.layers[0].output)
    g_0 = tf.reshape(grads_0, [-1])
    grads_1 = tf.gradients(loss_1, model.layers[0].output)
    g_1 = tf.reshape(grads_1, [-1])
    grads_2 = tf.gradients(loss_2, model.layers[0].output)
    g_2 = tf.reshape(grads_2, [-1])

    dot_product_1 = tf.tensordot(g_0, g_1, axes=1)
    norms_1 = tf.multiply(tf.norm(g_0), tf.norm(g_1))
    if tf.equal(norms_1, 0) is True:
        norms_1 = tf.math.add(norms_1, tf.constant([1e-9]))
    cs_1 = tf.divide(dot_product_1, norms_1)

    dot_product_2 = tf.tensordot(g_0, g_2, axes=1)
    norms_2 = tf.multiply(tf.norm(g_0), tf.norm(g_2))
    if tf.equal(norms_2, 0) is True:
        norms_2 = tf.math.add(norms_2, tf.constant([1e-9]))
    cs_2 = tf.divide(dot_product_2, norms_2)

    dot_product_3 = tf.tensordot(g_1, g_2, axes=1)
    norms_3 = tf.multiply(tf.norm(g_1), tf.norm(g_2))
    if tf.equal(norms_3, 0) is True:
        norms_3 = tf.math.add(norms_3, tf.constant([1e-9]))
    cs_3 = tf.divide(dot_product_3, norms_3)
    sum = tf.math.log(tf.math.exp(cs_1) + tf.math.exp(cs_2) + tf.math.exp(cs_3))
    return loss_0 + loss_1 + loss_2 + 0.5 * sum


(x_train, y_train), (x_test, y_test) = cifar10.load_data()
num_classes = 10
x_train = x_train.astype('float32') / 255
x_test = x_test.astype('float32') / 255
x_train_mean = np.mean(x_train, axis=0)
x_train -= x_train_mean
x_test -= x_train_mean
y_train = to_categorical(y_train, num_classes)
y_test = to_categorical(y_test, num_classes)

y_train_ensemble = []
y_test_ensemble = []
for _ in range(utils.num_ensemble):
    y_train_ensemble.append(y_train)
    y_test_ensemble.append(y_test)
y_train_ensemble = np.concatenate(y_train_ensemble, axis=-1)
y_test_ensemble = np.concatenate(y_test_ensemble, axis=-1)


model_input = Input(shape=(32, 32, 3))
model_out = []
for i in range(utils.num_ensemble):
    model_out.append(resnet_v1(model_input, depth=20, num_classes=num_classes))
model_output = concatenate(model_out)
model = Model(inputs=model_input, outputs=model_output)
checkpoint = ModelCheckpoint("cifar10_gal.h5", verbose=1, save_best_only=True)
lr_scheduler = LearningRateScheduler(utils.lr_schedule)
callbacks = [checkpoint, lr_scheduler]
model.compile(loss=gal_loss, optimizer=Adam(learning_rate=utils.lr_schedule(0)), metrics=[utils.acc_metric])
datagen = ImageDataGenerator(
        # set input mean to 0 over the dataset
        featurewise_center=False,
        # set each sample mean to 0
        samplewise_center=False,
        # divide inputs by std of dataset
        featurewise_std_normalization=False,
        # divide each input by its std
        samplewise_std_normalization=False,
        # apply ZCA whitening
        zca_whitening=False,
        # epsilon for ZCA whitening
        zca_epsilon=1e-06,
        # randomly rotate images in the range (deg 0 to 180)
        rotation_range=0,
        # randomly shift images horizontally
        width_shift_range=0.1,
        # randomly shift images vertically
        height_shift_range=0.1,
        # set range for random shear
        shear_range=0.,
        # set range for random zoom
        zoom_range=0.,
        # set range for random channel shifts
        channel_shift_range=0.,
        # set mode for filling points outside the input boundaries
        fill_mode='nearest',
        # value used for fill_mode = "constant"
        cval=0.,
        # randomly flip images
        horizontal_flip=True,
        # randomly flip images
        vertical_flip=False,
        # set rescaling factor (applied before any other transformation)
        rescale=None,
        # set function that will be applied on each input
        preprocessing_function=None,
        # image data format, either "channels_first" or "channels_last"
        data_format=None,
        # fraction of images reserved for validation (strictly between 0 and 1)
        validation_split=0.0)
datagen.fit(x_train)
with tf.device('/device:GPU:0'):
    model.fit(datagen.flow(x_train, y_train_ensemble, batch_size=64), validation_data=(x_test, y_test_ensemble), epochs=100, verbose=1, callbacks=callbacks)

model = load_model("cifar10_gal.h5", custom_objects={'gal_loss': gal_loss, 'acc_metric': utils.acc_metric})
model_predict = model.predict(x_test)
y_p = np.split(model_predict, utils.num_ensemble, axis=-1)
probabilities = []
for i in range(utils.num_ensemble):
    probabilities.append(y_p[i])
    accuracy = np.sum(np.argmax(y_p[i], axis=1) == np.argmax(y_test, axis=1)) / len(y_test)
average_ensemble_probability = np.mean(probabilities, axis=0)
ensemble_prediction_average = np.argmax(average_ensemble_probability, axis=1)
ensemble_accuracy = np.sum(ensemble_prediction_average == np.argmax(y_test, axis=1)) / len(y_test)
print("Ensemble Accuracy: " + str(ensemble_accuracy))
