import os
from collections import OrderedDict
from copy import deepcopy
from typing import Dict, Optional, List, Union, Callable, Iterator, Tuple

import torch
from torch import nn

from torch.nn import functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import PretrainedConfig

from ..utils.configuration_deepseek import DeepseekConfig
from ..utils.modeling_deepseek import DeepseekForCausalLM, DeepseekMoE




from .permutation import (
    permute_OLMoE_mlp_dense_expert_,
    compute_OLMoE_permutation_by_weight_matching,
    compute_OLMoE_permutation_by_activation_matching,
    merge_olmoe_mlp_by_activation_matching_within_and_across_models,
)
from .utils import generate_random_group_labels

from ..utils.constants import FP32_EPS


__all__ = [
    'ExpertsGrouperForDeepseek',
    'LEGAL_SIMILARITY_BASES',
    'SIMILARITY_MAPPING_FUNCTION',
    'Deepseek_merge_by_groups'
    'Deepseek_merge_by_groups_with_usage_frequency_weighting',
]

SIMILARITY_MAPPING_FUNCTION = {
    "cosine": lambda x, y: (F.cosine_similarity(x, y, dim=-1, eps=FP32_EPS) + 1) / 2,
    "mse": lambda x, y: 1 / (1 + 0.1 * torch.log(F.mse_loss(x, y, reduction="sum"))),
    "l2": lambda x, y: F.mse_loss(x, y, reduction="sum"),
}
LEGAL_SIMILARITY_BASES = ["weight", "feature", "feature.abs", "weight-feature", "gradient", "weight-gradient",
                          "router-logits", "router-weight", "router-weight-feature", "mse", "random",
                          "feature-correlation.lsa", "feature-correlation.max", "gate-weight", "gate-up-weight", "gate-act"]


