import os
import numpy as np
from typing import Optional, Dict, List
import time
import logging
import torch
import torch.nn as nn
import random
import yaml
from fire import Fire
from types import MethodType
from transformers import AutoTokenizer, Qwen2MoeForCausalLM 
from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock
from eval import get_calib_dataloder
from tqdm import tqdm
from scipy.optimize import minimize
import matplotlib.pyplot as plt
from utils import linear_CKA, batch_linear_CKA
logger = logging.getLogger(__name__)

STAT_TYPE_DICT = {
    0: "max",
    1: "min",
    2: "mean",
    3: "std"
}



def check_metrics_mat(
        metrics_path: str
):
    if not os.path.exists(metrics_path):
        print(f"Metrics file not found at {metrics_path}")
        return
    
    metrics = torch.load(metrics_path)
    print(f"Loaded metrics from {metrics_path}")
    
    for metric_name, metric_data in metrics.items():
        print(f"\n{'-'*20} {metric_name} {'-'*20}")
        
        if isinstance(metric_data, dict):
            
            for i, (layer_name, values) in enumerate(list(metric_data.items())[:5]):
                print(f"{layer_name}: {values}")
                if i >= 4:  
                    break
        elif isinstance(metric_data, torch.Tensor):
            
            print(metric_data[:5])
        else:
            
            print(f"Type: {type(metric_data)}")
            print(str(metric_data)[:200] + "..." if len(str(metric_data)) > 200 else str(metric_data))
    
    print("\nMetrics check completed.")

