import numpy as np
# import matplotlib
# matplotlib.use('Agg')  # Must be before importing matplotlib.pyplot or pylab!
import matplotlib.pyplot as plt
from matplotlib import patches
import imageio
import os
from PIL import Image

from modules.utils import utils


def plot_subset(x_array, cols=None, outlines=True):
    """ Input: matrix of images of shape: (rows, cols, h, w, d)
    """
    x_rows, x_cols, height, width, depth = x_array.shape
    assert depth == 1 or depth == 3, "x_array must contain greyscale or RGB images"
    cols = (cols if cols else x_cols)
    rows = x_rows * int(np.ceil(x_cols / cols))

    fig = plt.figure(figsize=(cols * 2, rows * 2))

    def draw_subplot(x_, ax_):
        if depth == 1:
            plt.imshow(x_.reshape([height, width]), cmap="Greys_r", vmin=0, vmax=1)
        elif depth == 3:
            plt.imshow(x_)
        if outlines:
            ax_.get_xaxis().set_visible(False)
            ax_.get_yaxis().set_visible(False)
        else:
            ax_.set_axis_off()

    for j, x_row in enumerate(x_array):
        for i, x in enumerate(x_row[:x_cols], 1):
            # display original
            ax = plt.subplot(rows, cols, i + j * cols * (rows / len(x_array)))  # rows, cols, subplot numbered from 1
            draw_subplot(x, ax)

    return fig


def plot_histograms(s_array, cols=None, outlines=True, filepath=None):
    """ Input: array of categorical variables of shape: (n_variables, n_classes) """
    n_variables, n_classes = s_array.shape
    cols = (cols if cols else n_variables)
    rows = int(np.ceil(n_variables / cols))

    plt.figure(figsize=(cols * 2, rows * 2))

    def draw_subplot(s_, ax_):
        plt.bar(range(n_classes), s_)
        # ax_.set_ylim(0, 1)  # this doesn't work well for a large number of classes, better not fix the y-axis
        if outlines:
            ax_.get_xaxis().set_visible(False)
            ax_.get_yaxis().set_visible(False)
        else:
            ax_.set_axis_off()

    for i, s in enumerate(s_array, 1):
        # display original
        ax = plt.subplot(rows, cols, i)  # rows, cols, subplot numbered from 1
        draw_subplot(s, ax)

    if filepath is None:
        plt.show()
    else:
        plt.savefig(filepath + ".png", bbox_inches='tight')
    plt.close()


def plot_rotations(rotations_array, n_cols=5, filepath=None):
    assert len(rotations_array.shape) == 3, "rotations_array must have shape (sample_size, n_rotations, 2)"
    sample_size = int(rotations_array.shape[0])
    n_rotations = int(rotations_array.shape[1])
    n_rows = np.ceil(sample_size / n_cols)

    for i, rotations in enumerate(rotations_array):
        x = rotations[:, 0]
        y = rotations[:, 1]
        c = np.arange(n_rotations)
        plt.subplot(n_rows, n_cols, i + 1)
        plt.axis("off")
        plt.axis("equal")
        plt.scatter(x, y, c=c, cmap="hsv", marker=".")

    if filepath is None:
        plt.show()
    else:
        plt.savefig(filepath + ".png", bbox_inches='tight')
    plt.close()


def plot_manifold_2d(decoding_function, grid_x, grid_y):
    """display a 2D manifold"""
    # grid_x = np.array of shape (n_x,)
    # grid_y = np.array of shape (n_y,)
    # decoding_function should take as input np.array of shape (batch_size, 2) and produce an image
    #   where batch_size should be n_x
    n_x = len(grid_x)
    grid_x_reshape = np.expand_dims(grid_x, axis=1)  # shape (n_x, 1)
    x_array = []
    for j, y in enumerate(grid_y):
        y_rep = np.repeat(y, n_x)  # shape (n_x,)
        y_rep = np.expand_dims(y_rep, axis=1)  # shape (n_x, 1)
        z_sample = np.concatenate((grid_x_reshape, y_rep), axis=1)  # shape (n_x, 2), suitable for decoder
        x_decoded = decoding_function(z_sample)  # shape (n_x, height, width, depth)
        x_array.append(x_decoded)
    x_array = np.array(x_array)

    plot_subset(x_array)


def yiq_to_rgb(yiq):
    conv_matrix = np.array([[1., 0.956, 0.619],
                            [1., -0.272, 0.647],
                            [1., -1.106, 1.703]])
    return np.tensordot(yiq, conv_matrix, axes=((-1,), (-1)))


def yiq_embedding(theta, phi):
    result = np.zeros(theta.shape + (3,))
    steps = 12
    rounding = True
    if rounding:
        theta_rounded = 2 * np.pi * np.round(steps * theta / (2 * np.pi)) / steps
        phi_rounded = 2 * np.pi * np.round(steps * phi / (2 * np.pi)) / steps
        theta = theta_rounded
        phi = phi_rounded
    result[..., 0] = 0.5 + 0.14 * np.cos((theta + phi) * steps / 2) - 0.2 * np.sin(phi)
    result[..., 1] = 0.25 * np.cos(phi)
    result[..., 2] = 0.25 * np.sin(phi)
    return yiq_to_rgb(result)


