from copy import deepcopy
import random
import math
import itertools as I
import logging
from typing import List, Optional, Tuple, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint

from transformers.models.mixtral.modeling_mixtral import (
    MixtralForCausalLM,
    MixtralSparseMoeBlock,
    MixtralBlockSparseTop2MLP,
)
from transformers.models.qwen2_moe.modeling_qwen2_moe import (
    Qwen2MoeForCausalLM,
    Qwen2MoeSparseMoeBlock,
    Qwen2MoeMLP,
)
from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeSparseMoeBlock

from transformers.models.deepseek_v2.modeling_deepseek_v2 import (
    DeepseekV2ForCausalLM,
    DeepseekV2MoE,
    DeepseekV2MLP,
    DeepseekV2MoEGate,
    
)

from data import MergeCacheDataset, PruneCacheDataset

logger = logging.getLogger(__name__)

class PrunableMixtralSparseMoeBlockWrapper(torch.nn.Module):
    def __init__(self, model,
                 r: Optional[int] = None,
                 dom_experts: Optional[List[int]] = None,
                 usage_freq: Optional[torch.Tensor] = None,
                 merge_method: Optional[str] = "norm_drop_fre",
                 mode: Optional[str] = "normal",
                 weight: Optional[List[int]] = None,
                 ):
        super().__init__()
        if isinstance(model, MixtralSparseMoeBlock):
            self.model = model
        else:
            self.model = model.model
        self.r = r

        
        self.experts_to_drop = None
        self.experts_assignment = None
        self.cache_dataset_type = "merge" if dom_experts is not None else "prune"
        self.cache_space = MergeCacheDataset() if self.cache_dataset_type == "merge" else PruneCacheDataset()
        self.cache_logits = False
        self.cache_X = False
        self.cache_Z = False
        self.cache_R = self.cache_dataset_type == "merge"
        
        
        self.dominant_experts = dom_experts
        self.usage_freq = usage_freq
        self.group_state_dict = {}  
        if dom_experts is not None:
            self.normal_experts = [i for i in range(self.model.num_experts) if i not in self.dominant_experts]
        else:
            self.normal_experts = None
        self.merge_method = merge_method
        self.mode = mode
        self.weight = weight
        self.device = self.model.gate.weight.device

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        batch_size, sequence_length, hidden_dim = hidden_states.shape
        hidden_states = hidden_states.view(-1, hidden_dim)
        
        router_logits = self.model.gate(hidden_states)

        
        if self.experts_to_drop is not None:
            for e in self.experts_to_drop:
                router_logits[:, e] = -float('inf')

        
        routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
        routing_weights, selected_experts = torch.topk(
            routing_weights, self.model.top_k, dim=-1)
        routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
        
        routing_weights = routing_weights.to(hidden_states.dtype)

        
        final_hidden_states = torch.zeros(
            (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
        )

        
        expert_mask = torch.nn.functional.one_hot(
            selected_experts, num_classes=self.model.num_experts).permute(2, 1, 0)

        
        for expert_idx in range(self.model.num_experts):
            expert_layer = self.model.experts[expert_idx]
            idx, top_x = torch.where(expert_mask[expert_idx])

            if top_x.shape[0] == 0:
                continue

            
            top_x_list = top_x.tolist()
            idx_list = idx.tolist()

            
            current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim)
            
            
            if self.cache_dataset_type == "merge":
                
                current_hidden_states = expert_layer(
                    current_state * routing_weights[top_x_list, idx_list, None])
            else:
                
                current_hidden_states = expert_layer(current_state)
                current_hidden_states = current_hidden_states * routing_weights[top_x_list, idx_list, None]

            
            final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))

        
        if self.experts_to_drop is not None and (self.cache_logits or self.cache_X or self.cache_Z):
            logger.warning(
                f'Already dropped {self.experts_to_drop} but still storing activations.')
            
        
        if self.cache_dataset_type == "merge":
            self.cache_space.append(
                alpha=(router_logits if self.cache_logits else None),
                X=(hidden_states if self.cache_X else None),
                Z=(final_hidden_states if self.cache_Z else None),
                R=(selected_experts if self.cache_R else None),
            )
        else:
            self.cache_space.append(
                alpha=(router_logits if self.cache_logits else None), 
                X=(hidden_states if self.cache_X else None), 
                Z=(final_hidden_states if self.cache_Z else None)
            )

        
        final_hidden_states = final_hidden_states.reshape(
            batch_size, sequence_length, hidden_dim)

        return final_hidden_states, router_logits

    @torch.no_grad()
    def enumerate(self):
        
        self.cache_logits = False
        self.cache_X = False
        self.cache_Z = False
        self.cache_R = False if self.cache_dataset_type == "prune" else True
        loss_history = dict()
        
        
        if self.cache_dataset_type == "merge":
            self.cache_space.Rs = torch.concat(self.cache_space.Rs)
            cache_space_Xs = torch.concat(self.cache_space.Xs)

        
        with torch.inference_mode():
            for dropped in I.combinations(range(self.model.num_experts), self.model.num_experts - self.r):
                self.experts_to_drop = dropped
                loss = 0

                for (hidden_states, final_hidden_states) in zip(self.cache_space.Xs, self.cache_space.Zs):
                    hidden_states = hidden_states.to(
                        device=self.model.gate.weight.data.device, non_blocking=True)
                    final_hidden_states = final_hidden_states.to(
                        dtype=torch.float64, device=self.model.gate.weight.data.device, non_blocking=True)

                    final_hidden_states_e, _ = self.forward(
                        hidden_states.unsqueeze(0))
                    loss += torch.norm(final_hidden_states -
                                      final_hidden_states_e.squeeze(0).to(torch.float64)).item()
                loss_history[dropped] = loss
            
            
            if self.cache_dataset_type == "merge":
                self.experts_to_drop = None

        
        self.experts_to_drop = min(loss_history, key=loss_history.get)
        return loss_history

    @torch.no_grad()
    def merge(self):
        assert self.cache_dataset_type == "merge", "只有在合并模式下才能使用merge方法"
        assert self.experts_assignment is not None
        assert len(self.experts_assignment) == self.model.num_experts - self.r
        
        
        self.cache_X = False
        self.cache_Z = False
        self.cache_R = False
        cache_space_Xs = torch.concat(self.cache_space.Xs)

        
        for i in range(self.model.num_experts - self.r):
            self.group_state_dict[self.normal_experts[i]] = self.experts_assignment[i]
        group_labels = [self.group_state_dict[key] for key in sorted(self.group_state_dict.keys())]
        print("merge: ", group_labels)

        
        
        
        if self.merge_method == "average":
            from merge_method.merging import merge_mixtral_mlp_experts_by_frequency_weighting
            self.model = merge_mixtral_mlp_experts_by_frequency_weighting(
                ffn=self.model,
                group_labels=torch.tensor(group_labels),
                usage_frequencies=torch.tensor([1] * self.model.num_experts),
            )
        elif self.merge_method == "weighted" and self.weight is not None:
            from merge_method.merging import merge_mixtral_moe_experts_within_and_across_models
            usage_frequencies = []
            for i in range(self.model.num_experts):
                if i in self.dominant_experts:
                    usage_frequencies.append(self.weight[0])
                else:
                    usage_frequencies.append(self.weight[1])
            self.model = merge_mixtral_mlp_experts_by_frequency_weighting(
                ffn=self.model,
                group_labels=torch.tensor(group_labels),
                usage_frequencies=torch.tensor(usage_frequencies),
            )
        elif self.merge_method == "freq":
            self.model = merge_mixtral_mlp_experts_by_frequency_weighting(
                ffn=self.model,
                group_labels=torch.tensor(group_labels),
                usage_frequencies=self.usage_freq,
            )
        elif self.merge_method == "zipit" or self.merge_method == "fix-dom-same":
            
            layer_forwarded_hidden_states = tuple()
            for expert_idx in range(self.model.num_experts):
                expert_mask = (self.cache_space.Rs == expert_idx)
                batch_tensor = torch.any(expert_mask, dim=-1)
                choice_input = cache_space_Xs[batch_tensor]
                layer_forwarded_hidden_states += (choice_input,)
            
            
            self.model = merge_mixtral_moe_experts_within_and_across_models(
                moe=self.model,
                group_labels=torch.tensor(group_labels),
                forwarded_hidden_states=layer_forwarded_hidden_states,
                dominant_alone=False,
                merge=self.merge_method,
                mode=self.mode,
                core_expert_indices=self.dominant_experts,
                usage_frequencies=None,
            )
            del layer_forwarded_hidden_states
        else:
            raise ValueError("Invalid merge method")

        
        del self.cache_space
        print(torch.cuda.memory_summary())

    @torch.no_grad()
    def prune(self):
        assert self.experts_to_drop is not None
        assert len(self.experts_to_drop) == self.model.num_experts - self.r
        
        
        if hasattr(self, 'cache_space'):
            del self.cache_space
        self.cache_X = False
        self.cache_Z = False
        if hasattr(self, 'cache_R'):
            self.cache_R = False

        
        experts_to_reserve = sorted(
            set(range(self.model.num_experts)) - set(self.experts_to_drop))
        print("experts_to_reserve: ", experts_to_reserve)
        print("experts_to_drop: ", self.experts_to_drop)

        
        gate_new = torch.nn.Linear(in_features=self.model.gate.in_features,
                                  out_features=self.r, bias=False, device='cpu', dtype=torch.bfloat16)
        gate_new.weight.data = self.model.gate.weight.data[list(
            experts_to_reserve)]
        self.model.gate = gate_new

        
        self.model.experts = torch.nn.ModuleList(
            [self.model.experts[i] for i in experts_to_reserve])
        self.model.num_experts = self.r 

