Module data_preprocess.data_loader

Expand source code
#!/usr/bin/env python3

# Import PyTorch root package import torch
import torch

import torchvision

from torchvision import transforms
from .fl_datasets import FEMNIST, FLCifar100, FLCifar10, FLCifar10ByClass, Shakespeare, SHAKESPEARE_EVAL_BATCH_SIZE
from .artificial_dataset import ArificialDataset
from .libsvm_dataset import LibSVMDataset

CIFAR_NORMALIZATION = ((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))


def get_torch_version() -> int:
    """
    Get PyTorch library version
    """
    return int(torch.__version__.split('+')[0].replace('.', ''))


def load_data(exec_ctx, path, dataset, args, load_trainset=True, download=True, client_id=None):
    """
    Load dataset.

    Args:
      exec_ctx: Execution context that maybe required for pseudo random generations
      path: path to dataset
      args: command line arguments
      load_trainset: Load train dataset or test dataset.
      download: If dataset is not presented in filesystem download it from the web
      client_id: Specified id of the client on bhalf of which dataset will be used

    Returns:
      None
    """
    dataset = dataset.lower()
    trainset = None

    if (client_id is not None and client_id < 0) or dataset in ['emnist', 'full_shakespeare']:
        client_id = None

    if dataset.startswith("cifar"):  # CIFAR-10/100

        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(*CIFAR_NORMALIZATION),
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(*CIFAR_NORMALIZATION),
        ])

        if dataset == "cifar10":
            if load_trainset:
                trainset = torchvision.datasets.CIFAR10(root=path, train=True, download=download,
                                                        transform=transform_train)
                trainset.num_clients = get_num_clients(dataset)

            testset = torchvision.datasets.CIFAR10(root=path, train=False, download=download, transform=transform_test)

        elif dataset == "cifar10_fl":
            if load_trainset:
                trainset = FLCifar10(exec_ctx, args, root=path, train=True,
                                     download=download, transform=transform_train, client_id=client_id)
            # testset = tv.datasets.CIFAR10(root=path, train=False, download=download, transform=transform_test)
            testset = FLCifar10(exec_ctx, args, root=path, train=False,
                                download=download, transform=transform_test, client_id=client_id)

        elif dataset == "cifar10_fl_by_class":
            if load_trainset:
                trainset = FLCifar10ByClass(exec_ctx, args, root=path, train=True,
                                            download=download, transform=transform_train, client_id=client_id)
            # testset = tv.datasets.CIFAR10(root=path, train=False, download=download, transform=transform_test)
            testset = FLCifar10ByClass(exec_ctx, args, root=path, train=False,
                                       download=download, transform=transform_test, client_id=client_id)

        elif dataset == "cifar100":
            if load_trainset:
                trainset = torchvision.datasets.CIFAR100(root=path, train=True,
                                                         download=download, transform=transform_train)
                trainset.num_clients = get_num_clients(dataset)

            testset = torchvision.datasets.CIFAR100(root=path, train=False,
                                                    download=download, transform=transform_test)
        elif dataset == "cifar100_fl":
            if load_trainset:
                trainset = FLCifar100(path, train=True, transform=transform_train, client_id=client_id)
                trainset.num_clients = get_num_clients(dataset)
            testset = FLCifar100(path, train=False, transform=transform_test)
        else:
            raise NotImplementedError(f'{dataset} is not implemented.')

    elif dataset in ["femnist", 'emnist']:
        if load_trainset:
            trainset = FEMNIST(path, train=True, client_id=client_id)

        testset = FEMNIST(path, train=False)

    elif dataset in ['shakespeare', 'full_shakespeare']:
        if load_trainset:
            trainset = Shakespeare(path, train=True, client_id=client_id)
        testset = Shakespeare(path, train=False)

    elif dataset in ['generated_for_quadratic_minimization']:
        trainset = ArificialDataset(exec_ctx, args, train=True)
        testset = ArificialDataset(exec_ctx, args, train=False)

        trainset.compute_Li_for_linear_regression()
        testset.compute_Li_for_linear_regression()

    elif dataset in LibSVMDataset.allowableDatasets():
        transform_train = None  # transforms.Compose([transforms.ToTensor()])
        transform_test = None  # transforms.Compose([transforms.ToTensor()])

        trainset = LibSVMDataset(exec_ctx, args,
                                 root=path, dataset=dataset, train=True, download=download,
                                 transform=transform_train, target_transform=None, client_id=client_id,
                                 num_clients=get_num_clients(dataset))

        testset = LibSVMDataset(exec_ctx, args,
                                root=path, dataset=dataset, train=False, download=download,
                                transform=transform_test, target_transform=None, client_id=client_id,
                                num_clients=get_num_clients(dataset))

        trainset.compute_Li_for_logregression()
        testset.compute_Li_for_logregression()

    else:
        raise NotImplementedError(f'{dataset} is not implemented.')

    return trainset, testset


