# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch
from torch.nn import functional as F

from mmdet.core.bbox import BaseBBoxCoder
from mmdet.core.bbox.builder import BBOX_CODERS


@BBOX_CODERS.register_module()
class MonoFlexCoder(BaseBBoxCoder):
    """Bbox Coder for MonoFlex.

    Args:
        depth_mode (str): The mode for depth calculation.
            Available options are "linear", "inv_sigmoid", and "exp".
        base_depth (tuple[float]): References for decoding box depth.
        depth_range (list): Depth range of predicted depth.
        combine_depth (bool): Whether to use combined depth (direct depth
            and depth from keypoints) or use direct depth only.
        uncertainty_range (list): Uncertainty range of predicted depth.
        base_dims (tuple[tuple[float]]): Dimensions mean and std of decode bbox
            dimensions [l, h, w] for each category.
        dims_mode (str): The mode for dimension calculation.
            Available options are "linear" and "exp".
        multibin (bool): Whether to use multibin representation.
        num_dir_bins (int): Number of Number of bins to encode
            direction angle.
        bin_centers (list[float]): Local yaw centers while using multibin
            representations.
        bin_margin (float): Margin of multibin representations.
        code_size (int): The dimension of boxes to be encoded.
        eps (float, optional): A value added to the denominator for numerical
            stability. Default 1e-3.
    """

    def __init__(self,
                 depth_mode,
                 base_depth,
                 depth_range,
                 combine_depth,
                 uncertainty_range,
                 base_dims,
                 dims_mode,
                 multibin,
                 num_dir_bins,
                 bin_centers,
                 bin_margin,
                 code_size,
                 eps=1e-3):
        super(MonoFlexCoder, self).__init__()

        # depth related
        self.depth_mode = depth_mode
        self.base_depth = base_depth
        self.depth_range = depth_range
        self.combine_depth = combine_depth
        self.uncertainty_range = uncertainty_range

        # dimensions related
        self.base_dims = base_dims
        self.dims_mode = dims_mode

        # orientation related
        self.multibin = multibin
        self.num_dir_bins = num_dir_bins
        self.bin_centers = bin_centers
        self.bin_margin = bin_margin

        # output related
        self.bbox_code_size = code_size
        self.eps = eps

    def encode(self, gt_bboxes_3d):
        """Encode ground truth to prediction targets.

        Args:
            gt_bboxes_3d (`BaseInstance3DBoxes`): Ground truth 3D bboxes.
                shape: (N, 7).

        Returns:
            torch.Tensor: Targets of orientations.
        """
        local_yaw = gt_bboxes_3d.local_yaw
        # encode local yaw (-pi ~ pi) to multibin format
        encode_local_yaw = local_yaw.new_zeros(
            [local_yaw.shape[0], self.num_dir_bins * 2])
        bin_size = 2 * np.pi / self.num_dir_bins
        margin_size = bin_size * self.bin_margin

        bin_centers = local_yaw.new_tensor(self.bin_centers)
        range_size = bin_size / 2 + margin_size

        offsets = local_yaw.unsqueeze(1) - bin_centers.unsqueeze(0)
        offsets[offsets > np.pi] = offsets[offsets > np.pi] - 2 * np.pi
        offsets[offsets < -np.pi] = offsets[offsets < -np.pi] + 2 * np.pi

        for i in range(self.num_dir_bins):
            offset = offsets[:, i]
            inds = abs(offset) < range_size
            encode_local_yaw[inds, i] = 1
            encode_local_yaw[inds, i + self.num_dir_bins] = offset[inds]

        orientation_target = encode_local_yaw

        return orientation_target

    def decode(self, bbox, base_centers2d, labels, downsample_ratio, cam2imgs):
        """Decode bounding box regression into 3D predictions.

        Args:
            bbox (Tensor): Raw bounding box predictions for each
                predict center2d point.
                shape: (N, C)
            base_centers2d (torch.Tensor): Base centers2d for 3D bboxes.
                shape: (N, 2).
            labels (Tensor): Batch predict class label for each predict
                center2d point.
                shape: (N, )
            downsample_ratio (int): The stride of feature map.
            cam2imgs (Tensor): Batch images' camera intrinsic matrix.
                shape: kitti (N, 4, 4)  nuscenes (N, 3, 3)

        Return:
            dict: The 3D prediction dict decoded from regression map.
            the dict has components below:
                - bboxes2d (torch.Tensor): Decoded [x1, y1, x2, y2] format
                    2D bboxes.
                - dimensions (torch.Tensor): Decoded dimensions for each
                    object.
                - offsets2d (torch.Tenosr): Offsets between base centers2d
                    and real centers2d.
                - direct_depth (torch.Tensor): Decoded directly regressed
                    depth.
                - keypoints2d (torch.Tensor): Keypoints of each projected
                    3D box on image.
                - keypoints_depth (torch.Tensor): Decoded depth from keypoints.
                - combined_depth (torch.Tensor): Combined depth using direct
                    depth and keypoints depth with depth uncertainty.
                - orientations (torch.Tensor): Multibin format orientations
                    (local yaw) for each objects.
        """

        # 4 dimensions for FCOS style regression
        pred_bboxes2d = bbox[:, 0:4]

        # change FCOS style to [x1, y1, x2, y2] format for IOU Loss
        pred_bboxes2d = self.decode_bboxes2d(pred_bboxes2d, base_centers2d)

        # 2 dimensions for projected centers2d offsets
        pred_offsets2d = bbox[:, 4:6]

        # 3 dimensions for 3D bbox dimensions offsets
        pred_dimensions_offsets3d = bbox[:, 29:32]

        # the first 8 dimensions are for orientation bin classification
        # and the second 8 dimensions are for orientation offsets.
        pred_orientations = torch.cat((bbox[:, 32:40], bbox[:, 40:48]), dim=1)

        # 3 dimensions for the uncertainties of the solved depths from
        # groups of keypoints
        pred_keypoints_depth_uncertainty = bbox[:, 26:29]

        # 1 dimension for the uncertainty of directly regressed depth
        pred_direct_depth_uncertainty = bbox[:, 49:50].squeeze(-1)

        # 2 dimension of offsets x keypoints (8 corners + top/bottom center)
        pred_keypoints2d = bbox[:, 6:26].reshape(-1, 10, 2)

        # 1 dimension for depth offsets
        pred_direct_depth_offsets = bbox[:, 48:49].squeeze(-1)

        # decode the pred residual dimensions to real dimensions
        pred_dimensions = self.decode_dims(labels, pred_dimensions_offsets3d)
        pred_direct_depth = self.decode_direct_depth(pred_direct_depth_offsets)
        pred_keypoints_depth = self.keypoints2depth(pred_keypoints2d,
                                                    pred_dimensions, cam2imgs,
                                                    downsample_ratio)

        pred_direct_depth_uncertainty = torch.clamp(
            pred_direct_depth_uncertainty, self.uncertainty_range[0],
            self.uncertainty_range[1])
        pred_keypoints_depth_uncertainty = torch.clamp(
            pred_keypoints_depth_uncertainty, self.uncertainty_range[0],
            self.uncertainty_range[1])

        if self.combine_depth:
            pred_depth_uncertainty = torch.cat(
                (pred_direct_depth_uncertainty.unsqueeze(-1),
                 pred_keypoints_depth_uncertainty),
                dim=1).exp()
            pred_depth = torch.cat(
                (pred_direct_depth.unsqueeze(-1), pred_keypoints_depth), dim=1)
            pred_combined_depth = \
                self.combine_depths(pred_depth, pred_depth_uncertainty)
        else:
            pred_combined_depth = None

        preds = dict(
            bboxes2d=pred_bboxes2d,
            dimensions=pred_dimensions,
            offsets2d=pred_offsets2d,
            keypoints2d=pred_keypoints2d,
            orientations=pred_orientations,
            direct_depth=pred_direct_depth,
            keypoints_depth=pred_keypoints_depth,
            combined_depth=pred_combined_depth,
            direct_depth_uncertainty=pred_direct_depth_uncertainty,
            keypoints_depth_uncertainty=pred_keypoints_depth_uncertainty,
        )

        return preds

    def decode_direct_depth(self, depth_offsets):
        """Transform depth offset to directly regressed depth.

        Args:
            depth_offsets (torch.Tensor): Predicted depth offsets.
                shape: (N, )

        Return:
            torch.Tensor: Directly regressed depth.
                shape: (N, )
        """
        if self.depth_mode == 'exp':
            direct_depth = depth_offsets.exp()
        elif self.depth_mode == 'linear':
            base_depth = depth_offsets.new_tensor(self.base_depth)
            direct_depth = depth_offsets * base_depth[1] + base_depth[0]
        elif self.depth_mode == 'inv_sigmoid':
            direct_depth = 1 / torch.sigmoid(depth_offsets) - 1
        else:
            raise ValueError

        if self.depth_range is not None:
            direct_depth = torch.clamp(
                direct_depth, min=self.depth_range[0], max=self.depth_range[1])

        return direct_depth

    def decode_location(self,
                        base_centers2d,
                        offsets2d,
                        depths,
                        cam2imgs,
                        downsample_ratio,
                        pad_mode='default'):
        """Retrieve object location.

        Args:
            base_centers2d (torch.Tensor): predicted base centers2d.
                shape: (N, 2)
            offsets2d (torch.Tensor): The offsets between real centers2d
                and base centers2d.
                shape: (N , 2)
            depths (torch.Tensor): Depths of objects.
                shape: (N, )
            cam2imgs (torch.Tensor): Batch images' camera intrinsic matrix.
                shape: kitti (N, 4, 4)  nuscenes (N, 3, 3)
            downsample_ratio (int): The stride of feature map.
            pad_mode (str, optional): Padding mode used in
                training data augmentation.

        Return:
            tuple(torch.Tensor): Centers of 3D boxes.
                shape: (N, 3)
        """
        N = cam2imgs.shape[0]
        # (N, 4, 4)
        cam2imgs_inv = cam2imgs.inverse()
        if pad_mode == 'default':
            centers2d_img = (base_centers2d + offsets2d) * downsample_ratio
        else:
            raise NotImplementedError
        # (N, 3)
        centers2d_img = \
            torch.cat((centers2d_img, depths.unsqueeze(-1)), dim=1)
        # (N, 4, 1)
        centers2d_extend = \
            torch.cat((centers2d_img, centers2d_img.new_ones(N, 1)),
                      dim=1).unsqueeze(-1)
        locations = torch.matmul(cam2imgs_inv, centers2d_extend).squeeze(-1)

        return locations[:, :3]

    def keypoints2depth(self,
                        keypoints2d,
                        dimensions,
                        cam2imgs,
                        downsample_ratio=4,
                        group0_index=[(7, 3), (0, 4)],
                        group1_index=[(2, 6), (1, 5)]):
        """Decode depth form three groups of keypoints and geometry projection
        model. 2D keypoints inlucding 8 coreners and top/bottom centers will be
        divided into three groups which will be used to calculate three depths
        of object.

        .. code-block:: none

                Group center keypoints:

                             + --------------- +
                            /|   top center   /|
                           / |      .        / |
                          /  |      |       /  |
                         + ---------|----- +   +
                         |  /       |      |  /
                         | /        .      | /
                         |/ bottom center  |/
                         + --------------- +

                Group 0 keypoints:

                             0
                             + -------------- +
                            /|               /|
                           / |              / |
                          /  |            5/  |
                         + -------------- +   +
                         |  /3            |  /
                         | /              | /
                         |/               |/
                         + -------------- + 6

                Group 1 keypoints:

                                               4
                             + -------------- +
                            /|               /|
                           / |              / |
                          /  |             /  |
                       1 + -------------- +   + 7
                         |  /             |  /
                         | /              | /
                         |/               |/
                       2 + -------------- +


        Args:
            keypoints2d (torch.Tensor): Keypoints of objects.
                8 vertices + top/bottom center.
                shape: (N, 10, 2)
            dimensions (torch.Tensor): Dimensions of objetcts.
                shape: (N, 3)
            cam2imgs (torch.Tensor): Batch images' camera intrinsic matrix.
                shape: kitti (N, 4, 4)  nuscenes (N, 3, 3)
            downsample_ratio (int, opitonal): The stride of feature map.
                Defaults: 4.
            group0_index(list[tuple[int]], optional): Keypoints group 0
                of index to calculate the depth.
                Defaults: [0, 3, 4, 7].
            group1_index(list[tuple[int]], optional): Keypoints group 1
                of index to calculate the depth.
                Defaults: [1, 2, 5, 6]

        Return:
            tuple(torch.Tensor): Depth computed from three groups of
                keypoints (top/bottom, group0, group1)
                shape: (N, 3)
        """

        pred_height_3d = dimensions[:, 1].clone()
        f_u = cam2imgs[:, 0, 0]
        center_height = keypoints2d[:, -2, 1] - keypoints2d[:, -1, 1]
        corner_group0_height = keypoints2d[:, group0_index[0], 1] \
            - keypoints2d[:, group0_index[1], 1]
        corner_group1_height = keypoints2d[:, group1_index[0], 1] \
            - keypoints2d[:, group1_index[1], 1]
        center_depth = f_u * pred_height_3d / (
            F.relu(center_height) * downsample_ratio + self.eps)
        corner_group0_depth = (f_u * pred_height_3d).unsqueeze(-1) / (
            F.relu(corner_group0_height) * downsample_ratio + self.eps)
        corner_group1_depth = (f_u * pred_height_3d).unsqueeze(-1) / (
            F.relu(corner_group1_height) * downsample_ratio + self.eps)

        corner_group0_depth = corner_group0_depth.mean(dim=1)
        corner_group1_depth = corner_group1_depth.mean(dim=1)

        keypoints_depth = torch.stack(
            (center_depth, corner_group0_depth, corner_group1_depth), dim=1)
        keypoints_depth = torch.clamp(
            keypoints_depth, min=self.depth_range[0], max=self.depth_range[1])

        return keypoints_depth

    def decode_dims(self, labels, dims_offset):
        """Retrieve object dimensions.

        Args:
            labels (torch.Tensor): Each points' category id.
                shape: (N, K)
            dims_offset (torch.Tensor): Dimension offsets.
                shape: (N, 3)

        Returns:
            torch.Tensor: Shape (N, 3)
        """

        if self.dims_mode == 'exp':
            dims_offset = dims_offset.exp()
        elif self.dims_mode == 'linear':
            labels = labels.long()
            base_dims = dims_offset.new_tensor(self.base_dims)
            dims_mean = base_dims[:, :3]
            dims_std = base_dims[:, 3:6]
            cls_dimension_mean = dims_mean[labels, :]
            cls_dimension_std = dims_std[labels, :]
            dimensions = dims_offset * cls_dimension_mean + cls_dimension_std
        else:
            raise ValueError

        return dimensions

    def decode_orientation(self, ori_vector, locations):
        """Retrieve object orientation.

        Args:
            ori_vector (torch.Tensor): Local orientation vector
                in [axis_cls, head_cls, sin, cos] format.
                shape: (N, num_dir_bins * 4)
            locations (torch.Tensor): Object location.
                shape: (N, 3)

        Returns:
            tuple[torch.Tensor]: yaws and local yaws of 3d bboxes.
        """
        if self.multibin:
            pred_bin_cls = ori_vector[:, :self.num_dir_bins * 2].view(
                -1, self.num_dir_bins, 2)
            pred_bin_cls = pred_bin_cls.softmax(dim=2)[..., 1]
            orientations = ori_vector.new_zeros(ori_vector.shape[0])
            for i in range(self.num_dir_bins):
                mask_i = (pred_bin_cls.argmax(dim=1) == i)
                start_bin = self.num_dir_bins * 2 + i * 2
                end_bin = start_bin + 2
                pred_bin_offset = ori_vector[mask_i, start_bin:end_bin]
                orientations[mask_i] = pred_bin_offset[:, 0].atan2(
                    pred_bin_offset[:, 1]) + self.bin_centers[i]
        else:
            axis_cls = ori_vector[:, :2].softmax(dim=1)
            axis_cls = axis_cls[:, 0] < axis_cls[:, 1]
            head_cls = ori_vector[:, 2:4].softmax(dim=1)
            head_cls = head_cls[:, 0] < head_cls[:, 1]
            # cls axis
            orientations = self.bin_centers[axis_cls + head_cls * 2]
            sin_cos_offset = F.normalize(ori_vector[:, 4:])
            orientations += sin_cos_offset[:, 0].atan(sin_cos_offset[:, 1])

        locations = locations.view(-1, 3)
        rays = locations[:, 0].atan2(locations[:, 2])
        local_yaws = orientations
        yaws = local_yaws + rays

        larger_idx = (yaws > np.pi).nonzero(as_tuple=False)
        small_idx = (yaws < -np.pi).nonzero(as_tuple=False)
        if len(larger_idx) != 0:
            yaws[larger_idx] -= 2 * np.pi
        if len(small_idx) != 0:
            yaws[small_idx] += 2 * np.pi

        larger_idx = (local_yaws > np.pi).nonzero(as_tuple=False)
        small_idx = (local_yaws < -np.pi).nonzero(as_tuple=False)
        if len(larger_idx) != 0:
            local_yaws[larger_idx] -= 2 * np.pi
        if len(small_idx) != 0:
            local_yaws[small_idx] += 2 * np.pi

        return yaws, local_yaws

    def decode_bboxes2d(self, reg_bboxes2d, base_centers2d):
        """Retrieve [x1, y1, x2, y2] format 2D bboxes.

        Args:
            reg_bboxes2d (torch.Tensor): Predicted FCOS style
                2D bboxes.
                shape: (N, 4)
            base_centers2d (torch.Tensor): predicted base centers2d.
                shape: (N, 2)

        Returns:
            torch.Tenosr: [x1, y1, x2, y2] format 2D bboxes.
        """
        centers_x = base_centers2d[:, 0]
        centers_y = base_centers2d[:, 1]

        xs_min = centers_x - reg_bboxes2d[..., 0]
        ys_min = centers_y - reg_bboxes2d[..., 1]
        xs_max = centers_x + reg_bboxes2d[..., 2]
        ys_max = centers_y + reg_bboxes2d[..., 3]

        bboxes2d = torch.stack([xs_min, ys_min, xs_max, ys_max], dim=-1)

        return bboxes2d

    def combine_depths(self, depth, depth_uncertainty):
        """Combine all the prediced depths with depth uncertainty.

        Args:
            depth (torch.Tensor): Predicted depths of each object.
                2D bboxes.
                shape: (N, 4)
            depth_uncertainty (torch.Tensor): Depth uncertainty for
                each depth of each object.
                shape: (N, 4)

        Returns:
            torch.Tenosr: combined depth.
        """
        uncertainty_weights = 1 / depth_uncertainty
        uncertainty_weights = \
            uncertainty_weights / \
            uncertainty_weights.sum(dim=1, keepdim=True)
        combined_depth = torch.sum(depth * uncertainty_weights, dim=1)

        return combined_depth
