import torch
import numpy as np
from tqdm import tqdm
from transformers import AutoTokenizer
import os
import pandas as pd
import matplotlib.pyplot as plt
import time
from typing import Dict, List, Optional, Tuple
from dataclasses import dataclass
from enum import Enum

from .modeling_mixtralscore import MixtralForCausalLM
from .configuration_mixtralscore import MixtralConfig

# Global smoothing constant to prevent divide-by-zero, aligned with paper (ε = 1e-8)
EPSILON = 1e-8

# Avoid CUDA memory fragmentation that could cause OOM
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_path = 'mistralai/Mixtral-8x7B-Instruct-v0.1'

class ExpertSelectionStrategy(Enum):
    """Enumeration of expert selection strategies"""
    ADAPTIVE_THRESHOLD = "adaptive_threshold"
    IMPORTANCE_BASED = "importance_based"
    LOAD_BALANCED = "load_balanced"
    HYBRID = "hybrid"
    DYNAMIC_LEARNING = "dynamic_learning"

@dataclass
class ExpertUsageStats:
    """Statistics about expert usage"""
    activation_count: int = 0
    total_probability: float = 0.0
    avg_probability: float = 0.0
    last_used: float = 0.0
    importance_score: float = 1.0
    efficiency_score: float = 1.0

class ExpertImportanceTracker:
    """Expert importance tracker"""
    
    def __init__(self, num_experts: int, decay_factor: float = 0.95):
        self.num_experts = num_experts
        self.decay_factor = decay_factor
        self.expert_stats = [ExpertUsageStats() for _ in range(num_experts)]
        self.global_step = 0
        
    def update_stats(self, selected_experts: torch.Tensor, 
                    routing_weights: torch.Tensor, 
                    expert_outputs: torch.Tensor):
        self.global_step += 1
        current_time = time.time()
        
        # Track expert activations
        for i in range(selected_experts.shape[0]):
            for j in range(selected_experts.shape[1]):
                expert_idx = selected_experts[i, j].item()
                if expert_idx > 0:  
                    weight = routing_weights[i, j].item()
                    
                    stats = self.expert_stats[expert_idx]
                    stats.activation_count += 1
                    stats.total_probability += weight
                    stats.avg_probability = stats.total_probability / stats.activation_count
                    stats.last_used = current_time
                    
                    # Compute contribution score Cs using exponential smoothing
                    # Paper: Cs = (1-α) * C_prev + α * ||o||, α=0.95
                    if expert_outputs is not None:
                        output_norm = torch.norm(expert_outputs[i, j]).item()
                        alpha = 0.95
                        stats.efficiency_score = (1.0 - alpha) * stats.efficiency_score + alpha * output_norm
        
        # Compute importance scores
        self._compute_importance_scores()
    
    def _compute_importance_scores(self):
        """Compute expert importance scores"""
        total_activations = sum(stats.activation_count for stats in self.expert_stats)
        if total_activations == 0:
            return
        
        for stats in self.expert_stats:
            # Importance based on activation frequency
            frequency_score = stats.activation_count / total_activations
            
            # Importance based on average routing probability
            probability_score = stats.avg_probability
            
            # Importance based on efficiency
            efficiency_score = stats.efficiency_score
            
            # Combined importance score
            stats.importance_score = (
                0.4 * frequency_score + 
                0.4 * probability_score + 
                0.2 * efficiency_score
            )

class PerformanceMonitor:
    """Performance monitor"""
    
    def __init__(self, window_size: int = 100):
        self.window_size = window_size
        self.performance_history = []
        self.expert_usage_history = []
        
    def update(self, selected_experts: torch.Tensor, 
               selected_weights: torch.Tensor, 
               expert_outputs: torch.Tensor):
        # Compute expert usage efficiency
        experts_per_token = (selected_experts != 0).sum(dim=-1).float().mean().item()
        
        # Compute output quality
        output_quality = torch.norm(expert_outputs, dim=-1).mean().item()
        
        # Compute routing weight distribution quality
        weight_entropy = -torch.sum(selected_weights * torch.log(selected_weights + EPSILON), dim=-1).mean().item()
        
        performance_metric = (output_quality + weight_entropy) / 2
        
        self.performance_history.append(performance_metric)
        self.expert_usage_history.append(experts_per_token)
        
        if len(self.performance_history) > self.window_size:
            self.performance_history.pop(0)
            self.expert_usage_history.pop(0)
    
    def get_performance_stats(self) -> Dict:
        """get performance"""
        if not self.performance_history:
            return {}
        
        return {
            'avg_performance': np.mean(self.performance_history),
            'avg_experts_per_token': np.mean(self.expert_usage_history),
            'performance_trend': self.performance_history[-10:],
            'usage_trend': self.expert_usage_history[-10:]
        }

