import torch
from torch import nn, Tensor


def batched_positions_to_image(positions, h, w, inv_pixel_size):
    """
    Creates binary images with 1s at given positions.
    """
    bs = len(positions)
    device, dtype = positions[0].device, positions[0].dtype

    img = torch.zeros((bs, h, w), dtype=dtype, device=device)

    for i, sample in enumerate(positions):
        idx_2d = (sample[:, :2] * inv_pixel_size).to(dtype=torch.int64)
        idx_2d[:, 0] = idx_2d[:, 0].clamp(0, w - 1)  # Clamp within image width
        idx_2d[:, 1] = idx_2d[:, 1].clamp(0, h - 1)  # Clamp within image height
        img[i, idx_2d[:, 1], idx_2d[:, 0]] = 1.0  # Mark detections

    return img


class DetectionLoss(nn.Module):
    def __init__(self, img_size: (int, int), inv_pixel_size: Tensor):
        super().__init__()
        self.bce = nn.BCELoss()
        self.h, self.w = img_size
        self.register_buffer("inv_pixel_size", inv_pixel_size, persistent=False)

    def forward(self, m_map: Tensor, x_gt: Tensor):
        m_gt = batched_positions_to_image(
            positions=x_gt,
            h=self.h,
            w=self.w,
            inv_pixel_size=self.inv_pixel_size,
        )
        loss = self.bce(m_map, m_gt)
        return loss
