import os
from torchvision.datasets import CIFAR10, MNIST
from .distributed_dataset import distributed_dataset
import torchvision.transforms as transforms

from PIL import Image

# class CustomTransform(object):
#     def __init__(self):
#         pass
#     def __call__(self, x):
#         img = Image.fromarray(x, mode='L')

#         return img

def mnist(rank, batch_size=None,
            transform=None, sample_size=100, remove_index=0,
            is_distribute=True, seed=777, path="../data", node=2):

    if transform is None:
        transform = transforms.Compose([
                # CustomTransform(),
                transforms.Resize(32), # to match the dimension of cifar10
                transforms.ToTensor(),
                transforms.Normalize(mean=(0.1307,), std=(0.3081,))])
    # if batch_size is None:
    #     batch_size = 1
    
    if not os.path.exists(path):
        os.mkdir(path)
    train_set = MNIST(root=path, train=True, download=True, transform=transform)
    # test_set = CIFAR10(root=path, train=False, download=True, transform=transform)
    if is_distribute:
        train_set = distributed_dataset(train_set, sample_size, rank, remove_index=remove_index, seed=seed, node=node)
    
    return train_set