class AdaptiveThresholdManager:
    """Adaptive threshold manager to determine thresholds"""
    
    def __init__(self, num_layers: int, base_config):
        self.num_layers = num_layers
        self.base_config = base_config
        self.threshold_history = []
        self.performance_history = []
        
    def analyze_routing_distribution(self, router_logits_list: List[torch.Tensor]) -> Dict:
        """Analyze the distribution of routing weights"""
        if not router_logits_list:
            return {}
        
        all_weights = []
        for router_logits in router_logits_list:
            if router_logits is not None:
                weights = torch.nn.functional.softmax(router_logits, dim=-1)
                all_weights.append(weights.flatten())
        
        if not all_weights:
            return {}
        
        combined_weights = torch.cat(all_weights)
        # Ensure quantile input is float32/float64 to avoid errors for bfloat16
        combined_weights = combined_weights.to(torch.float32)
        
        mean_weight = combined_weights.mean().item()
        std_weight = combined_weights.std().item()
        q = torch.tensor([0.1, 0.25, 0.5, 0.75, 0.9], device=combined_weights.device, dtype=torch.float32)
        percentiles = torch.quantile(combined_weights, q).tolist()
        
        return {
            'mean_weight': mean_weight,
            'std_weight': std_weight,
            'percentiles': percentiles,
            'total_weights': combined_weights.numel()
        }
    
    def suggest_optimal_threshold(self, routing_stats: Dict, current_performance: float) -> float:
        """Suggest optimal threshold based on routing distribution and performance
        Implements the three-stage strategy in the paper:
        1) Theoretical grounding: τ_theo = max(α1·μ_w, μ_w − α2·σ_w)
        2) Performance-based adjustment: τ_PPL = τ_theo·(1±λ)
        3) Boundedness with exponential smoothing: τ_final = κ1·τ_PPL + κ2·τ_theo, and clamp to [τ_min, τ_max]
        """
        if not routing_stats:
            return getattr(self.base_config, 'expert_probability_threshold', 0.3)
        
        mean_weight = routing_stats['mean_weight']
        std_weight = routing_stats['std_weight']
        percentiles = routing_stats.get('percentiles', None)

        # 1) Theoretical Grounding
        alpha1 = getattr(self.base_config, 'alpha1', 1.0)
        alpha2 = getattr(self.base_config, 'alpha2', 1.0)
        tau_theo = max(alpha1 * mean_weight, mean_weight - alpha2 * std_weight)

        # 2) Performance-based Adjustment with ΔPPL
        # ΔPPL = (PPL_current - PPL_prev) / PPL_prev
        if len(self.performance_history) >= 1:
            prev_ppl = self.performance_history[-1]
            denom = prev_ppl if abs(prev_ppl) > 1e-8 else 1e-8
            delta_ppl = (current_performance - prev_ppl) / denom

            # λ− decreases threshold (activate more experts) when ΔPPL > 0 (worse)
            # λ+ increases threshold (more sparsity) when ΔPPL < 0 (better)
            lambda_decrease = getattr(self.base_config, 'lambda_decrease', 0.2)  # λ−
            lambda_increase = getattr(self.base_config, 'lambda_increase', 0.1)  # λ+

            if delta_ppl > 0:
                tau_ppl = tau_theo * (1.0 - lambda_decrease)
            elif delta_ppl < 0:
                tau_ppl = tau_theo * (1.0 + lambda_increase)
            else:
                tau_ppl = tau_theo
        else:
            tau_ppl = tau_theo

        # 3) Boundedness Enforcement and Exponential Smoothing
        # Bounds: use 10th/90th percentiles of routing weights if available
        if percentiles and len(percentiles) >= 5:
            min_thresh = percentiles[0]
            max_thresh = percentiles[4]
        else:
            min_thresh = getattr(self.base_config, 'min_threshold', 0.2)
            max_thresh = getattr(self.base_config, 'max_threshold', 0.8)

        kappa1 = getattr(self.base_config, 'kappa1', 0.7)
        kappa2 = getattr(self.base_config, 'kappa2', 0.3)
        tau_final = kappa1 * tau_ppl + kappa2 * tau_theo
        tau_final = max(min_thresh, min(max_thresh, tau_final))

        return tau_final
    
    def update_thresholds(self, model, routing_stats: Dict, current_performance: float):
        """Update thresholds for all layers"""
        optimal_threshold = self.suggest_optimal_threshold(routing_stats, current_performance)
        
        updated_layers = 0
        for i in range(self.num_layers):
            mlp = getattr(model.model.layers[i], 'block_sparse_moe', None)
            if hasattr(mlp, 'adaptive_expert_selection'):
                # Check whether current threshold needs update
                current_thresh = getattr(mlp, 'current_threshold', 0.3)
                if abs(current_thresh - optimal_threshold) > 0.001:  # Update only if difference is meaningful
                    mlp.current_threshold = optimal_threshold
                    updated_layers += 1
        
        if updated_layers > 0:
            print(f"[Threshold Optimization] Updated {updated_layers} layers to {optimal_threshold:.4f}")
        
        self.threshold_history.append(optimal_threshold)
        self.performance_history.append(current_performance)
        
        if len(self.threshold_history) > 100:
            self.threshold_history.pop(0)
            self.performance_history.pop(0)


