from torchvision.datasets import ImageNet
import cv2


class SubImagenet(object):
  def __init__(self, samples, offset):
    self.samples = samples
    self.offset = offset

  def __getitem__(self, index):
    if self.offset == 0:
      return cv2.imread(self.samples[index][self.offset], cv2.IMREAD_COLOR)
    else:
      return self.samples[index][self.offset]


class UnifiedImageNet(ImageNet):
  def __init__(self, root, split='train', download=False, **kwargs):
    super(UnifiedImageNet, self).__init__(root, split, download, **kwargs)
    self.data = SubImagenet(self.samples, 0)
    self.targets = SubImagenet(self.samples, 1)
