import json

def get_vertival_chrome_output(input_path, output_path):
    with open(input_path, "r", encoding="utf-8") as f:
        steps = json.load(f)

    trace_events = []

    def add_trace(name, start_ns, dur_ns, pid=0, tid=0, args=None):
        trace_events.append({
            "name": name,
            "ph": "X",                    # 完整事件
            "ts": start_ns / 1000,        # ns → µs
            "dur": dur_ns / 1000,         # ns → µs
            "pid": pid,
            "tid": tid,
            "args": args or {}
        })

    tid = 0
    pid = 0

    for step_idx, step in enumerate(steps):
        step_name = step.get("step_name", f"step_{step_idx}")
        step_start = step.get("start", 0)
        step_end = step.get("end", 0)
        step_dur = max(step_end - step_start, 1)

        add_trace(step_name, step_start, step_dur, pid=pid, tid=tid)

        # frontend ops
        for fop in step.get("torch_frontend_ops", []):
            fop_name = fop.get("op_name", "frontend_op")
            fop_start = fop.get("start", 0)
            fop_end = fop.get("end", 0)
            fop_dur = max(fop_end - fop_start, 1)

            add_trace(fop_name, fop_start, fop_dur, pid=pid, tid=tid)

            # backend ops
            for bop in fop.get("torch_backend_ops", []):
                bop_name = bop.get("op_name", "backend_op")
                bop_start = bop.get("start", 0)
                bop_end = bop.get("end", 0)
                bop_dur = max(bop_end - bop_start, 1)

                add_trace(bop_name, bop_start, bop_dur, pid=pid, tid=tid, args={"op_id": bop.get("op_id")})

                # NCCL
                for nc in bop.get("nccl_calls", []):
                    nc_name = nc.get("name", "nccl_api")
                    nc_start = nc.get("start", 0)
                    nc_end = nc.get("end", 0)
                    nc_dur = max(nc_end - nc_start, 1)
                    add_trace(nc_name, nc_start, nc_dur, pid=pid, tid=tid)

                # Runtime
                for rt in bop.get("runtime_calls", []):
                    rt_name = rt.get("name", "runtime_api")
                    rt_start = rt.get("start", 0)
                    rt_end = rt.get("end", 0)
                    rt_dur = max(rt_end - rt_start, 1)
                    add_trace(rt_name, rt_start, rt_dur, pid=pid, tid=tid)

                    # Kernel
                    for k in rt.get("kernels", []):
                        k_name = k.get("name", "kernel")
                        gpu_start = k.get("gpu_start", 0)
                        gpu_end = k.get("gpu_end", 0)
                        k_dur = max(gpu_end - gpu_start, 1)
                        add_trace(
                            k_name,
                            rt_start,
                            rt_dur,
                            pid=pid,
                            tid=tid,
                            args={
                                "gpu_start": gpu_start,
                                "gpu_end": gpu_end,
                                "gpu_dur": gpu_end - gpu_start,
                            }
                        )

                # Driver
                for dr in bop.get("driver_calls", []):
                    dr_name = dr.get("name", "driver_api")
                    dr_start = dr.get("start", 0)
                    dr_end = dr.get("end", 0)
                    dr_dur = max(dr_end - dr_start, 1)
                    add_trace(dr_name, dr_start, dr_dur, pid=pid, tid=tid)

                    for k in dr.get("kernels", []):
                        k_name = k.get("name", "kernel")
                        gpu_start = k.get("gpu_start", 0)
                        gpu_end = k.get("gpu_end", 0)
                        k_dur = max(gpu_end - gpu_start, 1)
                        add_trace(
                            k_name,
                            dr_start,
                            dr_dur,
                            pid=pid,
                            tid=tid,
                            args={
                                "gpu_start": gpu_start,
                                "gpu_end": gpu_end,
                                "gpu_dur": gpu_end - gpu_start,
                            }
                        )

    with open(output_path, "w", encoding="utf-8") as f:
        json.dump({"traceEvents": trace_events}, f, indent=2)

    print(f"✅ 嵌套模式 Trace（所有 step 同一轨道）已保存到 {output_path}")
    return