import numpy as np
import torch
import torchvision
from PIL import Image

import utils


def make_grid_from_paths(paths):
    images = []
    for path in paths:
        im = open_image(path)
        images.append(im)
    grid = torchvision.utils.make_grid(
        torch.stack(images), nrow=int(np.sqrt(paths.shape[0])), padding=0)
    return grid, images


def open_image(path, width=300, to_numpy=False):
    trans = torchvision.transforms.ToTensor()
    resize = torchvision.transforms.Resize((width, width))
    img = Image.open(path)
    w, h = img.width, img.height
    crop = torchvision.transforms.CenterCrop(min(w, h))
    im = trans(resize(crop(img)))
    if to_numpy:
        im = utils.to_np(im.permute(1, 2, 0))
    return im
