# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the CC-by-NC license found in the
# LICENSE file in the root directory of this source tree.
import torch
from torchvision.transforms.v2 import (
    Compose,
    RandomHorizontalFlip,
    ToDtype,
    ToImage,
    Resize,
    InterpolationMode,
)


def get_train_transform():
    transform_list = [
        ToImage(),
        RandomHorizontalFlip(),
        ToDtype(torch.float32, scale=True),
    ]
    return Compose(transform_list)

def get_train_transform_resize(image_size):
    """
    Args:
        image_size (int or tuple): 리사이즈할 크기. 
            - int 을 주면 (image_size, image_size) 로 사용합니다.
            - (h, w) tuple 도 가능.
    """
    transform_list = [
        ToImage(),
        Resize((image_size, image_size), interpolation=InterpolationMode.BILINEAR),
        RandomHorizontalFlip(),
        ToDtype(torch.float32, scale=True),
    ]
    return Compose(transform_list)

def get_train_transform_celeba(image_size):
    """
    Args:
        image_size (int or tuple): 리사이즈할 크기. 
            - int 을 주면 (image_size, image_size) 로 사용합니다.
            - (h, w) tuple 도 가능.
    """
    transform_list = [
        ToImage(),
        Resize((image_size, image_size), interpolation=InterpolationMode.BILINEAR),
        RandomHorizontalFlip(),
        ToDtype(torch.float32, scale=True),
    ]
    return Compose(transform_list)
