import tensorflow as tf
import tensorflow.keras as keras
from tensorflow.keras import layers, metrics
from tensorflow.keras import losses
from matplotlib import pyplot as plt


class VAE(keras.Model):
    """Combines the encoder and decoder into an end-to-end model for training"""
    def __init__(
            self,
            encoder,
            decoder,
            bce_weight,
            mse_weight,
            mae_weight,
            ssim_weight,
            sharpness_weight,
            kl_weight,
            img_shape,
            **kwargs):
        super(VAE, self).__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder
        self.img_shape = img_shape
        self.bce_weight = bce_weight
        self.mse_weight = mse_weight
        self.mae_weight = mae_weight
        self.ssim_weight = ssim_weight
        self.sharpness_weight = sharpness_weight
        self.kl_weight = kl_weight
        self.total_loss_tracker = metrics.Mean(name="total_loss")
        self.reconstruction_loss_tracker = metrics.Mean(name="reconstruction_loss")
        self.kl_loss_tracker = metrics.Mean(name="kl_loss")

    @property
    def metrics(self):
        return [
            self.total_loss_tracker,
            self.reconstruction_loss_tracker,
            self.kl_loss_tracker,
        ]

    def train_step(self, inputs):
        with tf.GradientTape() as tape:
            # Sample from the latent space
            z_mean, z_log_var, z = self.encoder(inputs)

            # Reconstruct the images
            reconstruction = self.decoder([z, inputs[1]])

            # Compute the reconstruction loss
            reconstruction_loss = \
                self.bce_weight * tf.reduce_mean(
                    tf.reduce_sum(
                        losses.binary_crossentropy(inputs[0], reconstruction),
                        axis=(1, 2)
                    )) + \
                self.mse_weight * tf.reduce_mean(
                    tf.reduce_sum(
                        losses.mean_squared_error(inputs[0], reconstruction),
                        axis=(1, 2)
                    )) + \
                self.mae_weight * tf.reduce_mean(
                    tf.reduce_sum(
                        losses.mean_absolute_error(inputs[0], reconstruction),
                        axis=(1, 2)
                    )) + \
                self.ssim_weight * tf.reduce_sum(
                    1. - tf.image.ssim_multiscale(
                        img1=inputs[0],
                        img2=reconstruction,
                        max_val=max(self.img_shape[:-1]),
                        power_factors=(0.0448, 0.2856, 0.3001),
                        filter_size=7
                    )) + \
                self.sharpness_weight * tf.reduce_mean(
                    tf.abs(
                        tf.image.sobel_edges(reconstruction) -
                        tf.image.sobel_edges(inputs[0])
                    ))

            # Compute the KL loss
            kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
            kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))

            # Compute the total loss
            total_loss = reconstruction_loss + self.kl_weight * kl_loss

        # The gradient descent
        grads = tape.gradient(total_loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))

        # Update losses trackers
        self.total_loss_tracker.update_state(total_loss)
        self.reconstruction_loss_tracker.update_state(reconstruction_loss)
        self.kl_loss_tracker.update_state(kl_loss)

        # Return the losses
        return {
            "loss": self.total_loss_tracker.result(),
            "reconstruction_loss": self.reconstruction_loss_tracker.result(),
            "kl_loss": self.kl_loss_tracker.result(),
        }

    def test_step(self, inputs):
        # Sample from the latent space
        z_mean, z_log_var, z = self.encoder(inputs)

        # Reconstruct the images
        reconstruction = self.decoder([z, inputs[1]])

        # Compute the reconstruction loss
        reconstruction_loss = \
            self.bce_weight * tf.reduce_mean(
                tf.reduce_sum(
                    losses.binary_crossentropy(inputs[0], reconstruction),
                    axis=(1, 2)
                )) + \
            self.mse_weight * tf.reduce_mean(
                tf.reduce_sum(
                    losses.mean_squared_error(inputs[0], reconstruction),
                    axis=(1, 2)
                )) + \
            self.mae_weight * tf.reduce_mean(
                tf.reduce_sum(
                    losses.mean_absolute_error(inputs[0], reconstruction),
                    axis=(1, 2)
                )) + \
            self.ssim_weight * tf.reduce_sum(
                1. - tf.image.ssim_multiscale(
                    img1=inputs[0],
                    img2=reconstruction,
                    max_val=max(self.img_shape[:-1]),
                    power_factors=(0.0448, 0.2856, 0.3001),
                    filter_size=7
                )) + \
            self.sharpness_weight * tf.reduce_mean(
                tf.abs(
                    tf.image.sobel_edges(reconstruction) -
                    tf.image.sobel_edges(inputs[0])
                ))

        # Compute the KL loss
        kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
        kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))

        # Compute the total loss
        total_loss = reconstruction_loss + self.kl_weight * kl_loss

        return {
            "loss": total_loss,
            "reconstruction_loss": reconstruction_loss,
            "kl_loss": kl_loss,
        }

    def plot_images(self, x_train, epoch=None, logs=None, num_rows=2, num_cols=4):
        # plot the first 6 images in the train set
        plot_size = num_rows * num_cols
        plt.figure(figsize=(num_cols + 1, num_rows + 1))
        counter = 0
        for x in x_train:
            plt.subplot(num_rows, num_cols, counter + 1)
            images = tf.reshape(x[0][0:plot_size], (plot_size, *self.img_shape))
            lbls = x[1][0:plot_size]

            reconstructed_images = tf.reshape(
                self.decoder([self.encoder([images, lbls])[-1], lbls]),
                (plot_size, *self.img_shape)
            )

            for i in range(int(plot_size / 2)):
                plt.subplot(num_rows, num_cols, 2 * i + 1)
                plt.imshow(images[i], cmap='gray')
                plt.axis('off')

                plt.subplot(num_rows, num_cols, 2 * i + 2)
                plt.imshow(reconstructed_images[i], cmap='gray')
                plt.axis('off')
            break
        plt.show()
# -------------------------------
