import os
import numpy as np
from typing import Optional, Dict, List
import time
import logging
import torch
import yaml
from fire import Fire
from transformers import AutoTokenizer
from transformers.models.deepseek_v2.modeling_deepseek_v2 import (
    DeepseekV2MoE,
    DeepseekV2ForCausalLM
) 
from eval import get_calib_dataloder
from tqdm import tqdm
from scipy.optimize import minimize
import matplotlib.pyplot as plt
import torch.nn as nn
import json

logger = logging.getLogger(__name__)

STAT_TYPE_DICT = {
    0: "max",
    1: "min",
    2: "mean",
    3: "std"
}



def check_metrics_mat(
        metrics_path: str
):
    if not os.path.exists(metrics_path):
        print(f"Metrics file not found at {metrics_path}")
        return
    
    metrics = torch.load(metrics_path)
    print(f"Loaded metrics from {metrics_path}")
    
    for metric_name, metric_data in metrics.items():
        print(f"\n{'-'*20} {metric_name} {'-'*20}")
        
        if isinstance(metric_data, dict):
            
            for i, (layer_name, values) in enumerate(list(metric_data.items())[:5]):
                print(f"{layer_name}: {values}")
                if i >= 4:  
                    break
        elif isinstance(metric_data, torch.Tensor):
            
            print(metric_data[:5])
        else:
            
            print(f"Type: {type(metric_data)}")
            print(str(metric_data)[:200] + "..." if len(str(metric_data)) > 200 else str(metric_data))
    
    print("\nMetrics check completed.")

def cal_prune_num(
    save_dir: str = "./results/metrics",
    save_name: str = "deepseek_v2_moe_metrics.pt",
    metric: str = "mutual_information",  
    stat_type: int = 3,  
    min_margin: int = None,  
    pruning_ratio: float = 0.75,
    custom_max_experts: dict = None,  
    optimize_order: str = "prune_first",  
    max_neighbor_diff: int = None,  
    objective_type: str = "linear",  
    rounding_residuals: Optional[np.ndarray] = None,  
    return_residuals: bool = False,  
) -> np.ndarray | tuple[np.ndarray, np.ndarray]:
    
    if stat_type is None:
        stat_type = 2  
    
    
    stat_type_map = {0: "max", 1: "min", 2: "mean", 3: "std"}
    
    metrics = torch.load(os.path.join(save_dir, save_name))
    
    weight_outlier = metrics.get("weight-outlier", {})
    output_outlier = metrics.get("output-outlier", {})
    frequency_metrics = metrics.get("frequency", {})
    routing_score_metrics = metrics.get("routing-score", {})
    weight_mean_metrics = metrics.get("weight-mean", {})
    
    weight_cka_metrics = metrics.get("weight-cka", {})
    output_cka_metrics = metrics.get("output-cka", {})
    weight_cosine_metrics = metrics.get("weight-cosine", {})
    output_cosine_metrics = metrics.get("output-cosine", {})
    mutual_information_metrics = metrics.get("mutual_information", {})

    adapt_down_metrics = metrics.get("adapt_down", {})

    
    
    def normalize(x):
        return (x - np.min(x)) / (np.max(x) - np.min(x) + 1e-10)
    
    
    weight_outlier_scores = np.array([tensor.sum().item() for tensor in weight_outlier.values()])
    output_outlier_scores = np.array([tensor.sum().item() for tensor in output_outlier.values()])
    frequency_scores = np.array([tensor.sum().item() for tensor in frequency_metrics.values()])
    weight_mean_scores = np.array([tensor.sum().item() for tensor in weight_mean_metrics.values()]) 
    
    frequency_factor = (1 + frequency_scores) / (frequency_scores + 1e-10)
    
    if metric == "weight-outlier":
        
        importance_scores = frequency_factor * weight_outlier_scores
        importance_scores = normalize(importance_scores)
    elif metric == "output-outlier":
        
        importance_scores = frequency_factor * output_outlier_scores
        importance_scores = normalize(importance_scores)
    elif metric == "combined":
        
        importance_scores = frequency_factor * output_outlier_scores * weight_outlier_scores
        importance_scores = normalize(importance_scores)
    elif metric == "weight-mean":
        
        importance_scores = frequency_factor * weight_mean_scores
        importance_scores = normalize(importance_scores)
    elif metric == "mutual_information":
        if not mutual_information_metrics:
            raise ValueError("mutual_information is miss, calculate it first")
        importance_scores = np.array([tensor.sum().item() for tensor in mutual_information_metrics.values()])
    elif metric == "output-cosine":
        if not output_cosine_metrics:
            raise ValueError("output-cosine is miss, calculate it first")
        print(f"use {metric}'s {stat_type_map.get(stat_type, 'std')} as the important score")
        output_cosine_scores = np.array([tensor[stat_type].item() for tensor in output_cosine_metrics.values()])
        
        
        if stat_type in [0, 2]:  
            importance_scores = (-output_cosine_scores)
        else:  
            importance_scores = output_cosine_scores
        importance_scores = normalize(importance_scores)
    elif metric == "weight-cosine":
        if not weight_cosine_metrics:
            raise ValueError("weight-cosine is miss, calculate it first")
        print(f"use {metric}'s {stat_type_map.get(stat_type, 'std')} as the important score")
        weight_cosine_scores = np.array([tensor[stat_type].item() for tensor in weight_cosine_metrics.values()])
        
        
        if stat_type in [0, 2]:  
            importance_scores = (-weight_cosine_scores)
        else:  
            importance_scores = weight_cosine_scores
        importance_scores = normalize(importance_scores)
    elif metric == "adapt_down":
        if not adapt_down_metrics:
            raise ValueError("adapt_down is miss, calculate it first")
        importance_scores = np.array([tensor.sum().item() for tensor in adapt_down_metrics.values()])
        importance_scores = normalize(importance_scores)
    else:
        raise ValueError(f"unspported method {metric}，Use 'weight-outlier', 'output-outlier', 'combined', 'weight-mean', 'mutual_information', 'output-cosine', 'weight-cosine'")
    
    
    if rounding_residuals is not None:
        importance_scores = importance_scores + (1e-3 * rounding_residuals)
    
    
    num_experts = 64
    num_layers = 26

    
    layer_names = [f"model.layers.{i}.mlp" for i in range(num_layers)]
    layer_max_experts = []
    
    for i, layer_name in enumerate(layer_names):
        if custom_max_experts and layer_name in custom_max_experts and optimize_order in ["merge_first", "prune_first"]:
            
            max_expert = custom_max_experts[layer_name]
        else:
            
            max_expert = num_experts
        layer_max_experts.append(max_expert)
    
    
    if optimize_order in ["prune_first", "prune_only"]:
        
        total_experts = sum(layer_max_experts)
        num_experts_to_keep = int(total_experts * pruning_ratio)
    else:
        
        total_experts = sum(layer_max_experts)
        num_experts_to_keep = int(total_experts * pruning_ratio)
    
    
    
    bounds = []
    initial_experts = []
    
    for i, layer_name in enumerate(layer_names):
        layer_max = layer_max_experts[i]
        
        initial_val = layer_max * pruning_ratio
        initial_experts.append(initial_val)

        
        bounds.append((max(int(initial_val - min_margin), 16), layer_max))
    
    
    initial_experts = np.array(initial_experts)
    
    
    def constraint_total(x):
        return num_experts_to_keep - np.sum(x)
    
    
    cons = [{'type': 'eq', 'fun': constraint_total}]
    
    
    if max_neighbor_diff is not None and max_neighbor_diff >= 0:
        for i in range(num_layers - 1):
            
            def make_neighbor_constraint(i, j):
                def constraint(x):
                    return max_neighbor_diff - abs(x[i] - x[j])
                return constraint
            
            cons.append({'type': 'ineq', 'fun': make_neighbor_constraint(i, i+1)})
    
    
    if objective_type == "exp":
        
        def objective(x):
            return -np.sum(importance_scores * np.exp(x / 10))  

        def obj_grad(x):
            return -importance_scores * np.exp(x / 10) / 10
    elif objective_type == "log":
        
        def objective(x):
            return -np.sum(importance_scores * np.log(x + 1))  

        def obj_grad(x):
            return -importance_scores / (x + 1)
    elif objective_type == "square":
        
        def objective(x):
            return -np.sum(importance_scores * x**2)

        def obj_grad(x):
            return -2 * importance_scores * x
    elif objective_type == "linear":
        def objective(x):
            return -np.sum(importance_scores * x)

        def obj_grad(x):
            return -importance_scores
    else:  
        raise ValueError(f"Unsupported objective function type: {objective_type}，Use 'linear', 'exp', 'log', 'square'")
    
    
    res = minimize(
        fun=objective,
        x0=initial_experts,
        method="SLSQP",
        bounds=bounds,
        constraints=cons,  
        jac=obj_grad,
        options={"maxiter": 1000, "disp": True}
    )
    
    
    if not res.success:
        print(f"Failed to optimize: {res.message}")
    else:
        print(f"Optimized successfully: {res.message}")
    
    print(f"Iterations: {res.nit}")
    print(f"Optimized result: {res.x}")
    if hasattr(res, 'constr_violation'):
        print(f"Constraint violation: {res.constr_violation}")
    print(f"Objective value: {res.fun}")
    
    
    desired_counts_float = res.x
    rounded_counts = np.rint(desired_counts_float).astype(int)
    
    lower_bounds = np.array([max(int(initial_experts[i] - min_margin), 16) for i in range(num_layers)])
    upper_bounds = np.array(layer_max_experts)
    rounded_counts = np.clip(rounded_counts, lower_bounds, upper_bounds)
    rounding_residuals_out = desired_counts_float - rounded_counts.astype(float)

    print(f"Final each layer experts: {rounded_counts}")
    print(f"Total experts(rounded): {np.sum(rounded_counts)}，Target: {num_experts_to_keep}")

    if return_residuals:
        return rounded_counts, rounding_residuals_out
    return rounded_counts

