import utils
import matplotlib.pyplot as plt


def gridimshow(image, ax):
    if image.shape[0] == 1:
        image = utils.tensor2numpy(image[0, ...])
        ax.imshow(1 - image, cmap="Greys")
    else:
        image = utils.tensor2numpy(image.permute(1, 2, 0))
        ax.imshow(image)
    ax.spines["top"].set_visible(False)
    ax.spines["right"].set_visible(False)
    ax.spines["left"].set_visible(False)
    ax.spines["bottom"].set_visible(False)
    ax.tick_params(axis="both", length=0)
    ax.set_xticklabels("")
    ax.set_yticklabels("")


def plot_hist_marginals(data, lims=None, gt=None):
    """
    Plots marginal histograms and pairwise scatter plots of a dataset.
    """

    n_bins = int(np.sqrt(data.shape[0]))

    if data.ndim == 1:

        fig, ax = plt.subplots(1, 1)
        ax.hist(data, n_bins, normed=True)
        ax.set_ylim([0, ax.get_ylim()[1]])
        if lims is not None:
            ax.set_xlim(lims)
        if gt is not None:
            ax.vlines(gt, 0, ax.get_ylim()[1], color="r")

    else:

        n_dim = data.shape[1]
        fig, ax = plt.subplots(n_dim, n_dim)
        ax = np.array([[ax]]) if n_dim == 1 else ax

        if lims is not None:
            lims = np.asarray(lims)
            lims = np.tile(lims, [n_dim, 1]) if lims.ndim == 1 else lims

        for i in xrange(n_dim):
            for j in xrange(n_dim):

                if i == j:
                    ax[i, j].hist(data[:, i], n_bins, normed=True)
                    ax[i, j].set_ylim([0, ax[i, j].get_ylim()[1]])
                    if lims is not None:
                        ax[i, j].set_xlim(lims[i])
                    if gt is not None:
                        ax[i, j].vlines(gt[i], 0, ax[i, j].get_ylim()[1], color="r")

                else:
                    ax[i, j].plot(data[:, i], data[:, j], "k.", ms=2)
                    if lims is not None:
                        ax[i, j].set_xlim(lims[i])
                        ax[i, j].set_ylim(lims[j])
                    if gt is not None:
                        ax[i, j].plot(gt[i], gt[j], "r.", ms=8)

    plt.show(block=False)

    return fig, ax


def disp_imdata(xs, imsize, layout=(1, 1)):
    """
    Displays an array of images, a page at a time. The user can navigate pages with
    left and right arrows, start over by pressing space, or close the figure by esc.
    :param xs: an numpy array with images as rows
    :param imsize: size of the images
    :param layout: layout of images in a page
    :return: none
    """

    num_plots = np.prod(layout)
    num_xs = xs.shape[0]
    idx = [0]

    # create a figure with suplots
    fig, axs = plt.subplots(layout[0], layout[1])

    if isinstance(axs, np.ndarray):
        axs = axs.flatten()
    else:
        axs = [axs]

    for ax in axs:
        ax.axes.get_xaxis().set_visible(False)
        ax.axes.get_yaxis().set_visible(False)

    def plot_page():
        """Plots the next page."""

        ii = np.arange(idx[0], idx[0] + num_plots) % num_xs

        for ax, i in zip(axs, ii):
            ax.imshow(xs[i].reshape(imsize), cmap="gray", interpolation="none")
            ax.set_title(str(i))

        fig.canvas.draw()

    def on_key_event(event):
        """Event handler after key press."""

        key = event.key

        if key == "right":
            # show next page
            idx[0] = (idx[0] + num_plots) % num_xs
            plot_page()

        elif key == "left":
            # show previous page
            idx[0] = (idx[0] - num_plots) % num_xs
            plot_page()

        elif key == " ":
            # show first page
            idx[0] = 0
            plot_page()

        elif key == "escape":
            # close figure
            plt.close(fig)

    fig.canvas.mpl_connect("key_press_event", on_key_event)
    plot_page()

