import os
import numpy as np
from typing import Optional, Dict, List
import time
import logging
import torch
import yaml
from fire import Fire
from transformers import MixtralForCausalLM, AutoTokenizer
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
from eval import get_calib_dataloder
from tqdm import tqdm
from scipy.optimize import minimize
import matplotlib.pyplot as plt
import torch.nn as nn
import json

logger = logging.getLogger(__name__)

STAT_TYPE_DICT = {
    0: "max",
    1: "min",
    2: "mean",
    3: "std"
}



def check_metrics_mat(
        metrics_path: str
):
    if not os.path.exists(metrics_path):
        print(f"Metrics file not found at {metrics_path}")
        return
    
    metrics = torch.load(metrics_path)
    print(f"Loaded metrics from {metrics_path}")
    
    for metric_name, metric_data in metrics.items():
        print(f"\n{'-'*20} {metric_name} {'-'*20}")
        
        if isinstance(metric_data, dict):
            
            for i, (layer_name, values) in enumerate(list(metric_data.items())[:5]):
                print(f"{layer_name}: {values}")
                if i >= 4:  
                    break
        elif isinstance(metric_data, torch.Tensor):
            
            print(metric_data[:5])
        else:
            
            print(f"Type: {type(metric_data)}")
            print(str(metric_data)[:200] + "..." if len(str(metric_data)) > 200 else str(metric_data))
    
    print("\nMetrics check completed.")

def cal_prune_num(
    save_dir: str = "./results/metrics",
    save_name: str = "mixtral_moe_metrics.pt",
    metric: str = "mutual_information",  
    stat_type: int = 3,  
    min_margin: int = None,  
    pruning_ratio: float = 0.75,
    custom_max_experts: dict = None,  
    optimize_order: str = "prune_first",  
    max_neighbor_diff: int = None,  
    objective_type: str = "linear",  
    rounding_residuals: Optional[np.ndarray] = None,  
    return_residuals: bool = False,  
) -> np.ndarray | tuple[np.ndarray, np.ndarray]:
    
    if stat_type is None:
        stat_type = 2  
    
    
    stat_type_map = {0: "max", 1: "min", 2: "mean", 3: "std"}
    
    metrics = torch.load(os.path.join(save_dir, save_name))
    
    weight_outlier = metrics.get("weight-outlier", {})
    output_outlier = metrics.get("output-outlier", {})
    frequency_metrics = metrics.get("frequency", {})
    routing_score_metrics = metrics.get("routing-score", {})
    weight_mean_metrics = metrics.get("weight-mean", {})
    
    weight_cka_metrics = metrics.get("weight-cka", {})
    output_cka_metrics = metrics.get("output-cka", {})
    weight_cosine_metrics = metrics.get("weight-cosine", {})
    output_cosine_metrics = metrics.get("output-cosine", {})
    mutual_information_metrics = metrics.get("mutual_information", {})

    adapt_down_metrics = metrics.get("adapt_down", {})

    
    
    def normalize(x):
        return (x - np.min(x)) / (np.max(x) - np.min(x) + 1e-10)
    
    
    weight_outlier_scores = np.array([tensor.sum().item() for tensor in weight_outlier.values()])
    output_outlier_scores = np.array([tensor.sum().item() for tensor in output_outlier.values()])
    frequency_scores = np.array([tensor.sum().item() for tensor in frequency_metrics.values()])
    weight_mean_scores = np.array([tensor.sum().item() for tensor in weight_mean_metrics.values()]) 
    
    frequency_factor = (1 + frequency_scores) / (frequency_scores + 1e-10)
    
    if metric == "weight-outlier":
        
        importance_scores = frequency_factor * weight_outlier_scores
        importance_scores = normalize(importance_scores)
    elif metric == "output-outlier":
        
        importance_scores = frequency_factor * output_outlier_scores
        importance_scores = normalize(importance_scores)
    elif metric == "combined":
        
        importance_scores = frequency_factor * output_outlier_scores * weight_outlier_scores
        importance_scores = normalize(importance_scores)
    elif metric == "weight-mean":
        
        importance_scores = frequency_factor * weight_mean_scores
        importance_scores = normalize(importance_scores)
    elif metric == "mutual_information":
        if not mutual_information_metrics:
            raise ValueError("mutual_information is miss, calculate it first")
        importance_scores = np.array([tensor.sum().item() for tensor in mutual_information_metrics.values()])
    elif metric == "output-cosine":
        if not output_cosine_metrics:
            raise ValueError("output-cosine is miss, calculate it first")
        print(f"use {metric}'s {stat_type_map.get(stat_type, 'std')} as the important score")
        output_cosine_scores = np.array([tensor[stat_type].item() for tensor in output_cosine_metrics.values()])
        
        
        if stat_type in [0, 2]:  
            importance_scores = (-output_cosine_scores)
        else:  
            importance_scores = output_cosine_scores
        importance_scores = normalize(importance_scores)
    elif metric == "weight-cosine":
        if not weight_cosine_metrics:
            raise ValueError("weight-cosine is miss, calculate it first")
        print(f"use {metric}'s {stat_type_map.get(stat_type, 'std')} as the important score")
        weight_cosine_scores = np.array([tensor[stat_type].item() for tensor in weight_cosine_metrics.values()])
        
        
        if stat_type in [0, 2]:  
            importance_scores = (-weight_cosine_scores)
        else:  
            importance_scores = weight_cosine_scores
        importance_scores = normalize(importance_scores)
    elif metric == "adapt_down":
        if not adapt_down_metrics:
            raise ValueError("adapt_down is miss, calculate it first")
        importance_scores = np.array([tensor.sum().item() for tensor in adapt_down_metrics.values()])
        importance_scores = normalize(importance_scores)
    else:
        raise ValueError(f"Unsupported method: {metric}, please use 'weight-outlier', 'output-outlier', 'combined', 'weight-mean', 'mutual_information', 'output-cosine', 'weight-cosine'")
    
    
    if rounding_residuals is not None:
        importance_scores = importance_scores + (1e-3 * rounding_residuals)
    
    
    num_experts = 8
    num_layers = 32

    
    layer_names = [f"model.layers.{i}.block_sparse_moe" for i in range(num_layers)]
    layer_max_experts = []
    
    for i, layer_name in enumerate(layer_names):
        if custom_max_experts and layer_name in custom_max_experts and optimize_order in ["merge_first", "prune_first"]:
            
            max_expert = custom_max_experts[layer_name]
        else:
            
            max_expert = num_experts
        layer_max_experts.append(max_expert)
    
    
    if optimize_order in ["prune_first", "prune_only"]:
        
        total_experts = sum(layer_max_experts)
        num_experts_to_keep = int(total_experts * pruning_ratio)
    else:
        
        total_experts = sum(layer_max_experts)
        num_experts_to_keep = int(total_experts * pruning_ratio)
    
    
    
    bounds = []
    initial_experts = []
    
    for i, layer_name in enumerate(layer_names):
        layer_max = layer_max_experts[i]
        
        initial_val = layer_max * pruning_ratio
        initial_experts.append(initial_val)

        
        bounds.append((max(int(initial_val - min_margin), 2), layer_max))
    
    
    initial_experts = np.array(initial_experts)
    
    
    def constraint_total(x):
        return num_experts_to_keep - np.sum(x)
    
    
    cons = [{'type': 'eq', 'fun': constraint_total}]
    
    
    if max_neighbor_diff is not None and max_neighbor_diff >= 0:
        for i in range(num_layers - 1):
            
            def make_neighbor_constraint(i, j):
                def constraint(x):
                    return max_neighbor_diff - abs(x[i] - x[j])
                return constraint
            
            cons.append({'type': 'ineq', 'fun': make_neighbor_constraint(i, i+1)})
    
    
    if objective_type == "exp":
        
        def objective(x):
            return -np.sum(importance_scores * np.exp(x / 10))  

        def obj_grad(x):
            return -importance_scores * np.exp(x / 10) / 10
    elif objective_type == "log":
        
        def objective(x):
            return -np.sum(importance_scores * np.log(x + 1))  

        def obj_grad(x):
            return -importance_scores / (x + 1)
    elif objective_type == "square":
        
        def objective(x):
            return -np.sum(importance_scores * x**2)

        def obj_grad(x):
            return -2 * importance_scores * x
    elif objective_type == "linear":
        def objective(x):
            return -np.sum(importance_scores * x)

        def obj_grad(x):
            return -importance_scores
    else:  
        raise ValueError(f"Unsupported objective function type: {objective_type}，Use 'linear', 'exp', 'log', 'square'")
    
    
    res = minimize(
        fun=objective,
        x0=initial_experts,
        method="SLSQP",
        bounds=bounds,
        constraints=cons,  
        jac=obj_grad,
        options={"maxiter": 1000, "disp": True}
    )
    
    
    if not res.success:
        print(f"Failed to optimize: {res.message}")
    else:
        print(f"Optimized successfully: {res.message}")
    
    print(f"Iterations: {res.nit}")
    print(f"Optimized result: {res.x}")
    if hasattr(res, 'constr_violation'):
        print(f"Constraint violation: {res.constr_violation}")
    print(f"Objective value: {res.fun}")
    
    
    desired_counts_float = res.x
    rounded_counts = np.rint(desired_counts_float).astype(int)
    
    lower_bounds = np.array([max(int(initial_experts[i] - min_margin), 2) for i in range(num_layers)])
    upper_bounds = np.array(layer_max_experts)
    rounded_counts = np.clip(rounded_counts, lower_bounds, upper_bounds)
    rounding_residuals_out = desired_counts_float - rounded_counts.astype(float)

    print(f"Final each layer experts: {rounded_counts}")

    if return_residuals:
        return rounded_counts, rounding_residuals_out
    return rounded_counts

