import random
from torchvision.transforms import functional as TFF
from torchvision import transforms as TF


def paired_transform(
    img1, 
    img2, 
    crop_size=128, 
    hflip=True, 
    rotation=True,
    crop_prob=0.7, 
):
    if isinstance(crop_size, int):
        crop_size = (crop_size, crop_size)
    
    # -------- 随机决策：裁剪 or resize --------
    use_crop = random.random() < crop_prob
    if use_crop:
        # 检查图像是否足够大，否则强制resize
        if img1.height < crop_size[0] or img1.width < crop_size[1]:
            use_crop = False
    if use_crop:
        i, j, h, w = TF.RandomCrop.get_params(img, output_size=crop_size)

    # -------- 随机增强参数 --------
    do_hflip = hflip and random.random() < 0.5
    do_vflip = rotation and random.random() < 0.5
    do_rot90 = rotation and random.random() < 0.5

    # -------- 统一操作函数 --------
    def apply(img):
        # 策略分支
        if use_crop:
            img = TFF.crop(img, i, j, h, w)
        else:
            img = TFF.resize(img, crop_size)  # 直接resize原图
        
        # 公共增强
        if do_hflip:
            img = TFF.hflip(img)
        if do_vflip:
            img = TFF.vflip(img)
        if do_rot90:
            img = TFF.rotate(img, 90)
        img = TFF.to_tensor(img)
        return img

    # 应用变换
    img1 = apply(img1)
    img2 = apply(img2)
    return img1, img2
