import os
import time
import warnings
import gc
from typing import List, Tuple, Dict

# Set CUDA allocator before importing torch to reduce fragmentation
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
os.environ["TMPDIR"] = os.path.expanduser("~/.tmp")
os.makedirs(os.environ["TMPDIR"], exist_ok=True)

import numpy as np
import pandas as pd
import torch
from tqdm import tqdm
from transformers import AutoTokenizer
from dataclasses import dataclass
from enum import Enum

# Suppress flash_attn related warnings
warnings.filterwarnings("ignore", message=".*flash_attn.*")
warnings.filterwarnings("ignore", message=".*undefined symbol.*")

# Import local PhiMoE model and config
from .modeling_phimoe import PhiMoEForCausalLM
from .configuration_phimoe import PhiMoEConfig


class ExpertSelectionStrategy(Enum):
    """Expert selection strategy enum"""
    ADAPTIVE_THRESHOLD = "adaptive_threshold"
    IMPORTANCE_BASED = "importance_based"
    LOAD_BALANCED = "load_balanced"
    HYBRID = "hybrid"
    DYNAMIC_LEARNING = "dynamic_learning"


@dataclass
class ExpertUsageStats:
    """Per-expert usage statistics"""
    activation_count: int = 0
    total_probability: float = 0.0
    avg_probability: float = 0.0
    last_used: float = 0.0
    importance_score: float = 1.0
    efficiency_score: float = 1.0


class ExpertImportanceTracker:
    """Expert importance tracker"""
    
    def __init__(self, num_experts: int, decay_factor: float = 0.95):
        self.num_experts = num_experts
        self.decay_factor = decay_factor
        self.expert_stats = [ExpertUsageStats() for _ in range(num_experts)]
        self.global_step = 0
        
    def update_stats(self, selected_experts: torch.Tensor, 
                    routing_weights: torch.Tensor, 
                    expert_outputs: torch.Tensor):
        """Update expert usage statistics"""
        self.global_step += 1
        current_time = time.time()
        
        # Count expert activations
        for i in range(selected_experts.shape[0]):
            for j in range(selected_experts.shape[1]):
                expert_idx = selected_experts[i, j].item()
                if expert_idx > 0:  # 有效的专家索引
                    weight = routing_weights[i, j].item()
                    
                    # Update per-expert stats
                    stats = self.expert_stats[expert_idx]
                    stats.activation_count += 1
                    stats.total_probability += weight
                    stats.avg_probability = stats.total_probability / stats.activation_count
                    stats.last_used = current_time
                    
                    # Efficiency score (based on output magnitude)
                    if expert_outputs is not None:
                        output_norm = torch.norm(expert_outputs[i, j]).item()
                        stats.efficiency_score = 0.9 * stats.efficiency_score + 0.1 * output_norm
        
        # Recompute importance scores
        self._compute_importance_scores()
    
    def _compute_importance_scores(self):
        """Compute expert importance scores"""
        total_activations = sum(stats.activation_count for stats in self.expert_stats)
        if total_activations == 0:
            return
        
        for stats in self.expert_stats:
            # Importance based on activation frequency
            frequency_score = stats.activation_count / total_activations
            
            # Importance based on average routing probability
            probability_score = stats.avg_probability
            
            # Importance based on efficiency
            efficiency_score = stats.efficiency_score
            
            # Final combined importance score
            stats.importance_score = (
                0.4 * frequency_score + 
                0.4 * probability_score + 
                0.2 * efficiency_score
            )


