from __future__ import annotations

import dataclasses
import json
import math
from logging import getLogger
from pathlib import Path
from typing import Any

from bd_mcts.search_algo.base import Result, SearchAlgo, Trial
from bd_mcts.token_sequence import TokenSequence

logger = getLogger(__name__)


@dataclasses.dataclass
class MultiModelMCTSConfig:
    models: list[str]
    gen_length: int
    num_func_eval_budget: int
    full_rollout: bool = True
    enable_rollout_cache: bool = True
    demask_schedule: list[float] = dataclasses.field(
        default_factory=lambda: [0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.2]
    )
    min_unmask_num: int = 3
    exploration_const: float = 1.0
    max_selection_attempts: int = 20


@dataclasses.dataclass(eq=False)
class MCTSNode:
    token_sequence: TokenSequence | None
    clean_pred: str | None
    parent: "MCTSNode | None"
    depth: int
    # Metadata about the rollout/evaluation that created (or last updated) this node.
    incoming_action: str | None = None
    trial_id: str | None = None
    rollout_reward: float | None = None
    rollout_clean_pred: str | None = None
    rollout_num_func_evals: int | None = None
    rollout_token_seqs: list[TokenSequence] | None = None
    rollout_final_token_len: int | None = None
    rollout_token_len_history: list[int] | None = None
    rollout_token_len_with_mask_history: list[int] | None = None
    demask_step_indices: list[int] | None = None
    demask_mask_counts: list[int] | None = None
    children: dict[str, "MCTSNode"] = dataclasses.field(default_factory=dict)
    visit_count: int = 0
    value_sum: float = 0.0
    is_exhausted: bool = False

    def is_terminal(self) -> bool:
        return self.token_sequence is not None and self.token_sequence.num_masks <= 0

    def is_root(self) -> bool:
        return self.parent is None

    @property
    def value_mean(self) -> float:
        if self.visit_count <= 0:
            return 0.0
        return self.value_sum / self.visit_count


@dataclasses.dataclass(frozen=True)
class CachedRollout:
    new_token_seq: TokenSequence
    rollout_token_seqs: list[TokenSequence]
    clean_pred: str
    final_token_len: int | None = None
    token_len_history: list[int] | None = None
    token_len_with_mask_history: list[int] | None = None


