import torchvision.transforms as T

from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD


__all__ = [
    "get_transforms"
]


def get_transforms(
    model,
    model_source: str = 'timm',
    transform_test = None
):
    if model_source == 'timm':
        if hasattr(model, 'pretrained_cfg'):
            mean = model.pretrained_cfg['mean']
            std = model.pretrained_cfg['std']
            crop_size = model.pretrained_cfg['input_size'][-1]
            resize_size = round(crop_size / model.pretrained_cfg['crop_pct'])
        else:
            # set to ImageNet defaults
            mean = IMAGENET_DEFAULT_MEAN
            std = IMAGENET_DEFAULT_STD
            crop_size = 224
            resize_size = 256
        
        # create transform
        transform_train = T.Compose([
            T.RandomResizedCrop(crop_size, interpolation=T.InterpolationMode.BICUBIC),
            T.RandomHorizontalFlip(),
            T.ToTensor(),
            T.Normalize(mean, std)]
        )
        transform_test = T.Compose([
            T.Resize(resize_size, interpolation=T.InterpolationMode.BICUBIC),
            T.CenterCrop(crop_size),
            T.ToTensor(),
            T.Normalize(mean, std)
        ])
    elif model_source == 'open_clip':
        # first transform is resize
        resize_size = transform_test.transforms[0].size
        interpolation = transform_test.transforms[0].interpolation
        # second transform is center crop
        crop_size = transform_test.transforms[1].size
        # last transform is normalization
        mean = transform_test.transforms[-1].mean
        std = transform_test.transforms[-1].std

        transform_train = T.Compose([
            T.RandomResizedCrop(crop_size, interpolation=interpolation),
            T.RandomHorizontalFlip(),
            T.ToTensor(),
            T.Normalize(mean, std)]
        )

    return transform_train, transform_test
    