import os
from collections import OrderedDict
from copy import deepcopy
from typing import Dict, Optional, List, Union, Callable, Iterator, Tuple

import torch
from torch import nn
from scipy.optimize import linear_sum_assignment
from sklearn.cluster import SpectralClustering
from torch.nn import functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import (
    OlmoeConfig,
    OlmoeForCausalLM,
    PretrainedConfig
)
from transformers.modeling_outputs import Seq2SeqMoEOutput
from transformers.models.olmoe.modeling_olmoe import OlmoeMLP, OlmoeSparseMoeBlock


from .permutation import (
    permute_OLMoE_mlp_dense_expert_,
    compute_OLMoE_permutation_by_weight_matching,
    compute_OLMoE_permutation_by_activation_matching,
    merge_olmoe_mlp_by_activation_matching_within_and_across_models,
)

from .myolmoe import MyOlmoeForCausalLM, load_pretrained_weights
from .utils import generate_random_group_labels
from ..utils.constants import FP32_EPS

import ot

__all__ = [
    'ExpertsGrouperForOLMoE',
    'LEGAL_SIMILARITY_BASES',
    'SIMILARITY_MAPPING_FUNCTION',
    'ExpertUsageFrequencyTracker',
    'prune_non_core_experts_by_groups',
    'OLMoE_merge_by_groups_with_usage_frequency_weighting'
]

SIMILARITY_MAPPING_FUNCTION = {
    "cosine": lambda x, y: (F.cosine_similarity(x, y, dim=-1, eps=FP32_EPS) + 1) / 2,
    "mse": lambda x, y: 1 / (1 + 0.1 * torch.log(F.mse_loss(x, y, reduction="sum"))),
    "l2": lambda x, y: F.mse_loss(x, y, reduction="sum"),
}
LEGAL_SIMILARITY_BASES = ["weight", "feature", "feature.abs", "weight-feature", "gradient", "weight-gradient",
                          "router-logits", "router-weight", "router-weight-feature", "mse", "random",
                          "feature-correlation.lsa", "feature-correlation.max", "gate-weight", "gate-up-weight", "gate-act"]