class ExpertsGrouperForDeepseek(object):
    def __init__(
            self,
            config: Union[DeepseekConfig, PretrainedConfig],
            similarity_fn: str = "cosine",
            similarity_base: str = "weight",
    ):
        if similarity_fn not in SIMILARITY_MAPPING_FUNCTION:
            raise ValueError(
                f"[Merging]similarity_fn should be one of {SIMILARITY_MAPPING_FUNCTION.keys()}, got {similarity_fn} instead.")
        if similarity_base not in LEGAL_SIMILARITY_BASES:
            raise ValueError(
                f"[Merging]similarity_base should be one of {LEGAL_SIMILARITY_BASES}, got {similarity_base} instead.")

        self.num_experts = config.n_routed_experts
        self.num_experts_per_tok = config.num_experts_per_tok
        self.avg_num_merged_experts = self.num_experts
        self.hidden_size = config.hidden_size
        self.sparse_layer_indices = list(range(1, config.num_hidden_layers))
        self.similarity_fn = SIMILARITY_MAPPING_FUNCTION[similarity_fn]
        self.similarity_base = similarity_base
        self._group_state_dict = None
        self._similarity_state_dict = None
        self._usage_frequency_state_dict = None
        self.reset_all()

        # SVD
        self.composed_matrixes = dict()

    def reset_all(self):
        if self.similarity_base == "mse":
            self.similarity_fn = SIMILARITY_MAPPING_FUNCTION["mse"]
            print("[Merging]Set similarity_fn to mse for mse similarity_base.")
        self._group_state_dict = dict()
        self._similarity_state_dict = dict()
        self._usage_frequency_state_dict = dict()
        # Similarity range: [0, 2]
        for layer_idx in self.sparse_layer_indices:
            mlp_name = f"model.layers.{layer_idx}.mlp"
            self._group_state_dict[mlp_name] = torch.arange(self.num_experts,
                                                                    device="cuda")
            self._similarity_state_dict[mlp_name] = torch.zeros(
                (self.num_experts, self.num_experts), device="cuda"
            ) + torch.eye(self.num_experts, device="cuda")
            self._usage_frequency_state_dict[mlp_name] = torch.zeros(self.num_experts, device="cuda")
            self._usage_frequency_state_dict[mlp_name] = torch.zeros(self.num_experts, device="cuda")

        self.transport_matrices = dict()

    def similarity_state_dict(self) -> Dict[str, torch.Tensor]:
        return deepcopy(self._similarity_state_dict)

    def group_state_dict(self) -> Dict[str, torch.LongTensor]:
        return deepcopy(self._group_state_dict)

    def usage_frequency_state_dict(self) -> Dict[str, torch.Tensor]:
        return deepcopy(self._usage_frequency_state_dict)

    def save_similarity(self, mlp_name: str, i: int, j: int, similarity: float):
        self._similarity_state_dict[mlp_name][i, j] = similarity
        self._similarity_state_dict[mlp_name][j, i] = similarity

    def get_similarity(self, mlp_name: str, i: int, j: int) -> float:
        return self._similarity_state_dict[mlp_name][i, j].item()

    def get_similarity_matrix(self, mlp_name: str) -> torch.Tensor:
        return deepcopy(self._similarity_state_dict[mlp_name])

    def get_transport_matrix(self, mlp_name: str) -> torch.Tensor:
        return deepcopy(self.transport_matrixes[mlp_name])

    def get_composed_matrixes(self, mlp_name: str) -> List[torch.Tensor]:
        return self.composed_matrixes[mlp_name]

    def del_composed_matrixes(self, mlp_name: str):
        del self.composed_matrixes[mlp_name]

    def save_group_state_dict(self, save_dir: str):
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        torch.save(self._group_state_dict, os.path.join(save_dir, "group_state_dict.pt"))

    def load_group_state_dict(self, load_dir: str):
        self._group_state_dict = torch.load(os.path.join(load_dir, "group_state_dict.pt"))

    def set_avg_num_merged_experts(self, avg_num_experts: int):
        self.avg_num_merged_experts = avg_num_experts


    def _assign_num_groups_per_layer(
            self,
            average_num_groups: int,
            merging_layers: List[int],
    ) -> Dict[str, int]:
       
        num_groups_per_layer = dict()
        for i, layer_idx in enumerate(self.sparse_layer_indices):
            if layer_idx not in merging_layers:
                num_groups_per_layer[f"model.layers.{layer_idx}.mlp"] = self.num_experts
            else:
                num_groups_per_layer[f"model.layers.{layer_idx}.mlp"] = average_num_groups

        return num_groups_per_layer


    def group_experts_into_clusters_by_routing_guided_globally(
            self,
            average_num_groups: int,
            merging_layers: List[int],
            layer_group_capacity: Optional[int] = None,
    ) -> Dict[str, List[int]]:
        """
        Globally group experts into clusters by routing-guided clustering, each layer will have different number of
         clusters. The total number of clusters is determined by average_num_groups.

        Parameters
        ----------
        average_num_groups: int
            The average number of clusters for all layers.
        merging_layers: List[int]
            The layers that are excluded from merging.
        layer_group_capacity: Optional[int]
            The maximum number of experts in each group in the layers. If None, the number of experts in each group is not limited.

        Returns
        -------
        core_experts: Dict[str, List[int]]
            The core experts of each cluster
        """
        # By default, the first layer of encoder is excluded.
        # 1. Assign num_groups respectively for each layer according to average_num_groups
        layer_group_capacity = layer_group_capacity if layer_group_capacity is not None else self.num_experts
        num_groups_per_layer = self._assign_num_groups_per_layer(
            average_num_groups, merging_layers, 
        )
        print(f"[Merging]Number of groups of each layer: {num_groups_per_layer}")
        # 2. Group experts into clusters for each layer
        core_experts = dict()
        for layer_idx in tqdm(self.sparse_layer_indices,
                              desc=f"Globally routing-guided clustering experts into average {average_num_groups} clusters"):
            mlp_name = f"model.layers.{layer_idx}.mlp"
            num_groups = num_groups_per_layer[mlp_name]
            group_member_count = torch.zeros(num_groups)
            indices_sorted_by_usage = torch.argsort(self._usage_frequency_state_dict[mlp_name], descending=True)
            # 1.1 Assign top-K most-used experts with label 0 to K-1 respectively
            core_expert_indices = indices_sorted_by_usage[:num_groups]
            core_experts[mlp_name] = core_expert_indices.tolist()
            for i in range(num_groups):
                self._group_state_dict[mlp_name][core_expert_indices[i]] = i
                group_member_count[i] += 1

            # 1.2 Assign left unassigned experts to the cluster with the most similar core
            similarity_matrix = self.get_similarity_matrix(mlp_name)
            for i in range(num_groups, self.num_experts):
                # Find the most similar core
                expert_idx = indices_sorted_by_usage[i]
                most_similar_core = core_expert_indices[
                    torch.argmax(similarity_matrix[expert_idx, core_expert_indices])
                ]
                most_similar_group_label = self._group_state_dict[mlp_name][most_similar_core]
                self._group_state_dict[mlp_name][expert_idx] = most_similar_group_label
                group_member_count[most_similar_group_label] += 1
                if group_member_count[self._group_state_dict[mlp_name][expert_idx]] > layer_group_capacity:
                    if len(core_expert_indices) == 1:
                        raise ValueError(
                            f"[Merging]The number of groups at Encoder layer {layer_idx} is too small!"
                        )
                    # Kick out the filled group as well as its core, by pop the core from core_experts
                    core_index = torch.argmax(similarity_matrix[expert_idx, core_expert_indices])
                    core_expert_indices = torch.cat(
                        [core_expert_indices[:core_index], core_expert_indices[core_index + 1:]]
                    )

        return core_experts

   

    def compute_all_usages(
            self,
            model: DeepseekForCausalLM,
            batch: Dict[str, torch.Tensor],
            mini_batch_size: Optional[int] = 128,
            merging_layers: Optional[List[int]] = None,
    ):
        model.cuda()
        model.eval()
        total_batch_size = batch["input_ids"].shape[0]
        if mini_batch_size > total_batch_size:
            mini_batch_size = total_batch_size
        num_batches = total_batch_size // mini_batch_size
        
        self.topk_idx = {}
        handles = []
        for layer_idx in self.sparse_layer_indices:
            gate_moe_name = f"model.layers.{layer_idx}.mlp.gate"
            handles.append(model.model.layers[layer_idx].mlp.gate.register_forward_hook(
                self._get_topk_idx(gate_moe_name)
            ))

        for i in tqdm(range(num_batches), desc="[Merging]Computing all usages..."):
            with torch.no_grad():
                mini_batch = {k: v[i * mini_batch_size: (i + 1) * mini_batch_size] for k, v in batch.items()}
                mini_batch = {k: v.cuda() for k, v in mini_batch.items()}
                outputs = model(**mini_batch)
                for layer_idx in self.sparse_layer_indices:
                    gate_moe_name = f"model.layers.{layer_idx}.mlp.gate"
                    mlp_name = f"model.layers.{layer_idx}.mlp"
                    topk_idx = self.topk_idx[gate_moe_name]
                    for idx in topk_idx.reshape(-1):
                        self._usage_frequency_state_dict[mlp_name][idx] += 1

        for handle in handles:
            handle.remove()
        del self.topk_idx            
        self._usage_frequency_state_dict = {
            k: v / torch.sum(v) for k, v in self._usage_frequency_state_dict.items()
        }

    def reverse_all_similarities(self):
        print("[Merging]Reversing all similarities...")
        for key in self._similarity_state_dict.keys():
            self._similarity_state_dict[key] = 1 - self._similarity_state_dict[key]

    def compute_layer_composed_matrixes(
            self,
            model: DeepseekForCausalLM,
            merging_layer_idx: int,
            batch: Dict[str, torch.Tensor]
    ):
        model = model.cuda()
        model = model.eval()
        batch = {k: v.cuda() for k, v in batch.items()}
        self.activations = {}
        self.gate_acts = {}
        self.inputs = {}

        mlp_name = f"model.layers.{merging_layer_idx}.mlp"
        handle = model.model.layers[merging_layer_idx].mlp.register_forward_hook(
            self._get_mlp_activation(mlp_name)
        )

        with torch.no_grad():
            model(**batch)

        group_num = self.avg_num_merged_experts
        
        ACT = self.activations[mlp_name].permute(1, 0, 2)
        
        # compose with topk
        ACT_T = torch.transpose(ACT, 1, 2)
        W = torch.mean(torch.bmm(ACT, ACT_T), dim=0)
        W_inv = torch.linalg.inv(W)
        expert_num = W.size(0)
        value_list = []
        for i in range(expert_num):
            Wi = W[i:i+1, :]
            tmp = torch.matmul(Wi, W_inv)
            value_list.append(torch.matmul(tmp, Wi.T))
        arr = torch.tensor(value_list)
        _, sorted_indices = torch.sort(arr, descending=True)

        group_labels = torch.zeros_like(arr, dtype=torch.int32)
        A = torch.zeros((expert_num, group_num), device=W.device)
        B = []
        for i in range(group_num - 1):
            t = sorted_indices[i]
            group_labels[t] = i
            B.append(torch.matmul(W[t:t+1, :], W_inv))
        
        avg_w = torch.zeros((1, expert_num), device=W.device)
        for i in range(group_num - 1, expert_num):
            t = sorted_indices[i]
            group_labels[t] = group_num - 1
            avg_w += W[t:t+1, :]

        B.append(torch.matmul(avg_w / (expert_num - group_num + 1), W_inv))
        B = torch.cat(B, dim = 0)
        for i in range(expert_num):
            A[i][group_labels[i]] = 1
        
        C = self.gate_acts[mlp_name]
        D = self.inputs[mlp_name]
        self.composed_matrixes[mlp_name] = [A, B, C, D]
        handle.remove()

        del self.activations
        del self.gate_acts
        del self.inputs


    def compute_layer_act(
            self,
            model: DeepseekForCausalLM,
            merging_layer_idx: int,
            batch: Dict[str, torch.Tensor]
    ):
        model = model.cuda()
        model = model.eval()
        batch = {k: v.cuda() for k, v in batch.items()}
        self.up_acts = {}
        self.gate_acts = {}

        mlp_name = f"model.layers.{merging_layer_idx}.mlp"
        handle = model.model.layers[merging_layer_idx].mlp.register_forward_hook(
            self._get_mlp_activation(mlp_name)
        )

        with torch.no_grad():
            model(**batch)

        self.composed_matrixes[mlp_name] = [self.gate_acts[mlp_name], self.up_acts[mlp_name]]
        handle.remove()

        del self.up_acts
        del self.gate_acts

    def compute_all_similarities(
            self,
            model: DeepseekForCausalLM,
            batch: Dict[str, torch.Tensor] = None,
            merging_layers: Optional[List[int]] = None,
    ):
        if self.similarity_base not in ["weight", "router-weight"] and batch is None:
            raise ValueError(
                "[Merging]batch should be provided when similarity_base is not 'weight' or 'router-weight'")

        model = model.cuda()
        model = model.eval()
        if self.similarity_base == "weight":
            self._compute_all_similarities_by_weight(model.state_dict())
        elif self.similarity_base == 'gate-weight':
            self._compute_all_similarities_by_gate_weight(model.state_dict())
        elif self.similarity_base == 'gate-up-weight':
            self._compute_all_similarities_by_gate_up_weight(model.state_dict())
        elif self.similarity_base == 'router-weight':
            self._compute_all_similarities_by_router_weight(model.state_dict())
        elif self.similarity_base == 'router-logits':
            batch = {k: v.cuda() for k, v in batch.items()}
            self._compute_all_similarities_by_router_logits(model, batch)
        elif self.similarity_base == 'gate-act':
            batch = {k: v.cuda() for k, v in batch.items()}
            self._compute_all_similarities_by_gate_act(model, batch)
        else:
            raise NotImplementedError

    def _compute_all_similarities_by_weight(self, state_dict: Dict[str, torch.Tensor]):
        for layer_idx in tqdm(self.sparse_layer_indices, desc="[Merging]Computing similarities by weight..."):
            mlp_name = f"model.layers.{layer_idx}.mlp"
            for i in range(self.num_experts):
                for j in range(i + 1, self.num_experts):
                    i_flat = torch.cat(
                        [state_dict[f"{mlp_name}.experts.{i}.up_proj.weight"].flatten(),
                         state_dict[f"{mlp_name}.experts.{i}.down_proj.weight"].flatten()],
                        dim=0
                    )
                    j_flat = torch.cat(
                        [state_dict[f"{mlp_name}.experts.{j}.up_proj.weight"].flatten(),
                         state_dict[f"{mlp_name}.experts.{j}.down_proj.weight"].flatten()],
                        dim=0
                    )
                    similarity = self.similarity_fn(i_flat, j_flat)
                    self.save_similarity(mlp_name, i, j, similarity)

    def _compute_all_similarities_by_gate_weight(self, state_dict: Dict[str, torch.Tensor]):
        for layer_idx in tqdm(self.sparse_layer_indices, desc="[Merging]Computing similarities by weight..."):
            mlp_name = f"model.layers.{layer_idx}.mlp"
            for i in range(self.num_experts):
                for j in range(i + 1, self.num_experts):
                    i_flat = state_dict[f"{mlp_name}.experts.{i}.gate_proj.weight"].flatten()
                    j_flat = state_dict[f"{mlp_name}.experts.{j}.gate_proj.weight"].flatten()
                    similarity = self.similarity_fn(i_flat, j_flat)
                    self.save_similarity(mlp_name, i, j, similarity)

    
    def _compute_all_similarities_by_gate_up_weight(self, state_dict: Dict[str, torch.Tensor]):
        for layer_idx in tqdm(self.sparse_layer_indices, desc="[Merging]Computing similarities by weight..."):
            mlp_name = f"model.layers.{layer_idx}.mlp"
            for i in range(self.num_experts):
                for j in range(i + 1, self.num_experts):
                    i_flat = torch.cat(
                        [state_dict[f"{mlp_name}.experts.{i}.gate_proj.weight"].flatten(),
                         state_dict[f"{mlp_name}.experts.{i}.up_proj.weight"].flatten()],
                        dim=0
                    )
                    j_flat = torch.cat(
                        [state_dict[f"{mlp_name}.experts.{j}.gate_proj.weight"].flatten(),
                         state_dict[f"{mlp_name}.experts.{j}.up_proj.weight"].flatten()],
                        dim=0
                    )
                    similarity = self.similarity_fn(i_flat, j_flat)
                    self.save_similarity(mlp_name, i, j, similarity)


    def _compute_all_similarities_by_router_weight(
            self, state_dict: Dict[str, torch.Tensor]
    ):
        for layer_idx in tqdm(self.sparse_layer_indices, desc="[Merging]Computing similarities by router rows..."):
            mlp_name = f"model.layers.{layer_idx}.mlp"
            for i in range(self.num_experts):
                for j in range(i + 1, self.num_experts):
                    i_flat = state_dict[f"{mlp_name}.gate.weight"][i]
                    j_flat = state_dict[f"{mlp_name}.gate.weight"][j]
                    similarity = self.similarity_fn(i_flat, j_flat)
                    self.save_similarity(mlp_name, i, j, similarity)


    def _compute_all_similarities_by_router_logits(
            self, model: DeepseekForCausalLM, batch: Dict[str, torch.Tensor]
    ):
        self.router_logits = {}
        handles = []
        for layer_idx in self.sparse_layer_indices:
            gate_moe_name = f"model.layers.{layer_idx}.mlp.gate"
            handles.append(model.model.layers[layer_idx].mlp.gate.register_forward_hook(
                self._get_router_logits(gate_moe_name)
            ))
        with torch.no_grad():
            outputs = model(**batch)
        for handle in handles:
            handle.remove()
            
        for layer_idx in tqdm(self.sparse_layer_indices, desc="[Merging]Computing similarities by router logits..."):
            gate_moe_name = f"model.layers.{layer_idx}.mlp.gate"
            mlp_name = f"model.layers.{layer_idx}.mlp"
            router_logits = self.router_logits[gate_moe_name].reshape(-1, self.num_experts)
            with torch.no_grad():
                for i in range(self.num_experts):
                    for j in range(i + 1, self.num_experts):
                        i_flat = router_logits[:, i].flatten()
                        j_flat = router_logits[:, j].flatten()
                        similarity = self.similarity_fn(i_flat, j_flat)
                        self.save_similarity(mlp_name, i, j, similarity)

        del self.router_logits

    def _get_mlp_activation(self, name):
        def hook(module, input, output):
            batch_size, sequence_length, hidden_dim = input[0].shape
            hidden_states = input[0].view(-1, hidden_dim)
            gate_acts = []
            up_acts = []

            for expert in module.experts:
                up_acts.append(expert.up_proj(hidden_states))
                act_fn = expert.act_fn
                # gate_acts.append(act_fn(expert.gate_proj(hidden_states)))
                gate_acts.append(expert.gate_proj(hidden_states))
            self.up_acts[name] = torch.stack(up_acts)
            self.gate_acts[name] = torch.stack(gate_acts)
        return hook

    def _get_router_logits(self, name):
        def hook(module, input, output):
            bsz, seq_len, h = input[0].shape        
            ### compute gating score
            hidden_states = input[0].view(-1, h)
            logits = F.linear(hidden_states, module.weight, None)
            self.router_logits[name] = logits
        return hook
    
    def _get_topk_idx(self, name):
        def hook(module, input, output):
            topk_idx, topk_weight, aux_loss = output
            self.topk_idx[name] = topk_idx
        return hook