class PrunableQwen2MoeSparseMoeBlockWrapper(torch.nn.Module):
    def __init__(self, model,
                 r: Optional[int] = None,
                 ):
        super().__init__()
        if isinstance(model, Qwen2MoeSparseMoeBlock):
            self.model = model
        else:
            self.model = model.model
        self.r = r

        self.experts_to_drop = None
        self.cache_space = PruneCacheDataset()
        self.cache_logits = False
        self.cache_X = False
        self.cache_Z = False


    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        batch_size, sequence_length, hidden_dim = hidden_states.shape
        hidden_states = hidden_states.view(-1, hidden_dim)
        
        router_logits = self.model.gate(hidden_states)

        if self.experts_to_drop is not None:
            for e in self.experts_to_drop:
                router_logits[:, e] = -float('inf')

        routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
        routing_weights, selected_experts = torch.topk(routing_weights, self.model.top_k, dim=-1)
        if self.model.norm_topk_prob:
            routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
        
        routing_weights = routing_weights.to(hidden_states.dtype)

        final_hidden_states = torch.zeros(
            (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
        )

        
        
        expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.model.num_experts).permute(2, 1, 0)

        
        for expert_idx in range(self.model.num_experts):
            expert_layer = self.model.experts[expert_idx]
            idx, top_x = torch.where(expert_mask[expert_idx])

            if top_x.shape[0] == 0:
                continue

            
            top_x_list = top_x.tolist()
            idx_list = idx.tolist()

            
            
            
            current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim)
            current_hidden_states = expert_layer(current_state) * routing_weights[top_x_list, idx_list, None]

            
            
            final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))

        shared_expert_output = self.model.shared_expert(hidden_states)
        shared_expert_output = F.sigmoid(self.model.shared_expert_gate(hidden_states)) * shared_expert_output

        final_hidden_states = final_hidden_states + shared_expert_output

        if self.experts_to_drop is not None and (self.cache_logits or self.cache_X or self.cache_Z):
            logger.warn(
                f'Already dropped {self.experts_to_drop} but still storing activations.')
        self.cache_space.append(alpha=(router_logits if self.cache_logits else None), X=(hidden_states if self.cache_X else None), Z=(
            final_hidden_states if self.cache_Z else None))
        
        final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
        return final_hidden_states, router_logits

    @staticmethod
    def random_comb(iterable, r):
        return tuple(sorted(random.sample(iterable, r)))

    @torch.no_grad()
    def enumerate(self, max_iter: int = 100):
        self.cache_logits = False
        self.cache_X = False
        self.cache_Z = False
        loss_history = dict()

        with torch.inference_mode():
            for _ in range(max_iter):
                dropped = PrunableQwen2MoeSparseMoeBlockWrapper.random_comb(range(self.model.num_experts), 
                                                                            self.model.num_experts - self.r)
                self.experts_to_drop = dropped
                loss = 0

                for (hidden_states, final_hidden_states) in zip(self.cache_space.Xs, self.cache_space.Zs):
                    hidden_states = hidden_states.to(
                        device=self.model.gate.weight.data.device, non_blocking=True)
                    final_hidden_states = final_hidden_states.to(
                        dtype=torch.float64, device=self.model.gate.weight.data.device, non_blocking=True)

                    final_hidden_states_e, _ = self.forward(
                        hidden_states.unsqueeze(0))
                    loss += torch.norm(final_hidden_states -
                                       final_hidden_states_e.squeeze(0).to(torch.float64)).item()
                loss_history[dropped] = loss

        self.experts_to_drop = min(loss_history, key=loss_history.get)
        return loss_history

    @torch.no_grad()
    def prune(self):
        assert self.experts_to_drop is not None
        assert len(self.experts_to_drop) == self.model.num_experts - self.r
        del self.cache_space
        self.cache_X = False
        self.cache_Z = False

        experts_to_reserve = sorted(
            set(range(self.model.num_experts)) - set(self.experts_to_drop))

        gate_new = torch.nn.Linear(in_features=self.model.gate.in_features,
                                   out_features=self.r, bias=False, device='cpu', dtype=torch.bfloat16)
        gate_new.weight.data = self.model.gate.weight.data[list(
            experts_to_reserve)]
        self.model.gate = gate_new

        self.model.experts = torch.nn.ModuleList(
            [self.model.experts[i] for i in experts_to_reserve])
        self.model.num_experts = self.r

