import os
import torch
from torchvision.datasets import MNIST

class FEMNIST(MNIST):

    '''
    server.pt: 620 samples
    training.pt: 31272 samples, choose those users with more than 260 samples
    test.pt: 3524 samples
    
    '''
    def __init__(self, root, train=True, server=False):
        self.classes = [i for i in range(62)]
        self.data_name = 'FEMNIST'
        self.num_classes = 62
        self.training_file = 'training.pt'
        self.test_file = 'test.pt'

        if server:
            data_file = 'server.pt'
        elif train:
            data_file = self.training_file
        else:
            data_file = self.test_file
        self.data, self.targets = torch.load(os.path.join(root, 'processed', data_file))
        self.data = (self.data * 255).astype('uint8')
        self.targets = self.targets.astype('int64')


