import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
import numpy as np
from collections import defaultdict
import json
from typing import Dict, List, Optional, Tuple
from tqdm import tqdm
import sys
import os
import argparse

# Global storage for expert routing data
EXPERT_ROUTING_DATA = []

def patch_moe_layers(model):
    """Patch MoE layers in the loaded model to collect routing statistics."""
    
    # Find and patch all MoE layers
    moe_layers = []
    original_forwards = []
    
    print("Scanning model for MoE layers...")
    
    for layer_idx, layer in enumerate(model.model.layers):
        # Check if this layer has MoE (DeepseekMoE)
        if hasattr(layer.mlp, 'gate') and hasattr(layer.mlp, 'experts'):
            print(f"Found MoE layer at index {layer_idx}")
            moe_layers.append((layer_idx, layer.mlp))
            
            # Store original forward method
            original_forwards.append(layer.mlp.gate.forward)
            
            # Create patched forward method
            def create_patched_forward(self, layer_idx, original_forward):
                def patched_moe_gate_forward(hidden_states):
                    """Patched MoEGate forward that collects routing data."""
                    bsz, seq_len, h = hidden_states.shape        
                    
                    # Compute gating score (from DeepSeek implementation)
                    hidden_states_flat = hidden_states.view(-1, h)
                    logits = torch.nn.functional.linear(hidden_states_flat, self.weight, None)
                    
                    if self.scoring_func == 'softmax':
                        scores = logits.softmax(dim=-1)
                    else:
                        raise NotImplementedError(f'unsupported scoring function for MoE gating: {self.scoring_func}')
                    
                    # Select top-k experts
                    topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
                    
                    # Norm gate to sum 1
                    if self.top_k > 1 and self.norm_topk_prob:
                        denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
                        topk_weight = topk_weight / denominator

                    # Store routing data for analysis
                    routing_data = {
                        'layer_idx': layer_idx,
                        'router_logits': logits.detach().cpu(),  # [batch*seq_len, num_experts]
                        'router_probs': scores.detach().cpu(),   # [batch*seq_len, num_experts] 
                        'topk_indices': topk_idx.detach().cpu(), # [batch*seq_len, top_k]
                        'topk_weights': topk_weight.detach().cpu(), # [batch*seq_len, top_k]
                        'batch_size': bsz,
                        'seq_len': seq_len,
                        'num_experts': self.n_routed_experts,
                        'top_k': self.top_k
                    }
                    EXPERT_ROUTING_DATA.append(routing_data)
                    
                    # Compute auxiliary loss (from original implementation)
                    aux_loss = None
                    if self.training and self.alpha > 0.0:
                        scores_for_aux = scores
                        aux_topk = self.top_k
                        topk_idx_for_aux_loss = topk_idx.view(bsz, -1)
                        if self.seq_aux:
                            scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1)
                            ce = torch.zeros(bsz, self.n_routed_experts, device=hidden_states.device)
                            ce.scatter_add_(1, topk_idx_for_aux_loss, torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device)).div_(seq_len * aux_topk / self.n_routed_experts)
                            aux_loss = (ce * scores_for_seq_aux.mean(dim = 1)).sum(dim = 1).mean() * self.alpha
                        else:
                            mask_ce = torch.nn.functional.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts)
                            ce = mask_ce.float().mean(0)
                            Pi = scores_for_aux.mean(0)
                            fi = ce * self.n_routed_experts
                            aux_loss = (Pi * fi).sum() * self.alpha
                    
                    return topk_idx, topk_weight, aux_loss
                
                return patched_moe_gate_forward
            
            # Apply the patch
            layer.mlp.gate.forward = create_patched_forward(layer.mlp.gate, layer_idx, layer.mlp.gate.forward)
    
    print(f"Patched {len(moe_layers)} MoE layers")
    return moe_layers, original_forwards

