import torch

def XYZMapHW(points, aug, intrinsics, height, width):
    r"""Compute the hw mapping of point cloud with intrinsics.

    Args:
        points: torch.Tensor (N, 3), input point cloud
        aug: torch.Tensor (3, 3), to remove augmentation (if exists) for image mapping
        intrinsics: torch.Tensor (3, 3) or (3, 4), intrinsics of camera
        height: int, height of image
        width: int, width of image

    Returns:
        hw: torch.Tensor (N, 2), input hw mapping of point cloud
        mask: torch.Tensor (N), valid mask
    """
    
    # remove the augmentation
    points = points @ aug # (R @ P' )', R is the transpose of aug_rot

    # point cloud maps to image
    if intrinsics.shape[1] == 3:  # Intrinsics matrix
        points = points.T / points.T[2, :]
        uv = torch.matmul(intrinsics, points)
    elif intrinsics.shape[1] == 4:  # Camera matrix
        points = torch.cat([points, torch.ones((points.shape[0], 1), device = points.device, dtype = points.dtype)], dim = -1).T
        uv = torch.matmul(intrinsics, points)
        uv = uv[:3, :] / uv[2, :]

    # valid mask
    hw = uv[[1, 0], :].T.round().to(torch.int32)
    mask = torch.ones(hw.shape[0], device = hw.device).to(torch.bool)
    mask[hw[:, 0] < 0] = False
    mask[hw[:, 1] < 0] = False
    mask[hw[:, 0] >= height] = False
    mask[hw[:, 1] >= width] = False

    return hw, mask

def PatchMapping(points_hw, height, width, patch_feats, radius = 0):
    r"""Collect the correponding local features from the patch feature maps.

    Args:
        points_hw: torch.Tensor (N, 2), input hw mappings
        height: int, height of image
        width: int, width of image
        patch_feats: torch.Tensor (C, H, W), input patch feature maps
        radius: int, radius of local window, set to 0 for one-to-one mapping

    Returns:
        torch.Tensor (N, C, R * 2 + 1, R * 2 + 1), collected local features
    """

    patch_idxs = points_hw.float()
    patch_idxs = patch_idxs / torch.tensor([height, width], device = patch_idxs.device).unsqueeze(0)
    patch_idxs = patch_idxs * torch.tensor([patch_feats.shape[1], patch_feats.shape[2]], device = patch_idxs.device).unsqueeze(0)
    patch_idxs = patch_idxs.int()

    if radius > 0:
        # padding
        patch_feats = torch.cat([torch.zeros((patch_feats.shape[0], radius, patch_feats.shape[2]), dtype = torch.float32, device = patch_feats.device), patch_feats], dim = 1)
        patch_feats = torch.cat([patch_feats, torch.zeros((patch_feats.shape[0], radius, patch_feats.shape[2]), dtype = torch.float32, device = patch_feats.device)], dim = 1)
        patch_feats = torch.cat([torch.zeros((patch_feats.shape[0], patch_feats.shape[1], radius), dtype = torch.float32, device = patch_feats.device), patch_feats], dim = 2)
        patch_feats = torch.cat([patch_feats, torch.zeros((patch_feats.shape[0], patch_feats.shape[1], radius), dtype = torch.float32, device = patch_feats.device)], dim = 2)
        patch_idxs = patch_idxs + radius

    feats_list = []
    for i in range(patch_idxs.shape[0]):
        local_feat = patch_feats[:, patch_idxs[i, 0] - radius:patch_idxs[i, 0] + radius + 1, :]
        local_feat = local_feat[:, :, patch_idxs[i, 1] - radius:patch_idxs[i, 1] + radius + 1]
        feats_list.append(local_feat[None, :, :, :])

    return torch.cat(feats_list, dim = 0)

def normalize_hw(hw, size):
    kpts = hw.float()
    if not isinstance(size, torch.Tensor):
        size = torch.tensor(size, device=kpts.device, dtype=kpts.dtype)
    shift = size / 2
    scale = size.max(-1).values / 2
    kpts = (kpts - shift[None, :]) / scale
    return kpts

def rotate_half(x: torch.Tensor) -> torch.Tensor:
    x = x.unflatten(-1, (-1, 2))  # [B, H, N, C // 2, 2]
    x1, x2 = x.unbind(dim=-1)
    return torch.stack((-x2, x1), dim=-1).flatten(start_dim=-2)

def apply_cached_rotary_emb(
        freqs: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
    return (t * freqs[0]) + (rotate_half(t) * freqs[1])