import tensorflow as tf
import tensorflow_addons as tfa
import os
import sys
from typing import List, Tuple

sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from modules.latent_space.latentspace2 import LatentSpace, GaussianLatentSpace, HyperSphericalLatentSpace
from modules.vae.architectures import encoder_decoder_dense


class GAE(tf.keras.Model):
    def __init__(self,
                 latent_spaces: List[LatentSpace],
                 input_shape: Tuple[int, ...],
                 encoder_backbones: List[tf.keras.models.Model],
                 decoder_backbone: tf.keras.models.Model,
                 n_transforms: int,
                 average_mask=List[bool],
                 **kwargs):
        super(GAE, self).__init__(**kwargs)
        # Parameters
        self.input_shape_ = input_shape
        self.n_transforms = n_transforms


        # Hyperparameters
        self.latent_spaces = latent_spaces
        self.latent_dim = sum([latent_space.latent_dim for latent_space in self.latent_spaces])
        self.average_mask = average_mask
        assert len(self.latent_spaces) == len(
            self.average_mask), "Average mask length does not match latent space list length"

        # Neural network architectures
        self.encoder_backbones = encoder_backbones
        self.decoder_backbone = decoder_backbone
        # Latent parameters
        self.lst_encoder_loc = self.set_lst_encoders()

        # Set encoder and decoder
        self.encoder = self.set_encoder()
        self.decoder = self.set_decoder()

        # Loss trackers
        self.total_loss_tracker = tf.keras.metrics.Mean(name="total_loss")
        self.reconstruction_loss_tracker = tf.keras.metrics.Mean(
            name="reconstruction_loss"
        )
        self.log_var_z_tracker = tf.keras.metrics.Mean(name="log_var_z")


    @property
    def encoder_backbones(self):
        return self.__encoder_backbones

    @encoder_backbones.setter
    def encoder_backbones(self, encoder_backbones):
        if len(encoder_backbones) == 1:
            self.__encoder_backbones = encoder_backbones * len(self.latent_spaces)
        else:
            assert len(self.latent_spaces) == len(encoder_backbones), (f"Number of encoder backbones "
                                                                       f"{len(encoder_backbones)} is not the same as"
                                                                       f" the number of latent spaces"
                                                                       f" {len(self.latent_spaces)}")
            self.__encoder_backbones = encoder_backbones

    def set_lst_encoders(self):
        lst_encoder_loc = []
        # Define encoder backbone
        input_encoder = tf.keras.layers.Input(self.input_shape_)

        for encoder_backbone, latent_space in zip(self.encoder_backbones, self.latent_spaces):
            h_enc = encoder_backbone(input_encoder)
            lst_encoder_loc.append(tf.keras.Model(input_encoder, latent_space.loc_param_layer(h_enc)))

        return lst_encoder_loc

    def set_encoder(self) -> tf.keras.Model:

        # Define the input
        mult_input_layer = tf.keras.layers.Input((self.n_transforms, *self.input_shape_))

        # Pass each image through the encoder output (num_transformers, list_encoder_outputs)

        lst_avg_loc = []
        for num_latent_space, latent_space in enumerate(self.latent_spaces):
            # Estimate parameter tensors with shape (batch_size, n_transforms, param_shape)
            loc_param_estimate = tf.keras.layers.TimeDistributed(self.lst_encoder_loc[num_latent_space])(
                mult_input_layer)
            # Average the predicted distributions
            if self.average_mask[num_latent_space]:
                avg_loc = latent_space.average(loc_param_estimate)
            else:
                avg_loc = loc_param_estimate
            lst_avg_loc.append(avg_loc)
        # Create encoder from z and y
        encoder = tf.keras.models.Model(mult_input_layer, lst_avg_loc)
        return encoder

    def set_decoder(self) -> tf.keras.Model:
        print(f"Total latent dimension is {self.latent_dim}")
        input_decoder = tf.keras.layers.Input((self.latent_dim,))
        x = tf.keras.layers.Dense(self.decoder_backbone.inputs[0].shape[-1])(input_decoder)
        decoder = tf.keras.Model(input_decoder, self.decoder_backbone(x))

        # Pass multiple codes to decoder
        mult_input_layer = tf.keras.layers.Input((self.n_transforms, self.latent_dim))
        x = tf.keras.layers.TimeDistributed(decoder)(mult_input_layer)
        decoder = tf.keras.Model(mult_input_layer, x)
        return decoder

    @property
    def metrics(self):
        list_metrics = [
            self.total_loss_tracker,
            self.reconstruction_loss_tracker
        ]
        return list_metrics

    def call(self, inputs, training=False, **kwargs):
        loc_parameter_estimates = self.encoder(inputs)
        z = tf.keras.layers.Concatenate(-1)(loc_parameter_estimates)
        reconstruction = self.decoder(z)
        return reconstruction

    def train_step(self, data):
        with tf.GradientTape() as tape:
            image_input = data
            loc_parameter_estimates = self.encoder(image_input)
            z = tf.keras.layers.Concatenate(-1)(loc_parameter_estimates)
            reconstruction = self.decoder(z)
            # TODO fix reconstruction
            reconstruction_loss = tf.reduce_mean(
                tf.reduce_sum(
                    tf.keras.losses.binary_crossentropy(image_input, reconstruction), axis=(-1, -2, -3, -4)
                )
            )

            total_loss = reconstruction_loss

        grads = tape.gradient(total_loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))

        # Update metric trackers
        self.total_loss_tracker.update_state(total_loss)
        self.reconstruction_loss_tracker.update_state(reconstruction_loss)

        output_dictionary = {
            "loss": self.total_loss_tracker.result(),
            "reconstruction_loss": self.reconstruction_loss_tracker.result(),
        }

        return output_dictionary

    def make_embedding_function(self, num_latent):
        """
        Creates an embedding function that can be used to retrieve the location parameter of a given num_latent space
        Args:
            num_latent: the number of th elatent space whose latent variable is retrieved
        Returns:
            Returns a function which receives input data with the appropiate shape and returns the embedded data
        """

        def embedding_function(input_data):
            prediction = self.encoder.predict(input_data)[num_latent]
            return prediction

        return embedding_function


    def data_transformer(self, dictionary_):
        return dictionary_["image"]


