import numpy as np
import torch
import torch.nn as nn
from mmengine.model.weight_init import xavier_init, constant_init
from mmcv.cnn.bricks.transformer import build_transformer_layer_sequence
from mmengine.model import BaseModule

from mmengine.registry import MODELS


@MODELS.register_module()
class PerceptionTransformer(BaseModule):
    """Implements the Detr3D transformer.
    Args:
        as_two_stage (bool): Generate query from encoder features.
            Default: False.
        num_feature_levels (int): Number of feature maps from FPN:
            Default: 4.
        two_stage_num_proposals (int): Number of proposals when set
            `as_two_stage` as True. Default: 300.
    """

    def __init__(self,
                 num_feature_levels=4,
                 num_cams=6,
                 encoder=None,
                 embed_dims=256,
                 use_can_bus=False,
                 can_bus_norm=False,
                 use_cams_embeds=True,
                 **kwargs):
        super(PerceptionTransformer, self).__init__(**kwargs)
        self.encoder = build_transformer_layer_sequence(encoder)
        self.embed_dims = embed_dims
        self.num_feature_levels = num_feature_levels
        self.num_cams = num_cams

        self.use_can_bus = use_can_bus
        self.can_bus_norm = can_bus_norm
        self.use_cams_embeds = use_cams_embeds

        self.init_layers()

    def init_layers(self):
        """Initialize layers of the Detr3DTransformer."""
        self.level_embeds = nn.Parameter(torch.Tensor(
            self.num_feature_levels, self.embed_dims))
        self.cams_embeds = nn.Parameter(
            torch.Tensor(self.num_cams, self.embed_dims))
        # reference_points是在后面decoder的时候有用，这里去掉
        # self.reference_points = nn.Linear(self.embed_dims, 3)
        # 理论上只对camera feature进行spatial attention是不用can_bus_mlp的
        self.can_bus_mlp = nn.Sequential(
            nn.Linear(18, self.embed_dims // 2),
            nn.ReLU(inplace=True),
            nn.Linear(self.embed_dims // 2, self.embed_dims),
            nn.ReLU(inplace=True),
        )
        if self.can_bus_norm:
            self.can_bus_mlp.add_module('norm', nn.LayerNorm(self.embed_dims))

    def init_weights(self):
        """Initialize the transformer weights."""
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)
        for m in self.modules():
            if isinstance(m, MSDeformableAttention3D) or isinstance(m, TemporalSelfAttention) \
                    or isinstance(m, CustomMSDeformableAttention):
                try:
                    m.init_weight()
                except AttributeError:
                    m.init_weights()
        normal_(self.level_embeds)
        normal_(self.cams_embeds)
        xavier_init(self.reference_points, distribution='uniform', bias=0.)
        xavier_init(self.can_bus_mlp, distribution='uniform', bias=0.)


    # TODO 注意一下shift，shift只有在做temporal的时候有用，这里去掉
                # mlvl_feats,
                # bev_queries,
                # self.bev_h,
                # self.bev_w,
                # grid_length=(self.real_h / self.bev_h,
                #              self.real_w / self.bev_w),
                # bev_pos=bev_pos,
                # lidar2image=lidar2image,
                # lidar_aug_matrix=lidar_aug_matrix,
    def get_bev_features(
            self,
            mlvl_feats,
            bev_queries,
            bev_h,
            bev_w,
            grid_length=[0.512, 0.512],
            bev_pos=None,
            lidar2image=None,
            lidar_aug_matrix=None,
            metas=None,
            **kwargs):
        """
        obtain bev features.
        """

        bs = mlvl_feats[0].size(0)
        bev_queries = bev_queries.unsqueeze(1).repeat(1, bs, 1)
        bev_pos = bev_pos.flatten(2).permute(2, 0, 1)

        # obtain rotation angle and shift with ego motion
        # 下面这些计算主要就是为了获得shift，但是纯spatial的好像永不上先删了
        # delta_x = np.array([each['can_bus'][0]
        #                    for each in kwargs['img_metas']])
        # delta_y = np.array([each['can_bus'][1]
        #                    for each in kwargs['img_metas']])
        # ego_angle = np.array(
        #     [each['can_bus'][-2] / np.pi * 180 for each in kwargs['img_metas']])
        # grid_length_y = grid_length[0]
        # grid_length_x = grid_length[1]
        # translation_length = np.sqrt(delta_x ** 2 + delta_y ** 2)
        # translation_angle = np.arctan2(delta_y, delta_x) / np.pi * 180
        # bev_angle = ego_angle - translation_angle
        # shift_y = translation_length * \
        #     np.cos(bev_angle / 180 * np.pi) / grid_length_y / bev_h
        # shift_x = translation_length * \
        #     np.sin(bev_angle / 180 * np.pi) / grid_length_x / bev_w
        # shift_y = shift_y * self.use_shift
        # shift_x = shift_x * self.use_shift
        # shift = bev_queries.new_tensor(
        #     [shift_x, shift_y]).permute(1, 0)  # xy, bs -> bs, xy

        # if prev_bev is not None:
        #     if prev_bev.shape[1] == bev_h * bev_w:
        #         prev_bev = prev_bev.permute(1, 0, 2)
        #     if self.rotate_prev_bev:
        #         for i in range(bs):
        #             # num_prev_bev = prev_bev.size(1)
        #             rotation_angle = kwargs['img_metas'][i]['can_bus'][-1]
        #             tmp_prev_bev = prev_bev[:, i].reshape(
        #                 bev_h, bev_w, -1).permute(2, 0, 1)
        #             tmp_prev_bev = rotate(tmp_prev_bev, rotation_angle,
        #                                   center=self.rotate_center)
        #             tmp_prev_bev = tmp_prev_bev.permute(1, 2, 0).reshape(
        #                 bev_h * bev_w, 1, -1)
        #             prev_bev[:, i] = tmp_prev_bev[:, 0]

        # TODO add can bus signals,这里还稍微有点麻烦，可能会用上canbus的信号
        # TODO 目前的数据是不支持canbus的，所以先去掉
        # can_bus = bev_queries.new_tensor(
        #     [each['can_bus'] for each in kwargs['img_metas']])  # [:, :]
        # can_bus = self.can_bus_mlp(can_bus)[None, :, :]
        # bev_queries = bev_queries + can_bus * self.use_can_bus

        feat_flatten = []
        spatial_shapes = []
        for lvl, feat in enumerate(mlvl_feats):
            bs, num_cam, c, h, w = feat.shape
            spatial_shape = (h, w)
            feat = feat.flatten(3).permute(1, 0, 3, 2)
            if self.use_cams_embeds:
                feat = feat + self.cams_embeds[:, None, None, :].to(feat.dtype)
            feat = feat + self.level_embeds[None,
                                            None, lvl:lvl + 1, :].to(feat.dtype)
            spatial_shapes.append(spatial_shape)
            feat_flatten.append(feat)

        feat_flatten = torch.cat(feat_flatten, 2)
        spatial_shapes = torch.as_tensor(
            spatial_shapes, dtype=torch.long, device=bev_pos.device)
        level_start_index = torch.cat((spatial_shapes.new_zeros(
            (1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))

        feat_flatten = feat_flatten.permute(
            0, 2, 1, 3)  # (num_cam, H*W, bs, embed_dims)

        bev_embed = self.encoder(
            bev_queries,
            feat_flatten,
            feat_flatten,
            bev_h=bev_h,
            bev_w=bev_w,
            bev_pos=bev_pos,
            spatial_shapes=spatial_shapes,
            level_start_index=level_start_index,
            lidar2image=lidar2image,
            lidar_aug_matrix=lidar_aug_matrix,
            metas=metas,
            **kwargs
        )
        
        return bev_embed