import numpy as np
from matplotlib import pyplot as plt
from torchvision.datasets import CIFAR100


def display_multiple_img(images):
    figure, ax = plt.subplots(nrows=1, ncols=len(images), figsize=(12, 4))
    for ind, image in enumerate(images):
        ax[ind].imshow(image)
        ax[ind].set_axis_off()
    plt.tight_layout()
    plt.show()


def show_images(classes, number=8):
    ds = CIFAR100('./data', train=True)
    valid = ds.data[np.isin(ds.targets, classes)]
    np.random.shuffle(valid)
    display_multiple_img(valid[:number])


if __name__ == '__main__':
    show_images([0, 1])
