import os
from collections import OrderedDict
from copy import deepcopy
from typing import Dict, Optional, List, Union, Callable, Iterator, Tuple

import torch
from torch import nn
import re
import math
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from torch.nn import functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import (
    Qwen3MoeConfig,
    PretrainedConfig,
    Qwen3MoeForCausalLM
)

from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeMLP, Qwen3MoeSparseMoeBlock

from .permutation import (
    permute_OLMoE_mlp_dense_expert_,
    compute_OLMoE_permutation_by_weight_matching,
    compute_OLMoE_permutation_by_activation_matching,
    merge_olmoe_mlp_by_activation_matching_within_and_across_models,
)

from .utils import generate_random_group_labels
from ..utils.constants import FP32_EPS
from .arcee_fusion import ArceeFusionMerge

__all__ = [
    'ExpertsGrouperForQwen',
    'LEGAL_SIMILARITY_BASES',
    'SIMILARITY_MAPPING_FUNCTION',
    'Qwen_merge_by_groups_with_usage_frequency_weighting',
    'Qwen_merge_by_SVD'
]

SIMILARITY_MAPPING_FUNCTION = {
    "cosine": lambda x, y: (F.cosine_similarity(x, y, dim=-1, eps=FP32_EPS) + 1) / 2,
    "mse": lambda x, y: 1 / (1 + 0.1 * torch.log(F.mse_loss(x, y, reduction="sum"))),
    "l2": lambda x, y: F.mse_loss(x, y, reduction="sum"),
}
LEGAL_SIMILARITY_BASES = ["weight", "feature", "feature.abs", "weight-feature", "gradient", "weight-gradient",
                          "router-logits", "router-weight", "router-weight-feature", "mse", "random",
                          "feature-correlation.lsa", "feature-correlation.max", "gate-weight", "gate-up-weight", "gate-act"]

