import torch
from torchvision.datasets import CIFAR10
from torchvision.transforms import Compose, ToTensor, Normalize, CenterCrop


def cifar10_dataset(split):
  train = split == 'train'
  transforms = Compose([
    ToTensor(),
    # mean and std given by chatgpt. 
    # looks frequently used (e.g., https://github.com/kuangliu/pytorch-cifar/blob/49b7aa97b0c12fe0d4054e670403a16b6b834ddd/main.py#L39)
    Normalize(mean=torch.tensor([0.4914, 0.4822, 0.4465]), std=torch.tensor([0.2023, 0.1994, 0.2010]))]
  )
  return CIFAR10(root='../cifar10_dataset', train=train, transform=transforms, download=True)
