import tensorflow as tf
import os
import sys
import numpy as np
import tensorflow_hub as hub

PROJECT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
DIS_PROJECT_PATH = os.path.join(os.path.dirname(PROJECT_PATH), "disentanglement_lib")
sys.path.append(PROJECT_PATH)
sys.path.append(DIS_PROJECT_PATH)

print("PROJECT_PATH", PROJECT_PATH)
print("DIS_PROJECT_PATH", DIS_PROJECT_PATH)
from disentanglement_lib.data.ground_truth.named_data import get_named_ground_truth_data
from modules.utils import plotting
from modules.general_metric import general_metric

# Load previously trained model encoder_pytorch
repetitions = 10
model_name = "betavae"
dataset_list = ["modelnet", "arrow", "pixel4"]
# dataset_list = ["pixel4"]
model_list = ["betavae", "factor_vae", "tcvae"]
# model_list = ["vae"]
model_list = ["dip_vae", "dip_vae2"]

for dataset in dataset_list:
    # Get the data
    data_class = get_named_ground_truth_data(dataset)
    images = data_class.dataset_class.flat_images
    colors_flat = plotting.yiq_embedding(data_class.dataset_class.flat_factor_mesh_as_angles[:, 0],
                                         data_class.dataset_class.flat_factor_mesh_as_angles
                                         [:, 1])
    for model_name in model_list:
        for repetition in range(repetitions):
            results_path = os.path.join(DIS_PROJECT_PATH,"results_models_data")
            model_path0 = os.path.join(results_path, dataset+"_"+model_name+"_"+str(repetition), model_name)
            model_path = os.path.join(model_path0, model_name, "tfhub")
            save_path = os.path.join(model_path0, "results")
            save_path_images = os.path.join(save_path, "images")
            save_path_scores = os.path.join(save_path, "scores")
            os.makedirs(save_path_images, exist_ok=True)
            os.makedirs(save_path_scores, exist_ok=True)

            # Load trained model
            print("Loading model in ",model_path)
            encoder = hub.KerasLayer(model_path, signature="gaussian_encoder",signature_outputs_as_dict=True)
            decoder = hub.KerasLayer(model_path, signature="decoder",signature_outputs_as_dict=True)



            embedding_tensor = encoder(images.astype(np.float32))["mean"]
            reconstruction_tensor = decoder(embedding_tensor)["images"]
            with tf.Session() as sess:
                init = tf.global_variables_initializer()
                sess.run(init)
                embedding = embedding_tensor.eval()
                reconstruction = reconstruction_tensor.eval()
            print("Plotting embeddings")
            fig, axes = plotting.plot_latent_dimension_combinations(embedding, colors_flat)
            fig.savefig(os.path.join(save_path_images, "latent_space"),bbox_inches="tight")

            print("Calculating LSBD score")
            k_values = general_metric.create_combinations_k_values_range(start_value=-10, end_value=10)
            embedding_reshaped = embedding.reshape((64, 64, 4))
            score, final_k = general_metric.calculate_metric_k_list(embedding_reshaped, k_values)

            print(score, final_k)

            ICML_RESULTS_PATH = "/home/ICML/results"
            experiment_metrics_save_path = os.path.join(ICML_RESULTS_PATH, dataset, model_name,str(repetition))
            os.makedirs(experiment_metrics_save_path, exist_ok=True)
            np.save(os.path.join(experiment_metrics_save_path,"lsbd.npy"), score)
            np.save(os.path.join(save_path_scores, "score.npy"), score)