import os
import torch
import torchvision


def save_plots(fig_names, outputs, save_root):
    os.makedirs(save_root, exist_ok=True)
    h, w = outputs.shape[-2:]
    outputs = outputs.reshape(-1, 3, h, w)
    outputs = torch.tensor(outputs)
    if h < 256:
        # rescale to avoid too small images
        rescale_h = 256
        rescale_w = int(w * 256 / h)
        outputs = torch.nn.functional.interpolate(torch.tensor(outputs), size=(rescale_h, rescale_w), mode='bicubic')

    for (fig_name, output) in zip(fig_names, outputs):
        torchvision.utils.save_image(output, os.path.join(save_root, fig_name))
