
import collections
from typing import Any, Callable, Optional

import torch
from torchvision.datasets.folder import DatasetFolder, default_loader
from training.utils import image_transform


class ImageNetDataset(DatasetFolder):
    def __init__(
        self,
        root: str,
        loader: Callable[[str], Any] = default_loader,
        is_valid_file: Optional[Callable[[str], bool]] = None,
        image_size=256,
    ):
        IMG_EXTENSIONS = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp")

        self.transform = image_transform
        self.image_size = image_size

        super().__init__(
            root,
            loader,
            IMG_EXTENSIONS if is_valid_file is None else None,
            transform=self.transform,
            target_transform=None,
            is_valid_file=is_valid_file,
        )

        with open('./training/imagenet_label_mapping', 'r') as f:
            self.labels = {}
            for l in f:
                num, description = l.split(":")
                self.labels[int(num)] = description.strip()

        print("ImageNet dataset loaded.")

    def __getitem__(self, idx):

        try:
            path, target = self.samples[idx]
            image = self.loader(path)
            image = self.transform(image, resolution=self.image_size)
            input_ids = "{}".format(self.labels[target])
            class_ids = torch.tensor(target)

            return {'images': image, 'input_ids': input_ids, 'class_ids': class_ids}

        except Exception as e:
            print(e)
            return self.__getitem__(idx+1)

    def collate_fn(self, batch):
        batched = collections.defaultdict(list)
        for data in batch:
            for k, v in data.items():
                batched[k].append(v)
        for k, v in batched.items():
            if k not in ('input_ids'):
                batched[k] = torch.stack(v, dim=0)

        return batched


if __name__ == '__main__':
    pass
