from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import torchvision

import utils


def make_image_grid(samples, save_path,
                    title=None,
                    usetex=True):
    """
    :param samples: (N, C, H, W)
    """
    # make grid
    n = int(np.sqrt(samples.size(0)))
    grid = torchvision.utils.make_grid(samples, nrow=n, padding=0)
    del samples  # release memory

    # create figure
    fig, ax = plt.subplots(figsize=(8.27, 8.27))
    ax.set_xticks([]), ax.set_yticks([])
    if title:
        plt.title(title)
    plt.rc('text', usetex=usetex)
    im = utils.to_np(grid).transpose((1, 2, 0))
    # normalize to RGB range
    im = (im * 255).astype(np.uint8)
    ax.imshow(im)

    save_path = save_path + '.png'
    plt.savefig(save_path, format='png', dpi=500, transparent=True,
                bbox_inches='tight')
    plt.close()
