import torch
import torch.nn as nn

class Block(nn.Module):
    """
    长尾感知的特征细化模块，通过门控网络和类别分布条件调制缓解头类偏见
    """
    def __init__(self, in_dim=64, out_dim=64, cls_num_tensor=None):
        super(Block, self).__init__()
        self.in_dim = in_dim  # 输入特征维度
        self.out_dim = out_dim  # 输出特征维度
        
        if cls_num_tensor is not None:
            self.cls_num_tensor = cls_num_tensor
        else:
            self.cls_num_tensor = None
        
        # 隐藏层维度：取4和输入维度1/4中的较大值，保证轻量级结构
        hidden = max(4, in_dim // 4)
        
        # 门控网络：生成感知尾类的特征调制权重（用于细化特征，突出尾类信息）
        self.gate_fc1 = nn.Linear(in_dim, hidden)
        self.gate_fc2 = nn.Linear(hidden, in_dim)
        
        self.relu = nn.ReLU(inplace=True)  # 激活函数（原地操作，节省内存）
        
        # 标记是否拥有类别分布信息（即是否传入了各类别样本数量张量）
        self.has_cls_info = cls_num_tensor is not None
        
        if self.has_cls_info:
            # 类别分布条件调制层：从全局不平衡统计信息中学习调制参数（轻量级线性层）
            self.cls_condition = nn.Linear(1, in_dim, bias=True)
        else:
            self.cls_condition = None
        
        # 投影层：保证输出维度符合要求（输入输出维度不一致时使用线性层，否则使用恒等映射）
        if in_dim != out_dim:
            self.proj = nn.Linear(in_dim, out_dim, bias=True)
        else:
            self.proj = nn.Identity()
        
        # 可训练的强度参数：控制门控调制的幅度（初始值设为0.6）
        self.alpha = nn.Parameter(torch.tensor(0.6))

    def forward(self, feats):
        # 输入特征预期形状：[批量大小N, 输入维度in_dim]
        if feats.dim() != 2 or feats.size(1) != self.in_dim:
            raise ValueError("输入特征feats的形状必须为 [N, in_dim]")
        
        # 基于门控网络的特征细化路径
        g = self.relu(self.gate_fc1(feats))  # 第一步：压缩特征维度，提取核心信息 [N, hidden]
        gate = torch.sigmoid(self.gate_fc2(g))  # 第二步：生成门控权重，归一化到[0,1]区间 [N, in_dim]
        
        # 门控调制：对原始特征进行加权细化，alpha控制调制强度（突出尾类微弱特征）
        refined = feats * (1.0 + self.alpha * gate)
        
        # 可选的类别不平衡条件调制（仅当拥有类别分布信息时执行）
        if self.has_cls_info and self.cls_condition is not None:
            # 不计算梯度：全局统计信息仅用于提供调制参考，不参与反向传播更新
            with torch.no_grad():
                # 计算全局类别样本数量的均值，整理为[1,1]形状的张量（适配线性层输入）
                if self.cls_num_tensor is not None:
                    stat = self.cls_num_tensor.float().mean().view(1, 1)
                else:
                    stat = torch.tensor([[0.0]], dtype=feats.dtype, device=feats.device)
            
            # 生成全局分布调制权重，归一化到[0,1]区间 [1, in_dim]
            scales = torch.sigmoid(self.cls_condition(stat))
            # 全局调制：修正模型对头部类别的偏见，均衡头尾类特征表达
            refined = refined * (1.0 + scales)
        
        # 投影到目标输出维度，保证模块无缝集成到编码器-分类头之间
        out = self.proj(refined)
        return out