def cal_merge_num(
    save_dir: str = "./results/metrics",
    save_name: str = "mixtral_moe_metrics.pt",
    metric: str = "output-cosine",  
    stat_type: int = 2,  
    min_margin: int = 2,  
    merging_ratio: float = 0.75,  
    custom_max_experts: dict = None,  
    optimize_order: str = "prune_first",  
    max_neighbor_diff: int = None,  
    objective_type: str = "linear",  
    rounding_residuals: Optional[np.ndarray] = None,  
    return_residuals: bool = False,
) -> np.ndarray | tuple[np.ndarray, np.ndarray]:
    
    if stat_type is None:
        stat_type = 3  
    
    metrics = torch.load(os.path.join(save_dir, save_name))
    
    output_cka_metrics = metrics.get("output-cka", {})
    weight_cka_metrics = metrics.get("weight-cka", {})
    output_cosine_metrics = metrics.get("output-cosine", {})
    weight_cosine_metrics = metrics.get("weight-cosine", {})
    
    
    stat_type_map = {0: "max", 1: "min", 2: "mean", 3: "std"}
    print(f"use {metric}'s {stat_type_map.get(stat_type, 'std')} as the important score")

    
    if metric == "output-cosine":
        if not output_cosine_metrics:
            raise ValueError("output-cosine is miss, calculate it first")
        importance_scores = np.array([tensor[stat_type].item() for tensor in output_cosine_metrics.values()])
    elif metric == "weight-cosine":
        if not weight_cosine_metrics:
            raise ValueError("weight-cosine is miss, calculate it first")
        importance_scores = np.array([tensor[stat_type].item() for tensor in weight_cosine_metrics.values()])
    else:
        raise ValueError(f"unspported method {metric}")
    
    
    
    
    if stat_type in [0, 2]:  
        
        importance_scores = -importance_scores
    else:  
        
        pass

    
    if rounding_residuals is not None:
        importance_scores = importance_scores + (1e-3 * rounding_residuals)
    
    
    num_experts = 8
    num_layers = 32

    
    layer_names = [f"model.layers.{i}.block_sparse_moe" for i in range(num_layers)]
    layer_max_experts = []
    
    for i, layer_name in enumerate(layer_names):
        if custom_max_experts and layer_name in custom_max_experts and optimize_order in ["prune_first", "merge_first"]:
            
            layer_max = custom_max_experts[layer_name]
        else:
            
            layer_max = 8
        layer_max_experts.append(layer_max)
    
    
    if optimize_order in ["prune_first", "merge_only"]:
        
        total_experts = sum(layer_max_experts)
        num_experts_to_keep = int(total_experts * merging_ratio)
    else:
        
        total_experts = num_layers * num_experts
        num_experts_to_keep = int(total_experts * merging_ratio)
    

    
    

    bounds = []
    initial_target_experts = []
    
    for i, layer_name in enumerate(layer_names):
        layer_max = layer_max_experts[i]
        
        initial_val = layer_max * merging_ratio
        initial_target_experts.append(initial_val)
        bounds.append((max(int(initial_val - min_margin), 2), layer_max))
    
    
    initial_target_experts = np.array(initial_target_experts)
    
    
    def constraint_total(x):
        return num_experts_to_keep - np.sum(x)
    
    
    cons = [{'type': 'eq', 'fun': constraint_total}]
    
    
    if max_neighbor_diff is not None and max_neighbor_diff >= 0:
        for i in range(num_layers - 1):
            
            def make_neighbor_constraint(i, j):
                def constraint(x):
                    return max_neighbor_diff - abs(x[i] - x[j])
                return constraint
            
            cons.append({'type': 'ineq', 'fun': make_neighbor_constraint(i, i+1)})
    
    
    if objective_type == "exp":
        
        def objective(x):
            return -np.sum(importance_scores * np.exp(x / 10))  

        def obj_grad(x):
            return -importance_scores * np.exp(x / 10) / 10

    elif objective_type == "log":
        logging.info(f"Use logarithmic objective function")
        
        def objective(x):
            return -np.sum(importance_scores * np.log(x + 1))  

        def obj_grad(x):
            return -importance_scores / (x + 1)

    elif objective_type == "square":
        logging.info(f"Use square objective function")
        
        def objective(x):
            return -np.sum(importance_scores * x**2)

        def obj_grad(x):
            return -2 * importance_scores * x
    
    elif objective_type == "linear":
        logging.info(f"Use linear objective function")
        def objective(x):
            return -np.sum(importance_scores * x)

        def obj_grad(x):
            return -importance_scores
    else:  
        raise ValueError(f"Unsupported objective function type: {objective_type}，Use 'linear', 'exp', 'log', 'square'")
    
    
    res = minimize(
        fun=objective,
        x0=initial_target_experts,
        method="SLSQP",
        bounds=bounds,
        constraints=cons,  
        jac=obj_grad,
        options={"maxiter": 1000, "disp": True}
    )
    
    
    if not res.success:
        print(f"Failed to optimize: {res.message}")
    else:
        print(f"Optimized successfully: {res.message}")
    
    print(f"Iterations: {res.nit}")
    print(f"Optimized result: {res.x}")
    if hasattr(res, 'constr_violation'):
        print(f"Constraint violation: {res.constr_violation}")
    print(f"Objective value: {res.fun}")
    
    
    desired_counts_float = res.x
    rounded_counts = np.rint(desired_counts_float).astype(int)
    lower_bounds = np.array([max(int(initial_target_experts[i] - min_margin), 2) for i in range(num_layers)])
    upper_bounds = np.array(layer_max_experts)
    rounded_counts = np.clip(rounded_counts, lower_bounds, upper_bounds)
    rounding_residuals_out = desired_counts_float - rounded_counts.astype(float)

    print(f"Final each layer target experts: {rounded_counts}")
    print(f"Total target experts (rounded): {np.sum(rounded_counts)}, target: {num_experts_to_keep}")

    if return_residuals:
        return rounded_counts, rounding_residuals_out
    return rounded_counts