def test_threshold_mechanism():
    
    # Create a test tensor of routing weights
    batch_size, seq_len, num_experts = 2, 1, 8
    test_router_logits = torch.randn(batch_size, seq_len, num_experts)
    test_routing_weights = torch.nn.functional.softmax(test_router_logits, dim=-1)
    
    # Test the effect of different thresholds
    test_thresholds = [0.1, 0.2, 0.3, 0.4, 0.5]
    for thresh in test_thresholds:
        above_threshold = (test_routing_weights > thresh).sum().item()
        total_weights = test_routing_weights.numel()
        percentage = (above_threshold / total_weights) * 100
        print(f"  Threshold {thresh:>5}: {above_threshold:>3}/{total_weights} ({percentage:>5.1f}%) weights above threshold")
        
        # Compute number of activated experts per token
        experts_per_token = (test_routing_weights > thresh).sum(dim=-1).float()
        avg_experts = experts_per_token.mean().item()
        print(f"         Avg activated experts: {avg_experts:.2f}")
    

base_config = MixtralConfig.from_pretrained(
    model_path,
    num_experts_per_tok=2,
    num_local_experts=8,
    use_innovative_routing=True,
    expert_probability_threshold=0.18,  # higher base threshold to increase expert utilization
    min_experts_per_token=1,            # minimum 1 expert
    max_experts_per_token=2,            # maximum 2 experts
    importance_weight=0.25,             # lower importance weight to reduce bias
    prune_weight_epsilon=5e-4,          # lower pruning threshold to keep more experts
    adaptation_rate=0.08,               # lower adaptation rate for stability
    min_threshold=0.12,                 # higher min threshold
    max_threshold=0.28,                 # higher max threshold
    sparsity_control_factor=0.75,       # stronger sparsity control
    expert_activation_threshold=0.005,  # lower activation threshold to activate more experts
    # Adaptive threshold hyperparameters from the paper
    lambda_decrease=0.2,               # λ−
    lambda_increase=0.1,               # λ+
    kappa1=0.7,                        # κ1
    kappa2=0.3,                        # κ2
    alpha1=1.0,                        # α1
    alpha2=1.0,                        # α2
    # Weights for history-based importance Sh and final IS
    w1=0.4,
    w2=0.4,
    w3=0.2,
    beta_history=0.25,                 # β (default 0.25)
)

offload_dir = "/.offload_cache"
os.makedirs(offload_dir, exist_ok=True)

# Build memory budgets for each visible GPU based on free memory to avoid accelerate's max_memory errors
dynamic_max_memory = None

