





import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from tqdm import tqdm
import argparse
import os

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_path", type=str, required=True, help="Base model path")
    parser.add_argument("--num_samples", type=int, default=128, help="Number of samples (chunks) to profile")
    parser.add_argument("--seq_len", type=int, default=2048)
    parser.add_argument("--save_path", type=str, default="./expert_counts_c4.pt")
    return parser.parse_args()

def main():
    args = parse_args()
    device = "cuda" if torch.cuda.is_available() else "cpu"
    
    print(f"Loading model: {args.model_path}")
    model = AutoModelForCausalLM.from_pretrained(
        args.model_path, 
        device_map="auto", 
        torch_dtype=torch.bfloat16,
        trust_remote_code=True
    )
    tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True)
 
    num_layers = model.config.num_hidden_layers
    num_experts = model.config.num_experts
  
    top_k = getattr(model.config, "num_experts_per_tok", 8) 
    
    print(f"Model Config: {num_layers} Layers, {num_experts} Experts, Top-{top_k}")

  
    expert_counts = torch.zeros(num_layers, num_experts, device=device, dtype=torch.float32)


    def get_hook(layer_idx):
        def hook(module, input, output):
     
            logits = output.view(-1, num_experts)
            
  
            topk_indices = torch.topk(logits, k=top_k, dim=-1).indices
            flat_indices = topk_indices.view(-1)
            

            current_counts = torch.bincount(flat_indices, minlength=num_experts).float()
            expert_counts[layer_idx] += current_counts.to(expert_counts.device)
            
        return hook


    handles = []
    for i in range(num_layers):

        if hasattr(model, "model"):
            target_module = model.model.layers[i].mlp.gate
        else:
            target_module = model.layers[i].mlp.gate 
            
        handle = target_module.register_forward_hook(get_hook(i))
        handles.append(handle)


    print("Loading C4 dataset (streaming mode)...")

    try:
        dataset = load_dataset("c4", "en", split="train", streaming=True, trust_remote_code=True)
    except Exception as e:
        print(" c4  allenai/c4...")
        dataset = load_dataset("allenai/c4", "en", split="train", streaming=True, trust_remote_code=True)

    chunks = []
    token_buffer = []
    
    print(f"Processing data to get {args.num_samples} chunks of length {args.seq_len}...")

    iterator = iter(dataset)
    
    with tqdm(total=args.num_samples) as pbar:
        while len(chunks) < args.num_samples:
            try:
          
                data = next(iterator)
                text = data['text']
                
         
                tokens = tokenizer(text, add_special_tokens=False).input_ids
                token_buffer.extend(tokens)
                
             
                while len(token_buffer) >= args.seq_len:
                    chunk = token_buffer[:args.seq_len]
                    token_buffer = token_buffer[args.seq_len:] 
                    
                
                    chunks.append(torch.tensor([chunk], device=device))
                    
                    pbar.update(1)
                    if len(chunks) >= args.num_samples:
                        break
            except StopIteration:
                print("Dataset exhausted before reaching num_samples.")
                break
    
    print(f"Profiling on {len(chunks)} chunks...")


    model.eval()
    with torch.no_grad():
        for batch in tqdm(chunks):
            # batch shape: [1, seq_len]
            model(batch) 

    for h in handles:
        h.remove()


    torch.save(expert_counts.cpu(), args.save_path)
    print(f"Expert counts saved to {args.save_path}")
    

    print("Layer 0 Top 5 Experts:", torch.topk(expert_counts[0], 5).indices)

if __name__ == "__main__":
    main()



