import torch
import torch.nn as nn
import torch.nn.functional as F

class Block(nn.Module):
    """
    Prior-Biased MoE (Mixture of Experts)
    
    Distinctive Feature: PARALLEL PATHS instead of SERIAL modulation.
    
    Logic:
    1. Instead of trying to fix features in a single pipe, we create two parallel 'Experts':
       - Expert A (Conservative): Preserves identity/linear relationships.
       - Expert B (Aggressive): Uses non-linear transformation to hallucinate/expand tail features.
    2. The 'Router' looks at the global class distribution (cls_num) to decide the mixing weight.
       If the dataset is heavily imbalanced, it leans more towards Expert B.
       
    Params: 2 * (64*64) + small overhead approx 8.5k.
    """
    def __init__(self, in_dim=64, out_dim=64, cls_num_tensor=None):
        super(Block, self).__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim
        
        # 1. 注册分布先验
        if cls_num_tensor is None:
            cls_num_tensor = torch.ones(100)
        self.register_buffer('cls_num', cls_num_tensor.float())
        
        # 2. 定义两个性格迥异的专家
        
        # Expert A: 保守派。
        # 类似于 ResNet 的直连，倾向于保留原始信息，适合头部类。
        # 初始化接近 Identity
        self.expert_conservative = nn.Linear(in_dim, out_dim)
        
        # Expert B: 激进派。
        # 使用 Tanh 激活（不同于常规 ReLU），提供不同的非线性动力学。
        # 适合这一层需要对尾部特征做大幅扭曲时使用。
        self.expert_aggressive = nn.Sequential(
            nn.Linear(in_dim, out_dim),
            nn.Tanh() # Tanh 把值压缩到 -1~1，这种强约束有时对噪声大的尾部有效
        )
        
        # 3. 路由器 (Router / Gating Network)
        # 输入是整个数据集的分布向量 [C]，输出是两个专家的权重 [2]
        num_classes = len(cls_num_tensor)
        self.router = nn.Linear(num_classes, 2, bias=False)
        
        # 4. 融合后的激活
        self.final_act = nn.ReLU(inplace=True)
        self.norm = nn.LayerNorm(out_dim) # 输出前做一次规范化

    def _compute_routing_weights(self):
        """
        计算全局路由权重。
        这对于一个固定的数据集（和固定的 cls_num）来说，其实是一个常数。
        但在进化过程中，不同数据集会有不同的路由策略。
        """
        # 预处理分布：归一化
        # 这里直接用原始频率，Router 会自己学习如何看待这个分布
        dist = self.cls_num / self.cls_num.sum()
        
        # [C] -> [2]
        logits = self.router(dist)
        
        # 使用 Softmax 归一化为概率分布 (alpha, 1-alpha)
        weights = F.softmax(logits, dim=0)
        return weights

    def forward(self, feats):
        # feats: [N, 64]
        
        # 1. 计算两个专家的输出 (并行计算)
        out_a = self.expert_conservative(feats)
        out_b = self.expert_aggressive(feats)
        
        # 2. 获取路由权重 [2]
        # weights[0] 是 Expert A 的权重, weights[1] 是 Expert B 的权重
        weights = self._compute_routing_weights()
        
        # 3. 加权融合 (Soft Mixing)
        # 广播标量权重到整个特征矩阵
        # mixed = w0 * A + w1 * B
        mixed_out = weights[0] * out_a + weights[1] * out_b
        
        # 4. 后处理
        mixed_out = self.norm(mixed_out)
        mixed_out = self.final_act(mixed_out)
        
        # 残差连接 (这是 ResNet 必须的)
        return feats + mixed_out