class MultiModelMCTS(SearchAlgo):
    def __init__(self, config: MultiModelMCTSConfig) -> None:
        if not config.models:
            raise ValueError("MultiModelMCTS requires at least one model")
        if config.gen_length <= 0:
            raise ValueError("gen_length must be positive")
        if not config.demask_schedule:
            raise ValueError("demask_schedule must be non-empty")

        self.actions = list(config.models)
        self.gen_length = int(config.gen_length)
        self.num_func_eval_budget = int(config.num_func_eval_budget)
        if not config.full_rollout:
            logger.warning(
                "MultiModelMCTS always uses full rollout; overriding full_rollout=False"
            )
        self.full_rollout = True
        self.enable_rollout_cache = bool(config.enable_rollout_cache)
        self.demask_schedule = list(config.demask_schedule)
        self.min_unmask_num = int(config.min_unmask_num)
        self.exploration_const = float(config.exploration_const)
        self.max_selection_attempts = int(config.max_selection_attempts)

        self.schedule_targets = self._build_schedule_targets()

        self.root = MCTSNode(
            token_sequence=None,
            clean_pred=None,
            parent=None,
            depth=0,
        )
        self.pending_trials: dict[str, tuple[MCTSNode, str]] = {}
        self.results: list[tuple[str, float]] = []
        self.trial_id = 0

        self.rollout_cache: dict[tuple[tuple, str], CachedRollout] = {}

    def next_trial_id(self) -> str:
        self.trial_id += 1
        return str(self.trial_id)

    def ask(self) -> Trial:
        for _ in range(self.max_selection_attempts):
            selection = self._select_node_action()
            if selection is None:
                continue

            node, action, num_tokens_to_demask, target_idx = selection
            trial_id = self.next_trial_id()
            self.pending_trials[trial_id] = (node, action)

            remaining_budget = (
                self.num_func_eval_budget if self.num_func_eval_budget > 0 else -1
            )
            parent_token_seq = node.token_sequence if not node.is_root() else None

            cached = (
                self._get_cached_rollout(node, action)
                if self.enable_rollout_cache
                else None
            )
            cached_result = None
            rollout_targets = None
            if cached is not None:
                cached_result = Result(
                    trial_id=trial_id,
                    new_token_seq=cached.new_token_seq,
                    clean_pred=cached.clean_pred,
                    num_func_evals=0,
                    reward=0.0,
                    rollout_token_seqs=cached.rollout_token_seqs,
                    rollout_clean_preds=None,
                    final_token_len=cached.final_token_len,
                    token_len_history=cached.token_len_history,
                    token_len_with_mask_history=cached.token_len_with_mask_history,
                )
            else:
                rollout_targets = self._rollout_mask_targets(target_idx)

            return Trial(
                trial_id=trial_id,
                parent_token_seq=parent_token_seq,
                num_tokens_to_demask=num_tokens_to_demask,
                remaining_func_evals=remaining_budget,
                full_rollout=self.full_rollout,
                action=action,
                rollout_mask_targets=rollout_targets,
                cached_result=cached_result,
            )

        logger.warning(
            "ask failed to find an expandable node. will reset the tree and retry."
        )
        self._reset_tree()
        return self.ask()

    def tell(self, result: Result) -> None:
        pending = self.pending_trials.pop(result.trial_id, None)
        if pending is None:
            raise RuntimeError(
                f"trial_id {result.trial_id} not found in pending trials"
            )
        parent, action = pending

        self.num_func_eval_budget -= result.num_func_evals

        child = parent.children.get(action)
        if child is None:
            child = MCTSNode(
                token_sequence=result.new_token_seq,
                clean_pred=None,
                parent=parent,
                depth=parent.depth + 1,
                incoming_action=action,
                trial_id=result.trial_id,
                rollout_reward=float(result.reward),
                rollout_clean_pred=result.clean_pred,
                rollout_num_func_evals=int(result.num_func_evals),
            )
            parent.children[action] = child
        else:
            child.token_sequence = result.new_token_seq
            child.incoming_action = action
            child.trial_id = result.trial_id
            child.rollout_reward = float(result.reward)
            child.rollout_clean_pred = result.clean_pred
            child.rollout_num_func_evals = int(result.num_func_evals)
            child.rollout_token_seqs = None
            child.rollout_final_token_len = None
            child.rollout_token_len_history = None
            child.rollout_token_len_with_mask_history = None
            child.demask_step_indices = None
            child.demask_mask_counts = None

        if result.rollout_token_seqs:
            child.rollout_token_seqs = list(result.rollout_token_seqs)
        child.rollout_final_token_len = result.final_token_len
        if result.token_len_history is not None:
            child.rollout_token_len_history = list(result.token_len_history)
        if result.token_len_with_mask_history is not None:
            child.rollout_token_len_with_mask_history = list(
                result.token_len_with_mask_history
            )
        if result.demask_step_indices is not None:
            child.demask_step_indices = list(result.demask_step_indices)
        if result.demask_mask_counts is not None:
            child.demask_mask_counts = list(result.demask_mask_counts)

        if child.is_terminal():
            child.clean_pred = result.clean_pred
            child.is_exhausted = True

        if self.enable_rollout_cache:
            rollout_path = self._normalize_rollout_path(result)
            if rollout_path:
                self._store_rollout_cache(
                    parent,
                    action,
                    result.clean_pred,
                    rollout_path,
                    result.final_token_len,
                    result.token_len_history,
                    result.token_len_with_mask_history,
                )

        self._backup(child, result.reward)
        self.results.append((result.clean_pred, result.reward))

    def top_k(self, k: int) -> list[tuple[str, float]]:
        return sorted(self.results, key=lambda x: x[1], reverse=True)[:k]

    def render_tree_dot(
        self,
        output_path: str | Path,
        *,
        max_depth: int | None = None,
        max_nodes: int | None = None,
        include_clean_pred: bool = False,
        correct_reward_threshold: float = 1.0,
    ) -> None:
        render_multimodel_mcts_tree_dot(
            self.root,
            output_path,
            gen_length=self.gen_length,
            max_depth=max_depth,
            max_nodes=max_nodes,
            include_clean_pred=include_clean_pred,
            correct_reward_threshold=correct_reward_threshold,
        )

    def export_tree_json(
        self,
        output_path: str | Path,
        *,
        max_depth: int | None = None,
        max_nodes: int | None = None,
        include_token_sequences: bool = True,
        include_state_text: bool = True,
        state_text_max_chars: int | None = 600,
        include_rollouts: bool = True,
        include_rollout_texts: bool = True,
        rollout_text_max_chars: int | None = 600,
        include_rollout_token_sequences: bool = False,
        include_full_preds: bool = False,
        pred_preview_chars: int | None = 160,
        correct_reward_threshold: float = 1.0,
    ) -> None:
        export_multimodel_mcts_tree_json(
            self.root,
            output_path,
            gen_length=self.gen_length,
            max_depth=max_depth,
            max_nodes=max_nodes,
            include_token_sequences=include_token_sequences,
            include_state_text=include_state_text,
            state_text_max_chars=state_text_max_chars,
            include_rollouts=include_rollouts,
            include_rollout_texts=include_rollout_texts,
            rollout_text_max_chars=rollout_text_max_chars,
            include_rollout_token_sequences=include_rollout_token_sequences,
            include_full_preds=include_full_preds,
            pred_preview_chars=pred_preview_chars,
            correct_reward_threshold=correct_reward_threshold,
        )

    def _build_schedule_targets(self) -> list[int]:
        targets: list[int] = []
        for frac in self.demask_schedule:
            try:
                value = float(frac)
            except (TypeError, ValueError) as exc:
                raise ValueError(
                    f"demask_schedule entries must be numbers, got {frac!r}"
                ) from exc
            masks = int(math.ceil(self.gen_length * value))
            masks = max(0, min(self.gen_length, masks))
            targets.append(masks)
        return targets

    def _resolve_demask_plan(self, node: MCTSNode) -> tuple[int, int] | None:
        if not self.schedule_targets:
            return None
        if node.is_root() or node.token_sequence is None:
            parent_masks = self.gen_length
        else:
            parent_masks = int(node.token_sequence.num_masks)
        parent_depth = node.depth
        if parent_depth >= len(self.schedule_targets):
            return None

        target_idx = None
        target_masks = None
        for idx in range(parent_depth, len(self.schedule_targets)):
            candidate = self.schedule_targets[idx]
            if parent_masks > candidate:
                target_idx = idx
                target_masks = candidate
                break

        if target_idx is None or target_masks is None:
            return None

        num_tokens_to_demask = parent_masks - target_masks
        if num_tokens_to_demask < self.min_unmask_num:
            num_tokens_to_demask = min(parent_masks, self.min_unmask_num)
        if num_tokens_to_demask <= 0:
            return None
        return num_tokens_to_demask, target_idx

    def _rollout_mask_targets(self, target_idx: int) -> list[int] | None:
        if not self.full_rollout:
            return None
        if not self.schedule_targets:
            return None
        if target_idx + 1 >= len(self.schedule_targets):
            return None
        return list(self.schedule_targets[target_idx + 1 :])

    def _select_node_action(
        self,
    ) -> tuple[MCTSNode, str, int, int] | None:
        node = self.root
        while True:
            if node.is_exhausted:
                return None
            if node.is_terminal():
                self._mark_exhausted(node)
                return None

            plan = self._resolve_demask_plan(node)
            if plan is None:
                self._mark_exhausted(node)
                return None
            num_tokens_to_demask, target_idx = plan

            unexpanded = [a for a in self.actions if a not in node.children]
            if unexpanded:
                action = self._select_action(node, unexpanded)
                return node, action, num_tokens_to_demask, target_idx

            candidates = [
                child for child in node.children.values() if not child.is_exhausted
            ]
            if not candidates:
                self._mark_exhausted(node)
                return None
            node = self._select_child_ucb(node, candidates)

    def _select_action(self, node: MCTSNode, actions: list[str]) -> str:
        if self.enable_rollout_cache:
            cached = [
                action for action in actions if self._get_cached_rollout(node, action)
            ]
            candidates = cached if cached else actions
            return sorted(candidates)[0]
        return sorted(actions)[0]

    def _select_child_ucb(self, parent: MCTSNode, children: list[MCTSNode]) -> MCTSNode:
        parent_visits = max(1, parent.visit_count)
        best_child = children[0]
        best_score = -float("inf")
        for child in children:
            child_visits = max(1, child.visit_count)
            uct_score = child.value_mean + self.exploration_const * math.sqrt(
                math.log(parent_visits) / child_visits
            )
            if uct_score > best_score:
                best_score = uct_score
                best_child = child
        return best_child

    def _backup(self, node: MCTSNode, reward: float) -> None:
        current = node
        while current is not None:
            current.visit_count += 1
            current.value_sum += reward
            current = current.parent

    def _mark_exhausted(self, node: MCTSNode) -> None:
        node.is_exhausted = True
        current = node.parent
        while current is not None:
            if current.children and all(
                child.is_exhausted for child in current.children.values()
            ):
                current.is_exhausted = True
            else:
                break
            current = current.parent

    def _reset_tree(self) -> None:
        self.root = MCTSNode(
            token_sequence=None,
            clean_pred=None,
            parent=None,
            depth=0,
        )
        self.pending_trials.clear()

    def _get_cached_rollout(
        self, node: MCTSNode, action: str
    ) -> CachedRollout | None:
        key = (self._state_key(node.depth, node.token_sequence), action)
        return self.rollout_cache.get(key)

    def _store_rollout_cache(
        self,
        parent: MCTSNode,
        action: str,
        clean_pred: str,
        rollout_path: list[TokenSequence],
        final_token_len: int | None,
        token_len_history: list[int] | None,
        token_len_with_mask_history: list[int] | None,
    ) -> None:
        if len(rollout_path) < 2:
            return
        for idx in range(len(rollout_path) - 1):
            depth = parent.depth + 1 + idx
            state_key = self._state_key(depth, rollout_path[idx])
            cached_history = None
            if token_len_history is not None and len(token_len_history) == len(
                rollout_path
            ):
                cached_history = token_len_history[idx + 1 :]
            cached_with_mask_history = None
            if token_len_with_mask_history is not None and len(
                token_len_with_mask_history
            ) == len(rollout_path):
                cached_with_mask_history = token_len_with_mask_history[idx + 1 :]
            cached = CachedRollout(
                new_token_seq=rollout_path[idx + 1],
                rollout_token_seqs=rollout_path[idx + 1 :],
                clean_pred=clean_pred,
                final_token_len=final_token_len,
                token_len_history=cached_history,
                token_len_with_mask_history=cached_with_mask_history,
            )
            self.rollout_cache[(state_key, action)] = cached

    def _normalize_rollout_path(self, result: Result) -> list[TokenSequence]:
        path = list(result.rollout_token_seqs or [])
        if not path:
            return [result.new_token_seq]
        if self._token_sequence_key(path[0]) != self._token_sequence_key(
            result.new_token_seq
        ):
            path.insert(0, result.new_token_seq)
        return path

    def _state_key(self, depth: int, seq: TokenSequence | None) -> tuple:
        return (depth, self._token_sequence_key(seq))

    def _token_sequence_key(self, seq: TokenSequence | None) -> tuple:
        if seq is None:
            return ("<root>",)
        segments = tuple(
            (segment.kind, segment.content, int(segment.repetition))
            for segment in seq.segments
        )
        prompt = tuple(
            (msg.get("role", ""), msg.get("content", ""))
            for msg in seq.prompt_msgs
        )
        return (prompt, segments, int(seq.gen_length))


