from .cifar_aug import CIFARTransform
from .imagenet_aug import IMAGENETTransform
from .imagenet_aug import SimMIMTransform
from .imagenet_aug import SimSiamTransform

cifar_norm = [[0.4914, 0.4822, 0.4465], [0.2470, 0.2435, 0.2615]]
imagenet_norm = [[0.4914, 0.4822, 0.4465],[0.2470, 0.2435, 0.2615]]

def get_aug(config=None, name='cifar', is_train=True, transform_single=True, to_pil_image=False):
    if name == 'cifar':
        augmentation = CIFARTransform(normalize=cifar_norm, is_train=is_train, transform_single=transform_single, to_pil_image=to_pil_image)
    elif name == 'imagenet':
        if config.model.backbone == 'simsiam':
            augmentation = SimSiamTransform(config=config, normalize=imagenet_norm, is_train=is_train, transform_single=transform_single)
        elif config.model.backbone == 'simmim':
            augmentation = SimMIMTransform(config=config, normalize=imagenet_norm, is_train=is_train, transform_single=transform_single)
        else:
            augmentation = IMAGENETTransform(normalize=imagenet_norm, is_train=is_train, transform_single=transform_single, to_pil_image=to_pil_image)
        
    else:
        raise NotImplementedError
    return augmentation