class ExpertsGrouperForOLMoE(object):
    def __init__(
            self,
            config: Union[OlmoeConfig, PretrainedConfig],
            similarity_fn: str = "cosine",
            similarity_base: str = "weight",
    ):
        if similarity_fn not in SIMILARITY_MAPPING_FUNCTION:
            raise ValueError(
                f"[Merging]similarity_fn should be one of {SIMILARITY_MAPPING_FUNCTION.keys()}, got {similarity_fn} instead.")
        if similarity_base not in LEGAL_SIMILARITY_BASES:
            raise ValueError(
                f"[Merging]similarity_base should be one of {LEGAL_SIMILARITY_BASES}, got {similarity_base} instead.")

        self.num_experts = config.num_experts
        self.num_experts_per_tok = config.num_experts_per_tok
        self.avg_num_merged_experts = self.num_experts
        self.hidden_size = config.hidden_size
        self.sparse_layer_indices = list(range(0, config.num_hidden_layers))
        self.similarity_fn = SIMILARITY_MAPPING_FUNCTION[similarity_fn]
        self.similarity_base = similarity_base
        self._group_state_dict = None
        self._similarity_state_dict = None
        self._usage_frequency_state_dict = None
        self.reset_all()

        # OT
        self.transport_matrixes = dict()

        # SVD
        self.composed_matrixes = dict()

    def reset_all(self):
        if self.similarity_base == "mse":
            self.similarity_fn = SIMILARITY_MAPPING_FUNCTION["mse"]
            print("[Merging]Set similarity_fn to mse for mse similarity_base.")
        self._group_state_dict = dict()
        self._similarity_state_dict = dict()
        self._usage_frequency_state_dict = dict()
        # Similarity range: [0, 2]
        for layer_idx in self.sparse_layer_indices:
            mlp_name = f"model.layers.{layer_idx}.mlp"
            self._group_state_dict[mlp_name] = torch.arange(self.num_experts,
                                                                    device="cuda")
            self._similarity_state_dict[mlp_name] = torch.zeros(
                (self.num_experts, self.num_experts), device="cuda"
            ) + torch.eye(self.num_experts, device="cuda")
            self._usage_frequency_state_dict[mlp_name] = torch.zeros(self.num_experts, device="cuda")
            self._usage_frequency_state_dict[mlp_name] = torch.zeros(self.num_experts, device="cuda")

        self.transport_matrices = dict()

    def similarity_state_dict(self) -> Dict[str, torch.Tensor]:
        return deepcopy(self._similarity_state_dict)

    def group_state_dict(self) -> Dict[str, torch.LongTensor]:
        return deepcopy(self._group_state_dict)

    def usage_frequency_state_dict(self) -> Dict[str, torch.Tensor]:
        return deepcopy(self._usage_frequency_state_dict)

    def save_similarity(self, mlp_name: str, i: int, j: int, similarity: float):
        self._similarity_state_dict[mlp_name][i, j] = similarity
        self._similarity_state_dict[mlp_name][j, i] = similarity

    def get_similarity(self, mlp_name: str, i: int, j: int) -> float:
        return self._similarity_state_dict[mlp_name][i, j].item()

    def get_similarity_matrix(self, mlp_name: str) -> torch.Tensor:
        return deepcopy(self._similarity_state_dict[mlp_name])

    def get_transport_matrix(self, mlp_name: str) -> torch.Tensor:
        return deepcopy(self.transport_matrixes[mlp_name])

    def get_composed_matrixes(self, mlp_name: str) -> List[torch.Tensor]:
        return self.composed_matrixes[mlp_name]

    def del_composed_matrixes(self, mlp_name: str):
        del self.composed_matrixes[mlp_name]

    def save_group_state_dict(self, save_dir: str):
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        torch.save(self._group_state_dict, os.path.join(save_dir, "group_state_dict.pt"))

    def load_group_state_dict(self, load_dir: str):
        self._group_state_dict = torch.load(os.path.join(load_dir, "group_state_dict.pt"))

    def set_avg_num_merged_experts(self, avg_num_experts: int):
        self.avg_num_merged_experts = avg_num_experts


    def _assign_num_groups_per_layer(
            self,
            average_num_groups: int,
            merging_layers: List[int],
    ) -> Dict[str, int]:
        num_groups_per_layer = dict()
        for i, layer_idx in enumerate(self.sparse_layer_indices):
            if layer_idx not in merging_layers:
                num_groups_per_layer[f"model.layers.{layer_idx}.mlp"] = self.num_experts
            else:
                num_groups_per_layer[f"model.layers.{layer_idx}.mlp"] = average_num_groups

        return num_groups_per_layer




    def group_experts_into_clusters_by_routing_guided_globally(
            self,
            average_num_groups: int,
            merging_layers: List[int],
            layer_group_capacity: Optional[int] = None,
    ) -> Dict[str, List[int]]:
        """
        Globally group experts into clusters by routing-guided clustering, each layer will have different number of
         clusters. The total number of clusters is determined by average_num_groups.

        Parameters
        ----------
        average_num_groups: int
            The average number of clusters for all layers.
        merging_layers: List[int]
            The layers that are excluded from merging.
        layer_group_capacity: Optional[int]
            The maximum number of experts in each group in the layers. If None, the number of experts in each group is not limited.

        Returns
        -------
        core_experts: Dict[str, List[int]]
            The core experts of each cluster
        """
        # By default, the first layer of encoder is excluded.
        # 1. Assign num_groups respectively for each layer according to average_num_groups
        layer_group_capacity = layer_group_capacity if layer_group_capacity is not None else self.num_experts
        # whether to use optimization? assign layer num with OT or not?
        num_groups_per_layer = self._assign_num_groups_per_layer(
            average_num_groups, merging_layers, 
        )
        print(f"[Merging]Number of groups of each layer: {num_groups_per_layer}")
        # 2. Group experts into clusters for each layer
        core_experts = dict()
        for layer_idx in tqdm(#self.sparse_layer_indices,
                            merging_layers,
                              desc=f"Globally routing-guided clustering experts into average {average_num_groups} clusters"):
            mlp_name = f"model.layers.{layer_idx}.mlp"
            num_groups = num_groups_per_layer[mlp_name]
            group_member_count = torch.zeros(num_groups)
            indices_sorted_by_usage = torch.argsort(self._usage_frequency_state_dict[mlp_name], descending=True)
            # 1.1 Assign top-K most-used experts with label 0 to K-1 respectively
            core_expert_indices = indices_sorted_by_usage[:num_groups]
            core_experts[mlp_name] = core_expert_indices.tolist()
            for i in range(num_groups):
                self._group_state_dict[mlp_name][core_expert_indices[i]] = i
                group_member_count[i] += 1

            # 1.2 Assign left unassigned experts to the cluster with the most similar core
            similarity_matrix = self.get_similarity_matrix(mlp_name)
            for i in range(num_groups, self.num_experts):
                # Find the most similar core
                expert_idx = indices_sorted_by_usage[i]
                most_similar_core = core_expert_indices[
                    torch.argmax(similarity_matrix[expert_idx, core_expert_indices])
                ]
                most_similar_group_label = self._group_state_dict[mlp_name][most_similar_core]
                self._group_state_dict[mlp_name][expert_idx] = most_similar_group_label
                group_member_count[most_similar_group_label] += 1
                if group_member_count[self._group_state_dict[mlp_name][expert_idx]] > layer_group_capacity:
                    if len(core_expert_indices) == 1:
                        raise ValueError(
                            f"[Merging]The number of groups at Encoder layer {layer_idx} is too small!"
                        )
                    # Kick out the filled group as well as its core, by pop the core from core_experts
                    core_index = torch.argmax(similarity_matrix[expert_idx, core_expert_indices])
                    core_expert_indices = torch.cat(
                        [core_expert_indices[:core_index], core_expert_indices[core_index + 1:]]
                    )

            similarity_matrix = self.get_similarity_matrix(mlp_name)
            cost = similarity_matrix[:, core_expert_indices]
            source_prob = self._usage_frequency_state_dict[mlp_name]
            target_prob = source_prob[core_expert_indices]
            target_prob *= torch.sum(source_prob) / torch.sum(target_prob)
            ot_map = ot.emd(source_prob.detach().cpu(), target_prob.detach().cpu(), cost.detach().cpu()).transpose(0, 1)
            # ot_cost = torch.sum(ot_map * cost)
            ot_map = ot_map / ot_map.sum(dim=0)
            self.transport_matrixes[mlp_name] = ot_map

        return core_experts


    def compute_all_usages(
            self,
            model: OlmoeForCausalLM,
            batch: Dict[str, torch.Tensor],
            mini_batch_size: Optional[int] = 128,
            merging_layers: Optional[List[int]] = None,
    ):
        model.cuda()
        model.eval()
        total_batch_size = batch["input_ids"].shape[0]
        if mini_batch_size > total_batch_size:
            mini_batch_size = total_batch_size
        num_batches = total_batch_size // mini_batch_size
        for i in tqdm(range(num_batches), desc="[Merging]Computing all usages..."):
            with torch.no_grad():
                mini_batch = {k: v[i * mini_batch_size: (i + 1) * mini_batch_size] for k, v in batch.items()}
                mini_batch = {k: v.cuda() for k, v in mini_batch.items()}
                outputs = model(**mini_batch, output_router_logits=True)
                for layer_idx in merging_layers:
                    mlp_name = f"model.layers.{layer_idx}.mlp"
                    routing_weights = F.softmax(outputs.router_logits[layer_idx], dim=1, dtype=torch.float)
                    routing_weights, selected_experts = torch.topk(routing_weights, 8, dim=-1)
                    for idx in selected_experts.reshape(-1):
                        self._usage_frequency_state_dict[mlp_name][idx] += 1
                    
        self._usage_frequency_state_dict = {
            k: v / torch.sum(v) for k, v in self._usage_frequency_state_dict.items()
        }

    def reverse_all_similarities(self):
        print("[Merging]Reversing all similarities...")
        for key in self._similarity_state_dict.keys():
            self._similarity_state_dict[key] = 1 - self._similarity_state_dict[key]


    def compute_layer_composed_matrixes(
            self,
            model: OlmoeForCausalLM,
            merging_layer_idx: int,
            batch: Dict[str, torch.Tensor]
    ):
        model = model.cuda()
        model = model.eval()
        batch = {k: v.cuda() for k, v in batch.items()}
        self.activations = {}
        self.gate_acts = {}
        self.inputs = {}

        mlp_name = f"model.layers.{merging_layer_idx}.mlp"
        handle = model.model.layers[merging_layer_idx].mlp.register_forward_hook(
            self._get_mlp_activation(mlp_name)
        )

        with torch.no_grad():
            model(**batch)

        group_num = self.avg_num_merged_experts
        
        ACT = self.activations[mlp_name].permute(1, 0, 2)
        
        # compose with topk
        ACT_T = torch.transpose(ACT, 1, 2)
        W = torch.mean(torch.bmm(ACT, ACT_T), dim=0)
        W_inv = torch.linalg.inv(W)
        expert_num = W.size(0)
        value_list = []
        for i in range(expert_num):
            Wi = W[i:i+1, :]
            tmp = torch.matmul(Wi, W_inv)
            value_list.append(torch.matmul(tmp, Wi.T))
        arr = torch.tensor(value_list)
        _, sorted_indices = torch.sort(arr, descending=True)

        group_labels = torch.zeros_like(arr, dtype=torch.int32)
        A = torch.zeros((expert_num, group_num), device=W.device)
        B = []
        for i in range(group_num - 1):
            t = sorted_indices[i]
            group_labels[t] = i
            B.append(torch.matmul(W[t:t+1, :], W_inv))
        
        avg_w = torch.zeros((1, expert_num), device=W.device)
        for i in range(group_num - 1, expert_num):
            t = sorted_indices[i]
            group_labels[t] = group_num - 1
            avg_w += W[t:t+1, :]

        B.append(torch.matmul(avg_w / (expert_num - group_num + 1), W_inv))
        B = torch.cat(B, dim = 0)
        for i in range(expert_num):
            A[i][group_labels[i]] = 1
        
        C = self.gate_acts[mlp_name]
        D = self.inputs[mlp_name]
        self.composed_matrixes[mlp_name] = [A, B, C, D]
        handle.remove()

        del self.activations
        del self.gate_acts
        del self.inputs


    def compute_layer_act(
            self,
            model: OlmoeForCausalLM,
            merging_layer_idx: int,
            batch: Dict[str, torch.Tensor]
    ):
        model = model.cuda()
        model = model.eval()
        batch = {k: v.cuda() for k, v in batch.items()}
        self.up_acts = {}
        self.gate_acts = {}

        mlp_name = f"model.layers.{merging_layer_idx}.mlp"
        handle = model.model.layers[merging_layer_idx].mlp.register_forward_hook(
            self._get_mlp_activation(mlp_name)
        )

        with torch.no_grad():
            model(**batch)

        self.composed_matrixes[mlp_name] = [self.gate_acts[mlp_name], self.up_acts[mlp_name]]
        handle.remove()

        del self.up_acts
        del self.gate_acts

    def compute_all_similarities(
            self,
            model: OlmoeForCausalLM,
            batch: Dict[str, torch.Tensor] = None,
            merging_layers: Optional[List[int]] = None,
    ):
        if self.similarity_base not in ["weight", "router-weight"] and batch is None:
            raise ValueError(
                "[Merging]batch should be provided when similarity_base is not 'weight' or 'router-weight'")

        model = model.cuda()
        model = model.eval()
        if self.similarity_base == "weight":
            self._compute_all_similarities_by_weight(model.state_dict(), merging_layers)
        elif self.similarity_base == 'gate-weight':
            self._compute_all_similarities_by_gate_weight(model.state_dict(), merging_layers)
        elif self.similarity_base == 'gate-up-weight':
            self._compute_all_similarities_by_gate_up_weight(model.state_dict(), merging_layers)
        elif self.similarity_base == 'router-weight':
            self._compute_all_similarities_by_router_weight(model.state_dict(), merging_layers)
        elif self.similarity_base == 'router-logits':
            batch = {k: v.cuda() for k, v in batch.items()}
            self._compute_all_similarities_by_router_logits(model, batch, merging_layers)
        elif self.similarity_base == 'gate-act':
            batch = {k: v.cuda() for k, v in batch.items()}
            self._compute_all_similarities_by_gate_act(model, batch, merging_layers)
        else:
            raise NotImplementedError

    def _compute_all_similarities_by_weight(self, state_dict: Dict[str, torch.Tensor], 
                                            merging_layers: Optional[List[int]] = None,):
        for layer_idx in tqdm(merging_layers, desc="[Merging]Computing similarities by weight..."):
            mlp_name = f"model.layers.{layer_idx}.mlp"
            for i in range(self.num_experts):
                for j in range(i + 1, self.num_experts):
                    i_flat = torch.cat(
                        [state_dict[f"{mlp_name}.experts.{i}.up_proj.weight"].flatten(),
                         state_dict[f"{mlp_name}.experts.{i}.down_proj.weight"].flatten()],
                        dim=0
                    )
                    j_flat = torch.cat(
                        [state_dict[f"{mlp_name}.experts.{j}.up_proj.weight"].flatten(),
                         state_dict[f"{mlp_name}.experts.{j}.down_proj.weight"].flatten()],
                        dim=0
                    )
                    similarity = self.similarity_fn(i_flat, j_flat)
                    self.save_similarity(mlp_name, i, j, similarity)

    def _compute_all_similarities_by_gate_weight(self, state_dict: Dict[str, torch.Tensor], 
                                                 merging_layers: Optional[List[int]] = None,):
        for layer_idx in tqdm(merging_layers, desc="[Merging]Computing similarities by gate-weight..."):
            mlp_name = f"model.layers.{layer_idx}.mlp"
            for i in range(self.num_experts):
                for j in range(i + 1, self.num_experts):
                    i_flat = state_dict[f"{mlp_name}.experts.{i}.gate_proj.weight"].flatten()
                    j_flat = state_dict[f"{mlp_name}.experts.{j}.gate_proj.weight"].flatten()
                    similarity = self.similarity_fn(i_flat, j_flat)
                    self.save_similarity(mlp_name, i, j, similarity)

    
    def _compute_all_similarities_by_gate_up_weight(self, state_dict: Dict[str, torch.Tensor],
                                                    merging_layers: Optional[List[int]] = None,):
        for layer_idx in tqdm(merging_layers, desc="[Merging]Computing similarities by gate-up-weight..."):
            mlp_name = f"model.layers.{layer_idx}.mlp"
            for i in range(self.num_experts):
                for j in range(i + 1, self.num_experts):
                    i_flat = torch.cat(
                        [state_dict[f"{mlp_name}.experts.{i}.gate_proj.weight"].flatten(),
                         state_dict[f"{mlp_name}.experts.{i}.up_proj.weight"].flatten()],
                        dim=0
                    )
                    j_flat = torch.cat(
                        [state_dict[f"{mlp_name}.experts.{j}.gate_proj.weight"].flatten(),
                         state_dict[f"{mlp_name}.experts.{j}.up_proj.weight"].flatten()],
                        dim=0
                    )
                    similarity = self.similarity_fn(i_flat, j_flat)
                    self.save_similarity(mlp_name, i, j, similarity)


    def _compute_all_similarities_by_router_weight(
            self, state_dict: Dict[str, torch.Tensor],
            merging_layers: Optional[List[int]] = None,
    ):
        for layer_idx in tqdm(merging_layers, desc="[Merging]Computing similarities by router rows..."):
            mlp_name = f"model.layers.{layer_idx}.mlp"
            for i in range(self.num_experts):
                for j in range(i + 1, self.num_experts):
                    i_flat = state_dict[f"{mlp_name}.gate.weight"][i]
                    j_flat = state_dict[f"{mlp_name}.gate.weight"][j]
                    similarity = self.similarity_fn(i_flat, j_flat)
                    self.save_similarity(mlp_name, i, j, similarity)


    def _compute_all_similarities_by_router_logits(
            self, model: OlmoeForCausalLM, batch: Dict[str, torch.Tensor],
            merging_layers: Optional[List[int]] = None,
    ):
        with torch.no_grad():
            outputs = model(**batch, output_router_logits=True)
        for layer_idx in tqdm(merging_layers, desc="[Merging]Computing similarities by router logits..."):
            mlp_name = f"model.layers.{layer_idx}.mlp"
            router_logits = outputs.router_logits[layer_idx].reshape(-1, self.num_experts)
            with torch.no_grad():
                for i in range(self.num_experts):
                    for j in range(i + 1, self.num_experts):
                        i_flat = router_logits[:, i].flatten()
                        j_flat = router_logits[:, j].flatten()
                        similarity = self.similarity_fn(i_flat, j_flat)
                        self.save_similarity(mlp_name, i, j, similarity)

    def _compute_all_similarities_by_gate_act(
            self, model: OlmoeForCausalLM, batch: Dict[str, torch.Tensor],
            merging_layers: Optional[List[int]] = None,
    ):
        for layer_idx in tqdm(merging_layers, desc="[Merging]Computing similarities by gate act..."):
            
            self.up_acts = {}
            self.gate_acts = {}
            mlp_name = f"model.layers.{layer_idx}.mlp"
            handle = model.model.layers[layer_idx].mlp.register_forward_hook(
                self._get_mlp_activation(mlp_name)
            )
            with torch.no_grad():
                model(**batch)

            gate_acts = self.gate_acts[mlp_name]
            handle.remove()

            del self.up_acts
            del self.gate_acts
            #print(gate_acts.shape)
            with torch.no_grad():
                for i in range(self.num_experts):
                    for j in range(i + 1, self.num_experts):
                        similarity = self.similarity_fn(gate_acts[i], gate_acts[j])
                        self.save_similarity(mlp_name, i, j, similarity)
            

 

    def _get_mlp_activation(self, name):
        def hook(module, input, output):
            batch_size, sequence_length, hidden_dim = input[0].shape
            hidden_states = input[0].view(-1, hidden_dim)
            gate_acts = []
            up_acts = []

            for expert_idx in range(module.num_experts):
                up_acts.append(module.experts[expert_idx].up_proj(hidden_states))
                act_fn = module.experts[expert_idx].act_fn
                gate_acts.append(module.experts[expert_idx].gate_proj(hidden_states))
            self.up_acts[name] = torch.stack(up_acts)
            self.gate_acts[name] = torch.stack(gate_acts)
        return hook


