import torch
from torch import Tensor


def batched_insert_2d(to_insert: Tensor, i_idx: Tensor, j_idx: Tensor, output: Tensor):
    """
    Inserts 2D matrices from a batch into specified positions within larger 2D matrices.

    Args:
        to_insert (Tensor): A tensor of shape (N, h, w) containing the smaller matrices to insert.
        idx_2d (Tensor): A tensor of shape (N, 2) specifying the top-left (row, column) indices for each insertion.
        output (Tensor): A tensor of shape (N, H, W) where the smaller matrices will be inserted.

    Returns:
        Tensor: A tensor of shape (N, H, W) with the smaller matrices inserted at the specified positions.
    """

    bs, h, w = to_insert.size(0), to_insert.size(1), to_insert.size(2)

    # Create row and column indices for each small matrix
    rows, cols = (
        torch.arange(h, device=to_insert.device),
        torch.arange(w, device=to_insert.device),
    )
    row_indices = rows[None, :, None] + i_idx[:, None, None]
    col_indices = cols[None, None, :] + j_idx[:, None, None]

    # Batch indices
    batch_indices = torch.arange(bs, device=to_insert.device)[:, None, None]

    # Embed values in output
    output[batch_indices, row_indices, col_indices] = to_insert


def main():
    output = torch.zeros((2, 5, 5))
    idx_2d = torch.tensor([[1, 2], [3, 0]])
    to_insert = torch.arange(2 * 2 * 3, dtype=output.dtype).view(2, 2, 3)
    batched_insert_2d(to_insert, idx_2d=idx_2d, output=output)
    print(output)


if __name__ == "__main__":
    main()
