import tensorflow as tf
import tensorflow_hub as hub
import os


def reload_encoder_decoder_dislib(result_path, image_shape, dataset, model_name, repetition):
    """
    Reload models which were previously trained using disentanglement_lib https://github.com/google-research/disentanglement_lib
    Args:
        result_path: path where all results are saved
        image_shape: shape of the input images
        dataset: name of dataset e.g. arrow, pixel4, modelnet
        model_name: name of previously trained model e.g. dip_vae, dip_vae2, factor_vae, tcvae, betavae, vae
        repetition: number of repetition

    Returns:

    """
    model_path = os.path.join(result_path, dataset + "_" + model_name + "_" + str(
        repetition), model_name, model_name, "tfhub")
    # Define the encoder
    input_layer = tf.keras.layers.Input(image_shape)
    encoder_layer = hub.KerasLayer(model_path, signature="gaussian_encoder", signature_outputs_as_dict=True)(
        input_layer)
    encoder = tf.keras.models.Model(input_layer, encoder_layer["mean"])

    # Define the decoder
    latent_input_layer = tf.keras.layers.Input(encoder_layer["mean"].shape[-1])  # get the latent variable shape
    decoder_layer = hub.KerasLayer(model_path, signature="decoder", signature_outputs_as_dict=True)(latent_input_layer)
    decoder = tf.keras.models.Model(latent_input_layer, decoder_layer["images"])
    return encoder, decoder