class PerformanceMonitor:
    """Performance monitor"""
    
    def __init__(self, window_size: int = 100):
        self.window_size = window_size
        self.performance_history = []
        self.expert_usage_history = []
        
    def update(self, selected_experts: torch.Tensor, 
               selected_weights: torch.Tensor, 
               expert_outputs: torch.Tensor):
        """Update performance metrics"""
        # Average number of activated experts per token
        experts_per_token = (selected_experts != 0).sum(dim=-1).float().mean().item()
        
        # Output quality proxy (L2 norm magnitude)
        output_quality = torch.norm(expert_outputs, dim=-1).mean().item()
        
        # Weight distribution entropy
        weight_entropy = -torch.sum(selected_weights * torch.log(selected_weights + 1e-8), dim=-1).mean().item()
        
        performance_metric = (output_quality + weight_entropy) / 2
        
        self.performance_history.append(performance_metric)
        self.expert_usage_history.append(experts_per_token)
        
        # Keep sliding window size
        if len(self.performance_history) > self.window_size:
            self.performance_history.pop(0)
            self.expert_usage_history.pop(0)
    
    def get_performance_stats(self) -> Dict:
        """Get performance stats"""
        if not self.performance_history:
            return {}
        
        return {
            'avg_performance': np.mean(self.performance_history),
            'avg_experts_per_token': np.mean(self.expert_usage_history),
            'performance_trend': self.performance_history[-10:],
            'usage_trend': self.expert_usage_history[-10:]
        }


