import torch



def apply_jigsaw_mask(image, patch_size=32, mask_ratio=0.5, mask_type='random', 
                     shuffle=True, center_offset_x=0.0, center_offset_y=0.0):
    """
    Args:
        image: 输入图像张量 [B, C, H, W]
        patch_size: patch大小
        mask_ratio: 被mask的比例
        mask_type: mask类型 ('random'/'center'/'grid')
        shuffle: 是否打乱patch顺序
        center_offset_x: 中心区域水平位移比例 (-1.0到1.0)
        center_offset_y: 中心区域垂直位移比例 (-1.0到1.0)
    """
    B, C, H, W = image.shape
    device = image.device
    dtype = image.dtype
    
    assert H % patch_size == 0 and W % patch_size == 0
    grid_h, grid_w = H // patch_size, W // patch_size
    num_patches = grid_h * grid_w
    
    # 分割图像为Patch [B, num_patches, C, patch_size, patch_size]
    patches = image.unfold(2, patch_size, patch_size).unfold(3, patch_size, patch_size)
    patches = patches.permute(0, 2, 3, 1, 4, 5).contiguous().view(B, num_patches, C, patch_size, patch_size)
    
    # 创建Mask [B, num_patches, 1, 1, 1]
    if mask_type == 'random':
        mask = torch.rand(B, num_patches, 1, 1, 1, device=device) > mask_ratio

    elif mask_type == 'center':
        # 计算需要保留的中心区域大小
        keep_ratio = 1.0 - mask_ratio  # 保留比例
        keep_h = max(1, int(round(grid_h * keep_ratio**0.5)))  # 中心区域高度
        keep_w = max(1, int(round(grid_w * keep_ratio**0.5)))  # 中心区域宽度

        # 计算中心区域边界，考虑偏移量
        center_offset_x = max(-1.0, min(1.0, center_offset_x))  # 限制在-1到1之间
        center_offset_y = max(-1.0, min(1.0, center_offset_y))
        
        # 计算偏移后的中心位置
        start_h = (grid_h - keep_h) // 2 + int(center_offset_y * (grid_h - keep_h) / 2)
        end_h = start_h + keep_h
        start_w = (grid_w - keep_w) // 2 + int(center_offset_x * (grid_w - keep_w) / 2)
        end_w = start_w + keep_w

        # 确保不超出边界
        start_h = max(0, start_h)
        start_w = max(0, start_w)
        end_h = min(grid_h, end_h)
        end_w = min(grid_w, end_w)

        # 创建中心区域mask [grid_h, grid_w]
        center_mask = torch.zeros(grid_h, grid_w, device=device, dtype=torch.bool)
        center_mask[start_h:end_h, start_w:end_w] = True
        
        # 展平并扩展维度 [B, num_patches, 1, 1, 1]
        center_mask = center_mask.flatten().view(1, num_patches, 1, 1, 1).expand(B, -1, -1, -1, -1)
        mask = center_mask

    elif mask_type == 'grid':
        # 创建棋盘格 [grid_h, grid_w]
        i, j = torch.arange(grid_h, device=device), torch.arange(grid_w, device=device)
        checkerboard = (i[:, None] + j[None, :]) % 2  # 标准棋盘格公式
        
        # 处理边界情况
        if mask_ratio == 0:      # 全保留
            mask = torch.ones(B, num_patches, 1, 1, 1, device=device, dtype=torch.bool)
        elif mask_ratio == 1:    # 全mask
            mask = torch.zeros(B, num_patches, 1, 1, 1, device=device, dtype=torch.bool)
        else:
            # 扩展棋盘格到batch维度
            checkerboard = checkerboard.view(1, grid_h, grid_w).expand(B, -1, -1)
            # 随机选择棋盘格模式（黑格或白格）
            flip = torch.rand(B, 1, 1, device=device) > 0.5
            checkerboard = torch.where(flip, 1 - checkerboard, checkerboard)
            # 转换为mask（True=保留）
            mask = (checkerboard > 0.5).view(B, num_patches, 1, 1, 1)
    else:
        raise ValueError(f"未知mask类型: {mask_type}")
    
    # 应用Mask
    # masked_patches = patches * mask.float()
    
    # 应用Mask - 使用灰色填充被mask的区域
    # 创建灰色背景 [B, num_patches, C, patch_size, patch_size]
    gray_value = 0.5  # RGB灰色值
    gray_bg = torch.full_like(patches, gray_value)
    
    # 扩展mask到与patches相同的形状
    mask_expanded = mask.expand_as(patches)
    
    # 使用torch.where选择保留的patch或灰色背景
    masked_patches = torch.where(mask_expanded, patches, gray_bg)


    # 打乱顺序（如果需要）
    if shuffle:
        idx = torch.rand(B, num_patches, device=device).argsort(dim=1)
        idx = idx.view(B, num_patches, 1, 1, 1).expand(-1, -1, C, patch_size, patch_size)
        processed_patches = torch.gather(masked_patches, 1, idx)
    else:
        processed_patches = masked_patches
    
    # 重组图像 [B, C, H, W]
    jigsaw_image = processed_patches.view(B, grid_h, grid_w, C, patch_size, patch_size)
    jigsaw_image = jigsaw_image.permute(0, 3, 1, 4, 2, 5).contiguous().view(B, C, H, W)
    
    return jigsaw_image