def _truncate_text(text: str, limit: int) -> str:
    if limit <= 0:
        return ""
    if len(text) <= limit:
        return text
    if limit <= 3:
        return text[:limit]
    return text[: limit - 3] + "..."


def _dot_escape(text: str) -> str:
    safe = text.replace("\r", " ").replace("\n", " ").replace("\t", " ")
    safe = safe.encode("ascii", "backslashreplace").decode("ascii")
    safe = safe.replace("\\", "\\\\").replace('"', '\\"')
    return safe


def _format_node_label(
    node: MCTSNode,
    node_id: int,
    *,
    gen_length: int | None,
    include_clean_pred: bool,
    correct_reward_threshold: float = 1.0,
) -> str:
    lines = [
        f"id={node_id}",
        f"depth={node.depth}",
        f"children={len(node.children)}",
        f"visits={node.visit_count}",
        f"value={node.value_mean:.3f}",
    ]
    if node.trial_id is not None:
        lines.append(f"trial={node.trial_id}")
    if node.rollout_reward is not None:
        lines.append(f"reward={node.rollout_reward:.3f}")
        if node.rollout_reward >= float(correct_reward_threshold):
            lines.append("correct=1")
    if node.rollout_num_func_evals is not None:
        lines.append(f"nfe={int(node.rollout_num_func_evals)}")
    if node.token_sequence is None:
        if gen_length is not None:
            lines.append(f"masks={gen_length}")
    else:
        lines.append(f"masks={int(node.token_sequence.num_masks)}")
        lines.append(f"mask_frac={node.token_sequence.mask_fraction:.3f}")
    if node.is_terminal():
        lines.append("terminal=1")
    if node.is_exhausted:
        lines.append("exhausted=1")
    if include_clean_pred:
        pred = node.clean_pred or node.rollout_clean_pred
        if pred:
            preview = _truncate_text(pred, 80)
            lines.append(f"pred={preview}")
    return "\\n".join(_dot_escape(line) for line in lines)


