from tensorflow.keras.layers import Input, concatenate
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.datasets import cifar10, cifar100
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.callbacks import ModelCheckpoint, LearningRateScheduler
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import numpy as np
import tensorflow as tf
from model import resnet_v1
import utils
import argparse


parser = argparse.ArgumentParser()
parser.add_argument('--dataset', action='store', type=str, required=True)
args = parser.parse_args()

if args.dataset == "cifar10":
    (x_train, y_train), (x_test, y_test) = cifar10.load_data()
    num_classes = 10
else:
    (x_train, y_train), (x_test, y_test) = cifar100.load_data()
    num_classes = 100
x_train = x_train.astype('float32') / 255
x_test = x_test.astype('float32') / 255
x_train_mean = np.mean(x_train, axis=0)
x_train -= x_train_mean
x_test -= x_train_mean
y_train = to_categorical(y_train, num_classes)
y_test = to_categorical(y_test, num_classes)

y_train_ensemble = []
y_test_ensemble = []
for _ in range(utils.num_ensemble):
    y_train_ensemble.append(y_train)
    y_test_ensemble.append(y_test)
y_train_ensemble = np.concatenate(y_train_ensemble, axis=-1)
y_test_ensemble = np.concatenate(y_test_ensemble, axis=-1)


model_input = Input(shape=(32, 32, 3))
model_out = []
for i in range(utils.num_ensemble):
    model_out.append(resnet_v1(model_input, depth=20, num_classes=num_classes))
model_output = concatenate(model_out)
model = Model(inputs=model_input, outputs=model_output)
checkpoint = ModelCheckpoint(args.dataset + "_baseline.h5", verbose=1, save_best_only=True)
lr_scheduler = LearningRateScheduler(utils.lr_schedule)
callbacks = [checkpoint, lr_scheduler]
model.compile(loss=utils.ens_loss, optimizer=Adam(learning_rate=utils.lr_schedule(0)), metrics=[utils.acc_metric])
datagen = ImageDataGenerator(
        # set input mean to 0 over the dataset
        featurewise_center=False,
        # set each sample mean to 0
        samplewise_center=False,
        # divide inputs by std of dataset
        featurewise_std_normalization=False,
        # divide each input by its std
        samplewise_std_normalization=False,
        # apply ZCA whitening
        zca_whitening=False,
        # epsilon for ZCA whitening
        zca_epsilon=1e-06,
        # randomly rotate images in the range (deg 0 to 180)
        rotation_range=0,
        # randomly shift images horizontally
        width_shift_range=0.1,
        # randomly shift images vertically
        height_shift_range=0.1,
        # set range for random shear
        shear_range=0.,
        # set range for random zoom
        zoom_range=0.,
        # set range for random channel shifts
        channel_shift_range=0.,
        # set mode for filling points outside the input boundaries
        fill_mode='nearest',
        # value used for fill_mode = "constant"
        cval=0.,
        # randomly flip images
        horizontal_flip=True,
        # randomly flip images
        vertical_flip=False,
        # set rescaling factor (applied before any other transformation)
        rescale=None,
        # set function that will be applied on each input
        preprocessing_function=None,
        # image data format, either "channels_first" or "channels_last"
        data_format=None,
        # fraction of images reserved for validation (strictly between 0 and 1)
        validation_split=0.0)
datagen.fit(x_train)
with tf.device('/device:GPU:0'):
    model.fit(datagen.flow(x_train, y_train_ensemble, batch_size=64), validation_data=(x_test, y_test_ensemble), epochs=100, verbose=1, callbacks=callbacks)


model = load_model(args.dataset + "_baseline.h5", custom_objects={'ens_loss': utils.ens_loss, 'acc_metric': utils.acc_metric})
model_predict = model.predict(x_test)
y_p = np.split(model_predict, utils.num_ensemble, axis=-1)
probabilities = []
for i in range(utils.num_ensemble):
    probabilities.append(y_p[i])
    accuracy = np.sum(np.argmax(y_p[i], axis=1) == np.argmax(y_test, axis=1)) / len(y_test)
average_ensemble_probability = np.mean(probabilities, axis=0)
ensemble_prediction_average = np.argmax(average_ensemble_probability, axis=1)
ensemble_accuracy = np.sum(ensemble_prediction_average == np.argmax(y_test, axis=1)) / len(y_test)
print("Ensemble Accuracy: " + str(ensemble_accuracy))
