import torch
import torch.nn as nn
import math

class Block(nn.Module):
    """
    Spectral-Filter Block (DCT-Gating)
    
    Fundamentally Different Principle:
    - Domain: Operates in the FREQUENCY domain (via Discrete Cosine Transform), not the feature/spatial domain.
    - Mechanism: It learns a global "Equalizer" (Filter) based on the class distribution.
    - Intuition: Long-tail classes often suffer from high-frequency noise or lack low-frequency structural integrity. 
      This block acts like a learnable Band-Pass/Low-Pass filter to enhance signal quality.
    
    Params Check:
    - DCT Matrix: 0 (Fixed constants)
    - Filter Net: ~1.5k
    - Linear: ~4k
    - Total: ~5.6k (Ultra Efficient)
    """
    def __init__(self, in_dim=64, out_dim=64, cls_num_tensor=None):
        super(Block, self).__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim
        
        if cls_num_tensor is None:
            cls_num_tensor = torch.ones(100)
        self.register_buffer('cls_num', cls_num_tensor.float())
        
        # 1. 初始化 DCT 变换矩阵 (固定的，不可学习)
        # 我们用它把 [N, 64] 的特征变成 [N, 64] 的频谱
        self.register_buffer('dct_matrix', self._get_dct_matrix(in_dim))
        
        # 2. 频谱均衡器生成网络 (Spectral Filter Generator)
        # 输入：类别分布
        # 输出：64个频段的增益系数 (Gain)
        num_classes = len(cls_num_tensor)
        self.filter_gen = nn.Sequential(
            nn.Linear(num_classes, 16),      # 瓶颈
            nn.ReLU(inplace=True),
            nn.Linear(16, in_dim),           # 输出频域掩码
            nn.Sigmoid()                     # 限制增益在 0~1 (相当于滤波)
            # 如果想允许信号放大，可以去掉 Sigmoid 或乘一个系数
        )
        
        # 3. 频域处理后的特征融合
        # 既然在频域动过手术了，回来后用一个 Linear 整理一下
        self.post_linear = nn.Linear(in_dim, out_dim)
        self.norm = nn.LayerNorm(in_dim)
        self.act = nn.GELU()

    def _get_dct_matrix(self, N):
        """
        手动生成 DCT-II 矩阵 (离散余弦变换)
        C_ij = cos( pi * i * (2j + 1) / (2N) )
        """
        dct_m = torch.zeros(N, N)
        for k in range(N):
            for n in range(N):
                # k: 频率索引, n: 时间/特征索引
                # 系数归一化处理
                norm_factor = math.sqrt(1 / N) if k == 0 else math.sqrt(2 / N)
                dct_m[k, n] = norm_factor * math.cos(math.pi * k * (2 * n + 1) / (2 * N))
        return dct_m

    def _get_spectral_filter(self):
        # 1. 预处理分布
        inv_freq = self.cls_num.sum() / (self.cls_num + 1e-6)
        prior = torch.log(inv_freq)
        prior = prior / prior.max()
        
        # 2. 生成滤波器 [64]
        # 这决定了哪些频率通过，哪些被抑制
        # 例如：可能自动学会抑制高频(index大的部分)以去除尾部噪声
        return self.filter_gen(prior)

    def forward(self, feats):
        # feats: [N, 64]
        x = self.norm(feats)
        
        # --- Step 1: 时域 -> 频域 (DCT) ---
        # 矩阵乘法: [N, 64] @ [64, 64]^T = [N, 64]
        # x_freq 代表了特征在不同频率分量上的强度
        x_freq = torch.matmul(x, self.dct_matrix.t())
        
        # --- Step 2: 频域滤波 (Spectral Gating) ---
        # 获取基于长尾分布的均衡器
        spectral_gate = self._get_spectral_filter() # [64]
        
        # 应用滤波器 (哈达玛积)
        # 这一步是在"修整波形"
        x_freq_filtered = x_freq * spectral_gate
        
        # --- Step 3: 频域 -> 时域 (IDCT) ---
        # DCT 是正交矩阵，逆变换等于转置
        # [N, 64] @ [64, 64] = [N, 64]
        x_restored = torch.matmul(x_freq_filtered, self.dct_matrix)
        
        # --- Step 4: 融合 ---
        out = self.post_linear(x_restored)
        out = self.act(out)
        
        # 残差连接
        return feats + out