import tensorflow as tf
from keras.layers import InputLayer, Dense, Conv2D, Conv2DTranspose,\
    Dropout, BatchNormalization, Reshape, Flatten, MaxPool2D, UpSampling2D
from keras.models import Sequential
from keras.preprocessing.image import ImageDataGenerator
from keras.optimizers import Adam
from keras.losses import BinaryCrossentropy
import matplotlib.pyplot as plt
from os import makedirs, getcwd
import sys
from math import ceil
import time
from tqdm import tqdm

from IPython import display


conf_path = getcwd()
sys.path.append(conf_path)

from utils.celeba_hq import create_celeba_hq
from src.geo_conv import GeoConv2D, GeoConv2DTranspose
from src.coord_conv import CoordConv, CoordConvTranspose

# Global variables
INPUT_SHAPE = (256, 256, 3)
BATCH_SIZE = 32
LATENT_DIMS = [128, 256]
ACTIVATION = tf.nn.leaky_relu
CROSS_ENTROPY = BinaryCrossentropy(from_logits=False)
EPOCHS = 500
NUM_EXAMPLES = 16
STD = 1
SEED = 0
CONV_ARC = GeoConv2D
ARC_FOLDER = CONV_ARC.__name__ + "/"

tf.random.set_seed(SEED)

noise_dim = LATENT_DIMS[1]


"""Prepare the Dataset"""
CELEBA_HQ_DIRS = create_celeba_hq()
DATASET_DIR = "resources/celeba_hq/"
TRAIN_DIR = CELEBA_HQ_DIRS["train"]
TEST_DIR = CELEBA_HQ_DIRS["val"]

# Set up the generator
datagen = ImageDataGenerator(
    rescale=1./255.,
    horizontal_flip=True,
)

train_datagen = datagen.flow_from_directory(
    DATASET_DIR,
    target_size=INPUT_SHAPE[:2],
    batch_size=BATCH_SIZE,
    shuffle=True,
    class_mode=None
)


"""Define the Generator"""
def make_generator_model(latent_dim, conv_arc=Conv2D):
    if conv_arc == Conv2D:
        conv_tr_arc = Conv2DTranspose
    elif conv_arc == GeoConv2D:
        conv_tr_arc = GeoConv2DTranspose
    elif conv_arc == CoordConv:
        conv_tr_arc = CoordConvTranspose
    else:
        raise ValueError("Invalid Convolutional Layer Architecture")

    model = Sequential([
        InputLayer(input_shape=(latent_dim,)),
        Dense(16*16*latent_dim, use_bias=False),
        BatchNormalization(),

        # Reshape from 1D to 3D
        Reshape((16, 16, latent_dim)),

        # 16 x 16 x latent_dim  --->  32 x 32 x 512
        conv_tr_arc(filters=512, kernel_size=3, strides=2, padding='same', activation=ACTIVATION),
        conv_arc(filters=512, kernel_size=3, strides=1, padding='same', activation=ACTIVATION),
        BatchNormalization(),

        # 32 x 32 x 512  --->  64 x 64 x 256
        conv_tr_arc(filters=256, kernel_size=5, strides=2, padding='same', activation=ACTIVATION),
        conv_arc(filters=256, kernel_size=3, strides=1, padding='same', activation=ACTIVATION),
        BatchNormalization(),

        # 64 x 64 x 256  --->  128 x 128 x 128
        conv_tr_arc(filters=128, kernel_size=3, strides=2, padding='same', activation=ACTIVATION),
        conv_arc(filters=128, kernel_size=3, strides=1, padding='same', activation=ACTIVATION),
        Dropout(0.1),

        # 128 x 128 x 64  --->  256 x 256 x 64
        conv_tr_arc(filters=64, kernel_size=3, strides=2, padding='same', activation=ACTIVATION),
        conv_arc(filters=64, kernel_size=3, strides=1, padding='same', activation=ACTIVATION),
        BatchNormalization(),

        # 256 x 256 x 64  --->  256 x 256 x 32 --->  256 x 256 x 16
        UpSampling2D(),
        conv_arc(filters=32, kernel_size=3, strides=2, padding='same', activation=ACTIVATION),
        conv_arc(filters=16, kernel_size=3, strides=1, padding='same', activation=ACTIVATION),
        BatchNormalization(),

        # 256 x 256 x 16  --->  256 x 256 x 3
        conv_arc(filters=3, kernel_size=3, strides=1, padding='same', activation='sigmoid')
    ])

    return model