def _merge_mlp_experts_by_averaging(
        mlp: OlmoeMLP,
        group_labels: torch.LongTensor,
        permute: bool,
        permute_strategy: str,
        forwarded_hidden_states: Optional[Tuple[torch.Tensor]] = None,
) -> OlmoeMLP:
    for label in group_labels.unique():
        expert_indices = torch.where(group_labels == label)[0]
        if permute and permute_strategy == "weight-matching":
            for expert_idx in expert_indices[1:]:
                perm = compute_OLMoE_permutation_by_weight_matching(
                    reference_mlp=mlp.experts[expert_indices[0]],
                    target_mlp=mlp.experts[expert_idx],
                    include_wo=True
                )
                mlp.experts[expert_idx] = permute_OLMoE_mlp_dense_expert_(
                    mlp.experts[expert_idx], perm
                )
        elif permute and permute_strategy == "activation-matching":
            group_forwarded_hidden_states = torch.cat([
                forwarded_hidden_states[expert_idx] for expert_idx in expert_indices
            ], dim=0)
            for expert_idx in expert_indices[1:]:
                perm = compute_OLMoE_permutation_by_activation_matching(
                    reference_mlp=mlp.experts[expert_indices[0]],
                    target_mlp=mlp.experts[expert_idx],
                    forwarded_hidden_states=group_forwarded_hidden_states,
                )
                mlp.experts[expert_idx] = permute_OLMoE_mlp_dense_expert_(
                    mlp.experts[expert_idx], perm
                )
        elif permute:
            raise ValueError(f"Unknown permute strategy: {permute_strategy}")

        with torch.no_grad():
            up_proj_weight = torch.mean(
                torch.stack([mlp.experts[expert_idx].up_proj.weight for expert_idx in expert_indices]),
                dim=0
            )
            down_proj_weight = torch.mean(
                torch.stack([mlp.experts[expert_idx].down_proj.weight for expert_idx in expert_indices]),
                dim=0
            )
            gate_proj_weight = torch.mean(
                torch.stack([mlp.experts[expert_idx].gate_proj.weight for expert_idx in expert_indices]),
                dim=0
            )
            mlp.experts[expert_indices[0]].up_proj.weight.copy_(up_proj_weight)
            mlp.experts[expert_indices[0]].down_proj.weight.copy_(down_proj_weight)
            mlp.experts[expert_indices[0]].gate_proj.weight.copy_(gate_proj_weight)
            for expert_idx in expert_indices[1:]:
                # Binding merged experts to the first of them
                mlp.experts[expert_idx] = mlp.experts[expert_indices[0]]
    return mlp


