import torchvision
import torchvision.transforms as transforms

from data.utils import train_val_split
from configs import model_attributes


### CIFAR10 ###
def load_CIFAR10(args, train):
    transform = get_transform_CIFAR10(args, train)
    dataset = torchvision.datasets.CIFAR10(args.root_dir, train, transform=transform, download=True)
    if train:
        subsets = train_val_split(dataset, args.val_fraction)
    else:
        subsets = [dataset, ]
    return subsets


def get_transform_CIFAR10(args, train):
    transform_list = []
    # resize if needed
    target_resolution = model_attributes[args.model]['target_resolution']
    if target_resolution is not None:
        transform_list.append(transforms.Resize(target_resolution))
    transform_list += [transforms.ToTensor(),
                       transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))]
    composed_transform = transforms.Compose(transform_list)
    return composed_transform
