import os
import torch
import torchvision


def save_plots(fig_names, outputs, save_root):
    os.makedirs(save_root, exist_ok=True)
    h, w = outputs.shape[-2:]
    outputs = outputs.reshape(-1, 1, h, w)

    for (fig_name, output) in zip(fig_names, outputs):
        torchvision.utils.save_image(torch.tensor(output), os.path.join(save_root, fig_name))
