# System imports
import sys
import os
import time
sys.path.append("../../")
DIR_PATH = os.getcwd()
PROJECT_PATH = os.path.dirname(os.path.dirname(DIR_PATH))
sys.path.append(PROJECT_PATH)

import matplotlib.pyplot as plt
import scipy.linalg
import numpy as np
import tensorflow as tf

from modules.utils import plotting
from modules.general_metric import general_metric
from data import data_loader
from modules.utils.experiment_control.experiment import Experiment




DIS_PROJECT_PATH = os.path.join(os.path.dirname(PROJECT_PATH), "disentanglement_lib")
sys.path.append(DIS_PROJECT_PATH)

from experiments.disentanglement_lib import evaluate_metrics_dis_lib
from disentanglement_lib.data.ground_truth.named_data import get_named_ground_truth_data
from disentanglement_lib.evaluation import evaluate


os.environ["CUDA_VISIBLE_DEVICES"]="1"



num_labels = 0


dataset_list = ["modelnet40"]
architecture = "dense"
start_repetition = 0
end_repetition = 1
evaluation_repetitions = 10
num_factors = 2
num_train = 100
num_test = num_train // 2

for data_name in dataset_list:
    experiments_path = "/home/disentangling_everything/results_neurips_final/"+architecture
    experiment_name = data_name + "_" + architecture + "_" + str(num_labels)
    experiment_parameters = {"path": experiments_path, "experiment_name": experiment_name}
    print("Evaluating dataset", data_name)
    ground_truth_data = get_named_ground_truth_data(data_name)


    for repetition in np.arange(start_repetition, end_repetition):
        print("Evaluating repetition", repetition)
        exp = Experiment(**experiment_parameters)
        exp.select_target_previous_experiment(repetition)
        exp.model_parameters
        exp.load_parameters_name()
        # exp.model_parameters["HypertorusTransformVAE"]["input_shape"] = tuple(ground_truth_data.observation_shape)
        # exp.model_parameters["HypertorusTransformVAE"]["num_circles"] = 2
        print(exp.model_parameters)
        try:
            exp.model_parameters["HypercylinderTransformVAE"] = exp.model_parameters.pop("TransformVAE")
        except:
            print("Couldnt rename hypertorus")
        model_class = exp.recreate_model()
        model_u = model_class.setup_model(1)

        # Load the trained weights
        [exp.load_weights(model_u[key], key) for key in model_u.keys()]


        def representation_function(imgs):
            encoded = model_u["encoder_params"].predict(np.expand_dims(imgs, 1))
            representations_list = [encoded[2 * i] for i in range(num_factors)]
            latent = np.concatenate(representations_list, axis=-1).squeeze()
            return latent


        # Saving folder
        ICML_RESULTS_PATH = "/home/ICML/results"
        experiment_metrics_save_path = os.path.join(ICML_RESULTS_PATH, data_name, "lsbd_method", str(num_labels),
                                                    str(repetition))
        os.makedirs(experiment_metrics_save_path, exist_ok=True)

        start_time = time.time()
        evaluate_metrics_dis_lib.evaluate_metrics(representation_function, ground_truth_data, evaluation_repetitions, num_train,
                                                  num_test, experiment_metrics_save_path)
        print("--- %s seconds ---" % (time.time() - start_time))
