from .byol_aug import BYOL_transform
from .eval_aug import Transform_single
from .simclr_aug import SimCLRTransform
from .simsiam_aug import SimSiamTransform

imagenet_mean_std = [[0.485, 0.456, 0.406], [0.229, 0.224, 0.225]]


def get_aug(name='simsiam', mean_std=None, image_size=224, train=True, train_classifier=None):
    if mean_std is None:
        mean_std = imagenet_mean_std
    if train:
        if name == 'simsiam':
            augmentation = SimSiamTransform(image_size, mean_std)
        elif name == 'byol':
            augmentation = BYOL_transform(image_size, mean_std)
        elif name == 'simclr':
            augmentation = SimCLRTransform(image_size, mean_std)
        else:
            raise NotImplementedError
    elif not train:
        if train_classifier is None:
            raise Exception
        augmentation = Transform_single(image_size, train=train_classifier)
    else:
        raise Exception

    return augmentation