def cal_prune_num(
    save_dir: str = "./results/metrics",
    save_name: str = "qwen_1.5_moe_metrics.pt",
    metric: str = "output-outlier",  
    stat_type: int = 3,  
    min_margin: int = None,  
    pruning_ratio: float = 0.75,
    custom_max_experts: dict = None,  
    optimize_order: str = "prune_first",  
    max_neighbor_diff: int = None,  
    objective_type: str = "log",  
    model_config: dict = None,
) -> np.ndarray:
    
    if stat_type is None:
        stat_type = 3  
    
    
    stat_type_map = {0: "max", 1: "min", 2: "mean", 3: "std"}
    
    metrics = torch.load(os.path.join(save_dir, save_name))
    
    weight_outlier = metrics.get("weight-outlier", {})
    output_outlier = metrics.get("output-outlier", {})
    frequency_metrics = metrics.get("frequency", {})
    routing_score_metrics = metrics.get("routing-score", {})
    weight_mean_metrics = metrics.get("weight-mean", {})
    
    weight_cka_metrics = metrics.get("weight-cka", {})
    output_cka_metrics = metrics.get("output-cka", {})
    weight_cosine_metrics = metrics.get("weight-cosine", {})
    output_cosine_metrics = metrics.get("output-cosine", {})

    adapt_down_metrics = metrics.get("adapt_down", {})

    
    
    def normalize(x):
        return (x - np.min(x)) / (np.max(x) - np.min(x) + 1e-10)
    
    
    weight_outlier_scores = np.array([tensor.sum().item() for tensor in weight_outlier.values()])
    output_outlier_scores = np.array([tensor.sum().item() for tensor in output_outlier.values()])
    frequency_scores = np.array([tensor.sum().item() for tensor in frequency_metrics.values()])
    weight_mean_scores = np.array([tensor.sum().item() for tensor in weight_mean_metrics.values()]) 
    
    
    frequency_factor = (1 + frequency_scores) / (frequency_scores + 1e-10)
    
    if metric == "weight-outlier":
        
        importance_scores = frequency_factor * weight_outlier_scores
        importance_scores = normalize(importance_scores)
    elif metric == "output-outlier":
        
        importance_scores = frequency_factor * output_outlier_scores
        importance_scores = normalize(importance_scores)
    elif metric == "combined":
        
        importance_scores = frequency_factor * output_outlier_scores * weight_outlier_scores
        importance_scores = normalize(importance_scores)
    elif metric == "weight-mean":
        
        importance_scores = frequency_factor * weight_mean_scores
        importance_scores = normalize(importance_scores)
    elif metric == "output-cka":
        if not output_cka_metrics:
            raise ValueError("output-cka is miss, calculate it first")
        print(f"use {metric}'s {stat_type_map.get(stat_type, 'std')} as the important score")
        output_cka_scores = np.array([tensor[stat_type].item() for tensor in output_cka_metrics.values()])
        
        
        if stat_type in [0, 2]:  
            importance_scores = (-output_cka_scores)
        else:  
            importance_scores = output_cka_scores
        importance_scores = normalize(importance_scores)
    elif metric == "weight-cka":
        if not weight_cka_metrics:
            raise ValueError("weight-cka is miss, calculate it first")
        print(f"use {metric}'s {stat_type_map.get(stat_type, 'std')} as the important score")
        weight_cka_scores = np.array([tensor[stat_type].item() for tensor in weight_cka_metrics.values()])
        
        
        if stat_type in [0, 2]:  
            importance_scores = (-weight_cka_scores)
        else:  
            importance_scores = weight_cka_scores
        importance_scores = normalize(importance_scores)
    elif metric == "output-cosine":
        if not output_cosine_metrics:
            raise ValueError("output-cosine is miss, calculate it first")
        print(f"use {metric}'s {stat_type_map.get(stat_type, 'std')} as the important score")
        output_cosine_scores = np.array([tensor[stat_type].item() for tensor in output_cosine_metrics.values()])
        
        
        if stat_type in [0, 2]:  
            importance_scores = (-output_cosine_scores)
        else:  
            importance_scores = output_cosine_scores
        importance_scores = normalize(importance_scores)
    elif metric == "weight-cosine":
        if not weight_cosine_metrics:
            raise ValueError("weight-cosine is miss, calculate it first")
        print(f"use {metric}'s {stat_type_map.get(stat_type, 'std')} as the important score")
        weight_cosine_scores = np.array([tensor[stat_type].item() for tensor in weight_cosine_metrics.values()])
        
        
        if stat_type in [0, 2]:  
            importance_scores = (-weight_cosine_scores)
        else:  
            importance_scores = weight_cosine_scores
        importance_scores = normalize(importance_scores)
    elif metric == "adapt_down":
        if not adapt_down_metrics:
            raise ValueError("adapt_down is miss, calculate it first")
        importance_scores = np.array([tensor.sum().item() for tensor in adapt_down_metrics.values()])
        importance_scores = normalize(importance_scores)
    else:
        raise ValueError(f"unspported method {metric}，请使用 'weight-outlier', 'output-outlier', 'combined', 'weight-mean', 'output-cka', 'weight-cka', 'output-cosine', 'weight-cosine'")
    
    
    num_experts = model_config.num_experts
    num_layers = model_config.num_hidden_layers

    
    layer_names = [f"model.layers.{i}.mlp" for i in range(num_layers)]
    layer_max_experts = []
    
    for i, layer_name in enumerate(layer_names):
        if custom_max_experts and layer_name in custom_max_experts and optimize_order in ["merge_first", "prune_first"]:
            
            max_expert = custom_max_experts[layer_name]
        else:
            
            max_expert = num_experts
        layer_max_experts.append(max_expert)
    
    
    if optimize_order in ["prune_first", "prune_only"]:
        
        total_experts = sum(layer_max_experts)
        num_experts_to_keep = int(total_experts * pruning_ratio)
    else:
        
        total_experts = sum(layer_max_experts)
        num_experts_to_keep = int(total_experts * pruning_ratio)
    
    
    
    bounds = []
    initial_experts = []
    
    for i, layer_name in enumerate(layer_names):
        layer_max = layer_max_experts[i]
        
        initial_val = layer_max * pruning_ratio
        initial_experts.append(initial_val)

        
        bounds.append((max(int(initial_val - min_margin), 2), layer_max))
    
    
    initial_experts = np.array(initial_experts)
    
    
    def constraint_total(x):
        return num_experts_to_keep - np.sum(x)
    
    
    cons = [{'type': 'eq', 'fun': constraint_total}]
    
    
    if max_neighbor_diff is not None and max_neighbor_diff >= 0:
        for i in range(num_layers - 1):
            
            def make_neighbor_constraint(i, j):
                def constraint(x):
                    return max_neighbor_diff - abs(x[i] - x[j])
                return constraint
            
            cons.append({'type': 'ineq', 'fun': make_neighbor_constraint(i, i+1)})
    
    
    if objective_type == "exp":
        
        def objective(x):
            return -np.sum(importance_scores * np.exp(x / 10))  

        def obj_grad(x):
            return -importance_scores * np.exp(x / 10) / 10

    elif objective_type == "log":
        
        def objective(x):
            return -np.sum(importance_scores * np.log(x + 1))  

        def obj_grad(x):
            return -importance_scores / (x + 1)

    elif objective_type == "square":
        
        def objective(x):
            return -np.sum(importance_scores * x**2)

        def obj_grad(x):
            return -2 * importance_scores * x
            
    else:  
        def objective(x):
            return -np.sum(importance_scores * x)

        def obj_grad(x):
            return -importance_scores
    
    
    res = minimize(
        fun=objective,
        x0=initial_experts,
        method="SLSQP",
        bounds=bounds,
        constraints=cons,  
        jac=obj_grad,
        options={"maxiter": 1000, "disp": True}
    )
    
    
    if not res.success:
        print(f"Failed to optimize: {res.message}")
    else:
        print(f"Optimized successfully: {res.message}")
    
    print(f"Iterations: {res.nit}")
    print(f"Optimized result: {res.x}")
    if hasattr(res, 'constr_violation'):
        print(f"Constraint violation: {res.constr_violation}")
    print(f"Objective value: {res.fun}")
    
    
    
    prune_num = np.floor(res.x).astype(int)
    
    
    for i, layer_name in enumerate(layer_names):
        layer_max = layer_max_experts[i]
        prune_num[i] = min(layer_max, prune_num[i])
    
    
    prune_num = np.floor(res.x).astype(int)

    prune_num = np.clip(prune_num, initial_experts - min_margin, None)
    prune_num = np.floor(res.x).astype(int)
    
    
    remaining = num_experts_to_keep - np.sum(prune_num)
    
    if remaining > 0:
        
        sorted_indices = np.argsort(-importance_scores)
        for i in sorted_indices:
            layer_max = layer_max_experts[i]
                
            if prune_num[i] < layer_max:
                
                addition = min(remaining, layer_max - prune_num[i])
                prune_num[i] += addition
                remaining -= addition
                if remaining == 0:
                    break

    else:
        
        sorted_indices = np.argsort(importance_scores)
        for i in sorted_indices:
            
            min_experts = max(1, initial_experts[i] - min_margin)
            
            if prune_num[i] > min_experts:
                
                reduction = min(-remaining, prune_num[i] - min_experts)
                if reduction > 0:
                    prune_num[i] -= reduction
                    remaining += reduction
                    if remaining == 0:
                        break
    
    print(f"Final each layer experts: {prune_num}")
    print(f"Total experts: {np.sum(prune_num)}/{num_experts_to_keep}")
    
    return prune_num