def restore_moe_layers(moe_layers, original_forwards):
    """Restore original forward methods for MoE layers."""
    print("Restoring original MoE implementations...")
    for (layer_idx, moe_layer), original_forward in zip(moe_layers, original_forwards):
        moe_layer.gate.forward = original_forward

def analyze_expert_routing_data(routing_data_list: List[Dict]) -> Dict:
    """Analyze collected expert routing data."""
    if not routing_data_list:
        return {}
    
    # Get configuration from first batch
    first_batch = routing_data_list[0]
    num_experts = first_batch['num_experts']
    top_k = first_batch['top_k']
    
    # Group data by layer
    layer_data = defaultdict(list)
    for batch_data in routing_data_list:
        layer_idx = batch_data['layer_idx']
        layer_data[layer_idx].append(batch_data)
    
    print(f"Analyzing routing data from {len(layer_data)} layers...")
    
    # Overall statistics
    expert_usage_count = np.zeros(num_experts, dtype=np.int64)
    expert_total_weight = np.zeros(num_experts, dtype=np.float64)
    expert_activation_count = np.zeros(num_experts, dtype=np.int64)
    
    total_tokens = 0
    all_entropy_scores = []
    layer_statistics = {}
    
    for layer_idx, layer_batches in layer_data.items():
        print(f"Processing layer {layer_idx} with {len(layer_batches)} batches...")
        
        layer_expert_usage = np.zeros(num_experts, dtype=np.int64)
        layer_expert_weight = np.zeros(num_experts, dtype=np.float64)
        layer_tokens = 0
        layer_entropy_scores = []
        
        for batch_data in layer_batches:
            batch_size = batch_data['batch_size']
            seq_len = batch_data['seq_len']
            
            router_probs = batch_data['router_probs']  # [batch*seq_len, num_experts]
            topk_indices = batch_data['topk_indices']  # [batch*seq_len, top_k]
            topk_weights = batch_data['topk_weights']  # [batch*seq_len, top_k]
            
            num_tokens = router_probs.shape[0]
            total_tokens += num_tokens
            layer_tokens += num_tokens
            
            # Count expert usage
            for i in range(num_tokens):
                for k in range(top_k):
                    expert_idx = topk_indices[i, k].item()
                    expert_usage_count[expert_idx] += 1
                    expert_total_weight[expert_idx] += topk_weights[i, k].float().item()
                    layer_expert_usage[expert_idx] += 1
                    layer_expert_weight[expert_idx] += topk_weights[i, k].float().item()
            
            # Count expert activations
            for expert_idx in range(num_experts):
                expert_mask = (topk_indices == expert_idx).any(dim=1)
                expert_activation_count[expert_idx] += expert_mask.sum().long().item()
            
            # Compute entropy
            for i in range(num_tokens):
                probs = router_probs[i].float().numpy()  # Convert BFloat16 to Float32 then numpy
                probs = probs + 1e-12
                entropy = -np.sum(probs * np.log(probs))
                all_entropy_scores.append(entropy)
                layer_entropy_scores.append(entropy)
        
        # Store layer-specific statistics
        layer_statistics[layer_idx] = {
            'expert_usage_count': layer_expert_usage.tolist(),
            'expert_total_weight': layer_expert_weight.tolist(),
            'tokens_processed': layer_tokens,
            'mean_entropy': np.mean(layer_entropy_scores) if layer_entropy_scores else 0.0,
            'most_used_expert': int(np.argmax(layer_expert_usage)),
            'least_used_expert': int(np.argmin(layer_expert_usage))
        }
    
    # Compute final statistics
    expert_usage_percentage = (expert_usage_count / (total_tokens * top_k)) * 100
    expert_avg_weight = expert_total_weight / np.maximum(expert_usage_count, 1)
    expert_activation_percentage = (expert_activation_count / total_tokens) * 100
    
    mean_entropy = np.mean(all_entropy_scores) if all_entropy_scores else 0.0
    std_entropy = np.std(all_entropy_scores) if all_entropy_scores else 0.0
    
    # Find most and least used experts
    most_used_expert = np.argmax(expert_usage_count)
    least_used_expert = np.argmin(expert_usage_count)
    
    statistics = {
        'total_tokens_processed': int(total_tokens),
        'num_experts': int(num_experts),
        'top_k': int(top_k),
        'num_moe_layers': len(layer_data),
        'expert_usage_count': expert_usage_count.tolist(),
        'expert_usage_percentage': expert_usage_percentage.tolist(),
        'expert_average_weight': expert_avg_weight.tolist(),
        'expert_activation_count': expert_activation_count.tolist(),
        'expert_activation_percentage': expert_activation_percentage.tolist(),
        'routing_entropy': {
            'mean': float(mean_entropy),
            'std': float(std_entropy),
            'all_scores': all_entropy_scores[:1000]  # Store first 1000 for analysis
        },
        'expert_balance': {
            'most_used_expert': int(most_used_expert),
            'most_used_count': int(expert_usage_count[most_used_expert]),
            'least_used_expert': int(least_used_expert),
            'least_used_count': int(expert_usage_count[least_used_expert]),
            'usage_coefficient_of_variation': float(np.std(expert_usage_count) / np.mean(expert_usage_count)) if np.mean(expert_usage_count) > 0 else 0.0
        },
        'layer_statistics': layer_statistics
    }
    
    return statistics

