import numpy as np
import torch
import copy
import warnings
from mmcv.cnn.bricks.transformer import TransformerLayerSequence, build_attention
from mmcv.utils import ext_loader

from mmengine.registry import MODELS
from mmengine.model import ModuleList

from .custom_base_transformer_layer import MyCustomBaseTransformerLayer

ext_module = ext_loader.load_ext(
    '_ext', ['ms_deform_attn_backward', 'ms_deform_attn_forward'])


# MGFM layer implementation
@MODELS.register_module()
class MM_BEVFormerLayer(MyCustomBaseTransformerLayer):
    """multi-modality fusion layer.
    """

    def __init__(self,
                 attn_cfgs,
                 feedforward_channels,
                 ffn_dropout=0.0,
                 num_modality=2,
                 operation_order=None,
                 act_cfg=dict(type='ReLU', inplace=True),
                 norm_cfg=dict(type='LN'),
                 ffn_num_fcs=2,
                 **kwargs):
        super(MM_BEVFormerLayer, self).__init__(
            attn_cfgs=attn_cfgs,
            feedforward_channels=feedforward_channels,
            ffn_dropout=ffn_dropout,
            operation_order=operation_order,
            act_cfg=act_cfg,
            norm_cfg=norm_cfg,
            ffn_num_fcs=ffn_num_fcs,
            **kwargs)
        assert len(operation_order) == 4 or len(operation_order) == 6
        assert set(operation_order) <= {'norm', 'cross_attn', 'ffn', 'self_attn'}, "attention fusion operation not found"


        self.num_modality = num_modality

        self.num_attn = operation_order.count('self_attn') + operation_order.count(
            'cross_attn') * self.num_modality

        self.attentions = ModuleList()
        index = 0
        for operation_name in operation_order:
            if operation_name in ['cross_attn']:

                if 'batch_first' in attn_cfgs[index]:
                    assert self.batch_first == attn_cfgs[index]['batch_first']
                else:
                    attn_cfgs[index]['batch_first'] = self.batch_first


                for modality in range(self.num_modality):
                    attention = build_attention(attn_cfgs[index + modality])
                    attention.operation_name = operation_name
                    self.attentions.append(attention)

                index += self.num_modality
            if operation_name in ['self_attn']:
                if 'batch_first' in attn_cfgs[index]:
                    assert self.batch_first == attn_cfgs[index]['batch_first']
                else:
                    attn_cfgs[index]['batch_first'] = self.batch_first

                attention = build_attention(attn_cfgs[index])
                attention.operation_name = operation_name
                self.attentions.append(attention)
                index += 1


        base_weight = 1.0 / self.num_modality

        self.cross_model_weights = torch.nn.Parameter(torch.full((self.num_modality,),
                                                                 base_weight), requires_grad=True)



    def forward(self,
                query,
                multi_modality_keys=None,
                multi_modality_values=None,
                value=None,
                attn_masks=None,
                bev_h=None,
                bev_w=None,
                reference_points=None,
                spatial_shapes=None,
                level_start_index=None,
                positional_encoding=None,
                **kwargs):
        """Forward function for `TransformerDecoderLayer`.

        **kwargs contains some specific arguments of attentions.

        Args:
            query (Tensor): The input query with shape
                [num_queries, bs, embed_dims] if
                self.batch_first is False, else
                [bs, num_queries embed_dims].
            multi_modality_keys (List[Tensor]): 每个modality的bev feature map，每个元素都是
                [bs, h, W， embed_dims] .
            multi_modality_values (List[Tensor]): 初始是key的一个深拷贝，每个modality的bev feature map，每个元素都是
                [bs, h, W， embed_dims] .
            reference_points(Tensor): [h*w, 2]
        Returns:
            Tensor: forwarded results with shape [num_queries, bs, embed_dims].
        """
        norm_index = 0
        attn_index = 0
        ffn_index = 0
        identity = query
        if attn_masks is None:
            attn_masks = [None for _ in range(self.num_attn)]
        elif isinstance(attn_masks, torch.Tensor):
            attn_masks = [
                copy.deepcopy(attn_masks) for _ in range(self.num_attn)
            ]
            warnings.warn(f'Use same attn_mask in all attentions in '
                          f'{self.__class__.__name__} ')
        else:
            assert len(attn_masks) == self.num_attn, f'The length of ' \
                                                     f'attn_masks {len(attn_masks)} must be equal ' \
                                                     f'to the number of attention in ' \
                                                     f'operation_order {self.num_attn}'

        for layer in self.operation_order:
            if layer == 'norm':
                query = self.norms[norm_index](query)
                norm_index += 1

            elif layer == 'cross_attn':
                new_query = torch.zeros_like(query)
                for i in range(len(multi_modality_keys)):
                    single_modality_query = self.attentions[attn_index + i](
                        query,
                        multi_modality_keys[i],
                        multi_modality_values[i],
                        identity if self.pre_norm else None,
                        reference_points=reference_points,
                        attn_mask=attn_masks[attn_index],
                        bev_h=bev_h,
                        bev_w=bev_w,
                        spatial_shapes=spatial_shapes,
                        level_start_index=level_start_index,
                        positional_encoding=positional_encoding,
                        **kwargs)

                    new_query = new_query + self.cross_model_weights[i] * single_modality_query

                weights_sum = self.cross_model_weights.sum()
                new_query = new_query / weights_sum
                query = new_query
                attn_index += self.num_modality
                identity = query

            elif layer == 'ffn':
                query = self.ffns[ffn_index](
                    query, identity if self.pre_norm else None)
                ffn_index += 1

            elif layer == 'self_attn':
                bs, num_query, channel = query.size()
                key = query.clone().reshape(bs, bev_h, bev_w, channel)
                value = query.clone().reshape(bs, bev_h, bev_w, channel)
                query = self.attentions[attn_index](
                    query,
                    key,
                    value,
                    identity if self.pre_norm else None,
                    reference_points=reference_points,
                    attn_mask=attn_masks[attn_index],
                    bev_h=bev_h,
                    bev_w=bev_w,
                    spatial_shapes=spatial_shapes,
                    level_start_index=level_start_index,
                    positional_encoding=positional_encoding,
                    **kwargs)

                attn_index += 1
                identity = query

        return query
