import torch
from torch import Tensor


def get_img_extent(h: int, w: int, pixel_size: Tensor):
    img_size = Tensor([h, w]) * pixel_size
    img_extent = [torch.zeros_like(img_size), img_size]
    img_extent = torch.stack(img_extent, dim=-1)
    return img_extent


def get_vol_extent(h: int, w: int, pixel_size: Tensor, z_extent: Tensor):
    img_extent = get_img_extent(h=h, w=w, pixel_size=pixel_size)
    vol_extent = [img_extent, z_extent[None]]
    vol_extent = torch.cat(vol_extent, dim=0)
    return vol_extent
