from os import makedirs, getcwd
import sys
from math import floor
import numpy as np


import tensorflow as tf
from keras.preprocessing.image import ImageDataGenerator

# from keras import mixed_precision
# policy = mixed_precision.Policy('mixed_float16')
# mixed_precision.set_global_policy(policy)

import typing
import imageio
import cv2
from keras.callbacks import TensorBoard

from keras.layers import Dense, Conv2D, Conv2DTranspose, Dropout, BatchNormalization,\
    Reshape, Flatten, Input, UpSampling2D, MaxPool2D
from keras.optimizers import Optimizer

from src.geo_conv import GeoConv2D, GeoConv2DTranspose
from src.coord_conv import CoordConv, CoordConvTranspose
from utils.celeba_hq import create_celeba_hq
import matplotlib.pyplot as plt


# Configure the path
conf_path = getcwd()
sys.path.append(conf_path)


# Specify the dataset specifications
CELEBA_HQ = {
    'name': 'celeba_hq',
    'path': 'celeba_hq',
    'img_shape': (256, 256, 3),
    'batch_size': 24,
    'num_epochs': 100,
}

# The dataset
DATASET = CELEBA_HQ


img_shape = DATASET['img_shape']
batch_size = DATASET['batch_size']
num_epochs = DATASET['num_epochs']

# TODO: Change below line
ATTEMPT = '0'
CONV_ARC = GeoConv2D
ARC_FOLDER = CONV_ARC.__name__ + "/" + ATTEMPT + "/"

latent_dim = 128
ACTIVATION = tf.nn.leaky_relu

RESULTS_DIR = "results/wgan/" + DATASET["name"] + "/"
MODELS_DIR = RESULTS_DIR + "models/" + ARC_FOLDER
makedirs(MODELS_DIR, exist_ok=True)
HISTORIES_DIR = RESULTS_DIR + "histories/" + ARC_FOLDER
makedirs(HISTORIES_DIR, exist_ok=True)
GENERATED_IMGS_DIR = RESULTS_DIR + "images/" + ARC_FOLDER
makedirs(GENERATED_IMGS_DIR, exist_ok=True)


# load dataset
# -------------------------------
"""Prepare the Dataset"""
CELEBA_HQ_DIRS = create_celeba_hq()
DATASET_DIR = "resources/celeba_hq/"
TRAIN_DIR = CELEBA_HQ_DIRS["train"]
TEST_DIR = CELEBA_HQ_DIRS["val"]

# Set up the generator
datagen = ImageDataGenerator(
    preprocessing_function=lambda x:(x/127.5) - 1.0,
    horizontal_flip=True,
)

train_datagen = datagen.flow_from_directory(
    DATASET_DIR,
    target_size=img_shape[:2],
    batch_size=batch_size,
    shuffle=True,
    class_mode=None
)
# -------------------------------


# Define the generator model
def build_generator(
        noise_dim,
        output_channels=3,
        activation="tanh",
        conv_arc=Conv2D,
        geo_shift=True,
        geo_tr_shift=False
):
    if conv_arc == Conv2D:
        conv_tr_arc = Conv2DTranspose
    elif conv_arc == GeoConv2D:
        def conv_arc(**kwargs):
            shift = kwargs.pop("shift", geo_shift)
            return GeoConv2D(shift=shift, **kwargs)

        def conv_tr_arc(**kwargs):
            shift = kwargs.pop("shift", geo_tr_shift)
            return GeoConv2DTranspose(shift=shift, **kwargs)

    elif conv_arc == CoordConv:
        conv_tr_arc = CoordConvTranspose
    latent_inputs = Input(shape=noise_dim, name="input")

    x = Dense(16*16*latent_dim, use_bias=False, trainable=True)(latent_inputs)

    # Reshape from 1D to 3D
    x = Reshape((16, 16, latent_dim))(x)

    # 16 x 16 x latent_dim  --->  32 x 32 x 512
    x = conv_tr_arc(filters=512, kernel_size=5, strides=2, padding='same', activation=ACTIVATION)(x)
    x = conv_arc(filters=512, kernel_size=3, strides=1, padding='same', activation=ACTIVATION)(x)
    x = BatchNormalization()(x)

    # 32 x 32 x 512  --->  64 x 64 x 256
    x = conv_tr_arc(filters=256, kernel_size=5, strides=2, padding='same', activation=ACTIVATION)(x)
    x = conv_arc(filters=256, kernel_size=3, strides=1, padding='same', activation=ACTIVATION)(x)
    x = BatchNormalization()(x)

    # 64 x 64 x 256  --->  128 x 128 x 128
    x = conv_tr_arc(filters=128, kernel_size=5, strides=2, padding='same', activation=ACTIVATION)(x)
    x = conv_arc(filters=128, kernel_size=3, strides=1, padding='same', activation=ACTIVATION)(x)
    x = Dropout(0.1)(x)

    # 128 x 128 x 64  --->  256 x 256 x 64
    x = conv_tr_arc(filters=64, kernel_size=5, strides=2, padding='same', activation=ACTIVATION)(x)
    x = conv_arc(filters=64, kernel_size=3, strides=1, padding='same', activation=ACTIVATION)(x)
    x = BatchNormalization()(x)

    # 256 x 256 x 64  --->  256 x 256 x 32 --->  256 x 256 x 16
    x = UpSampling2D()(x)
    x = conv_arc(filters=32, kernel_size=3, strides=2, padding='same', activation=ACTIVATION)(x)
    x = conv_arc(filters=16, kernel_size=3, strides=1, padding='same', activation=ACTIVATION)(x)
    x = conv_arc(filters=8, kernel_size=5, strides=1, padding='same', activation=ACTIVATION)(x)
    x = Dropout(0.25)(x)
    x = BatchNormalization()(x)

    # 256 x 256 x 16  --->  256 x 256 x 3
    x = conv_arc(filters=3, kernel_size=3, strides=1, padding='same', activation=activation)(x)
    assert x.shape == (None, 256, 256, output_channels)

    model = tf.keras.Model(inputs=latent_inputs, outputs=x)

    return model


