from torch.utils.data import DataLoader
from torchvision import transforms
from stl10_embeddings.stl10_dataset import STL10Embedding


def show(data_loader, max_print=2):
    for batch_nr, (image, target, embedding) in enumerate(data_loader):
        print(f"batch_nr: {batch_nr}, "
              f"image: {image}, "
              f"target: {target}, "
              f"embedding: {embedding}, embedding shape: {embedding.shape}")
        if batch_nr >= max_print:
            break


def main(batch_size=2, use_cuda=False, num_workers=1,
         path="/"):
    transform = transforms.Compose([
        transforms.ToTensor(), ])
    # embeddings_dir = '../STL10-embeddings'
    embeddings_dir = ''
    train_set = STL10Embedding(root=path, embeddings_dir=embeddings_dir,
                               split='train', transform=transform,
                               download=True)
    test_set = STL10Embedding(root=path, embeddings_dir=embeddings_dir,
                              split='test', transform=transform, download=True)
    train_kwargs = {'batch_size': batch_size}
    test_kwargs = {'batch_size': batch_size}
    if use_cuda:
        cuda_kwargs = {'num_workers': num_workers,
                       'pin_memory': True,
                       'shuffle': True}
        train_kwargs.update(cuda_kwargs)
        test_kwargs.update(cuda_kwargs)
    train_loader = DataLoader(train_set, **train_kwargs)
    test_loader = DataLoader(test_set, **test_kwargs)

    show(train_loader)
    show(test_loader)


if __name__ == "__main__":
    main()
