Module data_preprocess.fl_datasets.femnist

Expand source code
import os
import numpy as np
import torchvision

from ..h5_tff_dataset import H5TFFDataset


class FEMNIST(H5TFFDataset):
    """
    Federated Extended MNIST Dataset.
    Clients corresponds to different person handwriting.
    """
    def __init__(self, h5_path, train=True, client_id=None):
        if train:
            h5_path = os.path.join(h5_path, 'femnist/fed_emnist_train.h5')
        else:
            h5_path = os.path.join(h5_path, 'femnist/fed_emnist_test.h5')
        super(FEMNIST, self).__init__(h5_path, client_id, 'femnist', 'pixels')
        self.transform = torchvision.transforms.ToTensor()

    def __getitem__(self, index):
        """
        Args:
            index (int): Index of item that is fetched on behalf on current setuped client

        Returns:
            tuple: (image, target) where target is index of the target class.
        """

        client, i = self._get_item_preprocess(index)
        x = 1 - self.transform(self.dataset[client]['pixels'][i])
        y = np.int64(self.dataset[client]['label'][i])
        return x, y

Classes

class FEMNIST (h5_path, train=True, client_id=None)

Federated Extended MNIST Dataset. Clients corresponds to different person handwriting.

Ctor.

Args

h5_path : str
path to HDF5 file with dataset. Not native for systems like TensorFlow
client_id : int
switch dataset to work view of client client_id

data_key(str): if h5_path is not in the filesystem and download is True then it will an attempt to download dataset from TFF_DATASETS[data_key] URL download(bool): allow to download dataset

Expand source code
class FEMNIST(H5TFFDataset):
    """
    Federated Extended MNIST Dataset.
    Clients corresponds to different person handwriting.
    """
    def __init__(self, h5_path, train=True, client_id=None):
        if train:
            h5_path = os.path.join(h5_path, 'femnist/fed_emnist_train.h5')
        else:
            h5_path = os.path.join(h5_path, 'femnist/fed_emnist_test.h5')
        super(FEMNIST, self).__init__(h5_path, client_id, 'femnist', 'pixels')
        self.transform = torchvision.transforms.ToTensor()

    def __getitem__(self, index):
        """
        Args:
            index (int): Index of item that is fetched on behalf on current setuped client

        Returns:
            tuple: (image, target) where target is index of the target class.
        """

        client, i = self._get_item_preprocess(index)
        x = 1 - self.transform(self.dataset[client]['pixels'][i])
        y = np.int64(self.dataset[client]['label'][i])
        return x, y

Ancestors

  • data_preprocess.h5_tff_dataset.H5TFFDataset
  • data_preprocess.fl_dataset.FLDataset
  • torch.utils.data.dataset.Dataset
  • typing.Generic