import json

def get_horizontal_chrome_output(input_path, output_path):
    with open(input_path, "r", encoding="utf-8") as f:
        steps = json.load(f)

    trace_events = []

    def add_trace(name, start_ns, dur_ns, pid=0, tid=0, args=None):
        trace_events.append({
            "name": name,
            "ph": "X",                    # 完整事件
            "ts": start_ns / 1000,        # ns → µs
            "dur": dur_ns / 1000,         # ns → µs
            "pid": pid,
            "tid": tid,
            "args": args or {}
        })

    tid = 0
    pid = 0

    for step_idx, step in enumerate(steps):
        step_name = step.get("step_name", f"step_{step_idx}")
        step_start = step.get("start", 0)
        step_end = step.get("end", 0)
        step_dur = max(step_end - step_start, 1)

        add_trace(step_name, step_start, step_dur, pid=pid, tid=tid)

        for cop in step.get("comm_ops", []):
            cop_name = cop.get("comm_name", "comm_op")
            cop_start = cop.get("start", 0)
            cop_end = cop.get("end", 0)
            cop_dur = max(cop_end - cop_start, 1)

            add_trace(cop_name, cop_start, cop_dur, pid=pid, tid=tid)

            for nc in cop.get("nccl_ops", []):
                nc_name = nc.get("nccl_name", "nccl_op")
                nc_start = nc.get("start", 0)
                nc_end = nc.get("end", 0)
                nc_dur = max(nc_end - nc_start, 1)

                add_trace(nc_name, nc_start, nc_dur, pid=pid, tid=tid)

                # cudalaunchkernel
                for launch in nc.get("cudalaunchkernel", []):
                    launch_name = launch.get("name", "cudaLaunchKernel")
                    launch_start = launch.get("start", 0)
                    launch_end = launch.get("end", 0)
                    launch_dur = max(launch_end - launch_start, 1)

                    add_trace(launch_name, launch_start, launch_dur, pid=pid, tid=tid)

                    for k in launch.get("kernels", []):
                        k_name = k.get("name", "kernel")
                        gpu_start = k.get("gpu_start", 0)
                        gpu_end = k.get("gpu_end", 0)
                        k_dur = max(gpu_end - gpu_start, 1)

                        add_trace(
                            k_name,
                            launch_start,
                            launch_dur,
                            pid=pid,
                            tid=tid,
                            args={
                                "gpu_start": gpu_start,
                                "gpu_end": gpu_end,
                                "gpu_dur": gpu_end - gpu_start,
                            }
                        )

    with open(output_path, "w", encoding="utf-8") as f:
        json.dump({"traceEvents": trace_events}, f, indent=2)

    print(f"✅ 通信调用链 Chrome Trace 已保存到 {output_path}")
    return