#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Construct a DAG over partitioned reasoning steps using an OpenAI-compatible API.
"""

from __future__ import annotations

import json
import os
import random
import re
import time
from dataclasses import dataclass
from string import Template
from typing import Any, Dict, Iterable, List, Optional, Tuple
from concurrent.futures import ThreadPoolExecutor, as_completed

from rich.console import Console
from rich.progress import Progress, SpinnerColumn, BarColumn, TextColumn, TimeElapsedColumn, TimeRemainingColumn

from openai import OpenAI


DAG_PROMPT = Template(
    """
### Goal
You will receive a sequence of reasoning steps.
For exactly one target step **s$current_id**, you must decide how it connects to earlier steps.

You must classify this connection as exactly ONE of:
- continue
- backtrack
- merge

### Input Format
The input block contains only:
  - s0
  - steps whose IDs are on the current available main path
  - the current step s$current_id
  - and (in a separate section) a small set of OTHER_OPEN_BRANCH_LEAVES, which are leaf steps from other branches

### Output Format
For the single target step s$current_id, output exactly:

1. **One short line explanation** (1-2 lines) of what s$current_id is doing or how it relates to earlier steps. Explain whether it is a normal continuation, a backtrack/restart from an earlier step, or is merging branches.

2. A line of the exact form:
   <|action|>continue
   OR
   <|action|>backtrack
   OR
   <|action|>merge

3. If (and only if) the action is:
   - backtrack:
        You MUST also output a single parent marker line in the exact form
        <|previous|>K
        where:
        - K is an integer step ID **strictly less than $current_id.
        - K must be chosen from the steps on the current main path in the Input Block.
        - K must be strictly earlier than $last_previous_step_id.
        - Never invent new or unseen step IDs.

   - merge:
        You MUST also output a parent marker line in the exact form
        <|previous|>K1,K2,...
        where:
        - K1,K2,... are two or more integer step IDs.
        - All IDs must be **strictly less than $current_id.
        - At least one of these IDs must come from OTHER_OPEN_BRANCH_LEAVES.
        - Never invent new or unseen step IDs.

   - continue:
        You DO NOT output a <|previous|>... line. (It is implicitly understood that the parent is $last_previous_step_id.)

### Input Block
$input_steps

### Previous Step Selection Rules

- If s$current_id logically continues from the most recent step on the current main path, classify as **continue**. (No <|previous|> line is needed in this case.)

- Otherwise, treat as **backtrack** if s$current_id is restarting / revisiting / re-checking an earlier point on the SAME main path.
  In this case you must output <|action|>backtrack and then <|previous|>K for that earlier step K.

- Rarely, if s$current_id is explicitly combining reasoning from more than one branch (for example, it is stitching together partial arguments from two different leaf steps, at least one of which is from OTHER_OPEN_BRANCH_LEAVES), classify as **merge**.
  In this case you must output <|action|>merge and then <|previous|>K1,K2,... with two or more parent IDs.

### Backtracking

If s$current_id appears to the following patterns, it is likely to **backtrack to K < $last_previous_step_id** in the available steps:
1. The steps before step K contain sufficient information to derive s$current_id, so that the sequence of previous steps before K + s$current_id can be seen as (part of) a logical flow.
2. The steps after step K contain extensive reasoning information that is similar to s$current_id, making them redundant or repetitive.
3. Step K should be selected as large as possible (i.e., the nearest valid earlier step).
4. s$current_id may show reflection, restart, or repetition (RRR) patterns, such as: "Alternatively...", "Let me check again...", etc.

### Merging

If s$current_id is actively combining reasoning from multiple branches, it is likely to **merge**:
1. At least 1 from other branches and at most 1 from the main path must be included.
2. Only applies when earlier reasoning actually branched: multiple sibling branches explored different possibilities at roughly the same stage, and now s$current_id is explicitly stitching those partial results together to form a single next conclusion. In that case, s$current_id should inherit from BOTH of those branch tips.
4. Merging is rare and should only be used when there is clear evidence of combining distinct prior lines of reasoning.