class GAEProto(GAE):
    def __init__(self,
                 latent_spaces: List[LatentSpace],
                 input_shape: Tuple[int, ...],
                 encoder_backbones: List[tf.keras.models.Model],
                 decoder_backbone: tf.keras.models.Model,
                 n_transforms: int,
                 average_mask=List[bool],
                 class_alpha: float = 1.0,
                 **kwargs):
        self.class_alpha = class_alpha
        super().__init__(latent_spaces, input_shape, encoder_backbones, decoder_backbone, n_transforms, average_mask,
                         **kwargs)
        self.proto_loss_tracker = tf.keras.metrics.Mean(name="class_loss")
        self.encoder_proto_embeddings = self.set_encoder_proto_embeddings()

    @property
    def metrics(self):
        list_metrics = [
            self.total_loss_tracker,
            self.reconstruction_loss_tracker,
            self.proto_loss_tracker
        ]
        return list_metrics


    def set_encoder_proto_embeddings(self) -> tf.keras.Model:
        # Define the input
        mult_input_layer = tf.keras.layers.Input((self.n_transforms, *self.input_shape_))

        # Assume that first location encoder 0 is the one to be averaged

        loc_param_estimate = tf.keras.layers.TimeDistributed(self.lst_encoder_loc[0])(
            mult_input_layer)
        # Create encoder from z and y
        encoder = tf.keras.models.Model(mult_input_layer, loc_param_estimate)
        return encoder


    def train_step(self, data):
        with tf.GradientTape() as tape:
            image_input = data
            loc_parameter_estimates = self.encoder(image_input)

            loc_param_estimate = self.encoder_proto_embeddings(image_input)


            z = tf.keras.layers.Concatenate(-1)(loc_parameter_estimates)
            reconstruction = self.decoder(z)
            # TODO fix reconstruction
            reconstruction_loss = tf.reduce_mean(
                tf.reduce_sum(
                    tf.keras.losses.binary_crossentropy(image_input, reconstruction), axis=(-1, -2, -3, -4)
                )
            )
            proto_loss = self.class_alpha * tf.reduce_mean(
                tf.norm(loc_param_estimate - loc_parameter_estimates[0]) ** 2)
            total_loss = reconstruction_loss + proto_loss

        grads = tape.gradient(total_loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))

        # Update metric trackers
        self.total_loss_tracker.update_state(total_loss)
        self.reconstruction_loss_tracker.update_state(reconstruction_loss)
        self.proto_loss_tracker.update_state(proto_loss)

        output_dictionary = {
            "loss": self.total_loss_tracker.result(),
            "reconstruction_loss": self.reconstruction_loss_tracker.result(),
            "proto_loss": self.proto_loss_tracker.result()
        }

        return output_dictionary