class PrunableQwen3MoeSparseMoeBlockWrapper(torch.nn.Module):
    def __init__(self, model,
                 r: Optional[int] = None,
                 ):
        super().__init__()
        if isinstance(model, Qwen3MoeSparseMoeBlock):
            self.model = model
        else:
            self.model = model.model
        self.r = r

        self.experts_to_drop = None
        self.cache_space = PruneCacheDataset()
        self.cache_logits = False
        self.cache_X = False
        self.cache_Z = False


    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        batch_size, sequence_length, hidden_dim = hidden_states.shape
        hidden_states = hidden_states.view(-1, hidden_dim)
        
        router_logits = self.model.gate(hidden_states)

        if self.experts_to_drop is not None:
            for e in self.experts_to_drop:
                router_logits[:, e] = -float('inf')

        routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
        routing_weights, selected_experts = torch.topk(routing_weights, self.model.top_k, dim=-1)
        if self.model.norm_topk_prob:
            routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
        
        routing_weights = routing_weights.to(hidden_states.dtype)

        final_hidden_states = torch.zeros(
            (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
        )

        
        
        expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.model.num_experts).permute(2, 1, 0)

        
        for expert_idx in range(self.model.num_experts):
            expert_layer = self.model.experts[expert_idx]
            idx, top_x = torch.where(expert_mask[expert_idx])

            if top_x.shape[0] == 0:
                continue

            
            top_x_list = top_x.tolist()
            idx_list = idx.tolist()

            
            
            
            current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim)
            current_hidden_states = expert_layer(current_state) * routing_weights[top_x_list, idx_list, None]

            
            
            final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))

        if self.experts_to_drop is not None and (self.cache_logits or self.cache_X or self.cache_Z):
            logger.warn(
                f'Already dropped {self.experts_to_drop} but still storing activations.')
        self.cache_space.append(alpha=(router_logits if self.cache_logits else None), X=(hidden_states if self.cache_X else None), Z=(
            final_hidden_states if self.cache_Z else None))
        
        final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
        return final_hidden_states, router_logits

    @staticmethod
    def random_comb(iterable, r):
        return tuple(sorted(random.sample(iterable, r)))

    @torch.no_grad()
    def enumerate(self, max_iter: int = 100):
        self.cache_logits = False
        self.cache_X = False
        self.cache_Z = False
        loss_history = dict()

        with torch.inference_mode():
            for _ in range(max_iter):
                dropped = PrunableQwen2MoeSparseMoeBlockWrapper.random_comb(range(self.model.num_experts), 
                                                                            self.model.num_experts - self.r)
                self.experts_to_drop = dropped
                loss = 0

                for (hidden_states, final_hidden_states) in zip(self.cache_space.Xs, self.cache_space.Zs):
                    hidden_states = hidden_states.to(
                        device=self.model.gate.weight.data.device, non_blocking=True)
                    final_hidden_states = final_hidden_states.to(
                        dtype=torch.float64, device=self.model.gate.weight.data.device, non_blocking=True)

                    final_hidden_states_e, _ = self.forward(
                        hidden_states.unsqueeze(0))
                    loss += torch.norm(final_hidden_states -
                                       final_hidden_states_e.squeeze(0).to(torch.float64)).item()
                loss_history[dropped] = loss

        self.experts_to_drop = min(loss_history, key=loss_history.get)
        return loss_history

    @torch.no_grad()
    def prune(self):
        assert self.experts_to_drop is not None
        assert len(self.experts_to_drop) == self.model.num_experts - self.r
        del self.cache_space
        self.cache_X = False
        self.cache_Z = False

        experts_to_reserve = sorted(
            set(range(self.model.num_experts)) - set(self.experts_to_drop))

        gate_new = torch.nn.Linear(in_features=self.model.gate.in_features,
                                   out_features=self.r, bias=False, device='cpu', dtype=torch.bfloat16)
        gate_new.weight.data = self.model.gate.weight.data[list(
            experts_to_reserve)]
        self.model.gate = gate_new

        self.model.experts = torch.nn.ModuleList(
            [self.model.experts[i] for i in experts_to_reserve])
        self.model.num_experts = self.r