def cal_merge_num(
    save_dir: str = "./results/metrics",
    save_name: str = "qwen_1.5_moe_metrics.pt",
    metric: str = "output-cka",  
    stat_type: int = 3,  
    min_margin: int = 2,  
    merging_ratio: float = 0.75,  
    custom_max_experts: dict = None,  
    optimize_order: str = "prune_first",  
    max_neighbor_diff: int = None,  
    objective_type: str = "log",  
    model_config: dict = None,
) -> np.ndarray:
    
    if stat_type is None:
        stat_type = 3  
    
    metrics = torch.load(os.path.join(save_dir, save_name))
    
    output_cka_metrics = metrics.get("output-cka", {})
    weight_cka_metrics = metrics.get("weight-cka", {})
    output_cosine_metrics = metrics.get("output-cosine", {})
    weight_cosine_metrics = metrics.get("weight-cosine", {})
    
    
    stat_type_map = {0: "max", 1: "min", 2: "mean", 3: "std"}
    print(f"use {metric}'s {stat_type_map.get(stat_type, 'std')} as the important score")

    
    if metric == "output-cka":
        if not output_cka_metrics:
            raise ValueError("output-cka is miss, calculate it first")
        importance_scores = np.array([tensor[stat_type].item() for tensor in output_cka_metrics.values()])
    elif metric == "weight-cka":
        if not weight_cka_metrics:
            raise ValueError("weight-cka is miss, calculate it first")
        importance_scores = np.array([tensor[stat_type].item() for tensor in weight_cka_metrics.values()])
    elif metric == "output-cosine":
        if not output_cosine_metrics:
            raise ValueError("output-cosine is miss, calculate it first")
        importance_scores = np.array([tensor[stat_type].item() for tensor in output_cosine_metrics.values()])
    elif metric == "weight-cosine":
        if not weight_cosine_metrics:
            raise ValueError("weight-cosine is miss, calculate it first")
        importance_scores = np.array([tensor[stat_type].item() for tensor in weight_cosine_metrics.values()])
    else:
        raise ValueError(f"unspported method {metric}")
    
    
    
    
    if stat_type in [0, 2]:  
        
        importance_scores = -importance_scores
    else:  
        
        pass
    
    
    num_experts = model_config.num_experts
    num_layers = model_config.num_hidden_layers

    
    layer_names = [f"model.layers.{i}.mlp" for i in range(num_layers)]
    layer_max_experts = []
    
    for i, layer_name in enumerate(layer_names):
        if custom_max_experts and layer_name in custom_max_experts and optimize_order in ["prune_first", "merge_first"]:
            
            layer_max = custom_max_experts[layer_name]
        else:
            
            layer_max = num_experts
        layer_max_experts.append(layer_max)
    
    
    if optimize_order in ["prune_first", "merge_only"]:
        
        total_experts = sum(layer_max_experts)
        num_experts_to_keep = int(total_experts * merging_ratio)
    else:
        
        total_experts = num_layers * num_experts
        num_experts_to_keep = int(total_experts * merging_ratio)
    

    
    

    bounds = []
    initial_target_experts = []
    
    for i, layer_name in enumerate(layer_names):
        layer_max = layer_max_experts[i]
        
        initial_val = layer_max * merging_ratio
        initial_target_experts.append(initial_val)
        bounds.append((max(int(initial_val - min_margin), 2), layer_max))
    
    
    initial_target_experts = np.array(initial_target_experts)
    
    
    def constraint_total(x):
        return num_experts_to_keep - np.sum(x)
    
    
    cons = [{'type': 'eq', 'fun': constraint_total}]
    
    
    if max_neighbor_diff is not None and max_neighbor_diff >= 0:
        for i in range(num_layers - 1):
            
            def make_neighbor_constraint(i, j):
                def constraint(x):
                    return max_neighbor_diff - abs(x[i] - x[j])
                return constraint
            
            cons.append({'type': 'ineq', 'fun': make_neighbor_constraint(i, i+1)})
    
    
    if objective_type == "exp":
        
        def objective(x):
            return -np.sum(importance_scores * np.exp(x / 10))  

        def obj_grad(x):
            return -importance_scores * np.exp(x / 10) / 10

    elif objective_type == "log":
        
        def objective(x):
            return -np.sum(importance_scores * np.log(x + 1))  

        def obj_grad(x):
            return -importance_scores / (x + 1)

    elif objective_type == "square":
        
        def objective(x):
            return -np.sum(importance_scores * x**2)

        def obj_grad(x):
            return -2 * importance_scores * x
            
    else:  
        def objective(x):
            return -np.sum(importance_scores * x)

        def obj_grad(x):
            return -importance_scores
    
    
    res = minimize(
        fun=objective,
        x0=initial_target_experts,
        method="SLSQP",
        bounds=bounds,
        constraints=cons,  
        jac=obj_grad,
        options={"maxiter": 1000, "disp": True}
    )
    
    
    if not res.success:
        print(f"Failed to optimize: {res.message}")
    else:
        print(f"Optimized successfully: {res.message}")
    
    print(f"Iterations: {res.nit}")
    print(f"Optimized result: {res.x}")
    if hasattr(res, 'constr_violation'):
        print(f"Constraint violation: {res.constr_violation}")
    print(f"Objective value: {res.fun}")
    
    
    
    target_experts = np.floor(res.x).astype(int)
    
    
    target_experts = np.clip(target_experts, initial_target_experts - min_margin, None)
    target_experts = np.floor(target_experts).astype(int)
    
    remaining = num_experts_to_keep - np.sum(target_experts)
    
    if remaining > 0:
        
        sorted_indices = np.argsort(-importance_scores)
        for i in sorted_indices:
            layer_max = layer_max_experts[i]
                
            if target_experts[i] < layer_max:
                
                addition = min(remaining, layer_max - target_experts[i])
                target_experts[i] += addition
                remaining -= addition
                if remaining == 0:
                    break
    else:
        
        sorted_indices = np.argsort(importance_scores)
        for i in sorted_indices:
            
            min_experts = max(1, initial_target_experts[i] - min_margin)
            
            if target_experts[i] > min_experts:
                
                reduction = min(-remaining, target_experts[i] - min_experts)
                if reduction > 0:
                    target_experts[i] -= reduction
                    remaining += reduction
                    if remaining == 0:
                        break

    
    print(f"Final each layer target experts: {target_experts}")
    print(f"Total target experts: {np.sum(target_experts)}/{num_experts_to_keep}")
    
    return target_experts

