#!/usr/bin/env python
# vllm_collect_stats_r1.py
"""
Collect MoE‑routing statistics from QuixiAI/DeepSeek‑R1‑AWQ
using the OpenAssistant/oasst1 dataset. 
"""

import argparse, os, types
from collections import defaultdict

import torch
from datasets import load_dataset
from vllm import LLM, SamplingParams

def load_llm(model_id: str, tp_size: int = 8):
    llm = LLM(
        model_id,
        tensor_parallel_size=tp_size,
        trust_remote_code=True,
        dtype="auto",
        compilation_config=0,     # disables torch.compile
        enforce_eager=True,
        max_model_len=1024,
        gpu_memory_utilization=0.9,
    )
    return llm


def _install_moe_logger(model):
    """
    Runs inside every worker.
    Wraps Experts.select_experts so we can stash the returned ids.
    Uses the same pattern as quant_model_v1.py
    """
    import types  # Optional for staticmethod wrapping

    def _get_wrapper_for_staticmethod(original_func, global_log):
        def _wrapped(*args, **kwargs):
            w, ids = original_func(*args, **kwargs)
            # Store the expert routing data
            global_log.append(ids.detach().cpu())
            return w, ids
        return staticmethod(_wrapped)  # rewrap it as staticmethod

    print("MODEL EXECUTOR: ", model)
    
    source_layer = model.model.layers[3].mlp.experts
    source_layer.GLOBAL_EXPERTS_LOG = []

    expert_class = type(source_layer)
    
    # replace static method on the class
    original = expert_class.select_experts
    expert_class.select_experts = _get_wrapper_for_staticmethod(original, source_layer.GLOBAL_EXPERTS_LOG)

    return "logger installed"


def _collect_stats(model):
    """
    Collect stats after each forward pass and clear the buffer.
    """
    source_layer = model.model.layers[3].mlp.experts
    
    res = [entry.tolist() for entry in source_layer.GLOBAL_EXPERTS_LOG]
    source_layer.GLOBAL_EXPERTS_LOG = []  # clear buffer for next run
    return res


def main(
    model_id="QuixiAI/DeepSeek-R1-AWQ",
    num_samples=64,
    max_len=512,
    tp_size=8,
    out_prefix="DeepSeek_R1_AWQ",
):
    llm = load_llm(model_id, tp_size)
    llm.llm_engine.model_executor.apply_model(_install_moe_logger)

    ds = load_dataset("OpenAssistant/oasst1", split=f"train[:{num_samples}]")
    prompts = [ex.get("text") or ex.get("prompt") or "" for ex in ds]

    num_layers = len(llm.llm_engine.model.model.layers)
    
    all_raw_data = []
    sample_idx = 0
    
    for prompt in prompts:
        print(f"Processing sample {sample_idx + 1}/{len(prompts)}")
        
        outs = llm.generate([prompt], SamplingParams(max_tokens=1))
        
        # Collect stats after each forward and clear buffer
        stats = llm.llm_engine.model_executor.apply_model(_collect_stats)
        
        # Process stats from all workers
        joined_log = [[stats[j][i] for j in range(len(stats))] for i in range(len(stats[0]))]
        
        # Convert to tensors and build raw data entries
        for call_idx, ids_list in enumerate(joined_log):
            ids_tensor = torch.tensor(ids_list)
            
            if sample_idx == 0 and call_idx == 0:
                if len(ids_tensor.shape) == 1:
                    ids_tensor = ids_tensor.unsqueeze(0)
                num_experts = ids_tensor.shape[-1] if len(ids_tensor.shape) > 1 else ids_tensor.max() + 1
                top_k = ids_tensor.shape[-1] if len(ids_tensor.shape) > 1 else 1
                print(f"Detected: {num_layers} layers, {num_experts} experts, top_k={top_k}")
            
            layer_idx = call_idx % num_layers
            all_raw_data.append({
                "layer_idx": layer_idx,
                "topk_indices": ids_tensor,
                "batch_size": 1,
                "seq_len": ids_tensor.shape[0] if len(ids_tensor.shape) > 1 else 1,
                "num_experts": num_experts,
                "top_k": top_k,
                "sample_idx": sample_idx,
            })
        
        sample_idx += 1

    if not all_raw_data:
        raise RuntimeError("No routing data captured – patch failed?")

    # save raw routing data
    torch.save(all_raw_data, f"{out_prefix}_raw_routing_data.pt")
    print(f"Saved raw routing → {out_prefix}_raw_routing_data.pt ({len(all_raw_data)} entries)")

    # build layer‑expert count matrix
    layer_mat = torch.zeros(num_layers, num_experts, dtype=torch.long)
    for entry in all_raw_data:
        idxs = entry["topk_indices"].reshape(-1)
        layer = entry["layer_idx"]
        layer_mat[layer] += torch.bincount(idxs, minlength=num_experts)

    torch.save(layer_mat, f"{out_prefix}_layer_expert_matrix.pt")
    print(f"Saved layer‑expert matrix → {out_prefix}_layer_expert_matrix.pt ({layer_mat.shape})")


    # print(f"Total samples processed: {sample_idx}")
    # print(f"Total routing calls captured: {len(all_raw_data)}")
    # print(f"Layer-expert usage distribution:")
    # for layer_idx in range(min(5, num_layers)): 
    #     layer_usage = layer_mat[layer_idx]
    #     print(f"  Layer {layer_idx}: {layer_usage.tolist()}")
    # if num_layers > 5:
    #     print(f"  ... ({num_layers - 5} more layers)")

# ----------------------------------------------------------------------
if __name__ == "__main__":
    ap = argparse.ArgumentParser()
    ap.add_argument("--model_id", default="QuixiAI/DeepSeek-R1-AWQ")
    ap.add_argument("--num_samples", type=int, default=64)
    ap.add_argument("--max_len", type=int, default=512)
    ap.add_argument("--tp", type=int, default=8)
    args = ap.parse_args()
    main(
        model_id=args.model_id,
        num_samples=args.num_samples,
        max_len=args.max_len,
        tp_size=args.tp,
        out_prefix=args.model_id.replace("/", "_"),
    )
