#!/usr/bin/anaconda3/bin/python3

# Code formatting imports
# TODO: add typing hints

import numpy as np
import neptune
import matplotlib.pyplot as plt

import os
import sys

DIR_PATH = os.getcwd()  # current directory ("disentangling-everything/experiments")
PROJECT_PATH = os.path.dirname(DIR_PATH)  # parent directory ("disentangling-everything")
sys.path.append(PROJECT_PATH)

# project-specific imports

from modules.vae.hypercylinder_transformvae import HypercylinderTransformVAE
from modules.utils import plotting, utils, callbacks
from modules.general_metric import general_metric
from data import data_loader
from neptunecontrib.monitoring.keras import NeptuneMonitor
from sklearn.decomposition import PCA


def plot_traversal_between_objects(model, euclidean_encoded, num_object_start, num_object_end, num_interpolation_points,
                                   num_angles):
    start_point = np.mean(euclidean_encoded[num_object_start], axis=0)
    end_point = np.mean(euclidean_encoded[num_object_end], axis=0)
    total_rotations = euclidean_encoded.shape[1]
    rotation_multiplier_index = total_rotations // num_angles

    fig, axes = plt.subplots(num_interpolation_points, num_angles, figsize=(num_angles, num_interpolation_points * 1))
    angles = 2 * np.pi * np.linspace(0, 1, total_rotations)
    circle = np.expand_dims(np.array([np.cos(angles), np.sin(angles)]).transpose(), 0)

    for interpolation in range(num_interpolation_points):
        point = start_point * (1 - interpolation / num_interpolation_points) + end_point * (
                interpolation / num_interpolation_points)
        constant_euclidean = np.expand_dims([point] * total_rotations, 0)
        traversal = model.predict([circle, constant_euclidean])
        for num_ax in range(num_angles):
            #             if num_ax == 0:
            #                 axes[interpolation, num_ax].set_ylabel(str(interpolation))

            axes[interpolation, num_ax].imshow(traversal[0, num_ax * rotation_multiplier_index])
            axes[interpolation, num_ax].set_xticks([])
            axes[interpolation, num_ax].set_yticks([])
    # Set common labels
    #     fig.text(0.5, 0.1, 'common xlabel', ha='center', va='center')
    #     fig.text(0.1, 0.5, 'common ylabel', ha='center', va='center', rotation='vertical')
    return fig, axes


