import os
import json
import argparse
from collections import defaultdict

from .utils import (
    is_step_marker, is_nccl_op, is_torch_frontend_op, is_comm_op,
    is_kernel_launch_api, is_nccl_kernel, parse_op_id, get_event_time,
    assign_events_to_ops_best
)
from .vertical_chrome_events import get_vertival_chrome_output
from .horizontal_chrome_events import get_horizontal_chrome_output

def build_all_event_and_ranges(input_file):
    events = []
    print(f"[INFO] 开始读取 {input_file} 文件...")

    with open(input_file, "r", encoding="utf-8") as f:
        for i, line in enumerate(f):
            line = line.strip()
            if not line:
                continue
            try:
                ev = json.loads(line)
                events.append(ev)
            except Exception as e:
                print(f"[ERROR] 解析 JSON 行失败: {line}\n错误: {e}")

    print(f"[INFO] 读取完毕，事件数量: {len(events)}")

    events.sort(key=get_event_time)
    print("[INFO] 已按时间戳排序")

    process_stacks = defaultdict(list)  
    ranges = []

    for ev in events:
        if ev["type"] != "NVTX_MARKER":
            continue

        name = ev["name"]
        ts = ev["timestamp"]
        marker_id = ev.get("id")
        process_id = ev.get("process_id")

        if name == "null":
            # === pop 操作 ===
            if process_id is None or marker_id is None:
                continue

            stack = process_stacks[process_id]
            for i in reversed(range(len(stack))):
                if stack[i]["id"] == marker_id:
                    last = stack.pop(i)
                    last["end"] = ts
                    ranges.append(last)
                    break
            else:
                print(f"[WARN] 找不到对应的 push 事件，process_id={process_id}, id={marker_id}")

        else:
            if process_id is None or marker_id is None:
                continue
            process_stacks[process_id].append({
                "name": name,
                "start": ts,
                "id": marker_id,
                "process_id": process_id,
            })

    print(f"[INFO] NVTX 区间构建完成，共: {len(ranges)}")

    cut_ts = None
    for r in ranges:
        if is_step_marker(r["name"]):
            cut_ts = r["start"]
            break

    return events, ranges, cut_ts

def get_all_ops(events, ranges, steps, cut_ts):
    events = [ev for ev in events if get_event_time(ev) >= cut_ts]
    ranges = [r for r in ranges if r["end"] >= cut_ts]
    print(f"[INFO] 剪裁后剩余事件数: {len(events)}, 区间数: {len(ranges)}")

    torch_frontend_ops = []
    torch_backend_ops = []
    nccl_ops = []
    comm_ops = []

    for r in ranges:
        name = r["name"]
        if is_step_marker(name):
            r["children"] = []
            steps.append(r)
        elif "op_id = " in name:
            r["op_id"] = parse_op_id(name)  # 后端赋 op_id
            r["children"] = []
            torch_backend_ops.append(r)
        elif is_nccl_op(name):
            nccl_ops.append(r)
        elif is_torch_frontend_op(name):
            r["children"] = []
            torch_frontend_ops.append(r)
        elif is_comm_op(name):
            r["children"] = []
            comm_ops.append(r)

    print(f"[INFO] 分类完成: steps={len(steps)}, torch_frontend_ops={len(torch_frontend_ops)}, "
        f"torch_backend_ops={len(torch_backend_ops)}, nccl_ops={len(nccl_ops)}")
    return events, steps, torch_frontend_ops, torch_backend_ops, nccl_ops, comm_ops

