import os
import numpy as np

from typing import Optional, Dict, List, Union, Tuple
import logging
import torch
import time
import json
from argparse import Namespace
from fire import Fire
from transformers import MixtralForCausalLM, AutoTokenizer
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock

from eval import get_calib_dataloder, evaluate_fewshot
from merge_method import (
    run_mixtral_tamp,
    adapt_compress_mixtral,
    dump_mixtral_moe_metrics,
    STAT_TYPE_DICT
)
from utils import (
    eval_task_to_list,
    load_opt_config,
    save_opt_config,
    set_project_seed,
    gen_smooth_staircase
)

from extend import evaluation, ppl_eval

from prune_method import adapt_prune

logger = logging.getLogger(__name__)









PRESET_RATIOS = {
    0: {"pruning_ratio": 0.5454545454545454, "merging_ratio": 0.9166666666666667},  
    1: {"pruning_ratio": 0.5909090909090909, "merging_ratio": 0.8461538461538461},  
    2: {"pruning_ratio": 0.6363636363636364, "merging_ratio": 0.7857142857142857},  
    3: {"pruning_ratio": 0.6818181818181819, "merging_ratio": 0.7333333333333333},  
    4: {"pruning_ratio": 0.7272727272727273, "merging_ratio": 0.6875},             
    5: {"pruning_ratio": 0.75, "merging_ratio": 0.67},  
    6: {"pruning_ratio": 0.8181818181818181, "merging_ratio": 0.6111111111111112},  
    7: {"pruning_ratio": 0.8636363636363636, "merging_ratio": 0.5789473684210527},  
    8: {"pruning_ratio": 0.9090909090909092, "merging_ratio": 0.5499999999999999},  
    9: {"pruning_ratio": 0.9545454545454546, "merging_ratio": 0.5238095238095238},  
    10: {"pruning_ratio": 0.76, "merging_ratio": 0.986842},  
    11: {"pruning_ratio": 0.8, "merging_ratio": 0.937500},  
    12: {"pruning_ratio": 0.82, "merging_ratio": 0.914634},  
    13: {"pruning_ratio": 0.85, "merging_ratio": 0.882353},  
    14: {"pruning_ratio": 0.875, "merging_ratio": 0.857},  
    15: {"pruning_ratio": 0.90, "merging_ratio": 0.833333},  
    16: {"pruning_ratio": 0.92, "merging_ratio": 0.815217},  
    17: {"pruning_ratio": 0.95, "merging_ratio": 0.789474},  
    18: {"pruning_ratio": 0.97, "merging_ratio": 0.773196},  
    19: {"pruning_ratio": 0.98, "merging_ratio": 0.765306},  
}

def call_expert_sparsity(
    model_path: str,
    output_path: str,
    max_block_size: int = 2048,
    calib_set: str = "c4",
    prune: str = "layerwise_pruning",
    experts_per_layer: Optional[List[int]] = None,
    n_blocks_for_stat: int = 32,
    batch_size: int = 16,
    seed: int = 42,
    eval_tasks: Optional[List[str]] = None,
    eval_path: Optional[str] = None,
    config_path: Optional[str] = None,
    model = None,
    tokenizer = None,
    result_save_path: Optional[str] = None,
    opt_save_dir: Optional[str] = None,
    model_output_dir: Optional[str] = None,
) -> Tuple[str, MixtralForCausalLM]:
    os.makedirs(output_path, exist_ok=True)
    
    
    args = Namespace(
        method=prune,
        r=None,  
        calib_set=calib_set,
        model_path=model_path,
        output_path=output_path,
        max_block_size=max_block_size,
        n_blocks_for_stat=n_blocks_for_stat,
        batch_size=batch_size,
        num_workers=8,
        seed=seed,
        use_flash_attention_2=False,
        layer_index=None,  
        evaluate_model=False,
        config_path=config_path,
        expert_config="prune_experts",
        
        result_save_path=result_save_path,
        opt_save_dir=opt_save_dir,
        model_output_dir=model_output_dir,
        
        provided_model=model,
        provided_tokenizer=tokenizer
    )
    
    
    logger.info(f"Executing Expert Sparsity with method: {prune}")
    pruned_model = adapt_prune(args)

    
    pruned_model.config.layer_num_experts = load_opt_config(config_path)['prune_experts']
    
    
    pruned_model_path = os.path.join(output_path, "pruned_model")
    
    return pruned_model_path, pruned_model