def cal_merge_num(
    save_dir: str = "./results/metrics",
    save_name: str = "deepseek_v2_moe_metrics.pt",
    metric: str = "output-cosine",  
    stat_type: int = 2,  
    min_margin: int = 2,  
    merging_ratio: float = 0.75,  
    custom_max_experts: dict = None,  
    optimize_order: str = "prune_first",  
    max_neighbor_diff: int = None,  
    objective_type: str = "linear",  
    rounding_residuals: Optional[np.ndarray] = None,  
    return_residuals: bool = False,
) -> np.ndarray | tuple[np.ndarray, np.ndarray]:
    
    if stat_type is None:
        stat_type = 3  
    
    metrics = torch.load(os.path.join(save_dir, save_name))
    
    output_cka_metrics = metrics.get("output-cka", {})
    weight_cka_metrics = metrics.get("weight-cka", {})
    output_cosine_metrics = metrics.get("output-cosine", {})
    weight_cosine_metrics = metrics.get("weight-cosine", {})
    
    
    stat_type_map = {0: "max", 1: "min", 2: "mean", 3: "std"}
    print(f"use {metric}'s {stat_type_map.get(stat_type, 'std')} as the important score")

    
    if metric == "output-cosine":
        if not output_cosine_metrics:
            raise ValueError("output-cosine is miss, calculate it first")
        importance_scores = np.array([tensor[stat_type].item() for tensor in output_cosine_metrics.values()])
    elif metric == "weight-cosine":
        if not weight_cosine_metrics:
            raise ValueError("weight-cosine is miss, calculate it first")
        importance_scores = np.array([tensor[stat_type].item() for tensor in weight_cosine_metrics.values()])
    else:
        raise ValueError(f"unspported method {metric}")
    
    
    
    
    if stat_type in [0, 2]:  
        
        importance_scores = -importance_scores
    else:  
        
        pass

    
    if rounding_residuals is not None:
        importance_scores = importance_scores + (1e-3 * rounding_residuals)
    
    
    num_experts = 64
    num_layers = 26

    
    layer_names = [f"model.layers.{i}.mlp" for i in range(num_layers)]
    layer_max_experts = []
    
    for i, layer_name in enumerate(layer_names):
        if custom_max_experts and layer_name in custom_max_experts and optimize_order in ["prune_first", "merge_first"]:
            
            layer_max = custom_max_experts[layer_name]
        else:
            
            layer_max = 64
        layer_max_experts.append(layer_max)
    
    
    if optimize_order in ["prune_first", "merge_only"]:
        
        total_experts = sum(layer_max_experts)
        num_experts_to_keep = int(total_experts * merging_ratio)
    else:
        
        total_experts = num_layers * num_experts
        num_experts_to_keep = int(total_experts * merging_ratio)
    

    
    

    bounds = []
    initial_target_experts = []
    
    for i, layer_name in enumerate(layer_names):
        layer_max = layer_max_experts[i]
        
        initial_val = layer_max * merging_ratio
        initial_target_experts.append(initial_val)
        bounds.append((max(int(initial_val - min_margin), 16), layer_max))
    
    
    initial_target_experts = np.array(initial_target_experts)
    
    
    def constraint_total(x):
        return num_experts_to_keep - np.sum(x)
    
    
    cons = [{'type': 'eq', 'fun': constraint_total}]
    
    
    if max_neighbor_diff is not None and max_neighbor_diff >= 0:
        for i in range(num_layers - 1):
            
            def make_neighbor_constraint(i, j):
                def constraint(x):
                    return max_neighbor_diff - abs(x[i] - x[j])
                return constraint
            
            cons.append({'type': 'ineq', 'fun': make_neighbor_constraint(i, i+1)})
    
    
    if objective_type == "exp":
        
        def objective(x):
            return -np.sum(importance_scores * np.exp(x / 10))  

        def obj_grad(x):
            return -importance_scores * np.exp(x / 10) / 10

    elif objective_type == "log":
        
        def objective(x):
            return -np.sum(importance_scores * np.log(x + 1))  

        def obj_grad(x):
            return -importance_scores / (x + 1)

    elif objective_type == "square":
        
        def objective(x):
            return -np.sum(importance_scores * x**2)

        def obj_grad(x):
            return -2 * importance_scores * x
    
    elif objective_type == "linear":
        def objective(x):
            return -np.sum(importance_scores * x)

        def obj_grad(x):
            return -importance_scores
    else:  
        raise ValueError(f"Unsupported objective function type: {objective_type}，Use 'linear', 'exp', 'log', 'square'")
    
    
    res = minimize(
        fun=objective,
        x0=initial_target_experts,
        method="SLSQP",
        bounds=bounds,
        constraints=cons,  
        jac=obj_grad,
        options={"maxiter": 1000, "disp": True}
    )
    
    
    if not res.success:
        print(f"Failed to optimize: {res.message}")
    else:
        print(f"Optimized successfully: {res.message}")
    
    print(f"Iterations: {res.nit}")
    print(f"Optimized result: {res.x}")
    if hasattr(res, 'constr_violation'):
        print(f"Constraint violation: {res.constr_violation}")
    print(f"Objective value: {res.fun}")
    
    
    desired_counts_float = res.x
    rounded_counts = np.rint(desired_counts_float).astype(int)
    lower_bounds = np.array([max(int(initial_target_experts[i] - min_margin), 16) for i in range(num_layers)])
    upper_bounds = np.array(layer_max_experts)
    rounded_counts = np.clip(rounded_counts, lower_bounds, upper_bounds)
    rounding_residuals_out = desired_counts_float - rounded_counts.astype(float)

    print(f"Final each layer target experts: {rounded_counts}")
    print(f"Total target experts(rounded): {np.sum(rounded_counts)}，Target: {num_experts_to_keep}")

    if return_residuals:
        return rounded_counts, rounding_residuals_out
    return rounded_counts

