'''Unified SVHN'''

import torchvision
from PIL import Image

class UnifiedSVHN(torchvision.datasets.SVHN):
  def __init__(self, root, **kwargs):
    super(UnifiedSVHN, self).__init__(root, **kwargs)
    self.data = self.data.transpose(0, 2, 3, 1)
    self.targets = self.labels

  def __getitem__(self, index):
    img, target = self.data[index], int(self.labels[index])

    # doing this so that it is consistent with all other datasets
    # to return a PIL Image
    img = Image.fromarray(img)

    if self.transform is not None:
      img = self.transform(img)

    if self.target_transform is not None:
      target = self.target_transform(target)

    return img, target