def get_test_batch_size(dataset, batch_size):
    dataset = dataset.lower()
    if dataset == 'shakespeare':
        return SHAKESPEARE_EVAL_BATCH_SIZE
    return batch_size


def evalute_num_classes(dataset):
    """
   Helper function for evaluate number of classes for classification via traversing all dataset samples

   Args:
       dataset: dataset object that supports __len __ and __getitem__ routines. __getitem__ should return (in., target)

   Returns:
       Number of classes in dataset
   """
    max_class = 0
    min_class = 0

    samples = len(dataset)
    for sample_idx in range(samples):
        input_sample, target = dataset[sample_idx]
        max_class = max(target, max_class)
        min_class = min(target, min_class)

    number_of_classes_in_dataset = max_class - min_class + 1
    return number_of_classes_in_dataset


def get_num_classes(dataset):
    """ Helper function for get number of classes in a well-known datasets

   Args:
       dataset(str): name of dataset

   Returns:
       Number of classes in dataset
   """
    dataset = dataset.lower()
    if dataset in ['cifar10', 'cifar10_fl', 'cifar10_fl_by_class']:
        num_classes = 10
    elif dataset in ['cifar100', 'cifar100_fl']:
        num_classes = 100
    elif dataset in ['femnist', 'emnist']:
        num_classes = 62
    elif dataset == 'fashion-mnist':
        num_classes = 10
    elif dataset in ['shakespeare', 'full_shakespeare']:
        num_classes = 90
    elif dataset in ['w9a', 'w8a', 'w7a', 'w6a', 'w5a', 'w4a', 'w3a', 'w2a', 'w1a']:
        num_classes = 2
    elif dataset in ['a9a', 'a8a', 'a7a', 'a6a', 'a5a', 'a4a', 'a3a', 'a2a', 'a1a']:
        num_classes = 2
    elif dataset in ['mushrooms', 'phishing']:
        num_classes = 2
    else:
        raise ValueError(f"Dataset {dataset} is not supported.")
    return num_classes


def get_num_clients(dataset):
    """
   Get number of clients for specific dataset.

   Args:
       dataset(str): name of dataset

   Returns:
       Number of clients presented in dataset
   """
    dataset = dataset.lower()
    if dataset in ['emnist', 'cifar10', 'cifar100', 'full_shakespeare']:
        num_clients = 1
    elif dataset == 'shakespeare':
        num_clients = 715
    elif dataset == 'femnist':
        num_clients = 3400
    elif dataset == 'cifar100_fl':
        num_clients = 500
    elif dataset == 'cifar10_fl':
        num_clients = 10
    elif dataset == 'cifar10_fl_by_class':
        num_clients = 10
    elif dataset == 'w9a' or dataset == 'w8a' or dataset == 'w7a' or dataset == 'w6a' or dataset == 'w5a' or \
         dataset == 'w4a' or dataset == 'w3a' or dataset == 'w2a' or dataset == 'w1a':
        # num_clients = 100
        num_clients = 10
    elif dataset == 'a9a' or dataset == 'a8a' or dataset == 'a7a' or dataset == 'a6a' or dataset == 'a5a' or \
         dataset == 'a4a' or dataset == 'a3a' or dataset == 'a2a' or dataset == 'a1a':
        # num_clients = 100
        num_clients = 10
    elif dataset == 'mushrooms' or dataset == 'phishing':
        num_clients = 10
    else:
        raise ValueError(f"Dataset {dataset} is not supported.")

    return num_clients

Functions

def evalute_num_classes(dataset)

Helper function for evaluate number of classes for classification via traversing all dataset samples

Args

dataset
dataset object that supports len and getitem routines. getitem should return (in., target)

Returns

Number of classes in dataset