def _merge_mlp_experts_by_averaging(
        mlp: DeepseekMoE,
        group_labels: torch.LongTensor,
        permute: bool,
        permute_strategy: str,
        forwarded_hidden_states: Optional[Tuple[torch.Tensor]] = None,
) -> DeepseekMoE:
    for label in group_labels.unique():
        expert_indices = torch.where(group_labels == label)[0]
        if permute and permute_strategy == "weight-matching":
            for expert_idx in expert_indices[1:]:
                perm = compute_OLMoE_permutation_by_weight_matching(
                    reference_mlp=mlp.experts[expert_indices[0]],
                    target_mlp=mlp.experts[expert_idx],
                    include_wo=True
                )
                mlp.experts[expert_idx] = permute_OLMoE_mlp_dense_expert_(
                    mlp.experts[expert_idx], perm
                )
        elif permute and permute_strategy == "activation-matching":
            group_forwarded_hidden_states = torch.cat([
                forwarded_hidden_states[expert_idx] for expert_idx in expert_indices
            ], dim=0)
            for expert_idx in expert_indices[1:]:
                perm = compute_OLMoE_permutation_by_activation_matching(
                    reference_mlp=mlp.experts[expert_indices[0]],
                    target_mlp=mlp.experts[expert_idx],
                    forwarded_hidden_states=group_forwarded_hidden_states,
                )
                mlp.experts[expert_idx] = permute_OLMoE_mlp_dense_expert_(
                    mlp.experts[expert_idx], perm
                )
        elif permute:
            raise ValueError(f"Unknown permute strategy: {permute_strategy}")

        with torch.no_grad():
            up_proj_weight = torch.mean(
                torch.stack([mlp.experts[expert_idx].up_proj.weight for expert_idx in expert_indices]),
                dim=0
            )
            down_proj_weight = torch.mean(
                torch.stack([mlp.experts[expert_idx].down_proj.weight for expert_idx in expert_indices]),
                dim=0
            )
            gate_proj_weight = torch.mean(
                torch.stack([mlp.experts[expert_idx].gate_proj.weight for expert_idx in expert_indices]),
                dim=0
            )
            mlp.experts[expert_indices[0]].up_proj.weight.copy_(up_proj_weight)
            mlp.experts[expert_indices[0]].down_proj.weight.copy_(down_proj_weight)
            mlp.experts[expert_indices[0]].gate_proj.weight.copy_(gate_proj_weight)
            for expert_idx in expert_indices[1:]:
                # Binding merged experts to the first of them
                mlp.experts[expert_idx] = mlp.experts[expert_indices[0]]
    return mlp


