# Copyright (c) OpenMMLab. All rights reserved.
import torch

from mmdet3d.core.bbox import points_cam2img


def get_keypoints(gt_bboxes_3d_list,
                  centers2d_list,
                  img_metas,
                  use_local_coords=True):
    """Function to filter the objects label outside the image.

    Args:
        gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image,
            shape (num_gt, 4).
        centers2d_list (list[Tensor]): Projected 3D centers onto 2D image,
            shape (num_gt, 2).
        img_metas (list[dict]): Meta information of each image, e.g.,
            image size, scaling factor, etc.
        use_local_coords (bool, optional): Wheher to use local coordinates
            for keypoints. Default: True.

    Returns:
        tuple[list[Tensor]]: It contains two elements, the first is the
        keypoints for each projected 2D bbox in batch data. The second is
        the visible mask of depth calculated by keypoints.
    """

    assert len(gt_bboxes_3d_list) == len(centers2d_list)
    bs = len(gt_bboxes_3d_list)
    keypoints2d_list = []
    keypoints_depth_mask_list = []

    for i in range(bs):
        gt_bboxes_3d = gt_bboxes_3d_list[i]
        centers2d = centers2d_list[i]
        img_shape = img_metas[i]['img_shape']
        cam2img = img_metas[i]['cam2img']
        h, w = img_shape[:2]
        # (N, 8, 3)
        corners3d = gt_bboxes_3d.corners
        top_centers3d = torch.mean(corners3d[:, [0, 1, 4, 5], :], dim=1)
        bot_centers3d = torch.mean(corners3d[:, [2, 3, 6, 7], :], dim=1)
        # (N, 2, 3)
        top_bot_centers3d = torch.stack((top_centers3d, bot_centers3d), dim=1)
        keypoints3d = torch.cat((corners3d, top_bot_centers3d), dim=1)
        # (N, 10, 2)
        keypoints2d = points_cam2img(keypoints3d, cam2img)

        # keypoints mask: keypoints must be inside
        # the image and in front of the camera
        keypoints_x_visible = (keypoints2d[..., 0] >= 0) & (
            keypoints2d[..., 0] <= w - 1)
        keypoints_y_visible = (keypoints2d[..., 1] >= 0) & (
            keypoints2d[..., 1] <= h - 1)
        keypoints_z_visible = (keypoints3d[..., -1] > 0)

        # (N, 1O)
        keypoints_visible = keypoints_x_visible & \
            keypoints_y_visible & keypoints_z_visible
        # center, diag-02, diag-13
        keypoints_depth_valid = torch.stack(
            (keypoints_visible[:, [8, 9]].all(dim=1),
             keypoints_visible[:, [0, 3, 5, 6]].all(dim=1),
             keypoints_visible[:, [1, 2, 4, 7]].all(dim=1)),
            dim=1)
        keypoints_visible = keypoints_visible.float()

        if use_local_coords:
            keypoints2d = torch.cat((keypoints2d - centers2d.unsqueeze(1),
                                     keypoints_visible.unsqueeze(-1)),
                                    dim=2)
        else:
            keypoints2d = torch.cat(
                (keypoints2d, keypoints_visible.unsqueeze(-1)), dim=2)

        keypoints2d_list.append(keypoints2d)
        keypoints_depth_mask_list.append(keypoints_depth_valid)

    return (keypoints2d_list, keypoints_depth_mask_list)