def _merge_mlp_experts_by_fisher_weighted_averaging(
        mlp: OlmoeMLP,
        group_labels: torch.LongTensor,
        name_prefix: str,
        experts_fisher_state_dict: Dict[str, torch.Tensor],
        permute: Optional[bool] = False
) -> OlmoeMLP:
    for label in group_labels.unique():
        expert_indices = torch.where(group_labels == label)[0]
        if permute:
            for expert_idx in expert_indices[1:]:
                perm = compute_OLMoE_permutation_by_activation_matching(
                    reference_mlp=mlp.experts[f"expert_{expert_indices[0]}"],
                    target_mlp=mlp.experts[f"expert_{expert_idx}"],
                    include_wo=True
                )
                mlp.experts[f"expert_{expert_idx}"] = permute_OLMoE_mlp_dense_expert_(
                    mlp.experts[f"expert_{expert_idx}"], perm
                )
        with torch.no_grad():
            wi_weight_list = torch.stack(
                [mlp.experts[f"expert_{expert_idx}"].wi.weight for expert_idx in expert_indices], dim=0
            )
            wo_weight_list = torch.stack(
                [mlp.experts[f"expert_{expert_idx}"].wo.weight for expert_idx in expert_indices], dim=0
            )
            wi_fisher_list = torch.stack(
                [experts_fisher_state_dict[f"{name_prefix}.expert_{expert_idx}.wi.weight"]
                 for expert_idx in expert_indices], dim=0
            )
            wo_fisher_list = torch.stack(
                [experts_fisher_state_dict[f"{name_prefix}.expert_{expert_idx}.wo.weight"]
                 for expert_idx in expert_indices], dim=0
            )

            wi_weight = torch.sum(wi_weight_list * wi_fisher_list, dim=0) / (
                    torch.sum(wi_fisher_list, dim=0) + FP32_EPS)
            wo_weight = torch.sum(wo_weight_list * wo_fisher_list, dim=0) / (
                    torch.sum(wo_fisher_list, dim=0) + FP32_EPS)

            mlp.experts[f"expert_{expert_indices[0]}"].wi.weight.copy_(wi_weight)
            mlp.experts[f"expert_{expert_indices[0]}"].wo.weight.copy_(wo_weight)
            for expert_idx in expert_indices[1:]:
                # Binding merged experts to the first of them
                mlp.experts[f"expert_{expert_idx}"] = mlp.experts[f"expert_{expert_indices[0]}"]
    return mlp


def _prune_mlp_experts_by_replacing_non_core(
        mlp: OlmoeMLP,
        group_labels: torch.LongTensor,
        core_expert_indices: List[int],
) -> OlmoeMLP:
    """
    Prune non-core experts by replacing them with core experts
    """
    for core_idx in core_expert_indices:
        with torch.no_grad():
            experts_indices = torch.where(group_labels == group_labels[core_idx])[0]
            core_expert = mlp.experts[core_idx]
            for expert_idx in experts_indices:
                if expert_idx == core_idx:
                    continue
                mlp.experts[expert_idx] = core_expert

    return mlp


def _prune_mlp_experts_by_dropping_non_core(
        mlp: OlmoeMLP,
        core_expert_indices: List[int],
) -> OlmoeMLP:
    non_core_expert_mask = torch.ones(len(mlp.experts), dtype=torch.bool)
    non_core_expert_mask[core_expert_indices] = False
    non_core_expert_indices = torch.where(non_core_expert_mask)[0]
    mask_value = -1e6

    def _custom_forward(self, hidden_states: torch.Tensor) -> Tuple:
        _, router_logits = self._compute_router_probabilities(hidden_states)

        router_logits = router_logits.masked_fill(non_core_expert_mask.to(router_logits.device), mask_value)
        router_probs = F.softmax(router_logits, dim=-1, dtype=self.dtype).to(self.input_dtype)

        expert_index = torch.argmax(router_probs, dim=-1)
        expert_index = F.one_hot(expert_index, num_classes=self.num_experts)
        # Since experts are pruned, no need to do max-capacity dropping
        router_probs = torch.max(router_probs, dim=-1).values.unsqueeze(-1)
        return expert_index, router_probs, router_logits

    mlp.gate.forward = _custom_forward.__get__(mlp.gate, torch.nn.Linear)

    for expert_idx in non_core_expert_indices:
        # fake expert, this will raise error if used by mistake
        mlp.experts[expert_idx] = torch.nn.Linear(mlp.router.classifier.in_features, 1)

    return mlp