# Define the discriminator model
def build_discriminator(img_shape, activation='linear', conv_arc=Conv2D, geo_shift=True):
    if conv_arc == GeoConv2D:
        def conv_arc(**kwargs):
            shift = kwargs.pop("shift", geo_shift)
            return GeoConv2D(shift=shift, **kwargs)

    img_inputs = Input(shape=img_shape, name="input")

    x = conv_arc(filters=16, kernel_size=5, strides=1, activation=ACTIVATION)(img_inputs)
    x = Dropout(0.2)(x)
    x = conv_arc(filters=32, kernel_size=3, strides=1, activation=ACTIVATION)(x)
    x = MaxPool2D()(x)

    x = conv_arc(filters=64, kernel_size=5, strides=1, activation=ACTIVATION)(x)
    x = Dropout(0.2)(x)
    x = conv_arc(filters=128, kernel_size=3, strides=1, activation=ACTIVATION)(x)
    x = MaxPool2D()(x)

    x = conv_arc(filters=192, kernel_size=5, strides=1, activation=ACTIVATION)(x)
    x = Dropout(0.2)(x)
    x = conv_arc(filters=256, kernel_size=3, strides=1, activation=ACTIVATION)(x)
    x = MaxPool2D()(x)

    x = conv_arc(filters=320, kernel_size=5, strides=1, activation=ACTIVATION)(x)
    x = Dropout(0.2)(x)
    x = conv_arc(filters=384, kernel_size=3, strides=1, activation=ACTIVATION)(x)
    x = MaxPool2D()(x)

    x = conv_arc(filters=448, kernel_size=5, strides=1, activation=ACTIVATION)(x)
    x = Dropout(0.2)(x)
    x = conv_arc(filters=512, kernel_size=3, strides=1, activation=ACTIVATION)(x)
    x = MaxPool2D()(x)

    x = Flatten()(x)
    x = Dropout(0.25)(x)

    x = Dense(1, activation=activation, dtype='float32')(x)

    model = tf.keras.Model(inputs=img_inputs, outputs=x)

    return model


