import tensorflow as tf
import tensorflow_addons as tfa
import os
import sys
from typing import List, Tuple

sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from modules.latent_space.latentspace2 import LatentSpace, GaussianLatentSpace, HyperSphericalLatentSpace
from modules.vae.architectures import encoder_decoder_dense


class GVAE(tf.keras.Model):
    def __init__(self,
                 latent_spaces: List[LatentSpace],
                 input_shape: Tuple[int, ...],
                 encoder_backbones: List[tf.keras.models.Model],
                 decoder_backbone: tf.keras.models.Model,
                 n_transforms: int,
                 average_mask=List[bool],
                 kl_weight: float = 1,
                 **kwargs):
        super(GVAE, self).__init__(**kwargs)
        # Parameters
        self.input_shape_ = input_shape
        self.n_transforms = n_transforms

        # Hyperparameters
        self.latent_spaces = latent_spaces
        self.latent_dim = sum([latent_space.latent_dim for latent_space in self.latent_spaces])
        self.average_mask = average_mask
        assert len(self.latent_spaces) == len(self.average_mask), "Average mask length does not match latent space list length"
        self.kl_weight = kl_weight

        # Neural network architectures
        self.encoder_backbones = encoder_backbones
        self.decoder_backbone = decoder_backbone
        # Latent parameters
        self.lst_encoder_loc, self.lst_encoder_scale = self.set_lst_parameter_encoders()

        # Set encoder and decoder
        self.encoder = self.set_encoder()
        self.decoder = self.set_decoder()

        # Loss trackers
        self.total_loss_tracker = tf.keras.metrics.Mean(name="total_loss")
        self.kl_weight_tracker = tf.keras.metrics.Mean(name="kl_weight")
        self.reconstruction_loss_tracker = tf.keras.metrics.Mean(
            name="reconstruction_loss"
        )
        self.kl_loss_tracker = tf.keras.metrics.Mean(name="kl_loss")
        self.log_var_z_tracker = tf.keras.metrics.Mean(name="log_var_z")

    @property
    def encoder_backbones(self):
        return self.__encoder_backbones

    @encoder_backbones.setter
    def encoder_backbones(self, encoder_backbones):
        if len(encoder_backbones) == 1:
            self.__encoder_backbones = encoder_backbones * len(self.latent_spaces)
        else:
            assert len(self.latent_spaces) == len(encoder_backbones), (f"Number of encoder backbones "
                                                                       f"{len(encoder_backbones)} is not the same as"
                                                                       f" the number of latent spaces"
                                                                       f" {len(self.latent_spaces)}")
            self.__encoder_backbones = encoder_backbones

    def set_lst_parameter_encoders(self):
        lst_encoder_loc = []
        lst_encoder_scale = []
        # Define encoder backbone
        input_encoder = tf.keras.layers.Input(self.input_shape_)

        for encoder_backbone, latent_space in zip(self.encoder_backbones, self.latent_spaces):
            h_enc = encoder_backbone(input_encoder)
            lst_encoder_loc.append(tf.keras.Model(input_encoder, latent_space.loc_param_layer(h_enc)))
            lst_encoder_scale.append(tf.keras.Model(input_encoder, latent_space.scale_param_layer(h_enc)))

        return lst_encoder_loc, lst_encoder_scale


    def set_encoder(self) -> tf.keras.Model:

        # Define the input
        mult_input_layer = tf.keras.layers.Input((self.n_transforms, *self.input_shape_))

        # Pass each image through the encoder output (num_transformers, list_encoder_outputs)
        lst_sample = []
        lst_avg_loc = []
        lst_scale = []
        for num_latent_space, latent_space in enumerate(self.latent_spaces):
            # Estimate parameter tensors with shape (batch_size, n_transforms, param_shape)
            loc_param_estimate = tf.keras.layers.TimeDistributed(self.lst_encoder_loc[num_latent_space])(
                mult_input_layer)
            scale_param_estimate = tf.keras.layers.TimeDistributed(self.lst_encoder_scale[num_latent_space])(
                mult_input_layer)
            # Average the predicted distributions
            if self.average_mask[num_latent_space]:
                avg_loc, avg_scale = latent_space.average_parameters(loc_param_estimate, scale_param_estimate)
            else:
                avg_loc = loc_param_estimate
                avg_scale = scale_param_estimate
            lst_avg_loc.append(avg_loc)
            lst_scale.append(avg_scale)
            lst_sample.append(latent_space.sampling([avg_loc, avg_scale]))


        # Create encoder from z and y
        encoder = tf.keras.models.Model(mult_input_layer, [lst_avg_loc, lst_scale, lst_sample])
        return encoder

    def set_decoder(self) -> tf.keras.Model:
        print(f"Total latent dimension is {self.latent_dim}")
        input_decoder = tf.keras.layers.Input((self.latent_dim,))
        x = tf.keras.layers.Dense(self.decoder_backbone.inputs[0].shape[-1])(input_decoder)
        decoder = tf.keras.Model(input_decoder, self.decoder_backbone(x))

        # Pass multiple codes to decoder
        mult_input_layer = tf.keras.layers.Input((self.n_transforms, self.latent_dim))
        x = tf.keras.layers.TimeDistributed(decoder)(mult_input_layer)
        decoder = tf.keras.Model(mult_input_layer, x)
        return decoder

    @property
    def metrics(self):
        list_metrics = [
            self.total_loss_tracker,
            self.reconstruction_loss_tracker,
            self.kl_loss_tracker
        ]
        return list_metrics

    def call(self, inputs, training=False, **kwargs):
        loc_parameter_estimates, scale_parameter_estimates, samples = self.encoder(inputs)
        z = tf.keras.layers.Concatenate(-1)(samples)
        reconstruction = self.decoder(z)
        return reconstruction

    def train_step(self, data):
        with tf.GradientTape() as tape:
            image_input = data
            loc_parameter_estimates, scale_parameter_estimates, samples = self.encoder(image_input)
            z = tf.keras.layers.Concatenate(-1)(samples)
            reconstruction = self.decoder(z)
            # TODO fix reconstruction
            reconstruction_loss = tf.reduce_mean(
                tf.reduce_sum(
                    tf.keras.losses.binary_crossentropy(image_input, reconstruction), axis=(-1, -2, -3, -4)
                )
            )

            kl_losses = []
            for num_latent_space, latent_space in enumerate(self.latent_spaces):
                kl_losses.append(latent_space.kl_loss(
                    [loc_parameter_estimates[num_latent_space], scale_parameter_estimates[num_latent_space]]))
            kl_loss = tf.add_n(kl_losses)
            kl_loss = tf.reduce_sum(kl_loss, axis = 0)
            kl_loss = tf.reduce_mean(kl_loss)
            total_loss = reconstruction_loss + kl_loss

        grads = tape.gradient(total_loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))

        # Update metric trackers
        self.total_loss_tracker.update_state(total_loss)
        self.kl_weight_tracker.update_state(self.kl_weight)
        self.reconstruction_loss_tracker.update_state(reconstruction_loss)
        self.kl_loss_tracker.update_state(kl_loss)

        output_dictionary = {
            "loss": self.total_loss_tracker.result(),
            "reconstruction_loss": self.reconstruction_loss_tracker.result(),
            "kl_loss": self.kl_loss_tracker.result(),
            "kl_weight": self.kl_weight_tracker.result()
        }

        return output_dictionary

    def make_embedding_function(self, num_latent):
        """
        Creates an embedding function that can be used to retrieve the location parameter of a given num_latent space
        Args:
            num_latent: the number of th elatent space whose latent variable is retrieved
        Returns:
            Returns a function which receives input data with the appropiate shape and returns the embedded data
        """
        def embedding_function(input_data):
            prediction = self.encoder.predict(input_data)[0][num_latent]
            return prediction
        return embedding_function

    def data_transformer(self, dictionary_):
        return dictionary_["image"]