base_model = MixtralForCausalLM.from_pretrained(
    model_path,
    config=base_config,
    dtype=torch.bfloat16,
    low_cpu_mem_usage=True,
    offload_state_dict=True,
    offload_folder=offload_dir,
    device_map="auto" if torch.cuda.is_available() else {"": "cpu"},
    max_memory=dynamic_max_memory,
)
base_model.eval()
try:
    base_model.config.use_cache = False
except Exception:
    pass

tokenizer = AutoTokenizer.from_pretrained(model_path)
tokenizer.pad_token = tokenizer.eos_token
num_layers = base_model.config.num_hidden_layers

print(f"Model loaded, num layers: {num_layers}")
print(f"Innovative expert routing enabled: {base_model.config.use_innovative_routing}")
print(f"Initial threshold: {getattr(base_model.config, 'expert_probability_threshold', 0.3)}")
print(f"Importance weight: {getattr(base_model.config, 'importance_weight', 0.35)}")

# Verify MoE layer status
moe_layers = num_layers

test_threshold_mechanism()

# If some MoE layers lack threshold state, inject it and print first few layers
try:
    injected = 0
    for i in range(num_layers):
        mlp = getattr(base_model.model.layers[i], 'block_sparse_moe', None)
        if hasattr(mlp, 'adaptive_expert_selection'):
                # Sync config into module to ensure min/max experts take effect
            try:
                mlp.config.min_experts_per_token = getattr(base_model.config, 'min_experts_per_token', 1)
                mlp.config.max_experts_per_token = getattr(base_model.config, 'max_experts_per_token', 2)
                # Set to module attributes directly
                mlp.min_experts_per_token = getattr(base_model.config, 'min_experts_per_token', 1)
                mlp.max_experts_per_token = getattr(base_model.config, 'max_experts_per_token', 2)
                # Sync threshold settings
                mlp.current_threshold = getattr(base_model.config, 'expert_probability_threshold', 0.3)
                mlp.min_threshold = getattr(base_model.config, 'min_threshold', 0.2)
                mlp.max_threshold = getattr(base_model.config, 'max_threshold', 0.8)
            except Exception:
                pass

            if getattr(mlp, 'current_threshold', None) is None:
                mlp.current_threshold = getattr(base_model.config, 'expert_probability_threshold', 0.3)
                mlp.min_threshold = getattr(base_model.config, 'min_threshold', 0.2)
                mlp.max_threshold = getattr(base_model.config, 'max_threshold', 0.8)
                mlp.adaptation_rate = getattr(base_model.config, 'adaptation_rate', 0.12)
                mlp.min_experts_per_token = getattr(base_model.config, 'min_experts_per_token', 1)
                mlp.max_experts_per_token = getattr(base_model.config, 'max_experts_per_token', 2)
                mlp.importance_weight = getattr(base_model.config, 'importance_weight', 0.35)
                num_local = getattr(base_model.config, 'num_local_experts', 8)
                # History-based stats per expert
                mlp.expert_importance = torch.ones(num_local, dtype=torch.float32)  # Sh
                mlp.expert_usage_count = torch.zeros(num_local, dtype=torch.long)   # N
                mlp.expert_total_probability = torch.zeros(num_local, dtype=torch.float32)  # P_total
                mlp.expert_contribution_score = torch.zeros(num_local, dtype=torch.float32)  # Cs
                injected += 1
            
            # Ensure max_experts_per_token is set properly
            if not hasattr(mlp, 'max_experts_per_token') or mlp.max_experts_per_token != 2:
                mlp.max_experts_per_token = 2
                print(f"Layer {i}: force set max_experts_per_token=2")

            # Move stats tensors to the same device as gating to avoid device mismatch
            try:
                device_mlp = mlp.gate.weight.device if hasattr(mlp, 'gate') else device
                if mlp.expert_importance is not None:
                    mlp.expert_importance = mlp.expert_importance.to(device_mlp)
                if mlp.expert_usage_count is not None:
                    mlp.expert_usage_count = mlp.expert_usage_count.to(device_mlp)
                if hasattr(mlp, 'expert_total_probability') and mlp.expert_total_probability is not None:
                    mlp.expert_total_probability = mlp.expert_total_probability.to(device_mlp)
                if hasattr(mlp, 'expert_contribution_score') and mlp.expert_contribution_score is not None:
                    mlp.expert_contribution_score = mlp.expert_contribution_score.to(device_mlp)
            except Exception:
                pass

    if injected:
        print(f"Injected threshold state for {injected} MoE layers")
    for i in range(min(5, num_layers)):
        mlp = getattr(base_model.model.layers[i], 'block_sparse_moe', None)
        if hasattr(mlp, 'adaptive_expert_selection'):
            print(
                f"Layer {i} initial threshold: {getattr(mlp, 'current_threshold', None)} | "
                f"min_k={getattr(mlp, 'min_experts_per_token', getattr(base_model.config, 'min_experts_per_token', 'NA'))}, "
                f"max_k={getattr(mlp, 'max_experts_per_token', getattr(base_model.config, 'max_experts_per_token', 'NA'))}"
            )