class AdaptiveThresholdManager:
    """Adaptive threshold manager (paper-aligned three-stage method)"""
    
    def __init__(self, num_layers: int, base_config):
        self.num_layers = num_layers
        self.base_config = base_config
        self.threshold_history = []
        self.performance_history = []
        
    def analyze_routing_distribution(self, router_logits_list: List[torch.Tensor]) -> Dict:
        """Analyze routing weight distribution (returns mean, std, percentiles)."""
        if not router_logits_list:
            return {}
        
        all_weights = []
        for router_logits in router_logits_list:
            if router_logits is not None and isinstance(router_logits, torch.Tensor) and not router_logits.is_meta:
                weights = torch.nn.functional.softmax(router_logits, dim=-1)
                all_weights.append(weights.flatten())
        
        if not all_weights:
            return {}
        
        # Merge all layer weights
        combined_weights = torch.cat(all_weights)
        # Ensure float32 for quantiles
        combined_weights = combined_weights.to(torch.float32)
        
        # Basic statistics
        mean_weight = combined_weights.mean().item()
        std_weight = combined_weights.std().item()
        q = torch.tensor([0.1, 0.25, 0.5, 0.75, 0.9], device=combined_weights.device, dtype=torch.float32)
        percentiles = torch.quantile(combined_weights, q).tolist()
        
        return {
            'mean_weight': mean_weight,
            'std_weight': std_weight,
            'percentiles': percentiles,
            'total_weights': combined_weights.numel()
        }
    
    def suggest_optimal_threshold(self, routing_stats: Dict, current_performance: float) -> float:
        """Three-stage threshold selection: τ_theo -> τ_PPL -> τ_final"""
        if not routing_stats:
            return getattr(self.base_config, 'expert_probability_threshold', 0.3)
        
        mean_weight = routing_stats['mean_weight']
        std_weight = routing_stats['std_weight']
        percentiles = routing_stats['percentiles']  # [P10,P25,P50,P75,P90]

        # 1) Theoretical grounding: τ_theo = max(α1·μ_w, μ_w − α2·σ_w)
        alpha1 = getattr(self.base_config, 'alpha1', 0.8)
        alpha2 = getattr(self.base_config, 'alpha2', 0.5)
        tau_theo = max(alpha1 * mean_weight, mean_weight - alpha2 * std_weight)

        # 2) Performance-based adjustment using ΔPPL
        if len(self.performance_history) >= 1:
            prev = self.performance_history[-1]
            denom = prev if abs(prev) > 1e-8 else 1e-8
            delta = (current_performance - prev) / denom
            lambda_decrease = getattr(self.base_config, 'lambda_decrease', 0.2)
            lambda_increase = getattr(self.base_config, 'lambda_increase', 0.1)
            if delta > 0:
                tau_ppl = tau_theo * (1.0 - lambda_decrease)
            elif delta < 0:
                tau_ppl = tau_theo * (1.0 + lambda_increase)
            else:
                tau_ppl = tau_theo
        else:
            tau_ppl = tau_theo

        # 3) Boundedness and smoothing: τ_final = κ1·τ_PPL + κ2·τ_theo, then clamp to [P10,P90]
        min_thresh = percentiles[0]
        max_thresh = percentiles[4]
        k1 = getattr(self.base_config, 'kappa1', 0.7)
        k2 = getattr(self.base_config, 'kappa2', 0.3)
        tau_final = k1 * tau_ppl + k2 * tau_theo
        tau_final = max(min_thresh, min(max_thresh, tau_final))
        return tau_final
    
    def update_thresholds(self, model, router_logits_list: List[torch.Tensor], current_performance: float):
        """Per-layer threshold update using three-stage method"""
        updated_layers = 0
        per_layer_thresholds = []

        for i in range(self.num_layers):
            layer_ref = model.model.layers[i]
            moe_block = getattr(layer_ref, 'mlp', None)
            if moe_block is None:
                moe_block = getattr(layer_ref, 'block_sparse_moe', None)
            if not (hasattr(moe_block, 'gate') and hasattr(moe_block, 'experts')):
                per_layer_thresholds.append(float('nan'))
                continue

            # Router logits for this layer
            router_logits = None
            try:
                router_logits = router_logits_list[i]
            except Exception:
                router_logits = None

            if router_logits is None or not isinstance(router_logits, torch.Tensor) or router_logits.is_meta:
                per_layer_thresholds.append(getattr(moe_block, 'current_threshold', 0.18))
                continue

            weights = torch.nn.functional.softmax(router_logits, dim=-1)
            flat = weights.to(torch.float32).flatten()
            if flat.numel() == 0:
                per_layer_thresholds.append(getattr(moe_block, 'current_threshold', 0.18))
                continue

            # Compute percentiles consistent with global stats
            q = torch.tensor([0.1, 0.25, 0.5, 0.6, 0.75, 0.8, 0.85, 0.9], device=flat.device, dtype=torch.float32)
            q_vals = torch.quantile(flat, q).tolist()
            routing_stats = {
                'mean_weight': float(flat.mean().item()),
                'std_weight': float(flat.std().item()),
                'percentiles': [q_vals[0], q_vals[1], q_vals[2], q_vals[4], q_vals[7]],  # P10,P25,P50,P75,P90
                'p60': q_vals[3],
                'p80': q_vals[5],
                'p85': q_vals[6],
                'total_weights': flat.numel(),
            }

            optimal_threshold = self.suggest_optimal_threshold(routing_stats, current_performance)

            current_thresh = getattr(moe_block, 'current_threshold', 0.18)
            if abs(current_thresh - optimal_threshold) > 0.001:
                moe_block.current_threshold = optimal_threshold
                updated_layers += 1
            per_layer_thresholds.append(optimal_threshold)

        if updated_layers > 0:
            avg_thr = float(np.nanmean(per_layer_thresholds)) if per_layer_thresholds else float('nan')
            print(f"[Threshold Optimization] Updated {updated_layers} layers; mean: {avg_thr:.4f}")

        # Record average threshold for history
        try:
            self.threshold_history.append(float(np.nanmean(per_layer_thresholds)))
        except Exception:
            self.threshold_history.append(per_layer_thresholds[0] if per_layer_thresholds else float('nan'))
        self.performance_history.append(current_performance)
        if len(self.threshold_history) > 100:
            self.threshold_history.pop(0)
            self.performance_history.pop(0)