def render_multimodel_mcts_tree_dot(
    root: MCTSNode,
    output_path: str | Path,
    *,
    gen_length: int | None = None,
    max_depth: int | None = None,
    max_nodes: int | None = None,
    include_clean_pred: bool = False,
    correct_reward_threshold: float = 1.0,
) -> None:
    path = Path(output_path)
    node_ids: dict[MCTSNode, int] = {}
    node_lines: list[str] = []
    edge_lines: list[str] = []
    next_id = 0

    stack: list[tuple[MCTSNode, int | None, str | None, int]] = [
        (root, None, None, 0)
    ]
    while stack:
        node, parent_id, action, depth = stack.pop()
        if max_depth is not None and depth > max_depth:
            continue

        node_id = node_ids.get(node)
        if node_id is None:
            if max_nodes is not None and len(node_ids) >= max_nodes:
                continue
            node_id = next_id
            next_id += 1
            node_ids[node] = node_id
            label = _format_node_label(
                node,
                node_id,
                gen_length=gen_length,
                include_clean_pred=include_clean_pred,
                correct_reward_threshold=correct_reward_threshold,
            )
            node_lines.append(f'  n{node_id} [label="{label}"];')

        if parent_id is not None:
            edge_label = _dot_escape(action or "")
            edge_lines.append(
                f'  n{parent_id} -> n{node_id} [label="{edge_label}"];'
            )

        if max_nodes is not None and len(node_ids) >= max_nodes:
            continue

        children = sorted(node.children.items(), key=lambda item: item[0])
        for action, child in reversed(children):
            stack.append((child, node_id, action, depth + 1))

    lines = [
        "digraph MultiModelMCTS {",
        "  rankdir=TB;",
        '  node [shape=box, fontname="Courier"];',
        '  edge [fontname="Courier"];',
    ]
    lines.extend(node_lines)
    lines.extend(edge_lines)
    lines.append("}")

    path.parent.mkdir(parents=True, exist_ok=True)
    path.write_text("\n".join(lines) + "\n", encoding="utf-8")


