import torch
import torch.nn as nn
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
import gc
import copy
from tqdm import tqdm
import random
import os
import time
import math


class DeepSeekCompressor:
    def __init__(self, layers_expert_groups, layer_group_params, num_calibration_samples=32, max_seq_length=2048, compression_ratio=None):
        # 应该是一个嵌套字典{层索引: {专家组索引: [专家索引]}}
        self.layers_expert_groups = layers_expert_groups
        self.layer_group_params = layer_group_params

        self.model = AutoModelForCausalLM.from_pretrained(
            'deepseek-ai/deepseek-moe-16b-base', torch_dtype=torch.bfloat16, trust_remote_code=True, device_map="balanced").eval()
        self.tokenizer = AutoTokenizer.from_pretrained(
            'deepseek-ai/deepseek-moe-16b-base', trust_remote_code=True, use_fast=False)

        self.device = next(self.model.parameters()).device
        self.num_devices = torch.cuda.device_count()
        self.dtype = next(self.model.parameters()).dtype

        # Prepare calibration dataset
        self.calibration_dataset = load_dataset(
            "wikitext", "wikitext-103-raw-v1", split="train")
        self.num_calibration_samples = num_calibration_samples
        self.max_seq_length = max_seq_length

        self.collected_features = {}
        self.target_layers = []  # 存储所有包含MoE的层

        # 自动识别所有MoE层
        self._identify_moe_layers()

        self.gate_up_max_rank = 512

        # 启用新的低秩保留架构
        self.use_low_rank_preservation = True

        # —— 性能/稳定性开关 ——
        # 1) 最终投影阶段是否使用白化（如关闭则对原矩阵做SVD，更快更稳）
        self.finalize_use_whitening = False
        # 2) SVD 后端：'auto'|'lowrank'|'full'，默认自动：小矩阵或需要精确时用 full，否则用 lowrank
        self.svd_backend = 'auto'
        # 3) 是否向量化白化/反白化，避免与对角矩阵的密集乘法
        self.vectorize_whitening = True

        layer_budget = {}
        total_params_budget = 0
        for layer_idx in self.layer_group_params:
            for group_idx in self.layer_group_params[layer_idx]:
                params = self.layer_group_params[layer_idx][group_idx]
                if params is not None:
                    total_params_budget += params
                    if layer_idx not in layer_budget:
                        layer_budget[layer_idx] = 0
                    layer_budget[layer_idx] += params
        print(f"[INFO] 总参数预算: {total_params_budget/1e6:.2f}M")
        print(f"[INFO] 各层参数预算: {layer_budget}")
        self.layer_budget_ratio = {}
        for layer_idx in layer_budget:
            self.layer_budget_ratio[layer_idx] = layer_budget[layer_idx] / \
                total_params_budget

        print(f"[INFO] 各层压缩率: {self.layer_budget_ratio}")
        self.avg_layer_budget_ratio = 1 / len(layer_budget)
        print(f"[INFO] 平均层压缩率: {self.avg_layer_budget_ratio:.4f}")
        
        self.compression_ratio = compression_ratio

    def _identify_moe_layers(self):
        """自动识别模型中所有包含MoE的层"""
        self.target_layers = []
        for i, layer in enumerate(self.model.model.layers):
            if hasattr(layer, 'mlp') and hasattr(layer.mlp, 'experts'):
                self.target_layers.append(i)

        print(f"Total MoE layers found: {len(self.target_layers)}")

    def _collect_features_hook_all_experts(self, layer_idx):
        """为指定层创建特征收集钩子"""
        def hook_fn(module, args):
            try:
                hidden_states = args[0]
                routing_indices, routing_weights, _ = module.gate(
                    hidden_states)

                # 初始化该层的特征收集结构 - 只存储必要的激活
                if layer_idx not in self.collected_features:
                    self.collected_features[layer_idx] = {}
                    for expert_idx in range(len(module.experts)):
                        self.collected_features[layer_idx][expert_idx] = {
                            'selected_hidden': [],   # 输入激活 - SVD-LLMv2需要
                            'intermediate': [],      # 中间激活 - SVD-LLMv2需要
                            'routing_weights': []    # 每个 token 到该专家的路由权重
                        }

                # 批量处理每个专家
                batch_size, seq_len, hidden_dim = hidden_states.shape
                top_k = routing_indices.shape[-1]

                # 展平所有张量以便处理
                flat_hidden_states = hidden_states.reshape(-1, hidden_dim)
                flat_routing_indices = routing_indices.reshape(-1, top_k)
                flat_routing_weights = routing_weights.reshape(-1, top_k)

                # 为每个专家建立一个索引列表
                expert_to_tokens = {}
                expert_to_weights = {}

                for token_idx in range(flat_routing_indices.shape[0]):
                    for k_idx in range(top_k):
                        expert_idx = flat_routing_indices[token_idx, k_idx].item(
                        )
                        weight = flat_routing_weights[token_idx, k_idx].item()

                        if weight > 1e-6:  # 权重阈值过滤
                            if expert_idx not in expert_to_tokens:
                                expert_to_tokens[expert_idx] = []
                                expert_to_weights[expert_idx] = []
                            expert_to_tokens[expert_idx].append(token_idx)
                            expert_to_weights[expert_idx].append(weight)

                # 一次性处理每个专家的所有tokens
                for expert_idx, token_indices in expert_to_tokens.items():
                    # 获取专家参数所在的设备
                    expert_device = module.experts[expert_idx].gate_proj.weight.device

                    # 批量计算特征
                    with torch.cuda.device(expert_device):
                        # 批量提取tokens
                        token_indices = torch.tensor(
                            token_indices, device=expert_device)
                        selected_hidden = flat_hidden_states.to(expert_device).index_select(
                            0, token_indices)

                        # 批量计算gate_proj
                        gate_proj_output = module.experts[expert_idx].gate_proj(
                            selected_hidden)

                        # 批量计算up_proj
                        up_proj_output = module.experts[expert_idx].up_proj(
                            selected_hidden)

                        # 计算中间激活
                        intermediate = module.experts[expert_idx].act_fn(
                            gate_proj_output) * up_proj_output

                        # 只存储SVD-LLMv2需要的激活，节省60%内存
                        self.collected_features[layer_idx][expert_idx]['selected_hidden'].append(
                            selected_hidden.detach().cpu())
                        self.collected_features[layer_idx][expert_idx]['intermediate'].append(
                            intermediate.detach().cpu())

                        # 同步收集该 token 对应的路由权重（与 selected_hidden 行一一对应）
                        # expert_to_weights 存储了与 token_indices 相同顺序的权重列表
                        token_weights = torch.tensor(
                            expert_to_weights.get(expert_idx, []), device=expert_device, dtype=selected_hidden.dtype)
                        # 存为 CPU tensor，以便后续合并和跨设备处理
                        self.collected_features[layer_idx][expert_idx]['routing_weights'].append(
                            token_weights.detach().cpu())

                        del token_weights
                        del gate_proj_output, up_proj_output                   

            except Exception as e:
                print(f"Error in layer {layer_idx} feature collection: {e}")
                import traceback
                traceback.print_exc()

        return hook_fn

    def collect_all_features(self):
        """收集所有MoE层的所有专家特征"""
        print("Collecting features from all MoE layers...")

        # 清空之前的收集结果
        self.collected_features = {}

        # 为所有MoE层注册钩子
        hooks = []
        for layer_idx in self.target_layers:
            moe_module = self.model.model.layers[layer_idx].mlp
            hook_fn = self._collect_features_hook_all_experts(layer_idx)
            hook = moe_module.register_forward_pre_hook(hook_fn)
            hooks.append(hook)

        random.seed(42)
        torch.manual_seed(42)

        # 拼接所有文本
        all_texts = [
            text for text in self.calibration_dataset['text'] if text.strip()]
        tot_text = "\n\n".join(all_texts)
        print(f"Total text length: {len(tot_text)} characters")

        sample_count = 0
        for s in range(self.num_calibration_samples):
            try:
                # 随机选择起始位置
                i = random.randint(0, len(tot_text) - self.max_seq_length - 1)
                j = i + self.max_seq_length * 10
                text_segment = tot_text[i:j]

                # 找到第一个句号后开始（与get_svd_scale保持一致）
                ind = text_segment.find(".")
                if ind != -1:
                    text_segment = text_segment[ind + 1:].strip()

                # Tokenize
                inputs = self.tokenizer(
                    text_segment, return_tensors="pt",
                    truncation=True, max_length=self.max_seq_length
                )

                # 检查序列长度
                if inputs.input_ids.shape[1] < self.max_seq_length:
                    continue

                # 处理device_map="auto"的情况 - 获取第一层的设备
                target_device = next(
                    self.model.model.layers[0].parameters()).device
                for key in inputs:
                    if torch.is_tensor(inputs[key]):
                        inputs[key] = inputs[key].to(target_device)

                with torch.no_grad():
                    self.model(**inputs)

                sample_count += 1
                if sample_count % 10 == 0:
                    print(f"Processed {sample_count} samples...")

            except Exception as e:
                print(f"Error processing sample {sample_count}: {e}")
                continue

        # 移除所有钩子
        for hook in hooks:
            hook.remove()

        # 整理收集到的特征
        self._consolidate_features()

        print(
            f"Feature collection completed. Processed {sample_count} samples.")
        self._print_collection_summary()

    def _consolidate_features(self):
        """整理收集到的特征，将列表转换为张量"""
        print("Consolidating collected features...")

        # 并行处理多个层
        for layer_idx in self.collected_features:
            # 预分配空间存储结果
            consolidated_features = {}

            for expert_idx in self.collected_features[layer_idx]:
                consolidated_features[expert_idx] = {
                    'selected_hidden': None, 'intermediate': None}

                # 只合并实际存储的特征类型，大幅减少内存使用
                for feature_type in ['selected_hidden', 'intermediate', 'routing_weights']:
                    feature_list = self.collected_features[layer_idx][expert_idx][feature_type]
                    if feature_list:
                        # 直接在CPU上合并，避免GPU内存不足
                        try:
                            consolidated_features[expert_idx][feature_type] = torch.cat(
                                feature_list, dim=0).to(self.dtype)
                        except RuntimeError as e:
                            # 如果内存不足，分批合并
                            print(
                                f"Memory error for {layer_idx}-{expert_idx}-{feature_type}, trying batch concat")
                            result = []
                            batch_size = max(1, len(feature_list) // 10)
                            for i in range(0, len(feature_list), batch_size):
                                batch = feature_list[i:i+batch_size]
                                if batch:
                                    result.append(torch.cat(batch, dim=0))

                            if result:
                                consolidated_features[expert_idx][feature_type] = torch.cat(
                                    result, dim=0).to(self.dtype)
                            else:
                                consolidated_features[expert_idx][feature_type] = None
                    else:
                        consolidated_features[expert_idx][feature_type] = None

                # 释放原始列表内存
                for feature_type in ['selected_hidden', 'intermediate', 'routing_weights']:
                    self.collected_features[layer_idx][expert_idx][feature_type] = None

            # 用合并后的特征替换原始特征列表
            self.collected_features[layer_idx] = consolidated_features

            # 主动触发垃圾回收
            gc.collect()
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            print(f"Layer {layer_idx} consolidation completed")

    def _print_collection_summary(self):
        """打印特征收集的摘要信息"""
        print("\n=== Feature Collection Summary ===")
        total_features = 0

        for layer_idx in sorted(self.collected_features.keys()):
            layer_features = 0
            expert_count = len(self.collected_features[layer_idx])

            for expert_idx in self.collected_features[layer_idx]:
                selected_hidden_features = self.collected_features[
                    layer_idx][expert_idx]['selected_hidden']
                intermediate_features = self.collected_features[layer_idx][expert_idx]['intermediate']

                if selected_hidden_features is not None:
                    layer_features += selected_hidden_features.shape[0]
                if intermediate_features is not None:
                    layer_features += intermediate_features.shape[0]

            print(
                f"Layer {layer_idx}: {expert_count} experts, {layer_features} total features")
            total_features += layer_features

        print(f"Total features collected: {total_features}")
        print("=" * 40)

    def get_layer_expert_features(self, layer_idx, expert_idx):
        """获取指定层指定专家的特征 - 动态计算gate/up/down输出以节省内存"""
        # 使用专家参数所在的设备，而不是固定的self.device
        expert_device = self.model.model.layers[layer_idx].mlp.experts[expert_idx].gate_proj.weight.device

        selected_hidden = self.collected_features[layer_idx][expert_idx]['selected_hidden'].to(
            expert_device)
        intermediate = self.collected_features[layer_idx][expert_idx]['intermediate'].to(
            expert_device)
        routing_weights = self.collected_features[layer_idx][expert_idx]['routing_weights'].to(
            expert_device)

        # 动态计算其他特征以节省内存
        # with torch.no_grad():
        #     expert = self.model.model.layers[layer_idx].mlp.experts[expert_idx]
        #     gate_proj = expert.gate_proj(selected_hidden)
        #     up_proj = expert.up_proj(selected_hidden)
        #     down_proj = expert.down_proj(intermediate)
        gate_proj = None
        up_proj = None
        down_proj = None

        return (selected_hidden, gate_proj, up_proj, down_proj, intermediate, routing_weights)

    def _reconstruct_full_matrix(self, ef, high_mat, high_idx, medium_comp, medium_recov, orig_shape):
        n_rows, in_dim = orig_shape
        # 在 CPU 上先分配（小心显存），随后移动到目标设备
        full = None
        if high_mat is not None:
            dtype = high_mat.dtype
            device = high_mat.device
        elif medium_comp is not None:
            dtype = medium_comp.dtype
            device = medium_comp.device
        else:
            dtype = torch.bfloat16  # 默认类型
            device = self.device
        full = torch.zeros((n_rows, in_dim), dtype=dtype).to(device)

        if high_mat is not None:
            full = full + high_mat.to(full.device)
        # 填充 medium 行
        if medium_comp is not None and medium_recov is not None:
            med = ef.get('medium_indices')
            med = med.long().to(full.device)
            # medium_recov (num_medium_rows, rank); medium_comp (rank, in_dim)
            med_mat = medium_recov @ medium_comp
            full[med] = med_mat
        return full

    def compress_all_experts(self):  # 添加默认参数
        """压缩所有专家"""

        # 首先收集所有特征
        self.collect_all_features()

        # 然后对每个层的每个专家进行压缩
        compressed_model = self.model

        self._prepare_features_for_processing()

        for layer_idx in self.target_layers:
            print(f"\nCompressing layer {layer_idx}...")

            layer_compressed_weights = {}
            layer_based_weights = {}
            layer_recovery_matrices = {}
            layer_bias_corrections = {}

            # 遍历当层下所有专家组
            if self.use_low_rank_preservation:
                # 使用新的低秩保留架构
                print(f"[INFO] 使用新的低秩保留架构处理层 {layer_idx}")
                layer_expert_factors = {}
                # 每层共享的 base_down，只存一份
                layer_base_down = None

                for expert_group in self.layers_expert_groups[layer_idx]:
                    # 获取当前专家组下的所有专家
                    expert_indices = self.layers_expert_groups[layer_idx][expert_group]

                    selected_hidden_features_list = {}
                    intermediate_features_list = {}
                    routing_weights_list = {}

                    for expert_idx in expert_indices:
                        # 获取三个特征
                        selected_hidden_features, _, _, _, intermediate_features, routing_weights = self.get_layer_expert_features(
                            layer_idx, expert_idx)

                        selected_hidden_features_list[expert_idx] = selected_hidden_features
                        intermediate_features_list[expert_idx] = intermediate_features
                        routing_weights_list[expert_idx] = routing_weights

                    # 获取专家权重
                    experts_gate_proj_weights, experts_up_proj_weights, experts_down_proj_weights = {}, {}, {}
                    for idx in expert_indices:
                        experts_gate_proj_weights[idx] = compressed_model.model.layers[
                            layer_idx].mlp.experts[idx].gate_proj.weight.data
                        experts_up_proj_weights[idx] = compressed_model.model.layers[layer_idx].mlp.experts[idx].up_proj.weight.data
                        experts_down_proj_weights[idx] = compressed_model.model.layers[
                            layer_idx].mlp.experts[idx].down_proj.weight.data

                    # 计算激活感知重要性分数
                    act_fn = compressed_model.model.layers[layer_idx].mlp.experts[expert_indices[0]].act_fn
                    importance_info = self._get_importance(
                        expert_indices,
                        intermediate_features_list,
                        routing_weights_list,
                        experts_down_proj_weights,
                        layer_idx,
                    )

                    # 应用差分压缩并获取新格式的因子（目前只处理 gate/up 的高/中/低分组）
                    group_expert_factors = self._apply_differential_compression(
                        expert_indices, importance_info, experts_gate_proj_weights,
                        experts_up_proj_weights, selected_hidden_features_list, {}, layer_idx=layer_idx, group_idx=expert_group
                    )

                    cached = {}  # expert_idx -> dict: {Xc, Xt, Y, gateW, upW, recov_gate, recov_up}
                    sample_cap_per_expert = 8192 * 2  # 防OOM，可按需调小
                    Xc_cat, Xt_cat, Y_cat = [], [], []
                    for expert_idx in expert_indices:
                        ef = group_expert_factors.get(expert_idx, {})
                        orig_shape = ef.get('original_shape', None)
                        gate_high = ef.get('gate_high', None)
                        gate_med_c = ef.get('gate_medium_compressed', None)
                        gate_med_r = ef.get('gate_medium_recovery', None)

                        up_high = ef.get('up_high', None)
                        up_med_c = ef.get('up_medium_compressed', None)
                        up_med_r = ef.get('up_medium_recovery', None)

                        G_full = self._reconstruct_full_matrix(ef, gate_high, ef.get(
                            'high_indices', None), gate_med_c, gate_med_r, orig_shape)
                        U_full = self._reconstruct_full_matrix(ef, up_high, ef.get(
                            'high_indices', None), up_med_c, up_med_r, orig_shape)

                        expert_param_device = self.model.model.layers[
                            layer_idx].mlp.experts[expert_idx].gate_proj.weight.device
                        dev = expert_param_device

                        expert_hidden = selected_hidden_features_list[expert_idx]
                        H = expert_hidden.to(dev)

                        # 确保所有压缩矩阵都在同一设备上
                        G_full = G_full.to(dev)
                        U_full = U_full.to(dev)

                        Gout = H @ G_full.T
                        Uout = H @ U_full.T

                        Xc = self.model.model.layers[layer_idx].mlp.experts[expert_idx].act_fn(
                            Gout) * Uout

                        # 取真实中间态与计算真实输出，确保在正确设备上
                        Xt = intermediate_features_list[expert_idx].to(dev)
                        # 动态计算down_proj输出以节省内存
                        Y = self.model.model.layers[layer_idx].mlp.experts[expert_idx].down_proj(
                            Xt)
                        Yc = self.model.model.layers[layer_idx].mlp.experts[expert_idx].down_proj(
                            Xc)

                        # 随机下采样，防止过大
                        if Xc.shape[0] > sample_cap_per_expert:
                            idx = torch.randperm(Xc.shape[0], device=dev)[
                                :sample_cap_per_expert]
                            Xc_s, Xt_s, Y_s = Xc[idx], Xt[idx], Y[idx]
                        else:
                            Xc_s, Xt_s, Y_s = Xc, Xt, Y

                        # 安全地转换为 float 并移动到 CPU 以避免设备内存问题
                        Xc_cat.append(Xc_s.float().cpu())
                        Xt_cat.append(Xt_s.float().cpu())
                        Y_cat.append(Y_s.float().cpu())

                        cached[expert_idx] = dict(
                            Xc=Xc
                        )

                        del Xc, Xt, Y, G_full, U_full, Gout, Uout, H
                        gc.collect()
                        torch.cuda.empty_cache()

                if len(Xc_cat) > 0:
                    # 选择显存最充足的设备进行计算，而不是默认使用第一张卡
                    if torch.cuda.is_available():
                        # 查找显存最充足的设备
                        best_device = 0
                        max_free_memory = 0
                        for i in range(torch.cuda.device_count()):
                            free_memory = torch.cuda.get_device_properties(
                                i).total_memory - torch.cuda.memory_allocated(i)
                            if free_memory > max_free_memory:
                                max_free_memory = free_memory
                                best_device = i
                        compute_device = f"cuda:{best_device}"
                        print(
                            f"  Using device {compute_device} for group compression (most free memory: {max_free_memory/1e9:.1f}GB)")
                    else:
                        compute_device = "cpu"

                    Xc_cat = torch.cat([x.to(compute_device)
                                       for x in Xc_cat], dim=0)
                    Xt_cat = torch.cat([x.to(compute_device)
                                       for x in Xt_cat], dim=0)
                    Y_cat = torch.cat([x.to(compute_device)
                                      for x in Y_cat], dim=0)

                    beta = 1e-3
                    R = None
                    I = torch.eye(
                        Xc_cat.shape[1], device=Xc_cat.device, dtype=Xc_cat.dtype)
                    XtX = Xc_cat.T @ Xc_cat + beta * I
                    XtX = XtX + 1e-6 * I  # 数值稳定
                    XcT_Xt = Xc_cat.T @ Xt_cat
                    # (inter_dim, inter_dim)
                    R = torch.linalg.solve(XtX, XcT_Xt)

                    # 清理中间变量
                    del I, XtX, XcT_Xt
                    torch.cuda.empty_cache()

                    # 4) 以 Xc 为自变量重拟合组级 base_down：min_B ||Y - Xc @ B^T||^2 + λ||B||^2
                    lam = 1e-3
                    I = torch.eye(
                        Xc_cat.shape[1], device=Xc_cat.device, dtype=Xc_cat.dtype)
                    G = Xc_cat.T @ Xc_cat + lam * I  # (inter, inter)
                    G = G + 1e-6 * I
                    XcT_Y = Xc_cat.T @ Y_cat  # (inter, hidden)
                    B_T = torch.linalg.solve(
                        G, XcT_Y)  # (inter, hidden)
                    B_refit = B_T.T  # (hidden, inter)

                    # 清理中间变量
                    del I, G, XcT_Y, B_T
                    torch.cuda.empty_cache()

                    # 5) 将桥接吸收到权重右侧：B_adj = B_refit @ R, W_delta' = (W_orig - B_refit) @ R
                    if R is not None:
                        base_down_matrix = (B_refit @ R).to(self.dtype)
                    else:
                        base_down_matrix = B_refit.to(self.dtype)

                    delta_weights = {}
                    for idx in expert_indices:
                        W_orig = experts_down_proj_weights[idx].to(
                            base_down_matrix.device)
                        delta_weights[idx] = (
                            W_orig - base_down_matrix).float().detach().cpu()

                    # 首先尝试复用 gate/up 的分组（按行）。importance_info 来自上游 _get_importance
                    importance_info_down = None
                    # 将 gate/up 的行索引直接作为 down 的列索引复用
                    per_expert = importance_info.get('per_expert_groups')
                    per_expert_down = {}
                    for idx in expert_indices:
                        grp = per_expert.get(
                            idx, {"high": [], "medium": [], "low": []})
                        # 直接映射：行索引 -> down 的列索引
                        per_expert_down[idx] = {
                            'high': list(grp.get('high', [])),
                            'medium': list(grp.get('medium', [])),
                            'low': list(grp.get('low', []))
                        }
                    importance_info_down = {
                        'per_expert_groups': per_expert_down,
                        'high_count': importance_info.get('high_count', 0),
                        'mid_count': importance_info.get('mid_count', 0),
                        'low_count': importance_info.get('low_count', 0),
                    }
                    importance_info['down'] = importance_info_down

                    # -------------------- 对 down_delta 按列分组压缩 --------------------
                    # 使用之前计算好的 importance_info['down']（列重要性分组）进行列级别的低秩压缩
                    importance_info_down = None
                    if importance_info is not None:
                        importance_info_down = importance_info.get(
                            'down') or importance_info.get('down_proj')

                    if importance_info_down is not None:
                        group_budget = None
                        if hasattr(self, 'layer_group_params') and layer_idx in self.layer_group_params and expert_group in self.layer_group_params[layer_idx]:
                            group_budget = int(
                                self.layer_group_params[layer_idx][expert_group])

                        down_factors = self._apply_differential_compression_down(
                            expert_indices,
                            importance_info_down,
                            delta_weights,
                            expert_features=cached,
                            expert_factors=group_expert_factors,
                            base_down_matrix=base_down_matrix,
                            R=R,
                            layer_idx=layer_idx,
                            group_idx=expert_group,
                            group_budget=group_budget,
                        )

                        # 将 down 压缩结果合并回 group_expert_factors，供后续 runtime 使用
                        for idx in expert_indices:
                            gf = group_expert_factors.get(idx, {})
                            df = down_factors.get(idx, {})
                            # 合并 down 因子/索引
                            if df:
                                gf.update(df)

                            # 不在每个专家中重复存储 base_down；把第一份 base_down 保存为 layer 级别
                            if layer_base_down is None and base_down_matrix is not None:
                                layer_base_down = base_down_matrix
                            group_expert_factors[idx] = gf

                        # 将本组的专家因子合并到本层集合中（避免后续重复替换）
                        for eidx, ef_flat in group_expert_factors.items():
                            layer_expert_factors[eidx] = ef_flat

            print(
                f"Replacing MoE module for layer {layer_idx} with compressed experts. layer_base_down set: {layer_base_down is not None}")
            if layer_expert_factors:
                self._create_and_replace_oage_moe_v2(
                    compressed_model, layer_idx, layer_expert_factors, layer_base_down)

            if layer_idx in self.collected_features:
                print(f"Releasing features for layer {layer_idx}...")
                del self.collected_features[layer_idx]
                gc.collect()
                torch.cuda.empty_cache()

        self._release_all_features()

        return compressed_model, self.tokenizer

    def _release_all_features(self):
        """释放所有收集的激活值"""
        print("Releasing all collected features...")

        # 清空所有特征
        self.collected_features.clear()

        # 强制垃圾回收
        gc.collect()

        # 清空GPU缓存
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

        print("All features released successfully!")

    def _prepare_features_for_processing(self):
        """预处理特征数据，转换为GPU张量 - 仅处理存储的特征"""
        print("Preparing features for batch processing...")

        for layer_idx in self.collected_features:
            for expert_idx in self.collected_features[layer_idx]:
                # 只处理实际存储的特征类型
                for feature_type in ['selected_hidden', 'intermediate', 'routing_weights']:
                    features = self.collected_features[layer_idx][expert_idx][feature_type]
                    if features is not None:
                        # 确保在GPU上且内存连续
                        features = features.contiguous()
                        self.collected_features[layer_idx][expert_idx][feature_type] = features

    def _get_importance(self, expert_indices, features, routing_weights, down_weights, layer_idx=None, eps: float = 1e-12):
        """计算重要性并分组。"""
        print("Computing Row-based Intermediate Importance scores (per-expert)...")

        all_scores = []
        all_positions = []
        anomaly_threshold_c = 2.5  # MAD异常阈值系数

        # 直接使用已收集的intermediate特征，避免重复计算
        for idx in expert_indices:
            # 获取设备并确保数据在正确设备上
            device = down_weights[idx].device
            intermediate_features = features.get(idx, None).to(device).float()
            intermediate_routing_weights = routing_weights.get(
                idx, None).to(device).float()
            Dw = down_weights.get(idx, None).to(device).float()

            # 内存优化：如果样本数太多，进行采样；同时保证 routing_weights 与样本一一对应
            max_samples_for_stats = int(4096 * 8)
            if intermediate_features.shape[0] > max_samples_for_stats:
                sample_norms = torch.norm(intermediate_features, dim=1, p=2)
                _, top_indices = torch.topk(
                    sample_norms, max_samples_for_stats)
                intermediate_features = intermediate_features[top_indices]
                intermediate_routing_weights = intermediate_routing_weights[top_indices]
                print(
                    f"    Sampled {max_samples_for_stats} most active samples from {intermediate_features.shape[0]}")

            # 向量化计算：W_act (N, D) = intermediate * weight
            weights = intermediate_routing_weights.view(-1, 1)  # (N,1)
            W_act = intermediate_features * weights  # (N, D)

            # # 统计量（按列/维度）
            # magnitude_score = torch.mean(torch.abs(W_act), dim=0)  # (D,)
            # variance_score = torch.var(W_act, dim=0)  # (D,)

            # median_val = torch.median(W_act, dim=0).values  # (D,)
            # mad_val = torch.median(torch.abs(W_act - median_val.unsqueeze(0)), dim=0).values  # (D,)
            # # 用 std 作为备选，避免 mad 为 0
            # small_mad_mask = mad_val < eps
            # if small_mad_mask.any():
            #     std_val = torch.std(W_act, dim=0)
            #     mad_val = torch.where(small_mad_mask, std_val + eps, mad_val)

            # threshold = median_val + anomaly_threshold_c * mad_val  # (D,)
            # anomaly_frequency = (torch.abs(W_act) > torch.abs(threshold.unsqueeze(0))).float().mean(dim=0)  # (D,)

            # # Dw 列范数 -> 每个中间维对应的 down 列范数
            # down_col_norms = torch.norm(Dw, dim=0, p=2)  # (D,)
            # xc_abs_var = torch.var(torch.abs(W_act), dim=0)  # (D,)
            # output_dynamics_score = (down_col_norms ** 2) * xc_abs_var  # (D,)

            # importance_scores = (
            #     0.2 * magnitude_score +
            #     0.2 * torch.sqrt(variance_score + eps) +
            #     0.3 * anomaly_frequency * magnitude_score +
            #     0.4 * output_dynamics_score
            # )  # (D,)

            w_mean = 0.0 # 稳定高能的重要性
            w_var = 0.1  # 能量波动的重要性
            w_peak = 0.9 # 峰值的重要性

            weighted_activations_abs = torch.abs(W_act)

            # 1. 计算三个核心指标
            # a. 平均能量得分 (mean of squares)
            mean_energy_score = torch.mean(weighted_activations_abs ** 2, dim=0)

            # b. 能量波动得分 (variance of squares)
            var_energy_score = torch.var(weighted_activations_abs ** 2, dim=0)

            # c. 峰值奖励
            peak_bonus_score = torch.max(weighted_activations_abs, dim=0).values

            # 2. 对各指标进行归一化，防止数值尺度差异过大
            mean_energy_score_norm = mean_energy_score / (mean_energy_score.max() + eps)
            var_energy_score_norm = var_energy_score / (var_energy_score.max() + eps)
            peak_bonus_score_norm = peak_bonus_score / (peak_bonus_score.max() + eps)

            # 3. 加权求和得到“混合激活分数”
            hybrid_activation_scores = (
                w_mean * mean_energy_score_norm + 
                w_var * var_energy_score_norm + 
                w_peak * peak_bonus_score_norm
            )

            # 3. 计算“输出能量 (Projection Energy)”
            projection_energy = torch.norm(Dw, p=2, dim=0) ** 2

            # 4. 计算最终的重要性分数
            intrinsic_power_scores = hybrid_activation_scores * projection_energy
            
            # 2. 计算“外部协同性”分数 (全新的、快速的下游影响代理)
            if layer_idx < self.target_layers[-1]: # 如果不是最后一层
                next_layer_idx = self.target_layers[self.target_layers.index(layer_idx) + 1]
                next_moe_module = self.model.model.layers[next_layer_idx].mlp

                # a. 准备下一层所有微专家的“期望输入方向” V_in_next
                next_layer_input_vectors = []
                for e_idx_next in range(len(next_moe_module.experts)):
                    gate_w = next_moe_module.experts[e_idx_next].gate_proj.weight.data.float()
                    up_w = next_moe_module.experts[e_idx_next].up_proj.weight.data.float()
                    v_in = (gate_w + up_w) / 2.0  # (inter_dim_next, hidden_dim)
                    next_layer_input_vectors.append(v_in)
                
                # (total_next_dims, hidden_dim)
                V_in_next = torch.cat(next_layer_input_vectors, dim=0).to(device)

                # b. 用下一层输入权重的范数作为其重要性的简单代理
                # (total_next_dims,)
                next_layer_proxy_importance = torch.norm(V_in_next, p=2, dim=1)
                
                # c. 当前专家的“输出方向” V_out_curr
                # W_down_curr (Dw) shape is (hidden_dim, inter_dim_curr)
                # V_out_curr shape should be (inter_dim_curr, hidden_dim)
                V_out_curr = Dw.T

                # d. 计算加权对齐度
                # (inter_dim_curr, hidden_dim) @ (hidden_dim, total_next_dims) -> (inter_dim_curr, total_next_dims)
                alignment_matrix = torch.abs(V_out_curr @ V_in_next.T)
                
                # (inter_dim_curr, total_next_dims) @ (total_next_dims,) -> (inter_dim_curr,)
                extrinsic_synergy_scores = alignment_matrix @ next_layer_proxy_importance
            else:
                # 最后一层没有下游，协同性为0
                extrinsic_synergy_scores = torch.zeros_like(intrinsic_power_scores)
                
            norm_synergy_scores = extrinsic_synergy_scores / (extrinsic_synergy_scores.max() + eps) * intrinsic_power_scores.max()
            importance_scores = intrinsic_power_scores + 0.0 * norm_synergy_scores

            # 展开到全局 score/position 列表
            D = importance_scores.shape[0]
            for col_idx in range(D):
                all_scores.append(importance_scores[col_idx].item())
                all_positions.append((idx, col_idx))

            # 清理中间变量
            del intermediate_features, intermediate_routing_weights, W_act
            gc.collect()
            torch.cuda.empty_cache()

        print(
            f"  Computed global importance scores for {len(all_scores)} dims across {len(expert_indices)} experts")
        sorted_indices = sorted(range(len(all_scores)),
                                key=lambda i: all_scores[i], reverse=True)
        return self._create_global_groups(sorted_indices, all_positions, expert_indices, layer_idx=layer_idx)

    def _create_global_groups(self, sorted_indices, all_positions, expert_indices, layer_idx=None):
        """创建全局重要性分组"""
        total_dims = len(sorted_indices)

        high_ratio = 0.8 * (1 - self.compression_ratio) * self.layer_budget_ratio[layer_idx] / self.avg_layer_budget_ratio
        high_count = int(high_ratio * total_dims)
        medium_count = (total_dims - high_count) // 3
        low_count = total_dims - high_count - medium_count

        print(
            f"  Global grouping: {high_count} high, {medium_count} medium, {low_count} low")

        # 按专家索引组织分组结果
        per_expert_groups = {}
        for idx in expert_indices:
            per_expert_groups[idx] = {"high": [], "medium": [], "low": []}

        # 分配维度到各组
        for rank, global_idx in enumerate(sorted_indices):
            expert_idx, dim_idx = all_positions[global_idx]

            if rank < high_count:
                per_expert_groups[expert_idx]["high"].append(dim_idx)
            elif rank < high_count + medium_count:
                per_expert_groups[expert_idx]["medium"].append(dim_idx)
            else:
                per_expert_groups[expert_idx]["low"].append(dim_idx)

        # 统计总数
        total_high = sum(len(groups["high"])
                         for groups in per_expert_groups.values())
        total_mid = sum(len(groups["medium"])
                        for groups in per_expert_groups.values())
        total_low = sum(len(groups["low"])
                        for groups in per_expert_groups.values())

        print("ACI-based global grouping completed:")
        print(f"  - High importance dims: {total_high}")
        print(f"  - Medium importance dims: {total_mid}")
        print(f"  - Low importance dims: {total_low}")

        return {
            'per_expert_groups': per_expert_groups,
            'high_count': int(total_high),
            'mid_count': int(total_mid),
            'low_count': int(total_low),
        }

    def _apply_differential_compression(self, expert_indices, importance_info, experts_gate_proj_weights, experts_up_proj_weights, expert_features, group_expert_ranks, layer_idx=None, group_idx=None):
        """应用差异化压缩策略（只作用于 gate/up；按 up 的重要性分组）
        - 低组：gate 与 up 结构化剪枝（置零）
        - 高/中组：在专家内对对应行进行两段式 SVD（不同 rank），并写回相同行位置（行块压缩）。
        - 然后对“行块压缩后的完整矩阵”做一次最终两段式 SVD，得到统一的 (compressed, recovery) 因子。
        - down 不在本函数中处理。

        白化矩阵使用“该专家的输入激活”计算，严格按专家分布自适应。
        返回：compressed_gate_weights, compressed_up_weights, expert_factors
        其中 expert_factors[expert_idx] = {Gc,Grec,Uc,Urec}
        """
        expert_micro = importance_info['per_expert_groups']

        # 依据外部提供的 layer/group 参数预算，基于“中重要性组的敏感度”按比例分配（优先使用 99% 能量所需秩）
        per_expert_alloc = {idx: None for idx in expert_indices}

        if layer_idx is not None and group_idx is not None and hasattr(self, 'layer_group_params') and layer_idx in self.layer_group_params and group_idx in self.layer_group_params[layer_idx]:
            group_budget = int(
                self.layer_group_params[layer_idx][group_idx]) // 3
            # For each expert, compute a sensitivity metric based on the medium-group rows.
            sensitivities = {}
            total_sens = 0.0

            high_count = int(importance_info.get('high_count', 0))
            low_count = int(importance_info.get('low_count', 0))
            rep_w = next(iter(experts_gate_proj_weights.values()))
            hidden_dim_est = int(rep_w.shape[1]) if rep_w is not None else 1
            remain_budget = max(0, int(group_budget) -
                                high_count * hidden_dim_est)

            for idx in expert_indices:
                groups = expert_micro.get(idx, {})
                medium_rows = groups.get('medium', [])
                if not medium_rows:
                    sensitivities[idx] = 0.0
                    continue

                gw = experts_gate_proj_weights[idx]
                uw = experts_up_proj_weights[idx]
                rows_tensor = torch.tensor(
                    medium_rows, dtype=torch.long, device=gw.device)
                W_parts = []
                if gw is not None and rows_tensor.numel() > 0:
                    W_parts.append(gw.index_select(
                        0, rows_tensor).float().cpu())
                if uw is not None and rows_tensor.numel() > 0:
                    W_parts.append(uw.index_select(
                        0, rows_tensor).float().cpu())
                if not W_parts:
                    sensitivities[idx] = 0.0
                    continue
                W_med = torch.cat(W_parts, dim=0)

                # SVD on CPU for safety; compute squared singular values as energy
                with torch.no_grad():
                    _, s, _ = torch.linalg.svd(W_med)
                    energy = (s ** 2)
                    total_energy = float(
                        energy.sum().item()) if energy.numel() > 0 else 0.0
                    if total_energy <= 0:
                        rank_needed = 0
                    else:
                        cum = torch.cumsum(energy, dim=0)
                        thresh = 0.99 * total_energy
                        # find first index where cumulative >= thresh
                        idxs = (cum >= thresh).nonzero(as_tuple=False)
                        rank_needed = int(
                            idxs[0].item()) + 1 if idxs.numel() > 0 else int(len(energy))

                sensitivities[idx] = float(rank_needed)
                total_sens += sensitivities[idx]

            # 如果所有敏感度为零，则退回均等分配
            if total_sens <= 0.0:
                per = max(1, remain_budget // max(1, len(expert_indices)))
                for idx in expert_indices:
                    per_expert_alloc[idx] = per
            else:
                # 直接使用敏感度值进行分配（不做 log 平滑）
                sum_sens = sum(sensitivities.get(idx, 0.0)
                               for idx in expert_indices) or 1.0
                for idx in expert_indices:
                    per_expert_alloc[idx] = max(
                        1, int(remain_budget * (sensitivities.get(idx, 0.0) / float(sum_sens))))

        # 克隆权重以便修改
        compressed_gate_weights = {
            idx: experts_gate_proj_weights[idx].clone() for idx in expert_indices}
        compressed_up_weights = {
            idx: experts_up_proj_weights[idx].clone() for idx in expert_indices}

        # 行块压缩 helper（两段式）- 修改为返回SVD因子
        def _compress_rows_two_stage(W, row_indices, U, S_s_sqrt, S_s_inv_sqrt, expert_idx, proj_type, return_factors=False, target_params=None):
            if W is None or not row_indices:
                return W if not return_factors else (W, None, None, None)
            dev = W.device
            dtype_out = W.dtype
            rows = torch.tensor(row_indices, device=dev, dtype=torch.long)
            W_sub = W.index_select(0, rows).float()

            # 向量化白化：Ww = (W_sub @ U) * sqrt_s
            if self.vectorize_whitening and S_s_sqrt.ndim == 2 and S_s_sqrt.shape[0] == S_s_sqrt.shape[1]:
                sqrt_s = torch.diag(S_s_sqrt).to(dev)
                Ww = (W_sub @ U.to(dev)) * sqrt_s
            else:
                Ww = W_sub @ U.to(dev) @ S_s_sqrt.to(dev)

            U_w, S_w, Vt_w = torch.linalg.svd(Ww, full_matrices=False)
            hard_cap = min(Ww.shape[0], Ww.shape[1], getattr(
                self, 'gate_up_max_rank', Ww.shape[1]))

            if return_factors:
                # 基于预算约束计算最优rank：中重要性组压缩到原来的10%
                num_medium_rows = len(row_indices)
                hidden_dim = W.shape[1]  # 4096
                # 原始参数量: num_medium_rows * hidden_dim
                original_params = num_medium_rows * hidden_dim

                # 如果外部传入了 target_params（专家分配后的剩余预算），优先使用它，否则退回到默认10%
                if target_params is None:
                    target_params = int(0.1 * original_params)

                optimal_rank = target_params // (hidden_dim + num_medium_rows)
                optimal_rank = min(optimal_rank, hard_cap)

                rank = optimal_rank
                print(f"  预算约束rank选择: expert={expert_idx}, proj={proj_type}, medium_rows={num_medium_rows}, "
                      f"original_params={original_params}, target_params={target_params}, optimal_rank={optimal_rank}")

                # 计算实际压缩比（rank==0 表示全部置零）
                if rank == 0:
                    print(
                        f"  实际压缩: expert={expert_idx}, proj={proj_type}, rank=0 -> 中组置零")
                else:
                    actual_params = rank * (hidden_dim + num_medium_rows)
                    actual_ratio = actual_params / original_params
                    print(f"  实际压缩: expert={expert_idx}, proj={proj_type}, rank={rank}, "
                          f"actual_params={actual_params}, compression_ratio={actual_ratio:.3f}")

            rank = max(0, min(rank, Ww.shape[0], Ww.shape[1]))

            U_k = U_w[:, :rank]
            S_k = S_w[:rank]
            Vt_k = Vt_w[:rank, :]

            if return_factors:
                # 返回SVD因子而不还原
                # (rank, input_dim)
                compressed_factor = (
                    Vt_k @ S_s_inv_sqrt.to(dev) @ U.to(dev).T).to(dtype_out)
                recovery_factor = U_k.to(
                    dtype_out) @ torch.diag(S_k).to(dtype_out)  # (num_rows, rank)
                return W, compressed_factor, recovery_factor, rows
            else:
                # 原逻辑：还原完整矩阵
                Ww_approx = U_k @ torch.diag(S_k) @ Vt_k
                # 反白化向量化：先按列除以 sqrt_s，再右乘 U^T
                if self.vectorize_whitening and S_s_inv_sqrt.ndim == 2 and S_s_inv_sqrt.shape[0] == S_s_inv_sqrt.shape[1]:
                    inv_sqrt_s = torch.diag(S_s_inv_sqrt).to(dev)
                    Wtmp = Ww_approx * inv_sqrt_s
                    W_approx = (Wtmp @ U.to(dev).T).to(dtype_out)
                else:
                    W_approx = (Ww_approx @ S_s_inv_sqrt.to(dev)
                                @ U.to(dev).T).to(dtype_out)
                W[rows, :] = W_approx
                return W

        compression_stats = {'high_compressed': 0,
                             'medium_compressed': 0, 'low_pruned': 0}
        expert_factors = {}

        # 逐专家处理
        for expert_idx in expert_indices:
            start_time = time.time()

            # 获取重要性分组
            groups = expert_micro.get(
                expert_idx, {'high': [], 'medium': [], 'low': []})
            high_rows = groups.get('high', [])
            medium_rows = groups.get('medium', [])
            low_rows = groups.get('low', [])

            print(
                f"  - Expert {expert_idx}: 高={len(high_rows)}, 中={len(medium_rows)}, 低={len(low_rows)}")

            # 获取专家特征用于白化
            X = expert_features.get(expert_idx, None)
            U = S_s_sqrt = S_s_inv_sqrt = None
            if X is not None and X.numel() > 0:
                Xf = X.float()
                Xf = Xf - Xf.mean(dim=0, keepdim=True)
                cov = (Xf.T @ Xf) / max(1, (Xf.shape[0] - 1))
                S_vec, U_ = torch.linalg.eigh(cov)
                S_vec = S_vec.clamp(min=1e-10)
                sqrt_s_vec = torch.sqrt(S_vec)
                inv_sqrt_s_vec = 1.0 / sqrt_s_vec
                S_s_sqrt_ = torch.diag(sqrt_s_vec)
                S_s_inv_sqrt_ = torch.diag(inv_sqrt_s_vec)
                U, S_s_sqrt, S_s_inv_sqrt = U_, S_s_sqrt_, S_s_inv_sqrt_

            # 处理gate_proj
            gate_weight = compressed_gate_weights[expert_idx]
            if low_rows:
                gate_weight[low_rows, :] = 0
                compression_stats['low_pruned'] += len(low_rows)

            gate_compressed_factor = gate_recovery_factor = gate_row_indices = None
            if U is not None and medium_rows:
                print(
                    f"[INFO] 专家 {expert_idx}: 中重要性行 {len(medium_rows)}/{gate_weight.shape[0]} -> 保存SVD因子")
                # 计算并传入基于组预算分配的 target_params（如果可用）
                tp = None
                alloc = per_expert_alloc.get(expert_idx, None)
                # 保存gate_proj中重要性行的SVD因子
                _, gate_compressed_factor, gate_recovery_factor, gate_row_indices = _compress_rows_two_stage(
                    gate_weight.clone(), medium_rows, U, S_s_sqrt, S_s_inv_sqrt, expert_idx, 'gate', return_factors=True, target_params=alloc
                )
                compression_stats['medium_compressed'] += len(medium_rows)

            if U is not None and high_rows:
                # 高重要性行保留原始格式，不进行压缩
                print(
                    f"[INFO] 专家 {expert_idx}: 高重要性行 {len(high_rows)}/{gate_weight.shape[0]} -> 保留原始格式")
                # 不对高重要性行进行任何压缩，保持原始权重

            compressed_gate_weights[expert_idx] = gate_weight

            # 处理up_proj
            up_weight = compressed_up_weights[expert_idx]
            if low_rows:
                up_weight[low_rows, :] = 0

            up_compressed_factor = up_recovery_factor = up_row_indices = None
            if U is not None and medium_rows:
                # 对 up 使用与 gate 相同的 target_params 逻辑
                tp_up = None
                alloc = per_expert_alloc.get(expert_idx, None)
                # 保存up_proj中重要性行的SVD因子
                _, up_compressed_factor, up_recovery_factor, up_row_indices = _compress_rows_two_stage(
                    up_weight.clone(), medium_rows, U, S_s_sqrt, S_s_inv_sqrt, expert_idx, 'up', return_factors=True, target_params=alloc
                )

            compressed_up_weights[expert_idx] = up_weight

            end_time = time.time()
            print(
                f"  - Expert {expert_idx} compression took {end_time - start_time:.2f} seconds.")

            # 计算稀疏度统计
            gate_nonzero = (gate_weight.abs() > 1e-6).sum().item()
            gate_total = gate_weight.numel()
            up_nonzero = (up_weight.abs() > 1e-6).sum().item()
            up_total = up_weight.numel()

            # 创建专家因子
            # 对高重要性组：保存完整矩阵，但将非高行置零（便于 runtime 直接使用完整形状）
            expert_high_gate_matrix = None
            expert_high_up_matrix = None
            if high_rows:
                orig_gate = experts_gate_proj_weights[expert_idx]
                orig_up = experts_up_proj_weights[expert_idx]
                # 构建全零矩阵并仅填充高重要性行
                expert_high_gate_matrix = torch.zeros_like(orig_gate)
                expert_high_up_matrix = torch.zeros_like(orig_up)
                idx_tensor = torch.tensor(
                    high_rows, dtype=torch.long, device=orig_gate.device)
                expert_high_gate_matrix[idx_tensor,
                                        :] = orig_gate[idx_tensor, :].clone()
                expert_high_up_matrix[idx_tensor,
                                      :] = orig_up[idx_tensor, :].clone()
            else:
                expert_high_gate_matrix = torch.zeros_like(
                    experts_gate_proj_weights[expert_idx])
                expert_high_up_matrix = torch.zeros_like(
                    experts_up_proj_weights[expert_idx])

            expert_factors[expert_idx] = dict(
                # gate_high/up_high 保存为完整矩阵（非高行为 0），以便 runtime 直接使用相同形状
                gate_high=expert_high_gate_matrix,
                up_high=expert_high_up_matrix,
                high_indices=torch.tensor(
                    high_rows, dtype=torch.long) if high_rows else None,

                # 中重要性组：保存低秩分解因子（不重构）
                # SVD压缩因子 (rank, input_dim)
                gate_medium_compressed=gate_compressed_factor,
                # SVD恢复因子 (num_medium_rows, rank)
                gate_medium_recovery=gate_recovery_factor,
                # SVD压缩因子 (rank, input_dim)
                up_medium_compressed=up_compressed_factor,
                # SVD恢复因子 (num_medium_rows, rank)
                up_medium_recovery=up_recovery_factor,
                medium_indices=gate_row_indices if gate_row_indices is not None else (
                    torch.tensor(medium_rows) if medium_rows else None),

                # 低重要性组：置零（稀疏）
                low_indices=torch.tensor(low_rows) if low_rows else None,

                # 保留元数据用于调试
                original_shape=(gate_weight.shape[0], gate_weight.shape[1]),
                compression_applied=True
            )

        print(f"Differential compression stats: HighCompressed={compression_stats['high_compressed']}, "
              f"MediumCompressed={compression_stats['medium_compressed']}, "
              f"LowPruned={compression_stats['low_pruned']}")

        return expert_factors

    def _apply_differential_compression_down(self, expert_indices, importance_info_down, experts_down_weights, expert_features=None, expert_factors=None, base_down_matrix=None, R=None, layer_idx=None, group_idx=None, group_budget=None):
        """按列（hidden 输出维度）对 down_proj 做差异化压缩。

        思路：
        - importance_info_down 使用与 `_create_global_groups` 相同的分组结构（per_expert_groups）
        - 对每个专家，计算 W_delta = W_orig - base_down_matrix (若提供)
        - 将 W_delta 转置，得到 (inter_dim, hidden_dim)，对转置后的行（即原始的输出列）按 high/medium/low 分组进行两段式 SVD
        - 中组返回低秩因子（compressed, recovery），高组保留原始列切片，低组置零

        返回：compressed_down_weights, recovery_matrices, expert_down_factors
        """
        print("Applying column-wise differential compression for down_proj...")

        expert_micro = importance_info_down.get(
            'per_expert_groups', {}) if importance_info_down else {}

        # 如果提供了 group_budget，则基于所有专家的 medium 列敏感度按比例分配下游 budget
        per_expert_target_params = {idx: None for idx in expert_indices}
        group_budget = group_budget // 3
        base_matrix_budget = base_down_matrix.numel() if base_down_matrix is not None else 0
        group_budget = (group_budget - int(importance_info_down.get('high_count', 0)) * next(iter(
            experts_down_weights.values())).shape[0] - base_matrix_budget) if group_budget is not None else None
        if group_budget is not None and group_budget > 0:
            # 计算每个专家的敏感度：在其 medium_cols 上计算达到99%能量所需秩
            sensitivities = {}
            total_sens = 0.0
            for idx in expert_indices:
                groups = expert_micro.get(idx, {})
                medium_cols = groups.get('medium', [])
                if not medium_cols:
                    sensitivities[idx] = 0.0
                    continue
                try:
                    W_orig = experts_down_weights.get(idx, None)
                    if W_orig is None:
                        sensitivities[idx] = 0.0
                        continue
                    # operate on CPU for safety
                    Wc = W_orig.index_select(1, medium_cols)
                    # we want singular values of Wc.T (columns as rows) -> transpose
                    Wc_t = Wc
                    with torch.no_grad():
                        _, s, _ = torch.linalg.svd(Wc_t)
                        energy = (s ** 2)
                        total_energy = float(
                            energy.sum().item()) if energy.numel() > 0 else 0.0
                        if total_energy <= 0:
                            rank_needed = 0
                        else:
                            cum = torch.cumsum(energy, dim=0)
                            thresh = 0.99 * total_energy
                            idxs = (cum >= thresh).nonzero(as_tuple=False)
                            rank_needed = int(
                                idxs[0].item()) + 1 if idxs.numel() > 0 else int(len(energy))
                    sensitivities[idx] = float(rank_needed)
                    total_sens += sensitivities[idx]
                except Exception:
                    sensitivities[idx] = 0.0

            if total_sens <= 0.0:
                # 均等分配
                per = max(1, group_budget // max(1, len(expert_indices)))
                for idx in expert_indices:
                    per_expert_target_params[idx] = per
            else:
                # 直接使用敏感度值进行分配（不做 log 平滑）
                sum_sens = sum(sensitivities.get(idx, 0.0)
                               for idx in expert_indices) or 1.0
                for idx in expert_indices:
                    per_expert_target_params[idx] = max(
                        1, int(group_budget * (sensitivities.get(idx, 0.0) / float(sum_sens))))

        compressed_down = {}
        recovery_down = {}
        expert_down_factors = {}

        # 行块压缩器（对转置后的矩阵进行行压缩，返回因子）
        def _compress_cols_two_stage_local(W, col_indices, U=None, S_s_sqrt=None, S_s_inv_sqrt=None, expert_idx=None, proj_type='down', return_factors=True, target_params=None):
            dev = W.device
            dtype_out = W.dtype
            cols = torch.tensor(col_indices, device=dev, dtype=torch.long)
            W_sub = W.index_select(1, cols).float()

            # 白化（如果提供了 U / S_s_sqrt）
            if U is not None and S_s_sqrt is not None and S_s_inv_sqrt is not None:
                try:
                    # 优先使用向量化形式
                    Ww = (W_sub @ U.to(W.device)) * \
                        S_s_sqrt.unsqueeze(0).to(W.device)
                except Exception:
                    # 退回到明确的矩阵乘法
                    Ww = W_sub @ U.to(W.device) @ S_s_sqrt.to(W.device)
            else:
                Ww = W_sub

            # 直接对 Ww 做SVD（此处不做白化，保持稳定）
            U_w, S_w, Vt_w = torch.linalg.svd(Ww, full_matrices=False)
            hard_cap = min(W_sub.shape[0], W_sub.shape[1])

            # 基于预算约束选择rank（简单策略：中组压缩到原始参数的10%或基于target_params）
            num_medium_cols = len(col_indices)
            # W_sub shape: (hidden_dim, num_medium_cols)
            hidden_dim = W_sub.shape[0]
            original_params = num_medium_cols * hidden_dim
            if target_params is None:
                target_params = int(0.1 * original_params)
            optimal_rank = target_params // (hidden_dim + num_medium_cols) if (
                hidden_dim + num_medium_cols) > 0 else 0
            optimal_rank = min(optimal_rank, hard_cap)
            rank = max(1, int(optimal_rank))

            U_k = U_w[:, :rank]        # (hidden_dim, rank)
            S_k = S_w[:rank]           # (rank,)
            Vt_k = Vt_w[:rank, :]      # (rank, num_medium_cols)

            # 期望输出（供 runtime 使用）:
            # - compressed_factor: (rank, hidden_dim)  (用于 rank->hidden_dim 的线性层, 存为 dc)
            # - recovery_factor: (num_medium_cols, rank) (用于 num_med->rank 的线性层, 存为 dr)
            # compressed = (U_k @ diag(S_k)).T -> shape (rank, hidden_dim)
            compressed_factor = (U_k @ torch.diag(S_k)).T.to(dtype_out)
            # recovery = Vt_k.T -> shape (num_medium_cols, rank)
            recovery_factor = (Vt_k @ S_s_inv_sqrt.to(dev) @
                               U.to(dev).T).T.to(dtype_out)

            if recovery_factor.dim() == 3:
                recovery_factor = recovery_factor.squeeze(0)

            return W, compressed_factor, recovery_factor, cols

        for expert_idx in expert_indices:
            groups = expert_micro.get(
                expert_idx, {'high': [], 'medium': [], 'low': []})
            high_cols = groups.get('high', [])
            medium_cols = groups.get('medium', [])
            low_cols = groups.get('low', [])

            print(
                f"  - Expert {expert_idx} down cols: high={len(high_cols)}, mid={len(medium_cols)}, low={len(low_cols)}")

            W_orig = experts_down_weights.get(expert_idx)
            if W_orig is None:
                continue

            # 准备工作矩阵和 delta
            W = W_orig.clone()
            if base_down_matrix is not None:
                W = W.to(base_down_matrix.device)
                # (hidden_dim, inter_dim)
                W_delta = (W - base_down_matrix).to(W.dtype)
                Wt = W_delta
            else:
                # 无基准矩阵时直接使用原始 W
                W = W.to(next(iter(experts_down_weights.values())).device)
                Wt = W

            if low_cols:
                W[:, low_cols] = 0

            down_compressed_factor = None
            down_recovery = None
            down_row_indices = None

            # 计算该专家的白化矩阵（如果提供了 expert_features）
            X = None
            U = S_s_sqrt = S_s_inv_sqrt = None
            if isinstance(expert_features, dict):
                X = expert_features.get(expert_idx, None)
                X = X['Xc']
                X = X.index_select(1, torch.tensor(
                    medium_cols, dtype=torch.long).to(X.device))
            if X is not None and isinstance(X, torch.Tensor) and X.numel() > 0:
                Xf = X.float()
                Xf = Xf - Xf.mean(dim=0, keepdim=True)
                cov = (Xf.T @ Xf) / max(1, (Xf.shape[0] - 1))
                S_vec, U_ = torch.linalg.eigh(cov)
                S_vec = S_vec.clamp(min=1e-10)
                sqrt_s_vec = torch.sqrt(S_vec)
                inv_sqrt_s_vec = 1.0 / sqrt_s_vec
                S_s_sqrt_ = torch.diag(sqrt_s_vec)
                S_s_inv_sqrt_ = torch.diag(inv_sqrt_s_vec)
                U, S_s_sqrt, S_s_inv_sqrt = U_, S_s_sqrt_, S_s_inv_sqrt_

            if medium_cols:
                tp = None
                if per_expert_target_params.get(expert_idx) is not None:
                    tp = per_expert_target_params.get(expert_idx)
                _, comp_fac, rec_fac, rows = _compress_cols_two_stage_local(
                    Wt, medium_cols, U=U, S_s_sqrt=S_s_sqrt, S_s_inv_sqrt=S_s_inv_sqrt, expert_idx=expert_idx, proj_type='down', return_factors=True, target_params=tp)
                down_compressed_factor = comp_fac
                down_recovery = rec_fac
                down_row_indices = rows

            # 高重要性列保留为完整矩阵（非高列置零），以便 runtime 直接使用相同形状
            down_high_matrix = None
            if high_cols:
                orig = W_orig
                # 构建与原始相同形状的全零矩阵并仅填充高重要性列
                down_high_matrix = torch.zeros_like(orig)
                idx_tensor = torch.tensor(
                    high_cols, dtype=torch.long, device=orig.device)
                down_high_matrix[:, idx_tensor] = orig[:, idx_tensor].clone()
            else:
                down_high_matrix = torch.zeros_like(W_orig)

            # 构建返回结构
            ef = {}
            if down_compressed_factor is not None:
                ef['down_medium_compressed'] = down_compressed_factor
                ef['down_medium_recovery'] = down_recovery
                ef['down_medium_indices'] = down_row_indices
            if down_high_matrix is not None:
                ef['down_high'] = down_high_matrix
            if low_cols:
                ef['down_low_indices'] = torch.tensor(
                    low_cols, dtype=torch.long)

            # 保留原始形状信息
            ef['down_original_shape'] = (W_orig.shape[0], W_orig.shape[1])
            expert_down_factors[expert_idx] = ef

        print("Column-wise differential compression for down_proj completed.")
        return expert_down_factors

    def _create_and_replace_oage_moe_v2(self, model, layer_idx, experts_factors, layer_base_down=None):
        """
        使用新的专家因子格式创建并替换OAGE MoE模块

        Args:
            model: 模型对象
            layer_idx: 层索引
            experts_factors: dict {expert_idx: expert_factors} 新格式的压缩信息
        """
        print(f"Creating OAGE MoE v2 for layer {layer_idx}...")

        # 获取原始MoE模块
        original_moe = model.model.layers[layer_idx].mlp

        # 替换原始模块
        model.model.layers[layer_idx].mlp = OAGEDeepseekMoE(
            original_moe=original_moe,
            experts_factors=experts_factors,
            layer_base_down=layer_base_down,
        )

        del original_moe
        torch.cuda.empty_cache()

        print(f"Layer {layer_idx} MoE module replaced with OAGE MoE v2")

    def _create_and_replace_oage_moe(self, model, layer_idx, layer_compressed_weights, layer_recovery_matrices, layer_bias_corrections, layer_based_weights):
        """从压缩数据创建并替换OAGE MoE模块（旧版本）"""
        print(f"Creating OAGE MoE for layer {layer_idx}...")

        # 获取原始MoE模块
        original_moe = model.model.layers[layer_idx].mlp

        # 获取该层的专家分组信息
        expert_groups = self.layers_expert_groups[layer_idx]

        # 创建OAGE MoE模块
        oage_moe = OAGEDeepseekMoE(
            original_moe=original_moe,
            expert_groups=expert_groups,
            recovery_matrices=layer_recovery_matrices,
            compressed_weights=layer_compressed_weights,
            bias_corrections=layer_bias_corrections,
            base_weights=layer_based_weights,
        )

        # 替换原始模块
        model.model.layers[layer_idx].mlp = oage_moe

        print(f"Layer {layer_idx} MoE module replaced with OAGE MoE")


class AddAuxiliaryLoss(torch.autograd.Function):
    """
    The trick function of adding auxiliary (aux) loss, 
    which includes the gradient of the aux loss during backpropagation.
    """
    @staticmethod
    def forward(ctx, x, loss):
        assert loss.numel() == 1
        ctx.dtype = loss.dtype
        ctx.required_aux_loss = loss.requires_grad
        return x

    @staticmethod
    def backward(ctx, grad_output):
        grad_loss = None
        if ctx.required_aux_loss:
            grad_loss = torch.ones(
                1, dtype=ctx.dtype, device=grad_output.device)
        return grad_output, grad_loss


class OAGEDeepseekMLP(nn.Module):
    def __init__(self, expert_factors, act_fn, anchor_weight=None):
        """
        新的OAGE MLP结构，支持高/中/低重要性组的分离处理

        Args:
            expert_factors: dict 包含压缩信息
                - gate_high, up_high: 高重要性行的完整矩阵（已压缩）
                - gate_medium_compressed, gate_medium_recovery: 中重要性行的低秩因子
                - up_medium_compressed, up_medium_recovery: 中重要性行的低秩因子
                - high_indices, medium_indices, low_indices: 各组行索引
            act_fn: 激活函数
        """
        super().__init__()

        self.act_fn = act_fn

        # 提取各组信息
        self.medium_indices = expert_factors.get('medium_indices', None)

        # 高重要性组：完整矩阵（已压缩）
        gate_high = expert_factors.get('gate_high', None)
        up_high = expert_factors.get('up_high', None)

        # 高重要性组以完整矩阵形式提供（非高行应为 0）。
        # presence 仅基于 gate_high/up_high 是否存在
        self.has_high_group = True if (
            gate_high is not None or up_high is not None) else False

        # 中重要性组：低秩因子
        gate_compressed = expert_factors.get('gate_medium_compressed', None)
        gate_recovery = expert_factors.get('gate_medium_recovery', None)
        up_compressed = expert_factors.get('up_medium_compressed', None)
        up_recovery = expert_factors.get('up_medium_recovery', None)

        # down 因子（高/中/低）
        down_comp = expert_factors.get('down_medium_compressed', None)
        down_rec = expert_factors.get('down_medium_recovery', None)
        down_med_indices = expert_factors.get('down_medium_indices', None)
        down_high = expert_factors.get('down_high', None)
        down_orig_shape = expert_factors.get('down_original_shape', None)

        # 推断设备与 dtype（优先使用 anchor_weight，如果有的话）
        if anchor_weight is not None:
            target_device = anchor_weight.device
            target_dtype = anchor_weight.dtype
        elif gate_high is not None:
            target_device = gate_high.device
            target_dtype = gate_high.dtype
        elif up_high is not None:
            target_device = up_high.device
            target_dtype = up_high.dtype
        elif gate_compressed is not None:
            target_device = gate_compressed.device
            target_dtype = gate_compressed.dtype
        else:
            target_device = torch.device(
                'cuda:0' if torch.cuda.is_available() else 'cpu')
            target_dtype = torch.get_default_dtype()

        if (gate_compressed is not None and gate_recovery is not None and
            up_compressed is not None and up_recovery is not None and
                self.medium_indices is not None):
            self.has_medium_group = True

            # gate_proj中重要性组低秩分解: input -> compressed -> recovery
            self.gate_medium_compressed = nn.Linear(gate_compressed.shape[1], gate_compressed.shape[0], bias=False).to(
                device=target_device, dtype=target_dtype)
            self.gate_medium_compressed.weight.data.copy_(
                gate_compressed.to(device=target_device, dtype=target_dtype))

            self.gate_medium_recovery = nn.Linear(gate_recovery.shape[1], gate_recovery.shape[0], bias=False).to(
                device=target_device, dtype=target_dtype)
            self.gate_medium_recovery.weight.data.copy_(
                gate_recovery.to(device=target_device, dtype=target_dtype))

            # up_proj中重要性组低秩分解
            self.up_medium_compressed = nn.Linear(up_compressed.shape[1], up_compressed.shape[0], bias=False).to(
                device=target_device, dtype=target_dtype)
            self.up_medium_compressed.weight.data.copy_(
                up_compressed.to(device=target_device, dtype=target_dtype))

            self.up_medium_recovery = nn.Linear(up_recovery.shape[1], up_recovery.shape[0], bias=False).to(
                device=target_device, dtype=target_dtype)
            self.up_medium_recovery.weight.data.copy_(
                up_recovery.to(device=target_device, dtype=target_dtype))
        else:
            self.has_medium_group = False

        # 注册 down 的中/高重要性因子为 Module/Parameter，便于 forward 直接调用
        self.has_down_medium = False
        self.has_down_high = False
        if down_comp is not None and down_rec is not None and down_med_indices is not None:
            # down_rec: (num_med, rank) -> we need Linear(num_med -> rank) with weight = down_rec.T
            # down_comp: (rank, hidden_dim) -> we need Linear(rank -> hidden_dim) with weight = down_comp.T
            try:
                dc = down_comp.to(device=target_device, dtype=target_dtype)
                dr = down_rec.to(device=target_device, dtype=target_dtype)
            except Exception:
                dc = down_comp
                dr = down_rec

            # recovery: num_med -> rank
            self.down_medium_recovery = nn.Linear(dr.shape[0], dr.shape[1], bias=False).to(
                device=target_device, dtype=target_dtype)
            self.down_medium_recovery.weight.data.copy_(
                dr.T.to(device=target_device, dtype=target_dtype))

            # compressed: rank -> hidden_dim
            self.down_medium_compressed = nn.Linear(dc.shape[0], dc.shape[1], bias=False).to(
                device=target_device, dtype=target_dtype)
            self.down_medium_compressed.weight.data.copy_(
                dc.T.to(device=target_device, dtype=target_dtype))

            # medium indices tensor
            self.down_medium_indices = down_med_indices if isinstance(
                down_med_indices, torch.Tensor) else torch.tensor(down_med_indices, dtype=torch.long)
            self.has_down_medium = True

        if down_high is not None:
            # down_high: expected full matrix (hidden_dim, inter_dim) where non-high columns are zero
            try:
                dh = down_high.to(device=target_device, dtype=target_dtype)
            except Exception:
                dh = down_high

            # infer shapes: dh shape should be (hidden_dim, inter_dim)
            hidden_dim = down_orig_shape[0] if down_orig_shape is not None else (
                dh.shape[0] if dh is not None else None)
            inter_dim = dh.shape[1]
            # create Linear(inter_dim -> hidden_dim) with weight = dh
            self.down_high_linear = nn.Linear(inter_dim, hidden_dim, bias=False).to(
                device=target_device, dtype=target_dtype)
            # weight shape for nn.Linear is (out_features, in_features)
            self.down_high_linear.weight.data.copy_(
                dh.to(device=target_device, dtype=target_dtype))
            self.has_down_high = True

        # 保存原始矩阵形状用于重构
        self.original_shape = expert_factors.get('original_shape', None)

        # 创建用于高重要性组的线性层，期望传入的 gate_high/up_high 为完整矩阵：
        # gate_high: (intermediate_size, hidden_dim) -> Linear(hidden_dim -> intermediate_size)
        # up_high:   (intermediate_size, hidden_dim) -> Linear(hidden_dim -> intermediate_size)
        if gate_high is not None:
            gh = gate_high.to(device=target_device, dtype=target_dtype)
            self.gate_high_linear = nn.Linear(gh.shape[1], gh.shape[0], bias=False).to(
                device=target_device, dtype=target_dtype)
            # weight shape: (out_features, in_features) == (intermediate_size, hidden_dim)
            self.gate_high_linear.weight.data.copy_(gh)
        else:
            self.gate_high_linear = None

        if up_high is not None:
            uh = up_high.to(device=target_device, dtype=target_dtype)
            self.up_high_linear = nn.Linear(uh.shape[1], uh.shape[0], bias=False).to(
                device=target_device, dtype=target_dtype)
            self.up_high_linear.weight.data.copy_(uh)
        else:
            self.up_high_linear = None

        # 确保整个 module 在正确的设备与 dtype 上
        self.to(device=target_device, dtype=target_dtype)

        # 冻结所有参数
        for param in self.parameters():
            param.requires_grad = False

    def forward(self, hidden_states, base_down_proj=None):
        """
        前向传播：分别处理高/中/低重要性组，拼回原输出shape，并将输入通过基准down_proj，最后两者相加。

        Args:
            hidden_states: 输入特征 [batch_size, seq_len, hidden_dim]
            base_down_proj: 基准down_proj权重 (hidden_dim, inter_dim) 或 nn.Linear
        """
        batch_size, seq_len, hidden_dim = hidden_states.shape

        intermediate_size = self.original_shape[0]

        # 初始化gate和up的输出
        gate_output = torch.zeros(batch_size, seq_len, intermediate_size,
                                  dtype=hidden_states.dtype, device=hidden_states.device)
        up_output = torch.zeros(batch_size, seq_len, intermediate_size,
                                dtype=hidden_states.dtype, device=hidden_states.device)

        # 处理高重要性组（完整矩阵）: gate_high_linear/up_high_linear 期望直接输出完整 intermediate_size
        if self.has_high_group:
            # 如果 gate_high_linear 存在，则其输出形状应为 [B, S, intermediate_size]
            if getattr(self, 'gate_high_linear', None) is not None:
                gate_high_out_full = self.gate_high_linear(
                    hidden_states)  # [B, S, intermediate_size]
                gate_output += gate_high_out_full
            if getattr(self, 'up_high_linear', None) is not None:
                up_high_out_full = self.up_high_linear(
                    hidden_states)  # [B, S, intermediate_size]
                up_output += up_high_out_full

        # 处理中重要性组（低秩因子）
        if self.has_medium_group and self.medium_indices is not None:
            medium_indices_device = self.medium_indices.to(
                hidden_states.device)
            gate_compressed_out = self.gate_medium_compressed(
                hidden_states)  # [B, S, rank]
            gate_medium_out = self.gate_medium_recovery(
                gate_compressed_out)  # [B, S, num_medium_rows]
            up_compressed_out = self.up_medium_compressed(
                hidden_states)      # [B, S, rank]
            up_medium_out = self.up_medium_recovery(
                up_compressed_out)        # [B, S, num_medium_rows]
            gate_output[:, :, medium_indices_device] = gate_medium_out
            up_output[:, :, medium_indices_device] = up_medium_out

        # 计算MoE中间态
        intermediate_state = self.act_fn(
            gate_output) * up_output  # [B, S, intermediate_size]

        # 下游投影：分别处理高/中重要性组
        # 1. 高重要性组
        down_out = torch.zeros(batch_size, seq_len, hidden_dim,
                               dtype=hidden_states.dtype, device=hidden_states.device)
        
        # 下游投影：分别处理高/中重要性组（使用已注册的 modules/params）
        down_out = torch.zeros(batch_size, seq_len, hidden_dim,
                                dtype=hidden_states.dtype, device=hidden_states.device)
        # 高重要性列（使用 down_high_linear，按整体中间态投影；down_high 矩阵对非高列已是 0）
        if self.has_down_high and getattr(self, 'down_high_linear', None) is not None:
            # down_high_linear maps full intermediate_size -> hidden_dim
            down_high_out = self.down_high_linear(intermediate_state)
            down_out += down_high_out

        # 中重要性列（先 recovery num_med->rank，再 compressed rank->hidden_dim）
        if self.has_down_medium:
            med_idx_dev = self.down_medium_indices.to(
                hidden_states.device)
            # [B,S,num_med]
            interm_med = intermediate_state[:, :, med_idx_dev]
            interm_med_rank = self.down_medium_recovery(
                interm_med)  # [B,S,rank]
            down_med_out = self.down_medium_compressed(
                interm_med_rank)  # [B,S,hidden_dim]
            down_out += down_med_out

        # 基准 down_proj（如果传入）: 使用 torch.nn.functional.linear 保持行为与 F.linear 一致
        if base_down_proj is not None:
            if isinstance(base_down_proj, torch.nn.Linear):
                base_out = base_down_proj(intermediate_state)
            else:
                # 假设 base_down_proj 为权重矩阵 shape (out_features, in_features) == (hidden_dim, intermediate_size)
                base_out = torch.nn.functional.linear(
                    intermediate_state, base_down_proj)
            down_out += base_out

        return down_out


class OAGEDeepseekMoE(nn.Module):
    """
    简洁的 OAGE MoE runtime：不在此处处理 down_proj 的分解/合成，
    而是将 layer 级别的 base_down 传入每个压缩专家，由专家模块负责合成 base + delta。
    """

    def __init__(self, original_moe, experts_factors, layer_base_down=None):
        super().__init__()
        self.gate = original_moe.gate
        self.shared_experts = original_moe.shared_experts
        self.num_experts_per_tok = original_moe.num_experts_per_tok
        self.act_fn = original_moe.experts[0].act_fn

        # 保存按层共享的基准 down 矩阵（直接保存为 Tensor 或 None）
        if layer_base_down is not None:
            try:
                bd = layer_base_down.to(
                    device=original_moe.experts[0].down_proj.weight.device,
                    dtype=original_moe.experts[0].down_proj.weight.dtype)
            except Exception:
                bd = layer_base_down
            # 不作为可训练参数
            self.layer_base_down = bd
        else:
            self.layer_base_down = None

        # 创建压缩专家模块列表
        self.compressed_experts = nn.ModuleList()
        self.has_compressed_expert = {}
        for expert_idx, expert_factors in experts_factors.items():
            if expert_factors.get('compression_applied', False):
                compressed_expert = OAGEDeepseekMLP(
                    expert_factors, self.act_fn)
                self.compressed_experts.append(compressed_expert)
                self.has_compressed_expert[expert_idx] = len(
                    self.compressed_experts) - 1
            else:
                self.has_compressed_expert[expert_idx] = -1

    def forward(self, hidden_states):
        identity = hidden_states
        orig_shape = hidden_states.shape
        topk_idx, topk_weight, aux_loss = self.gate(hidden_states)
        hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
        flat_topk_idx = topk_idx.view(-1)
        y = self.moe_infer(hidden_states, flat_topk_idx,
                           topk_weight.view(-1, 1)).view(*orig_shape)
        y = y + self.shared_experts(identity)
        return y

    @torch.no_grad()
    def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
        expert_cache = torch.zeros_like(x)
        idxs = flat_expert_indices.argsort()
        tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0)
        token_idxs = idxs // self.num_experts_per_tok

        for i, end_idx in enumerate(tokens_per_expert):
            start_idx = 0 if i == 0 else tokens_per_expert[i - 1]
            if start_idx == end_idx:
                continue

            expert_idx = i
            exp_token_idx = token_idxs[start_idx:end_idx]
            expert_tokens = x[exp_token_idx]

            if expert_idx in self.has_compressed_expert:
                compressed_idx = self.has_compressed_expert[expert_idx]
                if compressed_idx >= 0:
                    base_down = self.layer_base_down
                    expert_out = self.compressed_experts[compressed_idx](
                        expert_tokens.unsqueeze(1), base_down_proj=base_down
                    ).squeeze(1)
                else:
                    raise NotImplementedError("原始专家应该从压缩配置中移除")
            else:
                raise NotImplementedError("所有专家都应该在压缩配置中")

            expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
            expert_cache.scatter_reduce_(
                0, exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]), expert_out, reduce='sum'
            )

        return expert_cache