def calculate_active_experts_innovative(router_logits, layer_idx, model):
    """Count active experts exactly as in the paper using IS and dual-constraint activation.

    Steps:
      1) g = softmax(router_logits)
      2) History score Sh = w1*N/sumN + w2*P_total/N + w3*Cs (α=0.95 updated in model)
      3) IS = (1-β)*g + β*Sh with β=0.25
      4) Candidate set E = {i | IS_i > τ}, τ from moe_block.current_threshold (tail layers relaxed)
      5) Activation set A: if |E|=0 take argmax(IS); if 1≤|E|≤K take E; if |E|>K take Top-K by IS within E
      6) Return mean |A| per token
    """
    # Locate the MoE block of this layer
    layer_ref = model.model.layers[layer_idx]
    moe_block = getattr(layer_ref, 'mlp', None)
    if moe_block is None:
        moe_block = getattr(layer_ref, 'block_sparse_moe', None)
    if moe_block is None or not hasattr(moe_block, 'gate'):
        return float(getattr(model.config, 'num_experts_per_tok', 2))

    # Routing probability g
    if router_logits is None or not isinstance(router_logits, torch.Tensor):
        return float(getattr(model.config, 'num_experts_per_tok', 2))
    g = torch.nn.functional.softmax(router_logits, dim=-1)
    if g.dim() == 2:
        bsz, nE = g.shape
        g = g.view(bsz, 1, nE)
    device = g.device
    dtype = g.dtype

    # History-based score Sh (from layer buffers)
    eps = 1e-8
    w1, w2, w3 = 0.4, 0.4, 0.2
    beta = 0.25
    try:
        N_vec = moe_block.expert_usage_count.to(device=device, dtype=dtype)
        P_vec = moe_block.expert_total_probability.to(device=device, dtype=dtype)
        Cs_vec = moe_block.expert_contribution_score.to(device=device, dtype=dtype)
    except Exception:
        # 如果缓冲未初始化，则退回全零
        nE = g.shape[-1]
        N_vec = torch.zeros(nE, device=device, dtype=dtype)
        P_vec = torch.zeros(nE, device=device, dtype=dtype)
        Cs_vec = torch.zeros(nE, device=device, dtype=dtype)

    sumN = N_vec.sum().clamp_min(torch.tensor(eps, device=device, dtype=dtype))
    freq_term = N_vec / sumN
    prob_term = P_vec / (N_vec + eps)
    Sh = w1 * freq_term + w2 * prob_term + w3 * Cs_vec
    # Final importance IS
    IS = (1.0 - beta) * g + beta * Sh.view(1, 1, -1)

    # Dynamic threshold τ (no additional relaxation)
    tau = float(getattr(moe_block, 'current_threshold', getattr(model.config, 'expert_probability_threshold', 0.3)))
    K = int(getattr(moe_block, 'max_experts_per_token', getattr(model.config, 'num_experts_per_tok', 2)))

    # Build activation mask strictly following the rules
    activated_mask = torch.zeros_like(g, dtype=torch.bool)
    for i in range(g.shape[0]):
        for j in range(g.shape[1]):
            scores = IS[i, j]
            candidate = scores >= tau
            m = int(candidate.sum().item())
            if m == 0:
                idx = torch.topk(scores, 1).indices
                activated_mask[i, j].scatter_(0, idx, True)
            elif m <= K:
                activated_mask[i, j] = candidate
            else:
                masked_scores = scores.clone()
                masked_scores[~candidate] = -float('inf')
                top_idx = torch.topk(masked_scores, K).indices
                activated_mask[i, j].scatter_(0, top_idx, True)

    # Optional: normalize gating over the activated set (not used for counting)
    g_selected = g * activated_mask.float()
    _ = g_selected / (g_selected.sum(dim=-1, keepdim=True) + eps)

    # Count |A|
    active_per_token = activated_mask.sum(dim=-1).float()
    return active_per_token.mean().item() if active_per_token.numel() > 0 else 0.0


def setup_device() -> torch.device:
    # Honor CUDA_VISIBLE_DEVICES; use the first visible GPU (cuda:0) if available
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    return device


