import os
from util.crop import RandomResizedCrop
from timm.data.constants import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from .vtab import _DATASET_NAME, ImageFilelist, get_classes_num
Q
def create_dataset(
        name,
        root,
        split='validation',
        search_split=True,
        class_map=None,
        load_bytes=False,
        is_training=False,
        download=False,
        batch_size=None,
        repeats=0,
        final=False,
        **kwargs
):
    """ Dataset factory method

    In parenthesis after each arg are the type of dataset supported for each arg, one of:
      * folder - default, timm folder (or tar) based ImageDataset
      * torch - torchvision based datasets
      * TFDS - Tensorflow-datasets wrapper in IterabeDataset interface via IterableImageDataset
      * all - any of the above

    Args:
        name: dataset name, empty is okay for folder based datasets
        root: root folder of dataset (all)
        split: dataset split (all)
        search_split: search for split specific child fold from root so one can specify
            `imagenet/` instead of `/imagenet/val`, etc on cmd line / config. (folder, torch/folder)
        class_map: specify class -> index mapping via text file or dict (folder)
        load_bytes: load data, return images as undecoded bytes (folder)
        download: download dataset if not present and supported (TFDS, torch)
        is_training: create dataset in train mode, this is different from the split.
            For Iterable / TDFS it enables shuffle, ignored for other datasets. (TFDS)
        batch_size: batch size hint for (TFDS)
        repeats: dataset repeats per iteration i.e. epoch (TFDS)
        **kwargs: other args to pass to dataset

    Returns:
        Dataset object
    """
    _mean = IMAGENET_INCEPTION_MEAN
    _std = IMAGENET_INCEPTION_STD
    transform_train = transforms.Compose([
            transforms.Resize((224, 224), interpolation=3),
            transforms.ToTensor(),
            transforms.Normalize(mean=_mean, std=_std)])
    transform_val = transforms.Compose([
            transforms.Resize((224, 224), interpolation=3),
            transforms.ToTensor(),
            transforms.Normalize(mean=_mean, std=_std)])
    root = os.path.join(cfg)
    if final:
        dataset_train = ImageFilelist(root=root, flist=root + "/train800val200.txt", transform=transform_train)   
    else:
        dataset_train = ImageFilelist(root=root, flist=root + "/train800.txt", transform=transform_train)
    dataset_test = ImageFilelist(root=root, flist=root + "/test.txt", transform=transform_val)
    dataset_val = ImageFilelist(root=root, flist=root + "/val200.txt", transform=transform_val)
    nb_classes = get_classes_num(args.dataset)
    return dataset_train, dataset_val, dataset_test, nb_classes