def process_single_file(events, steps, torch_frontend_ops, torch_backend_ops, nccl_ops, comm_ops, output_dir, is_chrome):
    f_idx = 0
    for bop in torch_backend_ops:
        while (
            f_idx < len(torch_frontend_ops)
            and (
                torch_frontend_ops[f_idx].get("end", 0) < bop["start"]
                or torch_frontend_ops[f_idx].get("process_id") != bop.get("process_id")
            )
        ):
            f_idx += 1

        if f_idx < len(torch_frontend_ops):
            fop = torch_frontend_ops[f_idx]
            if (
                fop.get("process_id") == bop.get("process_id")
                and fop["start"] <= bop["start"] <= fop.get("end", 0)
            ):
                fop.setdefault("children", []).append(bop)


    runtime_calls, driver_calls, kernels = [], [], []

    for ev in events:
        t = ev["type"]
        if t == "RUNTIME":
            runtime_calls.append(ev)
        elif t == "DRIVER":
            driver_calls.append(ev)
        elif t == "KERNEL":
            kernels.append(ev)

    assign_events_to_ops_best(runtime_calls, torch_backend_ops, "runtimes")
    assign_events_to_ops_best(driver_calls, torch_backend_ops, "drivers")
    assign_events_to_ops_best(nccl_ops, comm_ops, "nccls")


    corr_list_map = defaultdict(list)
    comm_pairs = []


    def index_corr(ev_list):
        for ev in ev_list:
            key = ev.get("correlation_id")
            if key:
                corr_list_map[key].append(ev)


    index_corr(runtime_calls)
    index_corr(driver_calls)



    def find_nearest_launch_api(kernel, candidates):
        k_start = kernel.get("gpu_start", 0)
        min_dist, best_match = float('inf'), None
        for api in candidates:
            if not is_kernel_launch_api(api.get("name", "")):
                continue
            dist = abs(api.get("start", 0) - k_start)
            if dist < min_dist:
                min_dist, best_match = dist, api
        return best_match


    for k in kernels:
        if k.get("duration", 0) == 0:
            continue
        key = k.get("correlation_id")
        candidates = corr_list_map.get(key, [])
        target = find_nearest_launch_api(k, candidates)
        if target:
            target.setdefault("kernels", []).append(k)

    new_runtime_calls = []
    for api in runtime_calls:
        kernels_list = api.get("kernels", [])
        if any(is_nccl_kernel(k.get("name", "")) for k in kernels_list):
            comm_pairs.append((api, kernels_list))
        else:
            new_runtime_calls.append(api)
    runtime_calls = new_runtime_calls


    print("[INFO] 所有 kernel 事件已绑定到 kernel launch API")

    comm_pairs_sorted = sorted(comm_pairs, key=lambda x: x[0]["start"])
    nccl_ops_sorted = sorted(nccl_ops, key=lambda x: x["start"])

    j = 0
    for nccl in nccl_ops_sorted:
        start, end = nccl["start"], nccl["end"]

        while j < len(comm_pairs_sorted) and comm_pairs_sorted[j][0]["end"] < start:
            j += 1

        k = j
        while k < len(comm_pairs_sorted) and comm_pairs_sorted[k][0]["start"] <= end:
            api, ks = comm_pairs_sorted[k]
            if start <= api["start"] <= end:
                nccl.setdefault("launches", []).append({"runtime": api, "kernels": ks})
            k += 1

    print("[INFO] 所有事件归属完成")

    vertical_chain_output = []
    for step in steps:
        step_dict = {
            "step_name": step["name"],
            "start": step["start"],
            "end": step["end"],
            "torch_frontend_ops": []
        }

        for fop in torch_frontend_ops:
            if not (step["start"] <= fop["start"] <= step["end"]):
                continue

            fop_dict = {
                "op_name": fop["name"],
                "start": fop["start"],
                "end": fop.get("end"),
                "torch_backend_ops": []
            }

            for bop in fop.get("children", []):
                bop_dict = {
                    "op_name": bop["name"],
                    "op_id": bop.get("op_id"),
                    "start": bop["start"],
                    "end": bop.get("end"),
                    "runtime_calls": [],
                    "driver_calls": []
                }

                for rt in bop.get("runtimes", []):
                    rt_entry = {
                        "name": rt["name"],
                        "start": rt["start"],
                        "end": rt.get("end"),
                        "kernels": [
                            {
                                "name": k.get("name"),
                                "gpu_start": k.get("gpu_start"),
                                "gpu_end": k.get("gpu_end")
                            }
                            for k in rt.get("kernels", []) or []
                        ]
                    }
                    bop_dict["runtime_calls"].append(rt_entry)

                for drv in bop.get("drivers", []):
                    drv_entry = {
                        "name": drv["name"],
                        "start": drv["start"],
                        "end": drv.get("end"),
                        "kernels": [
                            {
                                "name": k.get("name"),
                                "gpu_start": k.get("gpu_start"),
                                "gpu_end": k.get("gpu_end")
                            }
                            for k in drv.get("kernels", []) or []
                        ]
                    }
                    bop_dict["driver_calls"].append(drv_entry)

                fop_dict["torch_backend_ops"].append(bop_dict)

            step_dict["torch_frontend_ops"].append(fop_dict)

        vertical_chain_output.append(step_dict)

    os.makedirs(output_dir, exist_ok=True)
    nvtx_chain_output_path = os.path.join(output_dir, "nvtx_chain_output.json")
    with open(nvtx_chain_output_path, "w", encoding="utf-8") as f:
        json.dump(vertical_chain_output, f, indent=2, ensure_ascii=False)
    if is_chrome:
        get_vertival_chrome_output(nvtx_chain_output_path, os.path.join(output_dir, "vertical_chrome_output.json"))

    step_comm_chain_output = []
    for step in steps:
        step_comm_dict = {
            "step_name": step["name"],
            "start": step["start"],
            "end": step["end"],
            "comm_ops": []
        }

        for cop in comm_ops:
            # if cop.get("process_id") != step.get("process_id"):
            #     continue
            if not (step["start"] <= cop["start"] <= step["end"]):
                continue

            cop_dict = {
                "comm_name": cop["name"],
                "start": cop["start"],
                "end": cop.get("end"),
                "nccl_ops": []
            }

            for nccl in cop.get("nccls", []):
                nccl_dict = {
                    "nccl_name": nccl["name"],
                    "start": nccl["start"],
                    "end": nccl.get("end"),
                    "cudalaunchkernel": []
                }

                for launch in nccl.get("launches", []):
                    api = launch.get("runtime")
                    if not api or not is_kernel_launch_api(api.get("name", "")):
                        continue

                    launch_entry = {
                        "name": api["name"],
                        "start": api["start"],
                        "end": api.get("end"),
                        "kernels": [
                            {"name": k.get("name"), "gpu_start": k.get("gpu_start"), "gpu_end": k.get("gpu_end")}
                            for k in launch.get("kernels", []) or []
                        ]
                    }
                    nccl_dict["cudalaunchkernel"].append(launch_entry)

                cop_dict["nccl_ops"].append(nccl_dict)

            step_comm_dict["comm_ops"].append(cop_dict)

        step_comm_chain_output.append(step_comm_dict)

    os.makedirs(output_dir, exist_ok=True)
    comm_chain_output_path = os.path.join(output_dir, "comm_chain_output.json")
    with open(comm_chain_output_path, "w", encoding="utf-8") as f:
        json.dump(step_comm_chain_output, f, indent=2, ensure_ascii=False)
    if is_chrome:
        get_horizontal_chrome_output(comm_chain_output_path, os.path.join(output_dir, "horizontal_chrome_output.json"))

    return {
        "nvtx_chain_output": nvtx_chain_output_path,
        "comm_chain_output": comm_chain_output_path
    }

