import logging
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Optional, Tuple, Type
from sam2.modeling.sam2_utils import LayerNorm2d, MLP


class BoundaryDecoder(nn.Module):
    def __init__(
        self,
        *,
        transformer_dim: int,
        transformer: nn.Module,
        num_multimask_outputs: int = 3,  # 与MaskDecoder保持一致，用于输出多个边界假设
        activation: Type[nn.Module] = nn.GELU,
        use_high_res_features: bool = False,
    ) -> None:
        """
        根据图像和提示嵌入预测边界，使用Transformer架构。
        这个类是MaskDecoder的简化和修改版本，专门用于边界分割。

        参数:
          transformer_dim (int): Transformer的通道维度
          transformer (nn.Module): 用于预测边界的Transformer
          num_multimask_outputs (int): 当需要输出多个边界时，预测的边界数量
          activation (nn.Module): 上采样模块中使用的激活函数类型
          use_high_res_features (bool): 是否使用来自编码器的高分辨率特征图来增强细节
        """
        super().__init__()
        self.transformer_dim = transformer_dim
        self.transformer = transformer

        self.num_boundary_tokens = num_multimask_outputs
        self.boundary_tokens = nn.Embedding(self.num_boundary_tokens, transformer_dim)

        # 上采样模块，将Transformer输出的特征图分辨率提升到输入图像的1/4
        # 结构与MaskDecoder中的output_upscaling保持一致
        self.output_upscaling = nn.Sequential(
            nn.ConvTranspose2d(
                transformer_dim, transformer_dim // 4, kernel_size=2, stride=2
            ),
            LayerNorm2d(transformer_dim // 4),
            activation(),
            nn.ConvTranspose2d(
                transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2
            ),
            activation(),
        )

        # 如果使用高分辨率特征，需要额外的卷积层来匹配通道维度
        self.use_high_res_features = use_high_res_features
        if use_high_res_features:
            self.conv_s1 = nn.Conv2d(  # 对应 upscaling 中第一个 ConvTranspose2d 的输出
                transformer_dim, transformer_dim // 4, kernel_size=1, stride=1
            )
            self.conv_s0 = nn.Conv2d(  # 对应 upscaling 中第二个 ConvTranspose2d 的输出
                transformer_dim, transformer_dim // 8, kernel_size=1, stride=1
            )

        # 输出超网络MLP，为每个boundary_token生成一个分割头
        self.output_hypernetworks_mlps = nn.ModuleList(
            [
                MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)
                for _ in range(self.num_boundary_tokens)
            ]
        )

    def forward(
        self,
        image_embeddings: torch.Tensor,
        image_pe: torch.Tensor,
        sparse_prompt_embeddings: torch.Tensor,
        dense_prompt_embeddings: torch.Tensor,
        multimask_output: bool,
        repeat_image: bool,
        high_res_features: Optional[List[torch.Tensor]] = None,
    ) -> torch.Tensor:
        """
        根据图像和提示嵌入预测边界。

        参数:
          image_embeddings (torch.Tensor): 图像编码器的输出嵌入
          image_pe (torch.Tensor): 图像嵌入的位置编码
          sparse_prompt_embeddings (torch.Tensor): 点和框等稀疏提示的嵌入
          dense_prompt_embeddings (torch.Tensor): 掩码等密集提示的嵌入
          multimask_output (bool): 是否返回多个边界图或单个边界图
          repeat_image (bool): 是否需要将图像数据在batch维度上重复
          high_res_features (Optional[List[torch.Tensor]]): 来自编码器的高分辨率特征图

        返回:
          torch.Tensor: 批处理的预测边界图 (logits)
        """
        # 调用核心预测函数
        boundaries = self.predict_boundaries(
            image_embeddings=image_embeddings,
            image_pe=image_pe,
            sparse_prompt_embeddings=sparse_prompt_embeddings,
            dense_prompt_embeddings=dense_prompt_embeddings,
            repeat_image=repeat_image,
            high_res_features=high_res_features,
        )

        # 根据 multimask_output 参数选择输出
        if multimask_output:
            return boundaries
        else:
            return boundaries[:, 0:1, :, :]

    def predict_boundaries(
        self,
        image_embeddings: torch.Tensor,
        image_pe: torch.Tensor,
        sparse_prompt_embeddings: torch.Tensor,
        dense_prompt_embeddings: torch.Tensor,
        repeat_image: bool,
        high_res_features: Optional[List[torch.Tensor]] = None,
    ) -> torch.Tensor:
        """预测边界的核心逻辑"""
        # 准备输入给Transformer的tokens
        # 将可学习的boundary_tokens与来自prompt_encoder的稀疏提示拼接
        output_tokens = self.boundary_tokens.weight.unsqueeze(0).expand(
            sparse_prompt_embeddings.size(0), -1, -1
        )
        tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)

        # 准备输入给Transformer的图像特征 (src)
        if repeat_image:
            src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
        else:
            assert image_embeddings.shape[0] == tokens.shape[0]
            src = image_embeddings
        src = src + dense_prompt_embeddings

        assert image_pe.size(0) == 1, "image_pe should have size 1 in batch dim"
        pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
        b, c, h, w = src.shape

        # 运行Transformer
        hs, src = self.transformer(src, pos_src, tokens)

        # 从Transformer的输出中分离出处理后的boundary_tokens
        boundary_tokens_out = hs[:, :self.num_boundary_tokens, :]

        # 上采样图像特征并预测边界
        src = src.transpose(1, 2).view(b, c, h, w)

        if not self.use_high_res_features:
            upscaled_embedding = self.output_upscaling(src)
        else:
            # 融合高分辨率特征以获得更精细的细节
            dc1, ln1, act1, dc2, act2 = self.output_upscaling
            feat_s0, feat_s1 = high_res_features

            upscaled_embedding = act1(ln1(dc1(src) + feat_s1))
            upscaled_embedding = act2(dc2(upscaled_embedding) + feat_s0)

        # 使用超网络生成最终的分割头权重
        hyper_in_list: List[torch.Tensor] = []
        for i in range(self.num_boundary_tokens):
            hyper_in_list.append(
                self.output_hypernetworks_mlps[i](boundary_tokens_out[:, i, :])
            )
        hyper_in = torch.stack(hyper_in_list, dim=1)

        # 生成低分辨率边界图
        # [b, num_tokens, c] @ [b, c, h*w] -> [b, num_tokens, h*w]
        b, c, h, w = upscaled_embedding.shape
        boundaries = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)

        return boundaries