def _merge_mlp_experts_by_usage_frequency_weighting(
        mlp: DeepseekMoE,
        group_labels: torch.LongTensor,
        usage_frequencies: torch.Tensor,
        permute: bool,
) -> DeepseekMoE:
    for label in group_labels.unique():
        expert_indices = torch.where(group_labels == label)[0]
        if permute:
            assert(False, "Do not support permute")
        with torch.no_grad():
            up_proj_weight_list = torch.stack(
                [mlp.experts[expert_idx].up_proj.weight * usage_frequencies[expert_idx] for expert_idx in
                 expert_indices], dim=0
            )
            down_proj_weight_list = torch.stack(
                [mlp.experts[expert_idx].down_proj.weight * usage_frequencies[expert_idx] for expert_idx in
                 expert_indices], dim=0
            )
            gate_proj_weight_list = torch.stack(
                [mlp.experts[expert_idx].gate_proj.weight * usage_frequencies[expert_idx] for expert_idx in
                 expert_indices], dim=0
            )
            up_proj_weight = torch.sum(up_proj_weight_list, dim=0) / (
                    torch.sum(usage_frequencies[expert_indices], dim=0) + FP32_EPS)
            down_proj_weight = torch.sum(down_proj_weight_list, dim=0) / (
                    torch.sum(usage_frequencies[expert_indices], dim=0) + FP32_EPS)
            gate_proj_weight = torch.sum(gate_proj_weight_list, dim=0) / (
                    torch.sum(usage_frequencies[expert_indices], dim=0) + FP32_EPS)
            
            mlp.experts[expert_indices[0]].up_proj.weight.copy_(up_proj_weight)
            mlp.experts[expert_indices[0]].down_proj.weight.copy_(down_proj_weight)
            mlp.experts[expert_indices[0]].gate_proj.weight.copy_(gate_proj_weight)
 
            for expert_idx in expert_indices[1:]:
                # Binding merged experts to the first of them
                mlp.experts[expert_idx] = mlp.experts[expert_indices[0]]
    return mlp