def plot_training_output(training_output, filepath=None):
    for metric, values_list in training_output.items():
        # replace +/- inf values to NaN, so nanquantile ignores them, to prevent inf axes
        values_list = np.array(values_list)
        values_list[values_list == np.inf] = np.NaN
        values_list[values_list == -np.inf] = np.NaN
        gap = 0.05
        bottom = np.nanquantile(values_list, gap)
        top = np.nanquantile(values_list, 1 - gap)
        bottom -= gap * (top - bottom)
        top += gap * (top - bottom)
        plt.ylim(bottom, top)
        plt.xlabel("epochs")
        plt.ylabel(metric)
        plt.plot(values_list)
        if filepath is None:
            plt.show()
        else:
            plt.savefig(filepath + "_" + metric + ".png", bbox_inches='tight')
        plt.close()


def plot_torus_angles(encoded, colors):
    encoded_horizontal_angle = np.arctan2(encoded[1][:, :, 0], encoded[1][:, :, 1])
    encoded_vertical_angle = np.arctan2(encoded[4][:, :, 0], encoded[4][:, :, 1])
    fig = plt.figure(figsize=(5, 5))
    ax = plt.gca()
    ax.scatter(encoded_horizontal_angle, encoded_vertical_angle, color=colors)
    ax.set_title("Torus encoded")
    return fig


def plot_decoded(decoded, height_grid, width_grid, divisor=2):
    fig = plt.figure(figsize=(5, 5))
    for data_num in range(height_grid // divisor * width_grid // divisor):
        ax = fig.add_subplot(height_grid // divisor, width_grid // divisor, data_num + 1)
        if decoded.shape[-1] == 1:
            ax.imshow(decoded[data_num * divisor, 0, :, :, 0])
        else:
            ax.imshow(decoded[data_num * divisor, 0, :, :, :])
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_aspect("equal")
    return fig


def plot_euclidean_embedding(encoded, colors):
    fig = plt.figure(figsize=(10, 5))
    ax = fig.add_subplot(1, 2, 1)
    ax.scatter(encoded[0][:, :, 0], encoded[0][:, :, 1], color=colors)
    ax.add_artist(patches.Circle((0, 0), 1.0, fill=False, zorder=-1))
    ax.set_title("Euclidean embedding 1")
    ax = fig.add_subplot(1, 2, 2)
    ax.scatter(encoded[3][:, :, 0], encoded[3][:, :, 1], color=colors)
    ax.add_artist(patches.Circle((0, 0), 1.0, fill=False, zorder=-1))
    ax.set_title("Euclidean embedding 2")
    return fig


def make_gif(source_path, image_names_sorted, gif_path, gif_name, resolution=None, duration=0.25):
    images = []
    for filename in image_names_sorted:
        if resolution is None:
            images.append(Image.open(os.path.join(source_path, filename)))
        else:
            images.append(Image.open(os.path.join(source_path, filename)).resize(resolution))
    imageio.mimsave(os.path.join(gif_path, gif_name + '.gif'), images, duration=duration)


def plot_latent_dimension_combinations(z, colors_flat):
    """
    Plotting of the embeddings for all possible pair-wise combinations of their dimensions
    Args:
        z: Array with num_vectors embeddings of z_dim dimensions (num_vectors, z_dim)
        colors_flat: Array of color values for each of the num_vectors embeddings (num_vectors, color_channels)

    Returns:
        fig, axes
    """
    total_dimensions = z.shape[-1]
    fig, axes = plt.subplots(total_dimensions, total_dimensions, figsize=(10, 10))
    for dim1 in range(total_dimensions):
        for dim2 in range(total_dimensions):
            axes[dim1, dim2].scatter(z[:, dim1], z[:, dim2], c=colors_flat)
            axes[dim1, dim2].set_title("Dim ({}, {})".format(dim1, dim2))
    return fig, axes


def latent_traversals_s1_x_rd(decoder, z2_dim, n_samples=10, n_traversals=10):
    """
    Plot latent traversals for a latent space S^1 x R^d, randomly sampling n_samples from a Gaussian prior on R^d,
        and traversing over S^1 in n_traversals steps
    Args:
        decoder (tf.keras.Model): decoder model with two inputs of dim 2 and d, respectively, for S^1 and R^d
        z2_dim (int): dimension d of R^d
        n_samples (int): number of samples from R^d, and number of rows in plot
        n_traversals (int): number of traversals in S^1, and number of columns in plot

    Returns:
         fig
    """
    angles_grid = np.linspace(0, 2 * np.pi, num=n_traversals, endpoint=False)
    z1_grid = utils.angle_to_point(angles_grid)  # shape (n_traversals, 2)
    z2_samples = np.random.normal(size=(n_samples, z2_dim))
    x_array = np.stack(
        [decoder.predict([z1_grid, np.tile(z, (n_traversals, 1))])
         for z in z2_samples],
        axis=0)  # shape (n_samples, n_traversals, h, w, d)
    return plot_subset(x_array)
