import tensorflow as tf

from typing import List


class Sampling(tf.keras.layers.Layer):
    """Uses (z_mean, z_log_var) to sample z from a Gaussian distribution"""

    def call(self, inputs):
        z_mean, z_log_var = inputs
        batch = tf.shape(z_mean)[0]
        dim = tf.shape(z_mean)[1]
        epsilon = tf.keras.backend.random_normal(shape=(batch, dim))
        return z_mean + tf.exp(0.5 * z_log_var) * epsilon


class VAE(tf.keras.models.Model):
    def __init__(self, dim: int,
                 input_shape: List,
                 encoder_backbone: tf.keras.Model,
                 decoder_backbone: tf.keras.Model,
                 kl_weight: float = 1.0, **kwargs):
        self.dim = dim  # dimension of latent space
        self.input_shape_ = input_shape  # input shape of image
        self.kl_weight = tf.Variable(kl_weight, dtype=float, trainable=False)  # similar to beta in beta-VAE
        super(VAE, self).__init__(**kwargs)
        self.encoder_backbone = encoder_backbone
        self.decoder_backbone = decoder_backbone
        self.encoder = self.set_encoder()
        self.decoder = self.set_decoder()
        self.total_loss_tracker = tf.keras.metrics.Mean(name="total_loss")
        self.reconstruction_loss_tracker = tf.keras.metrics.Mean(
            name="reconstruction_loss"
        )
        self.kl_loss_tracker = tf.keras.metrics.Mean(name="kl_loss")

    def set_encoder(self) -> tf.keras.models.Model:
        input_encoder = tf.keras.layers.Input(self.input_shape_)
        x = self.encoder_backbone(input_encoder)
        z_mean = tf.keras.layers.Dense(self.dim, name="z_mean")(x)
        z_log_var = tf.keras.layers.Dense(self.dim,
                                          kernel_initializer=tf.keras.initializers.RandomNormal(stddev=0.01),
                                          bias_initializer=tf.keras.initializers.Zeros(),
                                          name="z_log_var")(x)

        z = Sampling()([z_mean, z_log_var])
        encoder = tf.keras.models.Model(input_encoder, [z_mean, z_log_var, z])
        return encoder

    def set_decoder(self) -> tf.keras.models.Model:
        input_decoder = tf.keras.layers.Input((self.dim,))
        x = tf.keras.layers.Dense(self.decoder_backbone.inputs[0].shape[-1])(input_decoder)
        decoder = tf.keras.models.Model(input_decoder, self.decoder_backbone(x))
        return decoder

    def call(self, inputs, training=False):
        z_mean, z_log_var, z = self.encoder(inputs)
        reconstruction = self.decoder(z)
        return reconstruction

    @property
    def metrics(self):
        return [
            self.total_loss_tracker,
            self.reconstruction_loss_tracker,
            self.kl_loss_tracker,
            self.kl_weight_tracker
        ]

    def train_step(self, data):
        with tf.GradientTape() as tape:
            z_mean, z_log_var, z = self.encoder(data)
            reconstruction = self.decoder(z)
            print("Reconstruction shape", reconstruction.shape)
            print("Data shape", data.shape)
            reconstruction_loss = tf.reduce_mean(
                tf.reduce_sum(
                    tf.keras.losses.binary_crossentropy(data, reconstruction), axis=(-1, -2, -3)
                )
            )
            kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var)) * self.kl_weight
            kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
            total_loss = reconstruction_loss + kl_loss
        grads = tape.gradient(total_loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
        self.total_loss_tracker.update_state(total_loss)
        self.reconstruction_loss_tracker.update_state(reconstruction_loss)
        self.kl_loss_tracker.update_state(kl_loss)
        return {
            "loss": self.total_loss_tracker.result(),
            "reconstruction_loss": self.reconstruction_loss_tracker.result(),
            "kl_loss": self.kl_loss_tracker.result(),
        }


if __name__ == "__main__":
    import numpy as np

    input_shape = [64, 64, 3]
    total_images = 100
    images = np.zeros([total_images] + input_shape)
    latent_dim = 100

    # Define an encoder backbone
    encoder_backbone = tf.keras.Sequential(
        [
            tf.keras.layers.InputLayer(input_shape=input_shape),
            tf.keras.layers.Conv2D(
                filters=32, kernel_size=3, strides=(2, 2), activation='relu'),
            tf.keras.layers.Conv2D(
                filters=64, kernel_size=3, strides=(2, 2), activation='relu'),
            tf.keras.layers.Flatten(),
        ])
    # Define an decoder backbone
    decoder_backbone = tf.keras.Sequential(
        [
            tf.keras.layers.InputLayer(input_shape=(latent_dim,)),
            tf.keras.layers.Dense(units=encoder_backbone.output_shape[-1], activation=tf.nn.relu),
            tf.keras.layers.Reshape(target_shape=encoder_backbone.layers[-2].output_shape[1:]),
            tf.keras.layers.Conv2DTranspose(
                filters=64, kernel_size=3, strides=2, padding='same',
                activation='relu'),
            tf.keras.layers.Conv2DTranspose(
                filters=32, kernel_size=3, strides=2, padding='same',
                activation='relu'),
            # No activation
            tf.keras.layers.Conv2DTranspose(
                filters=1, kernel_size=3, strides=1, padding='same'),
        ]
    )

    vae = VAE(dim=latent_dim, input_shape=input_shape, encoder_backbone=encoder_backbone,
              decoder_backbone=decoder_backbone)

    vae.compile(optimizer=tf.keras.optimizers.Adam(lr=1e-3))
    vae.fit(x = images, batch_size=100, epochs=1)