def call_hcmoe(
    model_path: str,
    output_path: str,
    task: str = "mmlu",
    num_average_groups: Optional[List[int]] = None,
    dominant: str = "no",
    group: str = "expert",
    similarity_base: str = "router-logits",
    merge: str = "norm_drop_fre",
    mode: str = "normal",
    n_sentences: int = 16,
    max_block_size: int = 2048,
    train_batch_size: int = 16,
    eval_batch_size: int = 32,
    num_fewshot: int = 0,
    dynamic_group: bool = False,
    opt_path: Optional[str] = None,
    model = None,
    tokenizer = None,
    cluster: Optional[str] = "Graph_Partitioning",
    linkage: Optional[str] = "single",
    hierarchical_stopping_metric: Optional[str] = "silhouette",
    ingredient: Optional[str] = "act",
    overlap_metric: Optional[str] = "cosine",
    calib_set: str = "c4",
    seed: int = 42,
) -> Tuple[str, MixtralForCausalLM]:
    os.makedirs(output_path, exist_ok=True)
    
    
    result_path = os.path.join(output_path, "eval_results.json")
    
    
    logger.info(f"Executing HCMOE with similarity_base: {similarity_base}, merge: {merge}, mode: {mode}")
    
    merged_model = run_mixtral_tamp(
        task=task,
        num_average_groups=num_average_groups,
        model_name=model_path,
        dominant=dominant,
        group=group,
        similarity_base=similarity_base,
        merge=merge,
        mode=mode,
        n_sentences=n_sentences,
        max_block_size=max_block_size,
        calib_set=calib_set,
        train_batch_size=train_batch_size,
        eval_batch_size=eval_batch_size,
        seed=seed,
        partition=1,
        start_layer=0,
        output_path=output_path,
        result_path=result_path,
        model_path=None,
        group_limit=4,
        data_limit=50000,
        num_fewshot=num_fewshot,
        try_oracle=False,
        random_start_center=False,
        weight=None,
        cluster=cluster,
        linkage=linkage,
        hierarchical_stopping_metric=hierarchical_stopping_metric,
        ingredient=ingredient,
        overlap_metric=overlap_metric,
        dynamic_group=dynamic_group,
        opt_path=opt_path,
        
        model=model,
        tokenizer=tokenizer
    )

    
    merged_model.config.layer_num_experts = load_opt_config(opt_path)['merge_experts']

    
    merged_model_path = output_path
    return merged_model_path, merged_model

