archai.datasets package

Submodules

archai.datasets.aug_policies module

archai.datasets.aug_policies.fa_reduced_cifar10()[source]
archai.datasets.aug_policies.fa_reduced_svhn()[source]
archai.datasets.aug_policies.fa_resnet50_rimagenet()[source]

archai.datasets.augmentation module

class archai.datasets.augmentation.Augmentation(policies)[source]

Bases: object

archai.datasets.augmentation.AutoContrast(img, _)[source]
archai.datasets.augmentation.Brightness(img, v)[source]
archai.datasets.augmentation.Color(img, v)[source]
archai.datasets.augmentation.Contrast(img, v)[source]
archai.datasets.augmentation.Cutout(img, v)[source]
archai.datasets.augmentation.CutoutAbs(img, v)[source]
archai.datasets.augmentation.Equalize(img, _)[source]
archai.datasets.augmentation.Flip(img, _)[source]
archai.datasets.augmentation.Invert(img, _)[source]
archai.datasets.augmentation.Posterize(img, v)[source]
archai.datasets.augmentation.Posterize2(img, v)[source]
archai.datasets.augmentation.Rotate(img, v)[source]
archai.datasets.augmentation.SamplePairing(imgs)[source]
archai.datasets.augmentation.Sharpness(img, v)[source]
archai.datasets.augmentation.ShearX(img, v)[source]
archai.datasets.augmentation.ShearY(img, v)[source]
archai.datasets.augmentation.Solarize(img, v)[source]
archai.datasets.augmentation.TranslateX(img, v)[source]
archai.datasets.augmentation.TranslateXAbs(img, v)[source]
archai.datasets.augmentation.TranslateY(img, v)[source]
archai.datasets.augmentation.TranslateYAbs(img, v)[source]
archai.datasets.augmentation.add_named_augs(transform_train, aug: Union[List, str], cutout: int)[source]
archai.datasets.augmentation.apply_augment(img, name, level)[source]
archai.datasets.augmentation.arsaug_policy()[source]
archai.datasets.augmentation.augment_list(for_autoaug=True)[source]
archai.datasets.augmentation.autoaug2arsaug(f)[source]
archai.datasets.augmentation.autoaug_paper_cifar10()[source]
archai.datasets.augmentation.autoaug_policy()[source]
archai.datasets.augmentation.float_parameter(level, maxval)[source]
archai.datasets.augmentation.get_augment(name)[source]
archai.datasets.augmentation.int_parameter(level, maxval)[source]
archai.datasets.augmentation.no_duplicates(f)[source]
archai.datasets.augmentation.policy_decoder(augment, num_policy, num_op)[source]
archai.datasets.augmentation.remove_deplicates(policies)[source]

archai.datasets.data module

class archai.datasets.data.SubsetSampler(indices)[source]

Bases: torch.utils.data.sampler.Sampler

Samples elements from a given list of indices, without replacement.

Arguments:

indices (sequence): a sequence of indices

archai.datasets.data.create_dataset_provider(conf_dataset: archai.common.config.Config)archai.datasets.dataset_provider.DatasetProvider[source]
archai.datasets.data.get_data(conf_loader: archai.common.config.Config) → Tuple[Optional[torch.utils.data.dataloader.DataLoader], Optional[torch.utils.data.dataloader.DataLoader], Optional[torch.utils.data.dataloader.DataLoader]][source]
archai.datasets.data.get_dataloaders(ds_provider: archai.datasets.dataset_provider.DatasetProvider, load_train: bool, train_batch_size: int, load_test: bool, test_batch_size: int, aug, cutout: int, val_ratio: float, apex: archai.common.apex_utils.ApexUtils, val_fold=0, train_workers: Optional[int] = None, test_workers: Optional[int] = None, target_lb=- 1, max_batches: int = - 1) → Tuple[Optional[torch.utils.data.dataloader.DataLoader], Optional[torch.utils.data.dataloader.DataLoader], Optional[torch.utils.data.dataloader.DataLoader]][source]

archai.datasets.dataset_provider module

class archai.datasets.dataset_provider.DatasetProvider(conf_dataset: archai.common.config.Config)[source]

Bases: overrides.enforce.EnforceOverrides

abstract get_datasets(load_train: bool, load_test: bool, transform_train, transform_test) → Tuple[Optional[torch.utils.data.dataset.Dataset], Optional[torch.utils.data.dataset.Dataset]][source]
abstract get_transforms() → tuple[source]
archai.datasets.dataset_provider.get_provider_type(name: str) → overrides.enforce.EnforceOverridesMeta[source]
archai.datasets.dataset_provider.register_dataset_provider(name: str, class_type: overrides.enforce.EnforceOverridesMeta) → None[source]

archai.datasets.distributed_stratified_sampler module

class archai.datasets.distributed_stratified_sampler.DistributedStratifiedSampler(dataset: torch.utils.data.dataset.Dataset, world_size: Optional[int] = None, rank: Optional[int] = None, shuffle=True, val_ratio: Optional[float] = 0.0, is_val=False, max_items: Optional[int] = None)[source]

Bases: torch.utils.data.sampler.Sampler

set_epoch(epoch)[source]

archai.datasets.limit_dataset module

class archai.datasets.limit_dataset.LimitDataset(dataset, n)[source]

Bases: torch.utils.data.dataset.Dataset

Module contents