import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import warnings
warnings.filterwarnings('ignore')

import tensorflow
tensorflow.compat.v1.logging.set_verbosity(tensorflow.compat.v1.logging.ERROR)
from mnist_model_sup import base_model_mnist, ct_model_mnist
from mnist_utils_sup import DataGeneratorCMNIST
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import to_categorical
import numpy as np

######################

batch_size = 128
img_dim = 28
num_classes = 2
num_epochs = 20
use_bn = True # to train a model w/o batch normaliztion

######################

(x_train, y_train), (x_test, y_test) = mnist.load_data()

indices_train = np.where((y_train == 2) | (y_train == 3))
x_train_df = x_train[indices_train[0], :, :]
y_train_df = y_train[indices_train[0]]
y_train_df[y_train_df == 2] = 0
y_train_df[y_train_df == 3] = 1

indices_test = np.where((y_test == 2) | (y_test == 3))
x_test_df = x_test[indices_test[0], :, :]
y_test_df = y_test[indices_test[0]]
y_test_df[y_test_df == 2] = 0
y_test_df[y_test_df == 3] = 1


x_train_df_rb = np.reshape(x_train_df, (len(x_train_df), 28, 28, 1))
x_train_rgb_rb = np.zeros((np.shape(x_train_df_rb)[0], 28, 28, 3))

x_test_df_rb = np.reshape(x_test_df, (len(x_test_df), 28, 28, 1))
x_test_rgb_rb = np.zeros((np.shape(x_test_df_rb)[0], 28, 28, 3))

x_test_df_r = np.reshape(x_test_df, (len(x_test_df), 28, 28, 1))
x_test_rgb_r = np.zeros((np.shape(x_test_df_r)[0], 28, 28, 3))

x_test_df_b = np.reshape(x_test_df, (len(x_test_df), 28, 28, 1))
x_test_rgb_b = np.zeros((np.shape(x_test_df_b)[0], 28, 28, 3))

x_test_df_none = np.reshape(x_test_df, (len(x_test_df), 28, 28, 1))
x_test_rgb_none = np.zeros((np.shape(x_test_df_none)[0], 28, 28, 3))


red_square = np.zeros((6, 6, 3))
red_square[:, :, 0] = 255
red_square[:, :, 1] = 0
red_square[:, :, 2] = 0

blue_square = np.zeros((6, 6, 3))
blue_square[:, :, 0] = 0
blue_square[:, :, 1] = 0
blue_square[:, :, 2] = 255

for i in range(np.shape(y_train_df)[0]):
    x_train_rgb_rb[i] = np.dstack([x_train_df[i]] * 3)
    if y_train_df[i] == 1:
        x_train_rgb_rb[i][22:, 22:, :] = red_square
        x_train_rgb_rb[i][0:6, 0:6, :] = blue_square

for i in range(np.shape(y_test_df)[0]):
    x_test_rgb_rb[i] = np.dstack([x_test_df[i]] * 3)
    if y_test_df[i] == 1:
        x_test_rgb_rb[i][22:, 22:, :] = red_square
        x_test_rgb_rb[i][0:6, 0:6, :] = blue_square

for i in range(np.shape(y_test_df)[0]):
    x_test_rgb_r[i] = np.dstack([x_test_df[i]] * 3)
    if y_test_df[i] == 1:
        x_test_rgb_r[i][22:, 22:, :] = red_square

for i in range(np.shape(y_test_df)[0]):
    x_test_rgb_b[i] = np.dstack([x_test_df[i]] * 3)
    if y_test_df[i] == 1:
        x_test_rgb_b[i][0:6, 0:6, :] = blue_square

for i in range(np.shape(y_test_df)[0]):
    x_test_rgb_none[i] = np.dstack([x_test_df[i]] * 3)



y_train_df = to_categorical(y_train_df)
y_test_df = to_categorical(y_test_df)


train_generator = DataGeneratorCMNIST(x_train_rgb_rb, y_train_df, batch_size=batch_size)
val_generator = DataGeneratorCMNIST(x_test_rgb_rb, y_test_df, batch_size=batch_size)
test_generator_rb = DataGeneratorCMNIST(x_test_rgb_rb, y_test_df, batch_size=batch_size)
test_generator_r = DataGeneratorCMNIST(x_test_rgb_r, y_test_df, batch_size=batch_size)
test_generator_b = DataGeneratorCMNIST(x_test_rgb_b, y_test_df, batch_size=batch_size)
test_generator_none = DataGeneratorCMNIST(x_test_rgb_none, y_test_df, batch_size=batch_size)


if not os.path.isdir('./weights_ct'):
    os.makedirs('./weights_ct')

################### Training the Model ###################

model = base_model_mnist((img_dim, img_dim, 3), num_cls=num_classes, use_bn=use_bn, ct=False)
# model.summary()
model.compile(optimizer='sgd', loss='categorical_crossentropy', metrics=['accuracy'])

filepath = './weights_ct/mnist_.h5'

checkpoint = ModelCheckpoint(filepath, save_weights_only=True, save_best_only=True, monitor='val_acc', verbose=0)
_callbacks = [checkpoint]
if use_bn:
    print('\n\n *** Training with Batch Normalization...\n\n')
else:
    print('\n\n *** Training without Batch Normalization...\n\n')

model.fit_generator(generator=train_generator, steps_per_epoch=np.shape(x_train_rgb_rb)[0]//batch_size, verbose=2,
                    epochs=num_epochs, validation_data=val_generator, callbacks=_callbacks)


################### Testing the Model ###################

if use_bn:
    print('\n\n *** Testing the model with Batch Normalization on 4 sets...\n\n')
else:
    print('\n\n *** Testing the model without Batch Normalization on 4 sets...\n\n')


loss1, acc1 = \
    model.evaluate_generator(test_generator_rb, steps=np.shape(x_test_rgb_rb)[0] // batch_size, verbose=0)
loss2, acc2 = \
    model.evaluate_generator(test_generator_r, steps=np.shape(x_test_rgb_rb)[0] // batch_size, verbose=0)
loss3, acc3 = \
    model.evaluate_generator(test_generator_b, steps=np.shape(x_test_rgb_rb)[0] // batch_size, verbose=0)
loss4, acc4 = \
    model.evaluate_generator(test_generator_none, steps=np.shape(x_test_rgb_rb)[0] // batch_size, verbose=0)
print('\nRB: {}, R: {}, B: {}, None: {} \n'.format(acc1, acc2, acc3, acc4))
