# ---------------------------
# _, _ -- 2019
# The University of _, The _ Institute
# contact: _, _
# ---------------------------
"""Functions to help visualize images
"""
import tensorflow as tf
import tfmpl
import numpy as np
import matplotlib.pyplot as plt


def pack_images(images, rows, cols, params):
    """Helper utility to make a field of images."""
    width = params["IMAGE_SHAPE"][0]
    height = params["IMAGE_SHAPE"][1]
    depth = params["IMAGE_SHAPE"][2]
    n_images = rows * cols
    new_images_list = []
    images = tf.reshape(images, (-1, params["n_x"]))
    shape = tf.constant([width * height * depth])
    for j in range(n_images):
        new_image = tf.reshape(
            tf.scatter_nd([[d] for d in params["good_dims"]], images[j, :],
                          shape), (width, height, depth))
        new_images_list.append(new_image)
    images = tf.stack(new_images_list)
    batch = tf.shape(images)[0]
    rows = tf.minimum(rows, batch)
    cols = tf.minimum(batch // rows, cols)
    images = images[:rows * cols]
    images = tf.reshape(images, (rows, cols, width, height, depth))
    images = tf.transpose(images, [0, 2, 1, 3, 4])
    images = tf.reshape(images, [1, rows * width, cols * height, depth])
    return images


def image_tile_summary(name, tensor, params, rows=8, cols=8):
    """
    Creates a tf summary of a tensor as tiled images

    Args:
      name: name of image tile
      tensor: tensor of images
      params: run params as dict
      rows: number of rows in tile
      cols: number of cols in tile


    """
    tf.summary.image(
        name, pack_images(tensor, rows, cols, params), max_outputs=1)


def tf_pca(x, latent_size, n_dims=2):
    x = tf.reshape(x, (-1, latent_size))
    # --- Mean center the data
    x -= tf.reduce_mean(x, -2, keepdims=True)
    # --- Calculate svd decomp of data
    with tf.device('/cpu:0'):
        ss, us, _ = tf.svd(x, full_matrices=False, compute_uv=True)
    # --- XV = US
    ss = tf.expand_dims(ss, -2)
    r = (us[:, 0:n_dims] * ss[:, 0:n_dims])
    return r


@tfmpl.figure_tensor
def draw_pca_scatter(projected, sample_labels, n_samples):
    """
    Draws scatter plots along random dimensions of a tensor.
    Args:
      tensor: tensor of latent embeddings
      sample_labels: labels for samples
      latent_size: latent dimensions of layer
      num_figs: number of figures to display
      n_samples: number of samples per input
    Returns:
      scatter plot figures as png to display.
    """
    # --- Ensure we're not duplicating plots
    if len(sample_labels.shape) > 1:
        labels = np.tile(np.argmax(sample_labels, axis=1), n_samples)
    else:
        labels = np.tile(sample_labels, n_samples)
    figs = tfmpl.create_figures(1, figsize=(4, 4))
    for f in figs:
        ax = f.add_subplot(111)
        # --- Scatter along top two pca dimensions
        im = ax.scatter(
            projected[:, 0],
            projected[:, 1],
            c=labels,
            cmap=plt.cm.get_cmap('Spectral', 10),
            alpha=0.5)
        ax.grid(True)
        f.colorbar(im)
        f.tight_layout()
    return figs