class GVAEClass(GVAE):
    def __init__(self,
                 latent_spaces: List[LatentSpace],
                 input_shape: Tuple[int, ...],
                 encoder_backbones: List[tf.keras.models.Model],
                 decoder_backbone: tf.keras.models.Model,
                 num_classes:int,
                 n_transforms: int,
                 average_mask=List[bool],
                 kl_weight: float = 1,
                 class_alpha: float = 1.0,
                 **kwargs):
        self.class_alpha = class_alpha
        self.num_classes = num_classes
        super().__init__(latent_spaces, input_shape, encoder_backbones, decoder_backbone, n_transforms, average_mask,
                         kl_weight, **kwargs)
        self.class_loss_tracker = tf.keras.metrics.Mean(name="class_loss")
        self.classifier = self.set_classifier()


    @property
    def metrics(self):
        list_metrics = [
            self.total_loss_tracker,
            self.reconstruction_loss_tracker,
            self.kl_loss_tracker,
            self.class_loss_tracker
        ]
        return list_metrics

    def set_classifier(self) -> tf.keras.Model:
        """
        Set up classifier assume that in the list of latent spaces the first should encode the information about the class of the object
        Returns:

        """
        input_classifier = tf.keras.Input((self.latent_spaces[0].latent_dim,))
        output_classifier = tf.keras.layers.Dense(self.num_classes, activation="softmax")(input_classifier)
        return tf.keras.Model(input_classifier, output_classifier)




    def train_step(self, data):
        with tf.GradientTape() as tape:
            image_input, (image_output, label) = data
            loc_parameter_estimates, scale_parameter_estimates, samples = self.encoder(image_input)
            z_reduced = tf.reduce_mean(samples[0], axis=1)
            classification_output = self.classifier(z_reduced) # Give the samples of the first latent space
            z = tf.keras.layers.Concatenate(-1)(samples)

            reconstruction = self.decoder(z)
            reconstruction_loss = tf.reduce_mean(
                tf.reduce_sum(
                    tf.keras.losses.binary_crossentropy(image_input, reconstruction), axis=(-1, -2, -3, -4)
                )
            )

            kl_losses = []
            for num_latent_space, latent_space in enumerate(self.latent_spaces):
                kl_losses.append(latent_space.kl_loss(
                    [loc_parameter_estimates[num_latent_space], scale_parameter_estimates[num_latent_space]]))
            kl_loss = tf.add_n(kl_losses)
            kl_loss = tf.reduce_sum(kl_loss, axis=0)
            kl_loss = tf.reduce_mean(kl_loss)

            # Estimate the triplet loss
            classification_loss = tf.reduce_mean(tf.keras.losses.categorical_crossentropy(label, classification_output)) * self.class_alpha

            total_loss = reconstruction_loss + kl_loss + classification_loss

        grads = tape.gradient(total_loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))

        # Update metric trackers
        self.total_loss_tracker.update_state(total_loss)
        self.kl_weight_tracker.update_state(self.kl_weight)
        self.reconstruction_loss_tracker.update_state(reconstruction_loss)
        self.kl_loss_tracker.update_state(kl_loss)
        self.class_loss_tracker.update_state(classification_loss)

        output_dictionary = {
            "loss": self.total_loss_tracker.result(),
            "reconstruction_loss": self.reconstruction_loss_tracker.result(),
            "kl_loss": self.kl_loss_tracker.result(),
            "kl_weight": self.kl_weight_tracker.result(),
            "classification_loss": self.class_loss_tracker.result(),
        }

        return output_dictionary



