#!/usr/bin/anaconda3/bin/python3

# Code formatting imports
# TODO: add typing hints

import numpy as np
import neptune
# import neptune_tensorboard
import matplotlib.pyplot as plt

import os
import sys

DIR_PATH = os.getcwd()  # current directory ("disentangling-everything/experiments")
PROJECT_PATH = os.path.dirname(DIR_PATH)  # parent directory ("disentangling-everything")
sys.path.append(PROJECT_PATH)

# project-specific imports
from modules.vae import hypertorus_transformvae
from modules.utils import plotting, utils, callbacks
from modules.utils.experiment_control.experiment import Experiment
from modules.general_metric import general_metric
from data import data_loader
from experiments import neptune_config


def run_gridworld_torus(exp, data_parameters, labelling_parameters, model_parameters, training_parameters,
                        animations_periodicity=None):

    callback_list = [callbacks.NeptuneMonitor()]

    data_class = data_loader.load_factor_data(root_path=PROJECT_PATH, **data_parameters)

    # Add data information to the model parameters
    model_parameters.update({"input_shape": data_class.image_shape,
                             "num_circles": data_class.n_factors})
    # Save the model parameters for later reloading
    exp.set_model_parameters(model_parameters, hypertorus_transformvae.HypertorusTransformVAE.__name__)
    x_full = np.expand_dims(data_class.flat_images, axis=1)  # change shape to (n_data_points, 1, h, w, d)
    x_l, x_l_transformations, x_u = data_class.setup_circles_dataset_labelled_pairs(**labelling_parameters)

    vae = hypertorus_transformvae.HypertorusTransformVAE(**model_parameters)
    model_l, model_u = vae.setup_semi_supervised_models(x_l.shape[1])
    if animations_periodicity is not None:
        animations_cb = callbacks.Animations(model_u, exp, data_class, animations_periodicity)
        callback_list.append(animations_cb)
    vae.train_semi_supervised(model_l, model_u, x_l, x_l_transformations, x_u, callback_list=callback_list,
                              **training_parameters)

    # Local saving of weights
    [exp.save_model_weights(model_u[key], key) for key in model_u.keys()]

    # METRIC
    # compute representations (mu's) for full dataset
    # NOTE: only works for HypersphericalLatentSpace, which has parameters mu_euclidean, mu, log_t
    encoded = model_u["encoder_params"].predict(x_full)
    representations_list = [encoded[3 * i + 1] for i in range(data_class.n_factors)]
    # mean_distance, mean_angular_distance = vae.compute_metrics(data_class, representations_list)

    representations_array = np.concatenate(representations_list, axis=-1).squeeze()
    print("Initial representations shape", representations_array.shape)

    representations_reshaped = representations_array.reshape((*data_class.factors_shape, 2 * data_class.n_factors))
    print("Initial representations shape", representations_reshaped.shape)
    k_values = general_metric.create_combinations_k_values_range()
    lsbd_score, _ = general_metric.calculate_metric_k_list(representations_reshaped, k_values)
    neptune.log_metric("LSBD", lsbd_score)
    print("METRIC LSBD Score", lsbd_score)

    # print("METRIC Mean Distance:", mean_distance)
    # print("METRIC Mean Angular Distance:", mean_angular_distance)
    #
    #
    #
    # neptune.log_metric("metric_mean_distance", mean_distance)
    # neptune.log_metric("metric_mean_angular_distance", mean_angular_distance)

    # exp.update_experiment_parameters({
    #   "metric_mean_distance": mean_distance,
    #   "metric_mean_angular_distance": mean_angular_distance,
    # })

    # PLOTTING:
    plot_log_name = "plots"
    # plot some reconstructions
    print("...plotting reconstructions")
    sample_size = 20  # sample size
    indices = np.random.choice(len(x_full), size=sample_size, replace=False)
    x_sample = x_full[indices]
    x_sample_reconstr = model_u["full_model"].predict(x_sample)
    x_array = np.concatenate((x_sample, x_sample_reconstr), axis=1)  # shape (sample_size, 2, h, w, d)
    x_array = np.moveaxis(x_array, 0, 1)
    plotting.plot_subset(x_array)
    name = "reconstructions"
    neptune.log_image(plot_log_name, plt.gcf(), image_name=name)
    exp.save_figure(plt.gcf(), name)
    plt.close()

    # the following plots only work for torus-like data, i.e. with two cyclic factors
    if data_class.n_factors == 2:
        # plot latent traversals in a grid
        print("...plotting latent traversals grid")
        grid_size = 20
        angles_grid = np.linspace(0, 2 * np.pi, num=grid_size, endpoint=False)
        z_grid = utils.angle_to_point(angles_grid)  # shape (grid_size, 2)
        z_grid = np.expand_dims(z_grid, axis=1)  # shape (grid_size, 1, 2)
        x_array = np.concatenate(
            [model_u["decoder"].predict([z_grid, np.tile(z, (grid_size, 1, 1))])
             for z in z_grid],
            axis=1)  # shape (grid_size, grid_size, h, w, d)
        plotting.plot_subset(x_array)
        name = "latent_traversal_grid"
        neptune.log_image(plot_log_name, plt.gcf(), image_name=name)
        exp.save_figure(plt.gcf(), name)
        plt.close()

        # plot embeddings
        print("...plotting embeddings")
        v_angle = data_class.flat_factor_mesh_as_angles[:, 0]
        h_angle = data_class.flat_factor_mesh_as_angles[:, 1]
        colors = plotting.yiq_embedding(v_angle, h_angle)

        plotting.plot_torus_angles(encoded, colors=colors)
        name = "encoded_torus"
        neptune.log_image(plot_log_name, plt.gcf(), image_name=name)
        exp.save_figure(plt.gcf(), name)
        plt.close()

        plotting.plot_euclidean_embedding(encoded, colors=colors)
        name = "encoded_euclidean"
        neptune.log_image(plot_log_name, plt.gcf(), image_name=name)
        exp.save_figure(plt.gcf(), name)
        plt.close()

    print("Done!")


