#!/usr/bin/anaconda3/bin/python3

import neptune
import tensorflow as tf

import os
import sys

DIR_PATH = os.getcwd()  # current directory ("disentangling-everything/experiments")
PROJECT_PATH = os.path.dirname(DIR_PATH)  # parent directory ("disentangling-everything")
sys.path.append(PROJECT_PATH)

from experiments import main_transformvae2, neptune_config
from modules.utils.experiment_control.experiment import Experiment
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
if __name__ == '__main__':

    # Dataset
    data = "modelnet40"
    epochs = 1500

    # Model parameters
    # architecture = "dis_lib"
    architecture = "dense"

    # some standard datasets (use them for parameters["data_parameters"])
    if data == "arrow":
        dataset_params = {
            "data": "arrow",
            "arrow_size": 64,
            "n_hues": 64,
            "n_rotations": 64,
        }
    elif data == "pixel4":
        dataset_params = {
            "data": "pixel",
            "height": 64,
            "width": 64,
            "step_size_vert": 1,
            "step_size_hor": 1,
            "square_size": 4
        }
    elif data == "modelnet":
        dataset_params = {
            "dataset_filename": "modelnet_color_single_64_64.h5",
            "data": "modelnet_colors"
        }
    elif data == "modelnet40":
        dataset_params = {
            "root_path": "/data/aligned64",
            "data": "modelnet40",
            "collection_list": [
                "airplane"],
            "data_type": "train",
            "dataset_directory": ""
        }
    elif data == "coil100":
        dataset_params = {
            "root_path": PROJECT_PATH,
            "data": "coil100",
        }
    elif data == "smallnorb":
        dataset_params = {
            "root_path": PROJECT_PATH,
            "data": "smallnorb"
        }

    else:
        dataset_params = None

    # setup experimental parameters
    # n_repetitions = 10
    # n_labels_list = [-1]
    n_labels_list = [0]

    # ModelNet40
    n_repetitions = 1

    steps = 8  # NOTE: this means that steps+1 different n_labels are tried, 0 and max_labels included


    for n_labels in reversed(n_labels_list):
        for _ in range(n_repetitions):
            # ##### SETUP ALL PARAMETERS #####

            labelling_parameters = {
                "n_labels": n_labels,
            }

            model_parameters = {"input_shape": (64 ,64, 3),
                                "latent_dim": 5,
                                "separate_encoders": False,
                                "stop_gradient": False,
                                "log_t_limit": (-10, -5),
                                "kl_weight": 1,
                                "dist_weight": 1,
                                "architecture": architecture,
                                }
            if n_labels == -1:
                training_parameters = {
                    "epochs": epochs,
                    "batch_size": 128,
                }
            elif n_labels == 0:
                training_parameters = {
                    "epochs": epochs,
                    "batch_size": 128*72,
                }
            else:
                training_parameters = {
                }


            # combine parameters for experiment logging

            # ##### SETUP EXPERIMENT INFO (Neptune and Local) #####
            experiment_name = data +"_" + architecture + "_" + str(n_labels)  # saving name for the experiment for both Neptune and local

            # Neptune Experiment
            group = "TUe"
            api_token = neptune_config.API_KEY  # read api token from neptune config file
            upload_source_files = ["main_transformvae_multiple2.py"]  # OPTIONAL: save the source code used for the experiment
            neptune.init(project_qualified_name=group + "/sandbox", api_token=api_token)

            # Local Experiment
            experiment_path = os.path.join(PROJECT_PATH, "results_neurips_final", architecture)
            experiment_parameters = {"path": experiment_path, "experiment_name": experiment_name}
            exp = Experiment(**experiment_parameters)

            parameters = {
                **dataset_params,
                **labelling_parameters,
                **model_parameters,
                **training_parameters,
                **experiment_parameters
            }

            with neptune.create_experiment(name=experiment_name, params=parameters,
                                           upload_source_files=upload_source_files):
                exp.start_experiment(parameters)
                main_transformvae2.run_gridworld_torus(exp, dataset_params, labelling_parameters, model_parameters,
                                                       training_parameters)

            tf.keras.backend.clear_session()
