from typing import Tuple

import torch
from torchvision import transforms


class CIFARTransformTrain:

    def __call__(self, sample: Tuple[torch.Tensor, torch.Tensor]):
        """
        Args:
            sample: Tuple with (image, target).

        Returns:
            A tuple with the converted image and target.
        """
        img, target = sample
        transform = transforms.Compose(
            [
                transforms.RandomHorizontalFlip(),
                transforms.RandomCrop(32, padding=4),
                # transforms.ToTensor(),
                transforms.Normalize(
                    (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)
                ),
            ]
        )
        return transform(img), target


class CIFARTransformEval:

    def __call__(self, sample: Tuple[torch.Tensor, torch.Tensor]):
        """
        Args:
            sample: Tuple with (image, target).

        Returns:
            A tuple with the converted image and target.
        """
        img, target = sample
        transform = transforms.Compose(
            [
                # transforms.ToTensor(),
                transforms.Normalize(
                    (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)
                ),
            ]
        )
        return transform(img), target