@torch.inference_mode()
def dump_mixtral_moe_metrics(
    model: MixtralForCausalLM,
    metric: list[str] = ["frequency", "routing-score", "weight-outlier", "output-outlier", "output-cosine", "weight-cosine", "weight-mean", "mutual_information"],
    dataloader = None,
    save_dir: str = "./results/metrics",
    save_name: str = "mixtral_moe_metrics.pt",
    override: bool = True,
    sigmoid_t: float = 0.001,
):
    if os.path.exists(os.path.join(save_dir, save_name)) and not override:
        print(f"File {os.path.join(save_dir, save_name)} already exists.")
        return
    if "all" in metric:
        metric = ["frequency", "routing-score", "weight-outlier", "output-outlier", "output-cosine", "weight-cosine", "weight-mean", "mutual_information"]
    print(f"Get {metric} metrics for each layer of mixtral-8x7b-v0.1")
    
    metrics_results = {}
    
    
    if "frequency" in metric and dataloader is not None:
        frequency_metric = {}
        
        for layer_idx in range(len(model.model.layers)):
            ffn_name = f"model.layers.{layer_idx}.block_sparse_moe"
            frequency_metric[ffn_name] = torch.zeros(model.config.num_local_experts)
        
        model.eval()
        for p in model.parameters():
            p.requires_grad_(False)
            
        for batch in tqdm(dataloader, desc="Calculating frequency metric"):
            batch = {k: v.cuda() for k, v in batch.items()}
            if "labels" in batch:
                
                batch.pop("labels")
            with torch.no_grad():
                outputs = model(**batch, output_router_logits=True)
            all_router_logits = outputs.router_logits
            all_router_logits = torch.stack(all_router_logits)  
            selected_experts = torch.topk(all_router_logits, 2, dim=-1)[1].reshape(
                model.config.num_hidden_layers, -1
            )  
            
            for layer_idx in range(len(model.model.layers)):
                ffn_name = f"model.layers.{layer_idx}.block_sparse_moe"
                unique, counts = torch.unique(selected_experts[layer_idx], return_counts=True)
                frequency_metric[ffn_name][unique.cpu()] += counts.cpu()
        
        
        frequency_metric = {
            k: v / torch.sum(v) for k, v in frequency_metric.items()
        }
        metrics_results["frequency"] = frequency_metric
        print("Frequency metric calculated")
        
        
        del frequency_metric, all_router_logits, selected_experts
        torch.cuda.empty_cache()

    
    if "routing-score" in metric and dataloader is not None:
        routing_score_metric = {}
        
        for layer_idx in range(len(model.model.layers)):
            ffn_name = f"model.layers.{layer_idx}.block_sparse_moe"
            routing_score_metric[ffn_name] = torch.zeros(model.config.num_local_experts)
        
        model.eval()
        for p in model.parameters():
            p.requires_grad_(False)
            
        for batch in tqdm(dataloader, desc="Calculating routing-score metric"):
            batch = {k: v.cuda() for k, v in batch.items()}
            if "labels" in batch:
                batch.pop("labels")
            with torch.no_grad():
                outputs = model(**batch, output_router_logits=True)
            all_router_logits = outputs.router_logits
            
            for layer_idx in range(len(model.model.layers)):
                ffn_name = f"model.layers.{layer_idx}.block_sparse_moe"
                router_score = torch.nn.functional.softmax(all_router_logits[layer_idx], dim=1)
                scores = router_score.float().sum(0) / router_score.shape[0]
                routing_score_metric[ffn_name] += scores.cpu()
        
        metrics_results["routing-score"] = routing_score_metric
        print("Routing-score metric calculated")
        
        
        del routing_score_metric, all_router_logits, router_score, scores
        torch.cuda.empty_cache()

    
    if "weight-mean" in metric:
        weight_mean_metric = {}
        
        for layer_idx in range(len(model.model.layers)):
            ffn_name = f"model.layers.{layer_idx}.block_sparse_moe"
            weight_mean_metric[ffn_name] = torch.zeros(model.config.num_local_experts)
        
        model.eval()
        for p in model.parameters():
            p.requires_grad_(False)
        
        
        expert_info = []
        for name, module in model.named_modules():
            if ('block_sparse_moe' in name and 'gate' not in name and 
                isinstance(module, torch.nn.Linear) and 'experts' in name):
                parts = name.split('.')
                if len(parts) >= 6:
                    layer_name = f"model.layers.{parts[2]}.block_sparse_moe"
                    expert_idx = int(parts[5])
                    expert_info.append((layer_name, expert_idx, module))
        
        
        for layer_name, expert_idx, module in tqdm(expert_info, desc="Calculating weight-mean metric"):
            
            with torch.no_grad():
                weight = module.weight.data
                weight_mean = weight.abs().mean().item()
                weight_mean_metric[layer_name][expert_idx] = weight_mean
        
        metrics_results["weight-mean"] = weight_mean_metric
        print("Weight-Mean metric calculated")
        
        
        del weight_mean_metric, expert_info
        torch.cuda.empty_cache()

    
    if "weight-outlier" in metric:
        num_experts = model.config.num_local_experts
        layer_names = [f"model.layers.{i}.block_sparse_moe" for i in range(len(model.model.layers))]
        
        
        outlier_score = {name: torch.zeros(num_experts) for name in layer_names}
        
        
        def compute_outlier_score(module):
            with torch.no_grad():
                weight = module.weight.data
                abs_weight = weight.abs()
                return torch.max(abs_weight.max(dim=0).values / abs_weight.mean(dim=0)).item()
        
        
        expert_info = []
        for name, module in model.named_modules():
            if ('block_sparse_moe' in name and 'gate' not in name and 
                isinstance(module, torch.nn.Linear) and 'experts' in name):
                parts = name.split('.')
                if len(parts) >= 6:
                    layer_name = f"model.layers.{parts[2]}.block_sparse_moe"
                    expert_idx = int(parts[5])
                    expert_info.append((layer_name, expert_idx, module))
        
        
        for layer_name, expert_idx, module in tqdm(expert_info, desc="Calculating outlier metric"):
            score = compute_outlier_score(module)
            outlier_score[layer_name][expert_idx] += score

        metrics_results["weight-outlier"] = outlier_score
        print("Weight-Outlier metric calculated")
        
        
        del outlier_score, expert_info
        torch.cuda.empty_cache()

    
    if "output-outlier" in metric and dataloader is not None:
        output_outlier_metric = {}
        
        for layer_idx in range(len(model.model.layers)):
            ffn_name = f"model.layers.{layer_idx}.block_sparse_moe"
            output_outlier_metric[ffn_name] = torch.zeros(model.config.num_local_experts)
        
        model.eval()
        for p in model.parameters():
            p.requires_grad_(False)
            
        
        def compute_output_outlier_score(expert_output):
            with torch.no_grad():
                abs_output = expert_output.abs()
                return torch.max(abs_output.max(dim=0).values / abs_output.mean(dim=0)).item()
        
        
        original_forwards = {}
        expert_outputs = {}
        
        def _custom_moe_forward(self, hidden_states):
            
            original_output = self._original_forward(hidden_states)
            
            
            batch_size, sequence_length, hidden_dim = hidden_states.shape
            hidden_states_view = hidden_states.view(-1, hidden_dim)
            
            
            router_logits = self.gate(hidden_states_view)
            routing_weights, routing_indices = torch.topk(router_logits, self.top_k, dim=-1)
            
            
            current_batch_experts = []
            for expert_idx, expert in enumerate(self.experts):
                mask = (routing_indices == expert_idx).any(dim=-1)
                if mask.any():
                    expert_inputs = hidden_states_view[mask]
                    with torch.no_grad():
                        expert_output = expert(expert_inputs)
                    current_batch_experts.append(expert_output)
                else:
                    current_batch_experts.append(torch.tensor([]))
            
            
            expert_outputs[self._module_name] = current_batch_experts
            
            return original_output
        
        
        for name, module in model.named_modules():
            if isinstance(module, MixtralSparseMoeBlock):
                expert_outputs[name] = []
                module._original_forward = module.forward
                module._module_name = name
                module.forward = _custom_moe_forward.__get__(module, type(module))
        
        
        for batch in tqdm(dataloader, desc="Calculating output-outlier metric"):
            batch = {k: v.cuda() for k, v in batch.items()}
            if "labels" in batch:
                batch.pop("labels")
            
            with torch.no_grad():
                model(**batch)
            
            
            for layer_name, expert_outputs_list in expert_outputs.items():
                for expert_idx, expert_output in enumerate(expert_outputs_list):
                    if len(expert_output) > 0:  
                        outlier_score = compute_output_outlier_score(expert_output)
                        output_outlier_metric[layer_name][expert_idx] += outlier_score
        
        
        for name, module in model.named_modules():
            if isinstance(module, MixtralSparseMoeBlock):
                module.forward = module._original_forward
        
        metrics_results["output-outlier"] = output_outlier_metric
        print("Output-Outlier metric calculated")
        
        
        del output_outlier_metric, expert_outputs, original_forwards
        torch.cuda.empty_cache()

    
    if "output-cka" in metric and dataloader is not None:
        output_cka_metric = {}
        for layer_idx in range(len(model.model.layers)):
            ffn_name = f"model.layers.{layer_idx}.block_sparse_moe"
            output_cka_metric[ffn_name] = torch.zeros(4)

        model.eval()
        for p in model.parameters():
            p.requires_grad_(False)

        def _linear_cka(X: torch.Tensor, Y: torch.Tensor, eps: float = 1e-10) -> float:
            X = X.to(torch.float32)
            Y = Y.to(torch.float32)
            X = X - X.mean(dim=0, keepdim=True)
            Y = Y - Y.mean(dim=0, keepdim=True)
            cross = torch.matmul(X.T, Y)
            num = cross.pow(2).sum()
            XX = torch.matmul(X.T, X)
            YY = torch.matmul(Y.T, Y)
            den = torch.sqrt((XX.pow(2).sum() + eps) * (YY.pow(2).sum() + eps))
            return (num / (den + eps)).item()

        
        expert_outputs: Dict[str, List[List[torch.Tensor]]] = {}

        def _custom_moe_forward_cka(self, hidden_states):
            original_output = self._original_forward(hidden_states)
            batch_size, sequence_length, hidden_dim = hidden_states.shape
            hidden_states_view = hidden_states.view(-1, hidden_dim)
            current_batch_experts = [[] for _ in range(len(self.experts))]
            for expert_idx, expert in enumerate(self.experts):
                with torch.no_grad():
                    expert_output = expert(hidden_states_view)
                current_batch_experts[expert_idx] = expert_output.detach().cpu()
            if self._module_name not in expert_outputs:
                expert_outputs[self._module_name] = [[] for _ in range(len(self.experts))]
            for i, outputs in enumerate(current_batch_experts):
                if outputs.numel() > 0:
                    expert_outputs[self._module_name][i].append(outputs)
            return original_output

        
        for name, module in model.named_modules():
            if isinstance(module, MixtralSparseMoeBlock):
                module._original_forward = module.forward
                module._module_name = name
                module.forward = _custom_moe_forward_cka.__get__(module, type(module))

        
        for batch in tqdm(dataloader, desc="Collecting expert outputs for CKA"):
            batch = {k: v.cuda() for k, v in batch.items()}
            if "labels" in batch:
                batch.pop("labels")
            with torch.no_grad():
                model(**batch)

        
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        max_samples = 4096
        for layer_name, layer_expert_outputs in tqdm(expert_outputs.items(), desc='Calculating output-cka metric'):
            merged_outputs: List[Optional[torch.Tensor]] = []
            for expert_batch_outputs in layer_expert_outputs:
                if len(expert_batch_outputs) > 0:
                    merged = torch.cat(expert_batch_outputs, dim=0)
                    if merged.size(0) > max_samples:
                        merged = merged[:max_samples]
                    merged_outputs.append(merged)
                else:
                    merged_outputs.append(None)

            cka_values: List[float] = []
            num_experts_local = len(merged_outputs)
            for i in range(num_experts_local):
                for j in range(i + 1, num_experts_local):
                    if merged_outputs[i] is not None and merged_outputs[j] is not None:
                        Xi = merged_outputs[i].to(device)
                        Yj = merged_outputs[j].to(device)
                        
                        if Xi.dim() > 2:
                            Xi = Xi.reshape(Xi.shape[0], -1)
                        if Yj.dim() > 2:
                            Yj = Yj.reshape(Yj.shape[0], -1)
                        try:
                            val = _linear_cka(Xi, Yj)
                        except Exception:
                            val = 0.0
                        cka_values.append(float(val))
                    else:
                        cka_values.append(0.0)

            if cka_values:
                t = torch.tensor(cka_values)
                output_cka_metric[layer_name][0] = torch.max(t)
                output_cka_metric[layer_name][1] = torch.min(t)
                output_cka_metric[layer_name][2] = torch.mean(t)
                output_cka_metric[layer_name][3] = torch.std(t)

        
        for name, module in model.named_modules():
            if isinstance(module, MixtralSparseMoeBlock):
                module.forward = module._original_forward

        metrics_results["output-cka"] = output_cka_metric
        print("Output-CKA metric calculated")
        torch.cuda.empty_cache()

    
    if "weight-cka" in metric:
        weight_cka_metric = {}
        for layer_idx in range(len(model.model.layers)):
            ffn_name = f"model.layers.{layer_idx}.block_sparse_moe"
            weight_cka_metric[ffn_name] = torch.zeros(4)

        model.eval()
        for p in model.parameters():
            p.requires_grad_(False)

        def _linear_cka_2d(X: torch.Tensor, Y: torch.Tensor, eps: float = 1e-10) -> float:
            
            X = X.to(torch.float32)
            Y = Y.to(torch.float32)
            X = X - X.mean(dim=0, keepdim=True)
            Y = Y - Y.mean(dim=0, keepdim=True)
            cross = torch.matmul(X.T, Y)
            num = cross.pow(2).sum()
            XX = torch.matmul(X.T, X)
            YY = torch.matmul(Y.T, Y)
            den = torch.sqrt((XX.pow(2).sum() + eps) * (YY.pow(2).sum() + eps))
            return (num / (den + eps)).item()

        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        for layer_idx in tqdm(range(len(model.model.layers)), desc="Calculating weight-cka metric"):
            moe_block = model.model.layers[layer_idx].block_sparse_moe
            experts = moe_block.experts
            cka_values: List[float] = []
            for i in range(len(experts)):
                for j in range(i + 1, len(experts)):
                    ei = experts[i]
                    ej = experts[j]
                    total_val = 0.0
                    for wi, wj in ((ei.w1.weight, ej.w1.weight), (ei.w2.weight, ej.w2.weight), (ei.w3.weight, ej.w3.weight)):
                        Xi = wi.detach().to(device)
                        Yj = wj.detach().to(device)
                        if Xi.dim() == 2 and Yj.dim() == 2 and Xi.size(1) == Yj.size(1):
                            val = _linear_cka_2d(Xi, Yj)
                        else:
                            
                            val = _linear_cka_2d(Xi.reshape(1, -1), Yj.reshape(1, -1))
                        total_val += float(val)
                    cka_values.append(total_val)

            if cka_values:
                t = torch.tensor(cka_values)
                ffn_name = f"model.layers.{layer_idx}.block_sparse_moe"
                weight_cka_metric[ffn_name][0] = torch.max(t)
                weight_cka_metric[ffn_name][1] = torch.min(t)
                weight_cka_metric[ffn_name][2] = torch.mean(t)
                weight_cka_metric[ffn_name][3] = torch.std(t)

        metrics_results["weight-cka"] = weight_cka_metric
        print("Weight-CKA metric calculated")
        torch.cuda.empty_cache()

    
    time_start = time.time()
    if "output-cosine" in metric and dataloader is not None:
        output_cosine_metric = {}
        
        for layer_idx in range(len(model.model.layers)):
            ffn_name = f"model.layers.{layer_idx}.block_sparse_moe"
            output_cosine_metric[ffn_name] = torch.zeros(4)  
        
        model.eval()
        
        
        def compute_cosine_similarity(X, Y):
            
            if X.dim() > 2:
                X = X.reshape(X.shape[0], -1).to(torch.float32)
            if Y.dim() > 2:
                Y = Y.reshape(Y.shape[0], -1).to(torch.float32)

            
            min_rows = min(X.shape[0], Y.shape[0])
            X = X[:min_rows]
            Y = Y[:min_rows]
            
            
            return torch.nn.functional.cosine_similarity(X, Y).mean().item()
        
        
        original_forwards = {}
        expert_outputs = {}
        
        def _custom_moe_forward_cosine(self, hidden_states):
            
            original_output = self._original_forward(hidden_states)
            
            
            batch_size, sequence_length, hidden_dim = hidden_states.shape
            hidden_states_view = hidden_states.view(-1, hidden_dim)
            
            
            current_batch_experts = [[] for _ in range(len(self.experts))]
            for expert_idx, expert in enumerate(self.experts):
                with torch.no_grad():
                    expert_output = expert(hidden_states_view)
                current_batch_experts[expert_idx] = expert_output.detach().cpu()
            
            
            if self._module_name not in expert_outputs:
                expert_outputs[self._module_name] = [[] for _ in range(len(self.experts))]
            
            for i, outputs in enumerate(current_batch_experts):
                if len(outputs) > 0:
                    expert_outputs[self._module_name][i].append(outputs)
            
            return original_output
        
        
        for name, module in model.named_modules():
            if isinstance(module, MixtralSparseMoeBlock):
                module._original_forward = module.forward
                module._module_name = name
                module.forward = _custom_moe_forward_cosine.__get__(module, type(module))
        
        
        for batch in tqdm(dataloader, desc="Collecting expert outputs for cosine similarity"):
            batch = {k: v.cuda() for k, v in batch.items()}
            if "labels" in batch:
                batch.pop("labels")
            
            with torch.no_grad():
                model(**batch)
        
        
        for layer_name, layer_expert_outputs in tqdm(expert_outputs.items(), desc="Calculating output-cosine metric"):
            
            merged_outputs = []
            for expert_idx, expert_batch_outputs in enumerate(layer_expert_outputs):
                
                merged = torch.cat(expert_batch_outputs, dim=0).detach().cpu()  
                merged_outputs.append(merged)
            
            
            num_experts = len(merged_outputs)
            cosine_values = []
            
            for i in range(num_experts):
                for j in range(i+1, num_experts):
                    if merged_outputs[i] is not None and merged_outputs[j] is not None:
                        
                        
                        cosine_value = compute_cosine_similarity(merged_outputs[i], merged_outputs[j])
                        cosine_values.append(cosine_value)
                    else:
                        
                        cosine_values.append(0.0)
            
            
            if cosine_values:  
                cosine_tensor = torch.tensor(cosine_values)
                
                output_cosine_metric[layer_name][0] = torch.max(cosine_tensor)      
                output_cosine_metric[layer_name][1] = torch.min(cosine_tensor)      
                output_cosine_metric[layer_name][2] = torch.mean(cosine_tensor)     
                output_cosine_metric[layer_name][3] = torch.std(cosine_tensor)      
            
            
            del merged_outputs, cosine_values
        
        
        for name, module in model.named_modules():
            if isinstance(module, MixtralSparseMoeBlock):
                module.forward = module._original_forward
        
        metrics_results["output-cosine"] = output_cosine_metric
        print("Output-Cosine metric calculated")
        
        
        del output_cosine_metric, expert_outputs
        torch.cuda.empty_cache()
    print(f"Output-Cosine metric calculated时间: {time.time() - time_start}")

    
    time_start = time.time()
    if "weight-cosine" in metric:
        weight_cosine_metric = {}
        
        for layer_idx in range(len(model.model.layers)):
            ffn_name = f"model.layers.{layer_idx}.block_sparse_moe"
            weight_cosine_metric[ffn_name] = torch.zeros(4)  
        
        model.eval()
        
        
        def compute_cosine_similarity_optimized(X, Y):
            
            X_flat = X.reshape(-1).to(torch.float32)
            Y_flat = Y.reshape(-1).to(torch.float32)
            
            
            return torch.nn.functional.cosine_similarity(X_flat.unsqueeze(0), Y_flat.unsqueeze(0))[0]
        
        
        for layer_idx in tqdm(range(len(model.model.layers)), desc="Calculating weight-cosine metric"):
            layer = model.model.layers[layer_idx]
            moe_block = layer.block_sparse_moe
            
            
            experts = moe_block.experts
            num_experts = len(experts)
            
            expert_weights = []
            for expert in experts:
                expert_w = {
                    "w1": expert.w1.weight,
                    "w2": expert.w2.weight,
                    "w3": expert.w3.weight
                }
                expert_weights.append(expert_w)
            
            
            cosine_values = []
            
            
            for i in range(num_experts):
                for j in range(i+1, num_experts):
                    
                    w1_cosine = compute_cosine_similarity_optimized(
                        expert_weights[i]["w1"], expert_weights[j]["w1"]
                    )
                    w2_cosine = compute_cosine_similarity_optimized(
                        expert_weights[i]["w2"], expert_weights[j]["w2"]
                    )
                    w3_cosine = compute_cosine_similarity_optimized(
                        expert_weights[i]["w3"], expert_weights[j]["w3"]
                    )
                    
                    
                    total_cosine = w1_cosine + w2_cosine + w3_cosine
                    
                    cosine_values.append(total_cosine.item())
            
            
            if cosine_values:  
                ffn_name = f"model.layers.{layer_idx}.block_sparse_moe"
                cosine_tensor = torch.tensor(cosine_values)
                
                
                weight_cosine_metric[ffn_name][0] = torch.max(cosine_tensor)      
                weight_cosine_metric[ffn_name][1] = torch.min(cosine_tensor)      
                weight_cosine_metric[ffn_name][2] = torch.mean(cosine_tensor)     
                weight_cosine_metric[ffn_name][3] = torch.std(cosine_tensor)      
            
            
            del expert_weights, cosine_values
        
        metrics_results["weight-cosine"] = weight_cosine_metric
        print("Weight-Cosine metric calculated")
        
        
        del weight_cosine_metric, cosine_tensor
        torch.cuda.empty_cache()

    if "mutual_information" in metric:
        
        
        
        t = sigmoid_t
        try:
            json_path = os.path.join("importance_score", "layer_wise", "mixtral_mutual_information_values.json")
            with open(json_path, "r", encoding="utf-8") as f:
                mi_data = json.load(f)
            mi_metric = {}
            num_layers = len(model.model.layers)
            for layer_idx in range(num_layers):
                layer_key = f"layer_{layer_idx}"
                layer_name = f"model.layers.{layer_idx}.block_sparse_moe"
                if layer_key in mi_data:
                    final_score = float(mi_data[layer_key].get("mutual_information_sum", 0.0))
                else:
                    final_score = 0.0
                final_score = 1 / (1 + np.exp(-t * final_score))
                final_score = 1 - final_score
                mi_metric[layer_name] = torch.tensor(final_score)
            metrics_results["mutual_information"] = mi_metric
            print("Mutual-Information metric loaded")
        except Exception as e:
            print(f"read mutual_information JSON failure: {e}")

    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    torch.save(metrics_results, os.path.join(save_dir, save_name))
    print(f"Dumped to {os.path.join(save_dir, save_name)}")
    return metrics_results