except Exception as _e:
    print(f"Failed to inject threshold state: {_e}")

wiki_data_path = '.wiki_demo.txt'
with open(wiki_data_path, 'r', encoding='utf-8') as f:
    wiki_data = [line.strip() for line in f if line.strip()]
print(f"Loaded {len(wiki_data)} text samples")


def calculate_active_experts_innovative(router_logits, layer_idx):
    """Compute number of active experts (Mixtral uses MoE in all layers)"""
    moe_block = base_model.model.layers[layer_idx].block_sparse_moe
    
    # If no routing info, return configured top-k
    if router_logits is None or not isinstance(router_logits, torch.Tensor):
        return float(getattr(base_model.config, 'num_experts_per_tok', 2))

    routing_weights = torch.nn.functional.softmax(router_logits, dim=-1)
    
    # Get the MoE block of this layer
    moe_block = base_model.model.layers[layer_idx].block_sparse_moe
    
    if routing_weights.dim() == 2:
        batch_size, num_experts = routing_weights.shape
        routing_weights = routing_weights.view(batch_size, 1, num_experts)
    
    # Mixtral uses MoE for all layers
    # Dual-constraint activation:
    # 1) Build candidate set E by thresholding IS (combined scores)
    # 2) If |E|=0 select argmax IS; if 1<=|E|<=K select E; if |E|>K select top-K by IS in E
    max_experts = getattr(moe_block, 'max_experts_per_token', 2)

    current_threshold = getattr(moe_block, 'current_threshold', 0.3)
    
    # Build history-based expert importance Sh per paper
    # Sh = w1 * N/ΣN + w2 * P_total/N + w3 * Cs
    num_experts_total = routing_weights.shape[-1]
    if hasattr(moe_block, 'expert_usage_count') and hasattr(moe_block, 'expert_total_probability') and hasattr(moe_block, 'expert_contribution_score'):
        w1 = getattr(base_model.config, 'w1', 0.4)
        w2 = getattr(base_model.config, 'w2', 0.4)
        w3 = getattr(base_model.config, 'w3', 0.2)

        N_vec = moe_block.expert_usage_count.to(routing_weights.device, dtype=routing_weights.dtype)
        Ptotal_vec = moe_block.expert_total_probability.to(routing_weights.device, dtype=routing_weights.dtype)
        Cs_vec = moe_block.expert_contribution_score.to(routing_weights.device, dtype=routing_weights.dtype)

        sumN = N_vec.sum().clamp_min(torch.tensor(EPSILON, device=N_vec.device, dtype=N_vec.dtype))
        freq_term = N_vec / sumN
        prob_term = Ptotal_vec / (N_vec + EPSILON)
        Sh = w1 * freq_term + w2 * prob_term + w3 * Cs_vec
        # Normalize Sh to [0,1] for stability
        Sh = (Sh - Sh.min()) / (Sh.max() - Sh.min() + EPSILON)
    else:
        Sh = torch.ones(num_experts_total, device=routing_weights.device, dtype=routing_weights.dtype)

    # Final importance score IS_i = (1-β) * g_i + β * Sh_i
    beta_hist = getattr(base_model.config, 'beta_history', 0.25)
    combined_scores = (1.0 - beta_hist) * routing_weights + beta_hist * Sh.view(1, 1, -1)

    # Build final selection mask per token according to the three-case rule
    threshold_mask = torch.zeros_like(routing_weights, dtype=torch.bool)
    for i in range(routing_weights.shape[0]):
        for j in range(routing_weights.shape[1]):
            scores = combined_scores[i, j]
            # Candidate set E
            candidate_mask = scores >= current_threshold
            candidate_count = int(candidate_mask.sum().item())
            if candidate_count == 0:
                # {argmax IS_i}
                top_idx = torch.topk(scores, 1).indices
                threshold_mask[i, j].scatter_(0, top_idx, True)
            elif candidate_count <= max_experts:
                # Use all candidates
                threshold_mask[i, j] = candidate_mask
            else:
                # Select top-K by combined score within candidates
                candidate_scores = scores.clone()
                candidate_scores[~candidate_mask] = -float('inf')
                top_k_indices = torch.topk(candidate_scores, max_experts).indices
                threshold_mask[i, j].scatter_(0, top_k_indices, True)
    
    min_experts = getattr(moe_block, 'min_experts_per_token', 1)
    max_experts = getattr(moe_block, 'max_experts_per_token', 2)
    
    # Apply mask and normalize only over activated set A
    selected_weights = routing_weights * threshold_mask.float()
    weight_sums = selected_weights.sum(dim=-1, keepdim=True)
    selected_weights = selected_weights / (weight_sums + EPSILON)
    # Expert indices for activated set
    expert_ids = torch.arange(routing_weights.shape[-1], device=routing_weights.device).view(1, 1, -1).expand_as(routing_weights)
    selected_experts = expert_ids * threshold_mask.long()

    # Update history statistics (N, P_total, Cs) for next steps
    if hasattr(moe_block, 'expert_usage_count'):
        with torch.no_grad():
            # For each selected expert, accumulate counts and probs, update Cs with α=0.95
            alpha = 0.95
            # count updates
            counts = threshold_mask.sum(dim=(0, 1)).to(moe_block.expert_usage_count.device)
            moe_block.expert_usage_count[:counts.numel()] += counts.to(moe_block.expert_usage_count.dtype)
            # probability sum updates
            probs_sum = selected_weights.sum(dim=(0, 1)).to(
                getattr(moe_block, 'expert_total_probability').device
            )
            moe_block.expert_total_probability[:probs_sum.numel()] += probs_sum.to(moe_block.expert_total_probability.dtype)
            # Cs is now updated inside the model forward using true expert outputs
    
    # Compute number of active experts (based on non-zero weights)
    eps = getattr(moe_block, 'prune_weight_epsilon', 1e-8)
    active_count_per_token = (selected_weights > eps).sum(dim=-1).float()
    
    if active_count_per_token.numel() > 0:
        active_experts = active_count_per_token.mean().item()
    else:
        active_experts = 0.0
    
    if torch.isnan(torch.tensor(active_experts)) or active_experts < 0:
        active_experts = 0.0
    
    return active_experts