Expand source code
def evalute_num_classes(dataset):
    """
   Helper function for evaluate number of classes for classification via traversing all dataset samples

   Args:
       dataset: dataset object that supports __len __ and __getitem__ routines. __getitem__ should return (in., target)

   Returns:
       Number of classes in dataset
   """
    max_class = 0
    min_class = 0

    samples = len(dataset)
    for sample_idx in range(samples):
        input_sample, target = dataset[sample_idx]
        max_class = max(target, max_class)
        min_class = min(target, min_class)

    number_of_classes_in_dataset = max_class - min_class + 1
    return number_of_classes_in_dataset
def get_num_classes(dataset)

Helper function for get number of classes in a well-known datasets

Args

dataset(str): name of dataset

Returns

Number of classes in dataset

Expand source code
def get_num_classes(dataset):
    """ Helper function for get number of classes in a well-known datasets

   Args:
       dataset(str): name of dataset

   Returns:
       Number of classes in dataset
   """
    dataset = dataset.lower()
    if dataset in ['cifar10', 'cifar10_fl', 'cifar10_fl_by_class']:
        num_classes = 10
    elif dataset in ['cifar100', 'cifar100_fl']:
        num_classes = 100
    elif dataset in ['femnist', 'emnist']:
        num_classes = 62
    elif dataset == 'fashion-mnist':
        num_classes = 10
    elif dataset in ['shakespeare', 'full_shakespeare']:
        num_classes = 90
    elif dataset in ['w9a', 'w8a', 'w7a', 'w6a', 'w5a', 'w4a', 'w3a', 'w2a', 'w1a']:
        num_classes = 2
    elif dataset in ['a9a', 'a8a', 'a7a', 'a6a', 'a5a', 'a4a', 'a3a', 'a2a', 'a1a']:
        num_classes = 2
    elif dataset in ['mushrooms', 'phishing']:
        num_classes = 2
    else:
        raise ValueError(f"Dataset {dataset} is not supported.")
    return num_classes
def get_num_clients(dataset)

Get number of clients for specific dataset.

Args

dataset(str): name of dataset

Returns

Number of clients presented in dataset

Expand source code
def get_num_clients(dataset):
    """
   Get number of clients for specific dataset.

   Args:
       dataset(str): name of dataset

   Returns:
       Number of clients presented in dataset
   """
    dataset = dataset.lower()
    if dataset in ['emnist', 'cifar10', 'cifar100', 'full_shakespeare']:
        num_clients = 1
    elif dataset == 'shakespeare':
        num_clients = 715
    elif dataset == 'femnist':
        num_clients = 3400
    elif dataset == 'cifar100_fl':
        num_clients = 500
    elif dataset == 'cifar10_fl':
        num_clients = 10
    elif dataset == 'cifar10_fl_by_class':
        num_clients = 10
    elif dataset == 'w9a' or dataset == 'w8a' or dataset == 'w7a' or dataset == 'w6a' or dataset == 'w5a' or \
         dataset == 'w4a' or dataset == 'w3a' or dataset == 'w2a' or dataset == 'w1a':
        # num_clients = 100
        num_clients = 10
    elif dataset == 'a9a' or dataset == 'a8a' or dataset == 'a7a' or dataset == 'a6a' or dataset == 'a5a' or \
         dataset == 'a4a' or dataset == 'a3a' or dataset == 'a2a' or dataset == 'a1a':
        # num_clients = 100
        num_clients = 10
    elif dataset == 'mushrooms' or dataset == 'phishing':
        num_clients = 10
    else:
        raise ValueError(f"Dataset {dataset} is not supported.")

    return num_clients
def get_test_batch_size(dataset, batch_size)
Expand source code
def get_test_batch_size(dataset, batch_size):
    dataset = dataset.lower()
    if dataset == 'shakespeare':
        return SHAKESPEARE_EVAL_BATCH_SIZE
    return batch_size
def get_torch_version() ‑> int

Get PyTorch library version

Expand source code
def get_torch_version() -> int:
    """
    Get PyTorch library version
    """
    return int(torch.__version__.split('+')[0].replace('.', ''))
def load_data(exec_ctx, path, dataset, args, load_trainset=True, download=True, client_id=None)

Load dataset.

Args

exec_ctx
Execution context that maybe required for pseudo random generations
path
path to dataset
args
command line arguments
load_trainset
Load train dataset or test dataset.
download
If dataset is not presented in filesystem download it from the web
client_id
Specified id of the client on bhalf of which dataset will be used