"""Define the Discriminator"""
def make_discriminator_model(conv_arc=Conv2D):
    model = Sequential([
        InputLayer(input_shape=INPUT_SHAPE),

        conv_arc(filters=16, kernel_size=3, strides=1, activation=ACTIVATION),
        Dropout(0.2),
        conv_arc(filters=32, kernel_size=3, strides=1, activation=ACTIVATION),
        MaxPool2D(),

        conv_arc(filters=64, kernel_size=3, strides=1, activation=ACTIVATION),
        Dropout(0.2),
        conv_arc(filters=128, kernel_size=3, strides=1, activation=ACTIVATION),
        MaxPool2D(),

        conv_arc(filters=192, kernel_size=3, strides=1, activation=ACTIVATION),
        Dropout(0.2),
        conv_arc(filters=256, kernel_size=3, strides=1, activation=ACTIVATION),
        MaxPool2D(),

        conv_arc(filters=320, kernel_size=3, strides=1, activation=ACTIVATION),
        Dropout(0.2),
        conv_arc(filters=384, kernel_size=3, strides=1, activation=ACTIVATION),
        MaxPool2D(),

        conv_arc(filters=448, kernel_size=3, strides=1, activation=ACTIVATION),
        Dropout(0.2),
        conv_arc(filters=512, kernel_size=3, strides=1, activation=ACTIVATION),
        MaxPool2D(),

        Flatten(),
        Dense(units=32, activation=ACTIVATION),
        Dense(units=1, activation='sigmoid')
    ])

    return model


"""Loss functions"""
# The Discriminator's Loss function
def dis_loss(real_output, fake_output):
    real_loss = CROSS_ENTROPY(tf.ones_like(real_output), real_output)
    fake_loss = CROSS_ENTROPY(tf.zeros_like(fake_output), fake_output)
    return real_loss + fake_loss


# The Generator's Loss function
def gen_loss(fake_output):
    return CROSS_ENTROPY(tf.ones_like(fake_output), fake_output)


"""Optimisers"""
dis_optimiser = Adam(1e-4)  # Discriminator's Optimiser
gen_optimiser = Adam(1e-4)  # Generator's Optimiser


gen_model = make_generator_model(noise_dim, CONV_ARC)
# noise = tf.random.normal([1, noise_dim], stddev=STD)
# generated_img = gen_model(noise, training=False).numpy()
# plt.imshow(generated_img.squeeze())
# plt.show()

dis_model = make_discriminator_model(CONV_ARC)
# print(discriminator(generated_img))


# The training steps
@tf.function
def train_step(imgs):
    noise = tf.random.normal([BATCH_SIZE, noise_dim], stddev=STD)

    with tf.GradientTape() as gen_tape, tf.GradientTape() as dis_tape:
        fake_imgs = gen_model(noise, training=True)

        real_output = dis_model(imgs, training=True)
        fake_output = dis_model(fake_imgs, training=True)

        batch_gen_loss = gen_loss(fake_output)
        batch_dis_loss = dis_loss(real_output, fake_output)

        gen_gradients = gen_tape.gradient(batch_gen_loss, gen_model.trainable_variables)
        dis_gradients = dis_tape.gradient(batch_dis_loss, dis_model.trainable_variables)

        gen_optimiser.apply_gradients(zip(gen_gradients, gen_model.trainable_variables))
        dis_optimiser.apply_gradients(zip(dis_gradients, dis_model.trainable_variables))

    return tf.reduce_mean(batch_gen_loss), tf.reduce_sum(batch_dis_loss)


@tf.function
def train_gen_only():
    noise = tf.random.normal([BATCH_SIZE, noise_dim], stddev=STD)

    with tf.GradientTape() as gen_tape:
        fake_imgs = gen_model(noise, training=True)
        fake_output = dis_model(fake_imgs, training=True)
        batch_gen_loss = gen_loss(fake_output)

        gen_gradients = gen_tape.gradient(batch_gen_loss, gen_model.trainable_variables)
        gen_optimiser.apply_gradients(zip(gen_gradients, gen_model.trainable_variables))

    return tf.reduce_mean(batch_gen_loss)