@torch.inference_mode()
def dump_deepseek_v2_moe_metrics(
    model: DeepseekV2ForCausalLM,
    metric: list[str] = ["frequency", "routing-score", "weight-outlier", "output-outlier", "output-cosine", "weight-cosine", "weight-mean", "mutual_information"],
    dataloader = None,
    save_dir: str = "./results/metrics",
    save_name: str = "deepseek_v2_moe_metrics.pt",
    override: bool = True,
    sigmoid_t: float = 0.001,
):
    if os.path.exists(os.path.join(save_dir, save_name)) and not override:
        print(f"File {os.path.join(save_dir, save_name)} already exists.")
        return
    if "all" in metric:
        metric = ["frequency", "routing-score", "weight-outlier", "output-outlier", "output-cosine", "weight-cosine", "weight-mean", "mutual_information"]
    print(f"Get the metrics of each layer of Deepseek-V2-Lite: {metric}")
    
    metrics_results = {}
    
    num_layers = len(model.model.layers)
    num_experts_default = int(model.config.n_routed_experts)
    
    
    if "frequency" in metric and dataloader is not None:
        frequency_metric = {}
        
        for layer_idx in range(len(model.model.layers)):
            ffn_name = f"model.layers.{layer_idx}.mlp"
            frequency_metric[ffn_name] = torch.zeros(num_experts_default, dtype=torch.float32)

        model.eval()
        for p in model.parameters():
            p.requires_grad_(False)

        
        original_forwards = {}
        try:
            for layer_idx, layer in enumerate(model.model.layers):
                if isinstance(layer.mlp, DeepseekV2MoE):
                    moe_block = layer.mlp
                    original_forwards[layer_idx] = moe_block.forward

                    def _wrap_forward_factory(idx):
                        def _wrapped_forward(self, hidden_states):
                            topk_idx, topk_weight = self.gate(hidden_states)
                            flat_idx = topk_idx.view(-1)
                            if flat_idx.numel() > 0 and num_experts_default > 0:
                                binc = torch.bincount(flat_idx, minlength=num_experts_default).to(torch.float32)
                                key = f"model.layers.{idx}.mlp"
                                frequency_metric[key] += binc.cpu()
                            return self._original_forward(hidden_states)
                        return _wrapped_forward

                    moe_block._original_forward = moe_block.forward
                    moe_block.forward = _wrap_forward_factory(layer_idx).__get__(moe_block, type(moe_block))

            for batch in tqdm(dataloader, desc="Calculating frequency metric"):
                batch = {k: v.cuda() for k, v in batch.items()}
                if "labels" in batch:
                    batch.pop("labels")
                with torch.no_grad():
                    _ = model(**batch)
        finally:
            
            for layer_idx, orig in original_forwards.items():
                model.model.layers[layer_idx].mlp.forward = orig

        
        for k, v in frequency_metric.items():
            s = torch.sum(v)
            if s > 0:
                frequency_metric[k] = v / s
        metrics_results["frequency"] = frequency_metric
        print("Frequency metric calculated")
        torch.cuda.empty_cache()

    
    if "routing-score" in metric and dataloader is not None:
        routing_score_metric = {}
        token_counter = {}
        
        for layer_idx in range(len(model.model.layers)):
            ffn_name = f"model.layers.{layer_idx}.mlp"
            routing_score_metric[ffn_name] = torch.zeros(num_experts_default, dtype=torch.float32)
            token_counter[ffn_name] = 0

        model.eval()
        for p in model.parameters():
            p.requires_grad_(False)

        original_forwards = {}
        try:
            for layer_idx, layer in enumerate(model.model.layers):
                if isinstance(layer.mlp, DeepseekV2MoE):
                    moe_block = layer.mlp
                    original_forwards[layer_idx] = moe_block.forward

                    def _wrap_forward_factory(idx):
                        def _wrapped_forward(self, hidden_states):
                            bsz, seq_len, h = hidden_states.shape
                            flat = hidden_states.view(-1, h).to(torch.float32)
                            
                            logits = torch.nn.functional.linear(flat, self.gate.weight.to(torch.float32), None)
                            scores = torch.nn.functional.softmax(logits, dim=-1)
                            key = f"model.layers.{idx}.mlp"
                            if num_experts_default > 0:
                                routing_score_metric[key] += scores.sum(dim=0).to(torch.float32).cpu()
                            token_counter[key] += scores.shape[0]
                            return self._original_forward(hidden_states)
                        return _wrapped_forward

                    moe_block._original_forward = moe_block.forward
                    moe_block.forward = _wrap_forward_factory(layer_idx).__get__(moe_block, type(moe_block))

            for batch in tqdm(dataloader, desc="Calculating routing-score metric"):
                batch = {k: v.cuda() for k, v in batch.items()}
                if "labels" in batch:
                    batch.pop("labels")
                with torch.no_grad():
                    _ = model(**batch)
        finally:
            for layer_idx, orig in original_forwards.items():
                model.model.layers[layer_idx].mlp.forward = orig

        
        for k in routing_score_metric.keys():
            cnt = token_counter[k]
            if cnt > 0:
                routing_score_metric[k] = routing_score_metric[k] / float(cnt)
        metrics_results["routing-score"] = routing_score_metric
        print("Routing-score metric calculated")
        torch.cuda.empty_cache()

    
    if "weight-mean" in metric:
        weight_mean_metric = {}
        
        for layer_idx in range(len(model.model.layers)):
            ffn_name = f"model.layers.{layer_idx}.mlp"
            weight_mean_metric[ffn_name] = torch.zeros(num_experts_default)
        
        model.eval()
        for p in model.parameters():
            p.requires_grad_(False)
        
        
        expert_info = []
        for name, module in model.named_modules():
            
            
            if isinstance(module, nn.Linear) and any(proj_name in name for proj_name in ['gate_proj', 'up_proj', 'down_proj']) and 'shared' not in name:
                parts = name.split('.')
                if len(parts) >= 6:
                    layer_name = f"model.layers.{parts[2]}.mlp"
                    expert_idx = int(parts[5])
                    expert_info.append((layer_name, expert_idx, module))
        
        
        for layer_name, expert_idx, module in tqdm(expert_info, desc="Calculating weight-mean metric"):
            
            with torch.no_grad():
                weight = module.weight.data
                weight_mean = weight.abs().mean().item()
                weight_mean_metric[layer_name][expert_idx] = weight_mean
        
        metrics_results["weight-mean"] = weight_mean_metric
        print("Weight-Mean metric calculated")
        
        
        del weight_mean_metric, expert_info
        torch.cuda.empty_cache()

    
    if "weight-outlier" in metric:
        num_experts = num_experts_default
        layer_names = [f"model.layers.{i}.mlp" for i in range(len(model.model.layers))]
        
        
        outlier_score = {name: torch.zeros(num_experts) for name in layer_names}
        
        
        def compute_outlier_score(module):
            with torch.no_grad():
                weight = module.weight.data
                abs_weight = weight.abs()
                return torch.max(abs_weight.max(dim=0).values / abs_weight.mean(dim=0)).item()
        
        
        expert_info = []
        for name, module in model.named_modules():
            if isinstance(module, nn.Linear) and any(proj_name in name for proj_name in ['gate_proj', 'up_proj', 'down_proj']) and 'shared' not in name:
                parts = name.split('.')
                if len(parts) >= 6:
                    layer_name = f"model.layers.{parts[2]}.mlp"
                    expert_idx = int(parts[5])
                    expert_info.append((layer_name, expert_idx, module))
        
        
        for layer_name, expert_idx, module in tqdm(expert_info, desc="Calculating outlier metric"):
            score = compute_outlier_score(module)
            outlier_score[layer_name][expert_idx] += score

        metrics_results["weight-outlier"] = outlier_score
        print("Weight-Outlier metric calculated")
        
        
        del outlier_score, expert_info
        torch.cuda.empty_cache()

    
    if "output-outlier" in metric and dataloader is not None:
        output_outlier_metric = {}
        
        for layer_idx in range(len(model.model.layers)):
            ffn_name = f"model.layers.{layer_idx}.mlp"
            output_outlier_metric[ffn_name] = torch.zeros(num_experts_default)
        
        model.eval()
        for p in model.parameters():
            p.requires_grad_(False)
            
        
        def compute_output_outlier_score(expert_output):
            with torch.no_grad():
                abs_output = expert_output.abs()
                return torch.max(abs_output.max(dim=0).values / abs_output.mean(dim=0)).item()
        
        
        original_forwards = {}
        expert_outputs = {}
        
        def _custom_moe_forward(self, hidden_states):
            
            original_output = self._original_forward(hidden_states)
            
            
            batch_size, sequence_length, hidden_dim = hidden_states.shape
            hidden_states_view = hidden_states.view(-1, hidden_dim)
            
            
            topk_idx, topk_weight = self.gate(hidden_states)
            routing_indices = topk_idx.view(-1, self.num_experts_per_tok)
            
            
            current_batch_experts = []
            for expert_idx, expert in enumerate(self.experts):
                mask = (routing_indices == expert_idx).any(dim=-1)
                if mask.any():
                    expert_inputs = hidden_states_view[mask]
                    with torch.no_grad():
                        expert_output = expert(expert_inputs)
                    current_batch_experts.append(expert_output)
                else:
                    current_batch_experts.append(torch.tensor([]))
            
            
            expert_outputs[self._module_name] = current_batch_experts
            
            return original_output
        
        
        for name, module in model.named_modules():
            if isinstance(module, DeepseekV2MoE):
                expert_outputs[name] = []
                module._original_forward = module.forward
                module._module_name = name
                module.forward = _custom_moe_forward.__get__(module, type(module))
        
        
        for batch in tqdm(dataloader, desc="Calculating output-outlier metric"):
            batch = {k: v.cuda() for k, v in batch.items()}
            if "labels" in batch:
                batch.pop("labels")
            
            with torch.no_grad():
                model(**batch)
            
            
            for layer_name, expert_outputs_list in expert_outputs.items():
                for expert_idx, expert_output in enumerate(expert_outputs_list):
                    if len(expert_output) > 0:  
                        outlier_score = compute_output_outlier_score(expert_output)
                        output_outlier_metric[layer_name][expert_idx] += outlier_score
        
        
        for name, module in model.named_modules():
            if isinstance(module, DeepseekV2MoE):
                module.forward = module._original_forward
        
        metrics_results["output-outlier"] = output_outlier_metric
        print("Output-Outlier metric calculated")
        
        
        del output_outlier_metric, expert_outputs, original_forwards
        torch.cuda.empty_cache()

    
    if "output-cka" in metric and dataloader is not None:
        output_cka_metric = {}
        for layer_idx in range(len(model.model.layers)):
            ffn_name = f"model.layers.{layer_idx}.mlp"
            output_cka_metric[ffn_name] = torch.zeros(4)

        model.eval()
        for p in model.parameters():
            p.requires_grad_(False)

        def _linear_cka(X: torch.Tensor, Y: torch.Tensor, eps: float = 1e-10) -> float:
            X = X.to(torch.float32)
            Y = Y.to(torch.float32)
            X = X - X.mean(dim=0, keepdim=True)
            Y = Y - Y.mean(dim=0, keepdim=True)
            cross = torch.matmul(X.T, Y)
            num = cross.pow(2).sum()
            XX = torch.matmul(X.T, X)
            YY = torch.matmul(Y.T, Y)
            den = torch.sqrt((XX.pow(2).sum() + eps) * (YY.pow(2).sum() + eps))
            return (num / (den + eps)).item()

        
        expert_outputs: Dict[str, List[List[torch.Tensor]]] = {}

        def _custom_moe_forward_cka(self, hidden_states):
            original_output = self._original_forward(hidden_states)
            batch_size, sequence_length, hidden_dim = hidden_states.shape
            hidden_states_view = hidden_states.view(-1, hidden_dim)
            current_batch_experts = [[] for _ in range(len(self.experts))]
            for expert_idx, expert in enumerate(self.experts):
                with torch.no_grad():
                    expert_output = expert(hidden_states_view)
                current_batch_experts[expert_idx] = expert_output.detach().cpu()
            if self._module_name not in expert_outputs:
                expert_outputs[self._module_name] = [[] for _ in range(len(self.experts))]
            for i, outputs in enumerate(current_batch_experts):
                if outputs.numel() > 0:
                    expert_outputs[self._module_name][i].append(outputs)
            return original_output

        
        for name, module in model.named_modules():
            if isinstance(module, DeepseekV2MoE):
                module._original_forward = module.forward
                module._module_name = name
                module.forward = _custom_moe_forward_cka.__get__(module, type(module))

        
        for batch in tqdm(dataloader, desc="Collecting expert outputs for CKA"):
            batch = {k: v.cuda() for k, v in batch.items()}
            if "labels" in batch:
                batch.pop("labels")
            with torch.no_grad():
                model(**batch)

        
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        max_samples = 4096
        for layer_name, layer_expert_outputs in tqdm(expert_outputs.items(), desc='Calculating output-cka metric'):
            merged_outputs: List[Optional[torch.Tensor]] = []
            for expert_batch_outputs in layer_expert_outputs:
                if len(expert_batch_outputs) > 0:
                    merged = torch.cat(expert_batch_outputs, dim=0)
                    if merged.size(0) > max_samples:
                        merged = merged[:max_samples]
                    merged_outputs.append(merged)
                else:
                    merged_outputs.append(None)

            cka_values: List[float] = []
            num_experts_local = len(merged_outputs)
            for i in range(num_experts_local):
                for j in range(i + 1, num_experts_local):
                    if merged_outputs[i] is not None and merged_outputs[j] is not None:
                        Xi = merged_outputs[i].to(device)
                        Yj = merged_outputs[j].to(device)
                        
                        if Xi.dim() > 2:
                            Xi = Xi.reshape(Xi.shape[0], -1)
                        if Yj.dim() > 2:
                            Yj = Yj.reshape(Yj.shape[0], -1)
                        try:
                            val = _linear_cka(Xi, Yj)
                        except Exception:
                            val = 0.0
                        cka_values.append(float(val))
                    else:
                        cka_values.append(0.0)

            if cka_values:
                t = torch.tensor(cka_values)
                output_cka_metric[layer_name][0] = torch.max(t)
                output_cka_metric[layer_name][1] = torch.min(t)
                output_cka_metric[layer_name][2] = torch.mean(t)
                output_cka_metric[layer_name][3] = torch.std(t)

        
        for name, module in model.named_modules():
            if isinstance(module, DeepseekV2MoE):
                module.forward = module._original_forward

        metrics_results["output-cka"] = output_cka_metric
        print("Output-CKA metric calculated")
        torch.cuda.empty_cache()

    
    if "weight-cka" in metric:
        weight_cka_metric = {}
        for layer_idx in range(len(model.model.layers)):
            ffn_name = f"model.layers.{layer_idx}.mlp"
            weight_cka_metric[ffn_name] = torch.zeros(4)

        model.eval()
        for p in model.parameters():
            p.requires_grad_(False)

        def _linear_cka_2d(X: torch.Tensor, Y: torch.Tensor, eps: float = 1e-10) -> float:
            
            X = X.to(torch.float32)
            Y = Y.to(torch.float32)
            X = X - X.mean(dim=0, keepdim=True)
            Y = Y - Y.mean(dim=0, keepdim=True)
            cross = torch.matmul(X.T, Y)
            num = cross.pow(2).sum()
            XX = torch.matmul(X.T, X)
            YY = torch.matmul(Y.T, Y)
            den = torch.sqrt((XX.pow(2).sum() + eps) * (YY.pow(2).sum() + eps))
            return (num / (den + eps)).item()

        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        for layer_idx in tqdm(range(len(model.model.layers)), desc="Calculating weight-cka metric"):
            moe_block = model.model.layers[layer_idx].mlp
            experts = moe_block.experts
            cka_values: List[float] = []
            for i in range(len(experts)):
                for j in range(i + 1, len(experts)):
                    ei = experts[i]
                    ej = experts[j]
                    total_val = 0.0
                    for wi, wj in ((ei.gate_proj.weight, ej.gate_proj.weight), (ei.up_proj.weight, ej.up_proj.weight), (ei.down_proj.weight, ej.down_proj.weight)):
                        Xi = wi.detach().to(device)
                        Yj = wj.detach().to(device)
                        if Xi.dim() == 2 and Yj.dim() == 2 and Xi.size(1) == Yj.size(1):
                            val = _linear_cka_2d(Xi, Yj)
                        else:
                            
                            val = _linear_cka_2d(Xi.reshape(1, -1), Yj.reshape(1, -1))
                        total_val += float(val)
                    cka_values.append(total_val)

            if cka_values:
                t = torch.tensor(cka_values)
                ffn_name = f"model.layers.{layer_idx}.mlp"
                weight_cka_metric[ffn_name][0] = torch.max(t)
                weight_cka_metric[ffn_name][1] = torch.min(t)
                weight_cka_metric[ffn_name][2] = torch.mean(t)
                weight_cka_metric[ffn_name][3] = torch.std(t)

        metrics_results["weight-cka"] = weight_cka_metric
        print("Weight-CKA metric calculated")
        torch.cuda.empty_cache()

    
    if "output-cosine" in metric and dataloader is not None:
        output_cosine_metric = {}
        
        for layer_idx in range(1, len(model.model.layers)):
            ffn_name = f"model.layers.{layer_idx}.mlp"
            output_cosine_metric[ffn_name] = torch.zeros(4)  
        
        model.eval()
        
        
        def compute_cosine_similarity(X, Y):
            
            if X.dim() > 2:
                X = X.reshape(X.shape[0], -1).to(torch.float32)
            if Y.dim() > 2:
                Y = Y.reshape(Y.shape[0], -1).to(torch.float32)

            
            min_rows = min(X.shape[0], Y.shape[0])
            X = X[:min_rows]
            Y = Y[:min_rows]
            
            
            return torch.nn.functional.cosine_similarity(X, Y).mean().item()
        
        
        original_forwards = {}
        expert_outputs = {}
        
        def _custom_moe_forward_cosine(self, hidden_states):
            
            original_output = self._original_forward(hidden_states)
            
            
            batch_size, sequence_length, hidden_dim = hidden_states.shape
            hidden_states_view = hidden_states.view(-1, hidden_dim)
            
            
            current_batch_experts = [[] for _ in range(len(self.experts))]
            for expert_idx, expert in enumerate(self.experts):
                with torch.no_grad():
                    expert_output = expert(hidden_states_view)
                current_batch_experts[expert_idx] = expert_output.detach().cpu()
            
            
            if self._module_name not in expert_outputs:
                expert_outputs[self._module_name] = [[] for _ in range(len(self.experts))]
            
            for i, outputs in enumerate(current_batch_experts):
                if len(outputs) > 0:
                    expert_outputs[self._module_name][i].append(outputs)
            
            return original_output
        
        
        for name, module in model.named_modules():
            if isinstance(module, DeepseekV2MoE):
                module._original_forward = module.forward
                module._module_name = name
                module.forward = _custom_moe_forward_cosine.__get__(module, type(module))
        
        
        for batch in tqdm(dataloader, desc="Collecting expert outputs for cosine similarity"):
            batch = {k: v.cuda() for k, v in batch.items()}
            if "labels" in batch:
                batch.pop("labels")
            
            with torch.no_grad():
                model(**batch)
        
        
        for layer_name, layer_expert_outputs in tqdm(expert_outputs.items(), desc="Calculating output-cosine metric"):
            
            merged_outputs = []
            for expert_idx, expert_batch_outputs in enumerate(layer_expert_outputs):
                
                merged = torch.cat(expert_batch_outputs, dim=0).detach().cpu()  
                merged_outputs.append(merged)
            
            
            num_experts = len(merged_outputs)
            cosine_values = []
            
            for i in range(num_experts):
                for j in range(i+1, num_experts):
                    if merged_outputs[i] is not None and merged_outputs[j] is not None:
                        
                        
                        cosine_value = compute_cosine_similarity(merged_outputs[i], merged_outputs[j])
                        cosine_values.append(cosine_value)
                    else:
                        
                        cosine_values.append(0.0)
            
            
            if cosine_values:  
                cosine_tensor = torch.tensor(cosine_values)
                
                output_cosine_metric[layer_name][0] = torch.max(cosine_tensor)      
                output_cosine_metric[layer_name][1] = torch.min(cosine_tensor)      
                output_cosine_metric[layer_name][2] = torch.mean(cosine_tensor)     
                output_cosine_metric[layer_name][3] = torch.std(cosine_tensor)      
            
            
            del merged_outputs, cosine_values
        
        
        for name, module in model.named_modules():
            if isinstance(module, DeepseekV2MoE):
                module.forward = module._original_forward
        
        metrics_results["output-cosine"] = output_cosine_metric
        print("Output-Cosine metric calculated")
        
        
        del output_cosine_metric, expert_outputs
        torch.cuda.empty_cache()

    
    if "weight-cosine" in metric:
        weight_cosine_metric = {}
        
        for layer_idx in range(len(model.model.layers)):
            ffn_name = f"model.layers.{layer_idx}.mlp"
            weight_cosine_metric[ffn_name] = torch.zeros(4)  
        
        model.eval()
        
        
        def compute_cosine_similarity_optimized(X, Y):
            
            X_flat = X.reshape(-1).to(torch.float32)
            Y_flat = Y.reshape(-1).to(torch.float32)
            
            
            return torch.nn.functional.cosine_similarity(X_flat.unsqueeze(0), Y_flat.unsqueeze(0))[0]
        
        
        for layer_idx in tqdm(range(len(model.model.layers)), desc="Calculating weight-cosine metric"):
            layer = model.model.layers[layer_idx]
            moe_block = layer.mlp
            
            
            experts = moe_block.experts
            num_experts = len(experts)
            
            expert_weights = []
            for expert in experts:
                expert_w = {
                    "gate_proj": expert.gate_proj.weight,
                    "up_proj": expert.up_proj.weight,
                    "down_proj": expert.down_proj.weight
                }
                expert_weights.append(expert_w)
            
            
            cosine_values = []
            
            
            for i in range(num_experts):
                for j in range(i+1, num_experts):
                    
                    w1_cosine = compute_cosine_similarity_optimized(
                        expert_weights[i]["gate_proj"], expert_weights[j]["gate_proj"]
                    )
                    w2_cosine = compute_cosine_similarity_optimized(
                        expert_weights[i]["up_proj"], expert_weights[j]["up_proj"]
                    )
                    w3_cosine = compute_cosine_similarity_optimized(
                        expert_weights[i]["down_proj"], expert_weights[j]["down_proj"]
                    )
                    
                    
                    total_cosine = w1_cosine + w2_cosine + w3_cosine
                    
                    cosine_values.append(total_cosine.item())
            
            
            if cosine_values:  
                ffn_name = f"model.layers.{layer_idx}.mlp"
                cosine_tensor = torch.tensor(cosine_values)
                
                
                weight_cosine_metric[ffn_name][0] = torch.max(cosine_tensor)      
                weight_cosine_metric[ffn_name][1] = torch.min(cosine_tensor)      
                weight_cosine_metric[ffn_name][2] = torch.mean(cosine_tensor)     
                weight_cosine_metric[ffn_name][3] = torch.std(cosine_tensor)      
            
            
            del expert_weights, cosine_values
        
        metrics_results["weight-cosine"] = weight_cosine_metric
        print("Weight-Cosine metric calculated")
        
        
        del weight_cosine_metric, cosine_tensor
        torch.cuda.empty_cache()

    if "mutual_information" in metric:
        
        
        
        t = 1e-5
        try:
            json_path = os.path.join("importance_score", "layer_wise", "deepseek_v2_mutual_information_values.json")
            with open(json_path, "r", encoding="utf-8") as f:
                mi_data = json.load(f)
            mi_metric = {}
            num_layers = len(model.model.layers)
            for layer_idx in range(num_layers):
                if layer_idx == 0:
                    continue
                layer_key = f"layer_{layer_idx}"
                layer_name = f"model.layers.{layer_idx}.mlp"
                if layer_key in mi_data:
                    final_score = float(mi_data[layer_key].get("mutual_information_sum", 0.0))
                else:
                    final_score = 0.0
                final_score = 1 / (1 + np.exp(-t * final_score))
                final_score = 1 - final_score
                mi_metric[layer_name] = torch.tensor(final_score)
            metrics_results["mutual_information"] = mi_metric
            print("Mutual-Information metric loaded")
        except Exception as e:
            print(f"read mutual_information JSON failure: {e}")

    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    torch.save(metrics_results, os.path.join(save_dir, save_name))
    print(f"Dumped to {os.path.join(save_dir, save_name)}")
    return metrics_results