class ExpertsGrouperForQwen(object):
    def __init__(
            self,
            config: Union[Qwen3MoeConfig, PretrainedConfig],
            similarity_fn: str = "cosine",
            similarity_base: str = "weight",
    ):
        if similarity_fn not in SIMILARITY_MAPPING_FUNCTION:
            raise ValueError(
                f"[Merging]similarity_fn should be one of {SIMILARITY_MAPPING_FUNCTION.keys()}, got {similarity_fn} instead.")
        if similarity_base not in LEGAL_SIMILARITY_BASES:
            raise ValueError(
                f"[Merging]similarity_base should be one of {LEGAL_SIMILARITY_BASES}, got {similarity_base} instead.")

        self.num_experts = config.num_experts
        self.num_experts_per_tok = config.num_experts_per_tok
        self.avg_num_merged_experts = self.num_experts
        self.hidden_size = config.hidden_size
        self.sparse_layer_indices = list(range(0, config.num_hidden_layers))
        self.similarity_fn = SIMILARITY_MAPPING_FUNCTION[similarity_fn]
        self.similarity_base = similarity_base
        self._group_state_dict = None
        self._similarity_state_dict = None
        self._usage_frequency_state_dict = None
        self.cross_layer_activation_stats = None # 跨层协同
        self.reset_all()

        # SVD
        self.composed_matrixes = dict()

    def reset_all(self):
        if self.similarity_base == "mse":
            self.similarity_fn = SIMILARITY_MAPPING_FUNCTION["mse"]
            print("[Merging]Set similarity_fn to mse for mse similarity_base.")
        self._group_state_dict = dict()
        self._similarity_state_dict = dict()
        self._usage_frequency_state_dict = dict()
        # Similarity range: [0, 2]
        for layer_idx in self.sparse_layer_indices:
            mlp_name = f"model.layers.{layer_idx}.mlp"
            self._group_state_dict[mlp_name] = torch.arange(self.num_experts,
                                                                    device="cuda")
            self._similarity_state_dict[mlp_name] = torch.zeros(
                (self.num_experts, self.num_experts), device="cuda"
            ) + torch.eye(self.num_experts, device="cuda")
            self._usage_frequency_state_dict[mlp_name] = torch.zeros(self.num_experts, device="cuda")
            self._usage_frequency_state_dict[mlp_name] = torch.zeros(self.num_experts, device="cuda")

        self.transport_matrices = dict()

    def similarity_state_dict(self) -> Dict[str, torch.Tensor]:
        return deepcopy(self._similarity_state_dict)

    def group_state_dict(self) -> Dict[str, torch.LongTensor]:
        return deepcopy(self._group_state_dict)

    def usage_frequency_state_dict(self) -> Dict[str, torch.Tensor]:
        return deepcopy(self._usage_frequency_state_dict)

    def save_similarity(self, mlp_name: str, i: int, j: int, similarity: float):
        self._similarity_state_dict[mlp_name][i, j] = similarity
        self._similarity_state_dict[mlp_name][j, i] = similarity

    def get_similarity(self, mlp_name: str, i: int, j: int) -> float:
        return self._similarity_state_dict[mlp_name][i, j].item()

    def get_similarity_matrix(self, mlp_name: str) -> torch.Tensor:
        return deepcopy(self._similarity_state_dict[mlp_name])

    def get_transport_matrix(self, mlp_name: str) -> torch.Tensor:
        return deepcopy(self.transport_matrixes[mlp_name])

    def get_composed_matrixes(self, mlp_name: str) -> List[torch.Tensor]:
        return self.composed_matrixes[mlp_name]

    def del_composed_matrixes(self, mlp_name: str):
        del self.composed_matrixes[mlp_name]

    def save_group_state_dict(self, save_dir: str):
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        torch.save(self._group_state_dict, os.path.join(save_dir, "group_state_dict.pt"))

    def load_group_state_dict(self, load_dir: str):
        self._group_state_dict = torch.load(os.path.join(load_dir, "group_state_dict.pt"))

    def set_avg_num_merged_experts(self, avg_num_experts: int):
        self.avg_num_merged_experts = avg_num_experts

    def _assign_num_groups_per_layer(
            self,
            average_num_groups: int,
            merging_layers: List[int],
    ) -> Dict[str, int]:
        
        num_groups_per_layer = dict()
        for i, layer_idx in enumerate(self.sparse_layer_indices):
            if layer_idx not in merging_layers:
                num_groups_per_layer[f"model.layers.{layer_idx}.mlp"] = self.num_experts
            else:
                num_groups_per_layer[f"model.layers.{layer_idx}.mlp"] = average_num_groups

        return num_groups_per_layer

    def group_experts_into_clusters_by_routing_guided_globally(
            self,
            average_num_groups: int,
            merging_layers: List[int],
            layer_group_capacity: Optional[int] = None,
    ) -> Dict[str, List[int]]:
        """
        Globally group experts into clusters by routing-guided clustering, each layer will have different number of
         clusters. The total number of clusters is determined by average_num_groups.

        Parameters
        ----------
        average_num_groups: int
            The average number of clusters for all layers.
        merging_layers: List[int]
            The layers that are excluded from merging.
        layer_group_capacity: Optional[int]
            The maximum number of experts in each group in the layers. If None, the number of experts in each group is not limited.

        Returns
        -------
        core_experts: Dict[str, List[int]]
            The core experts of each cluster
        """
        # By default, the first layer of encoder is excluded.
        # 1. Assign num_groups respectively for each layer according to average_num_groups
        # 计算每一层的组数
        layer_group_capacity = layer_group_capacity if layer_group_capacity is not None else self.num_experts
        num_groups_per_layer = self._assign_num_groups_per_layer(
            average_num_groups, merging_layers, 
        )
        # print(f"[Merging]Number of groups of each layer: {num_groups_per_layer}")
        # print("[Merging] Number of groups of each layer 每一层的组数:")
        # for layer_name, num_groups in num_groups_per_layer.items():
        #     print(f"  {layer_name}: {num_groups}")
        # 2. Group experts into clusters for each layer
        core_experts = dict()
        for layer_idx in tqdm(self.sparse_layer_indices,
                              desc=f"Globally routing-guided clustering experts into average {average_num_groups} clusters"
        ):
            if merging_layers is not None and layer_idx not in merging_layers:
                continue
            mlp_name = f"model.layers.{layer_idx}.mlp"
            num_groups = num_groups_per_layer[mlp_name]
            group_member_count = torch.zeros(num_groups)

            indices_sorted_by_usage = torch.argsort(self._usage_frequency_state_dict[mlp_name], descending=True)
            # 1.1 Assign top-K most-used experts with label 0 to K-1 respectively
            # 选择使用频率最高的专家作为核心专家
            core_expert_indices = indices_sorted_by_usage[:num_groups]
            core_experts[mlp_name] = core_expert_indices.tolist()
            #print(f"\n第 {layer_idx} 层 ({mlp_name}):")
            #print(f"  核心专家（簇中心）: {core_expert_indices.tolist()}")
            for i in range(num_groups):
                self._group_state_dict[mlp_name][core_expert_indices[i]] = i
                group_member_count[i] += 1

            # 1.2 Assign left unassigned experts to the cluster with the most similar core
            # 剩余专家分配
            similarity_matrix = self.get_similarity_matrix(mlp_name)
            for i in range(num_groups, self.num_experts):
                # Find the most similar core
                expert_idx = indices_sorted_by_usage[i]
                most_similar_core = core_expert_indices[
                    torch.argmax(similarity_matrix[expert_idx, core_expert_indices])
                ]
                most_similar_group_label = self._group_state_dict[mlp_name][most_similar_core]
                self._group_state_dict[mlp_name][expert_idx] = most_similar_group_label
                group_member_count[most_similar_group_label] += 1
                #print(f"  专家 {expert_idx.item()} → 分配给 核心专家 {most_similar_core.item()} 的簇（簇编号 {most_similar_group_label}）")
                if group_member_count[self._group_state_dict[mlp_name][expert_idx]] > layer_group_capacity:
                    if len(core_expert_indices) == 1:
                        raise ValueError(
                            f"[Merging]The number of groups at Encoder layer {layer_idx} is too small!"
                        )
                    # Kick out the filled group as well as its core, by pop the core from core_experts
                    core_index = torch.argmax(similarity_matrix[expert_idx, core_expert_indices])
                    core_expert_indices = torch.cat(
                        [core_expert_indices[:core_index], core_expert_indices[core_index + 1:]]
                    )
            print(f"[Merging] Layer {layer_idx}成员数: {group_member_count.tolist()}")

        return core_experts

    def group_experts_into_equal_clusters_by_routing_guided_globally(
        self,
        num_groups: int,
        merging_layers: List[int],
    ) -> Dict[str, List[int]]:
        """
        均匀分组：根据激活频率,相似度，将专家分成大小均衡的簇
        参数
        num_groups: int每一层要分的组数。
        merging_layers: List[int]需要合并的层。
        返回
        core_experts: Dict[str, List[int]]每组的核心专家
        """
        core_experts = dict()
        # 遍历每一稀疏层
        for layer_idx in tqdm(self.sparse_layer_indices,
                            desc=f"[Merging]均匀分组：每层分成 {num_groups} 个簇 (usage+similarity guided)"
        ):
            if merging_layers is not None and layer_idx not in merging_layers:
                continue
            mlp_name = f"model.layers.{layer_idx}.mlp"
            # 每组的容量
            experts_per_group = self.num_experts // num_groups
            remainder = self.num_experts % num_groups
            max_capacity = [experts_per_group + (1 if i < remainder else 0) for i in range(num_groups)]

            group_member_count = torch.zeros(num_groups)
            # 按使用频率排序
            indices_sorted_by_usage = torch.argsort(self._usage_frequency_state_dict[mlp_name], descending=True)

            # 1: 选核心专家
            core_expert_indices = indices_sorted_by_usage[:num_groups]
            core_experts[mlp_name] = core_expert_indices.tolist()
            for i, core in enumerate(core_expert_indices):
                self._group_state_dict[mlp_name][core.item()] = i
                group_member_count[i] += 1

            # 2: 分配剩余专家
            similarity_matrix = self.get_similarity_matrix(mlp_name)
            for i in range(num_groups, self.num_experts):
                expert_idx = indices_sorted_by_usage[i]
                # 计算该专家与所有核心专家的相似度
                sim_scores = similarity_matrix[expert_idx, core_expert_indices]
                sorted_core_indices = torch.argsort(sim_scores, descending=True)

                assigned = False
                for core_rank in sorted_core_indices:
                    core = core_expert_indices[core_rank]
                    group_label = self._group_state_dict[mlp_name][core.item()]
                    if group_member_count[group_label] < max_capacity[group_label]:
                        self._group_state_dict[mlp_name][expert_idx.item()] = group_label
                        group_member_count[group_label] += 1
                        assigned = True
                        break

                if not assigned:
                    raise RuntimeError(f"[Merging] Expert {expert_idx.item()} 无法分配到任何组")

            print(f"[Merging] Layer {layer_idx}成员数: {group_member_count.tolist()}")
            # print(f"\n[Merging] Layer {layer_idx} ({mlp_name}):")
            # for group_id, core in enumerate(core_expert_indices.tolist()):
            #     members = [
            #         expert for expert, label in enumerate(self._group_state_dict[mlp_name])
            #         if label == group_id
            #     ]
            #     print(f"  核心专家 {core}: {members}")
            
        return core_experts

    def compute_all_usages(
            self,
            model: Qwen3MoeForCausalLM,
            batch: Dict[str, torch.Tensor],
            mini_batch_size: Optional[int] = 128,
            merging_layers: Optional[List[int]] = None,
    ):
        #model.cuda()
        # 在计算前初始化 usage frequency，避免累加
        self._usage_frequency_state_dict = {
            f"model.layers.{layer_idx}.mlp": torch.zeros(self.num_experts, device="cuda")
            for layer_idx in self.sparse_layer_indices
        }
        model.eval()
        total_batch_size = batch["input_ids"].shape[0]
        if mini_batch_size > total_batch_size:
            mini_batch_size = total_batch_size
        num_batches = total_batch_size // mini_batch_size

        for i in tqdm(range(num_batches), desc="[Merging]Computing all usages 计算所有专家使用频率..."):
            with torch.no_grad():
                mini_batch = {k: v[i * mini_batch_size: (i + 1) * mini_batch_size] for k, v in batch.items()}
                mini_batch = {k: v.cuda() for k, v in mini_batch.items()}
                outputs = model(**mini_batch, output_router_logits=True)
                for layer_idx in self.sparse_layer_indices:
                    mlp_name = f"model.layers.{layer_idx}.mlp"
                    routing_weights = F.softmax(outputs.router_logits[layer_idx], dim=1, dtype=torch.float)
                    routing_weights, selected_experts = torch.topk(routing_weights, self.num_experts_per_tok, dim=-1)
                    for idx in selected_experts.reshape(-1):
                        self._usage_frequency_state_dict[mlp_name][idx] += 1
                    
        self._usage_frequency_state_dict = {
            k: v / torch.sum(v) for k, v in self._usage_frequency_state_dict.items()
        }

    def compute_all_usages_in_group(
            self,
            model: Qwen3MoeForCausalLM,
            batch: Dict[str, torch.Tensor],
            mini_batch_size: Optional[int] = 128,
            merging_layers: Optional[List[int]] = None,
    ):
        # 初始化全局专家使用频率
        self._usage_frequency_state_dict = {
            f"model.layers.{layer_idx}.mlp": torch.zeros(self.num_experts, device="cuda")
            for layer_idx in self.sparse_layer_indices
        }
        # 初始化分组专家使用频率
        self._group_usage_frequency_state_dict = {
            f"model.layers.{layer_idx}.mlp": {}
            for layer_idx in self.sparse_layer_indices
        }

        model.eval()
        total_batch_size = batch["input_ids"].shape[0]
        if mini_batch_size > total_batch_size:
            mini_batch_size = total_batch_size
        num_batches = total_batch_size // mini_batch_size

        for i in tqdm(range(num_batches), desc="[Merging]Computing all usages 计算所有专家使用频率..."):
            with torch.no_grad():
                mini_batch = {k: v[i * mini_batch_size: (i + 1) * mini_batch_size] for k, v in batch.items()}
                mini_batch = {k: v.cuda() for k, v in mini_batch.items()}
                outputs = model(**mini_batch, output_router_logits=True)

                for layer_idx in self.sparse_layer_indices:
                    mlp_name = f"model.layers.{layer_idx}.mlp"
                    routing_weights = F.softmax(outputs.router_logits[layer_idx], dim=1, dtype=torch.float)
                    routing_weights, selected_experts = torch.topk(routing_weights, self.num_experts_per_tok, dim=-1)
                    
                    # 1. 更新全局专家使用频率
                    for idx in selected_experts.reshape(-1):
                        self._usage_frequency_state_dict[mlp_name][idx] += 1

                    # 2. 更新分组专家使用频率
                    group_labels = self._group_state_dict[mlp_name]  # 获取当前层的专家分组标签
                    for expert_idx in selected_experts.reshape(-1):
                        group_label = group_labels[expert_idx].item()  # 当前专家所属的组号
                        if group_label not in self._group_usage_frequency_state_dict[mlp_name]:
                            self._group_usage_frequency_state_dict[mlp_name][group_label] = 0
                        self._group_usage_frequency_state_dict[mlp_name][group_label] += 1

        # 归一化全局专家使用频率
        self._usage_frequency_state_dict = {
            k: v / torch.sum(v) for k, v in self._usage_frequency_state_dict.items()
        }
        # 归一化分组专家使用频率
        for mlp_name in self._group_usage_frequency_state_dict:
            total_group_usage = sum(self._group_usage_frequency_state_dict[mlp_name].values())
            if total_group_usage > 0:
                for group_label in self._group_usage_frequency_state_dict[mlp_name]:
                    self._group_usage_frequency_state_dict[mlp_name][group_label] /= total_group_usage

        return self._group_usage_frequency_state_dict  # 返回分组统计结果

    def print_usage_frequencies(self):
        '''打印专家使用频率统计'''
        print("[Merging]专家使用频率统计：")
        for layer_name, usage_freq in self._usage_frequency_state_dict.items():
            freqs_str = ", ".join([f"专{idx}:{freq.item():.6f}" for idx, freq in enumerate(usage_freq)])
            print(f"{layer_name}: {freqs_str}")
            print()  # 每层之间空一行

    def plot_usage_frequencies(self, save_path="/home/panjiaming/MOE/new_moe/MoEMerge/image/usage_frequencies.png"):
        '''绘制专家使用频率分布图'''
        plt.clf()  # 清空之前的图形
        usage_frequency_state_dict = self._usage_frequency_state_dict
        num_layers = len(usage_frequency_state_dict)
        fig, axes = plt.subplots(num_layers, 1, figsize=(12, num_layers * 1.2), sharex=True)

        if num_layers == 1:
            axes = [axes]

        max_experts = max(len(v) for v in usage_frequency_state_dict.values())

        for ax, (layer_name, usage_freq) in zip(axes, usage_frequency_state_dict.items()):
            if hasattr(usage_freq, 'cpu'):
                usage_freq = usage_freq.cpu().numpy()
            x = np.arange(len(usage_freq))
            ax.bar(x, usage_freq, width=0.8, color='skyblue')
            ax.set_ylim(0, max(usage_freq)*1.1)
            # 保留层索引部分作为y轴label
            layer_num = layer_name.split('.')[-2]
            ax.set_ylabel(f"Layer {layer_num}", fontsize=8)
            ax.tick_params(axis='y', labelsize=6)
            ax.grid(True, axis='y', linestyle='--', alpha=0.5)
            ax.set_xlim(-0.5, max_experts-0.5)

        axes[-1].set_xlabel("Expert Index", fontsize=10)
        plt.suptitle("Expert Usage Frequency Distribution", fontsize=14)
        plt.tight_layout(rect=[0, 0, 1, 0.96])
        plt.savefig(save_path)
        print(f"Image saved to {save_path}")

    def print_usage_frequencies_in_group(self):
        """打印分组专家使用频率统计"""
        print("[Merging]分组专家使用频率统计：")
        for layer_name, group_usage_freq in self._group_usage_frequency_state_dict.items():
            layer_idx = layer_name.split(".")[2]
            print(f"Layer {layer_idx} - Group Usage Frequencies:")
            for group_label, freq in sorted(group_usage_freq.items()):
                print(f"  Group {group_label}: {freq:.4f}")
    
    def plot_usage_frequencies_in_group(self, save_path="/home/panjiaming/MOE/new_moe/MoEMerge/image/group_usage_frequencies_in_group.png"):
        """绘制分组专家使用频率分布图（按组统计）"""
        plt.clf()  # 清空之前的图形
        group_usage_dict = self._group_usage_frequency_state_dict  # 使用分组频数字典
        num_layers = len(group_usage_dict)
        
        # 创建子图
        fig, axes = plt.subplots(num_layers, 1, figsize=(12, num_layers * 1.2), sharex=True)
        if num_layers == 1:
            axes = [axes]

        for ax, (layer_name, group_usage) in zip(axes, group_usage_dict.items()):
            layer_idx = layer_name.split(".")[2]  # 提取层号
            groups = sorted(group_usage.keys())   # 按组号排序
            freqs = [group_usage[g] for g in groups]  # 获取对应频数

            # 绘制柱状图（红色表示分组频数）
            ax.bar(groups, freqs, width=0.8, color='red')
            ax.set_ylim(0, max(freqs) * 1.1)
            ax.set_ylabel(f"Layer {layer_idx}", fontsize=8)
            ax.tick_params(axis='y', labelsize=6)
            ax.grid(True, axis='y', linestyle='--', alpha=0.5)
            ax.set_xlim(min(groups) - 0.5, max(groups) + 0.5)  # 横轴范围按组号动态调整

        # 全局标签和标题
        axes[-1].set_xlabel("Group Index", fontsize=10)
        plt.suptitle("Group Usage Frequency Distribution (After Merging)", fontsize=14)
        plt.tight_layout(rect=[0, 0, 1, 0.96])
        plt.savefig(save_path)
        print(f"Group usage frequency plot saved to {save_path}")

    def reverse_all_similarities(self):
        print("[Merging]Reversing all similarities...")
        for key in self._similarity_state_dict.keys():
            self._similarity_state_dict[key] = 1 - self._similarity_state_dict[key]

    def compute_layer_act(
            self,
            model: Qwen3MoeForCausalLM,
            merging_layer_idx: int,
            batch: Dict[str, torch.Tensor]
    ):
        #model = model.cuda()
        model = model.eval()
        batch = {k: v.cuda() for k, v in batch.items()}
        self.up_acts = {}
        self.gate_acts = {}

        mlp_name = f"model.layers.{merging_layer_idx}.mlp"
        handle = model.model.layers[merging_layer_idx].mlp.register_forward_hook(
            self._get_mlp_activation(mlp_name)
        )

        with torch.no_grad():
            model(**batch)

        self.composed_matrixes[mlp_name] = [self.gate_acts[mlp_name], self.up_acts[mlp_name]]
        handle.remove()

        del self.up_acts
        del self.gate_acts

    def compute_all_similarities(
            self,
            model: Qwen3MoeForCausalLM,
            batch: Dict[str, torch.Tensor] = None,
            merging_layers: Optional[List[int]] = None,
    ):
        if self.similarity_base not in ["weight", "router-weight"] and batch is None:
            raise ValueError(
                "[Merging]batch should be provided when similarity_base is not 'weight' or 'router-weight'")

        #model = model.cuda()
        model = model.eval()
        if self.similarity_base == "weight":
            self._compute_all_similarities_by_weight(model.state_dict())
        elif self.similarity_base == 'gate-weight':
            self._compute_all_similarities_by_gate_weight(model.state_dict())
        elif self.similarity_base == 'gate-up-weight':
            self._compute_all_similarities_by_gate_up_weight(model.state_dict())
        elif self.similarity_base == 'router-weight':
            self._compute_all_similarities_by_router_weight(model.state_dict())
        elif self.similarity_base == 'router-logits':
            batch = {k: v.cuda() for k, v in batch.items()}
            self._compute_all_similarities_by_router_logits(model, batch)
        elif self.similarity_base == 'gate-act':
            batch = {k: v.cuda() for k, v in batch.items()}
            self._compute_all_similarities_by_gate_act(model, batch)

        else:
            raise NotImplementedError

    def print_similarity_matrix(self, mlp_name: str):
            """打印指定层的相似度矩阵"""
            sim_matrix = self._similarity_state_dict[mlp_name].cpu().numpy()
            print(f"Similarity matrix for {mlp_name}:")
            print(sim_matrix)

    def print_all_similarity_matrices(self):
        """打印所有层的相似度矩阵"""
        for mlp_name in self._similarity_state_dict.keys():
            self.print_similarity_matrix(mlp_name)

    def plot_similarity_heatmap(self, mlp_name: str, save_dir: str = "/home/panjiaming/MOE/new_moe/MoEMerge/image"):
        """绘制指定层的相似度热力图，并保存到指定目录"""
        os.makedirs(save_dir, exist_ok=True)
        sim_matrix = self._similarity_state_dict[mlp_name].cpu().numpy()

        plt.figure(figsize=(8, 6))
        sns.heatmap(sim_matrix, cmap="viridis", square=True, annot=False)
        plt.title(f"Similarity Heatmap: {mlp_name}")
        plt.xlabel("Expert Index")
        plt.ylabel("Expert Index")
        plt.tight_layout()

        match = re.search(r"layers\.(\d+)\.", mlp_name)
        layer_num = match.group(1) if match else "unknown"

        save_path = os.path.join(save_dir, f"similarity_layer{layer_num}.png")
        plt.savefig(save_path)
        plt.close()

    def plot_all_similarity_heatmaps(self, save_dir: str = "/home/panjiaming/MOE/new_moe/MoEMerge/image"):
        """绘制所有层的相似度热力图，并保存到指定目录"""
        for mlp_name in self._similarity_state_dict.keys():
            self.plot_similarity_heatmap(mlp_name, save_dir)
        print(f"Heatmap saved to image")
    
    def _compute_all_similarities_by_weight(self, state_dict: Dict[str, torch.Tensor]):
        for layer_idx in tqdm(self.sparse_layer_indices, desc="[Merging]Computing similarities by weight..."):
            mlp_name = f"model.layers.{layer_idx}.mlp"
            for i in range(self.num_experts):
                for j in range(i + 1, self.num_experts):
                    i_flat = torch.cat(
                        [state_dict[f"{mlp_name}.experts.{i}.up_proj.weight"].flatten(),
                         state_dict[f"{mlp_name}.experts.{i}.down_proj.weight"].flatten()],
                        dim=0
                    )
                    j_flat = torch.cat(
                        [state_dict[f"{mlp_name}.experts.{j}.up_proj.weight"].flatten(),
                         state_dict[f"{mlp_name}.experts.{j}.down_proj.weight"].flatten()],
                        dim=0
                    )
                    similarity = self.similarity_fn(i_flat, j_flat)
                    self.save_similarity(mlp_name, i, j, similarity)

    def _compute_all_similarities_by_gate_weight(self, state_dict: Dict[str, torch.Tensor]):
        for layer_idx in tqdm(self.sparse_layer_indices, desc="[Merging]Computing similarities by gate weight..."):
            mlp_name = f"model.layers.{layer_idx}.mlp"
            for i in range(self.num_experts):
                for j in range(i + 1, self.num_experts):
                    i_flat = state_dict[f"{mlp_name}.experts.{i}.gate_proj.weight"].flatten()
                    j_flat = state_dict[f"{mlp_name}.experts.{j}.gate_proj.weight"].flatten()
                    similarity = self.similarity_fn(i_flat, j_flat)
                    self.save_similarity(mlp_name, i, j, similarity)

    def _compute_all_similarities_by_gate_up_weight(self, state_dict: Dict[str, torch.Tensor]):
        for layer_idx in tqdm(self.sparse_layer_indices, desc="[Merging]Computing similarities by gate-up weight..."):
            mlp_name = f"model.layers.{layer_idx}.mlp"
            for i in range(self.num_experts):
                for j in range(i + 1, self.num_experts):
                    i_flat = torch.cat(
                        [state_dict[f"{mlp_name}.experts.{i}.gate_proj.weight"].flatten(),
                         state_dict[f"{mlp_name}.experts.{i}.up_proj.weight"].flatten()],
                        dim=0
                    )
                    j_flat = torch.cat(
                        [state_dict[f"{mlp_name}.experts.{j}.gate_proj.weight"].flatten(),
                         state_dict[f"{mlp_name}.experts.{j}.up_proj.weight"].flatten()],
                        dim=0
                    )
                    similarity = self.similarity_fn(i_flat, j_flat)
                    self.save_similarity(mlp_name, i, j, similarity)

    def _compute_all_similarities_by_router_weight(
            self, state_dict: Dict[str, torch.Tensor]
    ):
        for layer_idx in tqdm(self.sparse_layer_indices, desc="[Merging]Computing similarities by router rows..."):
            mlp_name = f"model.layers.{layer_idx}.mlp"
            for i in range(self.num_experts):
                for j in range(i + 1, self.num_experts):
                    i_flat = state_dict[f"{mlp_name}.gate.weight"][i]
                    j_flat = state_dict[f"{mlp_name}.gate.weight"][j]
                    similarity = self.similarity_fn(i_flat, j_flat)
                    self.save_similarity(mlp_name, i, j, similarity)

    def _compute_all_similarities_by_router_logits(
            self, model: Qwen3MoeForCausalLM, batch: Dict[str, torch.Tensor]
    ):
        with torch.no_grad():
            outputs = model(**batch, output_router_logits=True)
        for layer_idx in tqdm(self.sparse_layer_indices, desc="[Merging]Computing similarities by router logits..."):
            mlp_name = f"model.layers.{layer_idx}.mlp"
            router_logits = outputs.router_logits[layer_idx].reshape(-1, self.num_experts)
            with torch.no_grad():
                for i in range(self.num_experts):
                    for j in range(i + 1, self.num_experts):
                        i_flat = router_logits[:, i].flatten()
                        j_flat = router_logits[:, j].flatten()
                        similarity = self.similarity_fn(i_flat, j_flat)
                        self.save_similarity(mlp_name, i, j, similarity)
    
    def _compute_all_similarities_by_gate_act(
            self, model: Qwen3MoeForCausalLM, batch: Dict[str, torch.Tensor]
    ):
        for layer_idx in tqdm(self.sparse_layer_indices, desc="[Merging]Computing similarities by gate logits..."):
            
            self.up_acts = {}
            self.gate_acts = {}
            mlp_name = f"model.layers.{layer_idx}.mlp"
            handle = model.model.layers[layer_idx].mlp.register_forward_hook(
                self._get_mlp_activation(mlp_name)
            )
            with torch.no_grad():
                model(**batch)

            gate_acts = self.gate_acts[mlp_name]
            handle.remove()

            del self.up_acts
            del self.gate_acts
            #print(gate_acts.shape)
            with torch.no_grad():
                for i in range(self.num_experts):
                    for j in range(i + 1, self.num_experts):
                        similarity = self.similarity_fn(gate_acts[i], gate_acts[j])
                        self.save_similarity(mlp_name, i, j, similarity)

    def _get_mlp_activation(self, name):
        def hook(module, input, output):
            batch_size, sequence_length, hidden_dim = input[0].shape
            hidden_states = input[0].view(-1, hidden_dim)
            gate_acts = []
            up_acts = []

            for expert_idx in range(module.num_experts):
                up_acts.append(module.experts[expert_idx].up_proj(hidden_states))
                act_fn = module.experts[expert_idx].act_fn
                # gate_acts.append(act_fn(module.experts[expert_idx].gate_proj(hidden_states)))
                gate_acts.append(module.experts[expert_idx].gate_proj(hidden_states))

            self.up_acts[name] = torch.stack(up_acts)
            self.gate_acts[name] = torch.stack(gate_acts)
        return hook

    def compute_cross_layer_coactivation(
        self, 
        model: Qwen3MoeForCausalLM,
        batch: Dict[str, torch.Tensor],
        mini_batch_size: Optional[int] = 128,
        merging_layers: Optional[List[int]] = None,
    ):
        """
        统计跨层专家协同激活概率，基于对应 token 的专家选择结果。
        输出: self.cross_layer_activation_stats[(layer_i, layer_j)] = num_experts x num_experts 的矩阵
        """
        # 初始化
        if not hasattr(self, "cross_layer_activation_stats") or self.cross_layer_activation_stats is None:
            self.cross_layer_activation_stats = {}

        model.eval()
        total_batch_size = batch["input_ids"].shape[0]
        if mini_batch_size > total_batch_size:
            mini_batch_size = total_batch_size
        num_batches = total_batch_size // mini_batch_size

        # 初始化统计矩阵
        for i, layer_i in enumerate(merging_layers):
            for j, layer_j in enumerate(merging_layers):
                if j > i:
                    self.cross_layer_activation_stats[(layer_i, layer_j)] = torch.zeros(
                        (self.num_experts, self.num_experts), device="cuda"
                    )

        for i in tqdm(range(num_batches), desc="[Merging] Computing cross-layer co-activation 计算协同激活..."):
            with torch.no_grad():
                mini_batch = {k: v[i * mini_batch_size:(i + 1) * mini_batch_size] for k, v in batch.items()}
                mini_batch = {k: v.cuda() for k, v in mini_batch.items()}
                outputs = model(**mini_batch, output_router_logits=True)

                # 存储每层专家选择结果: {layer: [batch_size, seq_len, top_k]}
                layer_expert_selection = {}

                for layer_idx in merging_layers:
                    routing_weights = F.softmax(outputs.router_logits[layer_idx], dim=-1)
                    _, selected_experts = torch.topk(routing_weights, self.num_experts_per_tok, dim=-1)
                    layer_expert_selection[layer_idx] = selected_experts  # shape: [batch, seq_len, top_k]

                # 统计对应 token 的协同专家对
                for i_idx, layer_i in enumerate(merging_layers):
                    for j_idx, layer_j in enumerate(merging_layers):
                        if j_idx > i_idx:  # 只统计 i -> j
                            # [batch * seq_len, top_k]
                            experts_i = layer_expert_selection[layer_i].reshape(-1, self.num_experts_per_tok)
                            experts_j = layer_expert_selection[layer_j].reshape(-1, self.num_experts_per_tok)

                            # 计算所有 (e_i, e_j) 组合：大小 = [N, top_k, top_k]
                            combos = experts_i.unsqueeze(2) * self.num_experts + experts_j.unsqueeze(1)
                            combos = combos.flatten()

                            # GPU 上统计频率
                            counts = torch.bincount(combos, minlength=self.num_experts * self.num_experts)
                            counts = counts.view(self.num_experts, self.num_experts)

                            self.cross_layer_activation_stats[(layer_i, layer_j)] += counts

        # 归一化处理
        for key in self.cross_layer_activation_stats:
            matrix = self.cross_layer_activation_stats[key]
            if matrix.sum() > 0:
                row_sums = matrix.sum(dim=1, keepdim=True)  # 每一行的和
                row_sums[row_sums == 0] = 1  # 避免除以0
                self.cross_layer_activation_stats[key] = matrix / row_sums

    def print_all_coactivation_matrices(self):
        np.set_printoptions(precision=3, suppress=True)  # 保留x位小数，且去掉科学计数法
        for (layer_i, layer_j), matrix in self.cross_layer_activation_stats.items():
            print(f"Cross-layer co-activation matrix for layers {layer_i}->{layer_j}:")
            print(matrix.cpu().numpy())

    def print_top_expert_pairs(self, layers=None, top_k=5):
        if layers is None:
            layers = sorted(set(i for pair in self.cross_layer_activation_stats.keys() for i in pair))
        
        num_layers = len(layers)
        
        print(f"Top {top_k} expert pairs between adjacent layers:")
        
        for i in range(num_layers - 1):
            layer_i, layer_j = layers[i], layers[i + 1]
            matrix = self.cross_layer_activation_stats[(layer_i, layer_j)]  # shape: [num_experts, num_experts]
            
            # 展开矩阵，找到 top_k 最大的值及对应坐标
            flat = matrix.flatten()
            top_values, top_indices = torch.topk(flat, top_k)
            
            print(f"Layer {layer_i}-{layer_j}:")
            for val, idx in zip(top_values, top_indices):
                e1 = idx // self.num_experts
                e2 = idx % self.num_experts
                print(f"  Expert {e1.item()} -> Expert {e2.item()} | Score: {val.item():.6f}")

    def print_top_expert_chains(self, layers=None, top_k=5):
        if layers is None:
            layers = sorted(set(i for pair in self.cross_layer_activation_stats.keys() for i in pair))
        
        num_layers = len(layers)
        paths = [([e], 0.0) for e in range(self.num_experts)]  # log 分数初始化为 0

        for i in range(num_layers - 1):
            layer_i, layer_j = layers[i], layers[i + 1]
            matrix = self.cross_layer_activation_stats[(layer_i, layer_j)]
            new_paths = []
            
            for path, score in paths:
                last_expert = path[-1]
                co_values = matrix[last_expert]
                top_values, top_indices = torch.topk(co_values, top_k)
                for val, idx in zip(top_values, top_indices):
                    new_score = score + math.log(val.item() + 1e-12)  # 累积 log
                    new_paths.append((path + [idx.item()], new_score))
            
            # 只保留 top_k 最强路径
            new_paths.sort(key=lambda x: x[1], reverse=True)
            paths = new_paths[:top_k]

        print(f"Top {top_k} expert chains across layers {layers}:")
        for path, log_score in paths:
            chain_str = " -> ".join([f"Expert {e}" for e in path])
            original_score = math.exp(log_score)
            print(f"  {chain_str} | LogScore: {log_score:.4f} | Score: {original_score:.4e}")

    def plot_all_coactivation_heatmaps(self, save_dir: str = "/home/panjiaming/MOE/new_moe/MoEMerge/image"):
        """绘制所有跨层协同激活热力图，并保存到指定目录"""
        os.makedirs(save_dir, exist_ok=True)
        for (layer_i, layer_j), matrix in self.cross_layer_activation_stats.items():
            sim_matrix = matrix.cpu().numpy()

            plt.figure(figsize=(8, 6))
            sns.heatmap(sim_matrix, cmap="viridis", square=True, annot=False)
            plt.title(f"Cross-layer Co-activation Heatmap: Layer {layer_i} -> Layer {layer_j}")
            plt.xlabel(f"Expert Index (Layer {layer_j})")
            plt.ylabel(f"Expert Index (Layer {layer_i})")
            plt.tight_layout()

            save_path = os.path.join(save_dir, f"{layer_i}_to_{layer_j}.png")
            plt.savefig(save_path)
            plt.close()
        print(f"Co-activation heatmap saved to image")