@torch.inference_mode()
def load_model_and_dump_metrics(
    model_name: str = "/Path/Mixtral-8x7B-v0.1",
    calib_set: Optional[str] = "c4",
    opt_batch_size: Optional[int] = 16,
    n_sentences: Optional[int] = 16,
    metric_save_dir: Optional[str] = "./results/metrics",
    metrics: list[str] = ['all'],
    override: bool = False,
    sigmoid_t: float = 0.5,
):
    print(f"Get OWL Metric on C4.\n Model: {model_name}\n opt_batch_size={opt_batch_size}, n_sentences={n_sentences}")

    
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    tokenizer.pad_token_id = tokenizer.eos_token_id
    model = MixtralForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.float16, device_map="auto"
    )
    model.eval()
    
    dataloader = get_calib_dataloder(
        dataset=calib_set,
        tokenizer=tokenizer,
        max_block_size=2048,
        n_blocks_for_stat=n_sentences,
        batch_size=opt_batch_size,
        num_workers=8,
    )

    
    metrics_file = f"mixtral_moe_metrics.pt"
    _ = dump_mixtral_moe_metrics(
        model=model, 
        metric=metrics, 
        dataloader=dataloader, 
        save_dir=metric_save_dir, 
        save_name=metrics_file, 
        override=override,
        sigmoid_t=sigmoid_t,
    )
    
    
    del model, tokenizer, dataloader
    torch.cuda.empty_cache()
    
    return metrics_file