class WGAN_GP(tf.keras.models.Model):
    def __init__(
            self,
            discriminator: tf.keras.models.Model,
            generator: tf.keras.models.Model,
            noise_dim: int,
            discriminator_extra_steps: int=5,
            gp_weight: typing.Union[float, int]=10.0
        ) -> None:
        super(WGAN_GP, self).__init__()
        self.discriminator = discriminator
        self.generator = generator
        self.noise_dim = noise_dim
        self.discriminator_extra_steps = discriminator_extra_steps
        self.gp_weight = gp_weight

    def compile(
            self,
            discriminator_opt: Optimizer,
            generator_opt: Optimizer,
            discriminator_loss: typing.Callable,
            generator_loss: typing.Callable,
            **kwargs
        ) -> None:
        super(WGAN_GP, self).compile(**kwargs)
        self.discriminator_opt = discriminator_opt
        self.generator_opt = generator_opt
        self.discriminator_loss = discriminator_loss
        self.generator_loss = generator_loss

    def add_instance_noise(self, x: tf.Tensor, stddev: float=0.1) -> tf.Tensor:
        """ Adds instance noise to the input tensor."""
        noise = tf.random.normal(tf.shape(x), mean=0.0, stddev=stddev, dtype=x.dtype)
        return x + noise

    def gradient_penalty(
            self,
            images: tf.Tensor,
            fake_samples: tf.Tensor,
            discriminator: tf.keras.models.Model
        ) -> tf.Tensor:
        """ Calculates the gradient penalty.

        Gradient penalty is calculated on an interpolated data
        and added to the discriminator loss.
        """
        batch_size = tf.shape(images)[0]
        # Generate random values for epsilon
        epsilon = tf.random.uniform(shape=[batch_size, 1, 1, 1], minval=0, maxval=1)

        # 1. Interpolate between real and fake samples
        interpolated_samples = epsilon * images + ((1 - epsilon) * fake_samples)

        with tf.GradientTape() as tape:
            tape.watch(interpolated_samples)
            # 2. Get the Critic's output for the interpolated image
            logits = discriminator(interpolated_samples, training=True)

        # 3. Calculate the gradients w.r.t to the interpolated image
        gradients = tape.gradient(logits, interpolated_samples)

        # 4. Calculate the L2 norm of the gradients.
        gradients_norm = tf.sqrt(tf.reduce_sum(tf.square(gradients), axis=[1, 2, 3]))

        # 5. Calculate gradient penalty
        gradient_penalty = tf.reduce_mean((gradients_norm - 1.0) ** 2)

        return gradient_penalty

    def train_step(self, real_samples: tf.Tensor) -> typing.Dict[str, float]:
        images = real_samples
        batch_size = tf.shape(images)[0]
        noise = tf.random.normal([batch_size, self.noise_dim])
        gps = []

        # Step 1. Train the discriminator with both real and fake samples
        # Train the discriminator more often than the generator
        for _ in range(self.discriminator_extra_steps):

            # Step 1. Train the discriminator with both real images and fake images
            with tf.GradientTape() as tape:
                fake_samples = self.generator(noise, training=True)
                pred_real = self.discriminator(images, training=True)
                pred_fake = self.discriminator(fake_samples, training=True)

                # Add instance noise to real and fake samples
                images = self.add_instance_noise(images)
                fake_samples = self.add_instance_noise(fake_samples)

                # Calculate the WGAN-GP gradient penalty
                gp = self.gradient_penalty(images, fake_samples, self.discriminator)
                gps.append(gp)

                # Add gradient penalty to the original discriminator loss
                disc_loss = self.discriminator_loss(pred_real, pred_fake) + gp * self.gp_weight

            # Compute discriminator gradients
            grads = tape.gradient(disc_loss, self.discriminator.trainable_variables)

            # Update discriminator weights
            self.discriminator_opt.apply_gradients(zip(grads, self.discriminator.trainable_variables))

        # Step 2. Train the generator
        with tf.GradientTape() as tape:
            fake_samples = self.generator(noise, training=True)
            pred_fake = self.discriminator(fake_samples, training=True)
            gen_loss = self.generator_loss(pred_fake)

        # Compute generator gradients
        grads = tape.gradient(gen_loss, self.generator.trainable_variables)

        # Update generator weights
        self.generator_opt.apply_gradients(zip(grads, self.generator.trainable_variables))

        # Update the metrics.
        # Metrics are configured in `compile()`.
        self.compiled_metrics.update_state(images, fake_samples)

        results = {m.name: m.result() for m in self.metrics}
        results.update({"d_loss": disc_loss, "g_loss": gen_loss, "gp": tf.reduce_mean(gps)})

        return results


class ResultsCallback(tf.keras.callbacks.Callback):
    """ Callback for generating and saving images during training."""
    def __init__(
            self,
            noise_dim: int,
            model_path: str,
            imgs_path: str,
            examples_to_generate: int = 16,
            grid_size: tuple = (4, 4),
            spacing: int = 5,
            gif_size: tuple = (416, 416),
            duration: float = 0.1,
            save_model: bool = True
        ) -> None:
        super(ResultsCallback, self).__init__()
        self.seed = tf.random.normal([examples_to_generate, noise_dim])
        self.results = []
        self.model_path = model_path
        self.imgs_path = imgs_path
        self.grid_size = grid_size
        self.spacing = spacing
        self.gif_size = gif_size
        self.duration = duration
        self.save_model = save_model

    def save_plt(self, epoch: int, results: np.ndarray):
        # construct an image from generated images with spacing between them using numpy
        w, h, c = results[0].shape
        # construct grind with self.grid_size
        grid = np.zeros((self.grid_size[0] * w + (self.grid_size[0] - 1) * self.spacing,
                         self.grid_size[1] * h + (self.grid_size[1] - 1) * self.spacing, c), dtype=np.uint8)
        for i in range(self.grid_size[0]):
            for j in range(self.grid_size[1]):
                grid[i * (w + self.spacing):i * (w + self.spacing) + w,
                j * (h + self.spacing):j * (h + self.spacing) + h] = results[i * self.grid_size[1] + j]

        grid = cv2.cvtColor(grid, cv2.COLOR_RGB2BGR)

        # save the image
        cv2.imwrite(f'{self.imgs_path}/img_{epoch}.png', grid)

        # save image to memory resized to gif size
        self.results.append(cv2.resize(grid, self.gif_size, interpolation=cv2.INTER_AREA))

    def on_epoch_end(self, epoch: int, logs: dict=None):
        # Define your custom code here that should be executed at the end of each epoch
        predictions = self.model.generator(self.seed, training=False)
        predictions_uint8 = (predictions * 127.5 + 127.5).numpy().astype(np.uint8)
        self.save_plt(epoch, predictions_uint8)

        if self.save_model:
            # save keras model to disk
            self.model.discriminator.save(self.model_path + "/discriminator.h5")
            self.model.generator.save(self.model_path + "/generator.h5")

    def on_train_end(self, logs: dict=None):
        # save the results as a gif with imageio

        # Create a list of imageio image objects from the OpenCV images
        # image is in BGR format, convert to RGB format when loading
        imageio_images = [imageio.core.util.Image(image[...,::-1]) for image in self.results]

        # Write the imageio images to a GIF file
        imageio.mimsave(self.imgs_path + "/output.gif", imageio_images, duration=self.duration)


