# System imports
import sys
import os
import time
sys.path.append("../../")
DIR_PATH = os.getcwd()
PROJECT_PATH = os.path.dirname(os.path.dirname(DIR_PATH))
sys.path.append(PROJECT_PATH)

import matplotlib.pyplot as plt
import scipy.linalg
import numpy as np
import tensorflow as tf

from modules.utils import plotting
from modules.general_metric import general_metric
from data import data_loader
from modules.utils.experiment_control.experiment import Experiment




DIS_PROJECT_PATH = os.path.join(os.path.dirname(PROJECT_PATH), "disentanglement_lib")
sys.path.append(DIS_PROJECT_PATH)

from experiments.disentanglement_lib import evaluate_metrics_dis_lib
from disentanglement_lib.data.ground_truth.named_data import get_named_ground_truth_data
from disentanglement_lib.evaluation import evaluate


os.environ["CUDA_VISIBLE_DEVICES"]="0"



dataset_list = ["modelnet40"]
architecture = "dense"

num_labels = 0
trained_repetitions = 10

for data in dataset_list:

    experiment_name = data + "_" + architecture + "_" + str(num_labels)
    experiments_path = "/home/disentangling_everything/results_neurips_final/"+architecture
    experiment_parameters = {"path": experiments_path, "experiment_name": experiment_name}

    print("Evaluating dataset", data)
    ground_truth_data = get_named_ground_truth_data(data)


    # some standard datasets (use them for parameters["data_parameters"])
    if data == "arrow":
        dataset_params = {
            "data": "arrow",
            "arrow_size": 64,
            "n_hues": 64,
            "n_rotations": 64,
        }
    elif data == "pixel4":
        dataset_params = {
            "data": "pixel",
            "height": 64,
            "width": 64,
            "step_size_vert": 1,
            "step_size_hor": 1,
            "square_size": 4
        }
    elif data == "modelnet":
        dataset_params = {
            "dataset_filename": "modelnet_color_single_64_64.h5",
            "data": "modelnet_colors"
        }
    elif data == "modelnet40":
        dataset_params = {

            "root_path": "/data/aligned64",
            "data": "modelnet40",
            "collection_list": [
                "airplane"],
            "data_type": "train",
            "dataset_directory": ""

        }
    elif data == "coil100":
        dataset_params = {
            "root_path": PROJECT_PATH,
            "data": "coil100",
        }

    else:
        dataset_params = None
    data_class = data_loader.load_factor_data(**dataset_params)
    x_full = np.expand_dims(data_class.flat_images, axis=1)  # change shape to (n_data_points, 1, h, w, d)


    for repetition in range(trained_repetitions):
        print("Evaluating repetition", repetition)
        exp = Experiment(**experiment_parameters)
        exp.select_target_previous_experiment(repetition)
        exp.model_parameters
        exp.load_parameters_name()
        # exp.model_parameters["HypertorusTransformVAE"]["input_shape"] = tuple(ground_truth_data.observation_shape)
        # exp.model_parameters["HypertorusTransformVAE"]["num_circles"] = 2
        print(exp.model_parameters)
        try:
            exp.model_parameters["HypercylinderTransformVAE"] = exp.model_parameters.pop("TransformVAE")
        except:
            print("Couldnt rename hypertorus")
        model_class = exp.recreate_model()
        model_u = model_class.setup_model(1)

        # Load the trained weights
        [exp.load_weights(model_u[key], key) for key in model_u.keys()]




        encoded = model_u["encoder_params"].predict(x_full)
        representations_list = [encoded[2 * i] for i in range(data_class.n_factors)]
        # mean_distance, mean_angular_distance = vae.compute_metrics(data_class, representations_list)

        representations_array = np.concatenate(representations_list, axis=-1).squeeze()
        representations_reshaped = representations_array.reshape(
            (*data_class.factors_shape, representations_array.shape[-1]))

        # ----------- CALCULATE LSBD METRIC ----------
        print("Start calculation of LSBD metric")
        k_values = np.arange(-10, 11)
        lsbd_score, _ = general_metric.calculate_metric_rotations(representations_reshaped, k_values, verbose=1)


        # Saving folder
        ICML_RESULTS_PATH = "/home/ICML/results"
        experiment_metrics_save_path = os.path.join(ICML_RESULTS_PATH, data, "lsbd_method", str(num_labels),
                                                    str(repetition))
        os.makedirs(experiment_metrics_save_path, exist_ok=True)

        start_time = time.time()

        np.save(os.path.join(experiment_metrics_save_path, "lsbd.npy"), lsbd_score)
        print("--- %s seconds ---" % (time.time() - start_time))