def _merge_mlp_experts_by_averaging(
        mlp: Qwen3MoeMLP,
        group_labels: torch.LongTensor,
        permute: bool,
        permute_strategy: str,
        forwarded_hidden_states: Optional[Tuple[torch.Tensor]] = None,
) -> Qwen3MoeMLP:
    
    #device = mlp.experts[0].down_proj.weight.device
    #usage_frequencies = usage_frequencies.to(device)

    for label in group_labels.unique():
        expert_indices = torch.where(group_labels == label)[0]
        if permute and permute_strategy == "weight-matching":
            for expert_idx in expert_indices[1:]:
                perm = compute_OLMoE_permutation_by_weight_matching(
                    reference_mlp=mlp.experts[expert_indices[0]],
                    target_mlp=mlp.experts[expert_idx],
                    include_wo=True
                )
                mlp.experts[expert_idx] = permute_OLMoE_mlp_dense_expert_(
                    mlp.experts[expert_idx], perm
                )
        elif permute and permute_strategy == "activation-matching":
            group_forwarded_hidden_states = torch.cat([
                forwarded_hidden_states[expert_idx] for expert_idx in expert_indices
            ], dim=0)
            for expert_idx in expert_indices[1:]:
                perm = compute_OLMoE_permutation_by_activation_matching(
                    reference_mlp=mlp.experts[expert_indices[0]],
                    target_mlp=mlp.experts[expert_idx],
                    forwarded_hidden_states=group_forwarded_hidden_states,
                )
                mlp.experts[expert_idx] = permute_OLMoE_mlp_dense_expert_(
                    mlp.experts[expert_idx], perm
                )
        elif permute:
            raise ValueError(f"Unknown permute strategy: {permute_strategy}")

        with torch.no_grad():
            up_proj_weight = torch.mean(
                torch.stack([mlp.experts[expert_idx].up_proj.weight for expert_idx in expert_indices]),
                dim=0
            )
            down_proj_weight = torch.mean(
                torch.stack([mlp.experts[expert_idx].down_proj.weight for expert_idx in expert_indices]),
                dim=0
            )
            gate_proj_weight = torch.mean(
                torch.stack([mlp.experts[expert_idx].gate_proj.weight for expert_idx in expert_indices]),
                dim=0
            )
            mlp.experts[expert_indices[0]].up_proj.weight.copy_(up_proj_weight)
            mlp.experts[expert_indices[0]].down_proj.weight.copy_(down_proj_weight)
            mlp.experts[expert_indices[0]].gate_proj.weight.copy_(gate_proj_weight)
            for expert_idx in expert_indices[1:]:
                # Binding merged experts to the first of them
                mlp.experts[expert_idx] = mlp.experts[expert_indices[0]]
    return mlp

