import numpy as np
import torch.distributed as dist
from torchvision.datasets import ImageNet as ImageNet_pytorch
from PIL import Image

from .base_dataset import BaseDataset
from utils.parallel import get_dist_info
from .utils import download_and_extract_archive, check_integrity

class ImageNet(BaseDataset):
    """Since there is prepared dataset class in pytorch, we just wrap it here.
    """
    # CLASSES = [
    #     'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog',
    #     'horse', 'ship', 'truck'
    # ]

    def load_annotations(self):
        rank, world_size = get_dist_info()
        bkp_dataset = ImageNet_pytorch(root=self.root, split=self.test_mode)
        # just use the dataset to get imgs and set
        
        self.CLASSES = bkp_dataset.classes
        self.imgs = bkp_dataset.imgs
        self.gt_labels = bkp_dataset.targets

        data_infos = []
        for img, gt_label in zip(self.imgs, self.gt_labels):
            gt_label = np.array(gt_label, dtype=np.int64)
            info = {'img': img[0], 'gt_label': gt_label}
            data_infos.append(info)
        return data_infos

    def __getitem__(self, idx):
        img, label = self.data_infos[idx]['img'], self.data_infos[idx]['gt_label']
        if self.transforms is not None:
            img = Image.open(img)
            img = self.transforms(img)
        if self.target_transforms is not None:
            label = self.target_transforms(label)
        return img, label
