import os

import torchvision
from matplotlib import pyplot


DIR = os.path.dirname(__file__)


def main():
    root = os.path.join(DIR, "..", "..", "data")
    assert os.path.isdir(root)

    train_data = torchvision.datasets.MNIST(root, train=True, download=True)
    test_data = torchvision.datasets.MNIST(root, train=False, download=True)

    for title, data in zip(["Training", "Testing"], [train_data, test_data]):
        print(f"{title} set size: {len(data)}")
        x0, y0 = data[0]
        print(x0.size)
        pyplot.imshow(x0)
        pyplot.title(f"{title} sample 0 (label={y0})")
        pyplot.show()


if __name__ == "__main__":
    main()