from collections import Counter
import json
import re

def is_step_marker(name):
    return name.startswith("step_execute_model[")

def is_nccl_op(name):
    return name.startswith("nccl")


def strip_op_id(name):
    # print(re.sub(r", op_id = \d+", "", name))
    return re.sub(r", op_id = \d+", "", name)


step_count = 0
torch_frontend_count = 0
torch_backend_count = 0
nccl_count = 0
runtime_count = 0
driven_count = 0
kernel_count = 0

torch_frontend_ops = Counter()
torch_backend_ops = Counter()
nccl_ops = Counter()
runtime_apis = Counter()
driver_apis = Counter()
kernel_names = Counter()

events = []
print("[INFO] 开始读取 output.jsonl 文件...")
with open("output1.jsonl", "r", encoding="utf-8") as f:
    for i, line in enumerate(f):
        line = line.strip()
        if not line:
            continue
        try:
            ev = json.loads(line)
            events.append(ev)
        except Exception as e:
            print(f"[ERROR] 解析 JSON 行失败: {line}\n错误: {e}")
print(f"[INFO] 读取完毕，事件数量: {len(events)}")

for ev in events:
    ev_type = ev.get("type")
    name = ev.get("name", "")

    if ev_type == "NVTX_MARKER":
        if name is None or name == "(null)":
            continue
        elif is_step_marker(name):
            step_count += 1
        elif "op_id = " in name:
            # print(name)
            torch_backend_count += 1
            torch_backend_ops[strip_op_id(name)] += 1
        elif is_nccl_op(name):
            nccl_count += 1
            nccl_ops[name] += 1
        else:
            torch_frontend_count += 1
            torch_frontend_ops[name] += 1
    elif ev_type == "RUNTIME":
        runtime_count += 1
        runtime_apis[name] += 1
    elif ev_type == "DRIVER":
        driven_count += 1
        driver_apis[name] += 1
    elif ev_type == "KERNEL":
        kernel_count += 1
        kernel_names[name] += 1

print("\n[统计结果]")
print(f"step 数量            : {step_count}")
print(f"PyTorch 前端数量     : {torch_frontend_count}")
print(f"PyTorch 后端数量     : {torch_backend_count}")
print(f"NCCL 数量            : {nccl_count}")
print(f"runtime 数量         : {runtime_count}")
print(f"driver 数量          : {driven_count}")
print(f"kernel 数量          : {kernel_count}")

def print_top(counter, title, top_n=100000000):
    print(f"\n[{title} 前 {top_n} 名]")
    for name, count in counter.most_common(top_n):
        print(f"{name:<60} {count}")

print_top(torch_frontend_ops, "PyTorch 前端 OP")
print_top(torch_backend_ops, "PyTorch 后端 OP")
print_top(nccl_ops, "NCCL OP")
print_top(runtime_apis, "CUDA Runtime API")
print_top(driver_apis, "CUDA Driver API")
print_top(kernel_names, "CUDA Kernel")