class GVAEProto(GVAE):
    def __init__(self,
                 latent_spaces: List[LatentSpace],
                 input_shape: Tuple[int, ...],
                 encoder_backbones: List[tf.keras.models.Model],
                 decoder_backbone: tf.keras.models.Model,
                 num_classes:int,
                 n_transforms: int,
                 average_mask=List[bool],
                 kl_weight: float = 1,
                 class_alpha: float = 1.0,
                 **kwargs):
        self.class_alpha = class_alpha
        super().__init__(latent_spaces, input_shape, encoder_backbones, decoder_backbone, n_transforms, average_mask,
                         kl_weight, **kwargs)
        self.kl_proto_loss_tracker = tf.keras.metrics.Mean(name="kl_proto_loss")


    @property
    def metrics(self):
        list_metrics = [
            self.total_loss_tracker,
            self.reconstruction_loss_tracker,
            self.kl_loss_tracker,
            self.kl_proto_loss_tracker
        ]
        return list_metrics

    def train_step(self, data):
        with tf.GradientTape() as tape:
            image_input, (image_output, label) = data
            loc_parameter_estimates, scale_parameter_estimates, samples = self.encoder(image_input)

            # Assume that last location encoder 0 is averaged
            loc_param_estimate = tf.keras.layers.TimeDistributed(self.lst_encoder_loc[0])(
                image_input)

            scale_param_estimate = tf.keras.layers.TimeDistributed(self.lst_encoder_scale[0])(
                image_input)

            z_reduced = tf.reduce_mean(samples[0], axis=1)
            classification_output = self.classifier(z_reduced) # Give the samples of the first latent space
            z = tf.keras.layers.Concatenate(-1)(samples)

            reconstruction = self.decoder(z)
            reconstruction_loss = tf.reduce_mean(
                tf.reduce_sum(
                    tf.keras.losses.binary_crossentropy(image_input, reconstruction), axis=(-1, -2, -3, -4)
                )
            )
            proto_loss = self.class_alpha*tf.reduce_mean(tf.reduce_sum(tf.norm(loc_param_estimate-loc_parameter_estimates[-1])**2, axis = -1))

            kl_losses = []
            for num_latent_space, latent_space in enumerate(self.latent_spaces):
                kl_losses.append(latent_space.kl_loss(
                    [loc_parameter_estimates[num_latent_space], scale_parameter_estimates[num_latent_space]]))
            kl_loss = tf.add_n(kl_losses)
            kl_loss = tf.reduce_sum(kl_loss, axis=0)
            kl_loss = tf.reduce_mean(kl_loss)

            # Estimate the triplet loss
            classification_loss = tf.reduce_mean(tf.keras.losses.categorical_crossentropy(label, classification_output)) * self.class_alpha

            total_loss = reconstruction_loss + kl_loss + classification_loss + proto_loss

        grads = tape.gradient(total_loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))

        # Update metric trackers
        self.total_loss_tracker.update_state(total_loss)
        self.kl_weight_tracker.update_state(self.kl_weight)
        self.reconstruction_loss_tracker.update_state(reconstruction_loss)
        self.kl_loss_tracker.update_state(kl_loss)
        self.kl_proto_loss_tracker.update_state(proto_loss)

        output_dictionary = {
            "loss": self.total_loss_tracker.result(),
            "reconstruction_loss": self.reconstruction_loss_tracker.result(),
            "kl_loss": self.kl_loss_tracker.result(),
            "kl_weight": self.kl_weight_tracker.result(),
            "kl_proto_loss": self.kl_proto_loss_tracker.result(),
        }

        return output_dictionary

    def data_transformer(self, dictionary_):
        return dictionary_["image"]