def load_model_and_tokenizer(
    model_path: str,
    device: torch.device,
) -> Tuple[PhiMoEForCausalLM, AutoTokenizer]:
    # Config (respect checkpoint defaults; do not override key MoE params)
    base_config: PhiMoEConfig = PhiMoEConfig.from_pretrained(
        model_path,
        output_router_logits=True,
        # Enable routing enhancements so that thresholding is effective
        use_innovative_routing=True,
        expert_probability_threshold=0.32,  # Higher base τ so <2 experts are often activated
        min_experts_per_token=1,            # Minimum 1 expert
        max_experts_per_token=2,            # Maximum 2 experts
        importance_weight=0.20,             # Lower importance weight to reduce bias
        prune_weight_epsilon=1e-6,          # Lower pruning epsilon to include small weights
        adaptation_rate=0.08,               # Lower adaptation rate for stability
        min_threshold=0.05,                 # Lower τ_min for flexibility
        max_threshold=0.50,                 # Higher τ_max for sparsity headroom
        sparsity_control_factor=0.75,       # Higher sparsity control factor
        expert_activation_threshold=0.005,  # Lower activation epsilon to activate more experts
        # Tail-layer relaxation
        tail_relax_start_layer=24,
        tail_threshold_scale=0.95,
    )

    # Memory/offload configuration
    offload_dir = ".offload_cache"
    os.makedirs(offload_dir, exist_ok=True)

    # Build per-GPU memory budgets from available VRAM to avoid accelerate max_memory errors
    dynamic_max_memory = None
    if torch.cuda.is_available():
        dynamic_max_memory = {"cpu": "400GiB"}
        try:
            num_devices = torch.cuda.device_count()
            for i in range(num_devices):
                try:
                    with torch.cuda.device(i):
                        free_bytes, total_bytes = torch.cuda.mem_get_info()
                    free_gib = int(free_bytes / (1024 ** 3))
                    total_gib = int(total_bytes / (1024 ** 3))
                    # Use similar budgeting policy as testmixscore.py
                    budget_gib = max(20, min(70, free_gib - 2))
                    print(f"GPU {i}: 总内存{total_gib}GB, 可用{free_gib}GB, 分配{budget_gib}GB")
                except Exception as e:
                    print(f"GPU {i} 内存检测失败: {e}")
                    budget_gib = 60
                # accelerate expects integer device indices, not "cuda:X"
                dynamic_max_memory[i] = f"{budget_gib}GiB"
        except Exception as e:
            print(f"GPU内存检测失败: {e}")
            # Fallback: assign conservative budgets for known IDs (include 0)
            dynamic_max_memory.update({
                0: "60GiB",
                1: "60GiB",
            })

    model = PhiMoEForCausalLM.from_pretrained(
        model_path,
        config=base_config,
        torch_dtype=torch.bfloat16,
        low_cpu_mem_usage=True,
        offload_state_dict=True,
        offload_folder=offload_dir,
        device_map="auto" if torch.cuda.is_available() else {"": "cpu"},
        max_memory=dynamic_max_memory,
        # Aggressive CPU offloading
        offload_buffers=True,
        # Disable flash_attention_2, use standard attention
        attn_implementation="eager",
    )
    model.eval()
    
    try:
        model.config.use_cache = False
    except Exception:
        pass

    tokenizer = AutoTokenizer.from_pretrained(model_path)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    return model, tokenizer


def load_text_data(path: str) -> List[str]:
    with open(path, "r", encoding="utf-8") as f:
        data = [line.strip() for line in f if line.strip()]
    return data


