from typing import Callable, Optional
import os
from glob import glob

from torchvision.datasets import ImageFolder
from torchvision.datasets.utils import download_url
import torchvision.transforms as transforms


class ImageNet(ImageFolder):

    def __init__(self,
                 root: str,
                 train: bool,
                 transform: Optional[Callable] = None,
                 target_transform: Optional[Callable] = None,
                 download: bool = False) -> None:

        self.root = os.path.join(os.path.expanduser(root), 'imagenet1k')

        word = {}
        with open(os.path.join(self.root, 'words.txt'), 'r') as f:
            for w in f.readlines():
                w = w.strip().split('\t')
                word[w[0]] = w[1].split(',')[0]

        if train:
            self.path = os.path.join(self.root, 'train')
        else:
            self.path = os.path.join(self.root, 'val')

        super().__init__(self.path,
                         transform=transforms.ToTensor()
                         if transform is None else transform,
                         target_transform=target_transform)
        self.classes_names = [word[x] for x in self.classes]

        self.mean, self.std = (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)
        if isinstance(self.transform, transforms.Compose):
            self.transform.transforms.append(
                transforms.Normalize(self.mean, self.std))

    def __getitem__(self, index):
        sample, target = super().__getitem__(index)
        sample = transforms.Resize((224, 224), antialias=True)(sample)
        return sample, target


if __name__ == "__main__":
    dataset = ImageNet('data/', True)
    for img, label in dataset:
        print(img.shape, label)