def prune_non_core_experts_by_groups(
        model: OlmoeForCausalLM,
        grouper: ExpertsGrouperForOLMoE,
        core_experts: Dict[str, List[int]],
        strategy: str = "replace",
        merging_layers: Optional[List[int]] = None,
) -> OlmoeForCausalLM:
    """
    Simply prune non-core experts by replacing them with core experts, or dropping them.
    (The difference between replacing and dropping is whether to route to the pruned experts and redirect to core experts)

    Parameters
    ----------
    model: OlmoeForCausalLM
        The model to merge experts
    grouper: ExpertsGrouperForSwitch
        The grouper to group experts, supposed to have been called `grouper.compute_all_usages()` and
            one of `grouper.group_experts()` (i.e. have grouped labels)
    core_experts: Dict[str, List[int]]
        The core experts dict, normally returned by `grouper.group_experts_into_clusters_by_routing_guided_globally()`
    strategy: str
        The strategy to merge experts, one of ["replace", "drop"]
    merging_layers: Optional[List[int]]
        The layers to merge experts, if None, merge all layers
    """
    if strategy == "replace":
        for layer_idx in tqdm(grouper.sparse_layer_indices, desc="[Merging]Pruning non-core experts by replacing..."):
            if merging_layers is None or layer_idx in merging_layers:
                mlp_name = f"model.layers.{layer_idx}.mlp"
                group_labels = grouper.group_state_dict()[mlp_name]
                model.model.layers[layer_idx].mlp = _prune_mlp_experts_by_replacing_non_core(
                    mlp=model.layers[layer_idx].mlp,
                    group_labels=group_labels,
                    core_expert_indices=core_experts[mlp_name]
                )
    elif strategy == "drop":
        for layer_idx in tqdm(grouper.sparse_layer_indices, desc="[Merging]Pruning non-core experts by dropping..."):
            if merging_layers is None or layer_idx in merging_layers:
                mlp_name = f"model.layers.{layer_idx}.mlp"
                model.model.layers[layer_idx].mlp = _prune_mlp_experts_by_dropping_non_core(
                    mlp=model.layers[layer_idx].mlp,
                    core_expert_indices=core_experts[mlp_name]
                )
    return model

def _merge_mlp_experts_by_usage_frequency_weighting(
        mlp: OlmoeForCausalLM,
        group_labels: torch.LongTensor,
        usage_frequencies: torch.Tensor,
        permute: bool,
) -> OlmoeForCausalLM:
    for label in group_labels.unique():
        expert_indices = torch.where(group_labels == label)[0]
        if permute:
            for expert_idx in expert_indices[1:]:
                perm = compute_OLMoE_permutation_by_weight_matching(
                    reference_mlp=mlp.experts[expert_indices[0]],
                    target_mlp=mlp.experts[expert_idx],
                    include_wo=True
                )
                mlp.experts[expert_idx] = permute_OLMoE_mlp_dense_expert_(
                    mlp.experts[expert_idx], perm
                )
        with torch.no_grad():
            up_proj_weight_list = torch.stack(
                [mlp.experts[expert_idx].up_proj.weight * usage_frequencies[expert_idx] for expert_idx in
                 expert_indices], dim=0
            )
            down_proj_weight_list = torch.stack(
                [mlp.experts[expert_idx].down_proj.weight * usage_frequencies[expert_idx] for expert_idx in
                 expert_indices], dim=0
            )
            gate_proj_weight_list = torch.stack(
                [mlp.experts[expert_idx].gate_proj.weight * usage_frequencies[expert_idx] for expert_idx in
                 expert_indices], dim=0
            )
            up_proj_weight = torch.sum(up_proj_weight_list, dim=0) / (
                    torch.sum(usage_frequencies[expert_indices], dim=0) + FP32_EPS)
            down_proj_weight = torch.sum(down_proj_weight_list, dim=0) / (
                    torch.sum(usage_frequencies[expert_indices], dim=0) + FP32_EPS)
            gate_proj_weight = torch.sum(gate_proj_weight_list, dim=0) / (
                    torch.sum(usage_frequencies[expert_indices], dim=0) + FP32_EPS)
            
            mlp.experts[expert_indices[0]].up_proj.weight.copy_(up_proj_weight)
            mlp.experts[expert_indices[0]].down_proj.weight.copy_(down_proj_weight)
            mlp.experts[expert_indices[0]].gate_proj.weight.copy_(gate_proj_weight)
 
            for expert_idx in expert_indices[1:]:
                # Binding merged experts to the first of them
                mlp.experts[expert_idx] = mlp.experts[expert_indices[0]]
    return mlp


def _merge_mlp_experts_by_weighting_act(
        mlp: OlmoeForCausalLM,
        group_labels: torch.LongTensor,
        usage_frequencies: torch.Tensor,
        composed_matrixes: List[torch.Tensor]
) -> OlmoeForCausalLM:
    gate_acts = composed_matrixes[0]
    up_acts = composed_matrixes[1]
    act_fn = mlp.experts[0].act_fn
    original_acts = act_fn(gate_acts) * up_acts
    #print(gate_acts.shape)
    for label in group_labels.unique():
        expert_indices = torch.where(group_labels == label)[0]
        usage_freq_sum = torch.sum(usage_frequencies[expert_indices], dim=0) + FP32_EPS
        
        with torch.no_grad():
            if expert_indices.numel() > 1:
                merged_gate_acts_list = torch.stack(
                    [gate_acts[expert_idx] * usage_frequencies[expert_idx] for expert_idx in expert_indices], dim=0
                )
                merged_up_acts_list = torch.stack(
                    [up_acts[expert_idx] * usage_frequencies[expert_idx] for expert_idx in expert_indices], dim=0
                )
                merged_gate_acts = torch.sum(merged_gate_acts_list, dim=0) / usage_freq_sum
                merged_up_acts = torch.sum(merged_up_acts_list, dim=0) / usage_freq_sum
                
                merged_acts = act_fn(merged_gate_acts) * merged_up_acts
                sample_num, intermediate_size = merged_acts.size()
                unmerged_acts = original_acts[expert_indices].permute(1,0,2).reshape(sample_num, -1)
                #print(unmerged_acts.shape)
                solution, _, _, _ = torch.linalg.lstsq(merged_acts.to(torch.float), unmerged_acts.to(torch.float))

                solution = solution.T.to(mlp.experts[0].down_proj.weight.dtype)
                down_proj_weight = torch.zeros_like(mlp.experts[0].down_proj.weight)
                for i, expert_idx in enumerate(expert_indices):
                    down_proj_weight += torch.matmul(mlp.experts[expert_idx].down_proj.weight * usage_frequencies[expert_idx], 
                                solution[i*intermediate_size:(i+1)*intermediate_size])
                down_proj_weight /= usage_freq_sum
            else:
                # group size == 1
                assert(expert_indices.numel() == 1)
                down_proj_weight = mlp.experts[expert_indices[0]].down_proj.weight

            up_proj_weight_list = torch.stack(
                [mlp.experts[expert_idx].up_proj.weight * usage_frequencies[expert_idx] for expert_idx in
                 expert_indices], dim=0
            )
            gate_proj_weight_list = torch.stack(
                [mlp.experts[expert_idx].gate_proj.weight * usage_frequencies[expert_idx] for expert_idx in
                 expert_indices], dim=0
            )
            up_proj_weight = torch.sum(up_proj_weight_list, dim=0) / usage_freq_sum
            gate_proj_weight = torch.sum(gate_proj_weight_list, dim=0) / usage_freq_sum
            
            mlp.experts[expert_indices[0]].up_proj.weight.copy_(up_proj_weight)
            mlp.experts[expert_indices[0]].gate_proj.weight.copy_(gate_proj_weight)
            mlp.experts[expert_indices[0]].down_proj.weight.copy_(down_proj_weight)
 
            for expert_idx in expert_indices[1:]:
                # Binding merged experts to the first of them
                mlp.experts[expert_idx] = mlp.experts[expert_indices[0]]
    return mlp

def _merge_mlp_experts_by_linear(
        mlp: MyOlmoeForCausalLM,
        group_labels: torch.LongTensor,
        usage_frequencies: torch.Tensor,
        
)->MyOlmoeForCausalLM:
    expert_num = usage_frequencies.shape[0]
    Linear_Merge_Matrix = torch.zeros(expert_num, expert_num, 
                                      dtype=mlp.experts[0].down_proj.weight.dtype)
    for label in group_labels.unique():
        expert_indices = torch.where(group_labels == label)[0]
        usage_freq_sum = torch.sum(usage_frequencies[expert_indices], dim=0) + FP32_EPS
        #print(expert_indices)
        for i in expert_indices:
            for j in expert_indices:
                Linear_Merge_Matrix[i][j] = usage_frequencies[j] / usage_freq_sum

    mlp.merge.weight.data.copy_(Linear_Merge_Matrix)

    return mlp
        