def _merge_mlp_experts_by_usage_frequency_weighting(
        mlp: Qwen3MoeSparseMoeBlock,
        group_labels: torch.LongTensor,
        usage_frequencies: torch.Tensor,
        permute: bool,
) -> Qwen3MoeSparseMoeBlock:
    device = mlp.experts[0].down_proj.weight.device
    usage_frequencies = usage_frequencies.to(device)
    for label in group_labels.unique():
        expert_indices = torch.where(group_labels == label)[0].to(device)
        if permute:
            assert(False, "Do not support permute")
        with torch.no_grad():
            up_proj_weight_list = torch.stack(
                [mlp.experts[expert_idx].up_proj.weight * usage_frequencies[expert_idx] for expert_idx in
                 expert_indices], dim=0
            )
            down_proj_weight_list = torch.stack(
                [mlp.experts[expert_idx].down_proj.weight * usage_frequencies[expert_idx] for expert_idx in
                 expert_indices], dim=0
            )
            gate_proj_weight_list = torch.stack(
                [mlp.experts[expert_idx].gate_proj.weight * usage_frequencies[expert_idx] for expert_idx in
                 expert_indices], dim=0
            )
            up_proj_weight = torch.sum(up_proj_weight_list, dim=0) / (
                    torch.sum(usage_frequencies[expert_indices], dim=0) + FP32_EPS)
            down_proj_weight = torch.sum(down_proj_weight_list, dim=0) / (
                    torch.sum(usage_frequencies[expert_indices], dim=0) + FP32_EPS)
            gate_proj_weight = torch.sum(gate_proj_weight_list, dim=0) / (
                    torch.sum(usage_frequencies[expert_indices], dim=0) + FP32_EPS)
            
            mlp.experts[expert_indices[0]].up_proj.weight.copy_(up_proj_weight)
            mlp.experts[expert_indices[0]].down_proj.weight.copy_(down_proj_weight)
            mlp.experts[expert_indices[0]].gate_proj.weight.copy_(gate_proj_weight)
 
            for expert_idx in expert_indices[1:]:
                # Binding merged experts to the first of them
                mlp.experts[expert_idx] = mlp.experts[expert_indices[0]]
    return mlp

