from torchvision import datasets
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder, ImageNet
from torch.utils.data import DataLoader

def get_dataset(batch_size, augmentation=False, download=False, n_workers=16):
    if download:
        print("ImageNet can not be downloaded")
        exit(-1)

    normalize = transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
    simple_transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        normalize,
    ])
    if augmentation:
        aug_transform = transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])
    else:
        aug_transform = simple_transform

    trainset = ImageFolder(
        "/globalscratch/ucl/elen/avander/imagenet/ILSVRC/Data/CLS-LOC/train",
        aug_transform)
    valset = ImageFolder(
        "/globalscratch/ucl/elen/avander/imagenet/ILSVRC/Data/CLS-LOC/val",
        simple_transform)
    # trainset = ImageNet(
    #     "/globalscratch/ucl/elen/avander/imagenet/ILSVRC/Data/CLS-LOC/train/ ",
    #     transform=aug_transform, split="train", download=False)
    # valset = ImageNet(
    #     "/globalscratch/ucl/elen/avander/imagenetILSVRC/Data/CLS-LOC/val/ ",
    #     transform=simple_transform, split="val", download=False)

    train_loader = DataLoader(
        trainset,
        batch_size=batch_size, shuffle=True,
        num_workers=n_workers, pin_memory=True
    )

    val_loader = DataLoader(
        valset,
        batch_size=batch_size, shuffle=False,
        num_workers=n_workers, pin_memory=True
    )
    test_loader = val_loader
    loaders = {
        "train": train_loader,
        "val": val_loader,
        "test": test_loader
    }
    return loaders, 1000

