import os
import sys

import tensorflow as tf
import tensorflow_hub as hub

PROJECT_PATH = os.path.dirname(os.path.dirname(os.getcwd()))
DIS_PROJECT_PATH = os.path.join(os.path.dirname(PROJECT_PATH), "disentanglement_lib")
sys.path.append(DIS_PROJECT_PATH)
sys.path.append(PROJECT_PATH)
from experiments.disentanglement_lib import evaluate_metrics_dis_lib
from disentanglement_lib.data.ground_truth.named_data import get_named_ground_truth_data
os.environ["CUDA_VISIBLE_DEVICES"]="0"
# Load previously trained model encoder
trained_repetitions = 10
model_list = ["dip_vae", "dip_vae2"]
model_list = ["vae", "dip_vae", "dip_vae2", "betavae", "factor_vae", "tcvae"]
dataset_list = ["arrow", "modelnet", "pixel4"]
dataset_list = ["coil100"]
repetitions = 10
num_train = 100
num_test = num_train // 2


for dataset in dataset_list:
    # Get the data
    ground_truth_data = get_named_ground_truth_data(dataset)
    print("Evaluating dataset", dataset)
    for model_name in model_list:
        print("Evaluating model ", model_name)
        for repetition in range(trained_repetitions):
            print("Evaluating repetition ", repetition)
            model_path = "/home/disentanglement_lib/results_models_data/"+dataset+"_"+model_name+"_"+str(repetition)+"/"+model_name+"/"+model_name+"/tfhub"
            input_layer = tf.keras.layers.Input(ground_truth_data.dataset_class.image_shape)
            encoder_layer = hub.KerasLayer(model_path, signature="gaussian_encoder",signature_outputs_as_dict=True)(input_layer)
            encoder = tf.keras.models.Model(input_layer, encoder_layer["mean"])
            representation_function = encoder.predict




            # Saving folder
            ICML_RESULTS_PATH = "/home/ICML/results"
            experiment_metrics_save_path = os.path.join(ICML_RESULTS_PATH, dataset, model_name, str(repetition))
            os.makedirs(experiment_metrics_save_path, exist_ok=True)


            # Calculate and save metrics
            evaluate_metrics_dis_lib.evaluate_metrics(representation_function, ground_truth_data, repetitions, num_train, num_test, experiment_metrics_save_path)