import torch
from einops import rearrange, repeat

def get_position_map_from_depth(depth, intrinsics, extrinsics):
    """Compute the position map from the depth map and the camera parameters.

    Args:
        depth (torch.Tensor): The depth map with the shape (B, F, 1, H, W).
        intrinsics (torch.Tensor): The camera intrinsics matrix in opencv system.
        extrinsics (torch.Tensor): The camera extrinsics matrix.
        image_wh (Tuple[int, int]): The image width and height.

    Returns:
        torch.Tensor: The position map with the shape (H, W, 3).
    """

    b, f, _, h, w = depth.shape
    resolution = h

    depth = rearrange(depth, "b f 1 h w -> b f h w")

    uv = torch.stack(
        torch.meshgrid(
            torch.arange(
                w, dtype=torch.float32, device=depth.device
            ),
            torch.arange(
                h, dtype=torch.float32, device=depth.device
            ),
            indexing="ij",
        )
    ) 
    uv = uv * (1.0 * 1 / resolution) + (0.5 * 1 / resolution)

    uv = repeat(uv, "c h w -> b f c h w", b=b, f=f)
    x_cam = uv[:, :, 0]
    y_cam = uv[:, :, 1] # b x f x h x w
    # Compute the position map by back-projecting depth pixels to 3D space
    fx = intrinsics[..., 0, 0].unsqueeze(-1) # b x f
    fy = intrinsics[..., 1, 1].unsqueeze(-1) # b x f
    cx = intrinsics[..., 0, 2].unsqueeze(-1) # b x f
    cy = intrinsics[..., 1, 2].unsqueeze(-1) # b x f
    z_cam = depth
    # x = (u_coord - intrinsics[..., 0, 2]) * depth / intrinsics[..., 0, 0] # x = (u - cx) * z / fx
    # y = (v_coord - intrinsics[..., 1, 2]) * depth / intrinsics[..., 1, 1] # y = (v - cy) * z / fy

    x_lift = (x_cam- cx.unsqueeze(-1))/ fx.unsqueeze(-1) * z_cam
    y_lift = (y_cam - cy.unsqueeze(-1)) / fy.unsqueeze(-1) * z_cam

    camera_coords = torch.stack([x_lift, y_lift, z_cam, torch.ones_like(z_cam)], dim=-1) # b x f x h x w x 4
    
    world_coords = torch.einsum("b f i j, b f h w i -> b f h w j", extrinsics, camera_coords) # b x f x h x w x 4
    # world_coords = coords_homogeneous @ extrinsics.T

    # Apply the mask to the position map
    position_map = world_coords[..., :3] 

    return position_map