threshold_manager = AdaptiveThresholdManager(num_layers, base_config)

print("Computing PPL...")
ppl_ls = []
layer_expert_activation_ls = [[] for _ in range(num_layers)]
progress_bar = tqdm(wiki_data, desc="Compute PPL", postfix="PPL: 0.0000")  # Aligned with deepseekscore: iterate all samples

for text in progress_bar:
    inputs = tokenizer.encode(text, return_tensors='pt')
    
    if inputs.size(1) == 0:
        continue
    
    # Align with deepseekscore's truncation: follow model limit and cap at 512
    max_length = getattr(base_model.config, 'n_positions', None)
    if max_length is None:
        max_length = getattr(base_model.config, 'max_position_embeddings', 1024)
    max_length = min(max_length, 512)
    inputs = inputs[:, :max_length].to(device)
    
    with torch.no_grad():
        try:
            outputs = base_model(
                inputs, 
                labels=inputs, 
                output_router_logits=True
            )
            loss = outputs.loss
            
            if torch.isfinite(loss):
                ppl = torch.exp(loss).item()
                ppl_ls.append(ppl)
                
                layer_avg_experts = []
                for layer_idx, router_logits in enumerate(outputs.router_logits):
                    try:
                        if router_logits is None:
                            # If no routing info, fallback to configured top-k
                            fallback = float(getattr(base_model.config, 'num_experts_per_tok', 2))
                            layer_avg_experts.append(fallback)
                            layer_expert_activation_ls[layer_idx].append(fallback)
                            continue
                        avg_experts = calculate_active_experts_innovative(
                            router_logits, 
                            layer_idx
                        )
                        layer_avg_experts.append(avg_experts)
                        layer_expert_activation_ls[layer_idx].append(avg_experts)
                    except Exception as _e_layer:
                        # Avoid dropping whole sample due to one-layer failure
                        fallback = float(getattr(base_model.config, 'num_experts_per_tok', 2))
                        layer_avg_experts.append(fallback)
                        try:
                            layer_expert_activation_ls[layer_idx].append(fallback)
                        except Exception:
                            pass
                
                global_avg_experts = np.mean(layer_avg_experts)
                progress_bar.set_postfix_str(f"PPL: {ppl:.4f}, 激活专家: {global_avg_experts:.2f}")
                
                # Reduce overhead: analyze/print thresholds less frequently
                # Optimize thresholds after every 50 samples
                if len(ppl_ls) % 50 == 0:
                    routing_stats = threshold_manager.analyze_routing_distribution(outputs.router_logits)
                    threshold_manager.update_thresholds(base_model, routing_stats, ppl)
                    
                    # Reduce log density: print detailed distribution every 100 samples when available
                    if routing_stats and (len(ppl_ls) % 100 == 0):
                        print(f"\n[Threshold Analysis] Sample {len(ppl_ls)}: mean={routing_stats['mean_weight']:.6f}, std={routing_stats['std_weight']:.6f}")
                        print(f"[Threshold Analysis] Percentiles: P10={routing_stats['percentiles'][0]:.6f}, P25={routing_stats['percentiles'][1]:.6f}, "
                              f"P50={routing_stats['percentiles'][2]:.6f}, P75={routing_stats['percentiles'][3]:.6f}, P90={routing_stats['percentiles'][4]:.6f}")
            else:
                print(f"Warning: invalid loss, skip: {text[:50]}...")
                continue
        except RuntimeError as e:
            # Handle OOM: clear cache and skip
            if 'out of memory' in str(e):
                print(f"CUDA OOM, skip: {text[:50]}... clearing cache.")
                try:
                    torch.cuda.empty_cache()
                except Exception:
                    pass
            else:
                print(f"Computation error, skip: {text[:50]}... error: {e}")
            continue
        except Exception as e:
            print(f"Computation error, skip: {text[:50]}... error: {e}")
            try:
                torch.cuda.empty_cache()
            except Exception:
                pass
            continue