@torch.inference_mode()
def dump_qwen_moe_metrics(
    model: Qwen2MoeForCausalLM,
    metric: list[str] = ["frequency", "routing-score", "weight-outlier", "output-outlier", "output-cosine", "weight-cosine", "weight-mean"],
    dataloader = None,
    save_dir: str = "./results/metrics",
    save_name: str = "qwen_moe_metrics.pt",
    override: bool = True,
):
    if os.path.exists(os.path.join(save_dir, save_name)) and not override:
        print(f"File {os.path.join(save_dir, save_name)} already exists.")
        return
    if "all" in metric:
        metric = ["frequency", "routing-score", "weight-outlier", "output-outlier", "output-cosine", "weight-cosine", "weight-mean"]
    print(f"Get {metric} metrics for each layer of Qwen")
    
    metrics_results = {}
    
    
    if "frequency" in metric and dataloader is not None:
        frequency_metric = {}
        
        for layer_idx in range(len(model.model.layers)):
            ffn_name = f"model.layers.{layer_idx}.mlp"
            frequency_metric[ffn_name] = torch.zeros(model.config.num_experts)
        
        model.eval()
        for p in model.parameters():
            p.requires_grad_(False)
            
        for batch in tqdm(dataloader, desc="Calculating frequency metric"):
            batch = {k: v.cuda() for k, v in batch.items()}
            if "labels" in batch:
                
                batch.pop("labels")
            with torch.no_grad():
                outputs = model(**batch, output_router_logits=True)
            all_router_logits = outputs.router_logits
            all_router_logits = torch.stack(all_router_logits)  
            selected_experts = torch.topk(all_router_logits, 2, dim=-1)[1].reshape(
                model.config.num_hidden_layers, -1
            )  
            
            for layer_idx in range(len(model.model.layers)):
                ffn_name = f"model.layers.{layer_idx}.mlp"
                unique, counts = torch.unique(selected_experts[layer_idx], return_counts=True)
                frequency_metric[ffn_name][unique.cpu()] += counts.cpu()
        
        
        frequency_metric = {
            k: v / torch.sum(v) for k, v in frequency_metric.items()
        }
        metrics_results["frequency"] = frequency_metric
        print("Frequency metric calculated")
        
        
        del frequency_metric, all_router_logits, selected_experts
        torch.cuda.empty_cache()

    
    if "routing-score" in metric and dataloader is not None:
        routing_score_metric = {}
        
        for layer_idx in range(len(model.model.layers)):
            ffn_name = f"model.layers.{layer_idx}.mlp"
            routing_score_metric[ffn_name] = torch.zeros(model.config.num_experts)
        
        model.eval()
        for p in model.parameters():
            p.requires_grad_(False)
            
        for batch in tqdm(dataloader, desc="Calculating routing-score metric"):
            batch = {k: v.cuda() for k, v in batch.items()}
            if "labels" in batch:
                batch.pop("labels")
            with torch.no_grad():
                outputs = model(**batch, output_router_logits=True)
            all_router_logits = outputs.router_logits
            
            for layer_idx in range(len(model.model.layers)):
                ffn_name = f"model.layers.{layer_idx}.mlp"
                router_score = torch.nn.functional.softmax(all_router_logits[layer_idx], dim=1)
                scores = router_score.float().sum(0) / router_score.shape[0]
                routing_score_metric[ffn_name] += scores.cpu()
        
        metrics_results["routing-score"] = routing_score_metric
        print("Routing-score metric calculated")
        
        
        del routing_score_metric, all_router_logits, router_score, scores
        torch.cuda.empty_cache()

    
    if "weight-mean" in metric:
        weight_mean_metric = {}
        
        for layer_idx in range(len(model.model.layers)):
            ffn_name = f"model.layers.{layer_idx}.mlp"
            weight_mean_metric[ffn_name] = torch.zeros(model.config.num_experts)
        
        model.eval()
        for p in model.parameters():
            p.requires_grad_(False)
        
        
        expert_info = []
        for name, module in model.named_modules():
            if ('mlp' in name and 'gate' not in name and 
                isinstance(module, torch.nn.Linear) and 'experts' in name):
                parts = name.split('.')
                if len(parts) >= 6:
                    layer_name = f"model.layers.{parts[2]}.mlp"
                    expert_idx = int(parts[5])
                    expert_info.append((layer_name, expert_idx, module))
        
        
        for layer_name, expert_idx, module in tqdm(expert_info, desc="Calculating weight-mean metric"):
            
            with torch.no_grad():
                weight = module.weight.data
                weight_mean = weight.abs().mean().item()
                weight_mean_metric[layer_name][expert_idx] = weight_mean
        
        metrics_results["weight-mean"] = weight_mean_metric
        print("Weight-Mean metric calculated")
        
        
        del weight_mean_metric, expert_info
        torch.cuda.empty_cache()

    
    if "weight-outlier" in metric:
        num_experts = model.config.num_experts
        layer_names = [f"model.layers.{i}.mlp" for i in range(len(model.model.layers))]
        
        
        outlier_score = {name: torch.zeros(num_experts) for name in layer_names}
        
        
        def compute_outlier_score(module):
            with torch.no_grad():
                weight = module.weight.data
                abs_weight = weight.abs()
                return torch.max(abs_weight.max(dim=0).values / abs_weight.mean(dim=0)).item()
        
        
        expert_info = []
        for name, module in model.named_modules():
            if ('mlp' in name and 'gate' not in name and 
                isinstance(module, torch.nn.Linear) and 'experts' in name):
                parts = name.split('.')
                if len(parts) >= 6:
                    layer_name = f"model.layers.{parts[2]}.mlp"
                    expert_idx = int(parts[5])
                    expert_info.append((layer_name, expert_idx, module))
        
        
        for layer_name, expert_idx, module in tqdm(expert_info, desc="Calculating outlier metric"):
            score = compute_outlier_score(module)
            outlier_score[layer_name][expert_idx] += score

        metrics_results["weight-outlier"] = outlier_score
        print("Weight-Outlier metric calculated")
        
        
        del outlier_score, expert_info
        torch.cuda.empty_cache()

    
    if "output-outlier" in metric and dataloader is not None:
        output_outlier_metric = {}
        
        for layer_idx in range(len(model.model.layers)):
            ffn_name = f"model.layers.{layer_idx}.mlp"
            output_outlier_metric[ffn_name] = torch.zeros(model.config.num_experts)
        
        model.eval()
        for p in model.parameters():
            p.requires_grad_(False)
            
        
        def compute_output_outlier_score(expert_output):
            with torch.no_grad():
                abs_output = expert_output.abs()
                return torch.max(abs_output.max(dim=0).values / abs_output.mean(dim=0)).item()
        
        
        original_forwards = {}
        expert_outputs = {}
        
        def _custom_moe_forward(self, hidden_states):
            
            original_output = self._original_forward(hidden_states)
            
            
            batch_size, sequence_length, hidden_dim = hidden_states.shape
            hidden_states_view = hidden_states.view(-1, hidden_dim)
            
            
            router_logits = self.gate(hidden_states_view)
            routing_weights, routing_indices = torch.topk(router_logits, self.top_k, dim=-1)
            
            
            current_batch_experts = []
            for expert_idx, expert in enumerate(self.experts):
                mask = (routing_indices == expert_idx).any(dim=-1)
                if mask.any():
                    expert_inputs = hidden_states_view[mask]
                    with torch.no_grad():
                        expert_output = expert(expert_inputs)
                    current_batch_experts.append(expert_output)
                else:
                    current_batch_experts.append(torch.tensor([]))
            
            
            expert_outputs[self._module_name] = current_batch_experts
            
            return original_output
        
        
        for name, module in model.named_modules():
            if isinstance(module, Qwen2MoeSparseMoeBlock):
                expert_outputs[name] = []
                module._original_forward = module.forward
                module._module_name = name
                module.forward = _custom_moe_forward.__get__(module, type(module))
        
        
        for batch in tqdm(dataloader, desc="Calculating output-outlier metric"):
            batch = {k: v.cuda() for k, v in batch.items()}
            if "labels" in batch:
                batch.pop("labels")
            
            with torch.no_grad():
                model(**batch)
            
            
            for layer_name, expert_outputs_list in expert_outputs.items():
                for expert_idx, expert_output in enumerate(expert_outputs_list):
                    if len(expert_output) > 0:  
                        outlier_score = compute_output_outlier_score(expert_output)
                        output_outlier_metric[layer_name][expert_idx] += outlier_score
        
        
        for name, module in model.named_modules():
            if isinstance(module, Qwen2MoeSparseMoeBlock):
                module.forward = module._original_forward
        
        metrics_results["output-outlier"] = output_outlier_metric
        print("Output-Outlier metric calculated")
        
        
        del output_outlier_metric, expert_outputs, original_forwards
        torch.cuda.empty_cache()

    
    if "output-cka" in metric and dataloader is not None:
        output_cka_metric = {}
        
        for layer_idx in range(len(model.model.layers)):
            ffn_name = f"model.layers.{layer_idx}.mlp"
            output_cka_metric[ffn_name] = torch.zeros(4)  
        
        model.eval()
        for p in model.parameters():
            p.requires_grad_(False)
               
        
        original_forwards = {}
        expert_outputs = {}
        
        def _custom_moe_forward_cka(self, hidden_states):
            
            original_output = self._original_forward(hidden_states)
            
            
            batch_size, sequence_length, hidden_dim = hidden_states.shape
            hidden_states_view = hidden_states.view(-1, hidden_dim)
            
            
            current_batch_experts = [[] for _ in range(len(self.experts))]  
            for expert_idx, expert in enumerate(self.experts):
                with torch.no_grad():
                    expert_output = expert(hidden_states_view)
                current_batch_experts[expert_idx] = expert_output.detach().cpu()
            
            
            if self._module_name not in expert_outputs:
                expert_outputs[self._module_name] = [[] for _ in range(len(self.experts))]
            
            for i, outputs in enumerate(current_batch_experts):
                if len(outputs) > 0:
                    expert_outputs[self._module_name][i].append(outputs)
            
            return original_output
        
        
        for name, module in model.named_modules():
            if isinstance(module, Qwen2MoeSparseMoeBlock):
                module._original_forward = module.forward
                module._module_name = name
                module.forward = _custom_moe_forward_cka.__get__(module, type(module))
        
        
        for batch in tqdm(dataloader, desc="Collecting expert outputs for CKA"):
            batch = {k: v.cuda() for k, v in batch.items()}
            if "labels" in batch:
                batch.pop("labels")
            
            with torch.no_grad():
                model(**batch)
        
        
        for layer_name, layer_expert_outputs in tqdm(expert_outputs.items(), desc='计算每层专家输出的CKA'):
            
            merged_outputs = []
            for expert_idx, expert_batch_outputs in enumerate(layer_expert_outputs):
                if expert_batch_outputs:
                    
                    merged_tensors = []
                    for batch_output in expert_batch_outputs:
                        
                        merged_tensors.append(batch_output.detach().cpu())
                        
                        del batch_output
                        
                    merged = torch.cat(merged_tensors, dim=0)
                    del merged_tensors
                    merged_outputs.append(merged)
                else:
                    merged_outputs.append(None)
            
            
            cka_values = []
            valid_cka_count = 0
            
            for i in range(num_experts):
                for j in range(i+1, num_experts):
                    if merged_outputs[i] is not None and merged_outputs[j] is not None:
                        try:
                            
                            max_samples = 2000  
                            
                            
                            X = merged_outputs[i][:max_samples]
                            Y = merged_outputs[j][:max_samples]
                            
                            
                            cka_value = linear_CKA(X, Y)
                            
                            
                            if not (np.isnan(cka_value) or np.isinf(cka_value)):
                                cka_values.append(cka_value)
                                valid_cka_count += 1
                            else:
                                print(f"Warning: The CKA calculation result for expert {i}-{j} in layer {layer_name} is {cka_value}, skipped")
                                
                            
                            if torch.cuda.is_available():
                                torch.cuda.empty_cache()
                                
                        except Exception as e:
                            print(f"Error calculating the CKA for expert {i}-{j} in layer {layer_name}: {str(e)}")
                            cka_values.append(0.0)  
                    else:
                        cka_values.append(0.0)
            
            
            del merged_outputs
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            
            
            if cka_values and valid_cka_count > 0:
                
                valid_cka_values = [v for v in cka_values if not (np.isnan(v) or np.isinf(v))]
                if valid_cka_values:
                    cka_tensor = torch.tensor(valid_cka_values)
                    
                    output_cka_metric[layer_name][0] = torch.max(cka_tensor)      
                    output_cka_metric[layer_name][1] = torch.min(cka_tensor)      
                    output_cka_metric[layer_name][2] = torch.mean(cka_tensor)     
                    output_cka_metric[layer_name][3] = torch.std(cka_tensor)      
                else:
                    print(f"Warning: The layer {layer_name} has no valid CKA values, set to zero")
                    output_cka_metric[layer_name] = torch.zeros(4)
        
        
        for name, module in model.named_modules():
            if isinstance(module, Qwen2MoeSparseMoeBlock):
                module.forward = module._original_forward
        
        metrics_results["output-cka"] = output_cka_metric
        print("Output-CKA metric calculated")
        
        
        del output_cka_metric, expert_outputs
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

    
    if "weight-cka" in metric:
        weight_cka_metric = {}
        
        for layer_idx in range(len(model.model.layers)):
            ffn_name = f"model.layers.{layer_idx}.mlp"
            weight_cka_metric[ffn_name] = torch.zeros(4)  
        
        model.eval()
        for p in model.parameters():
            p.requires_grad_(False)
     
        
        for layer_idx in tqdm(range(len(model.model.layers)), desc="Calculating weight-cka metric"):
            layer = model.model.layers[layer_idx]
            moe_block = layer.mlp
            
            
            experts = moe_block.experts
            
            
            cka_values = []
            
            for i in range(len(experts)):
                for j in range(i+1, len(experts)):
                    
                    expert_i = experts[i]
                    expert_j = experts[j]
                    
                    
                    w1_cka = linear_CKA(expert_i.gate_proj.weight, expert_j.gate_proj.weight)
                    w2_cka = linear_CKA(expert_i.up_proj.weight, expert_j.up_proj.weight)
                    w3_cka = linear_CKA(expert_i.down_proj.weight, expert_j.down_proj.weight)
                    
                    
                    total_cka = w1_cka + w2_cka + w3_cka
                    
                    cka_values.append(total_cka.item())
            
            
            if cka_values:  
                ffn_name = f"model.layers.{layer_idx}.mlp"
                cka_tensor = torch.tensor(cka_values)
                
                
                weight_cka_metric[ffn_name][0] = torch.max(cka_tensor)      
                weight_cka_metric[ffn_name][1] = torch.min(cka_tensor)      
                weight_cka_metric[ffn_name][2] = torch.mean(cka_tensor)     
                weight_cka_metric[ffn_name][3] = torch.std(cka_tensor)      
        
        metrics_results["weight-cka"] = weight_cka_metric
        print("Weight-CKA metric calculated")
        
        
        del weight_cka_metric, cka_values, cka_tensor
        torch.cuda.empty_cache()

    
    if "output-cosine" in metric and dataloader is not None:

        
        K_RESERVOIR   = 2048
        DTYPE_COLLECT = torch.float16
        DTYPE_SIM     = torch.float16
        DEVICES       = ["cuda:0", "cuda:1"]      
        

        
        class _Reservoir:
            def __init__(self, cap: int, dim: int):
                self.cap, self.dim = cap, dim
                self.buf = torch.empty((0, dim), dtype=DTYPE_COLLECT, device="cpu")
                self.n_seen = 0
            def update(self, x: torch.Tensor):
                m = x.size(0)
                take = min(self.cap - self.buf.size(0), m)
                if take: self.buf = torch.cat([self.buf, x[:take]], 0)
                for i in range(self.buf.size(0), m):
                    j = random.randint(0, self.n_seen + i)
                    if j < self.cap: self.buf[j] = x[i]
                self.n_seen += m
            def fetch(self):                     
                return self.buf.to(DTYPE_SIM)

        
        output_cosine_metric = {
            f"model.layers.{i}.mlp": torch.zeros(4, device="cpu")
            for i in range(len(model.model.layers))
        }

        
        expert_res = {}  
        def _patch_layer(block, layer_name: str):
            block._orig_forward = block.forward
            def _new_forward(self, hidden, *args, **kw):
                with torch.amp.autocast(device_type="cuda", dtype=DTYPE_COLLECT):
                    out = self._orig_forward(hidden, *args, **kw)
                feats = hidden.view(-1, hidden.size(-1))          
                feats_cpu = feats.to("cpu", dtype=DTYPE_COLLECT, non_blocking=True)
                if layer_name not in expert_res:
                    expert_res[layer_name] = [
                        _Reservoir(K_RESERVOIR, feats_cpu.size(-1))
                        for _ in range(len(self.experts))
                    ]
                for idx, expert in enumerate(self.experts):
                    with torch.amp.autocast(device_type="cuda", dtype=DTYPE_COLLECT):
                        e_out = expert(feats.to(DTYPE_COLLECT))
                    expert_res[layer_name][idx].update(e_out.detach().to("cpu"))
                return out
            block.forward = MethodType(_new_forward, block)

        for n, m in model.named_modules():
            if isinstance(m, Qwen2MoeSparseMoeBlock):
                _patch_layer(m, n)

        
        model.eval()
        for p in model.parameters(): p.requires_grad_(False)
        with torch.no_grad():
            for batch in tqdm(dataloader, desc="Collecting outputs (2-GPU)"):
                batch = {k: v.to(model.device, non_blocking=True) for k, v in batch.items()}
                batch.pop("labels", None)
                model(**batch)

        
        streams = {d: torch.cuda.Stream(device=d) for d in DEVICES}

        def _mean_cos(x: torch.Tensor, y: torch.Tensor) -> float:
            x = torch.nn.functional.normalize(x, dim=1)
            y = torch.nn.functional.normalize(y, dim=1)
            return float((x * y).sum(1).mean())

        for idx, (lyr, res_list) in enumerate(tqdm(expert_res.items(), desc="Cosine on 2 GPUs")):
            device = DEVICES[idx & 1]                   
            torch.cuda.set_device(device)
            with torch.cuda.stream(streams[device]):
                num_exp = len(res_list)
                cos_vals = []
                for i in range(num_exp):
                    for j in range(i + 1, num_exp):
                        xi = res_list[i].fetch().to(device, non_blocking=True)
                        yj = res_list[j].fetch().to(device, non_blocking=True)
                        cos_vals.append(_mean_cos(xi, yj))
                        del xi, yj
                if cos_vals:
                    t = torch.tensor(cos_vals, device="cpu")
                    output_cosine_metric[lyr][0] = t.max()
                    output_cosine_metric[lyr][1] = t.min()
                    output_cosine_metric[lyr][2] = t.mean()
                    output_cosine_metric[lyr][3] = t.std()
            torch.cuda.empty_cache()

        
        for n, m in model.named_modules():
            if hasattr(m, "_orig_forward"):
                m.forward = m._orig_forward

        metrics_results["output-cosine"] = output_cosine_metric
        print("Output-Cosine (2-GPU)  calculated")
        del output_cosine_metric, expert_res
        for d in DEVICES: torch.cuda.empty_cache()

    
    if "weight-cosine" in metric:

        
        DEVICES        = ["cuda:0", "cuda:1"]  
        DTYPE_COLLECT  = torch.float16        
        DTYPE_COMPUTE  = torch.float16        
        

        
        weight_cosine_metric = {
            f"model.layers.{i}.mlp": torch.zeros(4, device="cpu")
            for i in range(len(model.model.layers))
        }

        
        def _layer_weight_cosine(experts, device: str) -> torch.Tensor:
            
            w1 = [e.gate_proj.weight.detach().to(device, dtype=DTYPE_COLLECT).flatten() for e in experts]
            w2 = [e.up_proj.weight.detach().to(device, dtype=DTYPE_COLLECT).flatten()   for e in experts]
            w3 = [e.down_proj.weight.detach().to(device, dtype=DTYPE_COLLECT).flatten() for e in experts]

            
            def _stack_norm(lst):
                mat = torch.stack(lst, dim=0).to(DTYPE_COMPUTE)
                return torch.nn.functional.normalize(mat, dim=1)

            w1n, w2n, w3n = map(_stack_norm, (w1, w2, w3))

            
            sim1 = w1n @ w1n.T
            sim2 = w2n @ w2n.T
            sim3 = w3n @ w3n.T
            total_sim = sim1 + sim2 + sim3           

            
            idx = torch.triu_indices(total_sim.size(0), total_sim.size(0), offset=1, device=device)
            return total_sim[idx[0], idx[1]].to("cpu")   

        
        for layer_idx in tqdm(range(len(model.model.layers)), desc="Weight-cosine (2-GPU)"):
            layer      = model.model.layers[layer_idx]
            moe_block  = layer.mlp
            experts    = moe_block.experts
            device     = DEVICES[layer_idx & 1]           

            with torch.cuda.device(device), torch.amp.autocast(device_type="cuda", dtype=DTYPE_COLLECT):
                cos_vec = _layer_weight_cosine(experts, device)

            
            if cos_vec.numel():
                out = weight_cosine_metric[f"model.layers.{layer_idx}.mlp"]
                out[0] = cos_vec.max()
                out[1] = cos_vec.min()
                out[2] = cos_vec.mean()
                out[3] = cos_vec.std()

            
            torch.cuda.empty_cache()

        
        metrics_results["weight-cosine"] = weight_cosine_metric
        

        del weight_cosine_metric, cos_vec
        for d in DEVICES:
            torch.cuda.empty_cache()
    

    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    torch.save(metrics_results, os.path.join(save_dir, save_name))
    print(f"Dumped to {os.path.join(save_dir, save_name)}")
    return metrics_results