# def apply_content_block_jigsaw(
#     image, 
#     content_patch_size=64, 
#     small_patch_size=16, 
#     shuffle=True,
#     content_pos_x=None,
#     content_pos_y=None
# ):
#     """
#     应用基于内容块的拼图变换：保留一个大块内容区域，其余部分划分为小块并打乱
    
#     Args:
#         image: 输入图像张量 [B, C, H, W]
#         content_patch_size: 内容块的大小（正方形）
#         small_patch_size: 小块的大小（正方形）
#         shuffle: 是否打乱小块顺序
#         content_pos_x: 可选，内容块的x起始位置（可以是长度为B的张量或标量）
#         content_pos_y: 可选，内容块的y起始位置（可以是长度为B的张量或标量）
#     Returns:
#         jigsaw_image: 变换后的图像张量 [B, C, H, W]
#     """
#     B, C, H, W = image.shape
#     device = image.device
    
#     # 1. 确定内容块位置
#     max_x = W - content_patch_size
#     max_y = H - content_patch_size
    
#     # 检查内容块大小是否有效
#     if max_x < 0 or max_y < 0:
#         raise ValueError(f"内容块大小 {content_patch_size} 超过图像尺寸 {H}x{W}")
    
#     # 处理x位置
#     if content_pos_x is not None:
#         if isinstance(content_pos_x, (int, float)):
#             cx = torch.full((B,), int(content_pos_x), device=device)
#         else:
#             cx = content_pos_x.to(device)
#         # 检查边界
#         if (cx < 0).any() or (cx > max_x).any():
#             raise ValueError(f"content_pos_x 超出有效范围 [0, {max_x}]")
#     else:
#         cx = torch.randint(0, max_x + 1, (B,), device=device)
    
#     # 处理y位置
#     if content_pos_y is not None:
#         if isinstance(content_pos_y, (int, float)):
#             cy = torch.full((B,), int(content_pos_y), device=device)
#         else:
#             cy = content_pos_y.to(device)
#         # 检查边界
#         if (cy < 0).any() or (cy > max_y).any():
#             raise ValueError(f"content_pos_y 超出有效范围 [0, {max_y}]")
#     else:
#         cy = torch.randint(0, max_y + 1, (B,), device=device)
    
#     # 2. 计算可划分的小块网格
#     grid_h = H // small_patch_size
#     grid_w = W // small_patch_size
    
#     # 如果没有足够空间划分小块，直接返回原图
#     if grid_h <= 0 or grid_w <= 0:
#         return image.clone()
    
#     # 3. 创建输出图像（先复制原图）
#     jigsaw_image = image.clone()
    
#     # 4. 对每个batch独立处理
#     for b in range(B):
#         # 获取当前batch的内容块位置
#         cur_cx, cur_cy = cx[b], cy[b]
#         content_end_x = cur_cx + content_patch_size
#         content_end_y = cur_cy + content_patch_size
        
#         # 收集当前batch的所有有效小块
#         small_patches = []
#         positions = []
        
#         # 遍历所有可能的网格位置
#         for i in range(grid_h):
#             for j in range(grid_w):
#                 # 计算当前小块位置
#                 y_start = i * small_patch_size
#                 y_end = y_start + small_patch_size
#                 x_start = j * small_patch_size
#                 x_end = x_start + small_patch_size
                
#                 # 检查是否完全在内容块外部（允许边界相邻）
#                 # 判断条件：小块与内容块无重叠区域
#                 if (y_end <= cur_cy or  # 完全在内容块上方
#                     y_start >= content_end_y or  # 完全在内容块下方
#                     x_end <= cur_cx or  # 完全在内容块左侧
#                     x_start >= content_end_x):  # 完全在内容块右侧
                    
#                     # 提取小块 [C, P, P]
#                     patch = image[b, :, y_start:y_end, x_start:x_end].clone()
#                     small_patches.append(patch)
#                     positions.append((y_start, x_start))
        
#         # 如果没有有效小块，跳过该batch
#         if not small_patches:
#             continue
        