def main():
    model_path = "microsoft/Phi-3.5-MoE-instruct"
    wiki_data_path = ".wiki_demo.txt"

    device = setup_device()
    model, tokenizer = load_model_and_tokenizer(model_path, device)
    num_layers = model.config.num_hidden_layers

    # Inject per-layer thresholds to ensure the thresholding branch is active; print a snapshot for the first layers
    try:
        injected = 0
        for i in range(num_layers):
            layer_ref = model.model.layers[i]
            moe_block = getattr(layer_ref, 'mlp', None)
            if moe_block is None:
                moe_block = getattr(layer_ref, 'block_sparse_moe', None)
            if hasattr(moe_block, 'gate') and hasattr(moe_block, 'experts'):
                try:
                    # Set a relatively high initial τ and bounds to avoid over-decreasing early on
                    moe_block.current_threshold = getattr(model.config, 'expert_probability_threshold', 0.30)
                    moe_block.min_threshold = getattr(model.config, 'min_threshold', 0.05)
                    moe_block.max_threshold = getattr(model.config, 'max_threshold', 0.50)
                    # Ensure min/max experts per token exist on the block
                    if not hasattr(moe_block, 'min_experts_per_token'):
                        setattr(moe_block, 'min_experts_per_token', getattr(model.config, 'min_experts_per_token', 1))
                    if not hasattr(moe_block, 'max_experts_per_token'):
                        setattr(moe_block, 'max_experts_per_token', getattr(model.config, 'max_experts_per_token', 2))
                    injected += 1
                except Exception:
                    pass
        if injected:
            # Print snapshot of the first 5 layers' thresholds
            sample_layers = min(5, num_layers)
            layer_thresholds = []
            for li in range(sample_layers):
                layer_ref = model.model.layers[li]
                moe_block = getattr(layer_ref, 'mlp', None)
                if moe_block is None:
                    moe_block = getattr(layer_ref, 'block_sparse_moe', None)
                layer_thresholds.append(getattr(moe_block, 'current_threshold', float('nan')) if moe_block is not None else float('nan'))
            try:
                import numpy as _np
                mean_thr = float(_np.nanmean(layer_thresholds))
            except Exception:
                mean_thr = float('nan')
            thr_str = ", ".join([f"{t:.4f}" if isinstance(t, float) else str(t) for t in layer_thresholds])
            print(f"[Threshold Init] Set thresholds for {injected} layers; first {sample_layers} layers: [{thr_str}] | mean: {mean_thr:.4f}")
    except Exception as _e_inject:
        print(f"[Threshold Init] Failed: {_e_inject}")

    wiki_data = load_text_data(wiki_data_path)

    ppl_ls: List[float] = []
    layer_expert_activation_ls: List[List[float]] = [[] for _ in range(num_layers)]
    
    # Detailed activation records (optional)
    expert_activation_records: List[Dict] = []

    progress_bar = tqdm(wiki_data, desc="Compute PPL", postfix="PPL: 0.0000")

    # Initialize adaptive threshold manager
    threshold_manager = AdaptiveThresholdManager(num_layers, model.config)

    for text in progress_bar:
        # Clear GPU cache before processing each sample
        torch.cuda.empty_cache()
        gc.collect()
        
    # Check GPU free memory
        if torch.cuda.is_available():
            free_memory = torch.cuda.mem_get_info()[0]
            if free_memory < 100 * 1024 * 1024:  # Skip if <100MB free
                print(f"Low GPU memory ({free_memory/1024/1024/1024:.1f}GB free), skip: {text[:50]}...")
                continue
        
        inputs = tokenizer.encode(text, return_tensors="pt")
        if inputs.size(1) == 0:
            continue
        inputs = inputs[:, :192].to(device)  # Slightly shorten sequence length to reduce peak VRAM

        try:
            # Clear GPU cache
            torch.cuda.empty_cache()
            
            with torch.inference_mode():
                outputs = model(
                    inputs,
                    labels=inputs,
                    output_router_logits=True,
                )
            loss = outputs.loss
            if not torch.isfinite(loss):
                # Clear GPU cache
                del outputs, loss
                torch.cuda.empty_cache()
                gc.collect()
                continue

            ppl = torch.exp(loss).item()
            ppl_ls.append(ppl)
            
            # Clear GPU cache and GC
            del loss
            torch.cuda.empty_cache()
            gc.collect()

            # Per-layer statistics (IS-based activation counting)
            layer_avg_experts = []
            for layer_idx, router_logits in enumerate(outputs.router_logits):
                try:
                    if router_logits is None or not isinstance(router_logits, torch.Tensor) or router_logits.is_meta:
                        # Fallback to configured top-k when router logits are missing
                        fallback = float(getattr(model.config, 'num_experts_per_tok', 2))
                        layer_avg_experts.append(fallback)
                        layer_expert_activation_ls[layer_idx].append(fallback)
                        continue
                    avg_experts = calculate_active_experts_innovative(
                        router_logits, 
                        layer_idx,
                        model
                    )
                    layer_avg_experts.append(avg_experts)
                    layer_expert_activation_ls[layer_idx].append(avg_experts)
                except Exception as _e_layer:
                    # Print reason to avoid silent fallback
                    print(f"[层{layer_idx}] 统计异常: {_e_layer}")
                    # Avoid dropping the whole sample because of one layer
                    fallback = float(getattr(model.config, 'num_experts_per_tok', 2))
                    layer_avg_experts.append(fallback)
                    try:
                        layer_expert_activation_ls[layer_idx].append(fallback)
                    except Exception:
                        pass

            global_avg_experts = float(np.mean(layer_avg_experts)) if layer_avg_experts else 0.0
            progress_bar.set_postfix_str(f"PPL: {ppl:.4f}, Active experts: {global_avg_experts:.2f}")

            # Optimize thresholds every 10 samples
            if len(ppl_ls) % 10 == 0:
                # Analyze routing distribution
                routing_stats = threshold_manager.analyze_routing_distribution(outputs.router_logits)
                # Update per-layer thresholds
                threshold_manager.update_thresholds(model, outputs.router_logits, ppl)
                
                # Print current thresholds snapshot
                if routing_stats:
                    print(f"\n[Threshold Analysis] sample {len(ppl_ls)}: mean={routing_stats['mean_weight']:.6f}, "
                          f"std={routing_stats['std_weight']:.6f}")
                    print(f"[Threshold Analysis] weight percentiles: P10={routing_stats['percentiles'][0]:.6f}, "
                          f"P25={routing_stats['percentiles'][1]:.6f}, "
                          f"P50={routing_stats['percentiles'][2]:.6f}, "
                          f"P75={routing_stats['percentiles'][3]:.6f}, "
                          f"P80={routing_stats.get('p80', float('nan')):.6f}, "
                          f"P85={routing_stats.get('p85', float('nan')):.6f}, "
                          f"P90={routing_stats['percentiles'][4]:.6f}")

                    # Print the current τ of the first few layers
                    print(f"[Thresholds] first 5 layers τ: ", end="")
                    for li in range(min(5, num_layers)):
                        layer_ref = model.model.layers[li]
                        moe_block = getattr(layer_ref, 'mlp', None)
                        if moe_block is None:
                            moe_block = getattr(layer_ref, 'block_sparse_moe', None)
                        if moe_block is not None:
                            curr_thresh = getattr(moe_block, 'current_threshold', 0.0)
                            print(f"{curr_thresh:.3f}", end=" ")
                    print()
                    
                    # Compute a theoretical τ for reference
                    p85 = routing_stats.get('p85', routing_stats['percentiles'][4])
                    p90 = routing_stats['percentiles'][4]
                    tighter_mix = ((p85 + p90) / 2) * 1.25
                    theoretical = max(p90 * 1.05, tighter_mix)
                    print(f"[Threshold Theory] P85={p85:.4f}, P90={p90:.4f}, theoretical τ={theoretical:.4f}")
                
                # Force-clear GPU cache and GC
                torch.cuda.empty_cache()
                gc.collect()
            
            # Drop outputs to free memory
            del outputs
            torch.cuda.empty_cache()
            gc.collect()

        except Exception as e:
            print(f"Computation error, skip: {text[:50]}... Error: {e}")
            # Clear GPU cache and GC
            torch.cuda.empty_cache()
            gc.collect()
            continue

    # Final reporting
    if not ppl_ls:
        print("没有有效的PPL计算结果")
        return

    mean_ppl = float(np.mean(ppl_ls))

    # Per-layer average number of activated experts
    layer_mean_experts: List[float] = []
    for layer_data in layer_expert_activation_ls:
        if len(layer_data) > 0:
            val = float(np.mean(layer_data))
            if np.isnan(val):
                val = 0.0
            layer_mean_experts.append(val)
        else:
            layer_mean_experts.append(0.0)

    overall_avg_experts = float(np.mean(layer_mean_experts)) if layer_mean_experts else 0.0
    print(f"\nOverall average number of activated experts across layers: {overall_avg_experts:.2f} per token")

    for i, avg_experts in enumerate(layer_mean_experts):
        print(f"Layer {i}: {avg_experts:.2f} per token")

if __name__ == "__main__":
    main()