if ppl_ls:
    mean_ppl = np.mean(ppl_ls)
    print(f"Mean PPL: {mean_ppl:.4f}")
    
    layer_mean_experts = []
    for layer_data in layer_expert_activation_ls:
        if len(layer_data) > 0:
            mean_val = np.mean(layer_data)
            if np.isnan(mean_val):
                mean_val = 0.0
            layer_mean_experts.append(mean_val)
        else:
            layer_mean_experts.append(0.0)
    
    overall_avg_experts = np.mean(layer_mean_experts)
    print(f"\nOverall average active experts across all layers: {overall_avg_experts:.2f} per token")
    
    print("\nPer-layer average active experts:")
    for i, avg_experts in enumerate(layer_mean_experts):
        print(f"Layer {i}: {avg_experts:.2f} per token")
    
    
    print(f"- Overall average active experts: {overall_avg_experts:.2f} per token")
    print(f"- Mean PPL: {mean_ppl:.4f}")
    
    
    if hasattr(threshold_manager, 'threshold_history') and threshold_manager.threshold_history:

        if len(threshold_manager.threshold_history) > 1:
            threshold_trend = threshold_manager.threshold_history[-1] - threshold_manager.threshold_history[0]
            if threshold_trend > 0.001:
                print(f"- Threshold trend: Up (+{threshold_trend:.4f})")
            elif threshold_trend < -0.001:
                print(f"- Threshold trend: Down ({threshold_trend:.4f})")
            else:
                print(f"- Threshold trend: Stable (±{abs(threshold_trend):.4f})")

else:
    print("No valid PPL results")
    
