import os
import numpy as np
import time
from typing import Optional, Dict, List, Union, Tuple
import logging
import torch
import json
from argparse import Namespace
from fire import Fire
from transformers import Qwen2MoeForCausalLM, AutoTokenizer, AutoModelForCausalLM

from eval import get_calib_dataloder, evaluate_fewshot
from merge_method import (
    run_qwen_tamp,
    adapt_compress_qwen,
    cal_prune_num,
    cal_merge_num,
    dump_qwen_moe_metrics,
    STAT_TYPE_DICT
)
from utils import (
    eval_task_to_list,
    load_opt_config,
    save_opt_config,
    set_project_seed,
    gen_smooth_staircase
)

from extend import evaluation, ppl_eval

from prune_method import adapt_prune
from numpy.random import default_rng

logger = logging.getLogger(__name__)





PRESET_RATIOS = {
    0: {"pruning_ratio": 0.5454545454545454, "merging_ratio": 0.9166666666666667},  
    1: {"pruning_ratio": 0.5909090909090909, "merging_ratio": 0.8461538461538461},  
    2: {"pruning_ratio": 0.6363636363636364, "merging_ratio": 0.7857142857142857},  
    3: {"pruning_ratio": 0.6818181818181819, "merging_ratio": 0.7333333333333333},  
    4: {"pruning_ratio": 0.7272727272727273, "merging_ratio": 0.6875},             
    5: {"pruning_ratio": 0.7727272727272727, "merging_ratio": 0.6470588235294118},  
    6: {"pruning_ratio": 0.8181818181818181, "merging_ratio": 0.6111111111111112},  
    7: {"pruning_ratio": 0.8636363636363636, "merging_ratio": 0.5789473684210527},  
    8: {"pruning_ratio": 0.9090909090909092, "merging_ratio": 0.5499999999999999},  
    9: {"pruning_ratio": 0.9545454545454546, "merging_ratio": 0.5238095238095238},  
}

def call_expert_sparsity(
    model_path: str,
    output_path: str,
    max_block_size: int = 2048,
    calib_set: str = "c4",
    prune: str = "layerwise_pruning_qwen",
    experts_per_layer: Optional[List[int]] = None,
    n_blocks_for_stat: int = 32,
    batch_size: int = 16,
    seed: int = 42,
    eval_tasks: Optional[List[str]] = None,
    eval_path: Optional[str] = None,
    config_path: Optional[str] = None,
    model = None,
    tokenizer = None,
    result_save_path: Optional[str] = None,
    opt_save_dir: Optional[str] = None,
    model_output_dir: Optional[str] = None,
) -> Tuple[str, Qwen2MoeForCausalLM]:
    os.makedirs(output_path, exist_ok=True)
    
    
    args = Namespace(
        method=prune,
        r=None,  
        calib_set=calib_set,
        model_path=model_path,
        output_path=output_path,
        max_block_size=max_block_size,
        n_blocks_for_stat=n_blocks_for_stat,
        batch_size=batch_size,
        num_workers=8,
        seed=seed,
        use_flash_attention_2=False,
        layer_index=None,  
        evaluate_model=False,
        config_path=config_path,
        expert_config="prune_experts",
        
        result_save_path=result_save_path,
        opt_save_dir=opt_save_dir,
        model_output_dir=model_output_dir,
        
        provided_model=model,
        provided_tokenizer=tokenizer
    )
    
    
    logger.info(f"Executing Expert Sparsity with method: {prune}")
    pruned_model = adapt_prune(args)

    
    pruned_model.config.layer_num_experts = load_opt_config(config_path)['prune_experts']
    
    
    pruned_model_path = os.path.join(output_path, "pruned_model")
    
    
    
    
    
    
    
    
    
    
    
    return pruned_model_path, pruned_model