class GAEClass(GAE):
    def __init__(self,
                 latent_spaces: List[LatentSpace],
                 input_shape: Tuple[int, ...],
                 encoder_backbones: List[tf.keras.models.Model],
                 decoder_backbone: tf.keras.models.Model,
                 num_classes: int,
                 n_transforms: int,
                 average_mask=List[bool],
                 class_alpha: float = 1.0,
                 **kwargs):
        self.class_alpha = class_alpha
        self.num_classes = num_classes
        super().__init__(latent_spaces, input_shape, encoder_backbones, decoder_backbone, n_transforms, average_mask,
                         **kwargs)
        self.class_loss_tracker = tf.keras.metrics.Mean(name="class_loss")
        self.classifier = self.set_classifier()

    @property
    def metrics(self):
        list_metrics = [
            self.total_loss_tracker,
            self.reconstruction_loss_tracker,
            self.class_loss_tracker
        ]
        return list_metrics

    def set_classifier(self) -> tf.keras.Model:
        """
        Set up classifier assume that in the list of latent spaces the first should encode the information about the class of the object
        Returns:

        """
        input_classifier = tf.keras.Input((self.latent_spaces[0].latent_dim,))
        output_classifier = tf.keras.layers.Dense(self.num_classes, activation="softmax")(input_classifier)
        return tf.keras.Model(input_classifier, output_classifier)

    def train_step(self, data):
        with tf.GradientTape() as tape:
            image_input, (image_output, label) = data
            loc_parameter_estimates = self.encoder(image_input)
            # Assume the first latent dimension is the one to be averaged
            z_reduced = tf.reduce_mean(loc_parameter_estimates[0], axis=1)
            classification_output = self.classifier(z_reduced)  # Give the samples of the first latent space
            z = tf.keras.layers.Concatenate(-1)(loc_parameter_estimates)

            reconstruction = self.decoder(z)
            reconstruction_loss = tf.reduce_mean(
                tf.reduce_sum(
                    tf.keras.losses.binary_crossentropy(image_input, reconstruction), axis=(-1, -2, -3, -4)
                )
            )

            # Estimate the classification loss
            classification_loss = tf.reduce_mean(
                tf.keras.losses.categorical_crossentropy(label, classification_output)) * self.class_alpha

            total_loss = reconstruction_loss + classification_loss

        grads = tape.gradient(total_loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))

        # Update metric trackers
        self.total_loss_tracker.update_state(total_loss)
        self.reconstruction_loss_tracker.update_state(reconstruction_loss)
        self.class_loss_tracker.update_state(classification_loss)

        output_dictionary = {
            "loss": self.total_loss_tracker.result(),
            "reconstruction_loss": self.reconstruction_loss_tracker.result(),
            "classification_loss": self.class_loss_tracker.result(),
        }

        return output_dictionary

    def data_transformer(self, dictionary_):
        return dictionary_["image"], (dictionary_["image"], tf.one_hot(dictionary_["label"][0], depth=self.num_classes))


class GAETL(GAE):
    def __init__(self,
                 latent_spaces: List[LatentSpace],
                 input_shape: Tuple[int, ...],
                 encoder_backbones: List[tf.keras.models.Model],
                 decoder_backbone: tf.keras.models.Model,
                 n_transforms: int,
                 average_mask=List[bool],
                 class_alpha: float = 1.0,
                 **kwargs):
        self.class_alpha = class_alpha
        super().__init__(latent_spaces, input_shape, encoder_backbones, decoder_backbone, n_transforms, average_mask,
                         **kwargs)
        self.tl_loss_tracker = tf.keras.metrics.Mean(name="triplet_loss")

    @property
    def metrics(self):
        list_metrics = [
            self.total_loss_tracker,
            self.reconstruction_loss_tracker,
            self.tl_loss_tracker
        ]
        return list_metrics

    def train_step(self, data):
        with tf.GradientTape() as tape:
            image_input, (image_output, label) = data
            loc_parameter_estimates = self.encoder(image_input)
            z = tf.keras.layers.Concatenate(-1)(loc_parameter_estimates)
            reconstruction = self.decoder(z)
            # TODO fix reconstruction
            reconstruction_loss = tf.reduce_mean(
                tf.reduce_sum(
                    tf.keras.losses.binary_crossentropy(image_input, reconstruction), axis=(-1, -2, -3, -4)
                )
            )

            # Estimate the object identifier by averaging the location parameters along the appropriate axis
            object_identifier = tf.reduce_mean(loc_parameter_estimates[0], axis=1)
            # Estimate the triplet loss
            triplet_loss = tf.reduce_mean(tfa.losses.TripletSemiHardLoss()(label, object_identifier)) * self.class_alpha

            total_loss = reconstruction_loss + triplet_loss

        grads = tape.gradient(total_loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))

        # Update metric trackers
        self.total_loss_tracker.update_state(total_loss)
        self.reconstruction_loss_tracker.update_state(reconstruction_loss)
        self.tl_loss_tracker.update_state(triplet_loss)

        output_dictionary = {
            "loss": self.total_loss_tracker.result(),
            "reconstruction_loss": self.reconstruction_loss_tracker.result(),
            "tl_loss": self.tl_loss_tracker.result(),
        }

        return output_dictionary


if __name__ == "__main__":
    import numpy as np

    latent_spaces = [GaussianLatentSpace(dim=10), HyperSphericalLatentSpace(dim=1)]
    input_shape = [64, 64, 3]
    images = np.zeros([10, 5] + input_shape)
    encoder, decoder = encoder_decoder_dense(input_shape=input_shape)
    init_params = {"latent_spaces": latent_spaces,
                   "input_shape": input_shape,
                   "encoder_backbones": [encoder],
                   "decoder_backbone": decoder,
                   "n_transforms": 5,
                   "average_mask": [True, False], }