def run_gridworld_torus(exp, data_parameters, labelling_parameters, model_parameters, training_parameters,
                        animations_periodicity=None):
    exp.set_model_parameters(model_parameters, HypercylinderTransformVAE.__name__)
    callback_list = [callbacks.NeptuneMonitor()]

    data_class = data_loader.load_factor_data(**data_parameters)
    x_full = np.expand_dims(data_class.flat_images, axis=1)  # change shape to (n_data_points, 1, h, w, d)


    x_l = data_class.images
    print("Data shape", x_l.shape)
    n_views = data_class.images.shape[1]
    views = 2 * np.pi * np.linspace(0, 1, n_views, endpoint=False)
    transformations_circular = np.expand_dims(np.array([views] * len(x_l)), -1)
    transformations_euclidean = np.ones(transformations_circular.shape)
    x_l_transformations = [transformations_circular, transformations_euclidean]

    # Define model
    vae = HypercylinderTransformVAE(**model_parameters)

    model_l, model_u = vae.setup_semi_supervised_models(x_l.shape[1])
    if animations_periodicity is not None:
        animations_cb = callbacks.Animations(model_u, exp, data_class, animations_periodicity)
        callback_list.append(animations_cb)


    if labelling_parameters["n_labels"] == 0:
        print("Training unlabelled")
        x = np.expand_dims(data_class.flat_images, 1)
        y = np.expand_dims(data_class.flat_images, 1)
        model_u["full_model"].fit(x=x,
                                  y=y,
                                  callbacks=[NeptuneMonitor()],
                                  **training_parameters)
    else:
        print("Training labelled")
        x = [x_l, *x_l_transformations]
        y = x_l
        model_l["full_model"].fit(x=x,
                                  y=y,
                                  callbacks=[NeptuneMonitor()],
                                  **training_parameters)

    # Local saving of weights
    print("Saving weights")
    [print(key) for key in model_u.keys()]
    [exp.save_model_weights(model_u[key], key) for key in model_u.keys()]

    # METRIC
    # compute representations (mu's) for full dataset
    # NOTE: only works for HypersphericalLatentSpace, which has parameters mu_euclidean, mu, log_t
    encoded = model_u["encoder_params"].predict(x_full)
    representations_list = [encoded[2 * i] for i in range(data_class.n_factors)]
    # mean_distance, mean_angular_distance = vae.compute_metrics(data_class, representations_list)

    representations_array = np.concatenate(representations_list, axis=-1).squeeze()
    representations_reshaped = representations_array.reshape((*data_class.factors_shape, representations_array.shape[-1]))

    # ----------- CALCULATE LSBD METRIC ----------
    print("Start calculation of LSBD metric")
    k_values = np.arange(-10, 11)
    lsbd_score, _ = general_metric.calculate_metric_rotations(representations_reshaped, k_values, verbose=1)
    neptune.log_metric("LSBD", lsbd_score)
    print("METRIC LSBD Score", lsbd_score)

    # ----------- PLOTTING ----------
    plot_log_name = "plots"
    # plot some reconstructions
    print("...plotting reconstructions")
    sample_size = 20  # sample size
    indices = np.random.choice(len(x_full), size=sample_size, replace=False)
    x_sample = x_full[indices]
    x_sample_reconstr = model_u["full_model"].predict(x_sample)
    x_array = np.concatenate((x_sample, x_sample_reconstr), axis=1)  # shape (sample_size, 2, h, w, d)
    x_array = np.moveaxis(x_array, 0, 1)
    plotting.plot_subset(x_array)
    name = "reconstructions"
    neptune.log_image(plot_log_name, plt.gcf(), image_name=name)
    exp.save_figure(plt.gcf(), name)
    plt.close()

    # the following plots only work for torus-like data, i.e. with two cyclic factors
    encoded_data = model_l["encoder_params"].predict([x_l, *x_l_transformations])
    circle_encoded = encoded_data[0]
    euclidean_encoded = encoded_data[2]
    reconstructions = model_l["full_model"].predict([x_l, *x_l_transformations])

    print("Start Plotting")
    # Plot reconstructions per class
    total_classes = len(np.unique(data_class.labels))
    # examples = 1
    selected_views = data_class.images.shape[1]

    # for example_num in range(examples):
    #     if selected_views == 1:
    #         fig, axes = plt.subplots(2 * total_classes, selected_views,
    #                                  figsize=(selected_views * 10, total_classes * 2.5))
    #         for num_label, label in enumerate(np.unique(data_class.labels)):
    #             for num_ax in range(selected_views):
    #                 if num_ax == 0:
    #                     axes[num_label * 2].set_ylabel("Rec")
    #                     axes[(num_label * 2) + 1].set_ylabel("Orig")
    #                 # Plot original and reconstruction
    #                 axes[num_label * 2].imshow(
    #                     reconstructions[data_class.labels[:, 0] == label][example_num][num_ax])
    #                 axes[(num_label * 2) + 1].imshow(
    #                     data_class.images[data_class.labels[:, 0] == label][example_num][
    #                         num_ax])
    #                 axes[num_label * 2].set_xticks([])
    #                 axes[num_label * 2].set_yticks([])
    #                 axes[(num_label * 2) + 1].set_xticks([])
    #     else:
    #         fig, axes = plt.subplots(2 * total_classes, selected_views, figsize=(selected_views, total_classes * 2.5))
    #         for num_label, label in enumerate(np.unique(data_class.labels)):
    #             for num_ax in range(selected_views):
    #                 if num_ax == 0:
    #                     axes[num_label * 2, num_ax].set_ylabel("Rec")
    #                     axes[(num_label * 2) + 1, num_ax].set_ylabel("Orig")
    #                 axes[num_label * 2, num_ax].imshow(
    #                     reconstructions[data_class.labels[:, 0] == label][example_num][num_ax])
    #                 axes[(num_label * 2) + 1, num_ax].imshow(
    #                     data_class.images[data_class.labels == label][example_num][
    #                         num_ax])
    #                 axes[num_label * 2, num_ax].set_xticks([])
    #                 axes[num_label * 2, num_ax].set_yticks([])
    #                 axes[(num_label * 2) + 1, num_ax].set_xticks([])
    #                 axes[(num_label * 2) + 1, num_ax].set_yticks([])

        # neptune.log_image('plots', fig, image_name='reconstructions_' + str(example_num))

    print("Plotting PCA")
    print(data_class.labels[:,0])
    pca = PCA(n_components=2)
    pca.fit(np.mean(euclidean_encoded, 1))
    x_embedded = pca.transform(np.mean(euclidean_encoded, 1))
    fig, ax = plt.subplots(1, 1, figsize=(5, 5))
    ax.scatter(x_embedded[:, 0], x_embedded[:, 1], c=data_class.labels[:, 0])
    ax.set_title("Embedding Object Euclidean Color: Class")

    neptune.log_image('plots', fig, image_name='euclidean_object')

    pca = PCA(n_components=2)
    pca.fit(euclidean_encoded.reshape(euclidean_encoded.shape[0] * euclidean_encoded.shape[1], -1))
    x_embedded = pca.transform(
        euclidean_encoded.reshape(euclidean_encoded.shape[0] * euclidean_encoded.shape[1], -1))
    fig, ax = plt.subplots(1, 1, figsize=(5, 5))
    ax.scatter(x_embedded[:, 0], x_embedded[:, 1],
               c=data_class.labels.reshape(data_class.labels.shape[0], data_class.labels.shape[1]))
    ax.set_title("Embedding Image Euclidean Color: Class")
    neptune.log_image('plots', fig, image_name='euclidean_image')

    pca = PCA(n_components=2)
    x_embedded = pca.fit_transform(euclidean_encoded.reshape(euclidean_encoded.shape[0] * euclidean_encoded.shape[1], -1))
    fig, ax = plt.subplots(1, 1, figsize=(5, 5))
    ax.scatter(x_embedded[:, 0], x_embedded[:, 1],
               c=x_l_transformations[0].reshape(data_class.labels.shape[0], data_class.labels.shape[1]))
    ax.set_title("Embedding Image Euclidean Color: Angles")
    neptune.log_image('plots', fig, image_name='euclidean_image_rotation')

    print("Plotting circle embeddings")
    # Plot embeddings in the circle
    scale = 2
    fig, ax = plt.subplots(1, 1, figsize=(5, 5))
    for num_object in range(len(circle_encoded)):
        ax.scatter(scale * num_object * circle_encoded[num_object, :, 0],
                   scale * num_object * circle_encoded[num_object, :, 1], c=range(selected_views), cmap = "Reds")
    ax.set_title(r"$S^1$")
    plt.show()
    neptune.log_image("plots", fig, image_name="circle_embeddings")

    print("Plotting parameter distribution")
    # Plot the distribution of the parameters for the circle
    fig, ax = plt.subplots(1, 1, figsize=(5, 5))
    ax.hist(encoded_data[0].flatten(), label="location")
    ax.set_title("Distribution of circular embeddings parameters")
    ax.hist(encoded_data[1].flatten(), label="scale")
    ax.legend()
    neptune.log_image("plots", fig, image_name="circular_parameters_distribution")



    # Plot latent traversal between objects
    print("Plotting latent traversal distribution")
    fig, _ = plot_traversal_between_objects(model_l["decoder"],
                                            euclidean_encoded,
                                            num_object_start=0,
                                            num_object_end=1,
                                            num_interpolation_points=10,
                                            num_angles=10)
    neptune.log_image('plots', fig, image_name='latent_traversal_objects')
    plt.close(fig)

    print("Done!")