@torch.inference_mode()
def load_model_and_dump_metrics(
    model_name: str = "/data/Qwen1.5-MoE-A2.7B-Chat",
    calib_set: Optional[str] = "c4",
    opt_batch_size: Optional[int] = 16,
    n_sentences: Optional[int] = 16,
    metric_save_dir: Optional[str] = "./results/metrics",
    metrics: list[str] = ['all'],
    override: bool = False,
):
    print(f"Get OWL Metric on C4.\n Model: {model_name}\n opt_batch_size={opt_batch_size}, n_sentences={n_sentences}")

    
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    tokenizer.pad_token_id = tokenizer.eos_token_id
    model = Qwen2MoeForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.float16, device_map="auto"
    )
    model.eval()
    
    dataloader = get_calib_dataloder(
        dataset=calib_set,
        tokenizer=tokenizer,
        max_block_size=2048,
        n_blocks_for_stat=n_sentences,
        batch_size=opt_batch_size,
        num_workers=8,
    )

    
    metrics_file = f"qwen_1.5_moe_metrics.pt"
    _ = dump_qwen_moe_metrics(
        model=model, 
        metric=metrics, 
        dataloader=dataloader, 
        save_dir=metric_save_dir, 
        save_name=metrics_file, 
        override=override,
    )
    
    
    del model, tokenizer, dataloader
    torch.cuda.empty_cache()
    
    return metrics_file

