from mmengine.registry import MODELS
from mmengine.model import BaseModule
import torch.nn.functional as F
import torch, torch.nn as nn
from .blocks import gen_sineembed_for_position, linear_relu_ln

# from .vis_utils import draw_scenes


@MODELS.register_module()
class ImplicitFlattenFusion(BaseModule):
    def __init__(
        self,
        embed_dims: int = 256,
        num_groups: int = 8,
        attn_drop: float = 0.0,
        img_feature_size=[3, 8, 14],
        pts_feature_size=[8, 8],
        **kwargs,
    ):
        super(ImplicitFlattenFusion, self).__init__()
        if embed_dims % num_groups != 0:
            raise ValueError(
                f"embed_dims must be divisible by num_groups, "
                f"but got {embed_dims} and {num_groups}"
            )

        self.embed_dims = embed_dims
        self.pos_emb = nn.Parameter(
            torch.zeros(
                1,
                img_feature_size[0] * img_feature_size[1] * img_feature_size[2]
                + pts_feature_size[0] * pts_feature_size[1],
                embed_dims,
            )
        )

        self.anchor_pos_encoder = nn.Sequential(
            *linear_relu_ln(embed_dims, 1, 1, embed_dims),
            nn.Linear(embed_dims, embed_dims),
        )

        self.cross_attention_layer = nn.TransformerDecoderLayer(
            d_model=embed_dims,
            nhead=num_groups,
            dim_feedforward=embed_dims * 4,
            dropout=attn_drop,
            batch_first=True,
        )

        self.self_attention_layer = CrossAttentionDecoderLayer(
            d_model=embed_dims,
            nhead=num_groups,
            dim_feedforward=embed_dims * 4,
            dropout=attn_drop,
        )

    def init_weights(self):
        for module in self.modules():
            if isinstance(module, nn.Linear):
                module.weight.data.normal_(
                    mean=0.0,
                    std=0.02,
                )
                if module.bias is not None:
                    module.bias.data.zero_()
            elif isinstance(module, nn.LayerNorm):
                module.bias.data.zero_()
                module.weight.data.fill_(1.0)

    def forward(
        self,
        instance_feature: torch.Tensor,
        implicit_feature: torch.Tensor,
        anchor: torch.Tensor,
        img_features: torch.Tensor,
        pts_features: torch.Tensor,
        **kwargs: dict,
    ):
        feature_map = torch.cat([img_features, pts_features], dim=1) + self.pos_emb
        anchor_embed = gen_sineembed_for_position(anchor, self.embed_dims)
        anchor_pos_feature = self.anchor_pos_encoder(anchor_embed)

        implicit_feature = self.cross_attention_layer(
            implicit_feature + anchor_pos_feature, feature_map
        )
        implicit_feature = self.self_attention_layer(
            implicit_feature,
            implicit_feature,
            pos=anchor_pos_feature,
            query_pos=anchor_pos_feature,
        )
        return implicit_feature


class CrossAttentionDecoderLayer(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1):
        super().__init__()

        # Cross-attention
        self.cross_attn = nn.MultiheadAttention(
            d_model, nhead, dropout=dropout, batch_first=True
        )

        # Feedforward
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)

        # Normalization
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

        # Dropout layers
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

    def forward(
        self,
        tgt,
        memory,
        memory_mask=None,
        memory_key_padding_mask=None,
        pos=None,
        query_pos=None,
    ):
        """
        Args:
            tgt: (batch, tgt_len, d_model)
            memory: (batch, src_len, d_model)
            pos: positional embedding for memory (batch, src_len, d_model)
            query_pos: positional embedding for tgt (batch, tgt_len, d_model)
        """
        # Add positional embeddings
        q = tgt + query_pos if query_pos is not None else tgt
        k = memory + pos if pos is not None else memory

        # Cross-attention
        tgt2 = self.cross_attn(
            q,
            k,
            value=memory,
            attn_mask=memory_mask,
            key_padding_mask=memory_key_padding_mask,
        )[0]
        tgt = tgt + self.dropout1(tgt2)
        tgt = self.norm1(tgt)

        # Feedforward
        tgt2 = self.linear2(self.dropout(F.relu(self.linear1(tgt))))
        tgt = tgt + self.dropout2(tgt2)
        tgt = self.norm2(tgt)

        return tgt
