Module data_preprocess.fl_datasets.femnist
Expand source code
import os
import numpy as np
import torchvision
from ..h5_tff_dataset import H5TFFDataset
class FEMNIST(H5TFFDataset):
"""
Federated Extended MNIST Dataset.
Clients corresponds to different person handwriting.
"""
def __init__(self, h5_path, train=True, client_id=None):
if train:
h5_path = os.path.join(h5_path, 'femnist/fed_emnist_train.h5')
else:
h5_path = os.path.join(h5_path, 'femnist/fed_emnist_test.h5')
super(FEMNIST, self).__init__(h5_path, client_id, 'femnist', 'pixels')
self.transform = torchvision.transforms.ToTensor()
def __getitem__(self, index):
"""
Args:
index (int): Index of item that is fetched on behalf on current setuped client
Returns:
tuple: (image, target) where target is index of the target class.
"""
client, i = self._get_item_preprocess(index)
x = 1 - self.transform(self.dataset[client]['pixels'][i])
y = np.int64(self.dataset[client]['label'][i])
return x, y
Classes
class FEMNIST (h5_path, train=True, client_id=None)
-
Federated Extended MNIST Dataset. Clients corresponds to different person handwriting.
Ctor.
Args
h5_path
:str
- path to HDF5 file with dataset. Not native for systems like TensorFlow
client_id
:int
- switch dataset to work view of client client_id
data_key(str): if h5_path is not in the filesystem and download is True then it will an attempt to download dataset from TFF_DATASETS[data_key] URL download(bool): allow to download dataset
Expand source code
class FEMNIST(H5TFFDataset): """ Federated Extended MNIST Dataset. Clients corresponds to different person handwriting. """ def __init__(self, h5_path, train=True, client_id=None): if train: h5_path = os.path.join(h5_path, 'femnist/fed_emnist_train.h5') else: h5_path = os.path.join(h5_path, 'femnist/fed_emnist_test.h5') super(FEMNIST, self).__init__(h5_path, client_id, 'femnist', 'pixels') self.transform = torchvision.transforms.ToTensor() def __getitem__(self, index): """ Args: index (int): Index of item that is fetched on behalf on current setuped client Returns: tuple: (image, target) where target is index of the target class. """ client, i = self._get_item_preprocess(index) x = 1 - self.transform(self.dataset[client]['pixels'][i]) y = np.int64(self.dataset[client]['label'][i]) return x, y
Ancestors
- data_preprocess.h5_tff_dataset.H5TFFDataset
- data_preprocess.fl_dataset.FLDataset
- torch.utils.data.dataset.Dataset
- typing.Generic