import tensorflow as tf
import tensorflow_addons as tfa

class Sampling(tf.keras.layers.Layer):
    """Uses (z_mean, z_log_var) to sample z, the vector encoding a digit."""

    def call(self, inputs):
        z_mean, z_log_var = inputs
        batch = tf.shape(z_mean)[0]
        dim = tf.shape(z_mean)[1]
        epsilon = tf.keras.backend.random_normal(shape=(batch, dim))
        return z_mean + tf.exp(0.5 * z_log_var) * epsilon

    def compute_output_shape(self, input_shape):
        return input_shape[0]


class GVAEDelta(tf.keras.Model):
    def __init__(self, l_dim, m_dim, input_shape, kl_weight, encoder_backbone, decoder_backbone, n_transforms: int,
                 use_tl=False, class_alpha=1.0,
                 **kwargs):
        super(GVAEDelta, self).__init__(**kwargs)
        # Parameters
        self.input_shape_ = input_shape
        self.n_transforms = n_transforms

        # Hyperparameters
        self.l_dim = l_dim  # Dimension of transformation latent space
        self.m_dim = m_dim  # Dimension of invariant latent space
        self.kl_weight = kl_weight

        # Neural network architectures
        self.encoder_backbone = encoder_backbone
        self.decoder_backbone = decoder_backbone
        # Latent parameters
        self.location_parameters, self.scale_parameters = self.set_latent_parameters()

        # Set encoder and decoder
        self.encoder = self.set_encoder()
        self.decoder = self.set_decoder()

        # Loss trackers
        self.total_loss_tracker = tf.keras.metrics.Mean(name="total_loss")
        self.kl_weight_tracker = tf.keras.metrics.Mean(name="kl_weight")
        self.reconstruction_loss_tracker = tf.keras.metrics.Mean(
            name="reconstruction_loss"
        )
        self.kl_loss_tracker = tf.keras.metrics.Mean(name="kl_loss")
        self.log_var_z_tracker = tf.keras.metrics.Mean(name="log_var_z")

        self.use_tl = use_tl

        if self.use_tl:
            self.tl_loss_tracker = tf.keras.metrics.Mean(name="triplet_loss")
            self.class_alpha = class_alpha

    def set_latent_parameters(self):
        # Define encoder backbone
        input_encoder = tf.keras.layers.Input(self.input_shape_)
        x = self.encoder_backbone(input_encoder)

        # Define z parameters
        z_mean = tf.keras.layers.Dense(self.m_dim, name="z_mean")(x)
        z_log_var = tf.keras.layers.Dense(self.m_dim,
                                          kernel_initializer=tf.keras.initializers.Zeros(),
                                          bias_initializer=tf.keras.initializers.Zeros(),
                                          name="z_log_var"
                                          )(x)

        # Define y parameters
        y_mean = tf.keras.layers.Dense(self.l_dim, name="y_mean")(x)
        y_log_var = tf.keras.layers.Dense(self.l_dim,
                                          kernel_initializer=tf.keras.initializers.Zeros(),
                                          bias_initializer=tf.keras.initializers.Zeros(),
                                          name="y_log_var")(x)

        # Define models for predicting z and y parameters
        encoder_y_mean = tf.keras.Model(input_encoder, y_mean)
        encoder_y_log_var = tf.keras.Model(input_encoder, y_log_var)
        encoder_z_mean = tf.keras.Model(input_encoder, z_mean)
        encoder_z_log_var = tf.keras.Model(input_encoder, z_log_var)
        return [[encoder_z_mean, encoder_y_mean], [encoder_z_log_var, encoder_y_log_var]]

    def set_encoder(self) -> tf.keras.Model:
        encoder_z_mean, encoder_y_mean = self.location_parameters
        encoder_z_log_var, encoder_y_log_var = self.scale_parameters

        # Define the input
        mult_input_layer = tf.keras.layers.Input((self.n_transforms, *self.input_shape_))

        # Pass each image through the encoder output (num_transformers, list_encoder_outputs)
        mult_z_mean = tf.keras.layers.TimeDistributed(encoder_z_mean)(mult_input_layer)
        mult_z_log_var = tf.keras.layers.TimeDistributed(encoder_z_log_var)(mult_input_layer)
        z_mean = tf.keras.layers.Lambda(lambda y: tf.math.reduce_mean(y, axis=1))(mult_z_mean)
        z_log_var = tf.keras.layers.Lambda(lambda y: tf.math.reduce_mean(y, axis=1))(mult_z_log_var)
        mult_z = Sampling()([z_mean, z_log_var])
        mult_z = tf.keras.layers.Lambda(lambda y: tf.expand_dims(y, axis=1))(mult_z)
        mult_z = tf.keras.layers.Concatenate(axis=1)([mult_z for _ in range(self.n_transforms)])

        mult_y_mean = tf.keras.layers.TimeDistributed(encoder_y_mean)(mult_input_layer)
        mult_y_log_var = tf.keras.layers.TimeDistributed(encoder_y_log_var)(mult_input_layer)

        mult_y_mean_reshaped = tf.keras.layers.Reshape(mult_y_mean.shape[1:] + (1,))(mult_y_mean)
        mult_y_log_var_reshaped = tf.keras.layers.Reshape(mult_y_log_var.shape[1:] + (1,))(mult_y_log_var)
        mult_y_params = tf.keras.layers.Concatenate(-1)([mult_y_mean_reshaped, mult_y_log_var_reshaped])

        mult_y = tf.keras.layers.TimeDistributed(
            tf.keras.layers.Lambda(lambda x: Sampling()([x[:, :, 0], x[:, :, 1]])))(mult_y_params)

        # Create encoder from z and y
        encoder = tf.keras.models.Model(mult_input_layer, [[mult_z_mean, mult_y_mean], [mult_z_log_var, mult_y_log_var],
                                                           [mult_z, mult_y]])
        return encoder

    def set_decoder(self) -> tf.keras.Model:
        input_decoder = tf.keras.layers.Input((self.l_dim + self.m_dim,))
        x = tf.keras.layers.Dense(self.decoder_backbone.inputs[0].shape[-1])(input_decoder)
        decoder = tf.keras.Model(input_decoder, self.decoder_backbone(x))

        # Pass multiple codes to decoder
        mult_input_layer = tf.keras.layers.Input((self.n_transforms, self.l_dim + self.m_dim))
        x = tf.keras.layers.TimeDistributed(decoder)(mult_input_layer)
        decoder = tf.keras.Model(mult_input_layer, x)
        return decoder

    @property
    def metrics(self):
        list_metrics = [
            self.total_loss_tracker,
            self.reconstruction_loss_tracker,
            self.kl_loss_tracker,
            self.log_var_z_tracker
        ]
        if self.use_tl:
            list_metrics.append(self.tl_loss_tracker)
        return list_metrics

    def call(self, inputs, training=False):
        [_, _, [mult_z, mult_y]] = self.encoder(inputs)
        z = tf.keras.layers.Concatenate(-1)([mult_z, mult_y])
        reconstruction = self.decoder(z)
        return reconstruction

    def train_step(self, data):
        with tf.GradientTape() as tape:
            if self.use_tl:
                image_input, (image_output, label) = data
                [[mult_z_mean, mult_y_mean], [mult_z_log_var, mult_y_log_var], [mult_z, mult_y]] = self.encoder(
                    image_input)
            else:
                image_input = data
                [[mult_z_mean, mult_y_mean], [mult_z_log_var, mult_y_log_var], [mult_z, mult_y]] = self.encoder(data)

            z = tf.keras.layers.Concatenate(-1)([mult_z, mult_y])
            reconstruction = self.decoder(z)
            # TODO fix reconstruction
            reconstruction_loss = tf.reduce_mean(
                tf.reduce_sum(
                    tf.keras.losses.binary_crossentropy(image_input, reconstruction), axis=(-1, -2, -3, -4)
                )
            )
            z_log_var = tf.math.reduce_mean(mult_z_log_var, axis=1)
            z_mean = tf.math.reduce_mean(mult_z_mean, axis=1)
            kl_loss_z = tf.reduce_sum(
                -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var)) * self.kl_weight, axis=-1)
            kl_loss_y = tf.reduce_sum(
                -0.5 * (1 + mult_y_log_var - tf.square(mult_y_mean) - tf.exp(mult_y_log_var)) * self.kl_weight, axis=(-1, -2))
            kl_loss = kl_loss_z + kl_loss_y
            #             tf.print(kl_loss)
            kl_loss = tf.reduce_mean(kl_loss)
            total_loss = reconstruction_loss + kl_loss

            if self.use_tl:
                triplet_loss = tf.reduce_mean(tfa.losses.TripletSemiHardLoss()(label, z_mean)) * self.class_alpha
                total_loss += triplet_loss
        #             tf.print("Triplet loss", triplet_loss)
        #             tf.print("Mean z",tf.reduce_mean(mult_z_mean), "Log var z",tf.reduce_mean(mult_z_log_var))
        #             tf.print("Mean y",tf.reduce_mean(mult_y_mean), "Log var y",tf.reduce_mean(mult_y_log_var))
        grads = tape.gradient(total_loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))

        # Update metric trackers
        self.total_loss_tracker.update_state(total_loss)
        self.kl_weight_tracker.update_state(self.kl_weight)
        self.reconstruction_loss_tracker.update_state(reconstruction_loss)
        self.kl_loss_tracker.update_state(kl_loss)
        self.log_var_z_tracker.update_state(tf.math.reduce_mean(mult_z_log_var))
        if self.use_tl:
            self.tl_loss_tracker.update_state(triplet_loss)

        output_dictionary = {
            "loss": self.total_loss_tracker.result(),
            "reconstruction_loss": self.reconstruction_loss_tracker.result(),
            "kl_loss": self.kl_loss_tracker.result(),
            "kl_weight": self.kl_weight_tracker.result(),
            "log_var_z": self.log_var_z_tracker.result()
        }
        if self.use_tl:
            output_dictionary.update({"tl_loss": self.tl_loss_tracker.result()})

        return output_dictionary