def _merge_mlp_experts_by_weighting_act(
        mlp: DeepseekMoE,
        group_labels: torch.LongTensor,
        usage_frequencies: torch.Tensor,
        composed_matrixes: List[torch.Tensor]
) -> DeepseekMoE:
    gate_acts = composed_matrixes[0]
    up_acts = composed_matrixes[1]
    act_fn = mlp.experts[0].act_fn
    original_acts = act_fn(gate_acts) * up_acts
    
    for label in group_labels.unique():
        expert_indices = torch.where(group_labels == label)[0]
        usage_freq_sum = torch.sum(usage_frequencies[expert_indices], dim=0) + FP32_EPS
        with torch.no_grad():
            if expert_indices.numel() > 1:
                merged_gate_acts_list = torch.stack(
                    [gate_acts[expert_idx] * usage_frequencies[expert_idx] for expert_idx in expert_indices], dim=0
                )
                merged_up_acts_list = torch.stack(
                    [up_acts[expert_idx] * usage_frequencies[expert_idx] for expert_idx in expert_indices], dim=0
                )
                merged_gate_acts = torch.sum(merged_gate_acts_list, dim=0) / usage_freq_sum
                merged_up_acts = torch.sum(merged_up_acts_list, dim=0) / usage_freq_sum
                 
                merged_acts = act_fn(merged_gate_acts) * merged_up_acts
                sample_num, intermediate_size = merged_acts.size()
                unmerged_acts = original_acts[expert_indices].permute(1,0,2).reshape(sample_num, -1)
                solution, _, _, _ = torch.linalg.lstsq(merged_acts.to(torch.float), unmerged_acts.to(torch.float))

                solution = solution.T.to(mlp.experts[0].down_proj.weight.dtype)
                down_proj_weight = torch.zeros_like(mlp.experts[0].down_proj.weight)
                for i, expert_idx in enumerate(expert_indices):
                    down_proj_weight += torch.matmul(mlp.experts[expert_idx].down_proj.weight * usage_frequencies[expert_idx], 
                                solution[i*intermediate_size:(i+1)*intermediate_size])
                down_proj_weight /= usage_freq_sum            
            else:
                # group size == 1
                assert(expert_indices.numel() == 1)
                down_proj_weight = mlp.experts[expert_indices[0]].down_proj.weight

            up_proj_weight_list = torch.stack(
                [mlp.experts[expert_idx].up_proj.weight * usage_frequencies[expert_idx] for expert_idx in
                 expert_indices], dim=0
            )
            gate_proj_weight_list = torch.stack(
                [mlp.experts[expert_idx].gate_proj.weight * usage_frequencies[expert_idx] for expert_idx in
                 expert_indices], dim=0
            )
            up_proj_weight = torch.sum(up_proj_weight_list, dim=0) / usage_freq_sum
            gate_proj_weight = torch.sum(gate_proj_weight_list, dim=0) / usage_freq_sum

            mlp.experts[expert_indices[0]].up_proj.weight.copy_(up_proj_weight)
            mlp.experts[expert_indices[0]].gate_proj.weight.copy_(gate_proj_weight)
            mlp.experts[expert_indices[0]].down_proj.weight.copy_(down_proj_weight)
 
            for expert_idx in expert_indices[1:]:
                # Binding merged experts to the first of them
                mlp.experts[expert_idx] = mlp.experts[expert_indices[0]]
    return mlp




