def img_to_patch(x, patch_size, flatten_channels=True):
    """
    Args:
        x: Tensor representing the image of shape [B, C, H, W]
        patch_size: Number of pixels per dimension of the patches (integer)
        flatten_channels: If True, the patches will be returned in a flattened format
                           as a feature vector instead of an image grid.
    """
    B, C, H, W = x.shape
    x = x.reshape(B, C, H // patch_size, patch_size, W // patch_size, patch_size)
    x = x.permute(0, 2, 4, 1, 3, 5)  # [B, H', W', C, p_H, p_W]
    x = x.flatten(1, 2)  # [B, H'*W', C, p_H, p_W]
    if flatten_channels:
        x = x.flatten(2, 4)  # [B, H'*W', C*p_H*p_W]
    return x


def img_to_overlapping_patch(x, patch_size, flatten_channels=True, overlap=0.5):
    """
    Args:
        x: Tensor representing the image of shape [B, C, H, W]
        patch_size: Number of pixels per dimension of the patches (integer)
        flatten_channels: If True, the patches will be returned in a flattened format
                           as a feature vector instead of an image grid.
        overlap: Overlap fraction between consecutive patches (float between 0 and 1)
    """
    B, C, H, W = x.shape
    step_size = int(patch_size * (1 - overlap))

    # Extract patches using unfold
    x = x.unfold(2, patch_size, step_size)  # Unfold along height (dim 2)
    x = x.unfold(3, patch_size, step_size)  # Unfold along width (dim 3)

    # Permute dimensions to [B, num_patches_h, num_patches_w, C, pH, pW]
    x = x.permute(0, 2, 3, 1, 4, 5)

    # Flatten the patch indices into a single dimension
    x = x.flatten(1, 2)  # [B, num_patches, C, pH, pW]

    if flatten_channels:
        x = x.flatten(2, 4)  # Flatten channels and spatial dimensions

    return x


def compute_num_patches(patch_size, overlap, H=224, W=224):
    step_size = int(patch_size * (1 - overlap))
    num_h = (H - patch_size) // step_size + 1
    num_w = (W - patch_size) // step_size + 1
    return num_h * num_w, num_w