import torch
import numpy as np

def th_uv_grid(H: int, W: int, **tensor_kwargs) -> torch.Tensor:
    '''
    :param H: int
    :param W: int
    :param tensor_kwargs:
    :return: (H, W, 2)
    '''
    v, u = torch.meshgrid(torch.arange(H).to(**tensor_kwargs), torch.arange(W).to(**tensor_kwargs), indexing='ij')
    return torch.stack([u, v], dim=-1)

def depth_to_xyz(intr, depth):
    '''
    :param intr: shape (4,)
    :param depth: shape (H, W)
    :return: shape (H, W, 3)
    '''
    fx, fy, cx, cy = intr[0], intr[1], intr[2], intr[3]
    if isinstance(depth, np.ndarray):
        v, u = np.meshgrid(np.arange(depth.shape[0]), np.arange(depth.shape[1]), indexing='ij')
        x = (u - cx) / fx * depth
        y = (v - cy) / fy * depth
        return np.stack([x, y, depth], axis=-1)
    elif isinstance(depth, torch.Tensor):
        tensor_kwargs = dict(device=depth.device, dtype=depth.dtype)
        v, u = torch.meshgrid(torch.arange(depth.shape[0]).to(**tensor_kwargs), torch.arange(depth.shape[1]).to(**tensor_kwargs))
        x = (u - cx) / fx * depth
        y = (v - cy) / fy * depth
        return torch.stack([x, y, depth], dim=-1)
    else:
        raise ValueError(f'{type(depth)=}')