def collect_deepseek_expert_statistics(
    model_id: str = "deepseek-ai/deepseek-moe-16b-chat",
    dataset_name: str = "OpenAssistant/oasst1", 
    num_samples: int = 100,
    max_length: int = 512,
    batch_size: int = 2
):
    """Collect expert statistics from a DeepSeek MoE model."""
    
    print(f"Loading model: {model_id}")
    
    # Load tokenizer and model
    tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        torch_dtype=torch.bfloat16,
        device_map="auto",
        trust_remote_code=True
    )
    
    # Set padding token if not set
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    print(f"Loading dataset: {dataset_name}")
    
    # Load and prepare dataset
    try:
        dataset = load_dataset(dataset_name, split=f"train[:{num_samples}]")
    except:
        dataset = load_dataset(dataset_name, split="train").select(range(min(num_samples, 1000)))
    
    def tokenize_function(examples):
        return tokenizer(
            examples["text"] if "text" in examples else examples.get("prompt", ""),
            truncation=True,
            max_length=max_length,
            padding=False
        )
    
    dataset = dataset.map(tokenize_function, batched=True, remove_columns=dataset.column_names)
    
    # Create dataloader
    from torch.utils.data import DataLoader
    
    def collate_fn(batch):
        return tokenizer.pad(batch, return_tensors="pt")
    
    dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=collate_fn, shuffle=False)
    
    print("Patching DeepSeek MoE layers...")
    
    # Patch the model to collect routing data
    moe_layers, original_forwards = patch_moe_layers(model)
    
    if not moe_layers:
        print("No MoE layers found in the model!")
        return None
    
    try:
        print("Collecting expert routing statistics...")
        
        global EXPERT_ROUTING_DATA
        EXPERT_ROUTING_DATA = []
        
        model.eval()
        with torch.no_grad():
            for batch_idx, batch in enumerate(tqdm(dataloader, desc="Processing batches")):
                # Move batch to model device
                batch = {k: v.to(model.device) for k, v in batch.items()}
                
                # Forward pass (routing data will be collected by our patches)
                outputs = model(**batch)
                
                if batch_idx % 10 == 0:
                    print(f"Processed {batch_idx + 1} batches, collected {len(EXPERT_ROUTING_DATA)} routing samples")
        
        print(f"Collected routing data from {len(EXPERT_ROUTING_DATA)} samples")
        
        if not EXPERT_ROUTING_DATA:
            print("Warning: No routing data was collected.")
            return None
        
        # Analyze the collected data
        print("Analyzing expert routing patterns...")
        statistics = analyze_expert_routing_data(EXPERT_ROUTING_DATA)
        
        return statistics
        
    finally:
        # Always restore original implementations
        restore_moe_layers(moe_layers, original_forwards)