@tf.function
def train_dis_only(imgs):
    noise = tf.random.normal([BATCH_SIZE, noise_dim], stddev=STD)

    with tf.GradientTape() as dis_tape:
        fake_imgs = gen_model(noise, training=True)

        real_output = dis_model(imgs, training=True)
        fake_output = dis_model(fake_imgs, training=True)

        batch_gen_loss = gen_loss(fake_output)
        batch_dis_loss = dis_loss(real_output, fake_output)

        dis_gradients = dis_tape.gradient(batch_dis_loss, dis_model.trainable_variables)
        dis_optimiser.apply_gradients(zip(dis_gradients, dis_model.trainable_variables))

    return tf.reduce_mean(batch_gen_loss), tf.reduce_sum(batch_dis_loss)


NUM_TRAIN_IMGS = train_datagen.n
STEPS_PER_EPOCH = ceil(NUM_TRAIN_IMGS/BATCH_SIZE)
imgs_seed = tf.random.normal([NUM_EXAMPLES, noise_dim], stddev=STD)

RESULTS_DIR = "results/gan/"
MODELS_DIR = RESULTS_DIR + "models/" + ARC_FOLDER
makedirs(MODELS_DIR, exist_ok=True)
HISTORIES_DIR = RESULTS_DIR + "histories/" + ARC_FOLDER
makedirs(HISTORIES_DIR, exist_ok=True)
GENERATED_IMGS_DIR = RESULTS_DIR + "images/" + ARC_FOLDER
makedirs(GENERATED_IMGS_DIR, exist_ok=True)


def generate_and_save_images(model, epoch, seed):
    # Notice `training` is set to False.
    # This is so all layers run in inference mode (batchnorm).
    predictions = model(seed, training=False).numpy()

    fig = plt.figure(figsize=(4, 4))

    for i in range(predictions.shape[0]):
        plt.subplot(4, 4, i+1)
        plt.imshow(predictions[i, :, :, :])
        plt.axis('off')

    plt.savefig('%simage_at_epoch_%d.png' % (GENERATED_IMGS_DIR, epoch))
    plt.show()


def no_dis_training(gen_avg_loss, dis_avg_loss, gen_dis_ratio=4.):
    if gen_avg_loss / dis_avg_loss > gen_dis_ratio:
        return True
    else:
        return False


def no_gen_training(gen_avg_loss, dis_avg_loss, gen_dis_ratio=2.):
    if dis_avg_loss / gen_avg_loss > gen_dis_ratio:
        return True
    else:
        return False


def train(dataset, epochs):
    start_time = time.time()
    gen_avg_loss, dis_avg_loss = 1, 1
    for epoch in range(epochs):
        epoch_start_time = time.time()

        with tqdm(total=STEPS_PER_EPOCH) as pbar:
            step = 0
            while step < STEPS_PER_EPOCH:

                if no_dis_training(gen_avg_loss, dis_avg_loss):    # If Discriminator is too good
                    gen_avg_loss = train_gen_only()

                elif no_gen_training(gen_avg_loss, dis_avg_loss):  # If Generator is too good
                    img_batch = dataset.next()
                    gen_avg_loss, dis_avg_loss = train_dis_only(img_batch)

                else:                                              # Otherwise
                    img_batch = dataset.next()
                    gen_avg_loss, dis_avg_loss = train_step(img_batch)

                pbar.update(1)
                pbar.set_description("%d/%d, gen %.3f, disc %.3f" % (step, STEPS_PER_EPOCH, gen_avg_loss, dis_avg_loss))

                step += 1

        # Produce images for the GIF as you go
        display.clear_output(wait=True)
        generate_and_save_images(gen_model,
                                 epoch + 1,
                                 imgs_seed)

        # Save the model every 15 epochs
        # if epoch % 15 == 0:
        #     checkpoint.save(file_prefix=checkpoint_prefix)

        print("Time for epoch %d is %.3f seconds." % (epoch + 1, time.time() - epoch_start_time))

    # Generate after the final epoch
    print("Training for %d epochs completed in %.3f seconds" % (EPOCHS, time.time() - start_time))


# Train the model
train(train_datagen, EPOCHS)


# some_noise = tf.random.normal([1, noise_dim], stddev=STD)
# generated_img = gen_model(some_noise, training=False).numpy()
# plt.imshow(generated_img.squeeze())
# plt.show()


gen_model.save_weights(MODELS_DIR+"generator.h5")
dis_model.save_weights(MODELS_DIR+"discriminator.h5")
