import torch
import torch.nn as nn


class LocalContextFusionEncoder(nn.Module):
    def __init__(self, embed_dim, num_heads, k, max_relative_distance):
        """
        用于融合局部3D切片上下文的编码器
        Args:
            embed_dim (int): 特征维度 (例如 SAM2 中的 hidden_dim)
            num_heads (int): Transformer 注意力头数
            k (int): 滑动窗口的半径
            max_relative_distance (int): 最大相对距离, 用于位置编码. 通常设置为 k.
        """
        super().__init__()
        self.embed_dim = embed_dim
        self.k = k

        self.relative_pos_embedding = nn.Embedding(2 * max_relative_distance + 1, embed_dim)

        # 使用一个标准的Transformer Encoder层来处理融合
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim,
            nhead=num_heads,
            dim_feedforward=embed_dim * 4,
            dropout=0.1,
            activation='relu'
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=2)  # 可以使用多层

        # 使用一个标准的 Transformer Decoder 层作为交叉注意力模块
        # 它可以用 Q, K, V 不同的输入
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=embed_dim,
            nhead=num_heads,
            dim_feedforward=embed_dim * 4,
            dropout=0.1,
            activation='relu'
        )
        self.cross_attention_encoder = nn.TransformerDecoder(decoder_layer, num_layers=2)

    def forward(self, sliding_window_feats, relative_indices):
        """
        Args:
            sliding_window_feats (torch.Tensor): 滑动窗口内的特征.
                                                Shape: (num_tokens, B, C),
                                                其中 num_tokens = (H*W) * (2k+1)
            relative_indices (torch.Tensor): 相对索引. Shape: (2k+1,) e.g., [-k, ..., k]

        Returns:
            torch.Tensor: 融合了局部上下文的中心slice特征. Shape: (H*W, B, C)
        """
        # H*W, B, C
        hw, b, c = sliding_window_feats[0].shape

        # 将列表形式的feats堆叠起来，形成 (2k+1, HW, B, C)
        stacked_feats = torch.stack(sliding_window_feats, dim=0)

        # 调整形状为 Transformer期望的 (seq_len, batch_size, embed_dim)
        # 这里 seq_len 是 2k+1, batch_size 是 HW*B
        # (2k+1, HW, B, C) -> (2k+1, HW*B, C)
        seq_len, hw, b, c = stacked_feats.shape
        feats_for_transformer = stacked_feats.permute(1, 2, 0, 3).reshape(hw * b, seq_len, c).permute(1, 0, 2)

        # 获取相对位置编码, 并调整形状以匹配输入
        # relative_indices需要移动到和feats相同的device
        relative_pos_ids = relative_indices.to(feats_for_transformer.device) + self.k
        pos_embed = self.relative_pos_embedding(relative_pos_ids)  # (2k+1, C)
        pos_embed = pos_embed.unsqueeze(1).repeat(1, hw * b, 1)  # (2k+1, HW*B, C)

        transformer_memory = feats_for_transformer + pos_embed

        # --- 修改点 START ---
        # 提取中心切片的特征作为 Query
        center_slice_feat = feats_for_transformer[self.k].unsqueeze(0)  # Shape: (1, HW*B, C)

        # 整个窗口的特征（已加入位置编码）作为 Memory (Key, Value)
        # memory: (SeqLen, Batch, Dim) -> (2k+1, HW*B, C)
        # target: (TargetSeqLen, Batch, Dim) -> (1, HW*B, C)
        fused_output = self.cross_attention_encoder(tgt=center_slice_feat, memory=transformer_memory)

        center_slice_output = fused_output.squeeze(0)  # (HW*B, C)
        # --- 修改点 END ---

        context_fused_feat = center_slice_output.reshape(hw, b, c)
        return context_fused_feat