Note:
- If s$current_id lacks explicit reference to the content of earlier steps from other branches, it should generally NOT be classified as a **merge**. If s$current_id is combining results from different nodes in the main path only, treat it as **continue** or **backtrack**.
- Even if the sentence expresses a contrastive meaning (e.g., shows transition or “however”-like tone), if it continues a sequential analysis and does not repeat prior reasoning, treat it as a normal **continue**.
- If the last previous step is an RRR step that carries only the RRR meaning without actual analysis, and it **initiated a new branch**, then the current s$current_id is likely following that RRR lead into the new branch and should be treated as a normal **continue**.
- If s$current_id does not explicitly indicate that the current step is based on certain steps from other branches, do NOT use **merge**.
- Even if s$current_id states it continues from an earlier step, if the main path already contains the information mentioned in other branches, use **backtrack** instead of **merge**.
- **continue** and **backtrack** is common. While in doubt, prefer **continue**.

Available main-path steps (in order): $available_steps  
Available other-branch leaf steps: $leaf_steps

Output:
Explanation text here (CONCISE, 1-2 lines)
<|action|>...
[optional if backtrack (exactly 1 previous step) or merge (≥ 2 previous steps, at least 1 from other branches and at most 1 from the main path)]
<|previous|>...
""".strip()
)


@dataclass
class OpenAIConfig:
    api_key: str
    base_url: Optional[str]
    model: str
    temperature: float = 0.3
    top_p: float = 0.9
    max_tokens: int = 512
    n: int = 1
    timeout: Optional[float] = None


@dataclass
class DagParams:
    regen_limit: int = 5
    main_path_cap: int = 8
    other_leaf_cap: int = 5
    sleep_between_retries: float = 0.5
    strict_mode: bool = False
    random_seed: Optional[int] = None
    max_n_per_request: int = 4


ACTION_RE = re.compile(r"(?:<\|action\|>|\|action\|\s*:)\s*(\w+)", re.IGNORECASE)
PREV_RE = re.compile(r"(?:<\|previous\|>|\|previous\|\s*:)\s*([\d,\s]+)", re.IGNORECASE)


def build_client(config: OpenAIConfig) -> OpenAI:
    if config.base_url:
        return OpenAI(api_key=config.api_key, base_url=config.base_url, timeout=config.timeout)
    return OpenAI(api_key=config.api_key, timeout=config.timeout)


_CONSOLE = Console()


def _log(msg: str) -> None:
    _CONSOLE.print(msg)


def _format_step_block(ids: List[int], steps: List[str]) -> str:
    lines = []
    for idx in ids:
        if 0 <= idx < len(steps):
            lines.append(f"s{idx}: {steps[idx]}")
    return "\n".join(lines)


def build_input_block(
    steps: List[str],
    current_id: int,
    main_path: List[int],
    other_leaves: List[int],
    main_path_cap: int,
    other_leaf_cap: int,
) -> str:
    main_ids = main_path[-main_path_cap:] if main_path else []
    other_ids = other_leaves[-other_leaf_cap:] if other_leaves else []
    blocks = ["MAIN_PATH:", _format_step_block(main_ids, steps)]
    blocks.append("\nCURRENT_STEP:")
    blocks.append(_format_step_block([current_id], steps))
    blocks.append("\nOTHER_OPEN_BRANCH_LEAVES:")
    blocks.append(_format_step_block(other_ids, steps) if other_ids else "(none)")
    return "\n".join(blocks)


@dataclass
class StepBlock:
    action: str
    prev_list: List[int]
    explanation: str
    raw_text: str


def _strip_code_fences(text: str) -> str:
    text = text.strip()
    if text.startswith("```") and text.endswith("```"):
        lines = text.splitlines()
        if lines and lines[0].startswith("```"):
            lines = lines[1:]
        if lines and lines[-1].startswith("```"):
            lines = lines[:-1]
        return "\n".join(lines).strip()
    return text


def parse_llm_output(text: str) -> Tuple[Optional[StepBlock], List[str]]:
    if not isinstance(text, str):
        return None, ["non-string output"]
    cleaned = _strip_code_fences(text)
    lines = [ln.strip() for ln in cleaned.splitlines() if ln.strip()]
    if not lines:
        return None, ["empty output"]
    action_match = ACTION_RE.search(cleaned)
    action = action_match.group(1).lower() if action_match else None
    prev_list: List[int] = []
    prev_match = PREV_RE.search(cleaned)
    if prev_match:
        prev_list = [int(x) for x in re.findall(r"\d+", prev_match.group(1))]
    explanation = ""
    for ln in lines:
        if "<|action|>" in ln or "<|previous|>" in ln or ln.lower().startswith("|action|"):
            continue
        explanation = ln
        break
    if not explanation:
        explanation = lines[0]
    if action not in {"continue", "backtrack", "merge"}:
        return None, [f"invalid action: {action}"]
    return StepBlock(action=action, prev_list=prev_list, explanation=explanation, raw_text=text), []


def call_llm_for_step(
    client: OpenAI,
    openai_config: OpenAIConfig,
    prompt_text: str,
    n: int,
) -> Tuple[List[str], Dict[str, int]]:
    # Some OpenAI-compatible backends enforce a small upper bound on `n`.
    # Keep this function defensive to avoid hard failures on backend-specific limits.
    try:
        n = int(n)
    except Exception:
        n = 1
    if n < 1:
        n = 1
    response = client.chat.completions.create(
        model=openai_config.model,
        messages=[{"role": "user", "content": prompt_text}],
        temperature=openai_config.temperature,
        top_p=openai_config.top_p,
        max_tokens=openai_config.max_tokens,
        n=n,
    )
    outputs: List[str] = []
    for choice in response.choices:
        msg = getattr(choice, "message", None)
        if msg and getattr(msg, "content", None):
            outputs.append(msg.content)
    usage = {}
    if hasattr(response, "usage") and response.usage is not None:
        usage = {
            "prompt_tokens": int(getattr(response.usage, "prompt_tokens", 0) or 0),
            "completion_tokens": int(getattr(response.usage, "completion_tokens", 0) or 0),
            "total_tokens": int(getattr(response.usage, "total_tokens", 0) or 0),
        }
    return outputs, usage


def _compute_depth(children_map: Dict[int, List[int]], root_id: int = 0) -> Dict[int, int]:
    depth: Dict[int, int] = {root_id: 0}
    queue: List[int] = [root_id]
    for node in list(children_map.keys()):
        depth.setdefault(node, 0)
    while queue:
        cur = queue.pop(0)
        for nxt in children_map.get(cur, []):
            cand = depth[cur] + 1
            if nxt not in depth or cand < depth[nxt]:
                depth[nxt] = cand
                queue.append(nxt)
    return depth


def _graph_to_jsonable(
    parents_map: Dict[int, List[int]],
    children_map: Dict[int, List[int]],
    depth_map: Dict[int, int],
    explanations: Dict[int, str],
    node_texts: Dict[int, str],
    actions: Dict[int, str],
    errors: List[str],
) -> Dict[str, Any]:
    return {
        "parents": {str(k): v for k, v in parents_map.items()},
        "children": {str(k): v for k, v in children_map.items()},
        "depth": {str(k): int(v) for k, v in depth_map.items()},
        "explanations": {str(k): v for k, v in explanations.items()},
        "node_texts": {str(k): v for k, v in node_texts.items()},
        "actions": {str(k): v for k, v in actions.items()},
        "errors": errors,
    }


def _build_merged_view(
    dag_json: Dict[str, Any],
    step_texts: List[str],
) -> Dict[str, Any]:
    parents = {int(k): list(map(int, v)) for k, v in (dag_json.get("parents") or {}).items()}
    children = {int(k): list(map(int, v)) for k, v in (dag_json.get("children") or {}).items()}
    actions = {int(k): str(v) for k, v in (dag_json.get("actions") or {}).items()}

    if not parents and not children:
        parents = {0: []}
        children = {0: []}
        actions = {0: "root"}

    all_nodes = sorted(set(parents.keys()) | set(children.keys()))
    if not all_nodes:
        all_nodes = [0]

    indeg = {u: len(parents.get(u, [])) for u in all_nodes}
    outdeg_total = {u: len(children.get(u, [])) for u in all_nodes}

    def is_continue_edge(u: int, v: int) -> bool:
        return actions.get(v) == "continue" and indeg.get(v, 0) == 1 and parents.get(v, [None])[0] == u

    cont_children = {u: [v for v in children.get(u, []) if is_continue_edge(u, v)] for u in all_nodes}
    assigned = set()
    chains: List[List[int]] = []

    def is_chain_head(u: int) -> bool:
        if indeg.get(u, 0) != 1 or actions.get(u) != "continue":
            return True
        p = parents[u][0]
        return not (outdeg_total.get(p, 0) == 1)

    for u in all_nodes:
        if u in assigned or not is_chain_head(u):
            continue
        chain = [u]
        assigned.add(u)
        cur = u
        while True:
            cands = cont_children.get(cur, [])
            if outdeg_total.get(cur, 0) != 1 or len(cands) != 1:
                break
            v = cands[0]
            chain.append(v)
            assigned.add(v)
            if outdeg_total.get(v, 0) != 1:
                break
            cur = v
        chains.append(chain)

    for u in all_nodes:
        if u not in assigned:
            chains.append([u])
            assigned.add(u)

    chains.sort(key=lambda seq: seq[-1])
    merged_id_of: Dict[int, int] = {}
    merged_nodes: List[Dict[str, Any]] = []
    for mid, seq in enumerate(chains):
        for r in seq:
            merged_id_of[r] = mid
        merged_nodes.append({"id": mid, "raw": seq[:], "last_raw": seq[-1]})

    parentsM: Dict[int, List[int]] = {m["id"]: [] for m in merged_nodes}
    childrenM: Dict[int, List[int]] = {m["id"]: [] for m in merged_nodes}
    for child_raw, p_list in parents.items():
        mc = merged_id_of[child_raw]
        for p in p_list:
            mp = merged_id_of[p]
            if mp == mc:
                continue
            if mp not in parentsM[mc]:
                parentsM[mc].append(mp)
            if mc not in childrenM[mp]:
                childrenM[mp].append(mc)

    for d in (parentsM, childrenM):
        for k in d:
            d[k] = sorted(list(dict.fromkeys(d[k])))

    root_mid = merged_id_of.get(0, 0)
    depthM = _compute_depth(childrenM, root_mid)

    node_textsM: Dict[int, str] = {}
    explanationsM: Dict[int, str] = {}
    for mnode in merged_nodes:
        mid = mnode["id"]
        seq = mnode["raw"]
        header = "[MergedNode] contains raw: " + ", ".join(f"s{r}" for r in seq)
        body_parts = []
        for r in seq:
            raw_txt = step_texts[r] if 0 <= r < len(step_texts) else ""
            body_parts.append(f"── s{r} ──\\n{raw_txt}".rstrip())
        node_textsM[mid] = header + "\\n\\n" + "\\n\\n".join(body_parts)
        explanationsM[mid] = f"Merged tail = s{seq[-1]}"

    mids_sorted = sorted([m["id"] for m in merged_nodes], key=lambda m: next(x["last_raw"] for x in merged_nodes if x["id"] == m))
    merged_names: Dict[int, str] = {}
    for rank, mid in enumerate(mids_sorted):
        merged_names[mid] = f"s{rank}"

    return {
        "root_merged": root_mid,
        "merged_nodes": [{"id": m["id"], "raw": m["raw"], "last_raw": m["last_raw"]} for m in merged_nodes],
        "raw_to_merged": {str(r): int(mid) for r, mid in merged_id_of.items()},
        "parents": {str(k): v for k, v in parentsM.items()},
        "children": {str(k): v for k, v in childrenM.items()},
        "depth": {str(k): int(v) for k, v in depthM.items()},
        "node_texts": {str(k): v for k, v in node_textsM.items()},
        "explanations": {str(k): v for k, v in explanationsM.items()},
        "names": {str(k): v for k, v in merged_names.items()},
        "actions": {},
        "errors": [],
    }


def build_dag(
    prompt: str,
    steps: List[str],
    openai_config: OpenAIConfig,
    dag_params: DagParams,
) -> Dict[str, Any]:
    client = build_client(openai_config)

    n_steps = len(steps)
    parents: Dict[int, List[int]] = {0: []}
    children: Dict[int, List[int]] = {0: []}
    actions: Dict[int, str] = {0: "root"}
    explanations: Dict[int, str] = {0: "root"}
    node_texts: Dict[int, str] = {0: steps[0] if steps else ""}
    raw_outputs: Dict[int, str] = {0: ""}
    errors: List[str] = []
    usage_all = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0, "calls": 0}

    main_path: List[int] = [0]
    leaf_set = {0}
    base_n = max(1, int(openai_config.n or 1))
    attempts_max = max(0, int(dag_params.regen_limit) - 1)
    max_n_per_request = max(1, int(getattr(dag_params, "max_n_per_request", 4) or 4))

    _log(f"[bold cyan]DAG[/bold cyan] start: {n_steps} steps")
    for idx in range(1, n_steps):
        _log(f"[green]Step {idx}/{n_steps - 1}[/green] building...")
        last_previous_step_id = main_path[-1]
        input_block = build_input_block(
            steps,
            current_id=idx,
            main_path=main_path,
            other_leaves=sorted(leaf_set - set(main_path)),
            main_path_cap=dag_params.main_path_cap,
            other_leaf_cap=dag_params.other_leaf_cap,
        )
        available_steps = ",".join(str(i) for i in main_path)
        other_leaves = sorted(leaf_set - set(main_path))
        leaf_steps = ",".join(str(i) for i in other_leaves) or "(none)"
        prompt_text = DAG_PROMPT.safe_substitute(
            current_id=idx,
            input_steps=input_block,
            available_steps=available_steps,
            leaf_steps=leaf_steps,
            last_previous_step_id=last_previous_step_id,
        )

        action = None
        prev_list: List[int] = []
        raw_text = ""
        explanation = ""
        aggregated_valid: List[StepBlock] = []

        for attempt in range(attempts_max + 1):
            n_target = base_n + 2 * attempt
            n_this = min(max_n_per_request, n_target)
            _log(
                f"  [yellow]LLM round[/yellow] {attempt + 1}/{attempts_max + 1} "
                f"(n={n_this}, target={n_target}, cap={max_n_per_request})"
            )
            outputs, usage = call_llm_for_step(client, openai_config, prompt_text, n=n_this)
            usage_all["prompt_tokens"] += usage.get("prompt_tokens", 0)
            usage_all["completion_tokens"] += usage.get("completion_tokens", 0)
            usage_all["total_tokens"] += usage.get("total_tokens", 0)
            usage_all["calls"] += 1

            for candidate in outputs:
                block, errs = parse_llm_output(candidate)
                if errs:
                    errors.extend(errs)
                if block is None:
                    continue
                cand_action = block.action
                cand_prev = [p for p in block.prev_list if isinstance(p, int) and 0 <= p < idx]

                valid = True
                if cand_action == "backtrack":
                    backtrack_candidates = [p for p in cand_prev if p in main_path and p < last_previous_step_id]
                    valid = len(backtrack_candidates) >= 1
                    if valid:
                        cand_prev = [max(backtrack_candidates, key=lambda x: main_path.index(x))]
                elif cand_action == "merge":
                    if len(cand_prev) < 2:
                        valid = False
                    else:
                        uniq_prev = list(dict.fromkeys(cand_prev))
                        main_hits = [p for p in uniq_prev if p in main_path]
                        leaf_hits = [p for p in uniq_prev if p in other_leaves]
                        valid = len(leaf_hits) >= 1 and len(main_hits) <= 1
                        if valid:
                            cand_prev = uniq_prev

                if dag_params.strict_mode:
                    if cand_action == "continue" and cand_prev:
                        valid = False
                    if cand_action == "backtrack" and len(cand_prev) != 1:
                        valid = False
                    if cand_action == "merge" and len(cand_prev) < 2:
                        valid = False

                if valid:
                    aggregated_valid.append(StepBlock(action=cand_action, prev_list=cand_prev, explanation=block.explanation, raw_text=block.raw_text))

            if aggregated_valid:
                action_counts: Dict[str, int] = {}
                for blk in aggregated_valid:
                    action_counts[blk.action] = action_counts.get(blk.action, 0) + 1
                ranked = sorted(action_counts.items(), key=lambda x: x[1], reverse=True)
                if len(ranked) == 1 or ranked[0][1] > ranked[1][1]:
                    break

            if attempt < attempts_max:
                time.sleep(dag_params.sleep_between_retries)

        if not aggregated_valid:
            errors.append(f"s{idx}: no valid candidates after {attempts_max + 1} rounds")
            action = "continue"
            prev_list = [last_previous_step_id]
            raw_text = ""
        else:
            action_counts: Dict[str, int] = {}
            for blk in aggregated_valid:
                action_counts[blk.action] = action_counts.get(blk.action, 0) + 1
            ranked = sorted(action_counts.items(), key=lambda x: x[1], reverse=True)
            if len(ranked) == 1 or ranked[0][1] > ranked[1][1]:
                action = ranked[0][0]
            else:
                for pref in ["continue", "backtrack", "merge"]:
                    if pref in action_counts:
                        action = pref
                        break
            if action is None:
                action = "continue"

            if action == "continue":
                prev_list = [last_previous_step_id]
                explanation = next((b.explanation for b in aggregated_valid if b.action == "continue"), "")
            elif action == "backtrack":
                ks = [b.prev_list[0] for b in aggregated_valid if b.action == "backtrack" and len(b.prev_list) == 1]
                if not ks:
                    action = "continue"
                    prev_list = [last_previous_step_id]
                    explanation = ""
                else:
                    counts: Dict[int, int] = {}
                    for k in ks:
                        counts[k] = counts.get(k, 0) + 1
                    maxc = max(counts.values())
                    tied = [k for k, v in counts.items() if v == maxc]
                    if dag_params.random_seed is not None:
                        rnd = random.Random(dag_params.random_seed + idx)
                        chosen = rnd.choice(tied)
                    else:
                        chosen = tied[0]
                    prev_list = [chosen]
                    explanation = next((b.explanation for b in aggregated_valid if b.action == "backtrack" and b.prev_list == [chosen]), "")
            else:  # merge
                sets = [tuple(sorted(b.prev_list)) for b in aggregated_valid if b.action == "merge" and len(b.prev_list) >= 2]
                if not sets:
                    action = "continue"
                    prev_list = [last_previous_step_id]
                    explanation = ""
                else:
                    counts: Dict[Tuple[int, ...], int] = {}
                    for s in sets:
                        counts[s] = counts.get(s, 0) + 1
                    maxc = max(counts.values())
                    tied = [list(s) for s, v in counts.items() if v == maxc]
                    if dag_params.random_seed is not None:
                        rnd = random.Random(dag_params.random_seed + idx)
                        prev_list = rnd.choice(tied)
                    else:
                        prev_list = tied[0]
                    explanation = next((b.explanation for b in aggregated_valid if b.action == "merge" and sorted(b.prev_list) == sorted(prev_list)), "")
            if not explanation:
                explanation = next((b.explanation for b in aggregated_valid if b.action == action), "")
            raw_text = next((b.raw_text for b in aggregated_valid if b.action == action), "")

        if action == "continue":
            parent = main_path[-1]
            prev_list = [parent]
        elif action == "backtrack":
            valid = [p for p in prev_list if p < idx and p in main_path]
            if not valid:
                action = "continue"
                parent = main_path[-1]
                prev_list = [parent]
            else:
                parent = max(valid, key=lambda x: main_path.index(x))
                prev_list = [parent]
                old_tail = main_path[-1]
                if old_tail != parent:
                    leaf_set.add(old_tail)
                main_path = main_path[: main_path.index(parent) + 1]
        else:  # merge
            valid = [p for p in prev_list if p < idx]
            if len(valid) < 2:
                action = "continue"
                parent = main_path[-1]
                prev_list = [parent]
            else:
                prev_list = list(dict.fromkeys(valid))
                main_candidates = [p for p in prev_list if p in main_path]
                if main_candidates:
                    main_parent = max(main_candidates, key=lambda x: main_path.index(x))
                    main_path = main_path[: main_path.index(main_parent) + 1]
                else:
                    main_parent = main_path[-1]
                for p in prev_list:
                    leaf_set.discard(p)

        parents[idx] = list(prev_list)
        children[idx] = []
        actions[idx] = action
        explanations[idx] = explanation
        node_texts[idx] = steps[idx] if idx < len(steps) else ""
        raw_outputs[idx] = raw_text
        _log(f"  [blue]action[/blue]={action} parents={prev_list}")

        for p in prev_list:
            children.setdefault(p, []).append(idx)
            leaf_set.discard(p)
        leaf_set.add(idx)

        if idx not in main_path:
            main_path.append(idx)

    for nid in parents.keys():
        children.setdefault(nid, [])

    depth_map = _compute_depth(children, 0)
    dag_raw = _graph_to_jsonable(parents, children, depth_map, explanations, node_texts, actions, errors)
    dag_merged = _build_merged_view(dag_raw, steps)

    nodes = [{"id": i, "text": steps[i]} for i in range(n_steps)]
    edges = [{"from": p, "to": i} for i, plist in parents.items() for p in plist]

    _log("[bold cyan]DAG[/bold cyan] done")
    return {
        "prompt": prompt,
        "nodes": nodes,
        "edges": edges,
        "parents": parents,
        "children": children,
        "actions": actions,
        "raw_outputs": raw_outputs,
        "dag_graph_raw": dag_raw,
        "dag_graph_merged": dag_merged,
        "dag_parse_errors": errors,
        "dag_usage": usage_all,
    }


def build_dag_batch(
    items: Iterable[Dict[str, Any]],
    openai_config: OpenAIConfig,
    dag_params: DagParams,
    num_threads: int = 4,
) -> List[Dict[str, Any]]:
    tasks = list(items)
    results: List[Optional[Dict[str, Any]]] = [None] * len(tasks)

    def _worker(index: int, item: Dict[str, Any]) -> Tuple[int, Dict[str, Any]]:
        prompt = item.get("prompt", "")
        steps = item.get("steps", [])
        return index, build_dag(prompt, steps, openai_config, dag_params)

    with Progress(
        SpinnerColumn(),
        TextColumn("{task.description}"),
        BarColumn(),
        TextColumn("{task.completed}/{task.total}"),
        TimeElapsedColumn(),
        TimeRemainingColumn(),
        console=_CONSOLE,
    ) as progress:
        task_id = progress.add_task("Building DAG batch", total=len(tasks))
        with ThreadPoolExecutor(max_workers=num_threads) as ex:
            future_map = {ex.submit(_worker, idx, item): idx for idx, item in enumerate(tasks)}
            for future in as_completed(future_map):
                idx, result = future.result()
                results[idx] = result
                progress.update(task_id, advance=1)

    return [r for r in results if r is not None]


def load_json(path: str) -> Any:
    with open(path, "r", encoding="utf-8") as fh:
        return json.load(fh)


def main() -> None:
    data_path = os.path.join(os.path.dirname(__file__), "data.json")
    data = load_json(data_path)
    if not isinstance(data, list):
        raise ValueError("data.json must contain a list of items")

    first = data[0]
    openai_cfg = first.get("openai_config", {})
    dag_cfg = first.get("dag_params", {})
    thread_cfg = first.get("threading", {})

    api_key = openai_cfg.get("api_key") or os.getenv(openai_cfg.get("api_key_env", "OPENAI_API_KEY"), "")
    if not api_key:
        raise RuntimeError("Missing OpenAI API key. Set OPENAI_API_KEY or provide api_key in data.json")

    openai_config = OpenAIConfig(
        api_key=api_key,
        base_url=openai_cfg.get("base_url"),
        model=openai_cfg.get("model"),
        temperature=float(openai_cfg.get("temperature", 0.3)),
        top_p=float(openai_cfg.get("top_p", 0.9)),
        max_tokens=int(openai_cfg.get("max_tokens", 512)),
        n=int(openai_cfg.get("n", 1)),
        timeout=openai_cfg.get("timeout"),
    )
    dag_params = DagParams(
        regen_limit=int(dag_cfg.get("regen_limit", 5)),
        main_path_cap=int(dag_cfg.get("main_path_cap", 8)),
        other_leaf_cap=int(dag_cfg.get("other_leaf_cap", 5)),
        sleep_between_retries=float(dag_cfg.get("sleep_between_retries", 0.5)),
        strict_mode=bool(dag_cfg.get("strict_mode", False)),
        random_seed=dag_cfg.get("random_seed"),
        max_n_per_request=int(dag_cfg.get("max_n_per_request", 4)),
    )
    num_threads = int(thread_cfg.get("num_threads", 4))

    results = build_dag_batch(data, openai_config, dag_params, num_threads=num_threads)
    out_path = os.path.join(os.path.dirname(__file__), "dag_results.json")
    with open(out_path, "w", encoding="utf-8") as fh:
        json.dump(results, fh, ensure_ascii=False, indent=2)

    print(f"Wrote {len(results)} DAG results to {out_path}")


if __name__ == "__main__":
    main()