def _OLMoE_merge_mlp_experts_within_and_across_models(
        mlp: OlmoeMLP,
        group_labels: torch.LongTensor,
        forwarded_hidden_states: Tuple[torch.Tensor],
        dominant_alone: bool,
        core_expert_indices: Optional[List[int]] = None,
        usage_frequencies: Optional[torch.Tensor] = None,
) -> OlmoeMLP:
    """
    Merge grouped experts within and across models.

    Parameters
    ----------
    mlp: SwitchTransformersSparseMLP
        The mlp to merge experts.
    group_labels: torch.LongTensor
        The group labels of experts.
    forwarded_hidden_states: Tuple[torch.Tensor]
        The forwarded hidden states of each expert, should be of length num_experts
    dominant_alone: bool
        Whether to merge the dominant expert alone.
        If True, the merging process in a group will be done in two steps:
            1. Merge all experts except the dominant one.
            2. Merge the dominant expert with the merged expert in step 1.
    core_expert_indices: List[int]

    Returns
    -------
    mlp: SwitchTransformersSparseMLP
        The merged mlp.
    """
    if dominant_alone and core_expert_indices is None:
        raise ValueError("[Merging]dominant_alone is True, but core_expert_indices is None")

    for label in group_labels.unique():
        expert_indices = torch.where(group_labels == label)[0]
        with torch.no_grad():
            if dominant_alone:
                group_core_expert_indices = torch.stack([
                    idx for idx in expert_indices if idx in core_expert_indices])
                to_skip = False
                if len(group_core_expert_indices) == len(expert_indices):
                    merged_expert = mlp.experts[expert_indices[0]]
                    to_skip = True
                elif usage_frequencies is not None and len(group_core_expert_indices) == 1:
                    non_core_usage_sum = torch.sum(
                        usage_frequencies[[expert_idx.item() for expert_idx in
                                           expert_indices if expert_idx not in group_core_expert_indices]]).item()
                    if non_core_usage_sum == 0:
                        merged_expert = mlp.experts[group_core_expert_indices[0]]
                        to_skip = True
                    else:
                        to_skip = False
                if not to_skip:
                    # Stage 1: merge all experts except the dominant one
                    group_forwarded_hidden_states = torch.cat([
                        forwarded_hidden_states[expert_idx] for expert_idx in expert_indices if
                        expert_idx not in group_core_expert_indices
                    ], dim=0)
                    if usage_frequencies is not None:
                        non_core_usages = usage_frequencies[[expert_idx.item() for expert_idx in expert_indices if
                                                             expert_idx not in group_core_expert_indices]]
                    merged_expert = merge_olmoe_mlp_by_activation_matching_within_and_across_models(
                        mlp_list=[mlp.experts[expert_idx] for expert_idx in expert_indices if
                                  expert_idx not in group_core_expert_indices],
                        forwarded_hidden_states=group_forwarded_hidden_states,
                        average_coefs=non_core_usages.tolist() if usage_frequencies is not None else None
                    )
                    # Stage 2: merge the dominant expert with the merged expert in stage 1
                    group_forwarded_hidden_states = torch.cat([
                        forwarded_hidden_states[expert_idx] for expert_idx in expert_indices
                    ], dim=0)
                    if usage_frequencies is not None:
                        core_usages = usage_frequencies[group_core_expert_indices]
                        non_core_usage_sum = torch.sum(non_core_usages).item()
                    merged_expert = merge_olmoe_mlp_by_activation_matching_within_and_across_models(
                        mlp_list=[merged_expert] + [mlp.experts[expert_idx] for expert_idx in
                                                    group_core_expert_indices],
                        forwarded_hidden_states=group_forwarded_hidden_states,
                        average_coefs=[non_core_usage_sum] + core_usages.tolist(
                        ) if usage_frequencies is not None else None
                    )
            else:
                # Merge all experts in the group
                group_forwarded_hidden_states = torch.cat([
                    forwarded_hidden_states[expert_idx] for expert_idx in expert_indices
                ], dim=0)
                merged_expert = merge_olmoe_mlp_by_activation_matching_within_and_across_models(
                    mlp_list=[mlp.experts[expert_idx] for expert_idx in expert_indices],
                    forwarded_hidden_states=group_forwarded_hidden_states,
                    average_coefs=usage_frequencies[expert_indices].tolist() if usage_frequencies is not None else None
                )
            mlp.experts[expert_indices[0]].up_proj.weight.copy_(merged_expert.up_proj.weight)
            mlp.experts[expert_indices[0]].down_proj.weight.copy_(merged_expert.down_proj.weight)
            mlp.experts[expert_indices[0]].gate_proj.weight.copy_(merged_expert.gate_proj.weight)

            for expert_idx in expert_indices[1:]:
                # Binding merged experts to the first of them
                mlp.experts[expert_idx] = mlp.experts[expert_indices[0]]

    return mlp



def OLMoE_merge_by_Linear(
        model: OlmoeForCausalLM,
        grouper: ExpertsGrouperForOLMoE,
        merging_layers: Optional[List[int]],
        batch: Dict[str, torch.Tensor],
) -> MyOlmoeForCausalLM:
    usage_frequency_dict = grouper.usage_frequency_state_dict()
    config = OlmoeConfig.from_pretrained("/root/autodl-tmp/model/OLMoE-1B-7B-0125/config.json")  
    new_model = MyOlmoeForCausalLM(config)
    new_model = load_pretrained_weights(new_model, model)

    for layer_idx in tqdm(
            grouper.sparse_layer_indices[::-1],
            desc=f"[Merging]Merging experts with act..."
    ):
        if merging_layers is None or layer_idx in merging_layers:

            mlp_name = f"model.layers.{layer_idx}.mlp"
            
            group_labels = grouper.group_state_dict()[mlp_name]
            usage_frequencies = usage_frequency_dict[mlp_name]

            new_model.model.layers[layer_idx].mlp =  _merge_mlp_experts_by_linear(
                mlp=new_model.model.layers[layer_idx].mlp,
                group_labels=group_labels,
                usage_frequencies=usage_frequencies,
            ) 

    return new_model

def OLMoE_merge_by_groups_with_ACT(
        model: OlmoeForCausalLM,
        grouper: ExpertsGrouperForOLMoE,
        merging_layers: Optional[List[int]],
        batch: Dict[str, torch.Tensor],
) -> OlmoeForCausalLM:
    """
    Merges experts in model using activation-based merging strategy.
    
    This function performs expert merging by:
    1. Computing expert usage frequencies
    2. Capturing layer activations for each expert
    3. Merging experts within each group using activation-based weighting
    
    Args:
        model: The model to be merged
        grouper: Expert grouper containing grouping information and similarity metrics
        merging_layers: List of layer indices to merge (None merges all layers)
        batch: Input batch used for computing activations
        
    Returns:
        The merged model with reduced number of experts
    """
    usage_frequency_dict = grouper.usage_frequency_state_dict()
    
    for layer_idx in tqdm(
            grouper.sparse_layer_indices[::-1],
            desc=f"[Merging]Merging experts with act..."
    ):
        if merging_layers is None or layer_idx in merging_layers:
            grouper.compute_layer_act(
                model=model,
                merging_layer_idx=layer_idx,
                batch=batch
            )
            mlp_name = f"model.layers.{layer_idx}.mlp"
            composed_matrixes = grouper.get_composed_matrixes(mlp_name)
            group_labels = grouper.group_state_dict()[mlp_name]
            usage_frequencies = usage_frequency_dict[mlp_name]

            model.model.layers[layer_idx].mlp = _merge_mlp_experts_by_weighting_act(
                mlp=model.model.layers[layer_idx].mlp,
                group_labels=group_labels,
                usage_frequencies=usage_frequencies,
                composed_matrixes=composed_matrixes
            ) 

    return model

