import torch
import torch.nn.functional as F


def extract_patch(image, centers, patch_size=7):
    """
    image: tensor of shape (B, C, H, W)
    centers: tensor of shape (B, 2) with normalized coordinates in [-1, 1]
             where (x, y) order is used.
    patch_size: size of patch to extract (default 7 for 7x7)
    """
    B, C, H, W = image.shape

    # Convert centers to shape (B, 1, 1, 2) for broadcasting.
    centers = centers.view(B, 1, 1, 2)

    # Determine normalized pixel spacing.
    # Here, using definition: [0,0] is top left corner of top left pixel
    # so a pixel width/height is 2/W and 2/H respectively.
    dx = 2.0 / W
    dy = 2.0 / H

    # Compute pixel offset indices around center.
    # For a patch of size 7, we want offsets from -3 to 3
    half_size = patch_size // 2
    # Create a grid of offsets (in pixel coordinates) then scale to normalized units:
    offsets_x = (
        torch.arange(-half_size, half_size + 1, device=image.device, dtype=image.dtype)
        * dx
    )
    offsets_y = (
        torch.arange(-half_size, half_size + 1, device=image.device, dtype=image.dtype)
        * dy
    )

    # Create a meshgrid of offsets.
    # Note: 'ij' indexing gives grid_y, grid_x; we then stack them as (x, y)
    grid_y, grid_x = torch.meshgrid(offsets_y, offsets_x, indexing="ij")
    offset_grid = torch.stack(
        (grid_x, grid_y), dim=-1
    )  # shape (patch_size, patch_size, 2)

    # Add the offset grid to each center.
    # centers has shape (B,1,1,2) and offset_grid shape is (7,7,2) so we broadcast.
    sampling_grid = centers + offset_grid

    # Use grid_sample to get the patch. Here align_corners=False is a common choice when using [0,0] as pixel corner.
    # (If you prefer aligning with pixel centers you might use align_corners=True and adjust dx/dy accordingly.)
    patch = F.grid_sample(image, sampling_grid, mode="bilinear", align_corners=False)

    # patch shape will be (B, C, patch_size, patch_size)
    return patch


# Example usage:
if __name__ == "__main__":
    # Dummy image: batch of 2 images of 32x32 with 3 channels.
    image = torch.randn(2, 3, 32, 32)

    # Example centers: two coordinates in normalized space (x, y)
    # (e.g. center of the image: (0,0) in normalized coordinates can be defined appropriately)
    centers = torch.tensor([[0.0, 0.0], [0.5, -0.5]], dtype=torch.float32)

    patches = extract_patch(image, centers)
    print(patches.shape)  # Should print: torch.Size([2, 3, 7, 7])
