# ---------------------------------------------
# Copyright (c) OpenMMLab. All rights reserved.
# ---------------------------------------------
#  Modified by Zhiqi Li
# ---------------------------------------------

from projects.mmdet3d_plugin.models.utils.bricks import run_time
from projects.mmdet3d_plugin.models.utils.visual import save_tensor
from .custom_base_transformer_layer import MyCustomBaseTransformerLayer
import copy
import warnings
from mmcv.cnn.bricks.registry import (ATTENTION,
                                      TRANSFORMER_LAYER,
                                      TRANSFORMER_LAYER_SEQUENCE)
from mmcv.cnn.bricks.transformer import TransformerLayerSequence
from mmcv.runner import force_fp32, auto_fp16
import numpy as np
import torch
import cv2 as cv
import mmcv
from mmcv.utils import TORCH_VERSION, digit_version
from mmcv.utils import ext_loader
ext_module = ext_loader.load_ext(
    '_ext', ['ms_deform_attn_backward', 'ms_deform_attn_forward'])

@TRANSFORMER_LAYER_SEQUENCE.register_module()
class VoxelFormerEncoder(TransformerLayerSequence):

    """
    Attention with both self and cross
    Implements the decoder in DETR transformer.
    Args:
        return_intermediate (bool): Whether to return intermediate outputs.
        coder_norm_cfg (dict): Config of last normalization layer. Default：
            `LN`.
    """

    def __init__(self, *args, pc_range=None, num_points_in_pillar=4, num_points_in_voxel=1,
                 return_intermediate=False, dataset_type='nuscenes',
                 **kwargs):

        super(VoxelFormerEncoder, self).__init__(*args, **kwargs)
        self.return_intermediate = return_intermediate

        self.num_points_in_pillar = num_points_in_pillar
        self.num_points_in_voxel = num_points_in_voxel
        self.pc_range = pc_range
        self.fp16_enabled = False

    @staticmethod
    def get_reference_points(bev_z, bev_h, bev_w, num_points_in_voxel=1, dim='3d', bs=1, device='cuda', dtype=torch.float):
        """Get the reference points used in SCA and TSA.
        Args:
            bev_z, bev_h, bev_w: spatial shape of voxel.
            D: sample D points uniformly from each voxel.
            device (obj:`device`): The device where
                reference_points should be.
        Returns:
            Tensor: reference points used in decoder, has \
                shape (bs, num_keys, num_levels, 2).
        """

        # reference points in 3D space, used in spatial cross-attention (SCA)
        if dim == '3d':
            # only sample the center from each voxel
            zs = torch.linspace(0.5, bev_z - 0.5, bev_z, dtype=dtype,
                                device=device).view(1, bev_z, 1, 1).expand(1, bev_z, bev_h, bev_w) / bev_z
            ys = torch.linspace(0.5, bev_h - 0.5, bev_h, dtype=dtype,
                                device=device).view(1, 1, bev_h, 1).expand(1, bev_z, bev_h, bev_w) / bev_h
            xs = torch.linspace(0.5, bev_w - 0.5, bev_w, dtype=dtype,
                                device=device).view(1, 1, 1, bev_w).expand(1, bev_z, bev_h, bev_w) / bev_w
            
            ref_3d = torch.stack((xs, ys, zs), -1)  # (D, bev_z, bev_h, bev_w, 3)
            ref_3d = ref_3d.permute(0, 4, 1, 2, 3).flatten(2).permute(0, 2, 1)  # (D, num_query, 3)
            ref_3d = ref_3d[None].repeat(bs, 1, 1, 1)  # (bs, D, num_query, 3)
            
            if num_points_in_voxel > 1:
                num = num_points_in_voxel
                delta_z, delta_y, delta_x = 0.5/bev_z, 0.5/bev_h, 0.5/bev_w
                # the offset of sampling point from the voxel center is in the range [-delta, delta]
                zs_offset = torch.linspace(-delta_z, delta_z, num+2, dtype=dtype, device=device)[1:-1]
                zs_offset = zs_offset.view(num, 1, 1, 1).expand(num, bev_z, bev_h, bev_w)

                ys_offset = torch.linspace(-delta_y, delta_y, num+2, dtype=dtype, device=device)[1:-1]
                ys_offset = ys_offset.view(num, 1, 1, 1).expand(num, bev_z, bev_h, bev_w)

                xs_offset = torch.linspace(-delta_x, delta_x, num+2, dtype=dtype, device=device)[1:-1]
                xs_offset = xs_offset.view(num, 1, 1, 1).expand(num, bev_z, bev_h, bev_w)
                
                offset_3d = torch.stack((xs_offset, ys_offset, zs_offset), -1)  # (num, bev_z, bev_h, bev_w, 3)
                offset_3d = offset_3d.permute(0, 4, 1, 2, 3).flatten(2).permute(0, 2, 1)  # (num, num_query, 3)
                offset_3d = offset_3d[None].repeat(bs, 1, 1, 1)  # (bs, num, num_query, 3)
                ref_3d = offset_3d + ref_3d
                
                # num = num_points_in_voxel
                # delta_z, delta_y, delta_x = 0.5/bev_z, 0.5/bev_h, 0.5/bev_w
                # # the offset of sampling point from the voxel center is in the range [-delta, delta]
                # zs_offset_p = torch.linspace(0, delta_z, num//2+2, dtype=dtype, device=device)[1:-1]
                # zs_offset_m = torch.linspace(-delta_z, 0, num//2+2, dtype=dtype, device=device)[1:-1]
                # zs_offset = torch.cat([zs_offset_m, zs_offset_p], axis=0).view(num, 1, 1, 1).expand(num, bev_z, bev_h, bev_w)

                # ys_offset_p = torch.linspace(0, delta_y, num//2+2, dtype=dtype, device=device)[1:-1]
                # ys_offset_m = torch.linspace(-delta_y, 0, num//2+2, dtype=dtype, device=device)[1:-1]
                # ys_offset = torch.cat([ys_offset_m, ys_offset_p], axis=0).view(num, 1, 1, 1).expand(num, bev_z, bev_h, bev_w)

                # xs_offset_p = torch.linspace(0, delta_x, num//2+2, dtype=dtype, device=device)[1:-1]
                # xs_offset_m = torch.linspace(-delta_x, 0, num//2+2, dtype=dtype, device=device)[1:-1]
                # xs_offset = torch.cat([xs_offset_m, xs_offset_p], axis=0).view(num, 1, 1, 1).expand(num, bev_z, bev_h, bev_w)
                
                # offset_3d = torch.stack((xs_offset, ys_offset, zs_offset), -1)  # (num, bev_z, bev_h, bev_w, 3)
                # offset_3d = offset_3d.permute(0, 4, 1, 2, 3).flatten(2).permute(0, 2, 1)  # (num, num_query, 3)
                # offset_3d = offset_3d[None].repeat(bs, 1, 1, 1)  # (bs, num, num_query, 3)
                # offset_3d = offset_3d + ref_3d
                # ref_3d = torch.cat((ref_3d, offset_3d), dim=1)  # (bs, num+1, num_query, 3)

            return ref_3d

        # reference points on 2D bev plane, used in temporal self-attention (TSA).
        elif dim == '2d':
            # ref_y, ref_x = torch.meshgrid(
            #     torch.linspace(
            #         0.5, bev_h - 0.5, bev_h, dtype=dtype, device=device),
            #     torch.linspace(
            #         0.5, bev_w - 0.5, bev_w, dtype=dtype, device=device)
            # )
            ref_z, ref_y, ref_x = torch.meshgrid(
                torch.linspace(0.5,
                               bev_z - 0.5,
                               bev_z,
                               dtype=dtype,
                               device=device),
                torch.linspace(0.5,
                               bev_h - 0.5,
                               bev_h,
                               dtype=dtype,
                               device=device),
                torch.linspace(0.5,
                               bev_w - 0.5,
                               bev_w,
                               dtype=dtype,
                               device=device)
            )  # shape: (bev_z, bev_h, bev_w)
            ref_z = ref_z.reshape(-1)[None] / bev_z
            ref_y = ref_y.reshape(-1)[None] / bev_h
            ref_x = ref_x.reshape(-1)[None] / bev_w
            ref_2d = torch.stack((ref_x, ref_y, ref_z), -1)
            ref_2d = ref_2d.repeat(bs, 1, 1).unsqueeze(2)  # (bs, num_query, 1, 3)
            return ref_2d

    # This function must use fp32!!!
    @force_fp32(apply_to=('reference_points', 'img_metas'))
    def point_sampling(self, reference_points, pc_range,  img_metas):

        lidar2img = []
        for img_meta in img_metas:
            lidar2img.append(img_meta['lidar2img'])
        lidar2img = np.asarray(lidar2img)
        lidar2img = reference_points.new_tensor(lidar2img)  # (B, N, 4, 4)
        reference_points = reference_points.clone()

        reference_points[..., 0:1] = reference_points[..., 0:1] * \
            (pc_range[3] - pc_range[0]) + pc_range[0]
        reference_points[..., 1:2] = reference_points[..., 1:2] * \
            (pc_range[4] - pc_range[1]) + pc_range[1]
        reference_points[..., 2:3] = reference_points[..., 2:3] * \
            (pc_range[5] - pc_range[2]) + pc_range[2]

        reference_points = torch.cat(
            (reference_points, torch.ones_like(reference_points[..., :1])), -1)  # (bs, D, num_query, 4)

        reference_points = reference_points.permute(1, 0, 2, 3)
        D, B, num_query = reference_points.size()[:3]
        num_cam = lidar2img.size(1)

        reference_points = reference_points.view(
            D, B, 1, num_query, 4).repeat(1, 1, num_cam, 1, 1).unsqueeze(-1)

        lidar2img = lidar2img.view(
            1, B, num_cam, 1, 4, 4).repeat(D, 1, 1, num_query, 1, 1)  # (D, B, num_cam, num_query, 4, 4)

        reference_points_cam = torch.matmul(lidar2img.to(torch.float32),
                                            reference_points.to(torch.float32)).squeeze(-1)  # (D, B, num_cam, num_query, 4)

        eps = 1e-5

        bev_mask = (reference_points_cam[..., 2:3] > eps)  # (D, B, num_cam, num_query, 1)
        reference_points_cam = reference_points_cam[..., 0:2] / torch.maximum(
            reference_points_cam[..., 2:3], torch.ones_like(reference_points_cam[..., 2:3]) * eps)  # in pixel system

        reference_points_cam[..., 0] /= img_metas[0]['img_shape'][0][1]
        reference_points_cam[..., 1] /= img_metas[0]['img_shape'][0][0]

        bev_mask = (bev_mask & (reference_points_cam[..., 1:2] > 0.0)
                    & (reference_points_cam[..., 1:2] < 1.0)
                    & (reference_points_cam[..., 0:1] < 1.0)
                    & (reference_points_cam[..., 0:1] > 0.0))
        if digit_version(TORCH_VERSION) >= digit_version('1.8'):
            bev_mask = torch.nan_to_num(bev_mask)
        else:
            bev_mask = bev_mask.new_tensor(
                np.nan_to_num(bev_mask.cpu().numpy()))

        reference_points_cam = reference_points_cam.permute(2, 1, 3, 0, 4)  # (num_cam, B, num_query, D, 2)
        bev_mask = bev_mask.permute(2, 1, 3, 0, 4).squeeze(-1)  # (num_cam, B, num_query, D)

        return reference_points_cam, bev_mask

    @auto_fp16()
    def forward(self,
                bev_query,
                key,
                value,
                *args,
                bev_z=None,
                bev_h=None,
                bev_w=None,
                bev_pos=None,
                spatial_shapes=None,
                level_start_index=None,
                valid_ratios=None,
                prev_bev=None,
                shift=0.,
                **kwargs):
        """Forward function for `TransformerDecoder`.
        Args:
            bev_query (Tensor): Input BEV query with shape
                `(num_query, bs, embed_dims)`.
            key & value (Tensor): Input multi-cameta features with shape
                (num_cam, num_value, bs, embed_dims)
            reference_points (Tensor): The reference
                points of offset. has shape
                (bs, num_query, 4) when as_two_stage,
                otherwise has shape ((bs, num_query, 2).
            valid_ratios (Tensor): The radios of valid
                points on the feature map, has shape
                (bs, num_levels, 2)
        Returns:
            Tensor: Results with shape [1, num_query, bs, embed_dims] when
                return_intermediate is `False`, otherwise it has shape
                [num_layers, num_query, bs, embed_dims].
        """

        output = bev_query
        intermediate = []

        ref_3d = self.get_reference_points(
            bev_z, bev_h, bev_w, self.num_points_in_voxel, dim='3d', bs=bev_query.size(1),  device=bev_query.device, dtype=bev_query.dtype)
        
        # the following ref_2d is actualy 3d:  (bs, num_query, 1, 3)
        ref_2d = self.get_reference_points(
            bev_z, bev_h, bev_w, dim='2d', bs=bev_query.size(1), device=bev_query.device, dtype=bev_query.dtype)

        reference_points_cam, bev_mask = self.point_sampling(
            ref_3d, self.pc_range, kwargs['img_metas'])

        # bug: this code should be 'shift_ref_2d = ref_2d.clone()', we keep this bug for reproducing our results in paper.
        shift_ref_2d = ref_2d  # .clone()
        shift3d = shift.new_zeros(1, 3)
        shift3d[:, :2] = shift
        shift_ref_2d += shift3d[:, None, None, :]

        # (num_query, bs, embed_dims) -> (bs, num_query, embed_dims)
        bev_query = bev_query.permute(1, 0, 2)
        bev_pos = bev_pos.permute(1, 0, 2)
        bs, num_query, num_bev_level, _ = ref_2d.shape  # (bs, num_query, 1, 2)
        if prev_bev is not None:
            prev_bev = prev_bev.permute(1, 0, 2)
            prev_bev = torch.stack(
                [prev_bev, bev_query], 1).reshape(bs*2, num_query, -1)
            hybird_ref_2d = torch.stack([shift_ref_2d, ref_2d], 1).reshape(
                bs*2, num_query, num_bev_level, 3)
        else:
            hybird_ref_2d = torch.stack([ref_2d, ref_2d], 1).reshape(
                bs*2, num_query, num_bev_level, 3)

        for lid, layer in enumerate(self.layers):
            output = layer(
                bev_query,
                key,
                value,
                *args,
                bev_pos=bev_pos,
                ref_2d=hybird_ref_2d,
                ref_3d=ref_3d,
                bev_z=bev_z,
                bev_h=bev_h,
                bev_w=bev_w,
                spatial_shapes=spatial_shapes,
                level_start_index=level_start_index,
                reference_points_cam=reference_points_cam,
                bev_mask=bev_mask,
                prev_bev=prev_bev,
                **kwargs)

            bev_query = output
            if self.return_intermediate:
                intermediate.append(output)

        if self.return_intermediate:
            return torch.stack(intermediate)

        return output