def _compact_segments(seq: TokenSequence) -> list[dict[str, Any]]:
    compact: list[dict[str, Any]] = []
    for segment in seq.segments:
        kind = getattr(segment, "kind", "")
        content = getattr(segment, "content", "")
        repetition = int(getattr(segment, "repetition", 1) or 1)
        if (
            compact
            and compact[-1]["kind"] == kind
            and compact[-1]["content"] == content
        ):
            compact[-1]["repetition"] += repetition
        else:
            compact.append(
                {"kind": kind, "content": content, "repetition": repetition}
            )
    return compact


def _format_masked_text_from_compact_segments(
    segments: list[dict[str, Any]],
) -> str:
    parts: list[str] = []
    for seg in segments:
        kind = str(seg.get("kind", ""))
        if kind == "text":
            parts.append(str(seg.get("content", "")))
            continue
        if kind == "mask":
            rep = int(seg.get("repetition", 1) or 1)
            if rep == 1:
                parts.append("<mask>")
            else:
                parts.append(f"<mask>*{rep}")
            continue
        # Ignore pad/eos/eot for visualization/export.
    return "".join(parts)


def _token_sequence_to_dict(seq: TokenSequence) -> dict[str, Any]:
    compact = _compact_segments(seq)
    return {
        "gen_length": int(seq.gen_length),
        "token_count": int(seq.generated_token_count)
        if seq.generated_token_count is not None
        else None,
        "num_masks": int(seq.num_masks),
        "mask_fraction": float(seq.mask_fraction),
        "mask_positions": list(seq.mask_positions) if seq.mask_positions is not None else None,
        "segments": compact,
    }