class PrunableMoEGate(torch.nn.Module):
    def __init__(self, config, r=None):
        super().__init__()
        self.config = config
        self.top_k = config.num_experts_per_tok
        if r is None:
            self.n_routed_experts = config.n_routed_experts
        else:
            self.n_routed_experts = r

        self.scoring_func = config.scoring_func
        self.alpha = config.aux_loss_alpha
        self.seq_aux = config.seq_aux

        
        self.norm_topk_prob = config.norm_topk_prob
        self.gating_dim = config.hidden_size
        self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim)))
        self.reset_parameters()

    def reset_parameters(self) -> None:
        import torch.nn.init  as init
        init.kaiming_uniform_(self.weight, a=math.sqrt(5))
    
    def forward(self, hidden_states, experts_to_drop=None):
        bsz, seq_len, h = hidden_states.shape        
        
        hidden_states = hidden_states.view(-1, h)
        logits = F.linear(hidden_states, self.weight, None)

        if experts_to_drop is not None:
            for e in experts_to_drop:
                logits[:, e] = -float('inf')

        if self.scoring_func == 'softmax':
            scores = logits.softmax(dim=-1)
        else:
            raise NotImplementedError(f'insupportable scoring function for MoE gating: {self.scoring_func}')
        
        
        topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
        
        
        if self.top_k > 1 and self.norm_topk_prob:
            denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
            topk_weight = topk_weight / denominator

        
        if self.training and self.alpha > 0.0:
            scores_for_aux = scores
            aux_topk = self.top_k
            
            topk_idx_for_aux_loss = topk_idx.view(bsz, -1)
            if self.seq_aux:
                scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1)
                ce = torch.zeros(bsz, self.n_routed_experts, device=hidden_states.device)
                ce.scatter_add_(1, topk_idx_for_aux_loss, torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device)).div_(seq_len * aux_topk / self.n_routed_experts)
                aux_loss = (ce * scores_for_seq_aux.mean(dim = 1)).sum(dim = 1).mean() * self.alpha
            else:
                mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts)
                ce = mask_ce.float().mean(0)
                Pi = scores_for_aux.mean(0)
                fi = ce * self.n_routed_experts
                aux_loss = (Pi * fi).sum() * self.alpha
        else:
            aux_loss = None
        return topk_idx, topk_weight, aux_loss, logits

