import torchvision.transforms as T

def default_transform(img_size=224):
    return T.Compose(
        [
            T.Resize(img_size),
            T.CenterCrop(img_size),
            T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
        ]
    )

def get_train_crop_transform(original_img_size=140):
    crop_size = 128
    return T.Compose([
        T.RandomCrop((crop_size, crop_size)),           # random 200×200
        T.Resize((original_img_size, original_img_size)),              # resize back to 224×224
        T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
    ])

def get_train_crop_transform_resnet(original_img_size=140):
    crop_size = 128
    return T.Compose([
        T.RandomCrop((crop_size, crop_size)),           # random 200×200
        # T.Resize((original_img_size, original_img_size)),              # resize back to 224×224
        # T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
    ])

def get_eval_crop_transform(original_img_size=140):
    crop_size = 128
    return T.Compose([
        T.CenterCrop((crop_size, crop_size)),          # center 200×200
        T.Resize((original_img_size, original_img_size)),             # resize back to 224×224
        T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
    ])

def get_eval_crop_transform_resnet(original_img_size=140):
    crop_size = 128
    return T.Compose([
        T.CenterCrop((crop_size, crop_size)),          # center 200×200
        # T.Resize((original_img_size, original_img_size)),             # resize back to 224×224
        # T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
    ])