import tensorflow as tf
import tensorflow_addons as tfa



class GAE(tf.keras.Model):
    def __init__(self, l_dim, m_dim, input_shape, encoder_backbone, decoder_backbone, n_transforms: int,
                 use_tl=False, class_alpha=1.0,
                 **kwargs):
        super(GAE, self).__init__(**kwargs)
        # Parameters
        self.input_shape_ = input_shape
        self.n_transforms = n_transforms

        # Hyperparameters
        self.l_dim = l_dim  # Dimension of transformation latent space
        self.m_dim = m_dim  # Dimension of invariant latent space

        # Neural network architectures
        self.encoder_backbone = encoder_backbone
        self.decoder_backbone = decoder_backbone
        # Latent parameters
        self.location_parameters = self.set_latent_parameters()

        # Set encoder and decoder
        self.encoder = self.set_encoder()
        self.decoder = self.set_decoder()

        # Loss trackers
        self.total_loss_tracker = tf.keras.metrics.Mean(name="total_loss")
        self.reconstruction_loss_tracker = tf.keras.metrics.Mean(
            name="reconstruction_loss"
        )
        self.use_tl = use_tl

        if self.use_tl:
            self.tl_loss_tracker = tf.keras.metrics.Mean(name="triplet_loss")
            self.class_alpha = class_alpha

    def set_latent_parameters(self):
        # Define encoder backbone
        input_encoder = tf.keras.layers.Input(self.input_shape_)
        x = self.encoder_backbone(input_encoder)

        # Define z parameters
        z_mean = tf.keras.layers.Dense(self.m_dim, name="z_mean")(x)


        # Define y parameters
        y_mean = tf.keras.layers.Dense(self.l_dim, name="y_mean")(x)

        # Define models for predicting z and y parameters
        encoder_y_mean = tf.keras.Model(input_encoder, y_mean)
        encoder_z_mean = tf.keras.Model(input_encoder, z_mean)
        return encoder_z_mean, encoder_y_mean

    def set_encoder(self) -> tf.keras.Model:
        encoder_z_mean, encoder_y_mean = self.location_parameters

        # Define the input
        mult_input_layer = tf.keras.layers.Input((self.n_transforms, *self.input_shape_))

        # Pass each image through the encoder output (num_transformers, list_encoder_outputs)
        mult_z_mean = tf.keras.layers.TimeDistributed(encoder_z_mean)(mult_input_layer)
        z_mean = tf.keras.layers.Lambda(lambda y: tf.math.reduce_mean(y, axis=1))(mult_z_mean)

        mult_z = tf.keras.layers.Lambda(lambda y: tf.expand_dims(y, axis=1))(z_mean)
        mult_z = tf.keras.layers.Concatenate(axis=1)([mult_z for _ in range(self.n_transforms)])

        mult_y = tf.keras.layers.TimeDistributed(encoder_y_mean)(mult_input_layer)

        # Create encoder from z and y
        encoder = tf.keras.models.Model(mult_input_layer, [mult_z, mult_y])
        return encoder

    def set_decoder(self) -> tf.keras.Model:
        input_decoder = tf.keras.layers.Input((self.l_dim + self.m_dim,))
        x = tf.keras.layers.Dense(self.decoder_backbone.inputs[0].shape[-1])(input_decoder)
        decoder = tf.keras.Model(input_decoder, self.decoder_backbone(x))

        # Pass multiple codes to decoder
        mult_input_layer = tf.keras.layers.Input((self.n_transforms, self.l_dim + self.m_dim))
        x = tf.keras.layers.TimeDistributed(decoder)(mult_input_layer)
        decoder = tf.keras.Model(mult_input_layer, x)
        return decoder

    @property
    def metrics(self):
        list_metrics = [
            self.total_loss_tracker,
            self.reconstruction_loss_tracker
        ]
        if self.use_tl:
            list_metrics.append(self.tl_loss_tracker)
        return list_metrics

    def call(self, inputs, training=False):
        mult_z, mult_y = self.encoder(inputs)
        z = tf.keras.layers.Concatenate(-1)([mult_z, mult_y])
        reconstruction = self.decoder(z)
        return reconstruction

    def train_step(self, data):
        with tf.GradientTape() as tape:
            if self.use_tl:
                image_input, (image_output, label) = data
            else:
                image_input = data
            mult_z, mult_y = self.encoder(image_input)

            z = tf.keras.layers.Concatenate(-1)([mult_z, mult_y])
            reconstruction = self.decoder(z)
            # TODO fix reconstruction
            reconstruction_loss = tf.reduce_mean(
                tf.reduce_sum(
                    tf.keras.losses.binary_crossentropy(image_input, reconstruction), axis=(-1, -2, -3, -4)
                )
            )
            total_loss = reconstruction_loss

            if self.use_tl:
                triplet_loss = tf.reduce_mean(tfa.losses.TripletSemiHardLoss()(label, tf.math.reduce_mean(mult_z, axis = 1))) * self.class_alpha
                total_loss += triplet_loss
        grads = tape.gradient(total_loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))

        # Update metric trackers
        self.total_loss_tracker.update_state(total_loss)
        self.reconstruction_loss_tracker.update_state(reconstruction_loss)
        if self.use_tl:
            self.tl_loss_tracker.update_state(triplet_loss)

        output_dictionary = {
            "loss": self.total_loss_tracker.result(),
            "reconstruction_loss": self.reconstruction_loss_tracker.result()
        }
        if self.use_tl:
            output_dictionary.update({"tl_loss": self.tl_loss_tracker.result()})

        return output_dictionary

    def make_embedding_function(self, num_latent):
        """
        Creates an embedding function that can be used to retrieve the location parameter of a given num_latent space
        Args:
            num_latent: the number of th elatent space whose latent variable is retrieved
        Returns:
            Returns a function which receives input data with the appropiate shape and returns the embedded data
        """
        def embedding_function(input_data):
            prediction = self.encoder.predict(input_data)[num_latent]
            return prediction
        return embedding_function