@torch.inference_mode()
def load_model_and_dump_metrics(
    model_name: str = "/Path/Deepseek-V2-Lite",
    calib_set: Optional[str] = "c4",
    opt_batch_size: Optional[int] = 16,
    n_sentences: Optional[int] = 16,
    metric_save_dir: Optional[str] = "./results/metrics",
    metrics: list[str] = ['all'],
    override: bool = False,
    sigmoid_t: float = 0.5,
):
    print(f"Get OWL Metric on C4.\n Model: {model_name}\n opt_batch_size={opt_batch_size}, n_sentences={n_sentences}")

    
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    tokenizer.pad_token_id = tokenizer.eos_token_id
    model = DeepseekV2ForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.float16, device_map="auto"
    )
    model.eval()
    
    dataloader = get_calib_dataloder(
        dataset=calib_set,
        tokenizer=tokenizer,
        max_block_size=2048,
        n_blocks_for_stat=n_sentences,
        batch_size=opt_batch_size,
        num_workers=8,
    )

    
    metrics_file = f"deepseek_v2_moe_metrics.pt"
    _ = dump_deepseek_v2_moe_metrics(
        model=model, 
        metric=metrics, 
        dataloader=dataloader, 
        save_dir=metric_save_dir, 
        save_name=metrics_file, 
        override=override,
        sigmoid_t=sigmoid_t,
    )
    
    
    del model, tokenizer, dataloader
    torch.cuda.empty_cache()
    
    return metrics_file