#         # 5. 打乱小块顺序
#         if shuffle:
#             # 创建随机排列索引
#             perm = torch.randperm(len(small_patches), device=device)
#             shuffled_patches = [small_patches[i] for i in perm]
#         else:
#             shuffled_patches = small_patches
        
#         # 6. 将打乱后的小块放回图像
#         for idx, (y_start, x_start) in enumerate(positions):
#             jigsaw_image[b, :, 
#                         y_start:y_start+small_patch_size, 
#                         x_start:x_start+small_patch_size] = shuffled_patches[idx]
    
#     return jigsaw_image



def apply_content_block_jigsaw(
    image, 
    content_patch_size=64, 
    small_patch_size=16, 
    shuffle=True,
    content_pos_x=None,
    content_pos_y=None,
    content_block=None
):
    """
    应用基于内容块的拼图变换：保留一个大块内容区域，其余部分划分为小块并打乱
    
    Args:
        image: 输入图像张量 [B, C, H, W]
        content_patch_size: 内容块的大小（正方形）
        small_patch_size: 小块的大小（正方形）
        shuffle: 是否打乱小块顺序
        content_pos_x: 可选，内容块的x起始位置（可以是长度为B的张量或标量）
        content_pos_y: 可选，内容块的y起始位置（可以是长度为B的张量或标量）
        content_block: 可选，内容块坐标 [x1, y1, x2, y2] (可以长度为B的张量或列表)
    Returns:
        jigsaw_image: 变换后的图像张量 [B, C, H, W]
    """
    B, C, H, W = image.shape
    device = image.device
    
    # === 1. 确定内容块位置 ===
    if content_block is None:
        # 默认行为：使用正方形内容块
        max_x = W - content_patch_size
        max_y = H - content_patch_size
        
        # 检查内容块大小是否有效
        if max_x < 0 or max_y < 0:
            raise ValueError(f"内容块大小 {content_patch_size} 超过图像尺寸 {H}x{W}")
        
        # 处理x位置
        if content_pos_x is not None:
            if isinstance(content_pos_x, (int, float)):
                cx = torch.full((B,), int(content_pos_x), device=device)
            else:
                cx = content_pos_x.to(device)
            # 检查边界
            if (cx < 0).any() or (cx > max_x).any():
                raise ValueError(f"content_pos_x 超出有效范围 [0, {max_x}]")
        else:
            cx = torch.randint(0, max_x + 1, (B,), device=device)
        
        # 处理y位置
        if content_pos_y is not None:
            if isinstance(content_pos_y, (int, float)):
                cy = torch.full((B,), int(content_pos_y), device=device)
            else:
                cy = content_pos_y.to(device)
            # 检查边界
            if (cy < 0).any() or (cy > max_y).any():
                raise ValueError(f"content_pos_y 超出有效范围 [0, {max_y}]")
        else:
            cy = torch.randint(0, max_y + 1, (B,), device=device)
        
        # 创建内容块矩形区域 [B, 4]
        content_rect = torch.stack([
            cx, cy, 
            cx + content_patch_size, 
            cy + content_patch_size
        ], dim=1)
    else:
        # 使用content_block指定的内容块
        if isinstance(content_block, torch.Tensor):
            if content_block.dim() == 1:
                content_block = content_block.view(1, 4).expand(B, -1)
            elif content_block.dim() == 2 and content_block.size(0) == 1:
                content_block = content_block.expand(B, 4)
            elif content_block.dim() == 2 and content_block.size(0) != B:
                raise ValueError(f"content_block 需要是形状为 [B, 4] 或 [4] 的张量, 实际为 {content_block.shape}")
            content_block = content_block.to(device)
        else:
            # 列表形式
            if len(content_block) == 4 and isinstance(content_block[0], (int, float)):
                content_block = torch.tensor(content_block, device=device).view(1, 4).expand(B, 4)
            elif len(content_block) == B and all(len(rect)==4 for rect in content_block):
                content_block = torch.tensor(content_block, device=device)
            else:
                raise ValueError(f"content_block 格式无效: 需要是 [x1,y1,x2,y2] 或批量列表")
        
        # 转换为整型并确保有效性
        content_rect = content_block.round().to(torch.int32)
        
        # 边界检查
        invalid_rect = (content_rect[:, 0] < 0).any() | (content_rect[:, 1] < 0).any() | \
                       (content_rect[:, 2] > W).any() | (content_rect[:, 3] > H).any() | \
                       (content_rect[:, 0] >= content_rect[:, 2]).any() | \
                       (content_rect[:, 1] >= content_rect[:, 3]).any()
            
        if invalid_rect:
            raise ValueError(f"content_block 坐标无效或超出图像范围")
    
    # 2. 计算可划分的小块网格
    grid_h = H // small_patch_size
    grid_w = W // small_patch_size
    
    # 如果没有足够空间划分小块，直接返回原图
    if grid_h <= 0 or grid_w <= 0:
        return image.clone()
    
    # 3. 创建输出图像（先复制原图）
    jigsaw_image = image.clone()
    
    # 4. 对每个batch独立处理
    for b in range(B):
        # 获取当前batch的内容块位置
        cur_rect = content_rect[b]
        x1, y1, x2, y2 = cur_rect
        
        # 收集当前batch的所有有效小块
        small_patches = []
        positions = []
        
        # 遍历所有可能的网格位置
        for i in range(grid_h):
            for j in range(grid_w):
                # 计算当前小块位置
                y_start = i * small_patch_size
                y_end = y_start + small_patch_size
                x_start = j * small_patch_size
                x_end = x_start + small_patch_size
                
                # 检查是否与内容块无重叠
                if (y_end <= y1 or  # 完全在内容块上方
                    y_start >= y2 or  # 完全在内容块下方
                    x_end <= x1 or  # 完全在内容块左侧
                    x_start >= x2):  # 完全在内容块右侧
                    
                    # 提取小块 [C, P, P]
                    patch = image[b, :, y_start:y_end, x_start:x_end].clone()
                    small_patches.append(patch)
                    positions.append((y_start, x_start))
        
        # 如果没有有效小块，跳过该batch
        if not small_patches:
            continue
        
        # 5. 打乱小块顺序
        if shuffle:
            # 创建随机排列索引
            perm = torch.randperm(len(small_patches), device=device)
            shuffled_patches = [small_patches[i] for i in perm]
        else:
            shuffled_patches = small_patches
        
        # 6. 将打乱后的小块放回图像
        for idx, (y_start, x_start) in enumerate(positions):
            jigsaw_image[b, :, 
                        y_start:y_start+small_patch_size, 
                        x_start:x_start+small_patch_size] = shuffled_patches[idx]
    
    return jigsaw_image