def process_all_files(input_dir, output_dir, task_type, is_chrome):
    master_pid = None
    worker_pids = []
    all_process_data = {}
    all_process_event_range = {}
    cut_ts = None

    for input_file_name in os.listdir(input_dir):
        input_file_path = os.path.join(input_dir, input_file_name)
        if os.path.isfile(input_file_path):
            input_pid = int(input_file_name.split(".")[0].split("_")[-1][3:])
            events, ranges, temp_cut_ts = build_all_event_and_ranges(input_file_path)
            if temp_cut_ts:
                master_pid = input_pid
                cut_ts = temp_cut_ts
            else:
                worker_pids.append(input_pid)
            all_process_event_range[input_pid] = (events, ranges)

    master_events, master_ranges = all_process_event_range[master_pid]
    master_events, master_steps, master_torch_frontend_ops, master_torch_backend_ops, master_nccl_ops, master_comm_ops = get_all_ops(master_events, master_ranges, [], cut_ts)
    all_process_data[master_pid] = process_single_file(master_events, master_steps, master_torch_frontend_ops, master_torch_backend_ops, master_nccl_ops, master_comm_ops, os.path.join(output_dir, f"pid_{master_pid}"), is_chrome)
    for pid, (events, ranges) in all_process_event_range.items():
        if(pid != master_pid):
            events, steps, torch_frontend_ops, torch_backend_ops, nccl_ops, comm_ops = get_all_ops(events, ranges, master_steps, cut_ts)
            all_process_data[pid] = process_single_file(events, steps, torch_frontend_ops, torch_backend_ops, nccl_ops, comm_ops, os.path.join(output_dir, f"pid_{pid}"), is_chrome)

    meta_data = {
        "task_type": task_type,
        "error_type": "none",
        "master_pid": master_pid,
        "worker_pid": worker_pids,
        "all_process_data": all_process_data
    }

    with open(os.path.join(output_dir, "meta_data.json"), "w", encoding="utf-8") as f:
        json.dump(meta_data, f, indent=2, ensure_ascii=False)

    return

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Parse NVTX and CUPTI trace events.")
    parser.add_argument("--in_dir", type=str, default="../output/")
    parser.add_argument("--out_dir", type=str, default="../output/test_proj")
    args = parser.parse_args()

    process_all_files(args.in_dir, args.out_dir, "none", is_chrome=True)

    exit(0)