import torch
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torchvision.datasets.folder import default_loader

import os


class ImageNet(ImageFolder):
    def __init__(
        self, 
        cfg,
        train=True,
        transform=None,
        target_transform=None,
        loader=default_loader,
        is_valid_file=None,
        **kwargs,
    ):
        datadir = os.path.join(cfg.dataset.root, 'train' if train else 'val')
        super().__init__(
            datadir,
            transform=transform,
            target_transform=target_transform,
            loader=loader,
            is_valid_file=is_valid_file,
        )