def print_expert_statistics(stats: Dict):
    """Print expert statistics in a readable format."""
    if not stats:
        print("No statistics to display")
        return
        
    print("\n" + "="*60)
    print("DEEPSEEK MoE EXPERT USAGE STATISTICS")
    print("="*60)
    
    print(f"Total tokens processed: {stats['total_tokens_processed']:,}")
    print(f"Number of experts: {stats['num_experts']}")
    print(f"Number of MoE layers: {stats['num_moe_layers']}")
    print(f"Top-K routing: {stats['top_k']}")
    
    print(f"\nRouting Entropy:")
    print(f"  Mean: {stats['routing_entropy']['mean']:.4f}")
    print(f"  Std:  {stats['routing_entropy']['std']:.4f}")
    
    print(f"\nExpert Balance:")
    balance = stats['expert_balance']
    print(f"  Most used expert:  #{balance['most_used_expert']} ({balance['most_used_count']:,} times)")
    print(f"  Least used expert: #{balance['least_used_expert']} ({balance['least_used_count']:,} times)")
    print(f"  Usage coefficient of variation: {balance['usage_coefficient_of_variation']:.4f}")
    
    print(f"\nTop 10 Most Used Experts:")
    usage_counts = np.array(stats['expert_usage_count'])
    usage_percentages = np.array(stats['expert_usage_percentage'])
    avg_weights = np.array(stats['expert_average_weight'])
    
    top_indices = np.argsort(usage_counts)[::-1][:10]
    for i, expert_idx in enumerate(top_indices):
        print(f"  {i+1:2d}. Expert #{expert_idx:2d}: {usage_counts[expert_idx]:6,} uses ({usage_percentages[expert_idx]:5.2f}%), avg weight: {avg_weights[expert_idx]:.4f}")
    
    print(f"\nBottom 10 Least Used Experts:")
    bottom_indices = np.argsort(usage_counts)[:10]
    for i, expert_idx in enumerate(bottom_indices):
        print(f"  {i+1:2d}. Expert #{expert_idx:2d}: {usage_counts[expert_idx]:6,} uses ({usage_percentages[expert_idx]:5.2f}%), avg weight: {avg_weights[expert_idx]:.4f}")
    
    # Print per-layer statistics
    print(f"\nPer-Layer Statistics:")
    for layer_idx, layer_stats in stats['layer_statistics'].items():
        layer_usage = np.array(layer_stats['expert_usage_count'])
        most_used = layer_stats['most_used_expert']
        print(f"  Layer {layer_idx}: {layer_stats['tokens_processed']:,} tokens, "
              f"entropy={layer_stats['mean_entropy']:.3f}, "
              f"most used expert #{most_used} ({layer_usage[most_used]:,} times)")

def convert_numpy_types(obj):
    """Recursively convert numpy types to native Python types for JSON serialization."""
    if isinstance(obj, dict):
        return {key: convert_numpy_types(value) for key, value in obj.items()}
    elif isinstance(obj, list):
        return [convert_numpy_types(item) for item in obj]
    elif isinstance(obj, np.integer):
        return int(obj)
    elif isinstance(obj, np.floating):
        return float(obj)
    elif isinstance(obj, np.ndarray):
        return obj.tolist()
    else:
        return obj

