import matplotlib

matplotlib.use("Agg")

import matplotlib.pyplot as plt
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas

import functools as ft
import matplotlib.gridspec as gridspec
import math


def most_squarelike(n):
    c = int(n ** 0.5)
    while c > 0:
        if n % c in [0, c - 1]:
            return (c, int(math.ceil(n / c)))
        c -= 1


def make_visual(images, metrics, visualization_methods=[]):
    w, h = most_squarelike(len(visualization_methods))
    gs = gridspec.GridSpec(h + 1, w)

    fig = plt.figure(tight_layout=True)
    canvas = FigureCanvas(fig)

    ax = fig.add_subplot(gs[0, :])
    view_images(ax, images, n_images=w * 4)

    for i in range(len(visualization_methods)):
        wi, hi = i % w, i // w
        ax = fig.add_subplot(gs[hi + 1, wi])
        visualization_methods[i](ax=ax, metrics=metrics)

    plt.tight_layout()
    canvas.draw()
    out_image = np.frombuffer(canvas.tostring_rgb(), dtype="uint8")
    out_image = out_image.reshape(fig.canvas.get_width_height()[::-1] + (3,))
    return out_image


def make_visual_no_image(metrics, visualization_methods=[]):
    w, h = most_squarelike(len(visualization_methods))
    gs = gridspec.GridSpec(h, w)

    fig = plt.figure(tight_layout=True)
    canvas = FigureCanvas(fig)

    for i in range(len(visualization_methods)):
        wi, hi = i % w, i // w
        ax = fig.add_subplot(gs[hi, wi])
        visualization_methods[i](ax=ax, metrics=metrics)

    plt.tight_layout()
    canvas.draw()
    out_image = np.frombuffer(canvas.tostring_rgb(), dtype="uint8")
    out_image = out_image.reshape(fig.canvas.get_width_height()[::-1] + (3,))
    return out_image


def np_unstack(array, axis):
    arr = np.split(array, array.shape[axis], axis)
    arr = [a.squeeze() for a in arr]
    return arr


def view_images(ax, images, n_images=4):
    assert len(images.shape) == 4
    assert images.shape[-1] == 3
    interval = images.shape[0] // n_images
    sel_images = images[::interval]
    sel_images = np.concatenate(np_unstack(sel_images, 0), 1)
    ax.imshow(sel_images)


def visualize_metric(ax, metrics, *, metric_name, linestyle="--", marker="o", **kwargs):
    metric = metrics[metric_name]
    ax.plot(metric, linestyle=linestyle, marker=marker, **kwargs)
    ax.set_ylabel(metric_name)


def visualize_metrics(ax, metrics, *, ylabel=None, metric_names, **kwargs):
    for metric_name in metric_names:
        metric = metrics[metric_name]
        ax.plot(metric, linestyle="--", marker="o", **kwargs)
    ax.set_ylabel(ylabel or metric_names[0])
