import torch
from torch.utils.data import Subset
import torchvision
import torchvision.transforms as transforms
from models import model_attributes
from data.utils import *


### CIFAR10 ###
def load_CIFAR10(args, train):
    transform = get_transform_CIFAR10(args, train)
    dataset = torchvision.datasets.CIFAR10(args.root_dir,
                                           train,
                                           transform=transform,
                                           download=True)
    if train:
        subsets = train_val_split(dataset, args.val_fraction)
    else:
        subsets = [
            dataset,
        ]
    return subsets


def get_transform_CIFAR10(args, train):
    transform_list = []
    # resize if needed
    target_resolution = model_attributes[args.model]["target_resolution"]
    if target_resolution is not None:
        transform_list.append(transforms.Resize(target_resolution))
    transform_list += [
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ]
    composed_transform = transforms.Compose(transform_list)
    return composed_transform
