from torchvision.datasets import MNIST
from torchvision import transforms
from .data_utils import *


def get_dataset():
    dataset = MNIST('./data/MNIST', train=True, 
                        transform=transforms.Compose([transforms.Resize(32),
                                                      transforms.ToTensor(),
                                                      transforms.Normalize(
                                                            mean=0.5,
                                                            std=0.5),
                                                     ]),
                download=True)
    return dataset

def get_fl_dataset(args):
    dataset = get_dataset()
    dataset = [dataset[i] for i in range(len(dataset))]
    return process_data(args, dataset)
    
if __name__ == '__main__':
    get_fl_dataset(0)
