# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List

import numpy as np
import torch
from torch import Tensor

from embodiedqa.registry import TASK_UTILS
from embodiedqa.structures.bbox_3d import BaseInstance3DBoxes
from .partial_bin_based_bbox_coder import PartialBinBasedBBoxCoder


@TASK_UTILS.register_module()
class GroupFree3DBBoxCoder(PartialBinBasedBBoxCoder):
    """Modified partial bin based bbox coder for GroupFree3D.

    Args:
        num_dir_bins (int): Number of bins to encode direction angle.
        num_sizes (int): Number of size clusters.
        mean_sizes (list[list[int]]): Mean size of bboxes in each class.
        with_rot (bool, optional): Whether the bbox is with rotation.
            Defaults to True.
        size_cls_agnostic (bool, optional): Whether the predicted size is
            class-agnostic. Defaults to True.
    """

    def __init__(self,
                 num_dir_bins: int,
                 num_sizes: int,
                 mean_sizes: List[List[int]],
                 with_rot: bool = True,
                 size_cls_agnostic: bool = True) -> None:
        super(GroupFree3DBBoxCoder, self).__init__(
            num_dir_bins=num_dir_bins,
            num_sizes=num_sizes,
            mean_sizes=mean_sizes,
            with_rot=with_rot)
        self.size_cls_agnostic = size_cls_agnostic

    def encode(self, gt_bboxes_3d: BaseInstance3DBoxes,
               gt_labels_3d: Tensor) -> tuple:
        """Encode ground truth to prediction targets.

        Args:
            gt_bboxes_3d (BaseInstance3DBoxes): Ground truth bboxes
                with shape (n, 7).
            gt_labels_3d (torch.Tensor): Ground truth classes.

        Returns:
            tuple: Targets of center, size and direction.
        """
        # generate center target
        center_target = gt_bboxes_3d.gravity_center

        # generate bbox size target
        size_target = gt_bboxes_3d.dims
        size_class_target = gt_labels_3d
        size_res_target = gt_bboxes_3d.dims - gt_bboxes_3d.tensor.new_tensor(
            self.mean_sizes)[size_class_target]

        # generate dir target
        box_num = gt_labels_3d.shape[0]
        if self.with_rot:
            (dir_class_target,
             dir_res_target) = self.angle2class(gt_bboxes_3d.yaw)
        else:
            dir_class_target = gt_labels_3d.new_zeros(box_num)
            dir_res_target = gt_bboxes_3d.tensor.new_zeros(box_num)

        return (center_target, size_target, size_class_target, size_res_target,
                dir_class_target, dir_res_target)

    def decode(self, bbox_out: dict, prefix: str = '') -> Tensor:
        """Decode predicted parts to bbox3d.

        Args:
            bbox_out (dict): Predictions from model, should contain keys below.

                - center: predicted bottom center of bboxes.
                - dir_class: predicted bbox direction class.
                - dir_res: predicted bbox direction residual.
                - size_class: predicted bbox size class.
                - size_res: predicted bbox size residual.
                - size: predicted class-agnostic bbox size
            prefix (str, optional): Decode predictions with specific prefix.
                Defaults to ''.

        Returns:
            torch.Tensor: Decoded bbox3d with shape (batch, n, 7).
        """
        center = bbox_out[f'{prefix}center']
        batch_size, num_proposal = center.shape[:2]

        # decode heading angle
        if self.with_rot:
            dir_class = torch.argmax(bbox_out[f'{prefix}dir_class'], -1)
            dir_res = torch.gather(bbox_out[f'{prefix}dir_res'], 2,
                                   dir_class.unsqueeze(-1))
            dir_res.squeeze_(2)
            dir_angle = self.class2angle(dir_class, dir_res).reshape(
                batch_size, num_proposal, 1)
        else:
            dir_angle = center.new_zeros(batch_size, num_proposal, 1)

        # decode bbox size
        if self.size_cls_agnostic:
            bbox_size = bbox_out[f'{prefix}size'].reshape(
                batch_size, num_proposal, 3)
        else:
            size_class = torch.argmax(
                bbox_out[f'{prefix}size_class'], -1, keepdim=True)
            size_res = torch.gather(
                bbox_out[f'{prefix}size_res'], 2,
                size_class.unsqueeze(-1).repeat(1, 1, 1, 3))
            mean_sizes = center.new_tensor(self.mean_sizes)
            size_base = torch.index_select(mean_sizes, 0,
                                           size_class.reshape(-1))
            bbox_size = size_base.reshape(batch_size, num_proposal,
                                          -1) + size_res.squeeze(2)

        bbox3d = torch.cat([center, bbox_size, dir_angle], dim=-1)
        return bbox3d

    def split_pred(self,
                   cls_preds: Tensor,
                   reg_preds: Tensor,
                   base_xyz: Tensor,
                   prefix: str = '') -> Dict[str, Tensor]:
        """Split predicted features to specific parts.

        Args:
            cls_preds (torch.Tensor): Class predicted features to split.
            reg_preds (torch.Tensor): Regression predicted features to split.
            base_xyz (torch.Tensor): Coordinates of points.
            prefix (str, optional): Decode predictions with specific prefix.
                Defaults to ''.

        Returns:
            dict[str, torch.Tensor]: Split results.
        """
        results = {}
        start, end = 0, 0

        cls_preds_trans = cls_preds.transpose(2, 1)
        reg_preds_trans = reg_preds.transpose(2, 1)

        # decode center
        end += 3
        # (batch_size, num_proposal, 3)
        results[f'{prefix}center_residual'] = \
            reg_preds_trans[..., start:end].contiguous()
        results[f'{prefix}center'] = base_xyz + \
            reg_preds_trans[..., start:end].contiguous()
        start = end

        # decode direction
        end += self.num_dir_bins
        results[f'{prefix}dir_class'] = \
            reg_preds_trans[..., start:end].contiguous()
        start = end

        end += self.num_dir_bins
        dir_res_norm = reg_preds_trans[..., start:end].contiguous()
        start = end

        results[f'{prefix}dir_res_norm'] = dir_res_norm
        results[f'{prefix}dir_res'] = dir_res_norm * (
            np.pi / self.num_dir_bins)

        # decode size
        if self.size_cls_agnostic:
            end += 3
            results[f'{prefix}size'] = \
                reg_preds_trans[..., start:end].contiguous()
        else:
            end += self.num_sizes
            results[f'{prefix}size_class'] = reg_preds_trans[
                ..., start:end].contiguous()
            start = end

            end += self.num_sizes * 3
            size_res_norm = reg_preds_trans[..., start:end]
            batch_size, num_proposal = reg_preds_trans.shape[:2]
            size_res_norm = size_res_norm.view(
                [batch_size, num_proposal, self.num_sizes, 3])
            start = end

            results[f'{prefix}size_res_norm'] = size_res_norm.contiguous()
            mean_sizes = reg_preds.new_tensor(self.mean_sizes)
            results[f'{prefix}size_res'] = (
                size_res_norm * mean_sizes.unsqueeze(0).unsqueeze(0))

        # decode objectness score
        # Group-Free-3D objectness output shape (batch, proposal, 1)
        results[f'{prefix}obj_scores'] = cls_preds_trans[..., :1].contiguous()

        # decode semantic score
        results[f'{prefix}sem_scores'] = cls_preds_trans[..., 1:].contiguous()

        return results