'''
MIT License
Copyright (c) 2023 UNITES Lab
This function is modified from (https://github.com/UNITES-Lab/MC-SMoE)
'''
def OLMoE_merge_by_groups_with_usage_frequency_weighting(
        model: OlmoeForCausalLM,
        grouper: ExpertsGrouperForOLMoE,
        strategy: str = "normal",
        merging_layers: Optional[List[int]] = None,
        permute: Optional[bool] = False,
        within_and_across_models: Optional[bool] = False,
) -> OlmoeForCausalLM:
    """
    Merge experts by usage-frequency-weighted averaging, strategies include:
        1. normal: merge experts in each group by usage-frequency-weighted averaging.
        2. reversed: reverse usage frequencies by 1 - usage_frequency and merge experts in each group by
                        usage-frequency-weighted averaging.
        3. random: randomly initialize usage frequencies and merge experts in each group by
                        usage-frequency-weighted averaging.

    Parameters
    ----------
    model: OlmoeForCausalLM
        The model to merge experts
    grouper: ExpertsGrouperForOLMoE
        The grouper to group experts, supposed to have been called `grouper.compute_all_usages()` and
            one of `grouper.group_experts()` (i.e. have grouped labels)
    strategy: str
        The strategy to merge experts, one of ["normal", "reversed", "random"]
    merging_layers: Optional[List[int]]
        The layers to merge experts, if None, merge all layers
    permute: Optional[bool]
        Whether to permute the experts in the same group, only availabel when `within_and_across_models` is False.
    within_and_across_models: Optional[bool]
        Whether to merge experts within and across models.
    """
    if permute:
        print("[Merging]Permutation is enabled, will permute experts in the same group.")
    usage_frequency_dict = grouper.usage_frequency_state_dict()
    if strategy == "reversed":
        for key, value in usage_frequency_dict.items():
            usage_frequency_dict[key] = 1 - value
    elif strategy == "random":
        for key, value in usage_frequency_dict.items():
            usage_frequency_dict[key] = torch.rand_like(value)
    elif strategy != "normal":
        raise ValueError(f"[Merging]Unknown strategy {strategy}")

    for layer_idx in tqdm(
            grouper.sparse_layer_indices,
            desc=f"[Merging]Merging experts with {strategy} usage-frequency-weighted averaging..."
    ):
        if merging_layers is None or layer_idx in merging_layers:
            mlp_name = f"model.layers.{layer_idx}.mlp"
            group_labels = grouper.group_state_dict()[mlp_name]
            usage_frequencies = usage_frequency_dict[mlp_name]
            model.model.layers[layer_idx].mlp = _merge_mlp_experts_by_usage_frequency_weighting(
                mlp=model.model.layers[layer_idx].mlp,
                group_labels=group_labels,
                usage_frequencies=usage_frequencies,
                permute=permute
            ) 

    return model


def OLMoE_merge_by_groups(
        model: OlmoeForCausalLM,
        grouper: ExpertsGrouperForOLMoE,
        merging_layers: Optional[List[int]] = None,
        permute: Optional[bool] = False,
        permute_strategy: Optional[str] = "weight-matching",
        dataloader: Optional[DataLoader] = None,
) -> OlmoeForCausalLM:
    """
    Parameters
    ----------
    model: OlmoeForCausalLM
        The model to merge experts.
    grouper: ExpertsGrouperForSwitch
        The grouper to group experts, supposed to have been called `grouper.compute_all_usages()` and
            one of `grouper.group_experts()` (i.e. have grouped labels).
    merging_layers: Optional[List[int]]
        The layers to merge experts, if None, merge all layers.
    dataloader: Optional[DataLoader]
        The dataloader to compute activations, only used when `strategy` is "activation-matching".
    """
    forwarded_hidden_states = dict()
    if permute_strategy == "activation-matching":
        model.eval().cuda()
        handles = []

        def _get_activation_hook(name):
            def hook(module, input, output):
                forwarded_hidden_states[name].append(input[0].detach().reshape(-1, input[0].shape[-1]))

            return hook

        for layer_idx in tqdm(
                grouper.sparse_layer_indices,
                desc=f"[Merging]Registering forward hook..."
        ):
            mlp_name = f"model.layers.{layer_idx}.mlp"
            forwarded_hidden_states[mlp_name] = []
            handles.append(model.model.layers[layer_idx].mlp.register_forward_hook(
                _get_activation_hook(mlp_name))
            )

        # {name: values}, values will be of shape (len(dataloader), batch_size * seq_len)
        router_indices = {name: [] for name in forwarded_hidden_states.keys()}
        with torch.no_grad():
            for batch in tqdm(dataloader, desc="[Merging]Computing activations..."):
                batch = {k: v.cuda() for k, v in batch.items()}
                outputs = model(**batch, output_router_logits=True)
                for layer_idx in grouper.sparse_layer_indices:
                    routing_weights = F.softmax(outputs.router_logits[layer_idx], dim=1, dtype=torch.float)
                    routing_weights, selected_experts = torch.topk(routing_weights, grouper.num_experts_per_tok, dim=-1)
                    router_indices[f"model.layers.{layer_idx}.mlp"].append(
                        selected_experts
                    )

        for handle in handles:
            handle.remove()

    num_experts = grouper.num_experts
    
    for layer_idx in tqdm(grouper.sparse_layer_indices,
                            desc="[Merging]Merging experts with averaging..."):
        if merging_layers is None or layer_idx in merging_layers:
            mlp_name = f"model.layers.{layer_idx}.mlp"
            group_labels = grouper.group_state_dict()[mlp_name]
            if permute_strategy == "activation-matching":
                layer_forwarded_hidden_states = tuple()
                for expert_idx in range(num_experts):
                    layer_forwarded_hidden_states += (
                        torch.cat(
                            [forwarded_hidden_states[mlp_name][i][
                                    (router_indices[mlp_name][i] == expert_idx).any(dim=1)]
                                for i in range(len(dataloader))], dim=0),
                    )
                model.model.layers[layer_idx].mlp = _merge_mlp_experts_by_averaging(
                    mlp=model.model.layers[layer_idx].mlp,
                    group_labels=group_labels,
                    permute=permute,
                    permute_strategy=permute_strategy,
                    forwarded_hidden_states=layer_forwarded_hidden_states
                )
            else:
                model.model.layers[layer_idx].mlp = _merge_mlp_experts_by_averaging(
                    mlp=model.model.layers[layer_idx].mlp,
                    group_labels=group_labels,
                    permute=permute,
                    permute_strategy=permute_strategy
                )
    return model


