import io

import torch
from torchvision.transforms import ToTensor
from PIL import Image
from matplotlib import pyplot as plt
from matplotlib import colors as mcolors

plt.switch_backend('agg')


def plt2tensor(fig):
    buf = io.BytesIO()
    fig.savefig(buf, format='jpg')
    buf.seek(0)
    image = Image.open(buf)
    image = ToTensor()(image)
    plt.close()
    return image


def draw_2d(x, vmin=0, vmax=1, cmap='bwr', add_colorbar=False):
    # x: (nx, nt)
    fig, ax = plt.subplots(figsize=(4, 4))
    im = ax.imshow(x.detach().cpu().numpy(), cmap=cmap, aspect='auto', vmin=vmin, vmax=vmax)
    ax.axis('off')
    if add_colorbar:
        fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
    return plt2tensor(fig)


def draw_3d(x, downsample=2, nrow=5, vmin=-2.5, vmax=2.5, cmap='bwr', add_colorbar=False):
    # x: (nx, ny, nt)
    if downsample > 1:
        idx = torch.tensor(list(range(0, x.size(-1), downsample)), device=x.device, dtype=torch.long)
        x = x[..., idx]
    ncol = x.size(-1) // nrow
    fig, axes = plt.subplots(nrow, ncol, figsize=(8, 8))
    fig.subplots_adjust(hspace=0., wspace=0.)
    for ax, im in zip(axes.flat, x.detach().cpu().permute(2, 0, 1)):
        ax.imshow(im, cmap=cmap, aspect='auto', vmin=vmin, vmax=vmax)
        ax.axis('off')
    if add_colorbar:

        norm = mcolors.Normalize(vmin=vmin, vmax=vmax)
        sm = plt.cm.ScalarMappable(norm=norm, cmap=cmap)
        sm.set_array([])

        boxes = [ax.get_position() for ax in axes.ravel()]
        bottom = min(b.y0 for b in boxes)
        top = max(b.y1 for b in boxes)
        right = max(b.x1 for b in boxes)
        pad = 0.01
        max_right = 0.98
        width = min(0.02, max_right - right - pad)
        if width <= 0.001:

            pad = max(0.002, max_right - right - 0.02)
            width = max(0.005, max_right - right - pad)
        cax = fig.add_axes([right + pad, bottom, width, top - bottom])
        fig.colorbar(sm, cax=cax, orientation='vertical')

    return plt2tensor(plt.gcf())


def draw(x, **kwargs):
    if x.dim() == 2:
        return draw_2d(x, **kwargs)
    if x.dim() == 3:
        return draw_3d(x, **kwargs)
    raise ValueError(f'Unsupported dimension {x.dim()}')