def adapt_compress_qwen(
    model_name: str = "/data/Qwen1.5-MoE-A2.7B-Chat",
    calib_set: Optional[str] = "c4", 
    opt_batch_size: Optional[int] = 16,
    n_sentences: Optional[int] = 16,
    skip_metrics_calculation: bool = True, 
    metrics: list[str] = ['all'], 
    objective_type: str = None,  
    
    min_margin: int = None,
    max_neighbor_diff: Optional[int] = None,  

    
    prune_method: str = "weight-outlier", 
    pruning_ratio: float = 0.75,

    
    merge_method: str = "output-cka",  
    stat_type: int = 0,  
    merging_ratio: float = 0.75,  
    
    
    optimize_order: str = "merge_first",  
    
    
    save_opt: bool = True, 
    metric_save_dir: Optional[str] = "./output/metrics",
    opt_save_dir: Optional[str] = "./output/opt",
    override: bool = False, 
    
    
    model: Optional[Qwen2MoeForCausalLM] = None,
    tokenizer: Optional[AutoTokenizer] = None,
    
):
    
    
    stat_type_str = STAT_TYPE_DICT.get(stat_type, "") if stat_type is not None else ""
    opt_save_dir = os.path.join(opt_save_dir, f"{optimize_order}")
    
    
    if stat_type_str:
        
        result_file_path = os.path.join(opt_save_dir, f"mixtral_{optimize_order}_{prune_method}_{pruning_ratio}_{stat_type_str}{merge_method}_{merging_ratio}_{objective_type}.yaml")
    else:
        
        result_file_path = os.path.join(opt_save_dir, f"mixtral_{optimize_order}_{prune_method}_{pruning_ratio}_{merge_method}_{merging_ratio}_{objective_type}.yaml")
    
    if os.path.exists(result_file_path) and not override:

        with open(result_file_path, 'r') as f:
            result_dict = yaml.safe_load(f)
        results = {}
        results["merge_experts"] = np.array(result_dict["merge_experts"])
        results["prune_experts"] = np.array(result_dict["prune_experts"])
            

        return results
    
    
    metrics_file = f"qwen_1.5_moe_metrics.pt"
    
    
    metrics_file_path = os.path.join(metric_save_dir, metrics_file)
    if not skip_metrics_calculation and (not os.path.exists(metrics_file_path) or override):
        
        if model is None or tokenizer is None:
            metrics_file = load_model_and_dump_metrics(
                model_name=model_name,
                calib_set=calib_set,
                opt_batch_size=opt_batch_size,
                n_sentences=n_sentences,
                metric_save_dir=metric_save_dir,
                metrics=metrics,
                override=override
            )
        else:
            
            dataloader = get_calib_dataloder(
                dataset=calib_set,
                tokenizer=tokenizer,
                max_block_size=2048,
                n_blocks_for_stat=n_sentences,
                batch_size=opt_batch_size,
                num_workers=8,
            )
            
            
            _ = dump_qwen_moe_metrics(
                model=model, 
                metric=metrics, 
                dataloader=dataloader, 
                save_dir=metric_save_dir, 
                save_name=metrics_file, 
                override=override,
            )
    
    
    

    results = {}
    num_layers = model.config.num_hidden_layers
    num_experts = model.config.num_experts

    time_start = time.time()
    if optimize_order == "prune_first":

        
        
        prune_nums = cal_prune_num(
            save_dir=metric_save_dir,
            save_name=metrics_file,
            metric=prune_method,
            stat_type=stat_type,
            min_margin=min_margin,
            pruning_ratio=pruning_ratio,
            optimize_order=optimize_order,  
            max_neighbor_diff=max_neighbor_diff,  
            objective_type=objective_type,  
            model_config=model.config,
        )
        results["prune_experts"] = prune_nums
        
        
        
        custom_max_experts = {}
        for layer_idx in range(num_layers):
            layer_name = f"model.layers.{layer_idx}.mlp"
            custom_max_experts[layer_name] = prune_nums[layer_idx]
        
        target_experts = cal_merge_num(
            save_dir=metric_save_dir,
            save_name=metrics_file,
            metric=merge_method,
            stat_type=stat_type,
            min_margin=min_margin,
            merging_ratio=merging_ratio,
            custom_max_experts=custom_max_experts,  
            optimize_order=optimize_order,  
            objective_type=objective_type,  
            max_neighbor_diff=max_neighbor_diff,  
            model_config=model.config,
        )
        results["merge_experts"] = target_experts
        
    elif optimize_order == "merge_first":

        
        
        target_experts = cal_merge_num(
            save_dir=metric_save_dir,
            save_name=metrics_file,
            metric=merge_method,
            stat_type=stat_type,
            min_margin=min_margin,
            merging_ratio=merging_ratio,
            optimize_order=optimize_order,  
            objective_type=objective_type,  
            max_neighbor_diff=max_neighbor_diff,  
            model_config=model.config,
        )
        results["merge_experts"] = target_experts
        
        
        
        custom_max_experts = {}
        for layer_idx in range(num_layers):
            layer_name = f"model.layers.{layer_idx}.mlp"
            custom_max_experts[layer_name] = target_experts[layer_idx]
        
        prune_nums = cal_prune_num(
            save_dir=metric_save_dir,
            save_name=metrics_file,
            metric=prune_method,
            stat_type=stat_type,
            min_margin=min_margin,
            pruning_ratio=pruning_ratio,
            custom_max_experts=custom_max_experts,  
            optimize_order=optimize_order,  
            max_neighbor_diff=max_neighbor_diff,  
            objective_type=objective_type,  
            model_config=model.config,
        )
        results["prune_experts"] = prune_nums
    
    elif optimize_order == "prune_only":

        
        
        prune_nums = cal_prune_num(
            save_dir=metric_save_dir,
            save_name=metrics_file,
            metric=prune_method,
            stat_type=stat_type,
            min_margin=min_margin,
            pruning_ratio=pruning_ratio,
            optimize_order=optimize_order,  
            max_neighbor_diff=max_neighbor_diff,  
            objective_type=objective_type,  
            model_config=model.config,
        )
        results["prune_experts"] = prune_nums
        results["merge_experts"] = prune_nums  
        
    elif optimize_order == "merge_only":

        
        
        target_experts = cal_merge_num(
            save_dir=metric_save_dir,
            save_name=metrics_file,
            metric=merge_method,
            stat_type=stat_type,
            min_margin=min_margin,
            merging_ratio=merging_ratio,
            optimize_order=optimize_order,  
            max_neighbor_diff=max_neighbor_diff,  
            objective_type=objective_type,  
            model_config=model.config,
        )
        results["merge_experts"] = target_experts
        results["prune_experts"] = target_experts  
    
    else:
        raise ValueError(f"Unsupported optimize order: {optimize_order}，Please choose 'prune_first', 'merge_first', 'prune_only' or 'merge_only'")
        
    
    print(f"\n===== Optimized result summary =====")
    print(f"prune_experts: {sum(results['prune_experts'])}")
    print(f"merge_experts: {sum(results['merge_experts'])}")
    print(f"Compression rate: {sum(results['merge_experts'])/(num_experts * num_layers):.4f}")
    print(f"Optimization time: {time.time() - time_start:.2f}s")
    
    if save_opt:
        if not os.path.exists(opt_save_dir):
            os.makedirs(opt_save_dir)
        with open(result_file_path, "w") as f:
            
            results_dict = {}
            
            results_dict["prune_experts"] = results['prune_experts'].tolist()
            results_dict["merge_experts"] = results['merge_experts'].tolist()
            
            
            yaml.dump(results_dict, f, default_flow_style=False, sort_keys=False, indent=4)
    
    return results

if __name__ == "__main__":
    Fire(adapt_compress_qwen)