def OLMoE_merge_by_groups_within_and_across_models(
        model: OlmoeForCausalLM,
        grouper: ExpertsGrouperForOLMoE,
        dataloader: DataLoader,
        merging_layers: Optional[List[int]] = None,
        dominant_alone: Optional[bool] = False,
        core_experts: Optional[Dict[str, List[int]]] = None,
        usage_weighted: Optional[bool] = False,
) -> OlmoeForCausalLM:
    # {name: values}, values  will be of shape (len(dataloader), batch_size * seq_len, d_ff)
    forwarded_hidden_states = dict()

    usage_frequencies = grouper.usage_frequency_state_dict()

    model.eval().cuda()
    handles = []

    def _get_activation_hook(name):
        def hook(module, input, output):
            forwarded_hidden_states[name].append(input[0].detach().reshape(-1, input[0].shape[-1]))
        return hook

    for layer_idx in tqdm(
            grouper.sparse_layer_indices,
            desc=f"[Merging]Registering forward hook..."
    ):
        mlp_name = f"model.layers.{layer_idx}.mlp"
        forwarded_hidden_states[mlp_name] = []
        handles.append(model.model.layers[layer_idx].mlp.register_forward_hook(
            _get_activation_hook(mlp_name))
        )

    # {name: values}, values will be of shape (len(dataloader), batch_size * seq_len)
    router_indices = {name: [] for name in forwarded_hidden_states.keys()}
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="[Merging]Computing activations..."):
            batch = {k: v.cuda() for k, v in batch.items()}
            outputs = model(**batch, output_router_logits=True)
            for layer_idx in grouper.sparse_layer_indices:
                routing_weights = F.softmax(outputs.router_logits[layer_idx], dim=1, dtype=torch.float)
                routing_weights, selected_experts = torch.topk(routing_weights, grouper.num_experts_per_tok, dim=-1)
                router_indices[f"model.layers.{layer_idx}.mlp"].append(
                    selected_experts
                )

    for handle in handles:
        handle.remove()

    num_experts = grouper.num_experts
    for layer_idx in tqdm(
            grouper.sparse_layer_indices,
            desc=f"[Merging]Merging by groups within and across experts..."
    ):
        if merging_layers is None or layer_idx in merging_layers:
            mlp_name = f"model.layers.{layer_idx}.mlp"
            group_labels = grouper.group_state_dict()[mlp_name]
            layer_forwarded_hidden_states = tuple()
            for expert_idx in range(num_experts):
                layer_forwarded_hidden_states += (
                    torch.cat(
                        [forwarded_hidden_states[mlp_name][i][
                            (router_indices[mlp_name][i] == expert_idx).any(dim=1)]
                         for i in range(len(dataloader))], dim=0),
                )
            model.model.layers[layer_idx].mlp = _OLMoE_merge_mlp_experts_within_and_across_models(
                mlp=model.model.layers[layer_idx].mlp,
                group_labels=group_labels,
                forwarded_hidden_states=layer_forwarded_hidden_states,
                dominant_alone=dominant_alone,
                core_expert_indices=core_experts[mlp_name] if core_experts is not None else None,
                usage_frequencies=usage_frequencies[mlp_name] if usage_weighted else None,
            )

    del forwarded_hidden_states, router_indices
    torch.cuda.empty_cache()
    return model



class ExpertUsageFrequencyTracker(object):
    """
    A class to track the usage frequencies of experts in the model during the training process.
    """

    def __init__(
            self,
            named_parameters_caller: Callable[[], Iterator[Tuple[str, torch.nn.Parameter]]],
            beta: Optional[float] = 0.9,
            compute_every_n_steps: Optional[int] = 10,
            device: Optional[str] = 'cpu',
    ):
        """

        Parameters
        ----------
        named_parameters_caller: Callable[[], Iterator[Tuple[str, torch.nn.Parameter]]]
            Normally, it is `model.named_parameters`
        beta: Optional[float]
            The beta parameter in exponential moving average
        compute_every_n_steps: Optional[int]
            Compute usage EMA every n steps
        device: Optional[str]
            The device to store the usage frequency state dict
        """
        self.beta = beta
        self.compute_fisher_every_n_steps = compute_every_n_steps
        self.device = device
        self.exp_expert_usage_frequency_dict = None
        self.last_error = -1

        self._init_usage_frequency_dict(named_parameters_caller)
        self.num_layers = None

    def usage_frequency_state_dict(self) -> Dict[str, torch.Tensor]:
        return deepcopy(self.exp_expert_usage_frequency_dict)

    def _init_usage_frequency_dict(self,
                                   named_parameters_caller: Callable[[], Iterator[Tuple[str, torch.nn.Parameter]]]):
        self.exp_expert_usage_frequency_dict = OrderedDict()
        for name, _ in named_parameters_caller():
            if "router.classifier" in name:
                mlp_name = name.split(".router")[0]
                self.exp_expert_usage_frequency_dict[mlp_name] = []

        for name, _ in named_parameters_caller():
            if "expert_" in name and "wi" in name:
                mlp_name = name.split(".experts.expert_")[0]
                self.exp_expert_usage_frequency_dict[mlp_name].append(1.0)

        for key, value in self.exp_expert_usage_frequency_dict.items():
            value = torch.tensor(value, device=self.device)
            # init as average used frequency
            self.exp_expert_usage_frequency_dict[key] = value / torch.sum(value)
        self.num_layers = len(self.exp_expert_usage_frequency_dict) // 2

    def _update_exp_usage_frequency_state(
            self,
            model_outputs: Union[List[Seq2SeqMoEOutput], Seq2SeqMoEOutput]
    ):
        usage_frequency_state_dict = OrderedDict()
        is_first_step = self.last_error < 0
        for name in self.exp_expert_usage_frequency_dict.keys():
            usage_frequency_state_dict[name] = torch.zeros_like(self.exp_expert_usage_frequency_dict[name])
            # `name` is like "encoder.block.1.layer.1.mlp" or "decoder.block.3.layer.2.mlp" or etc.,
            # Capture the layer index from the name
            layer_idx = int(name.split(".block.")[1].split(".layer.")[0])
            is_encoder = "encoder" in name
            if is_encoder:
                if isinstance(model_outputs, list):
                    router_expert_index = [op.encoder_router_logits[layer_idx][1].reshape(-1) for op in model_outputs]
                    router_expert_index = torch.concat(router_expert_index)
                else:
                    router_expert_index = model_outputs.encoder_router_logits[layer_idx][1].reshape(-1)
            else:
                if isinstance(model_outputs, list):
                    router_expert_index = [op.decoder_router_logits[layer_idx][1].reshape(-1) for op in model_outputs]
                    router_expert_index = torch.concat(router_expert_index)
                else:
                    router_expert_index = model_outputs.decoder_router_logits[layer_idx][1].reshape(-1)
            for idx in router_expert_index:
                usage_frequency_state_dict[name][idx] += 1
            usage_frequency_state_dict[name] = usage_frequency_state_dict[name] / torch.sum(
                usage_frequency_state_dict[name])
        self.last_error = sum(
            [torch.sum(torch.abs(usage_frequency_state_dict[name] - self.exp_expert_usage_frequency_dict[name]))
             for name in self.exp_expert_usage_frequency_dict.keys()]
        )
        for name in self.exp_expert_usage_frequency_dict.keys():
            if is_first_step:
                self.exp_expert_usage_frequency_dict[name] = usage_frequency_state_dict[name]
            else:
                self.exp_expert_usage_frequency_dict[name] = self.beta * self.exp_expert_usage_frequency_dict[
                    name] + (1 - self.beta) * usage_frequency_state_dict[name]
        # normalize
        for name in self.exp_expert_usage_frequency_dict.keys():
            self.exp_expert_usage_frequency_dict[name] = self.exp_expert_usage_frequency_dict[name] / torch.sum(
                self.exp_expert_usage_frequency_dict[name])

    def step(self, model_outputs: Union[List[Seq2SeqMoEOutput], Seq2SeqMoEOutput], global_step: int) -> float:
        """
        Update the usage frequency state of experts and return the last error of the state which is L1 norm

        Parameters
        ----------
        model_outputs: Union[List[Seq2SeqMoEOutput], Seq2SeqMoEOutput]
            The outputs of the model on dataset for merging
        global_step: int
            The global step of the training process

        Returns
        -------
        float
            The last error of the state which is L1 norm
        """
        if global_step % self.compute_fisher_every_n_steps == 0:
            self._update_exp_usage_frequency_state(model_outputs)
        return self.last_error