def Deepseek_merge_by_groups_with_ACT(
        model: DeepseekForCausalLM,
        grouper: ExpertsGrouperForDeepseek,
        merging_layers: Optional[List[int]],
        batch: Dict[str, torch.Tensor],
) -> DeepseekForCausalLM:
    """
    Merges experts in model using activation-based merging strategy.
    
    This function performs expert merging by:
    1. Computing expert usage frequencies
    2. Capturing layer activations for each expert
    3. Merging experts within each group using activation-based weighting
    
    Args:
        model: The model to be merged
        grouper: Expert grouper containing grouping information and similarity metrics
        merging_layers: List of layer indices to merge (None merges all layers)
        batch: Input batch used for computing activations
        
    Returns:
        The merged model with reduced number of experts
    """
    usage_frequency_dict = grouper.usage_frequency_state_dict()
    
    for layer_idx in tqdm(
            grouper.sparse_layer_indices[::-1],
            desc=f"[Merging]Merging experts with act..."
    ):
        if merging_layers is None or layer_idx in merging_layers:
            grouper.compute_layer_act(
                model=model,
                merging_layer_idx=layer_idx,
                batch=batch
            )
            mlp_name = f"model.layers.{layer_idx}.mlp"
            composed_matrixes = grouper.get_composed_matrixes(mlp_name)
            group_labels = grouper.group_state_dict()[mlp_name]
            usage_frequencies = usage_frequency_dict[mlp_name]
            model.model.layers[layer_idx].mlp = _merge_mlp_experts_by_weighting_act(
                mlp=model.model.layers[layer_idx].mlp,
                group_labels=group_labels,
                usage_frequencies=usage_frequencies,
                composed_matrixes=composed_matrixes
            ) 

    return model

'''
MIT License
Copyright (c) 2023 UNITES Lab
This function is modified from (https://github.com/UNITES-Lab/MC-SMoE)
'''
def Deepseek_merge_by_groups_with_usage_frequency_weighting(
        model: DeepseekForCausalLM,
        grouper: ExpertsGrouperForDeepseek,
        strategy: str = "normal",
        merging_layers: Optional[List[int]] = None,
        permute: Optional[bool] = False,
        within_and_across_models: Optional[bool] = False,
) -> DeepseekForCausalLM:
    """
    Merge experts by usage-frequency-weighted averaging, strategies include:
        1. normal: merge experts in each group by usage-frequency-weighted averaging.
        2. reversed: reverse usage frequencies by 1 - usage_frequency and merge experts in each group by
                        usage-frequency-weighted averaging.
        3. random: randomly initialize usage frequencies and merge experts in each group by
                        usage-frequency-weighted averaging.

    Parameters
    ----------
    model: DeepseekForCausalLM
        The model to merge experts
    grouper: ExpertsGrouperForDeepseek
        The grouper to group experts, supposed to have been called `grouper.compute_all_usages()` and
            one of `grouper.group_experts()` (i.e. have grouped labels)
    strategy: str
        The strategy to merge experts, one of ["normal", "reversed", "random"]
    merging_layers: Optional[List[int]]
        The layers to merge experts, if None, merge all layers
    permute: Optional[bool]
        Whether to permute the experts in the same group, only availabel when `within_and_across_models` is False.
    within_and_across_models: Optional[bool]
        Whether to merge experts within and across models.
    """
    if permute:
        print("[Merging]Permutation is enabled, will permute experts in the same group.")
    usage_frequency_dict = grouper.usage_frequency_state_dict()
    if strategy == "reversed":
        for key, value in usage_frequency_dict.items():
            usage_frequency_dict[key] = 1 - value
    elif strategy == "random":
        for key, value in usage_frequency_dict.items():
            usage_frequency_dict[key] = torch.rand_like(value)
    elif strategy != "normal":
        raise ValueError(f"[Merging]Unknown strategy {strategy}")

    for layer_idx in tqdm(
            grouper.sparse_layer_indices,
            desc=f"[Merging]Merging experts with {strategy} usage-frequency-weighted averaging..."
    ):
        if merging_layers is None or layer_idx in merging_layers:
            mlp_name = f"model.layers.{layer_idx}.mlp"
            group_labels = grouper.group_state_dict()[mlp_name]
            usage_frequencies = usage_frequency_dict[mlp_name]
            model.model.layers[layer_idx].mlp = _merge_mlp_experts_by_usage_frequency_weighting(
                mlp=model.model.layers[layer_idx].mlp,
                group_labels=group_labels,
                usage_frequencies=usage_frequencies,
                permute=permute
            ) 
    return model

