import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List

# ------------------------------------------------------------------------------------
# 1. 辅助模块：相对位置偏置 (Relative Position Bias)
# ------------------------------------------------------------------------------------
class RelativePositionBias(nn.Module):
    """
    一个可学习的模块，用于为注意力矩阵生成相对位置偏置。
    """
    def __init__(self, num_buckets: int, num_heads: int, max_distance: int):
        super().__init__()
        self.relative_attention_bias = nn.Embedding(2 * max_distance, num_heads)

    def forward(self, relative_pos_matrix: torch.Tensor) -> torch.Tensor:
        """
        Args:
            relative_pos_matrix (torch.Tensor): 形状为 (L, L) 的相对位置矩阵, L是序列长度。
        
        Returns:
            torch.Tensor: 形状为 (1, num_heads, L, L) 的注意力偏置。
        """
        # 将相对位置矩阵的值平移，使其变为非负，作为嵌入表的索引
        # 例如，[-3, 3] -> [0, 6]
        bias_indices = relative_pos_matrix + (self.relative_attention_bias.num_embeddings // 2)
        
        # 查找嵌入并调整形状以匹配注意力矩阵
        # (L, L, num_heads) -> (L, L, num_heads) -> (num_heads, L, L) -> (1, num_heads, L, L)
        relative_bias = self.relative_attention_bias(bias_indices)
        relative_bias = relative_bias.permute(2, 0, 1).unsqueeze(0)
        return relative_bias

# ------------------------------------------------------------------------------------
# 2. 核心模块：全局和局部空间编码器
# ------------------------------------------------------------------------------------
class SpatialTransformerEncoder(nn.Module):
    """
    一个通用的空间Transformer编码器层，支持相对位置偏置。
    可同时用于全局和局部模式。
    """
    def __init__(self, embed_dim: int, num_heads: int, max_seq_len: int, ffn_dim_multiplier: int = 4):
        super().__init__()
        self.num_heads = num_heads
        
        self.pos_bias_generator = RelativePositionBias(
            num_buckets=2 * max_seq_len, # 简化处理，直接映射
            num_heads=num_heads,
            max_distance=max_seq_len
        )
        
        self.self_attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        
        self.ffn = nn.Sequential(
            nn.Linear(embed_dim, embed_dim * ffn_dim_multiplier),
            nn.GELU(),
            nn.Linear(embed_dim * ffn_dim_multiplier, embed_dim),
        )

    def forward(self, features: torch.Tensor, relative_pos_matrix: torch.Tensor) -> torch.Tensor:
        """
        Args:
            features (torch.Tensor): 输入特征，形状 (Batch, SeqLen, Dim)。
            relative_pos_matrix (torch.Tensor): 相对位置矩阵，形状 (SeqLen, SeqLen)。
        
        Returns:
            torch.Tensor: 编码后的特征，形状 (Batch, SeqLen, Dim)。
        """
        # 1. 计算相对位置偏置
        attn_bias = self.pos_bias_generator(relative_pos_matrix)
        
        # 2. 多头自注意力
        batch_size, seq_len, _ = features.shape
        attn_bias = attn_bias.expand(batch_size, -1, -1, -1).reshape(batch_size * self.num_heads, seq_len, seq_len)
        
        attn_output, _ = self.self_attn(features, features, features, attn_mask=attn_bias)
        
        # 3. Add & Norm
        features = self.norm1(features + attn_output)
        
        # 4. FFN
        ffn_output = self.ffn(features)
        
        # 5. Add & Norm
        features = self.norm2(features + ffn_output)
        
        return features

# ------------------------------------------------------------------------------------
# 3. 主模块：空间相关性模块
# ------------------------------------------------------------------------------------
class SpatialCorrelationModule(nn.Module):
    def __init__(self, fpn_dims: List[int], num_heads: int, num_frames: int, local_window_size: int):
        """
        Args:
            fpn_dims (List[int]): FPN每个层级的特征维度 (channel数)。
            num_heads (int): 注意力头的数量。
            num_frames (int): 序列的总帧数/切片数。
            local_window_size (int): 局部窗口的大小 (2*k+1)。
        """
        super().__init__()
        
        # 为FPN的每个层级创建一个全局和局部编码器
        self.global_encoders = nn.ModuleList([
            SpatialTransformerEncoder(embed_dim=dim, num_heads=num_heads, max_seq_len=num_frames)
            for dim in fpn_dims
        ])
        
        self.local_encoders = nn.ModuleList([
            SpatialTransformerEncoder(embed_dim=dim, num_heads=num_heads, max_seq_len=local_window_size)
            for dim in fpn_dims
        ])
        
        # 特征融合层：将原始、全局增强、局部增强三种特征融合
        # 使用 1x1 卷积 (等效于线性层) 来学习如何融合
        self.fusion_layers = nn.ModuleList([
            nn.Conv2d(dim * 3, dim, kernel_size=1)
            for dim in fpn_dims
        ])

    def forward(self, spatial_features: dict) -> List[torch.Tensor]:
        """
        Args:
            spatial_features (dict): 由 `prepare_spatial_features_optimized` 函数生成的字典。

        Returns:
            List[torch.Tensor]: 包含了融合后特征的列表，每个元素的形状为 (B*T, C, H, W)。
        """
        fused_fpn_outputs = []
        
        # 遍历FPN的每一层
        for i, (global_data, local_data) in enumerate(zip(spatial_features['global_mode'], spatial_features['local_mode'])):
            
            # --- 数据准备 ---
            original_features = global_data['features'] # (B, T, C, H, W)
            B, T, C, H, W = original_features.shape
            
            # --- 1. 全局编码 ---
            global_input = original_features.permute(0, 3, 4, 1, 2).reshape(B * H * W, T, C)
            global_pos_matrix = global_data['relative_pos_matrix']
            
            global_output = self.global_encoders[i](global_input, global_pos_matrix)
            global_enhanced_features = global_output.reshape(B, H, W, T, C).permute(0, 3, 4, 1, 2)
            
            # --- 2. 局部编码 ---
            local_features_window = local_data['features'] # (B, T, WinSize, C, H, W)
            WinSize = local_features_window.shape[2]
            
            local_input = local_features_window.permute(0, 1, 4, 5, 2, 3).reshape(B * T * H * W, WinSize, C)
            local_pos_vector = local_data['relative_pos_vector']
            # 将向量转换为矩阵
            local_pos_matrix = local_pos_vector.unsqueeze(0) - local_pos_vector.unsqueeze(1)
            
            local_output = self.local_encoders[i](local_input, local_pos_matrix)
            
            center_feature_index = WinSize // 2
            local_center_features = local_output[:, center_feature_index, :]
            local_enhanced_features = local_center_features.reshape(B, T, H, W, C).permute(0, 4, 1, 2, 3).reshape(B, C, T, H, W)
            local_enhanced_features = local_enhanced_features.permute(0, 2, 1, 3, 4) # -> (B, T, C, H, W)

            # --- 3. 特征融合 ---
            # 将三种特征在通道维度上拼接
            # (B, T, C, H, W) -> (B*T, C, H, W) 以便使用2D卷积
            original_flat = original_features.reshape(B * T, C, H, W)
            global_flat = global_enhanced_features.reshape(B * T, C, H, W)
            local_flat = local_enhanced_features.reshape(B * T, C, H, W)
            
            concatenated_features = torch.cat([original_flat, global_flat, local_flat], dim=1) # 形状: (B*T, 3*C, H, W)
            
            fused_features = self.fusion_layers[i](concatenated_features) # 形状: (B*T, C, H, W)
            fused_fpn_outputs.append(fused_features)
            
        return fused_fpn_outputs