def _merge_mlp_experts_by_weighting_act(
        mlp: Qwen3MoeSparseMoeBlock,
        group_labels: torch.LongTensor,
        usage_frequencies: torch.Tensor,
        composed_matrixes: List[torch.Tensor]
) -> Qwen3MoeSparseMoeBlock:
    device = mlp.experts[0].down_proj.weight.device
    gate_acts = composed_matrixes[0].to(device)
    up_acts = composed_matrixes[1].to(device)
    act_fn = mlp.experts[0].act_fn.to(device)
    usage_frequencies = usage_frequencies.to(device)
    original_acts = act_fn(gate_acts) * up_acts
    
    for label in group_labels.unique():
        expert_indices = torch.where(group_labels == label)[0].to(device)
        usage_freq_sum = torch.sum(usage_frequencies[expert_indices], dim=0) + FP32_EPS
        with torch.no_grad():
            if expert_indices.numel() > 1:
                merged_gate_acts_list = torch.stack(
                    [gate_acts[expert_idx] * usage_frequencies[expert_idx] for expert_idx in expert_indices], dim=0
                )
                merged_up_acts_list = torch.stack(
                    [up_acts[expert_idx] * usage_frequencies[expert_idx] for expert_idx in expert_indices], dim=0
                )
                merged_gate_acts = torch.sum(merged_gate_acts_list, dim=0) / usage_freq_sum
                merged_up_acts = torch.sum(merged_up_acts_list, dim=0) / usage_freq_sum
                
                merged_acts = act_fn(merged_gate_acts) * merged_up_acts
                sample_num, intermediate_size = merged_acts.size()
                unmerged_acts = original_acts[expert_indices].permute(1,0,2).reshape(sample_num, -1)
                solution, _, _, _ = torch.linalg.lstsq(merged_acts.to(torch.float), unmerged_acts.to(torch.float))

                solution = solution.T.to(mlp.experts[0].down_proj.weight.dtype)
                down_proj_weight = torch.zeros_like(mlp.experts[0].down_proj.weight)
                for i, expert_idx in enumerate(expert_indices):
                    down_proj_weight += torch.matmul(mlp.experts[expert_idx].down_proj.weight * usage_frequencies[expert_idx], 
                                solution[i*intermediate_size:(i+1)*intermediate_size])
                down_proj_weight /= usage_freq_sum
            else:
                # group size == 1
                assert(expert_indices.numel() == 1)
                down_proj_weight = mlp.experts[expert_indices[0]].down_proj.weight

            up_proj_weight_list = torch.stack(
                [mlp.experts[expert_idx].up_proj.weight * usage_frequencies[expert_idx] for expert_idx in
                 expert_indices], dim=0
            )
            gate_proj_weight_list = torch.stack(
                [mlp.experts[expert_idx].gate_proj.weight * usage_frequencies[expert_idx] for expert_idx in
                 expert_indices], dim=0
            )
            up_proj_weight = torch.sum(up_proj_weight_list, dim=0) / usage_freq_sum
            gate_proj_weight = torch.sum(gate_proj_weight_list, dim=0) / usage_freq_sum
            
            mlp.experts[expert_indices[0]].up_proj.weight.copy_(up_proj_weight)
            mlp.experts[expert_indices[0]].gate_proj.weight.copy_(gate_proj_weight)
            mlp.experts[expert_indices[0]].down_proj.weight.copy_(down_proj_weight)
 
            for expert_idx in expert_indices[1:]:
                # Binding merged experts to the first of them
                mlp.experts[expert_idx] = mlp.experts[expert_indices[0]]
    return mlp

