from tensorflow.keras.layers import Input, concatenate
from tensorflow.keras.models import Model
from tensorflow.keras.datasets import cifar10, cifar100
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.callbacks import ModelCheckpoint, LearningRateScheduler
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import categorical_crossentropy as cce
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import numpy as np
import tensorflow as tf
from model import resnet_v1
import utils
from scipy import stats
import argparse
tf.compat.v1.disable_eager_execution()


def parl_loss(y_true, y_pred, num_model=utils.num_ensemble):
    y_p = tf.split(y_pred, num_model, axis=-1)
    y_t = tf.split(y_true, num_model, axis=-1)
    loss_1 = cce(y_t[0], y_p[0])
    loss_2 = cce(y_t[1], y_p[1])
    loss_3 = cce(y_t[2], y_p[2])
    layer_sum = 0
    for i in utils.conv_layers[:args.num_layers]:
        g_0 = tf.reshape(tf.gradients(model.layers[i - 2].output, model.layers[0].output), [-1])
        g_1 = tf.reshape(tf.gradients(model.layers[i - 1].output, model.layers[0].output), [-1])
        g_2 = tf.reshape(tf.gradients(model.layers[i].output, model.layers[0].output), [-1])
        dot_product_1 = tf.tensordot(g_0, g_1, axes=1)
        norms_1 = tf.multiply(tf.norm(g_0), tf.norm(g_1))
        if tf.equal(norms_1, 0) is True:
            norms_1 = tf.math.add(norms_1, tf.constant([1e-12]))
        cs_1 = tf.divide(dot_product_1, norms_1)
        dot_product_2 = tf.tensordot(g_0, g_2, axes=1)
        norms_2 = tf.multiply(tf.norm(g_0), tf.norm(g_2))
        if tf.equal(norms_2, 0) is True:
            norms_2 = tf.math.add(norms_2, tf.constant([1e-12]))
        cs_2 = tf.divide(dot_product_2, norms_2)
        dot_product_3 = tf.tensordot(g_1, g_2, axes=1)
        norms_3 = tf.multiply(tf.norm(g_1), tf.norm(g_2))
        if tf.equal(norms_3, 0) is True:
            norms_3 = tf.math.add(norms_3, tf.constant([1e-12]))
        cs_3 = tf.divide(dot_product_3, norms_3)
        layer_sum = layer_sum + cs_1 + cs_2 + cs_3
    return loss_1 + loss_2 + loss_3 + 0.5 * layer_sum


parser = argparse.ArgumentParser()
parser.add_argument('--dataset', action='store', type=str, required=True)
parser.add_argument('--num_layers', action='store', type=int, required=True)
args = parser.parse_args()

if args.dataset == "cifar10":
    (x_train, y_train), (x_test, y_test) = cifar10.load_data()
    num_classes = 10
else:
    (x_train, y_train), (x_test, y_test) = cifar100.load_data()
    num_classes = 100

x_train = x_train.astype('float32') / 255
x_test = x_test.astype('float32') / 255
x_train_mean = np.mean(x_train, axis=0)
x_train -= x_train_mean
x_test -= x_train_mean
y_train = to_categorical(y_train, num_classes)
y_test = to_categorical(y_test, num_classes)

y_train_ensemble = []
y_test_ensemble = []
for _ in range(utils.num_ensemble):
    y_train_ensemble.append(y_train)
    y_test_ensemble.append(y_test)
y_train_ensemble = np.concatenate(y_train_ensemble, axis=-1)
y_test_ensemble = np.concatenate(y_test_ensemble, axis=-1)


model_input = Input(shape=(32, 32, 3))
model_out = []
for i in range(utils.num_ensemble):
    model_out.append(resnet_v1(model_input, depth=20, num_classes=num_classes))
model_output = concatenate(model_out)
model = Model(inputs=model_input, outputs=model_output)
checkpoint = ModelCheckpoint(args.dataset + "_parl_" + str(args.num_layers) + ".h5", verbose=1, save_best_only=True)
lr_scheduler = LearningRateScheduler(utils.lr_schedule)
callbacks = [checkpoint, lr_scheduler]
model.compile(loss=parl_loss, optimizer=Adam(learning_rate=utils.lr_schedule(0)), metrics=[utils.acc_metric])
datagen = ImageDataGenerator(
        # set input mean to 0 over the dataset
        featurewise_center=False,
        # set each sample mean to 0
        samplewise_center=False,
        # divide inputs by std of dataset
        featurewise_std_normalization=False,
        # divide each input by its std
        samplewise_std_normalization=False,
        # apply ZCA whitening
        zca_whitening=False,
        # epsilon for ZCA whitening
        zca_epsilon=1e-06,
        # randomly rotate images in the range (deg 0 to 180)
        rotation_range=0,
        # randomly shift images horizontally
        width_shift_range=0.1,
        # randomly shift images vertically
        height_shift_range=0.1,
        # set range for random shear
        shear_range=0.,
        # set range for random zoom
        zoom_range=0.,
        # set range for random channel shifts
        channel_shift_range=0.,
        # set mode for filling points outside the input boundaries
        fill_mode='nearest',
        # value used for fill_mode = "constant"
        cval=0.,
        # randomly flip images
        horizontal_flip=True,
        # randomly flip images
        vertical_flip=False,
        # set rescaling factor (applied before any other transformation)
        rescale=None,
        # set function that will be applied on each input
        preprocessing_function=None,
        # image data format, either "channels_first" or "channels_last"
        data_format=None,
        # fraction of images reserved for validation (strictly between 0 and 1)
        validation_split=0.0)
datagen.fit(x_train)
with tf.device('/device:GPU:0'):
    model.fit(datagen.flow(x_train, y_train_ensemble, batch_size=64), validation_data=(x_test, y_test_ensemble), epochs=100, verbose=1, callbacks=callbacks)

model = load_model(args.dataset + "_parl_" + str(args.num_layers) + ".h5", custom_objects={'parl_loss': parl_loss, 'acc_metric': utils.acc_metric})
model_predict = model.predict(x_test)
y_p = np.split(model_predict, utils.num_ensemble, axis=-1)
predictions = []
for i in range(utils.num_ensemble):
    predictions.append(np.argmax(y_p[i], axis=1))
mode_ensemble_predictions = stats.mode(predictions, axis=0)[0]
ensemble_accuracy = np.sum(mode_ensemble_predictions == np.argmax(y_test, axis=1)) / len(y_test)
print("Ensemble Accuracy: " + str(ensemble_accuracy))