class GVAETL(GVAE):
    def __init__(self,
                 latent_spaces: List[LatentSpace],
                 input_shape: Tuple[int, ...],
                 encoder_backbones: List[tf.keras.models.Model],
                 decoder_backbone: tf.keras.models.Model,
                 n_transforms: int,
                 average_mask=List[bool],
                 kl_weight: float = 1,
                 class_alpha: float = 1.0,
                 **kwargs):
        self.class_alpha = class_alpha
        super().__init__(latent_spaces, input_shape, encoder_backbones, decoder_backbone, n_transforms, average_mask,
                         kl_weight, **kwargs)
        self.tl_loss_tracker = tf.keras.metrics.Mean(name="triplet_loss")

    @property
    def metrics(self):
        list_metrics = [
            self.total_loss_tracker,
            self.reconstruction_loss_tracker,
            self.kl_loss_tracker,
            self.tl_loss_tracker
        ]
        return list_metrics


    def train_step(self, data):
        with tf.GradientTape() as tape:
            image_input, (image_output, label) = data
            loc_parameter_estimates, scale_parameter_estimates, samples = self.encoder(image_input)
            z = tf.keras.layers.Concatenate(-1)(samples)
            reconstruction = self.decoder(z)
            # TODO fix reconstruction
            reconstruction_loss = tf.reduce_mean(
                tf.reduce_sum(
                    tf.keras.losses.binary_crossentropy(image_input, reconstruction), axis=(-1, -2, -3, -4)
                )
            )

            kl_losses = []
            for num_latent_space, latent_space in enumerate(self.latent_spaces):
                kl_losses.append(latent_space.kl_loss(
                    [loc_parameter_estimates[num_latent_space], scale_parameter_estimates[num_latent_space]]))
            kl_loss = tf.add_n(kl_losses)
            kl_loss = tf.reduce_sum(kl_loss, axis=0)
            kl_loss = tf.reduce_mean(kl_loss)

            # Estimate the object identifier by averaging the location parameters along the appropriate axis
            object_identifier = tf.reduce_mean(loc_parameter_estimates[0], axis=1)
            # Estimate the triplet loss
            triplet_loss = tf.reduce_mean(tfa.losses.TripletSemiHardLoss()(label, object_identifier)) * self.class_alpha

            total_loss = reconstruction_loss + kl_loss + triplet_loss

        grads = tape.gradient(total_loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))

        # Update metric trackers
        self.total_loss_tracker.update_state(total_loss)
        self.kl_weight_tracker.update_state(self.kl_weight)
        self.reconstruction_loss_tracker.update_state(reconstruction_loss)
        self.kl_loss_tracker.update_state(kl_loss)
        self.tl_loss_tracker.update_state(triplet_loss)

        output_dictionary = {
            "loss": self.total_loss_tracker.result(),
            "reconstruction_loss": self.reconstruction_loss_tracker.result(),
            "kl_loss": self.kl_loss_tracker.result(),
            "kl_weight": self.kl_weight_tracker.result(),
            "tl_loss": self.tl_loss_tracker.result(),
        }

        return output_dictionary

    def data_transformer(self, dictionary_):
        return dictionary_["image"], dictionary_["label"]

if __name__ == "__main__":
    import numpy as np
    latent_spaces = [GaussianLatentSpace(dim=10), HyperSphericalLatentSpace(dim=1)]
    input_shape = [64, 64, 3]
    images = np.zeros([10,5]+input_shape)
    encoder, decoder = encoder_decoder_dense(input_shape=input_shape)
    init_params = {"latent_spaces": latent_spaces,
                   "input_shape": input_shape,
                   "encoder_backbones": [encoder],
                   "decoder_backbone": decoder,
                   "n_transforms": 5,
                   "average_mask": [True, False], }
    ugvae = GVAE(**init_params)

    ugvae.compile(optimizer=tf.keras.optimizers.Adam(lr=1e-3))
    ugvae.fit(images, batch_size = 2, epochs = 2)

    ugvae = GVAETL(**init_params)
    ugvae.compile(optimizer=tf.keras.optimizers.Adam(lr=1e-3))
    ugvae.fit(images, batch_size=2, epochs=2)