class PrunableDeepSeekV2MoeWrapper(torch.nn.Module):
    def __init__(self, model,
                 r: Optional[int] = None,
                 dom_experts: Optional[List[int]] = None,
                 usage_freq: Optional[torch.Tensor] = None,
                 merge_method: Optional[str] = "average",
                 mode: Optional[str] = "normal",
                 weight: Optional[List[int]] = None,
                 ):
        super().__init__()
        if isinstance(model, DeepseekV2MoE):
            self.model = model
        else:
            self.model = model.model
        self.device = self.model.gate.weight.device
        self.dtype = self.model.gate.weight.dtype
        self.gate = PrunableMoEGate(self.model.gate.config).to(self.device, self.dtype)
        self.gate.weight.data = self.model.gate.weight.data
        self.r = r

        
        self.experts_to_drop = None
        self.dominant_experts = dom_experts
        self.usage_freq = usage_freq
        self.group_state_dict = {}
        self.merge_method = merge_method
        self.mode = mode
        self.weight = weight

        
        self.cache_dataset_type = "merge" if self.dominant_experts is not None else "prune"
        self.cache_space = MergeCacheDataset() if self.cache_dataset_type == "merge" else PruneCacheDataset()
        self.cache_logits = False
        self.cache_X = False
        self.cache_Z = False
        self.cache_R = self.cache_dataset_type == "merge"

    def forward(self, hidden_states):
        identity = hidden_states
        orig_shape = hidden_states.shape
        topk_idx, topk_weight, aux_loss, router_logits = self.gate(hidden_states, self.experts_to_drop)
        flat_hidden = hidden_states.view(-1, hidden_states.shape[-1])

        
        y = self.model.moe(flat_hidden, topk_idx, topk_weight).view(*orig_shape)

        
        y = y + self.model.shared_experts(identity)

        
        if self.experts_to_drop is not None and (self.cache_logits or self.cache_X or self.cache_Z):
            logger.warning(
                f'Already dropped {self.experts_to_drop} but still storing activations.')

        
        if self.cache_dataset_type == "merge":
            self.cache_space.append(
                alpha=(router_logits if self.cache_logits else None),
                X=(flat_hidden if self.cache_X else None),
                Z=(y.reshape(-1, orig_shape[-1]) if self.cache_Z else None),
                R=(topk_idx if self.cache_R else None),
            )
        else:
            self.cache_space.append(
                alpha=(router_logits if self.cache_logits else None),
                X=(flat_hidden if self.cache_X else None),
                Z=(y.reshape(-1, orig_shape[-1]) if self.cache_Z else None),
            )
        
        return y

    @staticmethod
    def random_comb(iterable, r):
        return tuple(sorted(random.sample(iterable, r)))

    @torch.no_grad()
    def enumerate(self, max_iter: int = 1000):
        
        self.cache_logits = False
        self.cache_X = False
        self.cache_Z = False
        loss_history = dict()

        with torch.inference_mode():
            
            num_experts_current = self.model.gate.weight.shape[0]
            target_r = self.r if self.r is not None else num_experts_current
            num_to_drop = num_experts_current - target_r

            if num_to_drop <= 0:
                
                self.experts_to_drop = tuple()
                loss_history[self.experts_to_drop] = 0.0
            else:
                for _ in range(max_iter):
                    dropped = PrunableDeepSeekV2MoeWrapper.random_comb(
                        range(num_experts_current), num_to_drop
                    )
                    self.experts_to_drop = dropped
                    loss = 0

                    for (hidden_states, final_hidden_states) in zip(self.cache_space.Xs, self.cache_space.Zs):
                        hidden_states = hidden_states.to(
                            device=self.model.gate.weight.data.device, non_blocking=True)
                        final_hidden_states = final_hidden_states.to(
                            dtype=torch.float64, device=self.model.gate.weight.data.device, non_blocking=True)

                        
                        final_hidden_states_e = self.forward(
                            hidden_states.unsqueeze(0))

                        
                        loss += torch.norm(final_hidden_states -
                                           final_hidden_states_e.squeeze(0).to(torch.float64)).item()
                    loss_history[dropped] = loss

        self.experts_to_drop = min(loss_history, key=loss_history.get)
        return loss_history

    @torch.no_grad()
    def prune(self):
        assert self.experts_to_drop is not None
        
        num_experts_current = len(self.model.experts)
        target_r = self.r if self.r is not None else num_experts_current
        assert len(self.experts_to_drop) == max(num_experts_current - target_r, 0)
        if hasattr(self, 'cache_space'):
            del self.cache_space
        self.cache_X = False
        self.cache_Z = False

        experts_to_reserve = sorted(
            set(range(num_experts_current)) - set(self.experts_to_drop))

        
        with torch.no_grad():
            
            self.model.gate.weight = nn.Parameter(self.model.gate.weight.data[experts_to_reserve].clone())

        
        self.model.experts = torch.nn.ModuleList([self.model.experts[i] for i in experts_to_reserve])

    @torch.no_grad()
    def merge(self):
        assert self.cache_dataset_type == "merge", "Only available in merge mode"
        assert self.experts_assignment is not None
        assert len(self.experts_assignment) == self.model.config.n_routed_experts - self.r

        
        self.cache_X = False
        self.cache_Z = False
        self.cache_R = False

        
        for i in range(self.model.config.n_routed_experts - self.r):
            self.group_state_dict[self.normal_experts[i]] = self.experts_assignment[i]
        group_labels = [self.group_state_dict[key] for key in sorted(self.group_state_dict.keys())]

        
        if self.merge_method == "average":
            usage_frequencies = torch.tensor([1] * self.model.config.n_routed_experts)
        elif self.merge_method == "weighted" and self.weight is not None:
            usage_frequencies = []
            for i in range(self.model.config.n_routed_experts):
                if i in self.dominant_experts:
                    usage_frequencies.append(self.weight[0])
                else:
                    usage_frequencies.append(self.weight[1])
            usage_frequencies = torch.tensor(usage_frequencies)
        elif self.merge_method == "freq" and self.usage_freq is not None:
            usage_frequencies = self.usage_freq
        else:
            usage_frequencies = torch.tensor([1] * self.model.config.n_routed_experts)

        
        from merge_method.merging import merge_deepseek_mlp_experts_by_usage_frequency_weighting
        self.model = merge_deepseek_mlp_experts_by_usage_frequency_weighting(
            ffn=self.model,
            group_labels=torch.tensor(group_labels),
            usage_frequencies=usage_frequencies,
        )

