import torch
import numpy as np
import matplotlib.pyplot as plt
import os
import torchvision.transforms.functional as F
plt.rcParams["savefig.bbox"] = 'tight'

data_dir = "../dataviz"
os.makedirs(data_dir, exist_ok=True)

def showimage(imgs):
    if not isinstance(imgs, list):
        imgs = [imgs]
    fig, axs = plt.subplots(ncols=len(imgs), squeeze=False)
    for i, img in enumerate(imgs):
        img = img.detach()
        img = F.to_pil_image(img)
        axs[0, i].imshow(np.asarray(img))
        axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
    plt.savefig(os.path.join(data_dir,"batch_vis.png"))