def _Deepseek_merge_mlp_experts_within_and_across_models(
        mlp: DeepseekMoE,
        group_labels: torch.LongTensor,
        forwarded_hidden_states: Tuple[torch.Tensor],
        dominant_alone: bool,
        core_expert_indices: Optional[List[int]] = None,
        usage_frequencies: Optional[torch.Tensor] = None,
) -> DeepseekMoE:
    """
    Merge grouped experts within and across models.

    Parameters
    ----------
    mlp: SwitchTransformersSparseMLP
        The mlp to merge experts.
    group_labels: torch.LongTensor
        The group labels of experts.
    forwarded_hidden_states: Tuple[torch.Tensor]
        The forwarded hidden states of each expert, should be of length num_experts
    dominant_alone: bool
        Whether to merge the dominant expert alone.
        If True, the merging process in a group will be done in two steps:
            1. Merge all experts except the dominant one.
            2. Merge the dominant expert with the merged expert in step 1.
    core_expert_indices: List[int]

    Returns
    -------
    mlp: SwitchTransformersSparseMLP
        The merged mlp.
    """
    if dominant_alone and core_expert_indices is None:
        raise ValueError("[Merging]dominant_alone is True, but core_expert_indices is None")

    for label in group_labels.unique():
        expert_indices = torch.where(group_labels == label)[0]
        with torch.no_grad():
            if dominant_alone:
                group_core_expert_indices = torch.stack([
                    idx for idx in expert_indices if idx in core_expert_indices])
                to_skip = False
                if len(group_core_expert_indices) == len(expert_indices):
                    merged_expert = mlp.experts[expert_indices[0]]
                    to_skip = True
                elif usage_frequencies is not None and len(group_core_expert_indices) == 1:
                    non_core_usage_sum = torch.sum(
                        usage_frequencies[[expert_idx.item() for expert_idx in
                                           expert_indices if expert_idx not in group_core_expert_indices]]).item()
                    if non_core_usage_sum == 0:
                        merged_expert = mlp.experts[group_core_expert_indices[0]]
                        to_skip = True
                    else:
                        to_skip = False
                if not to_skip:
                    # Stage 1: merge all experts except the dominant one
                    group_forwarded_hidden_states = torch.cat([
                        forwarded_hidden_states[expert_idx] for expert_idx in expert_indices if
                        expert_idx not in group_core_expert_indices
                    ], dim=0)
                    if usage_frequencies is not None:
                        non_core_usages = usage_frequencies[[expert_idx.item() for expert_idx in expert_indices if
                                                             expert_idx not in group_core_expert_indices]]
                    merged_expert = merge_olmoe_mlp_by_activation_matching_within_and_across_models(
                        mlp_list=[mlp.experts[expert_idx] for expert_idx in expert_indices if
                                  expert_idx not in group_core_expert_indices],
                        forwarded_hidden_states=group_forwarded_hidden_states,
                        average_coefs=non_core_usages.tolist() if usage_frequencies is not None else None
                    )
                    # Stage 2: merge the dominant expert with the merged expert in stage 1
                    group_forwarded_hidden_states = torch.cat([
                        forwarded_hidden_states[expert_idx] for expert_idx in expert_indices
                    ], dim=0)
                    if usage_frequencies is not None:
                        core_usages = usage_frequencies[group_core_expert_indices]
                        non_core_usage_sum = torch.sum(non_core_usages).item()
                    merged_expert = merge_olmoe_mlp_by_activation_matching_within_and_across_models(
                        mlp_list=[merged_expert] + [mlp.experts[expert_idx] for expert_idx in
                                                    group_core_expert_indices],
                        forwarded_hidden_states=group_forwarded_hidden_states,
                        average_coefs=[non_core_usage_sum] + core_usages.tolist(
                        ) if usage_frequencies is not None else None
                    )
            else:
                # Merge all experts in the group
                group_forwarded_hidden_states = torch.cat([
                    forwarded_hidden_states[expert_idx] for expert_idx in expert_indices
                ], dim=0)
                merged_expert = merge_olmoe_mlp_by_activation_matching_within_and_across_models(
                    mlp_list=[mlp.experts[expert_idx] for expert_idx in expert_indices],
                    forwarded_hidden_states=group_forwarded_hidden_states,
                    average_coefs=usage_frequencies[expert_indices].tolist() if usage_frequencies is not None else None
                )
            mlp.experts[expert_indices[0]].up_proj.weight.copy_(merged_expert.up_proj.weight)
            mlp.experts[expert_indices[0]].down_proj.weight.copy_(merged_expert.down_proj.weight)
            mlp.experts[expert_indices[0]].gate_proj.weight.copy_(merged_expert.gate_proj.weight)

            for expert_idx in expert_indices[1:]:
                # Binding merged experts to the first of them
                mlp.experts[expert_idx] = mlp.experts[expert_indices[0]]

    return mlp