def call_hcmoe(
    model_path: str,
    output_path: str,
    task: str = "mmlu",
    num_average_groups: Optional[List[int]] = None,
    dominant: str = "no",
    group: str = "expert",
    similarity_base: str = "router-logits",
    merge: str = "zipit",
    mode: str = "normal",
    n_sentences: int = 16,
    max_block_size: int = 2048,
    train_batch_size: int = 16,
    eval_batch_size: int = 32,
    num_fewshot: int = 0,
    dynamic_group: bool = False,
    opt_path: Optional[str] = None,
    model = None,
    tokenizer = None,
    cluster: Optional[str] = "hierarchical",
    linkage: Optional[str] = "ward",
    hierarchical_stopping_metric: Optional[str] = "silhouette",
    ingredient: Optional[str] = "act",
    overlap_metric: Optional[str] = "cosine",
    calib_set: str = "c4",
    seed: int = 42,
) -> Tuple[str, Qwen2MoeForCausalLM]:
    os.makedirs(output_path, exist_ok=True)
    
    
    result_path = os.path.join(output_path, "eval_results.json")
    
    
    logger.info(f"Executing HCMOE with similarity_base: {similarity_base}, merge: {merge}, mode: {mode}")
    
    merged_model = run_qwen_tamp(
        task=task,
        num_average_groups=num_average_groups[0],
        model_name=model_path,
        dominant=dominant,
        group=group,
        similarity_base=similarity_base,
        merge=merge,
        mode=mode,
        n_sentences=n_sentences,
        max_block_size=max_block_size,
        calib_set=calib_set,
        train_batch_size=train_batch_size,
        eval_batch_size=eval_batch_size,
        seed=seed,
        partition=1,
        start_layer=0,
        output_path=output_path,
        result_path=result_path,
        model_path=None,
        group_limit=4,
        data_limit=50000,
        num_fewshot=num_fewshot,
        try_oracle=False,
        random_start_center=False,
        weight=None,
        cluster=cluster,
        linkage=linkage,
        hierarchical_stopping_metric=hierarchical_stopping_metric,
        ingredient=ingredient,
        overlap_metric=overlap_metric,
        dynamic_group=dynamic_group,
        opt_path=opt_path,
        
        model=model,
        tokenizer=tokenizer
    )

    
    merged_model.config.layer_num_experts = load_opt_config(opt_path)['merge_experts']

    
    merged_model_path = output_path
    return merged_model_path, merged_model

