import torchvision
from PIL import Image

class UnifiedMNIST(torchvision.datasets.MNIST):
  def __init__(self, root, **kwargs):
    super(UnifiedMNIST, self).__init__(root, **kwargs)
    self.data = self.data.numpy()
    self.targets = self.targets.numpy()

  def __getitem__(self, index):
    img, target = self.data[index], int(self.targets[index])

    # doing this so that it is consistent with all other datasets
    # to return a PIL Image
    img = Image.fromarray(img, mode='L')

    if self.transform is not None:
      img = self.transform(img)

    if self.target_transform is not None:
      target = self.target_transform(target)

    return img, target