def Qwen_merge_by_groups_with_ACT(
        model: Qwen3MoeForCausalLM,
        grouper: ExpertsGrouperForQwen,
        merging_layers: Optional[List[int]],
        batch: Dict[str, torch.Tensor],
) -> Qwen3MoeForCausalLM:
    """
    Merges experts in model using activation-based merging strategy.
    
    This function performs expert merging by:
    1. Computing expert usage frequencies
    2. Capturing layer activations for each expert
    3. Merging experts within each group using activation-based weighting
    
    Args:
        model: The model to be merged
        grouper: Expert grouper containing grouping information and similarity metrics
        merging_layers: List of layer indices to merge (None merges all layers)
        batch: Input batch used for computing activations
        
    Returns:
        The merged model with reduced number of experts
    """
    usage_frequency_dict = grouper.usage_frequency_state_dict()
    
    for layer_idx in tqdm(
            grouper.sparse_layer_indices[::-1],
            desc=f"[Merging]Merging experts with act..."
    ):
        if merging_layers is None or layer_idx in merging_layers:
            grouper.compute_layer_act(
                model=model,
                merging_layer_idx=layer_idx,
                batch=batch
            )
            mlp_name = f"model.layers.{layer_idx}.mlp"
            composed_matrixes = grouper.get_composed_matrixes(mlp_name)
            group_labels = grouper.group_state_dict()[mlp_name]
            usage_frequencies = usage_frequency_dict[mlp_name]
            model.model.layers[layer_idx].mlp = _merge_mlp_experts_by_weighting_act(
                mlp=model.model.layers[layer_idx].mlp,
                group_labels=group_labels,
                usage_frequencies=usage_frequencies,
                composed_matrixes=composed_matrixes
            ) 

    return model