def adapt_qwen(
    original_model_name: str = "/Path/Qwen1.5-MoE-A2.7B-Chat",
    
    
    adapt_mode: str = "adapt",  
    preset_id: Optional[int] = None,  
    
    
    uniform_prune_experts: Optional[int] = None,
    uniform_merge_experts: Optional[int] = None,
    min_margin: Optional[int] = 2,  
    max_neighbor_diff: Optional[int] = None,  
    
    
    customized_prune_experts: Optional[List[int]] = None,
    customized_merge_experts: Optional[List[int]] = None,
    
    
    prune_method: Optional[str] = 'weight-outlier', 
    pruning_ratio: Optional[float] = None, 
    
    
    merge_method: Optional[str] = "output-cosine",  
    stat_type: Optional[int] = 3,  
    merging_ratio: Optional[float] = None,  
    
    
    objective_type: Optional[str] = "square",  
    optimize_order: Optional[str] = "prune_first",  

    
    dominant: Optional[str] = "no", 
    group: Optional[str] = "expert", 
    similarity_base: Optional[str] = "expert-output", 
    merge: Optional[str] = "fix-dom-same", 
    mode: Optional[str] = "normal", 
    cluster: Optional[str] = "hierarchical", 
    linkage: Optional[str] = "single", 
    hierarchical_stopping_metric: Optional[str] = "silhouette", 
    ingredient: Optional[str] = "act", 
    overlap_metric: Optional[str] = "cosine", 

    
    prune: Optional[str] = "layerwise_pruning_qwen", 
    
    
    evaluation_mode: Optional[str] = "accuracy", 
    calib_set: str = "c4",
    max_block_size: int = 2048,
    n_sentences: int = 16,
    train_batch_size: int = 16,
    eval_batch_size: int = 8,
    num_fewshot: int = 0,
    seed: int = 42,
    rng_seed: Optional[int] = 42,
    eval_task: Union[List[str], str, None] = ["wikitext", "arc_challenge", "arc_easy", "boolq", "hellaswag", "mmlu", "openbookqa", "rte", "winogrande"],
    
    
    skip_metrics_calculation: bool = False,  
    metrics: Union[list[str], str] = ['all'],  
    override: bool = False,  
    
    
    save_checkpoint: bool = False,  
    metrics_save_dir: Optional[str] = "./output/metrics/qwen",
    opt_save_dir: Optional[str] = "./output/opt/qwen",
    result_output_dir: Optional[str] = "./output/results/qwen",
    model_output_dir: Optional[str] = "./output/model/qwen",
):
    
    os.makedirs(result_output_dir, exist_ok=True)
    os.makedirs(opt_save_dir, exist_ok=True)
    os.makedirs(model_output_dir, exist_ok=True)
    os.makedirs(os.path.join(opt_save_dir, optimize_order), exist_ok=True)
    set_project_seed(seed=seed)
    
    
    logger.info(f"Loading model from {original_model_name}")
    model = AutoModelForCausalLM.from_pretrained(
        original_model_name,
        torch_dtype=torch.float16,
        device_map="auto",
        attn_implementation="flash_attention_2"
    )
    model.config.layer_num_experts = [model.config.num_experts for _ in range(model.config.num_hidden_layers)]
    tokenizer = AutoTokenizer.from_pretrained(original_model_name)
    tokenizer.pad_token_id = tokenizer.eos_token_id
    model_name = original_model_name.split("/")[-1]

    
    original_model_params = model.num_parameters()
    
    
    if preset_id is not None and adapt_mode == "adapt":
        
        model_path_suffix = f"{model_name}_preset{preset_id}"
    elif stat_type is not None:
        model_path_suffix = f"{model_name}_{prune_method}_{pruning_ratio}_{STAT_TYPE_DICT[stat_type]}-{merge_method}_{merging_ratio}"
    else:
        model_path_suffix = f"{model_name}_{prune_method}_{pruning_ratio}_{merge_method}_{merging_ratio}"
    opt_config_path = os.path.join(opt_save_dir, f"{optimize_order}/{model_path_suffix}.yaml")
    os.makedirs(os.path.join(result_output_dir, f"{optimize_order}"), exist_ok=True)

    
    if adapt_mode not in ["adapt", "uniform", "owl", "random", "up", "down", "customize"]:
        raise ValueError(f"Unsupported adapt_mode: {adapt_mode}. Must be 'adapt', 'uniform', 'owl', 'random', 'up', 'down', or 'customize'.")
    
    
    if preset_id is not None:
        if preset_id not in PRESET_RATIOS:
            raise ValueError(f"Invalid preset_id: {preset_id}. Must be in range [0-9].")
        
        if adapt_mode == "adapt":
            
            preset = PRESET_RATIOS[preset_id]
            pruning_ratio = preset["pruning_ratio"]
            merging_ratio = preset["merging_ratio"]
            logger.info(f"Using preset {preset_id}: pruning_ratio={pruning_ratio:.4f}, merging_ratio={merging_ratio:.4f}")
            logger.info(f"Combined effect (product): {pruning_ratio * merging_ratio:.4f}")
        else:
            logger.warning(f"preset_id is only effective when adapt_mode='adapt'. Ignoring preset_id={preset_id}.")
    
    
    if adapt_mode == "uniform":
        if uniform_prune_experts is None and optimize_order in ["prune_first", "prune_only"]:
            raise ValueError("In uniform mode with prune_first or prune_only, uniform_prune_experts must be specified.")
        if uniform_merge_experts is None and optimize_order in ["merge_first", "merge_only"]:
            raise ValueError("In uniform mode with merge_first or merge_only, uniform_merge_experts must be specified.")
    
    
    if adapt_mode == "customize":
        
        if customized_prune_experts is None and optimize_order in ["prune_first", "prune_only"]:
            raise ValueError("In customize mode with prune_first or prune_only, customized_prune_experts must be specified.")
        if customized_merge_experts is None and optimize_order in ["merge_first", "merge_only"]:
            raise ValueError("In customize mode with merge_first or merge_only, customized_merge_experts must be specified.")
            
        
        if customized_prune_experts is not None and len(customized_prune_experts) != 32:
            raise ValueError(f"customized_prune_experts must have length 32, got {len(customized_prune_experts)}")
        if customized_merge_experts is not None and len(customized_merge_experts) != 32:
            raise ValueError(f"customized_merge_experts must have length 32, got {len(customized_merge_experts)}")
            
        
        if customized_prune_experts is not None:
            for i, num in enumerate(customized_prune_experts):
                if num < 2 or num > model.config.num_experts:
                    raise ValueError(f"customized_prune_experts[{i}] = {num} is out of range [2, {model.config.num_experts}]")
        if customized_merge_experts is not None:
            for i, num in enumerate(customized_merge_experts):
                if num < 2 or num > model.config.num_experts:
                    raise ValueError(f"customized_merge_experts[{i}] = {num} is out of range [2, {model.config.num_experts}]")
    
    
    if not os.path.exists(opt_config_path) or override:
        logger.info("Computing metrics and optimizing expert counts...")
        
        os.makedirs(metrics_save_dir, exist_ok=True)
        
        if adapt_mode == "adapt":
            
            logger.info("Using adaptive expert allocation mode")
            
            
            results = adapt_compress_qwen(
                model_name=original_model_name,
                calib_set=calib_set,
                opt_batch_size=train_batch_size,
                n_sentences=n_sentences,
                skip_metrics_calculation=skip_metrics_calculation,  
                metrics=metrics if isinstance(metrics, list) else metrics.split(","),  
                min_margin=min_margin,
                max_neighbor_diff=max_neighbor_diff,  
                prune_method=prune_method,
                pruning_ratio=pruning_ratio,
                merge_method=merge_method,
                stat_type=stat_type,
                merging_ratio=merging_ratio,
                optimize_order=optimize_order,
                objective_type=objective_type,
                save_opt=False,  
                metric_save_dir=metrics_save_dir,
                opt_save_dir=opt_save_dir,
                override=override,
                
                model=model,
                tokenizer=tokenizer
            )
            
            
            prune_experts = results["prune_experts"]
            merge_experts = results["merge_experts"]
        elif adapt_mode == "customize":
            logger.info("Using customize expert allocation mode")
            
            
            if optimize_order in ["prune_first", "prune_only"]:
                prune_experts = customized_prune_experts
                
                if optimize_order == "prune_only" or customized_merge_experts is None:
                    merge_experts = prune_experts.copy()
                else:
                    merge_experts = customized_merge_experts
                    
                    merge_experts = [min(m, p) for m, p in zip(merge_experts, prune_experts)]
            else:  
                merge_experts = customized_merge_experts
                
                if optimize_order == "merge_only" or customized_prune_experts is None:
                    prune_experts = merge_experts.copy()
                else:
                    prune_experts = customized_prune_experts
                    
                    prune_experts = [min(p, m) for p, m in zip(prune_experts, merge_experts)]
                    
            logger.info(f"Customized prune experts: {prune_experts}")
            logger.info(f"Customized merge experts: {merge_experts}")
        elif adapt_mode == "owl":
            logger.info("Using OWL expert allocation mode")
            pass
        elif adapt_mode == "random":
            logger.info("Using random expert allocation mode")
            
            
            np.random.seed(seed)
            num_layers = model.config.num_hidden_layers
            original_experts_per_layer = model.config.num_experts
            min_experts = model.config.num_experts_per_tok  
            max_experts = model.config.num_experts  
            
            if optimize_order in ["merge_only", "prune_only"]:
                
                total_experts = num_layers * original_experts_per_layer
                
                ratio = merging_ratio if optimize_order == "merge_only" else pruning_ratio
                target_total_experts = int(total_experts * ratio)
                logger.info(f"Target total experts after merging: {target_total_experts}")
                
                
                while True:
                    
                    random_experts = np.random.randint(min_experts, max_experts + 1, size=num_layers)
                    
                    current_total = np.sum(random_experts)
                    
                    
                    if abs(current_total - target_total_experts) <= num_layers // 2:
                        step2_experts = random_experts
                        
                        step1_experts = random_experts.copy()
                        break
            elif optimize_order in ["prune_first", "merge_first"]:
                
                total_experts = num_layers * original_experts_per_layer
                
                
                target_experts = int(total_experts * pruning_ratio) if optimize_order == "prune_first" else int(total_experts * merging_ratio)
                logger.info(f"Target total experts after step1: {target_experts}")
                
                
                while True:
                    random_experts = np.random.randint(min_experts, max_experts + 1, size=num_layers)
                    current_total = np.sum(random_experts)
                    
                    if abs(current_total - target_experts) <= num_layers // 2:
                        step1_experts = random_experts
                        break
                
                
                target_experts = int(np.sum(step1_experts) * merging_ratio) if optimize_order == "prune_first" else int(np.sum(step1_experts) * pruning_ratio)
                logger.info(f"Target total experts after step2: {target_experts}")
                
                
                while True:
                    random_merge_experts = np.array([np.random.randint(min_experts, min(step1_experts[i], max_experts) + 1) for i in range(num_layers)])
                    current_total = np.sum(random_merge_experts)
                    
                    if abs(current_total - target_experts) <= num_layers // 2:
                        step2_experts = random_merge_experts
                        break
                
            if optimize_order in ["prune_first", "prune_only"]:
                prune_experts = step1_experts
                merge_experts = step2_experts
            elif optimize_order in ["merge_first", "merge_only"]:
                prune_experts = step2_experts
                merge_experts = step1_experts
            
            
            logger.info(f"Random prune experts: {prune_experts}")
            logger.info(f"Random merge experts: {merge_experts}")
        elif adapt_mode in ["up", "down"]:
            logger.info("Using up/down expert allocation mode")
            
            num_layers = model.config.num_hidden_layers
            original_experts_per_layer = model.config.num_experts
            min_experts = model.config.num_experts_per_tok  
            max_experts = model.config.num_experts  

            target_total_experts = num_layers * original_experts_per_layer * merging_ratio if optimize_order == "merge_only" else num_layers * original_experts_per_layer * pruning_ratio
            
            step_experts = gen_smooth_staircase(
                num_layers   = num_layers,
                min_experts  = min_experts,
                max_experts  = max_experts,
                target_total = target_total_experts,
                adapt_mode   = adapt_mode,
                base_noise   = 0.25,   
                max_step     = 2,      
                rng_seed     = rng_seed    
            )
            step_experts = np.array(sorted(step_experts, reverse=True if adapt_mode == "down" else False))
            
            
            if optimize_order == "merge_only":
                merge_experts = step_experts
                prune_experts = step_experts.copy()  
            else:  
                prune_experts = step_experts
                merge_experts = step_experts.copy()  
            
            logger.info(f"Up allocation generated experts configuration with sum: {sum(step_experts)}")
                    

            
            logger.info(f"Up/Down mode prune experts: {prune_experts}")
            logger.info(f"Up/Down mode merge experts: {merge_experts}")
            
            
            assert all(min_experts <= e <= max_experts for e in prune_experts), "Prune experts configuration violates min/max constraints"
            assert all(min_experts <= e <= max_experts for e in merge_experts), "Merge experts configuration violates min/max constraints"
        else:
            
            logger.info("Using uniform expert allocation mode")
            
            num_layers = model.config.num_hidden_layers
            
            if optimize_order in ["prune_first", "prune_only"]:
                
                prune_experts = np.full(num_layers, uniform_prune_experts, dtype=np.int32)
                logger.info(f"Using uniform pruning with {uniform_prune_experts} experts per layer")
                
                if optimize_order == "prune_first" and uniform_merge_experts is not None:
                    
                    merge_experts = np.full(num_layers, min(uniform_prune_experts, uniform_merge_experts), dtype=np.int32)
                    logger.info(f"Using uniform merging with {min(uniform_prune_experts, uniform_merge_experts)} experts per layer")
                elif optimize_order == "prune_first":
                    
                    merge_experts = prune_experts.copy()
                    logger.info(f"No uniform_merge_experts specified, using prune_experts value")
                else:
                    
                    merge_experts = prune_experts.copy()
            else:
                
                if uniform_merge_experts is not None:
                    merge_experts = np.full(num_layers, uniform_merge_experts, dtype=np.int32)
                    logger.info(f"Using uniform merging with {uniform_merge_experts} experts per layer")
                else:
                    raise ValueError("uniform_merge_experts must be specified for merge_first or merge_only in uniform mode")
                
                if optimize_order == "merge_first" and uniform_prune_experts is not None:
                    
                    prune_experts = np.full(num_layers, min(uniform_merge_experts, uniform_prune_experts), dtype=np.int32)
                    logger.info(f"Using uniform pruning with {min(uniform_merge_experts, uniform_prune_experts)} experts per layer")
                elif optimize_order == "merge_first":
                    
                    prune_experts = merge_experts.copy()
                    logger.info(f"No uniform_prune_experts specified, using merge_experts value")
                else:
                    
                    prune_experts = merge_experts.copy()
        
        
        config = {
            "prune_experts": prune_experts.tolist() if isinstance(prune_experts, np.ndarray) else prune_experts,
            "merge_experts": merge_experts.tolist() if isinstance(merge_experts, np.ndarray) else merge_experts
        }
        save_opt_config(config, opt_config_path)
        logger.info(f"Saved expert configuration to {opt_config_path}")
    
    
    opt_config = load_opt_config(opt_config_path)
    
    prune_experts_per_layer = opt_config["prune_experts"]
    merge_experts_per_layer = opt_config["merge_experts"]
    
    logger.info(f"Pruning experts per layer: {prune_experts_per_layer}")
    logger.info(f"Merging experts per layer: {merge_experts_per_layer}")
    
    
    final_model_path = None
    final_model = None
    
    if optimize_order == "prune_first":
        logger.info("Execute order: Prune first")
        
        
        model_output_dir = os.path.join(model_output_dir, f"{optimize_order}/{model_path_suffix}")
        
        
        result_path = os.path.join(result_output_dir, f"{optimize_order}/{model_path_suffix}")
        
        
        logger.info("Step 1: Pruning experts using Expert Sparsity method")
        pruned_model_path, pruned_model = call_expert_sparsity(
            model_path=original_model_name,
            output_path=model_output_dir,
            max_block_size=max_block_size,
            calib_set=calib_set,
            experts_per_layer=prune_experts_per_layer,
            prune=prune,
            n_blocks_for_stat=n_sentences,
            batch_size=train_batch_size,
            seed=seed,
            eval_tasks=eval_task_to_list(eval_task),
            eval_path=result_path,
            config_path=opt_config_path,
            model=model,
            tokenizer=tokenizer,
            result_save_path=result_path,
            opt_save_dir=os.path.join(opt_save_dir, optimize_order),
            model_output_dir=model_output_dir
        )
        
        
        logger.info("Step 2: Merging experts using HCMOE method")
        merged_model_path, merged_model = call_hcmoe(
            model_path=pruned_model_path,
            output_path=model_output_dir,
            max_block_size=max_block_size,
            task=eval_task[0] if isinstance(eval_task, list) else eval_task,
            num_average_groups=merge_experts_per_layer,
            dominant=dominant,
            group=group,
            similarity_base=similarity_base,
            merge=merge,
            mode=mode,
            n_sentences=n_sentences,
            train_batch_size=train_batch_size,
            eval_batch_size=eval_batch_size,
            seed=seed,
            num_fewshot=num_fewshot,
            dynamic_group=True,
            opt_path=opt_config_path,
            model=pruned_model,
            tokenizer=tokenizer,
            cluster=cluster,
            linkage=linkage,
            hierarchical_stopping_metric=hierarchical_stopping_metric,
            ingredient=ingredient,
            overlap_metric=overlap_metric,
            calib_set=calib_set
        )
        
        final_model_path = merged_model_path
        final_model = merged_model

    elif optimize_order == "merge_first":
        logger.info("Execute order: Merge first")
        
        
        model_output_dir = os.path.join(model_output_dir, f"{optimize_order}/{model_path_suffix}")
        
        result_path = os.path.join(result_output_dir, f"{optimize_order}/{model_path_suffix}")
        
        
        logger.info("Step 1: Merging experts using HCMOE method")
        merged_model_path, merged_model = call_hcmoe(
            model_path=original_model_name,
            output_path=model_output_dir,
            task=eval_task_to_list(eval_task),
            num_average_groups=merge_experts_per_layer,
            dominant=dominant,
            group=group,
            similarity_base=similarity_base,
            merge=merge,
            mode=mode,
            n_sentences=n_sentences,
            max_block_size=max_block_size,
            train_batch_size=train_batch_size,
            eval_batch_size=eval_batch_size,
            seed=seed,
            num_fewshot=num_fewshot,
            dynamic_group=True,
            opt_path=opt_config_path,
            model=model,
            tokenizer=tokenizer,
            cluster=cluster,
            linkage=linkage,
            hierarchical_stopping_metric=hierarchical_stopping_metric,
            ingredient=ingredient,
            overlap_metric=overlap_metric,
            calib_set=calib_set
        )
        
        
        logger.info("Step 2: Pruning experts using Expert Sparsity method")
        pruned_model_path, pruned_model = call_expert_sparsity(
            model_path=merged_model_path,
            max_block_size=max_block_size,
            output_path=model_output_dir,
            calib_set=calib_set,
            experts_per_layer=prune_experts_per_layer,
            prune=prune,
            n_blocks_for_stat=n_sentences,
            batch_size=train_batch_size,
            seed=seed,
            eval_tasks=eval_task_to_list(eval_task),
            eval_path=result_path,
            config_path=opt_config_path,
            model=merged_model,
            tokenizer=tokenizer,
            result_save_path=result_path,
            opt_save_dir=os.path.join(opt_save_dir, optimize_order),
            model_output_dir=model_output_dir
        )
        
        final_model_path = pruned_model_path
        final_model = pruned_model
        
    elif optimize_order == "prune_only":
        logger.info("Execute order: Prune only")
        
        
        model_output_dir = os.path.join(model_output_dir, f"{optimize_order}/{model_path_suffix}")
        
        
        result_path = os.path.join(result_output_dir, f"{optimize_order}/{model_path_suffix}")
        
        
        logger.info("Step 1: Pruning experts using Expert Sparsity method")
        pruned_model_path, pruned_model = call_expert_sparsity(
            model_path=original_model_name,
            max_block_size=max_block_size,
            output_path=model_output_dir,
            calib_set=calib_set,
            experts_per_layer=prune_experts_per_layer,
            prune=prune,
            n_blocks_for_stat=n_sentences,
            batch_size=train_batch_size,
            seed=seed,
            eval_tasks=eval_task_to_list(eval_task),
            eval_path=result_path,
            config_path=opt_config_path,
            model=model,
            tokenizer=tokenizer,
            result_save_path=result_path,
            opt_save_dir=os.path.join(opt_save_dir, optimize_order),
            model_output_dir=model_output_dir
        )
        
        final_model_path = pruned_model_path
        final_model = pruned_model
        
    elif optimize_order == "merge_only":
        logger.info("Execute order: Merge only")
        
        
        model_output_dir = os.path.join(model_output_dir, f"{optimize_order}/{model_path_suffix}")
        
        
        result_path = os.path.join(result_output_dir, f"{optimize_order}/{model_path_suffix}")
        
        
        logger.info("Step 1: Merging experts using HCMOE method")
        merged_model_path, merged_model = call_hcmoe(
            model_path=original_model_name,
            output_path=model_output_dir,
            task=eval_task_to_list(eval_task),
            num_average_groups=merge_experts_per_layer,
            dominant=dominant,
            group=group,
            similarity_base=similarity_base,
            merge=merge,
            mode=mode,
            n_sentences=n_sentences,
            max_block_size=max_block_size,
            train_batch_size=train_batch_size,
            eval_batch_size=eval_batch_size,
            num_fewshot=num_fewshot,
            dynamic_group=True,
            opt_path=opt_config_path,
            model=model,
            tokenizer=tokenizer,
            cluster=cluster,
            linkage=linkage,
            hierarchical_stopping_metric=hierarchical_stopping_metric,
            ingredient=ingredient,
            overlap_metric=overlap_metric,
            calib_set=calib_set
        )
        
        final_model_path = merged_model_path
        final_model = merged_model
    else:
        raise ValueError(f"Unsupported optimize_order: {optimize_order}. Must be 'prune_first', 'merge_first', 'prune_only', or 'merge_only'.")
    
    
    logger.info(f"Final adapted model saved at: {final_model_path}")
    logger.info("Evaluating final model...")
    
    
    if final_model is None:
        final_model = Qwen2MoeForCausalLM.from_pretrained(
            final_model_path,
            torch_dtype=torch.bfloat16,
            device_map="auto"
        )
    
    
    result_path = os.path.join(result_output_dir, f"{optimize_order}/{model_path_suffix}_final_eval_results.txt")
    if evaluation_mode == "accuracy":
        evaluate_fewshot(
            model=final_model,
            tokenizer=tokenizer,
        task=eval_task_to_list(eval_task),
        num_fewshot=num_fewshot,
        eval_batch_size=eval_batch_size,
        output_path=result_path,
        log=True
        )
        ppls = ppl_eval(final_model, tokenizer, datasets=['wikitext2'], batch_size=eval_batch_size)
        with open(result_path, 'a') as f:
            f.write(f"\nPPL: {ppls}\n")
            f.write(f"\nFinal Experts: {final_model.config.layer_num_experts}\n")
    elif evaluation_mode == "speed":
        pass
    else:
        logger.info("Do not carry out evaluation.")
    
    
    logger.info(f"Final experts: {final_model.config.layer_num_experts}")
    logger.info(f"Final compression rate: {(original_model_params - final_model.num_parameters()) / original_model_params}")

    if save_checkpoint:
        
        final_model.save_pretrained(final_model_path, safe_serialization=False)
        tokenizer.save_pretrained(final_model_path, safe_serialization=False)
        logger.info(f"Final model saved at: {final_model_path}")

    return final_model_path

if __name__ == "__main__":
    
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%Y/%m/%d %H:%M:%S",
        level=logging.INFO,
    )
    
    
    logger.info("Available preset ratios (preset_id: pruning_ratio, merging_ratio):")
    for pid, preset in PRESET_RATIOS.items():
        pr = preset["pruning_ratio"]
        mr = preset["merging_ratio"]
        logger.info(f"  {pid}: {pr:.4f}, {mr:.4f} (product: {pr * mr:.4f})")
    
    
    Fire(adapt_qwen)