def adapt_compress_mixtral(
    model_name: str = "/Path/Mixtral-8x7B-v0.1",
    calib_set: Optional[str] = "c4", 
    opt_batch_size: Optional[int] = 16,
    n_sentences: Optional[int] = 16,
    skip_metrics_calculation: bool = True, 
    metrics: list[str] = ['all'], 
    objective_type: str = None,  
    
    min_margin: int = None,
    max_neighbor_diff: Optional[int] = None,  

    
    prune_method: str = "mutual_information", 
    pruning_ratio: float = 0.75,
    sigmoid_t: float = 0.001, 

    
    merge_method: str = "output-cosine",  
    stat_type: int = 0,  
    merging_ratio: float = 0.67,  
    
    
    optimize_order: str = "prune_first",  
    
    
    save_opt: bool = True, 
    metric_save_dir: Optional[str] = "./output/metrics",
    opt_save_dir: Optional[str] = "./output/opt",
    override: bool = False, 
    
    
    model: Optional[MixtralForCausalLM] = None,
    tokenizer: Optional[AutoTokenizer] = None,    
):
    
    
    stat_type_str = STAT_TYPE_DICT.get(stat_type, "") if stat_type is not None else ""
    opt_save_dir = os.path.join(opt_save_dir, f"{optimize_order}")
    
    
    if stat_type_str:
        
        result_file_path = os.path.join(opt_save_dir, f"mixtral_{optimize_order}_{prune_method}_{pruning_ratio}_{stat_type_str}{merge_method}_{merging_ratio}_{objective_type}.yaml")
    else:
        
        result_file_path = os.path.join(opt_save_dir, f"mixtral_{optimize_order}_{prune_method}_{pruning_ratio}_{merge_method}_{merging_ratio}_{objective_type}.yaml")
    
    if os.path.exists(result_file_path) and not override:

        with open(result_file_path, 'r') as f:
            result_dict = yaml.safe_load(f)
        results = {}
        results["merge_experts"] = np.array(result_dict["merge_experts"])
        results["prune_experts"] = np.array(result_dict["prune_experts"])
            

        return results
    
    
    metrics_file = f"mixtral_moe_metrics.pt"
    
    
    metrics_file_path = os.path.join(metric_save_dir, metrics_file)
    if not skip_metrics_calculation and (not os.path.exists(metrics_file_path) or override):
        
        if model is None or tokenizer is None:
            metrics_file = load_model_and_dump_metrics(
                model_name=model_name,
                calib_set=calib_set,
                opt_batch_size=opt_batch_size,
                n_sentences=n_sentences,
                metric_save_dir=metric_save_dir,
                metrics=metrics,
                override=override,
                sigmoid_t=sigmoid_t,
            )
        else:
            
            dataloader = get_calib_dataloder(
                dataset=calib_set,
                tokenizer=tokenizer,
                max_block_size=2048,
                n_blocks_for_stat=n_sentences,
                batch_size=opt_batch_size,
                num_workers=8,
            )
            
            
            _ = dump_mixtral_moe_metrics(
                model=model, 
                metric=metrics, 
                dataloader=dataloader, 
                save_dir=metric_save_dir, 
                save_name=metrics_file, 
                override=override,
                sigmoid_t=sigmoid_t,
            )
    
    
    

    results = {}
    num_layers = 32
    num_experts = 8

    time_start = time.time()
    if optimize_order == "prune_first":

        
        
        prune_nums, rounding_residuals = cal_prune_num(
            save_dir=metric_save_dir,
            save_name=metrics_file,
            metric=prune_method,
            stat_type=stat_type,
            min_margin=min_margin,
            pruning_ratio=pruning_ratio,
            optimize_order=optimize_order,  
            max_neighbor_diff=max_neighbor_diff,  
            objective_type=objective_type,  
            return_residuals=True,
        )
        results["prune_experts"] = prune_nums
        
        
        
        custom_max_experts = {}
        for layer_idx in range(num_layers):
            layer_name = f"model.layers.{layer_idx}.block_sparse_moe"
            custom_max_experts[layer_name] = prune_nums[layer_idx]
        
        target_experts = cal_merge_num(
            save_dir=metric_save_dir,
            save_name=metrics_file,
            metric=merge_method,
            stat_type=stat_type,
            min_margin=min_margin,
            merging_ratio=merging_ratio,
            custom_max_experts=custom_max_experts,  
            optimize_order=optimize_order,  
            objective_type=objective_type,  
            max_neighbor_diff=max_neighbor_diff,  
            rounding_residuals=rounding_residuals,
        )
        results["merge_experts"] = target_experts
        
    elif optimize_order == "merge_first":

        
        
        merge_nums, rounding_residuals = cal_merge_num(
            save_dir=metric_save_dir,
            save_name=metrics_file,
            metric=merge_method,
            stat_type=stat_type,
            min_margin=min_margin,
            merging_ratio=merging_ratio,
            optimize_order=optimize_order,  
            objective_type=objective_type,  
            max_neighbor_diff=max_neighbor_diff,  
            return_residuals=True,
        )
        results["merge_experts"] = merge_nums
        
        
        
        custom_max_experts = {}
        for layer_idx in range(num_layers):
            layer_name = f"model.layers.{layer_idx}.block_sparse_moe"
            custom_max_experts[layer_name] = merge_nums[layer_idx]
        
        prune_nums = cal_prune_num(
            save_dir=metric_save_dir,
            save_name=metrics_file,
            metric=prune_method,
            stat_type=stat_type,
            min_margin=min_margin,
            pruning_ratio=pruning_ratio,
            custom_max_experts=custom_max_experts,  
            optimize_order=optimize_order,  
            max_neighbor_diff=max_neighbor_diff,  
            objective_type=objective_type,  
            rounding_residuals=rounding_residuals,
        )
        results["prune_experts"] = prune_nums
    
    elif optimize_order == "prune_only":

        
        
        prune_nums = cal_prune_num(
            save_dir=metric_save_dir,
            save_name=metrics_file,
            metric=prune_method,
            stat_type=stat_type,
            min_margin=min_margin,
            pruning_ratio=pruning_ratio,
            optimize_order=optimize_order,  
            max_neighbor_diff=max_neighbor_diff,  
            objective_type=objective_type,  
        )
        results["prune_experts"] = prune_nums
        results["merge_experts"] = prune_nums  
        
    elif optimize_order == "merge_only":

        
        
        target_experts = cal_merge_num(
            save_dir=metric_save_dir,
            save_name=metrics_file,
            metric=merge_method,
            stat_type=stat_type,
            min_margin=min_margin,
            merging_ratio=merging_ratio,
            optimize_order=optimize_order,  
            max_neighbor_diff=max_neighbor_diff,  
            objective_type=objective_type,  
        )
        results["merge_experts"] = target_experts
        results["prune_experts"] = target_experts  
    
    else:
        raise ValueError(f"Unsupported optimize order: {optimize_order}，Please choose 'prune_first', 'merge_first', 'prune_only' or 'merge_only'")
        
    
    print(f"\n===== Optimized result summary =====")
    print(f"prune_experts: {sum(results['prune_experts'])}")
    print(f"merge_experts: {sum(results['merge_experts'])}")
    print(f"Compression rate: {sum(results['merge_experts'])/(num_experts * num_layers):.4f}")
    print(f"Optimization time: {time.time() - time_start:.2f}s") 
    
    
    if save_opt:
        if not os.path.exists(opt_save_dir):
            os.makedirs(opt_save_dir)
        with open(result_file_path, "w") as f:
            
            results_dict = {}
            
            results_dict["prune_experts"] = results['prune_experts'].tolist()
            results_dict["merge_experts"] = results['merge_experts'].tolist()
            
            
            yaml.dump(results_dict, f, default_flow_style=False, sort_keys=False, indent=4)
    
    return results

if __name__ == "__main__":
    Fire(adapt_compress_mixtral)
