import tensorflow as tf
import os
import sys
import numpy as np
import tensorflow_hub as hub

PROJECT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
DIS_PROJECT_PATH = os.path.join(os.path.dirname(PROJECT_PATH), "disentanglement_lib")
sys.path.append(PROJECT_PATH)
sys.path.append(DIS_PROJECT_PATH)

print("PROJECT_PATH", PROJECT_PATH)
print("DIS_PROJECT_PATH", DIS_PROJECT_PATH)
from disentanglement_lib.data.ground_truth.named_data import get_named_ground_truth_data
from modules.utils import plotting
from modules.general_metric import general_metric
os.environ["CUDA_VISIBLE_DEVICES"]="3"
# Load previously trained model encoder_pytorch
start_repetition = 0
end_repetition = 10

dataset_list = ["coil100"]
model_list = ["vae", "dip_vae", "dip_vae2", "betavae", "factor_vae", "tcvae"]
# model_list = ["betavae", "factor_vae", "tcvae"]
# model_list = ["vae"]
# model_list = ["dip_vae", "dip_vae2"]
# model_list = ["betavae"]
for dataset in dataset_list:
    # Get the data
    data_class = get_named_ground_truth_data(dataset)
    images = data_class.dataset_class.flat_images
    num_objects = data_class.dataset_class.images.shape[0]
    num_rotations = data_class.dataset_class.images.shape[1]
    colors_flat = plotting.yiq_embedding(data_class.dataset_class.flat_factor_mesh_as_angles[:, 0],
                                         data_class.dataset_class.flat_factor_mesh_as_angles
                                         [:, 1])
    for model_name in model_list:
        for repetition in np.arange(start_repetition, end_repetition):
            results_path = os.path.join(DIS_PROJECT_PATH,"results_models_data")
            print("Loading models from ", results_path)
            model_path0 = os.path.join(results_path, dataset+"_"+model_name+"_"+str(repetition), model_name)
            model_path = os.path.join(model_path0, model_name, "tfhub")
            save_path = os.path.join(model_path0, "results")
            save_path_images = os.path.join(save_path, "images")
            save_path_scores = os.path.join(save_path, "scores")
            os.makedirs(save_path_images, exist_ok=True)
            os.makedirs(save_path_scores, exist_ok=True)

            # Load trained model
            print("Loading model in ",model_path)
            encoder = hub.KerasLayer(model_path, signature="gaussian_encoder",signature_outputs_as_dict=True)
            decoder = hub.KerasLayer(model_path, signature="decoder",signature_outputs_as_dict=True)



            embedding_tensor = encoder(images.astype(np.float32))["mean"]
            reconstruction_tensor = decoder(embedding_tensor)["images"]
            with tf.Session() as sess:
                init = tf.global_variables_initializer()
                sess.run(init)
                embedding = embedding_tensor.eval()
                reconstruction = reconstruction_tensor.eval()
            print("Plotting embeddings")
            fig, axes = plotting.plot_latent_dimension_combinations(embedding, colors_flat)
            fig.savefig(os.path.join(save_path_images, "latent_space"),bbox_inches="tight")

            print("Calculating LSBD score")
            k_values = np.arange(-10, 11)
            embedding_reshaped = embedding.reshape((num_objects, num_rotations, embedding.shape[-1]))
            score, final_k = general_metric.calculate_metric_rotations(embedding_reshaped, k_values, verbose=1)

            print(score, final_k)

            ICML_RESULTS_PATH = "/home/ICML/results"
            experiment_metrics_save_path = os.path.join(ICML_RESULTS_PATH, dataset, model_name,str(repetition))
            os.makedirs(experiment_metrics_save_path, exist_ok=True)
            np.save(os.path.join(experiment_metrics_save_path,"lsbd.npy"), score)
            np.save(os.path.join(save_path_scores, "score.npy"), score)

            