import torch
import torch.nn as nn
from mmcv.runner import auto_fp16, force_fp32
from mmdet3d.core import bbox3d2result
from mmdet3d.models.detectors.mvx_two_stage import MVXTwoStageDetector
from mmdet.models import DETECTORS
from torch import Tensor

from projects.mmdet3d_plugin.models.utils.grid_mask import GridMask
from projects.mmdet3d_plugin.models.utils.misc import locations


@DETECTORS.register_module()
class Petr3DPruning(MVXTwoStageDetector):

    def __init__(
        self,
        use_grid_mask=False,
        pts_voxel_layer=None,
        pts_voxel_encoder=None,
        pts_middle_encoder=None,
        pts_fusion_layer=None,
        img_backbone=None,
        pts_backbone=None,
        img_neck=None,
        pts_neck=None,
        pts_bbox_head=None,
        img_roi_head=None,
        img_rpn_head=None,
        train_cfg=None,
        test_cfg=None,
        num_frame_head_grads=2,
        num_frame_backbone_grads=2,
        num_frame_losses=2,
        stride=16,
        position_level=0,
        aux_2d_only=True,
        single_test=False,
        pretrained=None,
        num_pruned=408,
        num_query=236,
        num_propagated=64,
    ):
        super(Petr3DPruning, self).__init__(
            pts_voxel_layer,
            pts_voxel_encoder,
            pts_middle_encoder,
            pts_fusion_layer,
            img_backbone,
            pts_backbone,
            img_neck,
            pts_neck,
            pts_bbox_head,
            img_roi_head,
            img_rpn_head,
            train_cfg,
            test_cfg,
            pretrained,
        )
        self.grid_mask = GridMask(
            True, True, rotate=1, offset=False, ratio=0.5, mode=1, prob=0.7
        )
        self.use_grid_mask = use_grid_mask
        self.prev_scene_token = None
        self.num_frame_head_grads = num_frame_head_grads
        self.num_frame_backbone_grads = num_frame_backbone_grads
        self.num_frame_losses = num_frame_losses
        self.single_test = single_test
        self.stride = stride
        self.position_level = position_level
        self.aux_2d_only = aux_2d_only
        self.test_flag = False
        self.num_pruned = num_pruned
        self.num_query = num_query
        self.num_propagated = num_propagated
        self.masked_query_idxs = []
        self.reset_records()

    def reset_records(self):
        self.query_records = None

    def update_records(self, cls_scores: Tensor):
        """更新 query_records
        Args:
            cls_scores (Tensor): [B, Nq, C]
        """

        if (
            self.pts_bbox_head.num_query == self.num_query
            and self.pts_bbox_head.num_propagated == self.num_propagated
        ):
            return

        # [B, Nq]
        cls_scores, _ = cls_scores.max(dim=-1)
        # [Nq,]
        cls_scores = cls_scores.mean(dim=0)
        # [1, Nq]
        cls_scores = cls_scores.reshape(1, -1)

        if self.query_records is None:
            self.query_records = cls_scores
        else:
            self.query_records = torch.cat([self.query_records, cls_scores], dim=0)

    def prune(self):

        if self.query_records is None:
            return
        if (
            self.pts_bbox_head.num_query == self.num_query
            and self.pts_bbox_head.num_propagated == self.num_propagated
        ):
            # pre-defined queries 与 propagated queries 都裁剪完了
            return

        # 裁剪参数
        cls_scores = self.query_records
        cls_scores = cls_scores.mean(dim=0)

        dtype = cls_scores.dtype
        device = cls_scores.device
        idxs = torch.arange(
            start=0,
            end=self.pts_bbox_head.num_query + self.pts_bbox_head.num_propagated,
        ).to(dtype=dtype, device=device)
        cls_scores = torch.stack([idxs, cls_scores])
        _, sorted_idxs = cls_scores[1].sort()
        cls_scores = cls_scores[:, sorted_idxs]

        i = 0
        while True:
            x = cls_scores[:, i]
            idx = int(x[0].to(dtype=torch.long))

            if idx >= self.pts_bbox_head.num_query:
                # 先裁剪 pre-defined queries，再裁剪 propagated queries
                if self.pts_bbox_head.num_query > self.num_query:
                    i += 1
                    continue
                else:
                    # if self.pts_bbox_head.num_propagated > self.num_propagated:
                    # 落在 num_propagated 中
                    self.pts_bbox_head.num_propagated -= 1
                    break
            else:
                if self.pts_bbox_head.num_query > self.num_query:
                    # pre-defined queries 尚未裁剪完
                    N = len(self.pts_bbox_head.reference_points)
                    self.pts_bbox_head.deprecated_reference_points.append(
                        self.pts_bbox_head.reference_points[idx]
                    )
                    self.pts_bbox_head.reference_points = nn.ParameterList(
                        [
                            self.pts_bbox_head.reference_points[j]
                            for j in range(N)
                            if j != idx
                        ]
                    )
                    # 同时裁剪 num_query 与 num_propagated
                    self.pts_bbox_head.num_query = len(
                        self.pts_bbox_head.reference_points
                    )
                    self.masked_query_idxs.append(idx)

                    break
                else:
                    # pre-defined queries 已经裁剪完了
                    i += 1
                    continue

        self.reset_records()

    def prune2(self):

        if self.query_records is None:
            return
        if (
            self.pts_bbox_head.num_query == self.num_query
            and self.pts_bbox_head.num_propagated == self.num_propagated
        ):
            # pre-defined queries 与 propagated queries 都裁剪完了
            return

        # 裁剪参数
        cls_scores = self.query_records
        cls_scores = cls_scores.mean(dim=0)

        dtype = cls_scores.dtype
        device = cls_scores.device
        idxs = torch.arange(
            start=0,
            end=self.pts_bbox_head.num_query + self.pts_bbox_head.num_propagated,
        ).to(dtype=dtype, device=device)
        cls_scores = torch.stack([idxs, cls_scores])
        _, sorted_idxs = cls_scores[1].sort()
        cls_scores = cls_scores[:, sorted_idxs]

        i = 0
        while True:
            x = cls_scores[:, i]
            idx = int(x[0].to(dtype=torch.long))

            if idx < self.pts_bbox_head.num_query:
                if self.pts_bbox_head.num_propagated > self.num_propagated:
                    i += 1
                    continue
                else:
                    N = len(self.pts_bbox_head.reference_points)
                    self.pts_bbox_head.deprecated_reference_points.append(
                        self.pts_bbox_head.reference_points[idx]
                    )
                    self.pts_bbox_head.reference_points = nn.ParameterList(
                        [
                            self.pts_bbox_head.reference_points[j]
                            for j in range(N)
                            if j != idx
                        ]
                    )
                    self.pts_bbox_head.num_query = len(
                        self.pts_bbox_head.reference_points
                    )
                    self.masked_query_idxs.append(idx)

                    break
            else:
                if self.pts_bbox_head.num_propagated > self.num_propagated:
                    self.pts_bbox_head.num_propagated -= 1
                    break
                else:
                    i += 1
                    continue

        self.reset_records()

    def extract_img_feat(self, img, len_queue=1, training_mode=False):
        """Extract features of images."""
        B = img.size(0)

        if img is not None:
            if img.dim() == 6:
                img = img.flatten(1, 2)
            if img.dim() == 5 and img.size(0) == 1:
                img.squeeze_()
            elif img.dim() == 5 and img.size(0) > 1:
                B, N, C, H, W = img.size()
                img = img.reshape(B * N, C, H, W)
            if self.use_grid_mask:
                img = self.grid_mask(img)

            img_feats = self.img_backbone(img)
            if isinstance(img_feats, dict):
                img_feats = list(img_feats.values())
        else:
            return None
        if self.with_img_neck:
            img_feats = self.img_neck(img_feats)

        BN, C, H, W = img_feats[self.position_level].size()
        if self.training or training_mode:
            img_feats_reshaped = img_feats[self.position_level].view(
                B, len_queue, int(BN / B / len_queue), C, H, W
            )
        else:
            img_feats_reshaped = img_feats[self.position_level].view(
                B, int(BN / B / len_queue), C, H, W
            )

        return img_feats_reshaped

    @auto_fp16(apply_to=("img"), out_fp32=True)
    def extract_feat(self, img, T, training_mode=False):
        """Extract features from images and points."""
        img_feats = self.extract_img_feat(img, T, training_mode)
        return img_feats

    def obtain_history_memory(
        self,
        gt_bboxes_3d=None,
        gt_labels_3d=None,
        gt_bboxes=None,
        gt_labels=None,
        img_metas=None,
        centers2d=None,
        depths=None,
        gt_bboxes_ignore=None,
        **data
    ):
        losses = dict()
        T = data["img"].size(1)
        num_nograd_frames = T - self.num_frame_head_grads
        num_grad_losses = T - self.num_frame_losses
        for i in range(T):
            requires_grad = False
            return_losses = False
            data_t = dict()
            for key in data:
                data_t[key] = data[key][:, i]

            data_t["img_feats"] = data_t["img_feats"]
            if i >= num_nograd_frames:
                requires_grad = True
            if i >= num_grad_losses:
                return_losses = True
            loss = self.forward_pts_train(
                gt_bboxes_3d[i],
                gt_labels_3d[i],
                gt_bboxes[i],
                gt_labels[i],
                img_metas[i],
                centers2d[i],
                depths[i],
                requires_grad=requires_grad,
                return_losses=return_losses,
                **data_t
            )
            if loss is not None:
                for key, value in loss.items():
                    losses["frame_" + str(i) + "_" + key] = value
        return losses

    def prepare_location(self, img_metas, **data):
        pad_h, pad_w, _ = img_metas[0]["pad_shape"][0]
        bs, n = data["img_feats"].shape[:2]
        x = data["img_feats"].flatten(0, 1)
        location = locations(x, self.stride, pad_h, pad_w)[None].repeat(bs * n, 1, 1, 1)
        return location

    def forward_roi_head(self, location, **data):
        if (self.aux_2d_only and not self.training) or not self.with_img_roi_head:
            return {"topk_indexes": None}
        else:
            outs_roi = self.img_roi_head(location, **data)
            return outs_roi

    def forward_pts_train(
        self,
        gt_bboxes_3d,
        gt_labels_3d,
        gt_bboxes,
        gt_labels,
        img_metas,
        centers2d,
        depths,
        requires_grad=True,
        return_losses=False,
        **data
    ):
        """Forward function for point cloud branch.
        Args:
            pts_feats (list[torch.Tensor]): Features of point cloud branch
            gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth
                boxes for each sample.
            gt_labels_3d (list[torch.Tensor]): Ground truth labels for
                boxes of each sampole
            img_metas (list[dict]): Meta information of samples.
            gt_bboxes_ignore (list[torch.Tensor], optional): Ground truth
                boxes to be ignored. Defaults to None.
        Returns:
            dict: Losses of each branch.
        """
        location = self.prepare_location(img_metas, **data)

        if not requires_grad:
            self.eval()
            with torch.no_grad():
                outs = self.pts_bbox_head(location, img_metas, None, **data)
            self.train()

        else:
            outs_roi = self.forward_roi_head(location, **data)
            topk_indexes = outs_roi["topk_indexes"]
            outs = self.pts_bbox_head(location, img_metas, topk_indexes, **data)

        cls_scores = outs["all_cls_scores"][-1]
        self.update_records(cls_scores)

        if return_losses:
            loss_inputs = [gt_bboxes_3d, gt_labels_3d, outs]
            losses = self.pts_bbox_head.loss(*loss_inputs)
            if self.with_img_roi_head:
                loss2d_inputs = [
                    gt_bboxes,
                    gt_labels,
                    centers2d,
                    depths,
                    outs_roi,
                    img_metas,
                ]
                losses2d = self.img_roi_head.loss(*loss2d_inputs)
                losses.update(losses2d)

            return losses
        else:
            return None

    @force_fp32(apply_to=("img"))
    def forward(self, return_loss=True, **data):
        if return_loss:
            for key in [
                "gt_bboxes_3d",
                "gt_labels_3d",
                "gt_bboxes",
                "gt_labels",
                "centers2d",
                "depths",
                "img_metas",
            ]:
                data[key] = list(zip(*data[key]))
            return self.forward_train(**data)
        else:
            return self.forward_test(**data)

    def forward_train(
        self,
        img_metas=None,
        gt_bboxes_3d=None,
        gt_labels_3d=None,
        gt_labels=None,
        gt_bboxes=None,
        gt_bboxes_ignore=None,
        depths=None,
        centers2d=None,
        **data
    ):
        if self.test_flag:  # for interval evaluation
            self.pts_bbox_head.reset_memory()
            self.test_flag = False
        # data['img'].shape=(bs, que_len, n_views, c, h, w)
        T = data["img"].size(1)

        prev_img = data["img"][:, : -self.num_frame_backbone_grads]
        rec_img = data["img"][:, -self.num_frame_backbone_grads :]
        rec_img_feats = self.extract_feat(rec_img, self.num_frame_backbone_grads)

        if T - self.num_frame_backbone_grads > 0:
            self.eval()
            with torch.no_grad():
                prev_img_feats = self.extract_feat(
                    prev_img, T - self.num_frame_backbone_grads, True
                )
            self.train()
            data["img_feats"] = torch.cat([prev_img_feats, rec_img_feats], dim=1)
        else:
            data["img_feats"] = rec_img_feats

        losses = self.obtain_history_memory(
            gt_bboxes_3d,
            gt_labels_3d,
            gt_bboxes,
            gt_labels,
            img_metas,
            centers2d,
            depths,
            gt_bboxes_ignore,
            **data
        )

        return losses

    def forward_test(self, img_metas, rescale, **data):
        self.test_flag = True
        for var, name in [(img_metas, "img_metas")]:
            if not isinstance(var, list):
                raise TypeError("{} must be a list, but got {}".format(name, type(var)))
        for key in data:
            if key != "img":
                data[key] = data[key][0][0].unsqueeze(0)
            else:
                data[key] = data[key][0]
        return self.simple_test(img_metas[0], **data)

    def simple_test_pts(self, img_metas, **data):
        """Test function of point cloud branch."""
        location = self.prepare_location(img_metas, **data)
        outs_roi = self.forward_roi_head(location, **data)
        topk_indexes = outs_roi["topk_indexes"]

        if img_metas[0]["scene_token"] != self.prev_scene_token:
            self.prev_scene_token = img_metas[0]["scene_token"]
            data["prev_exists"] = data["img"].new_zeros(1)
            self.pts_bbox_head.reset_memory()
        else:
            data["prev_exists"] = data["img"].new_ones(1)

        outs = self.pts_bbox_head(location, img_metas, topk_indexes, **data)
        bbox_list = self.pts_bbox_head.get_bboxes(outs, img_metas)
        bbox_results = [
            bbox3d2result(bboxes, scores, labels)
            for bboxes, scores, labels in bbox_list
        ]
        return bbox_results

    def simple_test(self, img_metas, **data):
        """Test function without augmentaiton."""
        data["img_feats"] = self.extract_img_feat(data["img"], 1)

        bbox_list = [dict() for i in range(len(img_metas))]
        bbox_pts = self.simple_test_pts(img_metas, **data)
        for result_dict, pts_bbox in zip(bbox_list, bbox_pts):
            result_dict["pts_bbox"] = pts_bbox
        return bbox_list
