from cifar100_embeddings.cifar100_dataset import CIFAR100Embedding
from torch.utils.data import DataLoader
from torchvision import transforms


def show(data_loader, max_print=2):
    for batch_nr, (image, target, embedding) in enumerate(data_loader):
        print(f"batch_nr: {batch_nr}, "
              f"image: {image}, "
              f"target: {target}, "
              f"embedding: {embedding}, embedding shape: {embedding.shape}")
        if batch_nr >= max_print:
            break


def main(batch_size=2, use_cuda=False, num_workers=1,
         path=""):
    transform = transforms.Compose([
        transforms.ToTensor(), ])
    train_set = CIFAR100Embedding(root=path, train=True, transform=transform)
    test_set = CIFAR100Embedding(root=path, train=False, transform=transform)
    train_kwargs = {'batch_size': batch_size}
    test_kwargs = {'batch_size': batch_size}
    if use_cuda:
        cuda_kwargs = {'num_workers': num_workers,
                       'pin_memory': True,
                       'shuffle': True}
        train_kwargs.update(cuda_kwargs)
        test_kwargs.update(cuda_kwargs)
    train_loader = DataLoader(train_set, **train_kwargs)
    test_loader = DataLoader(test_set, **test_kwargs)

    show(train_loader)
    show(test_loader)


if __name__ == "__main__":
    main()
