#!/usr/bin/anaconda3/bin/python3

import numpy as np
import neptune
# import neptune_tensorboard
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split


import os
import sys
DIR_PATH = os.getcwd()  # current directory ("disentangling-everything/experiments")
PROJECT_PATH = os.path.dirname(DIR_PATH)  # parent directory ("disentangling-everything")
sys.path.append(PROJECT_PATH)

# project-specific imports
from modules.vae.shapeculturevae import ShapeCultureVAE
from modules.vae import architectures
from modules.utils import plotting, utils, callbacks, shrec_utils
from modules.utils.experiment_control.experiment import Experiment
from data import data_loader
from experiments import neptune_config


def run_shapeculturevae(exp, model_parameters, training_parameters, splitting_parameters):
    exp.set_model_parameters(model_parameters, ShapeCultureVAE.__name__)
    callback_list = [callbacks.NeptuneMonitor()]

    # load data
    dataset_s = data_loader.load_factor_data(data="shrec2021", root_path=PROJECT_PATH,
                                             challenge="Shape", data_type="train")
    dataset_c = data_loader.load_factor_data(data="shrec2021", root_path=PROJECT_PATH,
                                             challenge="Culture", data_type="train")

    x_s_train, x_s_valid, y_s_train, y_s_valid = train_test_split(dataset_s.images, dataset_s.labels,
                                                                  stratify=dataset_s.labels, **splitting_parameters)
    x_c_train, x_c_valid, y_c_train, y_c_valid = train_test_split(dataset_c.images, dataset_c.labels,
                                                                  stratify=dataset_c.labels, **splitting_parameters)
    # y's have shape (num_objects, 12) and are NOT onehot, but all 12 labels are the same.
    # so select first of 12, and then change to onehot, such that shape becomes (num_objects, n_classes)
    y_s_train = shrec_utils.change_labels_to_onehot(y_s_train[:, 0])
    y_s_valid = shrec_utils.change_labels_to_onehot(y_s_valid[:, 0])
    y_c_train = shrec_utils.change_labels_to_onehot(y_c_train[:, 0])
    y_c_valid = shrec_utils.change_labels_to_onehot(y_c_valid[:, 0])
    input_shape = dataset_s.image_shape

    vae = ShapeCultureVAE(input_shape=input_shape, **model_parameters)

    vae.train_supervised(x_s_train, x_s_valid, y_s_train, y_s_valid,
                         x_c_train, x_c_valid, y_c_train, y_c_valid,
                         **training_parameters)
    print("Training done!")


if __name__ == '__main__':
    model_parameters = {
        "enc_dec_architecture": "vgg",
        "separate_encoders": True,
        "latent_dim_s": 32,
        "latent_dim_c": 32,
        "weight_kl": 1.0,
        "weight_clf": 1.0,
        "weight_kl_posterior_prior": 1.0,
        "weight_dist_to_avg_r": 100,
        "weight_dist_to_avg_s": 100,
        "weight_dist_to_avg_c": 100,
    }

    training_parameters = {
        "epochs": 5,
        "batch_size": 64,
    }

    splitting_parameters = {
        "test_size": 0.3,
        "random_state": 0,
    }

    # combine parameters for experiment logging
    parameters = {**model_parameters, **training_parameters, **splitting_parameters}

    # ##### SETUP EXPERIMENT INFO (Neptune and Local) #####
    experiment_name = "shapeculturevae_test"  # saving name for the experiment for both Neptune and local

    # Neptune Experiment
    group = "TUe"
    api_token = neptune_config.API_KEY  # read api token from neptune config file
    upload_source_files = ["main_shapeculturevae.py"]  # OPTIONAL: save the source code used for the experiment
    neptune.init(project_qualified_name=group+"/shrec2021", api_token=api_token)

    # Local Experiment
    experiment_path = os.path.join(PROJECT_PATH, "results")
    experiment_parameters = {"path": experiment_path, "experiment_name": experiment_name}
    exp = Experiment(**experiment_parameters)

    with neptune.create_experiment(name=experiment_name, params=parameters, upload_source_files=upload_source_files):
        exp.start_experiment(parameters)
        run_shapeculturevae(exp, model_parameters, training_parameters, splitting_parameters)
