import torchvision.transforms as transforms
from abc import ABC, abstractmethod



class AbstractTransform(ABC):
    def __init__(self, image_size):
        self.image_size = image_size

    @property
    @abstractmethod
    def train_transform(self):
        """Return the training transforms pipeline."""
        pass

    @property
    @abstractmethod
    def test_transform(self):
        """Return the testing transforms pipeline."""
        pass


class CifarImageTransforms(AbstractTransform):
    def __init__(self, image_size):
        super().__init__(image_size)
        self.MEAN = (0.4914, 0.4822, 0.4465)
        self.STD = (0.2470, 0.2435, 0.2615)

    @property
    def test_transform(self):
        return transforms.Compose(
            [
                transforms.Resize(self.image_size),
                transforms.ToTensor(),
                transforms.Normalize(self.MEAN, self.STD)
            ]
        )
    
    @property
    def train_transform(self):
        return transforms.Compose(
            [
                transforms.RandomCrop(32, padding=4),
                transforms.Resize(self.image_size),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(self.MEAN, self.STD)
            ]
        )


class TinyImgTransform(AbstractTransform):
    def __init__(self, image_size=64):  
        super().__init__(image_size)
        self.MEAN, self.STD = (0.4802, 0.4480, 0.3975), (0.2770, 0.2691, 0.2821)

    @property
    def train_transform(self):
        return transforms.Compose(
            [
                transforms.RandomCrop(64, padding=4),
                transforms.Resize(self.image_size),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(self.MEAN, self.STD),
            ]
        )

    @property
    def test_transform(self):
        return transforms.Compose(
            [   
                transforms.Resize(self.image_size),
                transforms.ToTensor(),
                transforms.Normalize(self.MEAN, self.STD),
            ]
        )


class BaseTransform(AbstractTransform):
    def __init__(self, image_size):
        super().__init__(image_size)
        self.common_transform = [
            transforms.Normalize(
                mean=(0.0, 0.0, 0.0), std=(1.0, 1.0, 1.0)
            ),
        ]

    @property
    def train_transform(self):
        return transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.Resize((self.image_size, self.image_size)),
            transforms.ToTensor(),
            *self.common_transform,
        ])
    
    @property
    def test_transform(self):
        return transforms.Compose([
            transforms.Resize((self.image_size, self.image_size)),
            transforms.ToTensor(),
            *self.common_transform,
        ])


class ImageNetRTransform(BaseTransform):
    def __init__(self, image_size=224):
        super().__init__(image_size)


class DomainNetTransform(BaseTransform):
    def __init__(self, image_size=256):
        super().__init__(image_size)
    