'''ImageNet on Philly
https://github.com/Hadisalman/smoothing-adversarial/blob/master/code/datasets.py'''

import os
from torch.utils.data import Dataset
from dataset.zipdata import ZipData

IMAGENET_ON_PHILLY_DIR = "/hdfs/public/imagenet/2012/"


def imagenet_on_philly(train=True, transform=None) -> Dataset:
  trainpath = os.path.join(IMAGENET_ON_PHILLY_DIR, 'train.zip')
  train_map = os.path.join(IMAGENET_ON_PHILLY_DIR, 'train_map.txt')
  valpath = os.path.join(IMAGENET_ON_PHILLY_DIR, 'val.zip')
  val_map = os.path.join(IMAGENET_ON_PHILLY_DIR, 'val_map.txt')

  if train:
    return ZipData(trainpath, train_map, transform)
  else:
    return ZipData(valpath, val_map, transform)
