import torch
from torchvision.datasets import ImageNet
from torchvision.transforms import Compose, ToTensor, Normalize, CenterCrop, Resize



def imagenet_dataset(split):
  transforms = Compose([
      Resize(256),
      CenterCrop(256),
      ToTensor(),
      Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
  ])
  return ImageNet(f'../imagenet_dataset/images', split=split, transform=transforms)
