from matplotlib import pyplot as plt
from PIL import Image
from torchvision.utils import make_grid


def plot_diffusion(x, nrow=16, figsize=(20, 20)):
    plt.figure(figsize=figsize)
    plt.axis("off")
    plt.imshow(
        make_grid(
            x,
            nrow=nrow,
            normalize=True,
        ).permute(1, 2, 0)
    )
    plt.show()


def image_grid(imgs, rows):
    cols = len(imgs) // rows

    w, h = imgs[0].size
    grid = Image.new("RGB", size=(cols * w, rows * h))
    for i, img in enumerate(imgs):
        grid.paste(img, box=(i % cols * w, i // cols * h))
    return grid