if __name__ == '__main__':
    # some standard datasets (use them for data_parameters)
    rotate_hueshift_arrow_params = {
        "data": "arrow",
        "arrow_size": 32,
        "n_hues": 32,
        "n_rotations": 32,
    }
    wrapped_pixel_params = {
        "data": "pixel",
        "height": 32,
        "width": 32,
        "step_size_vert": 1,
        "step_size_hor": 1,
    }

    # ##### SETUP ALL PARAMETERS #####
    data_parameters = rotate_hueshift_arrow_params

    labelling_parameters = {
        "n_labels": 512,
    }

    model_parameters = {
        "dist_weight": 10,
        "separate_encoders": False,
        "stop_gradient": False,
    }

    training_parameters = {
        "epochs": 300,
        "batch_size": 128,
    }

    animations_periodicity = 4  # per how many epochs to plot latent traversals during training, None means off

    # combine parameters for experiment logging
    parameters = {
        **data_parameters,
        **labelling_parameters,
        **model_parameters,
        **training_parameters,
        "animations_periodicity": animations_periodicity
    }

    # ##### SETUP EXPERIMENT INFO (Neptune and Local) #####
    experiment_name = "animation_test"  # saving name for the experiment for both Neptune and local

    # Neptune Experiment
    group = "TUe"
    api_token = neptune_config.API_KEY  # read api token from neptune config file
    upload_source_files = ["main_transformvae.py"]  # OPTIONAL: save the source code used for the experiment
    neptune.init(project_qualified_name=group + "/sandbox", api_token=api_token)
    # neptune_tensorboard.integrate_with_tensorflow()  # this function already reads out all metrics
    # In case a more controlled logging of the metrics is desired remove the keras integration.
    # Instead use the special Neptune callback from the following lines:
    # class NeptuneMonitor(tensorflow.keras.callbacks.Callback):
    #     def on_epoch_end(self, epoch, logs={}):
    #         for metric_name in logs:
    #             neptune.send_metric(metric_name, epoch, logs[metric_name])
    #         neptune.send_metric('loss', epoch, logs['loss'])
    # neptune_monitor = NeptuneMonitor()

    # Local Experiment
    experiment_path = os.path.join(PROJECT_PATH, "results")
    experiment_parameters = {"path": experiment_path, "experiment_name": experiment_name}
    exp = Experiment(**experiment_parameters)

    with neptune.create_experiment(name=experiment_name, params=parameters, upload_source_files=upload_source_files):
        exp.start_experiment(parameters)
        run_gridworld_torus(exp, data_parameters, labelling_parameters, model_parameters, training_parameters,
                            animations_periodicity=animations_periodicity)

        # Logging images
        # Provide a figure class from matplotlib to x. Other type of images can also be logged.
        # neptune.log_image(log_name = 'plots', x = plt.gcf(), image_name="test", description = None, timestamp = None)

        # Save artifact (i.e. any file such as the trained weights file).
        # neptune.log_artifact(artifact = "path_to_file")

        # Log metric
        # neptune.log_metric(log_name = "metrics", x = evaluate_result, timestamp=None)
