import matplotlib.pyplot as plt
import numpy as np


def plot_grid(dataset, grid, shuffle=False, cmap=None, scale=1.2):
    num_rows, num_cols = grid
    figsize = (scale * num_cols, scale * num_rows)
    fig, ax = plt.subplots(num_rows, num_cols, figsize=figsize)

    num_data = dataset.data.shape[0]
    indices = np.arange(num_data)

    if shuffle:
        np.random.shuffle(indices)

    for i_row in range(num_rows):
        for i_col in range(num_cols):
            index = i_row * num_cols + i_col
            image, _ = dataset[indices[index]]
            ax[i_row, i_col].imshow(image, cmap=cmap)
            ax[i_row, i_col].axis('off')

    fig.tight_layout()

    plt.show()