def Deepseek_merge_by_groups(
        model: DeepseekForCausalLM,
        grouper: ExpertsGrouperForDeepseek,
        merging_layers: Optional[List[int]] = None,
        permute: Optional[bool] = False,
        permute_strategy: Optional[str] = "weight-matching",
        dataloader: Optional[DataLoader] = None,
) -> DeepseekForCausalLM:
    """
    Parameters
    ----------
    model: DeepseekForCausalLM
        The model to merge experts.
    grouper: ExpertsGrouperForSwitch
        The grouper to group experts, supposed to have been called `grouper.compute_all_usages()` and
            one of `grouper.group_experts()` (i.e. have grouped labels).
    merging_layers: Optional[List[int]]
        The layers to merge experts, if None, merge all layers.
    dataloader: Optional[DataLoader]
        The dataloader to compute activations, only used when `strategy` is "activation-matching".
    """
    forwarded_hidden_states = dict()
    router_logits = dict()
    if permute_strategy == "activation-matching":
        model.eval().cuda()
        handles = []

        def _get_activation_hook(name):
            def hook(module, input, output):
                bsz, seq_len, h = input[0].shape        
                ### compute gating score
                hidden_states = input[0].view(-1, h)
                logits = F.linear(hidden_states, module.gate.weight, None)
                router_logits[name].append(logits)
                forwarded_hidden_states[name].append(input[0].detach().reshape(-1, input[0].shape[-1]))

            return hook

        for layer_idx in tqdm(
                grouper.sparse_layer_indices,
                desc=f"[Merging]Registering forward hook..."
        ):
            mlp_name = f"model.layers.{layer_idx}.mlp"
            forwarded_hidden_states[mlp_name] = []
            router_logits[mlp_name] = []
            handles.append(model.model.layers[layer_idx].mlp.register_forward_hook(
                _get_activation_hook(mlp_name))
            )

        # {name: values}, values will be of shape (len(dataloader), batch_size * seq_len)
        router_indices = {name: [] for name in forwarded_hidden_states.keys()}
        with torch.no_grad():
            for batch in tqdm(dataloader, desc="[Merging]Computing activations..."):
                batch = {k: v.cuda() for k, v in batch.items()}
                outputs = model(**batch)
                for layer_idx in grouper.sparse_layer_indices:
                    mlp_name = f"model.layers.{layer_idx}.mlp"
                    routing_weights = F.softmax(router_logits[mlp_name][-1], dim=1, dtype=torch.float)
                    routing_weights, selected_experts = torch.topk(routing_weights, grouper.num_experts_per_tok, dim=-1)
                    router_indices[mlp_name].append(
                        selected_experts
                    )

        for handle in handles:
            handle.remove()

    num_experts = grouper.num_experts
    
    for layer_idx in tqdm(grouper.sparse_layer_indices,
                            desc="[Merging]Merging experts with averaging..."):
        if merging_layers is None or layer_idx in merging_layers:
            mlp_name = f"model.layers.{layer_idx}.mlp"
            group_labels = grouper.group_state_dict()[mlp_name]
            if permute_strategy == "activation-matching":
                layer_forwarded_hidden_states = tuple()
                for expert_idx in range(num_experts):
                    layer_forwarded_hidden_states += (
                        torch.cat(
                            [forwarded_hidden_states[mlp_name][i][
                                    (router_indices[mlp_name][i] == expert_idx).any(dim=1)]
                                for i in range(len(dataloader))], dim=0),
                    )
                model.model.layers[layer_idx].mlp = _merge_mlp_experts_by_averaging(
                    mlp=model.model.layers[layer_idx].mlp,
                    group_labels=group_labels,
                    permute=permute,
                    permute_strategy=permute_strategy,
                    forwarded_hidden_states=layer_forwarded_hidden_states
                )
            else:
                model.model.layers[layer_idx].mlp = _merge_mlp_experts_by_averaging(
                    mlp=model.model.layers[layer_idx].mlp,
                    group_labels=group_labels,
                    permute=permute,
                    permute_strategy=permute_strategy
                )

    return model


def Deepseek_merge_by_groups_within_and_across_models(
        model: DeepseekForCausalLM,
        grouper: ExpertsGrouperForDeepseek,
        dataloader: DataLoader,
        merging_layers: Optional[List[int]] = None,
        dominant_alone: Optional[bool] = False,
        core_experts: Optional[Dict[str, List[int]]] = None,
        usage_weighted: Optional[bool] = False,
) -> DeepseekForCausalLM:
    # {name: values}, values  will be of shape (len(dataloader), batch_size * seq_len, d_ff)
    forwarded_hidden_states = dict()
    router_logits = dict()
    usage_frequencies = grouper.usage_frequency_state_dict()

    model.eval().cuda()
    handles = []

    def _get_activation_hook(name):
        def hook(module, input, output):
            bsz, seq_len, h = input[0].shape        
            ### compute gating score
            hidden_states = input[0].view(-1, h)
            logits = F.linear(hidden_states, module.gate.weight, None)
            router_logits[name].append(logits)
            forwarded_hidden_states[name].append(input[0].detach().reshape(-1, input[0].shape[-1]))

        return hook

    for layer_idx in tqdm(
            grouper.sparse_layer_indices,
            desc=f"[Merging]Registering forward hook..."
    ):
        mlp_name = f"model.layers.{layer_idx}.mlp"
        forwarded_hidden_states[mlp_name] = []
        router_logits[mlp_name] = []
        handles.append(model.model.layers[layer_idx].mlp.register_forward_hook(
            _get_activation_hook(mlp_name))
        )

    # {name: values}, values will be of shape (len(dataloader), batch_size * seq_len)
    router_indices = {name: [] for name in forwarded_hidden_states.keys()}
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="[Merging]Computing activations..."):
            batch = {k: v.cuda() for k, v in batch.items()}
            outputs = model(**batch)
            for layer_idx in grouper.sparse_layer_indices:
                mlp_name = f"model.layers.{layer_idx}.mlp"
                routing_weights = F.softmax(router_logits[mlp_name][-1], dim=1, dtype=torch.float)
                routing_weights, selected_experts = torch.topk(routing_weights, grouper.num_experts_per_tok, dim=-1)
                router_indices[mlp_name].append(
                    selected_experts
                )

    for handle in handles:
        handle.remove()

    num_experts = grouper.num_experts
    for layer_idx in tqdm(
            grouper.sparse_layer_indices,
            desc=f"[Merging]Merging by groups within and across experts..."
    ):
        if merging_layers is None or layer_idx in merging_layers:
            mlp_name = f"model.layers.{layer_idx}.mlp"
            group_labels = grouper.group_state_dict()[mlp_name]
            layer_forwarded_hidden_states = tuple()
            for expert_idx in range(num_experts):
                layer_forwarded_hidden_states += (
                    torch.cat(
                        [forwarded_hidden_states[mlp_name][i][
                            (router_indices[mlp_name][i] == expert_idx).any(dim=1)]
                         for i in range(len(dataloader))], dim=0),
                )
            model.model.layers[layer_idx].mlp = _Deepseek_merge_mlp_experts_within_and_across_models(
                mlp=model.model.layers[layer_idx].mlp,
                group_labels=group_labels,
                forwarded_hidden_states=layer_forwarded_hidden_states,
                dominant_alone=dominant_alone,
                core_expert_indices=core_experts[mlp_name] if core_experts is not None else None,
                usage_frequencies=usage_frequencies[mlp_name] if usage_weighted else None,
            )

    del forwarded_hidden_states, router_logits, router_indices
    torch.cuda.empty_cache()
    return model