#!/usr/bin/anaconda3/bin/python3

import neptune
import tensorflow as tf

import os
import sys

DIR_PATH = os.getcwd()  # current directory ("disentangling-everything/experiments")
PROJECT_PATH = os.path.dirname(DIR_PATH)  # parent directory ("disentangling-everything")
sys.path.append(PROJECT_PATH)

from experiments import main_transformvae, neptune_config
from modules.utils.experiment_control.experiment import Experiment
os.environ["CUDA_VISIBLE_DEVICES"]="2"
if __name__ == '__main__':

    # Dataset
    data = "modelnet"
    epochs = 1500

    # Model parameters
    architectures = "dis_lib"

    # some standard datasets (use them for parameters["data_parameters"])
    if data == "arrow":
        dataset_params = {
            "data": "arrow",
            "arrow_size": 64,
            "n_hues": 64,
            "n_rotations": 64,
        }
    elif data == "pixel4":
        dataset_params = {
            "data": "pixel",
            "height": 64,
            "width": 64,
            "step_size_vert": 1,
            "step_size_hor": 1,
            "square_size": 4
        }
    elif data == "modelnet":
        dataset_params = {
            "dataset_filename": "modelnet_color_single_64_64.h5",
            "data": "modelnet_colors"
        }
    else:
        dataset_params = None

    # setup experimental parameters
    # n_repetitions = 10
    n_labels_list = [0, 64, 128, 192, 256, 320, 384, 448, 512]

    # ModelNet40
    n_repetitions = 5
    max_labels = 64 * 64 // 2  # 64 factor values for each factor, labelling in pairs
    steps = 8  # NOTE: this means that steps+1 different n_labels are tried, 0 and max_labels included
    step_size = max_labels // steps

    for n_labels in reversed(n_labels_list):
        for _ in range(n_repetitions):
            # ##### SETUP ALL PARAMETERS #####

            labelling_parameters = {
                "n_labels": n_labels,
            }

            model_parameters = {
                "dist_weight": 100,
                "separate_encoders": False,
                "stop_gradient": False,
                "log_t_limit": (-10.0, -9.0),
                "architectures": architectures,
            }

            training_parameters = {
                "epochs": epochs,
                "batch_size": 128,
            }

            # combine parameters for experiment logging
            parameters = {
                **dataset_params,
                **labelling_parameters,
                **model_parameters,
                **training_parameters,
            }

            # ##### SETUP EXPERIMENT INFO (Neptune and Local) #####
            experiment_name = data+"_"+architectures+"_"+str(n_labels)  # saving name for the experiment for both Neptune and local


            # Neptune Experiment
            group = "TUe"
            api_token = neptune_config.API_KEY  # read api token from neptune config file
            upload_source_files = ["main_transformvae.py"]  # OPTIONAL: save the source code used for the experiment
            neptune.init(project_qualified_name=group + "/sandbox", api_token=api_token)
            # In case a more controlled logging of the metrics is desired remove the keras integration.
            # Instead use the special Neptune callback from the following lines:
            # class NeptuneMonitor(tensorflow.keras.callbacks.Callback):
            #     def on_epoch_end(self, epoch, logs={}):
            #         for metric_name in logs:
            #             neptune.send_metric(metric_name, epoch, logs[metric_name])
            #         neptune.send_metric('loss', epoch, logs['loss'])
            # neptune_monitor = NeptuneMonitor()

            # Local Experiment
            experiment_path = os.path.join(PROJECT_PATH, "results3", architectures)
            experiment_parameters = {"path": experiment_path, "experiment_name": experiment_name}
            exp = Experiment(**experiment_parameters)

            with neptune.create_experiment(name=experiment_name, params=parameters,
                                           upload_source_files=upload_source_files):
                exp.start_experiment(parameters)
                main_transformvae.run_gridworld_torus(exp, dataset_params, labelling_parameters, model_parameters,
                                                      training_parameters)

                # Logging images
                # Provide a figure class from matplotlib to x. Other type of images can also be logged.
                # neptune.log_image(log_name = 'plots', x = plt.gcf(), image_name="test", description = None,
                #                   timestamp = None)

                # Save artifact (i.e. any file such as the trained weights file).
                # neptune.log_artifact(artifact = "path_to_file")

                # Log metric
                # neptune.log_metric(log_name = "metrics", x = evaluate_result, timestamp=None)
            tf.keras.backend.clear_session()