import torch

def vae_feature_apply_jigsaw_mask(latents, patch_scale=1/8, mask_ratio=0.5, 
                                 shuffle=True, center_offset_x=0.0, center_offset_y=0.0):
    """
    对VAE latent特征应用Jigsaw Mask操作
    
    Args:
        latents: 输入latent特征 [B, C, H, W]
        patch_scale: patch大小相对于原图的比例 (默认1/8)
        mask_ratio: 被mask的比例 (0.0-1.0)
        shuffle: 是否打乱patch顺序
        center_offset_x: 中心区域水平位移比例 (-1.0到1.0)
        center_offset_y: 中心区域垂直位移比例 (-1.0到1.0)
    
    Returns:
        处理后的latent特征 [B, C, H, W]
    """
    B, C, H, W = latents.shape
    device = latents.device
    dtype = latents.dtype
    
    # 计算latent空间中的实际patch尺寸
    patch_size = max(1, int(H * patch_scale))
    
    # 验证patch尺寸有效性
    if H % patch_size != 0 or W % patch_size != 0:
        raise ValueError(f"Patch size {patch_size} 不能整除特征图尺寸 {H}x{W}。"
                         f"请调整patch_scale值 (当前: {patch_scale})")
    
    grid_h, grid_w = H // patch_size, W // patch_size
    num_patches = grid_h * grid_w
    
    # 分割特征为Patch [B, num_patches, C, patch_size, patch_size]
    patches = latents.unfold(2, patch_size, patch_size).unfold(3, patch_size, patch_size)
    patches = patches.permute(0, 2, 3, 1, 4, 5).contiguous()
    patches = patches.view(B, num_patches, C, patch_size, patch_size)
    
    # 创建Random Mask [B, num_patches, 1, 1, 1]
    mask = torch.rand(B, num_patches, 1, 1, 1, device=device) > mask_ratio
    
    # 应用Mask
    masked_patches = patches * mask.to(dtype)
    
    # 打乱顺序（如果需要）
    if shuffle:
        idx = torch.rand(B, num_patches, device=device).argsort(dim=1)
        idx = idx.view(B, num_patches, 1, 1, 1).expand(-1, -1, C, patch_size, patch_size)
        processed_patches = torch.gather(masked_patches, 1, idx)
    else:
        processed_patches = masked_patches
    
    # 重组特征 [B, C, H, W]
    jigsaw_latent = processed_patches.view(B, grid_h, grid_w, C, patch_size, patch_size)
    jigsaw_latent = jigsaw_latent.permute(0, 3, 1, 4, 2, 5).contiguous()
    jigsaw_latent = jigsaw_latent.view(B, C, H, W)
    
    return jigsaw_latent