'''
MIT License
Copyright (c) 2023 UNITES Lab
This function is modified from (https://github.com/UNITES-Lab/MC-SMoE)
'''
def Qwen_merge_by_groups_with_usage_frequency_weighting(
        model: Qwen3MoeForCausalLM,
        grouper: ExpertsGrouperForQwen,
        strategy: str = "normal",
        merging_layers: Optional[List[int]] = None,
        permute: Optional[bool] = False,
        within_and_across_models: Optional[bool] = False,
) -> Qwen3MoeForCausalLM:
    """
    Merge experts by usage-frequency-weighted averaging, strategies include:
        1. normal: merge experts in each group by usage-frequency-weighted averaging.
        2. reversed: reverse usage frequencies by 1 - usage_frequency and merge experts in each group by
                        usage-frequency-weighted averaging.
        3. random: randomly initialize usage frequencies and merge experts in each group by
                        usage-frequency-weighted averaging.

    Parameters
    ----------
    model: DeepseekForCausalLM
        The model to merge experts
    grouper: ExpertsGrouperForQwen
        The grouper to group experts, supposed to have been called `grouper.compute_all_usages()` and
            one of `grouper.group_experts()` (i.e. have grouped labels)
    strategy: str
        The strategy to merge experts, one of ["normal", "reversed", "random"]
    merging_layers: Optional[List[int]]
        The layers to merge experts, if None, merge all layers
    permute: Optional[bool]
        Whether to permute the experts in the same group, only availabel when `within_and_across_models` is False.
    within_and_across_models: Optional[bool]
        Whether to merge experts within and across models.
    """
    if permute:
        print("[Merging]Permutation is enabled, will permute experts in the same group.")
    usage_frequency_dict = grouper.usage_frequency_state_dict()
    if strategy == "reversed":
        for key, value in usage_frequency_dict.items():
            usage_frequency_dict[key] = 1 - value
    elif strategy == "random":
        for key, value in usage_frequency_dict.items():
            usage_frequency_dict[key] = torch.rand_like(value)
    elif strategy != "normal":
        raise ValueError(f"[Merging]Unknown strategy {strategy}")

    for layer_idx in tqdm(
            grouper.sparse_layer_indices,
            desc=f"[Merging]Merging experts with {strategy} usage-frequency-weighted averaging..."
    ):
        if merging_layers is None or layer_idx in merging_layers:
            mlp_name = f"model.layers.{layer_idx}.mlp"
            group_labels = grouper.group_state_dict()[mlp_name]
            usage_frequencies = usage_frequency_dict[mlp_name]
            model.model.layers[layer_idx].mlp = _merge_mlp_experts_by_usage_frequency_weighting(
                mlp=model.model.layers[layer_idx].mlp,
                group_labels=group_labels,
                usage_frequencies=usage_frequencies,
                permute=permute
            )

    return model

def _Qwen_merge_mlp_experts_within_and_across_models(
        mlp: Qwen3MoeMLP,
        group_labels: torch.LongTensor,
        forwarded_hidden_states: Tuple[torch.Tensor],
        dominant_alone: bool,
        core_expert_indices: Optional[List[int]] = None,
        usage_frequencies: Optional[torch.Tensor] = None,
) -> Qwen3MoeMLP:
    """
    Merge grouped experts within and across models.

    Parameters
    ----------
    mlp: SwitchTransformersSparseMLP
        The mlp to merge experts.
    group_labels: torch.LongTensor
        The group labels of experts.
    forwarded_hidden_states: Tuple[torch.Tensor]
        The forwarded hidden states of each expert, should be of length num_experts
    dominant_alone: bool
        Whether to merge the dominant expert alone.
        If True, the merging process in a group will be done in two steps:
            1. Merge all experts except the dominant one.
            2. Merge the dominant expert with the merged expert in step 1.
    core_expert_indices: List[int]

    Returns
    -------
    mlp: SwitchTransformersSparseMLP
        The merged mlp.
    """
    if dominant_alone and core_expert_indices is None:
        raise ValueError("[Merging]dominant_alone is True, but core_expert_indices is None")

    for label in group_labels.unique():
        expert_indices = torch.where(group_labels == label)[0]
        with torch.no_grad():
            if dominant_alone:
                group_core_expert_indices = torch.stack([
                    idx for idx in expert_indices if idx in core_expert_indices])
                to_skip = False
                if len(group_core_expert_indices) == len(expert_indices):
                    merged_expert = mlp.experts[expert_indices[0]]
                    to_skip = True
                elif usage_frequencies is not None and len(group_core_expert_indices) == 1:
                    non_core_usage_sum = torch.sum(
                        usage_frequencies[[expert_idx.item() for expert_idx in
                                           expert_indices if expert_idx not in group_core_expert_indices]]).item()
                    if non_core_usage_sum == 0:
                        merged_expert = mlp.experts[group_core_expert_indices[0]]
                        to_skip = True
                    else:
                        to_skip = False
                if not to_skip:
                    # Stage 1: merge all experts except the dominant one
                    group_forwarded_hidden_states = torch.cat([
                        forwarded_hidden_states[expert_idx] for expert_idx in expert_indices if
                        expert_idx not in group_core_expert_indices
                    ], dim=0)
                    if usage_frequencies is not None:
                        non_core_usages = usage_frequencies[[expert_idx.item() for expert_idx in expert_indices if
                                                             expert_idx not in group_core_expert_indices]]
                    merged_expert = merge_olmoe_mlp_by_activation_matching_within_and_across_models(
                        mlp_list=[mlp.experts[expert_idx] for expert_idx in expert_indices if
                                  expert_idx not in group_core_expert_indices],
                        forwarded_hidden_states=group_forwarded_hidden_states,
                        average_coefs=non_core_usages.tolist() if usage_frequencies is not None else None
                    )
                    # Stage 2: merge the dominant expert with the merged expert in stage 1
                    group_forwarded_hidden_states = torch.cat([
                        forwarded_hidden_states[expert_idx] for expert_idx in expert_indices
                    ], dim=0)
                    if usage_frequencies is not None:
                        core_usages = usage_frequencies[group_core_expert_indices]
                        non_core_usage_sum = torch.sum(non_core_usages).item()
                    merged_expert = merge_olmoe_mlp_by_activation_matching_within_and_across_models(
                        mlp_list=[merged_expert] + [mlp.experts[expert_idx] for expert_idx in
                                                    group_core_expert_indices],
                        forwarded_hidden_states=group_forwarded_hidden_states,
                        average_coefs=[non_core_usage_sum] + core_usages.tolist(
                        ) if usage_frequencies is not None else None
                    )
            else:
                # Merge all experts in the group
                group_forwarded_hidden_states = torch.cat([
                    forwarded_hidden_states[expert_idx] for expert_idx in expert_indices
                ], dim=0)
                merged_expert = merge_olmoe_mlp_by_activation_matching_within_and_across_models(
                    mlp_list=[mlp.experts[expert_idx] for expert_idx in expert_indices],
                    forwarded_hidden_states=group_forwarded_hidden_states,
                    average_coefs=usage_frequencies[expert_indices].tolist() if usage_frequencies is not None else None
                )
            mlp.experts[expert_indices[0]].up_proj.weight.copy_(merged_expert.up_proj.weight)
            mlp.experts[expert_indices[0]].down_proj.weight.copy_(merged_expert.down_proj.weight)
            mlp.experts[expert_indices[0]].gate_proj.weight.copy_(merged_expert.gate_proj.weight)

            for expert_idx in expert_indices[1:]:
                # Binding merged experts to the first of them
                mlp.experts[expert_idx] = mlp.experts[expert_indices[0]]

    return mlp

def Qwen_merge_by_groups(
        model: Qwen3MoeForCausalLM,
        grouper: ExpertsGrouperForQwen,
        merging_layers: Optional[List[int]] = None,
        permute: Optional[bool] = False,
        permute_strategy: Optional[str] = "weight-matching",
        dataloader: Optional[DataLoader] = None,
) -> Qwen3MoeForCausalLM:
    """
    Parameters
    ----------
    model: Qwen2MoeForCausalLM
        The model to merge experts.
    grouper: ExpertsGrouperForQwen
        The grouper to group experts, supposed to have been called `grouper.compute_all_usages()` and
            one of `grouper.group_experts()` (i.e. have grouped labels).
    merging_layers: Optional[List[int]]
        The layers to merge experts, if None, merge all layers.
    dataloader: Optional[DataLoader]
        The dataloader to compute activations, only used when `strategy` is "activation-matching".
    """
    forwarded_hidden_states = dict()
    if permute_strategy == "activation-matching":
        model.eval().cuda()
        handles = []

        def _get_activation_hook(name):
            def hook(module, input, output):
                forwarded_hidden_states[name].append(input[0].detach().reshape(-1, input[0].shape[-1]))

            return hook

        for layer_idx in tqdm(
                grouper.sparse_layer_indices,
                desc=f"[Merging]Registering forward hook..."
        ):
            mlp_name = f"model.layers.{layer_idx}.mlp"
            forwarded_hidden_states[mlp_name] = []
            handles.append(model.model.layers[layer_idx].mlp.register_forward_hook(
                _get_activation_hook(mlp_name))
            )

        # {name: values}, values will be of shape (len(dataloader), batch_size * seq_len)
        router_indices = {name: [] for name in forwarded_hidden_states.keys()}
        with torch.no_grad():
            for batch in tqdm(dataloader, desc="[Merging]Computing activations..."):
                batch = {k: v.cuda() for k, v in batch.items()}
                outputs = model(**batch, output_router_logits=True)
                for layer_idx in grouper.sparse_layer_indices:
                    routing_weights = F.softmax(outputs.router_logits[layer_idx], dim=1, dtype=torch.float)
                    routing_weights, selected_experts = torch.topk(routing_weights, grouper.num_experts_per_tok, dim=-1)
                    router_indices[f"model.layers.{layer_idx}.mlp"].append(
                        selected_experts
                    )

        for handle in handles:
            handle.remove()

    num_experts = grouper.num_experts
    
    for layer_idx in tqdm(grouper.sparse_layer_indices,
                            desc="[Merging]Merging experts with averaging..."):
        if merging_layers is None or layer_idx in merging_layers:
            mlp_name = f"model.layers.{layer_idx}.mlp"
            group_labels = grouper.group_state_dict()[mlp_name]
            if permute_strategy == "activation-matching":
                layer_forwarded_hidden_states = tuple()
                for expert_idx in range(num_experts):
                    layer_forwarded_hidden_states += (
                        torch.cat(
                            [forwarded_hidden_states[mlp_name][i][
                                    (router_indices[mlp_name][i] == expert_idx).any(dim=1)]
                                for i in range(len(dataloader))], dim=0),
                    )
                model.model.layers[layer_idx].mlp = _merge_mlp_experts_by_averaging(
                    mlp=model.model.layers[layer_idx].mlp,
                    group_labels=group_labels,
                    permute=permute,
                    permute_strategy=permute_strategy,
                    forwarded_hidden_states=layer_forwarded_hidden_states
                )
            else:
                model.model.layers[layer_idx].mlp = _merge_mlp_experts_by_averaging(
                    mlp=model.model.layers[layer_idx].mlp,
                    group_labels=group_labels,
                    permute=permute,
                    permute_strategy=permute_strategy
                )

    return model