class DynamicSkippingMixtralSparseMoeBlockWrapper(nn.Module):
    def __init__(self, model: MixtralSparseMoeBlock, beta: float):
        super().__init__()
        assert isinstance(model, MixtralSparseMoeBlock)
        assert model.top_k == 2
        self.hidden_dim = model.hidden_dim
        self.ffn_dim = model.ffn_dim
        self.num_experts = model.num_experts
        self.top_k = model.top_k
        self.gate = model.gate
        self.experts = model.experts

        self.beta = beta

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        batch_size, sequence_length, hidden_dim = hidden_states.shape
        hidden_states = hidden_states.view(-1, hidden_dim)
        
        router_logits = self.gate(hidden_states)

        routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
        routing_weights, selected_experts = torch.topk(
            routing_weights, self.top_k, dim=-1)

        
        mask_top1 = (routing_weights[:, 1] < self.beta * routing_weights[:, 0])
        routing_weights[mask_top1, 1] = 0

        routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
        
        routing_weights = routing_weights.to(hidden_states.dtype)

        final_hidden_states = torch.zeros(
            (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
        )

        
        
        
        expert_mask = torch.nn.functional.one_hot(
            selected_experts, num_classes=self.num_experts)

        expert_mask[mask_top1, 1, :] = 0
        expert_mask = expert_mask.permute(2, 1, 0)

        
        for expert_idx in range(self.num_experts):
            expert_layer = self.experts[expert_idx]
            top_x, indices = torch.where(expert_mask[expert_idx])

            if indices.shape[0] == 0:
                continue

            
            indices_list = indices.tolist()
            top_x_list = top_x.tolist()

            
            
            
            current_state = hidden_states[None,
                                          indices_list].reshape(-1, hidden_dim)
            current_hidden_states = expert_layer(current_state)
            
            current_hidden_states = current_hidden_states * routing_weights[indices_list, top_x_list, None]

            
            
            final_hidden_states.index_add_(
                0, indices, current_hidden_states.to(hidden_states.dtype))
        final_hidden_states = final_hidden_states.reshape(
            batch_size, sequence_length, hidden_dim)
        return final_hidden_states, router_logits
