import torchvision.transforms as transforms
from antgine.datasets import ImageFolder

_default_train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

_default_test_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])


class ImageNet(ImageFolder):
    """
        ImageNet dataset class.
    """
    def __init__(self, root: str, batch_size: int,
                 train_transform: transforms.Compose = _default_train_transform,
                 test_transform: transforms.Compose = _default_test_transform,
                 num_workers=8):
        """
        :param str root: Dataset's root directory.
        :param int batch_size: Batch size.
        :param transforms.Compose train_transform: Transform applied to inputs during training.
        :param transforms.Compose test_transform: Transform applied to inputs during testing.
        :param int num_workers: Number of workers launched for loading data.
        """
        super().__init__(root=root, train_path='train', test_path='validation',
                         batch_size=batch_size, train_transform=train_transform,
                         test_transform=test_transform, num_workers=num_workers)
