from tensorflow import keras
import numpy as np
import neptune
import matplotlib.pyplot as plt

from modules.utils import utils, plotting


class NeptuneMonitor(keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        for metric_name in logs:
            neptune.send_metric(metric_name, epoch, logs[metric_name])


class Animations(keras.callbacks.Callback):
    def __init__(self, model_u, exp, data_class, periodicity=1):
        super().__init__()
        self.model_u = model_u
        self.exp = exp
        self.data_class = data_class
        self.periodicity = periodicity

    def on_epoch_end(self, epoch, logs=None):
        if epoch % self.periodicity == 0:
            # plot latent traversals in a grid
            grid_size = 20
            angles_grid = np.linspace(0, 2 * np.pi, num=grid_size, endpoint=False)
            z_grid = utils.angle_to_point(angles_grid)  # shape (grid_size, 2)
            z_grid = np.expand_dims(z_grid, axis=1)  # shape (grid_size, 1, 2)
            x_array = np.concatenate(
                [self.model_u["decoder"].predict([z_grid, np.tile(z, (grid_size, 1, 1))])
                 for z in z_grid],
                axis=1)  # shape (grid_size, grid_size, h, w, d)
            plotting.plot_subset(x_array)
            plot_sub_dir = "animation_traversals"
            name = f"latent_traversal_grid_epoch{epoch}"
            neptune.log_image(plot_sub_dir, plt.gcf(), image_name=name)
            self.exp.save_figure(plt.gcf(), image_name=name, sub_dir=plot_sub_dir)
            plt.close()

            # plot embeddings
            x_full = np.expand_dims(self.data_class.flat_images, axis=1)
            encoded = self.model_u["encoder_params"].predict(x_full)

            print("...plotting embeddings")
            v_angle = self.data_class.flat_factor_mesh_as_angles[:, 0]
            h_angle = self.data_class.flat_factor_mesh_as_angles[:, 1]
            colors = plotting.yiq_embedding(v_angle, h_angle)

            plotting.plot_torus_angles(encoded, colors=colors)
            plot_sub_dir = "animation_torus"
            name = f"encoded_torus_epoch{epoch}"
            neptune.log_image(plot_sub_dir, plt.gcf(), image_name=name)
            self.exp.save_figure(plt.gcf(), image_name=name, sub_dir=plot_sub_dir)
            plt.close()

            plotting.plot_euclidean_embedding(encoded, colors=colors)
            plot_sub_dir = "animation_euclidean"
            name = f"encoded_euclidean_epoch{epoch}"
            neptune.log_image(plot_sub_dir, plt.gcf(), image_name=name)
            self.exp.save_figure(plt.gcf(), image_name=name, sub_dir=plot_sub_dir)
            plt.close()