def adapt_compress_deepseek(
    model_name: str = "/Path/Deepseek-V2-Lite",
    calib_set: Optional[str] = "c4", 
    opt_batch_size: Optional[int] = 16,
    n_sentences: Optional[int] = 16,
    skip_metrics_calculation: bool = True, 
    metrics: list[str] = ['all'], 
    objective_type: str = None,  
    
    min_margin: int = None,
    max_neighbor_diff: Optional[int] = None,  

    
    prune_method: str = "mutual_information", 
    pruning_ratio: float = 0.75,
    sigmoid_t: float = 0.001, 

    
    merge_method: str = "output-cosine",  
    stat_type: int = 0,  
    merging_ratio: float = 0.67,  
    
    
    optimize_order: str = "prune_first",  
    
    
    save_opt: bool = True, 
    metric_save_dir: Optional[str] = "./output/metrics",
    opt_save_dir: Optional[str] = "./output/opt",
    override: bool = False, 
    
    
    model: Optional[DeepseekV2ForCausalLM] = None,
    tokenizer: Optional[AutoTokenizer] = None,    
):
    
    
    stat_type_str = STAT_TYPE_DICT.get(stat_type, "") if stat_type is not None else ""
    opt_save_dir = os.path.join(opt_save_dir, f"{optimize_order}")
    
    
    if stat_type_str:
        
        result_file_path = os.path.join(opt_save_dir, f"deepseek_v2_{optimize_order}_{prune_method}_{pruning_ratio}_{stat_type_str}{merge_method}_{merging_ratio}_{objective_type}.yaml")
    else:
        
        result_file_path = os.path.join(opt_save_dir, f"deepseek_v2_{optimize_order}_{prune_method}_{pruning_ratio}_{merge_method}_{merging_ratio}_{objective_type}.yaml")
    
    if os.path.exists(result_file_path) and not override:

        with open(result_file_path, 'r') as f:
            result_dict = yaml.safe_load(f)
        results = {}
        results["merge_experts"] = np.array(result_dict["merge_experts"])
        results["prune_experts"] = np.array(result_dict["prune_experts"])
            

        return results
    
    
    metrics_file = f"deepseek_v2_moe_metrics.pt"
    
    
    metrics_file_path = os.path.join(metric_save_dir, metrics_file)
    if not skip_metrics_calculation and (not os.path.exists(metrics_file_path) or override):
        
        if model is None or tokenizer is None:
            metrics_file = load_model_and_dump_metrics(
                model_name=model_name,
                calib_set=calib_set,
                opt_batch_size=opt_batch_size,
                n_sentences=n_sentences,
                metric_save_dir=metric_save_dir,
                metrics=metrics,
                override=override,
                sigmoid_t=sigmoid_t,
            )
        else:
            
            dataloader = get_calib_dataloder(
                dataset=calib_set,
                tokenizer=tokenizer,
                max_block_size=2048,
                n_blocks_for_stat=n_sentences,
                batch_size=opt_batch_size,
                num_workers=8,
            )
            
            
            _ = dump_deepseek_v2_moe_metrics(
                model=model, 
                metric=metrics, 
                dataloader=dataloader, 
                save_dir=metric_save_dir, 
                save_name=metrics_file, 
                override=override,
                sigmoid_t=sigmoid_t,
            )
    
    
    

    results = {}
    num_layers = 26
    num_experts = 64

    time_start = time.time()
    if optimize_order == "prune_first":

        
        
        prune_nums, rounding_residuals = cal_prune_num(
            save_dir=metric_save_dir,
            save_name=metrics_file,
            metric=prune_method,
            stat_type=stat_type,
            min_margin=min_margin,
            pruning_ratio=pruning_ratio,
            optimize_order=optimize_order,  
            max_neighbor_diff=max_neighbor_diff,  
            objective_type=objective_type,  
            return_residuals=True,
        )
        results["prune_experts"] = prune_nums
        
        
        
        custom_max_experts = {}
        for layer_idx in range(num_layers):
            layer_name = f"model.layers.{layer_idx}.mlp"
            custom_max_experts[layer_name] = prune_nums[layer_idx]
        
        target_experts = cal_merge_num(
            save_dir=metric_save_dir,
            save_name=metrics_file,
            metric=merge_method,
            stat_type=stat_type,
            min_margin=min_margin,
            merging_ratio=merging_ratio,
            custom_max_experts=custom_max_experts,  
            optimize_order=optimize_order,  
            objective_type=objective_type,  
            max_neighbor_diff=max_neighbor_diff,  
            rounding_residuals=rounding_residuals,
        )
        results["merge_experts"] = target_experts
        
    elif optimize_order == "merge_first":

        
        
        merge_nums, rounding_residuals = cal_merge_num(
            save_dir=metric_save_dir,
            save_name=metrics_file,
            metric=merge_method,
            stat_type=stat_type,
            min_margin=min_margin,
            merging_ratio=merging_ratio,
            optimize_order=optimize_order,  
            objective_type=objective_type,  
            max_neighbor_diff=max_neighbor_diff,  
            return_residuals=True,
        )
        results["merge_experts"] = merge_nums
        
        
        
        custom_max_experts = {}
        for layer_idx in range(num_layers):
            layer_name = f"model.layers.{layer_idx}.mlp"
            custom_max_experts[layer_name] = merge_nums[layer_idx]
        
        prune_nums = cal_prune_num(
            save_dir=metric_save_dir,
            save_name=metrics_file,
            metric=prune_method,
            stat_type=stat_type,
            min_margin=min_margin,
            pruning_ratio=pruning_ratio,
            custom_max_experts=custom_max_experts,  
            optimize_order=optimize_order,  
            max_neighbor_diff=max_neighbor_diff,  
            objective_type=objective_type,  
            rounding_residuals=rounding_residuals,
        )
        results["prune_experts"] = prune_nums
    
    elif optimize_order == "prune_only":

        
        
        prune_nums = cal_prune_num(
            save_dir=metric_save_dir,
            save_name=metrics_file,
            metric=prune_method,
            stat_type=stat_type,
            min_margin=min_margin,
            pruning_ratio=pruning_ratio,
            optimize_order=optimize_order,  
            max_neighbor_diff=max_neighbor_diff,  
            objective_type=objective_type,  
        )
        results["prune_experts"] = prune_nums
        results["merge_experts"] = prune_nums  
        
    elif optimize_order == "merge_only":

        
        
        target_experts = cal_merge_num(
            save_dir=metric_save_dir,
            save_name=metrics_file,
            metric=merge_method,
            stat_type=stat_type,
            min_margin=min_margin,
            merging_ratio=merging_ratio,
            optimize_order=optimize_order,  
            max_neighbor_diff=max_neighbor_diff,  
            objective_type=objective_type,  
        )
        results["merge_experts"] = target_experts
        results["prune_experts"] = target_experts  
    
    else:
        raise ValueError(f"Unsupported optimize order: {optimize_order}，Please choose 'prune_first', 'merge_first', 'prune_only' or 'merge_only'")
        
    
    print(f"\n===== Optimized result summary =====")
    print(f"prune_experts: {sum(results['prune_experts'])}")
    print(f"merge_experts: {sum(results['merge_experts'])}")
    print(f"Compression rate: {sum(results['merge_experts'])/(num_experts * num_layers):.4f}")
    print(f"Optimization time: {time.time() - time_start:.2f}s") 
    
    
    if save_opt:
        if not os.path.exists(opt_save_dir):
            os.makedirs(opt_save_dir)
        with open(result_file_path, "w") as f:
            
            results_dict = {}
            
            results_dict["prune_experts"] = results['prune_experts'].tolist()
            results_dict["merge_experts"] = results['merge_experts'].tolist()
            
            
            yaml.dump(results_dict, f, default_flow_style=False, sort_keys=False, indent=4)
    
    return results

if __name__ == "__main__":
    Fire(adapt_compress_deepseek)