def Qwen_merge_by_groups_within_and_across_models(
        model: Qwen3MoeForCausalLM,
        grouper: ExpertsGrouperForQwen,
        dataloader: DataLoader,
        merging_layers: Optional[List[int]] = None,
        dominant_alone: Optional[bool] = False,
        core_experts: Optional[Dict[str, List[int]]] = None,
        usage_weighted: Optional[bool] = False,
) -> Qwen3MoeForCausalLM:
    # {name: values}, values  will be of shape (len(dataloader), batch_size * seq_len, d_ff)
    forwarded_hidden_states = dict()

    usage_frequencies = grouper.usage_frequency_state_dict()

    model.eval().cuda()
    handles = []

    def _get_activation_hook(name):
        def hook(module, input, output):
            forwarded_hidden_states[name].append(input[0].detach().reshape(-1, input[0].shape[-1]))
        return hook

    for layer_idx in tqdm(
            grouper.sparse_layer_indices,
            desc=f"[Merging]Registering forward hook..."
    ):
        mlp_name = f"model.layers.{layer_idx}.mlp"
        forwarded_hidden_states[mlp_name] = []
        handles.append(model.model.layers[layer_idx].mlp.register_forward_hook(
            _get_activation_hook(mlp_name))
        )

    # {name: values}, values will be of shape (len(dataloader), batch_size * seq_len)
    router_indices = {name: [] for name in forwarded_hidden_states.keys()}
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="[Merging]Computing activations..."):
            batch = {k: v.cuda() for k, v in batch.items()}
            outputs = model(**batch, output_router_logits=True)
            for layer_idx in grouper.sparse_layer_indices:
                routing_weights = F.softmax(outputs.router_logits[layer_idx], dim=1, dtype=torch.float)
                routing_weights, selected_experts = torch.topk(routing_weights, grouper.num_experts_per_tok, dim=-1)
                router_indices[f"model.layers.{layer_idx}.mlp"].append(
                    selected_experts
                )

    for handle in handles:
        handle.remove()

    num_experts = grouper.num_experts
    for layer_idx in tqdm(
            grouper.sparse_layer_indices,
            desc=f"[Merging]Merging by groups within and across experts..."
    ):
        if merging_layers is None or layer_idx in merging_layers:
            mlp_name = f"model.layers.{layer_idx}.mlp"
            group_labels = grouper.group_state_dict()[mlp_name]
            layer_forwarded_hidden_states = tuple()
            for expert_idx in range(num_experts):
                layer_forwarded_hidden_states += (
                    torch.cat(
                        [forwarded_hidden_states[mlp_name][i][
                            (router_indices[mlp_name][i] == expert_idx).any(dim=1)]
                         for i in range(len(dataloader))], dim=0),
                )
            model.model.layers[layer_idx].mlp = _Qwen_merge_mlp_experts_within_and_across_models(
                mlp=model.model.layers[layer_idx].mlp,
                group_labels=group_labels,
                forwarded_hidden_states=layer_forwarded_hidden_states,
                dominant_alone=dominant_alone,
                core_expert_indices=core_experts[mlp_name] if core_experts is not None else None,
                usage_frequencies=usage_frequencies[mlp_name] if usage_weighted else None,
            )

    del forwarded_hidden_states, router_indices
    torch.cuda.empty_cache()
    return model

def _merge_mlp_experts_by_arcee_fusion(
        mlp: Qwen3MoeMLP,
        group_labels: torch.LongTensor,
        usage_frequencies: torch.Tensor,
) -> Qwen3MoeMLP:

    device = mlp.experts[0].down_proj.weight.device
    usage_frequencies = usage_frequencies.to(device)
    
    for label in group_labels.unique():
        expert_indices = torch.where(group_labels == label)[0].to(device)
        usage_freq_sum = torch.sum(usage_frequencies[expert_indices], dim=0) + FP32_EPS
        with torch.no_grad():
            if expert_indices.numel() > 1:

                group_usage_frequencies = usage_frequencies[expert_indices]
                # center expert index:
                center_expert_idx = expert_indices[torch.argmax(group_usage_frequencies)]
    
                other_expert_indices = expert_indices[expert_indices != center_expert_idx]
                other_usage_frequencies = usage_frequencies[other_expert_indices]
                
                # Sort other experts by activation frequency (DESCENDING order - from low to high)
                sorted_indices = torch.argsort(other_usage_frequencies, descending=False)
                sorted_other_experts = other_expert_indices[sorted_indices]
                
                center_expert = mlp.experts[center_expert_idx]
                up_proj_weight = center_expert.up_proj.weight
                down_proj_weight = center_expert.down_proj.weight
                gate_proj_weight = center_expert.gate_proj.weight
                for i, expert_idx in enumerate(sorted_other_experts):

                    up_proj_weight = ArceeFusionMerge(model=mlp.experts[expert_idx].up_proj.weight, 
                                                    base_model=up_proj_weight)
                    down_proj_weight = ArceeFusionMerge(model=mlp.experts[expert_idx].down_proj.weight, 
                                                    base_model=down_proj_weight)
                    gate_proj_weight = ArceeFusionMerge(model=mlp.experts[expert_idx].gate_proj.weight, 
                                                    base_model=gate_proj_weight)
                    
                # Update center expert weights
                center_expert.up_proj.weight.copy_(up_proj_weight)
                center_expert.down_proj.weight.copy_(down_proj_weight)
                center_expert.gate_proj.weight.copy_(gate_proj_weight)
                        
                for expert_idx in expert_indices:
                    # Binding merged experts to the first of them
                    mlp.experts[expert_idx] = center_expert
                
    
    print(f"[Merging]Expert merging completed for this layer.")
    return mlp

def Qwen_merge_by_groups_with_arcee_fusion(
        model: Qwen3MoeForCausalLM,
        grouper: ExpertsGrouperForQwen,
        merging_layers: Optional[List[int]],
        batch: Dict[str, torch.Tensor],
) -> Qwen3MoeForCausalLM:

    usage_frequency_dict = grouper.usage_frequency_state_dict()
    
    for layer_idx in tqdm(
        grouper.sparse_layer_indices[::-1], 
        desc="[Merging]Processing layers one by one..."
    ):
        if merging_layers is None or layer_idx in merging_layers:
            grouper.compute_layer_act(
                model=model,
                merging_layer_idx=layer_idx,
                batch=batch
            )
            mlp_name = f"model.layers.{layer_idx}.mlp"
            composed_matrixes = grouper.get_composed_matrixes(mlp_name)
            group_labels = grouper.group_state_dict()[mlp_name]
            usage_frequencies = usage_frequency_dict[mlp_name]
            model.model.layers[layer_idx].mlp = _merge_mlp_experts_by_arcee_fusion(
                mlp=model.model.layers[layer_idx].mlp,
                group_labels=group_labels,
                usage_frequencies=usage_frequencies
            ) 
    return model