def adapt_mixtral(
    original_model_name: str = "/Path/Mixtral-8x7B-v0.1",
    
    
    adapt_mode: str = "adapt",  
    preset_id: Optional[int] = None,  
    
    
    uniform_prune_experts: Optional[int] = None,
    uniform_merge_experts: Optional[int] = None,
    min_margin: Optional[int] = 2,  
    max_neighbor_diff: Optional[int] = None,  
    
    
    customized_prune_experts: Optional[List[int]] = [7]*8 + [6]*8 + [5]*8 + [4]*8,
    customized_merge_experts: Optional[List[int]] = [5]*8 + [4]*8 + [3]*8 + [2]*8,
    
    
    prune_method: Optional[str] = None, 
    pruning_ratio: Optional[float] = None,
    sigmoid_t: Optional[float] = 0.001,

    
    merge_method: Optional[str] = None, 
    stat_type: Optional[int] = 2, 
    merging_ratio: Optional[float] = None, 
    
    
    objective_type: Optional[str] = "linear", 
    optimize_order: Optional[str] = "prune_first", 

    
    dominant: Optional[str] = "no",
    group: Optional[str] = "expert",
    similarity_base: Optional[str] = "expert-output",
    merge: Optional[str] = "norm_drop_fre",
    mode: Optional[str] = "normal", 
    cluster: Optional[str] = "Graph_Partitioning", 
    linkage: Optional[str] = "average", 
    hierarchical_stopping_metric: Optional[str] = "silhouette", 
    ingredient: Optional[str] = "act", 
    overlap_metric: Optional[str] = "cosine", 

    
    prune: Optional[str] = "layerwise_pruning", 
    
    
    evaluation_mode: Optional[str] = "accuracy", 
    calib_set: str = "wikitext",
    max_block_size: int = 2048,
    n_sentences: int = 16,
    train_batch_size: int = 16,
    eval_batch_size: int = 8,
    num_fewshot: int = 0,
    seed: int = 42,
    rng_seed: Optional[int] = 42,
    eval_task: Union[List[str], str, None] = ["wikitext", "arc_challenge", "arc_easy", "boolq", "hellaswag", "mmlu", "openbookqa", "rte", "winogrande"],
    
    
    skip_metrics_calculation: bool = False,  
    metrics: Union[list[str], str] = ['all'],  
    override: bool = False,  
    
    
    save_checkpoint: bool = False,  
    metrics_save_dir: Optional[str] = "./output/metrics/mixtral",
    opt_save_dir: Optional[str] = "./output/opt/mixtral",
    result_output_dir: Optional[str] = "./output/results/mixtral",
    model_output_dir: Optional[str] = "./output/model/mixtral",
    
    
    use_wandb: Optional[bool] = True,
    wandb_project: Optional[str] = "DM-MOE",
    wandb_entity: Optional[str] = None,
    wandb_run_name: Optional[str] = "Mixtral-8x7B-v0.1",
    wandb_tags: Optional[List[str]] = None,
    wandb_mode: Optional[str] = None,  
    wandb_resume: Optional[str] = None,  
):
    
    os.makedirs(result_output_dir, exist_ok=True)
    os.makedirs(opt_save_dir, exist_ok=True)
    os.makedirs(model_output_dir, exist_ok=True)
    os.makedirs(os.path.join(opt_save_dir, optimize_order), exist_ok=True)
    set_project_seed(seed=seed)

    
    if preset_id is not None:
        if preset_id not in PRESET_RATIOS:
            raise ValueError(f"Invalid preset_id: {preset_id}. Must be one of {list(PRESET_RATIOS.keys())}.")
        
        if adapt_mode == "adapt":
            
            preset = PRESET_RATIOS[preset_id]
            pruning_ratio = preset["pruning_ratio"]
            merging_ratio = preset["merging_ratio"]
            logger.info(f"Using preset {preset_id}: pruning_ratio={pruning_ratio:.4f}, merging_ratio={merging_ratio:.4f}")
            logger.info(f"Combined effect (product): {pruning_ratio * merging_ratio:.4f}")
        else:
            logger.warning(f"preset_id is only effective when adapt_mode='adapt'. Ignoring preset_id={preset_id}.")

    run = None
    if use_wandb:
        try:
            if wandb_mode in ["offline", "disabled"]:
                os.environ["WANDB_MODE"] = wandb_mode
            
            import wandb  
            init_kwargs = dict(
                project=wandb_project,
                name=wandb_run_name,
                tags=wandb_tags,
            )
            if wandb_entity:
                init_kwargs["entity"] = wandb_entity
            if wandb_resume:
                init_kwargs["resume"] = wandb_resume
            run = wandb.init(**init_kwargs)
            
            if run is not None:
                run.config.update({
                    "original_model_name": original_model_name,
                    "adapt_mode": adapt_mode,
                    "preset_id": preset_id,
                    "uniform_prune_experts": uniform_prune_experts,
                    "uniform_merge_experts": uniform_merge_experts,
                    "customized_prune_experts": customized_prune_experts,
                    "customized_merge_experts": customized_merge_experts,
                    "prune_method": prune_method,
                    "merge_method": merge_method,
                    "pruning_ratio": pruning_ratio,
                    "merging_ratio": merging_ratio,
                    "stat_type": stat_type,
                    "min_margin": min_margin,
                    "max_neighbor_diff": max_neighbor_diff,
                    "sigmoid_t": sigmoid_t,
                    "objective_type": objective_type,
                    "optimize_order": optimize_order,
                    "dominant": dominant,
                    "group": group,
                    "similarity_base": similarity_base,
                    "merge": merge,
                    "mode": mode,
                    "cluster": cluster,
                    "linkage": linkage,
                    "hierarchical_stopping_metric": hierarchical_stopping_metric,
                    "ingredient": ingredient,
                    "overlap_metric": overlap_metric,
                    "prune": prune,
                    "evaluation_mode": evaluation_mode,
                    "calib_set": calib_set,
                    "n_sentences": n_sentences,
                    "train_batch_size": train_batch_size,
                    "eval_batch_size": eval_batch_size,
                    "num_fewshot": num_fewshot,
                    "seed": seed,
                    "rng_seed": rng_seed,
                    "eval_task": eval_task,
                    "skip_metrics_calculation": skip_metrics_calculation,
                    "metrics": metrics,
                    "override": override,
                    "save_checkpoint": save_checkpoint,
                    "metrics_save_dir": metrics_save_dir,
                    "opt_save_dir": opt_save_dir,
                    "result_output_dir": result_output_dir,
                    "model_output_dir": model_output_dir,
                    "use_wandb": use_wandb,
                    "wandb_project": wandb_project,
                    "wandb_entity": wandb_entity,
                    "wandb_run_name": wandb_run_name,
                    "wandb_tags": wandb_tags,
                    "wandb_mode": wandb_mode,
                    "wandb_resume": wandb_resume,
                }, allow_val_change=True)
        except Exception as wandb_e:
            logger.warning(f"W&B initialization failed, continuing without W&B. Error: {wandb_e}")
            run = None
    
    
    logger.info(f"Loading model from {original_model_name}")
    model = MixtralForCausalLM.from_pretrained(
        original_model_name,
        torch_dtype=torch.bfloat16,
        
        device_map="auto",
    )
    model.config.layer_num_experts = [model.config.num_local_experts for _ in range(32)]
    tokenizer = AutoTokenizer.from_pretrained(original_model_name)
    tokenizer.pad_token_id = tokenizer.eos_token_id
    model_name = original_model_name.split("/")[-1]

    
    original_model_params = model.num_parameters()
    
    
    if preset_id is not None and adapt_mode == "adapt":
        
        model_path_suffix = f"{model_name}_preset{preset_id}"
    elif stat_type is not None:
        model_path_suffix = f"{model_name}_{prune_method}_{pruning_ratio}_{STAT_TYPE_DICT[stat_type]}-{merge_method}_{merging_ratio}"
    else:
        model_path_suffix = f"{model_name}_{prune_method}_{pruning_ratio}_{merge_method}_{merging_ratio}"
    opt_config_path = os.path.join(opt_save_dir, f"{optimize_order}/{model_path_suffix}.yaml")
    os.makedirs(os.path.join(result_output_dir, f"{optimize_order}"), exist_ok=True)
    if run is not None:
        run.config.update({
            "opt_config_path": opt_config_path,
            "metrics_save_dir": metrics_save_dir,
            "opt_save_dir": opt_save_dir,
            "result_output_dir": result_output_dir,
            "model_output_dir": model_output_dir,
        }, allow_val_change=True)

    
    if adapt_mode not in ["adapt", "uniform", "owl", "random", "up", "down", "customize"]:
        raise ValueError(f"Unsupported adapt_mode: {adapt_mode}. Must be 'adapt', 'uniform', 'owl', 'random', 'up', 'down', or 'customize'.")
    
    
    if adapt_mode == "uniform":
        if uniform_prune_experts is None and optimize_order in ["prune_first", "prune_only"]:
            raise ValueError("In uniform mode with prune_first or prune_only, uniform_prune_experts must be specified.")
        if uniform_merge_experts is None and optimize_order in ["merge_first", "merge_only"]:
            raise ValueError("In uniform mode with merge_first or merge_only, uniform_merge_experts must be specified.")
    
    
    if adapt_mode == "customize":
        
        if customized_prune_experts is None and optimize_order in ["prune_first", "prune_only"]:
            raise ValueError("In customize mode with prune_first or prune_only, customized_prune_experts must be specified.")
        if customized_merge_experts is None and optimize_order in ["merge_first", "merge_only"]:
            raise ValueError("In customize mode with merge_first or merge_only, customized_merge_experts must be specified.")
            
        
        if customized_prune_experts is not None and len(customized_prune_experts) != 32:
            raise ValueError(f"customized_prune_experts must have length 32, got {len(customized_prune_experts)}")
        if customized_merge_experts is not None and len(customized_merge_experts) != 32:
            raise ValueError(f"customized_merge_experts must have length 32, got {len(customized_merge_experts)}")
            
        
        if customized_prune_experts is not None:
            for i, num in enumerate(customized_prune_experts):
                if num < 2 or num > 8:
                    raise ValueError(f"customized_prune_experts[{i}] = {num} is out of range [2, 8]")
        if customized_merge_experts is not None:
            for i, num in enumerate(customized_merge_experts):
                if num < 2 or num > 8:
                    raise ValueError(f"customized_merge_experts[{i}] = {num} is out of range [2, 8]")
    
    
    if not os.path.exists(opt_config_path) or override:
        logger.info("Computing metrics and optimizing expert counts...")
        
        os.makedirs(metrics_save_dir, exist_ok=True)
        
        if adapt_mode == "adapt":
            
            logger.info("Using adaptive expert allocation mode")
            
            
            time_start = time.time()
            results = adapt_compress_mixtral(
                model_name=original_model_name,
                calib_set=calib_set,
                opt_batch_size=train_batch_size,
                n_sentences=n_sentences,
                skip_metrics_calculation=skip_metrics_calculation,
                metrics=metrics if isinstance(metrics, list) else metrics.split(","),
                min_margin=min_margin,
                max_neighbor_diff=max_neighbor_diff,
                prune_method=prune_method,
                pruning_ratio=pruning_ratio,
                sigmoid_t=sigmoid_t,
                merge_method=merge_method,
                stat_type=stat_type,
                merging_ratio=merging_ratio,
                optimize_order=optimize_order,
                objective_type=objective_type,
                save_opt=False,  
                metric_save_dir=metrics_save_dir,
                opt_save_dir=opt_save_dir,
                override=override,
                
                model=model,
                tokenizer=tokenizer
            )
            print(f"The Time of calculating expert counts: {time.time() - time_start}")
            
            prune_experts = results["prune_experts"]
            merge_experts = results["merge_experts"]
        elif adapt_mode == "customize":
            logger.info("Using customize expert allocation mode")
            
            
            if optimize_order in ["prune_first", "prune_only"]:
                prune_experts = customized_prune_experts
                
                if optimize_order == "prune_only" or customized_merge_experts is None:
                    merge_experts = prune_experts.copy()
                else:
                    merge_experts = customized_merge_experts
                    
                    merge_experts = [min(m, p) for m, p in zip(merge_experts, prune_experts)]
            else:  
                merge_experts = customized_merge_experts
                
                if optimize_order == "merge_only" or customized_prune_experts is None:
                    prune_experts = merge_experts.copy()
                else:
                    prune_experts = customized_prune_experts
                    
                    prune_experts = [min(p, m) for p, m in zip(prune_experts, merge_experts)]
                    
            logger.info(f"Customized prune experts: {prune_experts}")
            logger.info(f"Customized merge experts: {merge_experts}")
        elif adapt_mode == "owl":
            logger.info("Using OWL expert allocation mode")
            logger.warning("OWL expert allocation mode is not implemented for Mixtral-8x7B-v0.1")
            pass
        elif adapt_mode == "random":
            logger.info("Using random expert allocation mode")
            
            
            np.random.seed(seed)
            num_layers = 32
            original_experts_per_layer = 8
            min_experts = 2  
            max_experts = 8  
            
            if optimize_order in ["merge_only", "prune_only"]:
                
                total_experts = num_layers * original_experts_per_layer
                
                ratio = merging_ratio if optimize_order == "merge_only" else pruning_ratio
                target_total_experts = int(total_experts * ratio)
                logger.info(f"Target total experts after merging: {target_total_experts}")
                
                
                while True:
                    
                    random_experts = np.random.randint(min_experts, max_experts + 1, size=num_layers)
                    
                    current_total = np.sum(random_experts)
                    
                    
                    if abs(current_total - target_total_experts) <= num_layers // 2:
                        step2_experts = random_experts
                        
                        step1_experts = random_experts.copy()
                        break
            elif optimize_order in ["prune_first", "merge_first"]:
                
                total_experts = num_layers * original_experts_per_layer
                
                
                target_experts = int(total_experts * pruning_ratio) if optimize_order == "prune_first" else int(total_experts * merging_ratio)
                logger.info(f"Target total experts after step1: {target_experts}")
                
                
                while True:
                    random_experts = np.random.randint(min_experts, max_experts + 1, size=num_layers)
                    current_total = np.sum(random_experts)
                    
                    if abs(current_total - target_experts) <= num_layers // 2:
                        step1_experts = random_experts
                        break
                
                
                target_experts = int(np.sum(step1_experts) * merging_ratio) if optimize_order == "prune_first" else int(np.sum(step1_experts) * pruning_ratio)
                logger.info(f"Target total experts after step2: {target_experts}")
                
                
                while True:
                    random_merge_experts = np.array([np.random.randint(min_experts, min(step1_experts[i], max_experts) + 1) for i in range(num_layers)])
                    current_total = np.sum(random_merge_experts)
                    
                    if abs(current_total - target_experts) <= num_layers // 2:
                        step2_experts = random_merge_experts
                        break
                
            if optimize_order in ["prune_first", "prune_only"]:
                prune_experts = step1_experts
                merge_experts = step2_experts
            elif optimize_order in ["merge_first", "merge_only"]:
                prune_experts = step2_experts
                merge_experts = step1_experts
            
            
            logger.info(f"Random prune experts: {prune_experts}")
            logger.info(f"Random merge experts: {merge_experts}")
        elif adapt_mode in ["up", "down"]:
            logger.info("Using up/down expert allocation mode")
            
            num_layers = 32
            original_experts_per_layer = 8
            min_experts = 2  
            max_experts = 8  

            target_total_experts = num_layers * original_experts_per_layer * merging_ratio if optimize_order == "merge_only" else num_layers * original_experts_per_layer * pruning_ratio
            
            step_experts = gen_smooth_staircase(
                num_layers   = num_layers,
                min_experts  = min_experts,
                max_experts  = max_experts,
                target_total = target_total_experts,
                adapt_mode   = adapt_mode,
                base_noise   = 0.25,   
                max_step     = 2,      
                rng_seed     = rng_seed    
            )
            step_experts = np.array(sorted(step_experts, reverse=True if adapt_mode == "down" else False))
            
            
            if optimize_order == "merge_only":
                merge_experts = step_experts
                prune_experts = step_experts.copy()  
            else:  
                prune_experts = step_experts
                merge_experts = step_experts.copy()  
            
            logger.info(f"Up allocation generated experts configuration with sum: {sum(step_experts)}")
                    

            
            logger.info(f"Up/Down mode prune experts: {prune_experts}")
            logger.info(f"Up/Down mode merge experts: {merge_experts}")
            
            
            assert all(min_experts <= e <= max_experts for e in prune_experts), "Prune experts configuration violates min/max constraints"
            assert all(min_experts <= e <= max_experts for e in merge_experts), "Merge experts configuration violates min/max constraints"
        else:
            
            logger.info("Using uniform expert allocation mode")
            
            num_layers = 32
            
            if optimize_order in ["prune_first", "prune_only"]:
                
                prune_experts = np.full(num_layers, uniform_prune_experts, dtype=np.int32)
                logger.info(f"Using uniform pruning with {uniform_prune_experts} experts per layer")
                
                if optimize_order == "prune_first" and uniform_merge_experts is not None:
                    
                    merge_experts = np.full(num_layers, min(uniform_prune_experts, uniform_merge_experts), dtype=np.int32)
                    logger.info(f"Using uniform merging with {min(uniform_prune_experts, uniform_merge_experts)} experts per layer")
                elif optimize_order == "prune_first":
                    
                    merge_experts = prune_experts.copy()
                    logger.info(f"No uniform_merge_experts specified, using prune_experts value")
                else:
                    
                    merge_experts = prune_experts.copy()
            else:
                
                if uniform_merge_experts is not None:
                    merge_experts = np.full(num_layers, uniform_merge_experts, dtype=np.int32)
                    logger.info(f"Using uniform merging with {uniform_merge_experts} experts per layer")
                else:
                    raise ValueError("uniform_merge_experts must be specified for merge_first or merge_only in uniform mode")
                
                if optimize_order == "merge_first" and uniform_prune_experts is not None:
                    
                    prune_experts = np.full(num_layers, min(uniform_merge_experts, uniform_prune_experts), dtype=np.int32)
                    logger.info(f"Using uniform pruning with {min(uniform_merge_experts, uniform_prune_experts)} experts per layer")
                elif optimize_order == "merge_first":
                    
                    prune_experts = merge_experts.copy()
                    logger.info(f"No uniform_prune_experts specified, using merge_experts value")
                else:
                    
                    prune_experts = merge_experts.copy()
        
        
        config = {
            "prune_experts": prune_experts.tolist() if isinstance(prune_experts, np.ndarray) else prune_experts,
            "merge_experts": merge_experts.tolist() if isinstance(merge_experts, np.ndarray) else merge_experts
        }
        save_opt_config(config, opt_config_path)
        logger.info(f"Saved expert configuration to {opt_config_path}")
        if run is not None:
            run.config.update({
                "computed_prune_experts": config["prune_experts"],
                "computed_merge_experts": config["merge_experts"],
            }, allow_val_change=True)
    
    
    opt_config = load_opt_config(opt_config_path)
    
    prune_experts_per_layer = opt_config["prune_experts"]
    merge_experts_per_layer = opt_config["merge_experts"]
    
    logger.info(f"Pruning experts per layer: {prune_experts_per_layer}")
    logger.info(f"Merging experts per layer: {merge_experts_per_layer}")
    if run is not None:
        run.config.update({
            "prune_experts": prune_experts_per_layer,
            "merge_experts": merge_experts_per_layer,
            "sum_prune_experts": int(np.sum(prune_experts_per_layer)),
            "sum_merge_experts": int(np.sum(merge_experts_per_layer)),
        }, allow_val_change=True)
    
    
    final_model_path = None
    final_model = None
    
    if optimize_order == "prune_first":
        logger.info("Execute order: Prune first")
        
        
        model_output_dir = os.path.join(model_output_dir, f"{optimize_order}/{model_path_suffix}")
        
        
        result_path = os.path.join(result_output_dir, f"{optimize_order}/{model_path_suffix}")
        
        
        logger.info("Step 1: Pruning experts using Expert Sparsity method")
        pruned_model_path, pruned_model = call_expert_sparsity(
            model_path=original_model_name,
            output_path=model_output_dir,
            max_block_size=max_block_size,
            calib_set=calib_set,
            experts_per_layer=prune_experts_per_layer,
            prune=prune,
            n_blocks_for_stat=n_sentences,
            batch_size=train_batch_size,
            seed=seed,
            eval_tasks=eval_task_to_list(eval_task),
            eval_path=result_path,
            config_path=opt_config_path,
            model=model,
            tokenizer=tokenizer,
            result_save_path=result_path,
            opt_save_dir=os.path.join(opt_save_dir, optimize_order),
            model_output_dir=model_output_dir
        )
        
        
        logger.info("Step 2: Merging experts using HCMOE method")
        merged_model_path, merged_model = call_hcmoe(
            model_path=pruned_model_path,
            output_path=model_output_dir,
            max_block_size=max_block_size,
            task=eval_task[0] if isinstance(eval_task, list) else eval_task,
            num_average_groups=merge_experts_per_layer,
            dominant=dominant,
            group=group,
            similarity_base=similarity_base,
            merge=merge,
            mode=mode,
            n_sentences=n_sentences,
            train_batch_size=train_batch_size,
            eval_batch_size=eval_batch_size,
            seed=seed,
            num_fewshot=num_fewshot,
            dynamic_group=True,
            opt_path=opt_config_path,
            model=pruned_model,
            tokenizer=tokenizer,
            cluster=cluster,
            linkage=linkage,
            hierarchical_stopping_metric=hierarchical_stopping_metric,
            ingredient=ingredient,
            overlap_metric=overlap_metric,
            calib_set=calib_set
        )
        
        final_model_path = merged_model_path
        final_model = merged_model

    elif optimize_order == "merge_first":
        logger.info("Execute order: Merge first")
        
        
        model_output_dir = os.path.join(model_output_dir, f"{optimize_order}/{model_path_suffix}")
        
        result_path = os.path.join(result_output_dir, f"{optimize_order}/{model_path_suffix}")
        
        
        logger.info("Step 1: Merging experts using HCMOE method")
        merged_model_path, merged_model = call_hcmoe(
            model_path=original_model_name,
            output_path=model_output_dir,
            task=eval_task_to_list(eval_task),
            num_average_groups=merge_experts_per_layer,
            dominant=dominant,
            group=group,
            similarity_base=similarity_base,
            merge=merge,
            mode=mode,
            n_sentences=n_sentences,
            max_block_size=max_block_size,
            train_batch_size=train_batch_size,
            eval_batch_size=eval_batch_size,
            seed=seed,
            num_fewshot=num_fewshot,
            dynamic_group=True,
            opt_path=opt_config_path,
            model=model,
            tokenizer=tokenizer,
            cluster=cluster,
            linkage=linkage,
            hierarchical_stopping_metric=hierarchical_stopping_metric,
            ingredient=ingredient,
            overlap_metric=overlap_metric,
            calib_set=calib_set
        )
        
        
        logger.info("Step 2: Pruning experts using Expert Sparsity method")
        pruned_model_path, pruned_model = call_expert_sparsity(
            model_path=merged_model_path,
            max_block_size=max_block_size,
            output_path=model_output_dir,
            calib_set=calib_set,
            experts_per_layer=prune_experts_per_layer,
            prune=prune,
            n_blocks_for_stat=n_sentences,
            batch_size=train_batch_size,
            seed=seed,
            eval_tasks=eval_task_to_list(eval_task),
            eval_path=result_path,
            config_path=opt_config_path,
            model=merged_model,
            tokenizer=tokenizer,
            result_save_path=result_path,
            opt_save_dir=os.path.join(opt_save_dir, optimize_order),
            model_output_dir=model_output_dir
        )
        
        final_model_path = pruned_model_path
        final_model = pruned_model
        
    elif optimize_order == "prune_only":
        logger.info("Execute order: Prune only")
        
        
        model_output_dir = os.path.join(model_output_dir, f"{optimize_order}/{model_path_suffix}")
        
        
        result_path = os.path.join(result_output_dir, f"{optimize_order}/{model_path_suffix}")
        
        
        logger.info("Step 1: Pruning experts using Expert Sparsity method")
        pruned_model_path, pruned_model = call_expert_sparsity(
            model_path=original_model_name,
            max_block_size=max_block_size,
            output_path=model_output_dir,
            calib_set=calib_set,
            experts_per_layer=prune_experts_per_layer,
            prune=prune,
            n_blocks_for_stat=n_sentences,
            batch_size=train_batch_size,
            seed=seed,
            eval_tasks=eval_task_to_list(eval_task),
            eval_path=result_path,
            config_path=opt_config_path,
            model=model,
            tokenizer=tokenizer,
            result_save_path=result_path,
            opt_save_dir=os.path.join(opt_save_dir, optimize_order),
            model_output_dir=model_output_dir
        )
        
        final_model_path = pruned_model_path
        final_model = pruned_model
        
    elif optimize_order == "merge_only":
        logger.info("Execute order: Merge only")
        
        
        model_output_dir = os.path.join(model_output_dir, f"{optimize_order}/{model_path_suffix}")
        
        
        result_path = os.path.join(result_output_dir, f"{optimize_order}/{model_path_suffix}")
        
        
        logger.info("Step 1: Merging experts using HCMOE method")
        merged_model_path, merged_model = call_hcmoe(
            model_path=original_model_name,
            output_path=model_output_dir,
            task=eval_task_to_list(eval_task),
            num_average_groups=merge_experts_per_layer,
            dominant=dominant,
            group=group,
            similarity_base=similarity_base,
            merge=merge,
            mode=mode,
            n_sentences=n_sentences,
            max_block_size=max_block_size,
            train_batch_size=train_batch_size,
            eval_batch_size=eval_batch_size,
            num_fewshot=num_fewshot,
            dynamic_group=True,
            opt_path=opt_config_path,
            model=model,
            tokenizer=tokenizer,
            cluster=cluster,
            linkage=linkage,
            hierarchical_stopping_metric=hierarchical_stopping_metric,
            ingredient=ingredient,
            overlap_metric=overlap_metric,
            calib_set=calib_set
        )
        
        final_model_path = merged_model_path
        final_model = merged_model
    else:
        raise ValueError(f"Unsupported optimize_order: {optimize_order}. Must be 'prune_first', 'merge_first', 'prune_only', or 'merge_only'.")
    
    
    logger.info(f"Final adapted model saved at: {final_model_path}")
    logger.info("Evaluating final model...")
    
    
    if final_model is None:
        final_model = MixtralForCausalLM.from_pretrained(
            final_model_path,
            torch_dtype=torch.bfloat16,
            
            device_map="auto"
        )
    
    
    result_path = os.path.join(result_output_dir, f"{optimize_order}/{model_path_suffix}_final_eval_results.txt")
    if evaluation_mode == "accuracy":
        eval_results = evaluate_fewshot(
            model=final_model,
            tokenizer=tokenizer,
            task=eval_task_to_list(eval_task),
            num_fewshot=num_fewshot,
            eval_batch_size=eval_batch_size,
            output_path=result_path,
            log=True
        )
        ppls = ppl_eval(final_model, tokenizer, datasets=['wikitext2'], batch_size=eval_batch_size)
        with open(result_path, 'a') as f:
            f.write(f"\nPPL: {ppls}\n")
            f.write(f"\nFinal Experts: {final_model.config.layer_num_experts}\n")
        
        if run is not None:
            
            for key, value in eval_results.items():
                if key != "results":
                    continue
                for dataset_name, dataset_results in value.items():
                    run.log({f"{dataset_name}/acc": dataset_results["acc,none"]})
                    run.log({f"{dataset_name}/acc_std": dataset_results["acc_stderr,none"]})
            
            serializable_eval = json.loads(json.dumps(eval_results, default=str))
            
            run.log({f"ppl/{k}": v for k, v in ppls.items()})
            run.summary["final_experts"] = final_model.config.layer_num_experts
            compression_rate = (original_model_params - final_model.num_parameters()) / original_model_params
            run.summary["compression_rate"] = compression_rate
    elif evaluation_mode == "speed":
        pass
    else:
        logger.info("Do not carry out evaluation.")
    
    
    logger.info(f"Final experts: {final_model.config.layer_num_experts}")
    logger.info(f"Final compression rate: {(original_model_params - final_model.num_parameters()) / original_model_params}")

    if save_checkpoint:
        
        final_model.save_pretrained(final_model_path, safe_serialization=False)
        tokenizer.save_pretrained(final_model_path, safe_serialization=False)
        logger.info(f"Final model saved at: {final_model_path}")

    
    if run is not None:
        run.finish()

    return final_model_path

if __name__ == "__main__":
    
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%Y/%m/%d %H:%M:%S",
        level=logging.INFO,
    )
    
    
    logger.info("Available preset ratios (preset_id: pruning_ratio, merging_ratio):")
    for pid, preset in PRESET_RATIOS.items():
        pr = preset["pruning_ratio"]
        mr = preset["merging_ratio"]
        logger.info(f"  {pid}: {pr:.4f}, {mr:.4f} (product: {pr * mr:.4f})")
    
    
    Fire(adapt_mixtral)