def save_statistics(stats: Dict, model_id: str):
    """Save statistics to files."""
    if not stats:
        return
        
    # Clean model ID for filename
    clean_model_id = model_id.replace("/", "_").replace("-", "_")
    
    # Convert numpy types to native Python types for JSON serialization
    json_safe_stats = convert_numpy_types(stats)
    
    # Save detailed JSON
    json_filename = f"{clean_model_id}_expert_stats.json"
    with open(json_filename, 'w') as f:
        json.dump(json_safe_stats, f, indent=2)
    print(f"\nDetailed statistics saved to: {json_filename}")
    
    # Save expert usage as PyTorch tensor
    expert_usage = torch.tensor(stats['expert_usage_count'], dtype=torch.long)
    pt_filename = f"{clean_model_id}_expert_usage.pt"
    torch.save(expert_usage, pt_filename)
    print(f"Expert usage tensor saved to: {pt_filename}")
    
    # Save additional analysis files
    
    # Save per-layer expert usage matrix
    layer_usage_matrix = []
    for layer_idx in sorted(stats['layer_statistics'].keys()):
        layer_stats = stats['layer_statistics'][layer_idx]
        layer_usage_matrix.append(layer_stats['expert_usage_count'])
    
    layer_matrix_tensor = torch.tensor(layer_usage_matrix, dtype=torch.long)
    layer_matrix_filename = f"{clean_model_id}_layer_expert_matrix.pt"
    torch.save(layer_matrix_tensor, layer_matrix_filename)
    print(f"Layer-expert usage matrix saved to: {layer_matrix_filename}")
    
    # Save routing entropy scores
    entropy_tensor = torch.tensor(stats['routing_entropy']['all_scores'], dtype=torch.float)
    entropy_filename = f"{clean_model_id}_routing_entropy.pt"
    torch.save(entropy_tensor, entropy_filename)
    print(f"Routing entropy scores saved to: {entropy_filename}")
    
    # Save raw routing data (the exact forward activations)
    global EXPERT_ROUTING_DATA
    if EXPERT_ROUTING_DATA:
        # Filter to only save relevant keys
        filtered_routing_data = []
        for routing_batch in EXPERT_ROUTING_DATA:
            filtered_batch = {
                'layer_idx': routing_batch['layer_idx'],
                'topk_indices': routing_batch['topk_indices'],
                'batch_size': routing_batch['batch_size'],
                'seq_len': routing_batch['seq_len'],
                'num_experts': routing_batch['num_experts'],
                'top_k': routing_batch['top_k']
            }
            filtered_routing_data.append(filtered_batch)
        
        routing_data_filename = f"{clean_model_id}_raw_routing_data.pt"
        torch.save(filtered_routing_data, routing_data_filename)
        print(f"Raw routing data saved to: {routing_data_filename}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_id", type=str, default="deepseek-ai/deepseek-moe-16b-chat")
    parser.add_argument("--dataset_name", type=str, default="OpenAssistant/oasst1")
    parser.add_argument("--num_samples", type=int, default=200)
    parser.add_argument("--max_length", type=int, default=512)
    parser.add_argument("--batch_size", type=int, default=1)
    args = parser.parse_args()
    
    # Configuration
    MODEL_ID = args.model_id
    DATASET_NAME = args.dataset_name
    NUM_SAMPLES = args.num_samples
    MAX_LENGTH = args.max_length
    BATCH_SIZE = args.batch_size
    
    print("DeepSeek MoE Expert Statistics Collection")
    print("-" * 50)
    
    try:
        # Collect statistics
        statistics = collect_deepseek_expert_statistics(
            model_id=MODEL_ID,
            dataset_name=DATASET_NAME,
            num_samples=NUM_SAMPLES,
            max_length=MAX_LENGTH,
            batch_size=BATCH_SIZE
        )
        
        if statistics:
            # Print results
            print_expert_statistics(statistics)
            
            # Save results
            save_statistics(statistics, MODEL_ID)
            
            print(f"\n{'='*60}")
            print("Expert statistics collection completed successfully!")
            print(f"{'='*60}")
        else:
            print("Failed to collect expert statistics.")
            
    except Exception as e:
        print(f"Error during expert statistics collection: {e}")
        import traceback
        traceback.print_exc() 