class LRScheduler(tf.keras.callbacks.Callback):
    """Learning rate scheduler for WGAN-GP"""
    def __init__(self, decay_epochs: int, tb_callback=None, min_lr: float=0.0000002):
        super(LRScheduler, self).__init__()
        self.decay_epochs = decay_epochs
        self.min_lr = min_lr
        self.tb_callback = tb_callback
        self.compiled = False

    def on_epoch_end(self, epoch, logs=None):
        if not self.compiled:
            self.generator_lr = self.model.generator_opt.lr.numpy()
            self.discriminator_lr = self.model.discriminator_opt.lr.numpy()
            self.compiled = True

        if epoch < self.decay_epochs:
            # new_g_lr = max(self.generator_lr * (1 - (epoch / self.decay_epochs)), self.min_lr)
            # self.model.generator_opt.lr.assign(new_g_lr)
            self.generator_lr = max(self.generator_lr * .95, self.min_lr)
            self.model.generator_opt.lr.assign(self.generator_lr)
            # new_d_lr = max(self.discriminator_lr * (1 - (epoch / self.decay_epochs)), self.min_lr)
            # self.model.discriminator_opt.lr.assign(new_d_lr)
            self.discriminator_lr = max(self.discriminator_lr * .95, self.min_lr)
            self.model.discriminator_opt.lr.assign(self.discriminator_lr)
            # print(f"Learning rate generator: {new_g_lr}, discriminator: {new_d_lr}")
            print(f"Learning rate generator: {self.generator_lr}, discriminator: {self.discriminator_lr}")

            # Log the learning rate on TensorBoard
            if self.tb_callback is not None:
                writer = self.tb_callback._writers.get('train')  # get the writer from the TensorBoard callback
                with writer.as_default():
                    # tf.summary.scalar('generator_lr', data=new_g_lr, step=epoch)
                    # tf.summary.scalar('discriminator_lr', data=new_d_lr, step=epoch)
                    tf.summary.scalar('generator_lr', data=self.generator_lr, step=epoch)
                    tf.summary.scalar('discriminator_lr', data=self.discriminator_lr, step=epoch)
                    writer.flush()


# Wasserstein loss for the discriminator
def discriminator_w_loss(pred_real, pred_fake):
    real_loss = tf.reduce_mean(pred_real)
    fake_loss = tf.reduce_mean(pred_fake)
    return fake_loss - real_loss


# Wasserstein loss for the generator
def generator_w_loss(pred_fake):
    return -tf.reduce_mean(pred_fake)


generator = build_generator(latent_dim, conv_arc=CONV_ARC)
generator.summary()

discriminator = build_discriminator(img_shape, conv_arc=CONV_ARC)
discriminator.summary()

generator_optimizer = tf.keras.optimizers.Adam(0.00002, beta_1=0.5, beta_2=0.9)
discriminator_optimizer = tf.keras.optimizers.Adam(0.00002, beta_1=0.5, beta_2=0.9)

callback = ResultsCallback(noise_dim=latent_dim, model_path=MODELS_DIR, imgs_path=GENERATED_IMGS_DIR, duration=0.04)
tb_callback = TensorBoard(HISTORIES_DIR + '/logs')
lr_scheduler = LRScheduler(decay_epochs=num_epochs, tb_callback=tb_callback)

gan = WGAN_GP(discriminator, generator, latent_dim, discriminator_extra_steps=5)
gan.compile(discriminator_optimizer, generator_optimizer, discriminator_w_loss, generator_w_loss, run_eagerly=False)

gan.fit(train_datagen, epochs=num_epochs, callbacks=[callback, tb_callback, lr_scheduler])
