import torch
from torch import Tensor


def map_coordinates_cell_top_left(
    h: int,
    w: int,
    cell_height: float,
    cell_width: float,
) -> Tensor:
    coords_h = torch.arange(h).float()
    coords_w = torch.arange(w).float()
    coords_h *= cell_height
    coords_w *= cell_width

    x_coords, y_coords = torch.meshgrid([coords_w, coords_h], indexing="xy")
    coords_tensor = torch.stack((x_coords, y_coords), dim=0)

    return coords_tensor


def map_coordinates_cell_center(
    h: int, w: int, cell_height: float, cell_width: float, device
) -> Tensor:
    """
    This function generates a tensor of size (2, H, W) where each cell contains the coordinates of the center of that cell.
    Example: For h=3 and w=4, the output would be a 2x3x4 tensor with values:

    [[  [0.5, 1.5, 2.5, 3.5],
        [0.5, 1.5, 2.5, 3.5],
        [0.5, 1.5, 2.5, 3.5]    ],
    [   [0.5, 0.5, 0.5, 0.5],
        [1.5, 1.5, 1.5, 1.5],
        [2.5, 2.5, 2.5, 2.5]    ]]

    Args:
        h (torch.Tensor): The height of the grid.
        w (torch.Tensor): The width of the grid.

    Returns:
        torch.Tensor: A tensor of shape (2, h, w), where each element is a pair of coordinates.
    """
    coords_h = torch.arange(h, device=device).float()
    coords_w = torch.arange(w, device=device).float()
    coords_h += 0.5
    coords_h *= cell_height
    coords_w += 0.5
    coords_w *= cell_width

    x_coords, y_coords = torch.meshgrid([coords_w, coords_h], indexing="xy")
    coords_tensor = torch.stack((x_coords, y_coords), dim=0)

    return coords_tensor


def create_boolean_map_from_idx(
    idx: Tensor,
    h: int,
    w: int,
) -> Tensor:
    """
    Returns a tensor of shape (H, W) where each coordinate in `coords` is marked as 1.0 and all other coordinates are 0.0.
    Args:
        coords (Tensor): A tensor of shape (N, 2) containing the x and y coordinates.
        h (int): Height of the output grid.
        w (int): Width of the output grid.
    """
    if idx.dtype != torch.long:
        raise ValueError("Expected input to be a tensor of type long.")
    if len(idx.shape) != 2 or idx.shape[1] != 2:
        raise ValueError("Expected input to be a tensor of shape (N, 2).")
    if torch.any(idx[..., 0] < 0) or torch.any(idx[..., 0] >= w):
        raise ValueError(
            f"Expected x idx to be in [0, {w}), but found [{idx[..., 0].min()},{idx[..., 0].max()}]."
        )
    if torch.any(idx[..., 1] < 0) or torch.any(idx[..., 1] >= h):
        raise ValueError(
            f"Expected y idx to be in [0, {h}), but found [{idx[..., 1].min()},{idx[..., 1].max()}]."
        )
    map = torch.zeros((h * w), device=idx.device)
    idx = idx[:, 1] * w + idx[:, 0]
    map[idx] = 1.0
    return map.view(h, w)
