import os
import torch
import torchvision.transforms as transforms
import torchvision.datasets as datasets


def build_imagenet_data(data_path: str = '', input_size: int = 224, batch_size: int = 64, workers: int = 4):
    print('==> Using Pytorch Dataset')

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    # traindir = '/mnt/lustre/share/ImageNet-Pytorch/train'
    # valdir = '/mnt/lustre/share/ImageNet-Pytorch/val'
    traindir = os.path.join(data_path, 'train')
    valdir = os.path.join(data_path, 'val')
    # torchvision.set_image_backend('accimage')
    train_dataset = datasets.ImageFolder(
        traindir,
        transforms.Compose([
            transforms.RandomResizedCrop(input_size),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))

    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True,
        num_workers=workers, pin_memory=True)
    val_loader = torch.utils.data.DataLoader(
        datasets.ImageFolder(valdir, transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(input_size),
            transforms.ToTensor(),
            normalize,
        ])),
        batch_size=100, shuffle=False,
        num_workers=workers, pin_memory=True) 
    return train_loader, val_loader