def export_multimodel_mcts_tree_json(
    root: MCTSNode,
    output_path: str | Path,
    *,
    gen_length: int | None = None,
    max_depth: int | None = None,
    max_nodes: int | None = None,
    include_token_sequences: bool = True,
    include_state_text: bool = True,
    state_text_max_chars: int | None = 600,
    include_rollouts: bool = True,
    include_rollout_texts: bool = True,
    rollout_text_max_chars: int | None = 600,
    include_rollout_token_sequences: bool = False,
    include_full_preds: bool = False,
    pred_preview_chars: int | None = 160,
    correct_reward_threshold: float = 1.0,
) -> None:
    path = Path(output_path)
    node_ids: dict[MCTSNode, int] = {}
    nodes_out: list[dict[str, Any]] = []
    edges_out: list[dict[str, Any]] = []
    next_id = 0

    # (node, parent_id, action, depth)
    stack: list[tuple[MCTSNode, int | None, str | None, int]] = [
        (root, None, None, 0)
    ]
    while stack:
        node, parent_id, action, depth = stack.pop()
        if max_depth is not None and depth > max_depth:
            continue

        node_id = node_ids.get(node)
        if node_id is None:
            if max_nodes is not None and len(node_ids) >= max_nodes:
                continue
            node_id = next_id
            next_id += 1
            node_ids[node] = node_id

            seq = node.token_sequence
            masks = None
            mask_fraction = None
            token_count = None
            mask_positions = None
            if seq is None:
                if gen_length is not None:
                    masks = int(gen_length)
                    mask_fraction = float(gen_length) / float(gen_length) if gen_length else None
                    token_count = int(gen_length)
                    mask_positions = list(range(int(gen_length)))
            else:
                masks = int(seq.num_masks)
                mask_fraction = float(seq.mask_fraction)
                token_count = (
                    int(seq.generated_token_count)
                    if seq.generated_token_count is not None
                    else int(seq.gen_length)
                )
                if seq.mask_positions is not None:
                    mask_positions = list(seq.mask_positions)

            pred = node.clean_pred or node.rollout_clean_pred
            preview = None
            if pred:
                if pred_preview_chars is None:
                    preview = pred
                else:
                    preview = _truncate_text(pred, max(0, int(pred_preview_chars)))

            state_text_preview = None
            if include_state_text:
                if seq is None:
                    if gen_length is not None:
                        state_text_preview = f"<mask>*{int(gen_length)}"
                else:
                    state_segments = _compact_segments(seq)
                    state_text = _format_masked_text_from_compact_segments(
                        state_segments
                    )
                    if state_text_max_chars is None:
                        state_text_preview = state_text
                    else:
                        state_text_preview = _truncate_text(
                            state_text, max(0, int(state_text_max_chars))
                        )

            rollout_obj = None
            if include_rollouts and node.rollout_token_seqs:
                snapshots: list[dict[str, Any]] = []
                lengths = node.rollout_token_len_history or []
                lengths_with_masks = node.rollout_token_len_with_mask_history or []
                for idx, snap_seq in enumerate(node.rollout_token_seqs):
                    snap_rec: dict[str, Any] = {
                        "idx": int(idx),
                        "token_count": (
                            int(snap_seq.generated_token_count)
                            if snap_seq.generated_token_count is not None
                            else int(snap_seq.gen_length)
                        ),
                        "masks": int(snap_seq.num_masks),
                        "mask_fraction": float(snap_seq.mask_fraction),
                        "mask_positions": list(snap_seq.mask_positions)
                        if snap_seq.mask_positions is not None
                        else None,
                    }
                    if idx < len(lengths):
                        snap_rec["token_len"] = int(lengths[idx])
                    if idx < len(lengths_with_masks):
                        snap_rec["token_len_with_masks"] = int(lengths_with_masks[idx])
                    if include_rollout_texts:
                        snap_segments = _compact_segments(snap_seq)
                        snap_text = _format_masked_text_from_compact_segments(
                            snap_segments
                        )
                        if rollout_text_max_chars is None:
                            snap_rec["masked_text"] = snap_text
                        else:
                            snap_rec["masked_text"] = _truncate_text(
                                snap_text, max(0, int(rollout_text_max_chars))
                            )
                    if include_rollout_token_sequences:
                        snap_rec["token_sequence"] = _token_sequence_to_dict(snap_seq)
                    snapshots.append(snap_rec)
                rollout_obj = {
                    "final_token_len": node.rollout_final_token_len,
                    "token_len_history": node.rollout_token_len_history,
                    "token_len_with_mask_history": node.rollout_token_len_with_mask_history,
                    "demask_step_indices": node.demask_step_indices,
                    "demask_mask_counts": node.demask_mask_counts,
                    "clean_pred": node.rollout_clean_pred,
                    "snapshots": snapshots,
                }

            node_rec: dict[str, Any] = {
                "id": node_id,
                "depth": int(node.depth),
                "parent_id": parent_id,
                "incoming_action": node.incoming_action,
                "visit_count": int(node.visit_count),
                "value_sum": float(node.value_sum),
                "value_mean": float(node.value_mean),
                "is_terminal": bool(node.is_terminal()),
                "is_exhausted": bool(node.is_exhausted),
                "masks": masks,
                "mask_fraction": mask_fraction,
                "token_count": token_count,
                "mask_positions": mask_positions,
                "trial_id": node.trial_id,
                "rollout_reward": node.rollout_reward,
                "rollout_num_func_evals": node.rollout_num_func_evals,
                "pred_preview": preview,
                "pred": pred if include_full_preds else None,
                "state_masked_text": state_text_preview,
                "rollout": rollout_obj,
            }
            if node.rollout_reward is not None:
                node_rec["is_correct"] = bool(
                    float(node.rollout_reward) >= float(correct_reward_threshold)
                )
            if include_token_sequences and seq is not None:
                node_rec["token_sequence"] = _token_sequence_to_dict(seq)

            nodes_out.append(node_rec)

        if parent_id is not None:
            edges_out.append(
                {
                    "source": int(parent_id),
                    "target": int(node_id),
                    "action": action or "",
                }
            )

        if max_nodes is not None and len(node_ids) >= max_nodes:
            continue

        children = sorted(node.children.items(), key=lambda item: item[0])
        for child_action, child in reversed(children):
            stack.append((child, node_id, child_action, depth + 1))

    best_reward = None
    best_node_id = None
    for rec in nodes_out:
        reward = rec.get("rollout_reward")
        if reward is None:
            continue
        reward = float(reward)
        if best_reward is None or reward > best_reward:
            best_reward = reward
            best_node_id = int(rec["id"])

    output = {
        "format": "bd-mcts.multimodel_mcts.tree.v3",
        "gen_length": int(gen_length) if gen_length is not None else None,
        "correct_reward_threshold": float(correct_reward_threshold),
        "nodes": nodes_out,
        "edges": edges_out,
        "stats": {
            "total_nodes": len(nodes_out),
            "total_edges": len(edges_out),
            "best_reward": best_reward,
            "best_node_id": best_node_id,
        },
    }

    path.parent.mkdir(parents=True, exist_ok=True)
    path.write_text(json.dumps(output, indent=2, ensure_ascii=True) + "\n", encoding="utf-8")