Returns

None

Expand source code
def load_data(exec_ctx, path, dataset, args, load_trainset=True, download=True, client_id=None):
    """
    Load dataset.

    Args:
      exec_ctx: Execution context that maybe required for pseudo random generations
      path: path to dataset
      args: command line arguments
      load_trainset: Load train dataset or test dataset.
      download: If dataset is not presented in filesystem download it from the web
      client_id: Specified id of the client on bhalf of which dataset will be used

    Returns:
      None
    """
    dataset = dataset.lower()
    trainset = None

    if (client_id is not None and client_id < 0) or dataset in ['emnist', 'full_shakespeare']:
        client_id = None

    if dataset.startswith("cifar"):  # CIFAR-10/100

        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(*CIFAR_NORMALIZATION),
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(*CIFAR_NORMALIZATION),
        ])

        if dataset == "cifar10":
            if load_trainset:
                trainset = torchvision.datasets.CIFAR10(root=path, train=True, download=download,
                                                        transform=transform_train)
                trainset.num_clients = get_num_clients(dataset)

            testset = torchvision.datasets.CIFAR10(root=path, train=False, download=download, transform=transform_test)

        elif dataset == "cifar10_fl":
            if load_trainset:
                trainset = FLCifar10(exec_ctx, args, root=path, train=True,
                                     download=download, transform=transform_train, client_id=client_id)
            # testset = tv.datasets.CIFAR10(root=path, train=False, download=download, transform=transform_test)
            testset = FLCifar10(exec_ctx, args, root=path, train=False,
                                download=download, transform=transform_test, client_id=client_id)

        elif dataset == "cifar10_fl_by_class":
            if load_trainset:
                trainset = FLCifar10ByClass(exec_ctx, args, root=path, train=True,
                                            download=download, transform=transform_train, client_id=client_id)
            # testset = tv.datasets.CIFAR10(root=path, train=False, download=download, transform=transform_test)
            testset = FLCifar10ByClass(exec_ctx, args, root=path, train=False,
                                       download=download, transform=transform_test, client_id=client_id)

        elif dataset == "cifar100":
            if load_trainset:
                trainset = torchvision.datasets.CIFAR100(root=path, train=True,
                                                         download=download, transform=transform_train)
                trainset.num_clients = get_num_clients(dataset)

            testset = torchvision.datasets.CIFAR100(root=path, train=False,
                                                    download=download, transform=transform_test)
        elif dataset == "cifar100_fl":
            if load_trainset:
                trainset = FLCifar100(path, train=True, transform=transform_train, client_id=client_id)
                trainset.num_clients = get_num_clients(dataset)
            testset = FLCifar100(path, train=False, transform=transform_test)
        else:
            raise NotImplementedError(f'{dataset} is not implemented.')

    elif dataset in ["femnist", 'emnist']:
        if load_trainset:
            trainset = FEMNIST(path, train=True, client_id=client_id)

        testset = FEMNIST(path, train=False)

    elif dataset in ['shakespeare', 'full_shakespeare']:
        if load_trainset:
            trainset = Shakespeare(path, train=True, client_id=client_id)
        testset = Shakespeare(path, train=False)

    elif dataset in ['generated_for_quadratic_minimization']:
        trainset = ArificialDataset(exec_ctx, args, train=True)
        testset = ArificialDataset(exec_ctx, args, train=False)

        trainset.compute_Li_for_linear_regression()
        testset.compute_Li_for_linear_regression()

    elif dataset in LibSVMDataset.allowableDatasets():
        transform_train = None  # transforms.Compose([transforms.ToTensor()])
        transform_test = None  # transforms.Compose([transforms.ToTensor()])

        trainset = LibSVMDataset(exec_ctx, args,
                                 root=path, dataset=dataset, train=True, download=download,
                                 transform=transform_train, target_transform=None, client_id=client_id,
                                 num_clients=get_num_clients(dataset))

        testset = LibSVMDataset(exec_ctx, args,
                                root=path, dataset=dataset, train=False, download=download,
                                transform=transform_test, target_transform=None, client_id=client_id,
                                num_clients=get_num_clients(dataset))

        trainset.compute_Li_for_logregression()
        testset.compute_Li_for_logregression()

    else:
        raise NotImplementedError(f'{dataset} is not implemented.')

    return trainset, testset