@TRANSFORMER_LAYER.register_module()
class VoxelFormerLayer(MyCustomBaseTransformerLayer):
    """Implements decoder layer in DETR transformer.
    Args:
        attn_cfgs (list[`mmcv.ConfigDict`] | list[dict] | dict )):
            Configs for self_attention or cross_attention, the order
            should be consistent with it in `operation_order`. If it is
            a dict, it would be expand to the number of attention in
            `operation_order`.
        feedforward_channels (int): The hidden dimension for FFNs.
        ffn_dropout (float): Probability of an element to be zeroed
            in ffn. Default 0.0.
        operation_order (tuple[str]): The execution order of operation
            in transformer. Such as ('self_attn', 'norm', 'ffn', 'norm').
            Default：None
        act_cfg (dict): The activation config for FFNs. Default: `LN`
        norm_cfg (dict): Config dict for normalization layer.
            Default: `LN`.
        ffn_num_fcs (int): The number of fully-connected layers in FFNs.
            Default：2.
    """

    def __init__(self,
                 attn_cfgs,
                 feedforward_channels,
                 ffn_dropout=0.0,
                 operation_order=None,
                 act_cfg=dict(type='ReLU', inplace=True),
                 norm_cfg=dict(type='LN'),
                 ffn_num_fcs=2,
                 **kwargs):
        super(VoxelFormerLayer, self).__init__(
            attn_cfgs=attn_cfgs,
            feedforward_channels=feedforward_channels,
            ffn_dropout=ffn_dropout,
            operation_order=operation_order,
            act_cfg=act_cfg,
            norm_cfg=norm_cfg,
            ffn_num_fcs=ffn_num_fcs,
            **kwargs)
        self.fp16_enabled = False
        assert len(operation_order) == 6
        assert set(operation_order) == set(
            ['self_attn', 'norm', 'cross_attn', 'ffn'])

    def forward(self,
                query,
                key=None,
                value=None,
                bev_pos=None,
                query_pos=None,
                key_pos=None,
                attn_masks=None,
                query_key_padding_mask=None,
                key_padding_mask=None,
                ref_2d=None,
                ref_3d=None,
                bev_z=None,
                bev_h=None,
                bev_w=None,
                reference_points_cam=None,
                mask=None,
                spatial_shapes=None,
                level_start_index=None,
                prev_bev=None,
                **kwargs):
        """Forward function for `TransformerDecoderLayer`.

        **kwargs contains some specific arguments of attentions.

        Args:
            query (Tensor): The input query with shape
                [num_queries, bs, embed_dims] if
                self.batch_first is False, else
                [bs, num_queries embed_dims].
            key (Tensor): The key tensor with shape [num_keys, bs,
                embed_dims] if self.batch_first is False, else
                [bs, num_keys, embed_dims] .
            value (Tensor): The value tensor with same shape as `key`.
            query_pos (Tensor): The positional encoding for `query`.
                Default: None.
            key_pos (Tensor): The positional encoding for `key`.
                Default: None.
            attn_masks (List[Tensor] | None): 2D Tensor used in
                calculation of corresponding attention. The length of
                it should equal to the number of `attention` in
                `operation_order`. Default: None.
            query_key_padding_mask (Tensor): ByteTensor for `query`, with
                shape [bs, num_queries]. Only used in `self_attn` layer.
                Defaults to None.
            key_padding_mask (Tensor): ByteTensor for `query`, with
                shape [bs, num_keys]. Default: None.

        Returns:
            Tensor: forwarded results with shape [num_queries, bs, embed_dims].
        """

        norm_index = 0
        attn_index = 0
        ffn_index = 0
        identity = query
        if attn_masks is None:
            attn_masks = [None for _ in range(self.num_attn)]
        elif isinstance(attn_masks, torch.Tensor):
            attn_masks = [
                copy.deepcopy(attn_masks) for _ in range(self.num_attn)
            ]
            warnings.warn(f'Use same attn_mask in all attentions in '
                          f'{self.__class__.__name__} ')
        else:
            assert len(attn_masks) == self.num_attn, f'The length of ' \
                                                     f'attn_masks {len(attn_masks)} must be equal ' \
                                                     f'to the number of attention in ' \
                f'operation_order {self.num_attn}'

        for layer in self.operation_order:
            # temporal self attention
            if layer == 'self_attn':

                query = self.attentions[attn_index](
                    query,
                    prev_bev,
                    prev_bev,
                    identity if self.pre_norm else None,
                    query_pos=bev_pos,
                    key_pos=bev_pos,
                    attn_mask=attn_masks[attn_index],
                    key_padding_mask=query_key_padding_mask,
                    reference_points=ref_2d,
                    spatial_shapes=torch.tensor(
                        [[bev_z, bev_h, bev_w]], device=query.device),  # spatial_shapes
                    level_start_index=torch.tensor([0], device=query.device),
                    **kwargs)
                attn_index += 1
                identity = query

            elif layer == 'norm':
                query = self.norms[norm_index](query)
                norm_index += 1

            # spaital cross attention
            elif layer == 'cross_attn':
                query = self.attentions[attn_index](
                    query,
                    key,
                    value,
                    identity if self.pre_norm else None,
                    query_pos=query_pos,
                    key_pos=key_pos,
                    reference_points=ref_3d,
                    reference_points_cam=reference_points_cam,
                    mask=mask,
                    attn_mask=attn_masks[attn_index],
                    key_padding_mask=key_padding_mask,
                    spatial_shapes=spatial_shapes,
                    level_start_index=level_start_index,
                    **kwargs)
                attn_index += 1
                identity = query

            elif layer == 'ffn':
                query = self.ffns[ffn_index](
                    query, identity if self.pre_norm else None)
                ffn_index += 1

        return query
