from __future__ import annotations

import json
import math
import multiprocessing as mp
import os
import queue
import random
import shutil
import subprocess
from logging import getLogger
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Any, Iterable

import hydra
from joblib import Parallel, delayed, parallel_config
from omegaconf import DictConfig, OmegaConf

from bd_mcts.search_algo.abmcts import (
    ABMCTSD,
    ABMCTSDConfig,
    ABMCTSDParallel,
    render_abmcts_tree_html,
)
from bd_mcts.search_algo.base import Result, SearchAlgo
from bd_mcts.search_algo.best_of_n import BestOfN, BestOfNConfig
from bd_mcts.search_algo.dts import DTS, DTSConfig
from bd_mcts.search_algo.mask_strategy_mcts import (
    MaskStrategyMCTS,
    MaskStrategyMCTSConfig,
)
from bd_mcts.search_algo.temperature_mcts import (
    TemperatureMCTS,
    TemperatureMCTSConfig,
)
from bd_mcts.search_algo.multimodel_mcts import (
    MultiModelMCTS,
    MultiModelMCTSConfig,
)
from bd_mcts.tasks.base import ProbResult, Task
from bd_mcts.tasks.code_eval_task import (
    BigCodeBenchTask,
    CruxEvalTask,
    EvalPlusTask,
    make_code_eval_task,
)

logger = getLogger(__name__)


@dataclass(frozen=True)
class GenerationResult:
    new_token_seq: TokenSequence
    lm_response: str
    num_func_evals: int
    rollout_token_seqs: list[TokenSequence] | None
    final_token_len: int | None = None
    token_len_history: list[int] | None = None
    token_len_with_mask_history: list[int] | None = None
    demask_step_indices: list[int] | None = None
    demask_mask_counts: list[int] | None = None


@dataclass(frozen=True)
class ModelHandle:
    name: str
    tokenizer: Any
    demasker: Any
    max_seq_len: int | None
    device: Any
    mask_token_id: int
    pad_token_id: int
    eos_token_id: int
    eot_token_ids: tuple[int, ...]


class DemaskGenerator:
    def __init__(
        self,
        handles: dict[str, ModelHandle],
        *,
        gen_length: int,
        system_prompt: str | None,
        debug_dir: Path | None = None,
        trace_demask_steps: bool = False,
    ) -> None:
        self._handles = handles
        self._gen_length = gen_length
        self._system_prompt = system_prompt
        self._debug_dir = debug_dir
        self._debug_fp = None
        self._trace_demask_steps = bool(trace_demask_steps)
        self._mask_strategy_demaskers: dict[tuple[str, str], Any] = {}
        self._temperature_demaskers: dict[tuple[str, str, float], Any] = {}

    def _resolve_demasker_for_mask_strategy(
        self, *, model_name: str, mask_strategy: str | None
    ) -> Any:
        handle = self._handles[model_name]
        if not mask_strategy:
            return handle.demasker

        key = (model_name, mask_strategy)
        cached = self._mask_strategy_demaskers.get(key)
        if cached is not None:
            return cached

        demasker = handle.demasker

        try:
            from bd_mcts.demask_interface.dream import Dream
        except Exception:  # pragma: no cover - optional dependency paths
            Dream = None  # type: ignore[assignment]

        if Dream is not None and isinstance(demasker, Dream):
            cfg = demasker.generation_config
            if getattr(cfg, "alg", None) == mask_strategy:
                return demasker

            override = Dream(
                demasker.model,
                steps=cfg.steps,
                alg=mask_strategy,
                alg_temp=cfg.alg_temp,
                eos_penalty=cfg.eos_penalty,
                temperature=cfg.temperature,
                top_p=cfg.top_p,
                top_k=cfg.top_k,
                max_length=cfg.max_length,
                max_new_tokens=cfg.max_new_tokens,
                eps=cfg.eps,
                mask_token_id=cfg.mask_token_id,
                pad_token_id=cfg.pad_token_id,
                bos_token_id=cfg.bos_token_id,
                eos_token_id=cfg.eos_token_id,
            )
            self._mask_strategy_demaskers[key] = override
            return override

        try:
            from bd_mcts.demask_interface.llada import Llada
        except Exception:  # pragma: no cover - optional dependency paths
            Llada = None  # type: ignore[assignment]

        if Llada is not None and isinstance(demasker, Llada):
            cfg = demasker.generation_config
            if getattr(cfg, "alg", None) == mask_strategy:
                return demasker

            override = Llada(
                demasker.model,
                steps=cfg.steps,
                gen_length=cfg.gen_length,
                block_length=cfg.block_length,
                alg=mask_strategy,
                temperature=cfg.temperature,
                cfg_scale=cfg.cfg_scale,
                max_new_tokens=cfg.max_new_tokens,
                eos_penalty=cfg.eos_penalty,
                eot_confidence_penalty=cfg.eot_confidence_penalty,
                mask_token_id=cfg.mask_token_id,
                pad_token_id=cfg.pad_token_id,
                bos_token_id=cfg.bos_token_id,
                eos_token_id=cfg.eos_token_id,
                eot_token_ids=cfg.eot_token_ids,
            )
            self._mask_strategy_demaskers[key] = override
            return override

        raise ValueError(
            f"mask_strategy overrides are not supported for demasker={type(demasker).__name__}"
        )

    def _resolve_demasker_for_temperature(
        self,
        demasker: Any,
        *,
        model_name: str,
        demask_temperature: float | None,
    ) -> Any:
        if demask_temperature is None:
            return demasker
        try:
            temperature = float(demask_temperature)
        except (TypeError, ValueError) as exc:
            raise ValueError(
                f"demask_temperature must be a number, got {demask_temperature!r}"
            ) from exc
        if temperature < 0.0:
            raise ValueError(
                f"demask_temperature must be >= 0, got {temperature}"
            )

        cfg = getattr(demasker, "generation_config", None)
        current = getattr(cfg, "temperature", None)
        if current is not None:
            try:
                if float(current) == temperature:
                    return demasker
            except (TypeError, ValueError):
                pass

        alg = str(getattr(cfg, "alg", ""))
        key = (model_name, alg, temperature)
        cached = self._temperature_demaskers.get(key)
        if cached is not None:
            return cached

        try:
            from bd_mcts.demask_interface.dream import Dream
        except Exception:  # pragma: no cover - optional dependency paths
            Dream = None  # type: ignore[assignment]

        if Dream is not None and isinstance(demasker, Dream):
            override = Dream(
                demasker.model,
                steps=cfg.steps,
                alg=cfg.alg,
                alg_temp=cfg.alg_temp,
                eos_penalty=cfg.eos_penalty,
                temperature=temperature,
                top_p=cfg.top_p,
                top_k=cfg.top_k,
                max_length=cfg.max_length,
                max_new_tokens=cfg.max_new_tokens,
                eps=cfg.eps,
                mask_token_id=cfg.mask_token_id,
                pad_token_id=cfg.pad_token_id,
                bos_token_id=cfg.bos_token_id,
                eos_token_id=cfg.eos_token_id,
            )
            self._temperature_demaskers[key] = override
            return override

        try:
            from bd_mcts.demask_interface.llada import Llada
        except Exception:  # pragma: no cover - optional dependency paths
            Llada = None  # type: ignore[assignment]

        if Llada is not None and isinstance(demasker, Llada):
            override = Llada(
                demasker.model,
                steps=cfg.steps,
                gen_length=cfg.gen_length,
                block_length=cfg.block_length,
                alg=cfg.alg,
                temperature=temperature,
                cfg_scale=cfg.cfg_scale,
                max_new_tokens=cfg.max_new_tokens,
                eos_penalty=cfg.eos_penalty,
                eot_confidence_penalty=cfg.eot_confidence_penalty,
                mask_token_id=cfg.mask_token_id,
                pad_token_id=cfg.pad_token_id,
                bos_token_id=cfg.bos_token_id,
                eos_token_id=cfg.eos_token_id,
                eot_token_ids=cfg.eot_token_ids,
            )
            self._temperature_demaskers[key] = override
            return override

        raise ValueError(
            f"temperature overrides are not supported for demasker={type(demasker).__name__}"
        )

    def _debug_write(self, line: str) -> None:
        if self._debug_dir is None:
            return
        if self._debug_fp is None:
            self._debug_dir.mkdir(parents=True, exist_ok=True)
            path = self._debug_dir / f"demask_debug_pid{os.getpid()}.log"
            self._debug_fp = path.open("a", encoding="utf-8")
            self._debug_fp.write(
                f"# demask debug log started pid={os.getpid()} time={datetime.now().isoformat(timespec='seconds')}\n"
            )
            self._debug_fp.flush()
        ts = datetime.now().isoformat(timespec="seconds")
        self._debug_fp.write(f"[{ts}] {line}\n")
        self._debug_fp.flush()

    def generate(
        self,
        *,
        prompt: str,
        parent_seq: TokenSequence | None,
        num_tokens_to_demask: int,
        full_rollout: bool,
        model_name: str,
        mask_strategy: str | None = None,
        demask_temperature: float | None = None,
        rollout_mask_targets: list[int] | None = None,
        remaining_func_evals: int | None = None,
    ) -> GenerationResult:
        def _env_flag(name: str) -> bool:
            value = os.environ.get(name, "")
            return value.strip().lower() in ("1", "true", "yes", "y", "on")

        debug_enabled = _env_flag("BD_MCTS_DEBUG_DEMASK")
        allow_mask_token = _env_flag("BD_MCTS_DEMASK_ALLOW_MASK_TOKEN")
        handle = self._handles[model_name]
        prompt_msgs = _build_prompt_msgs(prompt, self._system_prompt)

        from bd_mcts.token_sequence import TokenSequence

        if parent_seq is None:
            base_seq = TokenSequence.init_diffusion_input(
                prompt_msgs, gen_length=self._gen_length
            )
        else:
            base_seq = parent_seq

        base_masks = int(getattr(base_seq, "num_masks", 0))

        prompt_len, token_ids = _tokenize_sequence(
            base_seq, handle.tokenizer, handle.max_seq_len
        )
        stop_ids = (handle.eos_token_id, *handle.eot_token_ids)

        def _count_len_from_tokens(tokens: list[int]) -> int:
            return _count_generated_tokens(
                tokens,
                stop_ids=stop_ids,
                mask_token_id=handle.mask_token_id,
                pad_token_id=handle.pad_token_id,
            )

        def _count_len_with_masks_from_tokens(tokens: list[int]) -> int:
            return _count_generated_tokens_with_masks(
                tokens,
                stop_ids=stop_ids,
                pad_token_id=handle.pad_token_id,
            )

        def _count_len_from_state(state: Any) -> int:
            return _count_len_from_tokens(state[0, prompt_len:].tolist())

        def _count_len_with_masks_from_state(state: Any) -> int:
            return _count_len_with_masks_from_tokens(state[0, prompt_len:].tolist())

        gen_length = len(token_ids) - prompt_len
        demasker = self._resolve_demasker_for_mask_strategy(
            model_name=model_name, mask_strategy=mask_strategy
        )
        demasker = self._resolve_demasker_for_temperature(
            demasker,
            model_name=model_name,
            demask_temperature=demask_temperature,
        )
        demasker = _resolve_llada_demasker(demasker, gen_length)
        x = _to_tensor(token_ids, handle.device)
        steps = _resolve_demask_steps(demasker)

        initial_masks = _count_masks(x, handle.mask_token_id)
        target_masks = max(0, initial_masks - num_tokens_to_demask)

        rollout_targets: list[int] = []
        if full_rollout and rollout_mask_targets:
            seen = set()
            for target in rollout_mask_targets:
                try:
                    value = int(target)
                except (TypeError, ValueError):
                    continue
                if value < 0 or value in seen:
                    continue
                rollout_targets.append(value)
                seen.add(value)
        start_step_idx = 0
        if parent_seq is not None:
            start_step_idx = _resolve_demask_start_step(
                demasker,
                initial_masks=initial_masks,
                gen_length=gen_length,
            )

        max_step = steps
        if remaining_func_evals is not None:
            try:
                remaining_steps = int(remaining_func_evals)
            except (TypeError, ValueError):
                remaining_steps = -1
            if remaining_steps >= 0:
                max_step = min(steps, start_step_idx + remaining_steps)

        if debug_enabled or allow_mask_token:
            labels = []
            if mask_strategy is not None:
                labels.append(f"mask_strategy={mask_strategy}")
            if demask_temperature is not None:
                labels.append(f"demask_temperature={demask_temperature}")
            strategy_label = f" {' '.join(labels)}" if labels else ""
            header = (
                f"[demask] begin model={model_name}{strategy_label} prompt_len={prompt_len} "
                f"gen_len={gen_length} base_masks={base_masks} initial_masks={initial_masks} "
                f"target_masks={target_masks} steps={steps} start_step={start_step_idx} "
                f"max_step={max_step} full_rollout={full_rollout}"
            )
            self._debug_write(header)
            if debug_enabled:
                logger.info(header)

        new_seq: TokenSequence | None = None
        new_seq_token_len: int | None = None
        new_seq_token_len_with_masks: int | None = None
        steps_used = 0
        x0_at_target = None
        final_tokens: list[int] | None = None
        rollout_path: list[TokenSequence] | None = None
        rollout_lengths: list[int] | None = None
        rollout_lengths_with_masks: list[int] | None = None
        rollout_target_idx = 0
        stall_steps = 0
        x0_mask_total = 0
        x0_mask_warned = 0
        stall_transfer_warned = 0
        min_masks_seen = initial_masks
        demask_step_indices: list[int] | None = (
            [] if self._trace_demask_steps else None
        )
        demask_mask_counts: list[int] | None = (
            [] if self._trace_demask_steps else None
        )

        prev_masks = initial_masks

        for step_idx in range(start_step_idx, max_step):
            x0, remasker = demasker.predict_x0(x, step_idx=step_idx)
            x0_mask_pred = 0
            if debug_enabled or allow_mask_token:
                try:
                    x0_mask_pred = int(
                        ((x0 == handle.mask_token_id) & (x == handle.mask_token_id))
                        .sum()
                        .item()
                    )
                except Exception:
                    x0_mask_pred = 0
            if x0_mask_pred > 0:
                x0_mask_total += x0_mask_pred
                x0_mask_warned += 1
                if debug_enabled or x0_mask_warned <= 10:
                    msg = (
                        f"[demask] predicted mask_token_id in x0 at masked positions: "
                        f"model={model_name} mask_strategy={mask_strategy} demask_temperature={demask_temperature} "
                        f"step={step_idx} count={x0_mask_pred}"
                    )
                    self._debug_write(msg)
                    logger.warning(msg)
            x = remasker.step(x0)
            steps_used += 1

            masks_left = _count_masks(x, handle.mask_token_id)
            min_masks_seen = min(min_masks_seen, masks_left)
            if demask_step_indices is not None and demask_mask_counts is not None:
                demask_step_indices.append(int(step_idx))
                demask_mask_counts.append(int(masks_left))
            if masks_left == prev_masks and prev_masks > 0:
                stall_steps += 1
                transfer_count = getattr(remasker, "_debug_transfer_count", None)
                selected_mask_count = getattr(remasker, "_debug_selected_mask_count", None)
                if transfer_count is not None and int(transfer_count) > 0:
                    stall_transfer_warned += 1
                    if debug_enabled or stall_transfer_warned <= 10:
                        msg = (
                            f"[demask] transferred tokens but masks did not decrease: "
                            f"model={model_name} mask_strategy={mask_strategy} demask_temperature={demask_temperature} "
                            f"step={step_idx} masks={masks_left} "
                            f"target={target_masks} transfer={transfer_count} "
                            f"selected_mask={selected_mask_count}"
                        )
                        self._debug_write(msg)
                        logger.warning(msg)
                elif debug_enabled and stall_steps <= 10:
                    msg = (
                        f"[demask] no mask decrease: model={model_name} mask_strategy={mask_strategy} "
                        f"demask_temperature={demask_temperature} step={step_idx} "
                        f"masks={masks_left} target={target_masks}"
                    )
                    self._debug_write(msg)
                    logger.info(msg)
            prev_masks = masks_left
            current_seq = None
            current_len = None
            current_len_with_masks = None
            if new_seq is None and masks_left <= target_masks:
                new_seq = base_seq.build_from_generated_token_ids(
                    handle.tokenizer, x[0, prompt_len:].tolist()
                )
                current_seq = new_seq
                new_seq_token_len = _count_len_from_state(x)
                new_seq_token_len_with_masks = _count_len_with_masks_from_state(x)
                current_len = new_seq_token_len
                current_len_with_masks = new_seq_token_len_with_masks
                x0_at_target = x0
                if full_rollout and rollout_targets:
                    rollout_path = [new_seq]
                    rollout_lengths = [new_seq_token_len]
                    rollout_lengths_with_masks = [new_seq_token_len_with_masks]
                if not full_rollout:
                    final_tokens = x0[0, prompt_len:].tolist()
                    break

            if (
                full_rollout
                and rollout_targets
                and rollout_path is not None
                and rollout_target_idx < len(rollout_targets)
                and masks_left <= rollout_targets[rollout_target_idx]
            ):
                if current_seq is None:
                    current_seq = base_seq.build_from_generated_token_ids(
                        handle.tokenizer, x[0, prompt_len:].tolist()
                    )
                if current_len is None:
                    current_len = _count_len_from_state(x)
                if current_len_with_masks is None:
                    current_len_with_masks = _count_len_with_masks_from_state(x)
                while (
                    rollout_target_idx < len(rollout_targets)
                    and masks_left <= rollout_targets[rollout_target_idx]
                ):
                    rollout_path.append(current_seq)
                    if rollout_lengths is not None:
                        rollout_lengths.append(current_len)
                    if rollout_lengths_with_masks is not None:
                        rollout_lengths_with_masks.append(current_len_with_masks)
                    rollout_target_idx += 1

            if full_rollout and masks_left <= 0:
                final_tokens = x[0, prompt_len:].tolist()
                break

        if final_tokens is None:
            if full_rollout:
                final_tokens = x[0, prompt_len:].tolist()
            elif x0_at_target is not None:
                final_tokens = x0_at_target[0, prompt_len:].tolist()
            else:
                final_tokens = x[0, prompt_len:].tolist()

        final_seq = base_seq.build_from_generated_token_ids(
            handle.tokenizer, final_tokens
        )
        if new_seq is None:
            new_seq = final_seq
        if full_rollout and rollout_targets and rollout_path is None:
            rollout_path = [new_seq]

        if debug_enabled or allow_mask_token:
            footer = (
                f"[demask] done model={model_name} mask_strategy={mask_strategy} "
                f"demask_temperature={demask_temperature} initial_masks={initial_masks} "
                f"min_masks={min_masks_seen} final_masks={final_seq.num_masks} "
                f"new_seq_masks={new_seq.num_masks} steps_used={steps_used} "
                f"target_masks={target_masks} stall_steps={stall_steps} "
                f"x0_mask_total={x0_mask_total} stall_transfer_steps={stall_transfer_warned}"
            )
            self._debug_write(footer)
            if full_rollout and final_seq.num_masks > 0:
                logger.warning(footer)
            elif debug_enabled:
                logger.info(footer)

        final_token_len = _count_generated_tokens(
            final_tokens,
            stop_ids=stop_ids,
            mask_token_id=handle.mask_token_id,
            pad_token_id=handle.pad_token_id,
        )
        final_token_len_with_masks = _count_generated_tokens_with_masks(
            final_tokens,
            stop_ids=stop_ids,
            pad_token_id=handle.pad_token_id,
        )

        rollout_token_seqs = None
        token_len_history = None
        token_len_with_mask_history = None
        if full_rollout:
            if rollout_path:
                rollout_token_seqs = rollout_path
                if rollout_lengths is None or len(rollout_lengths) != len(rollout_path):
                    rollout_lengths = [
                        _count_sequence_tokens(
                            seq,
                            handle.tokenizer,
                            handle.max_seq_len,
                            stop_ids=stop_ids,
                            mask_token_id=handle.mask_token_id,
                            pad_token_id=handle.pad_token_id,
                        )
                        for seq in rollout_path
                    ]
                if rollout_lengths_with_masks is None or len(
                    rollout_lengths_with_masks
                ) != len(rollout_path):
                    rollout_lengths_with_masks = [
                        _count_sequence_tokens_with_masks(
                            seq,
                            handle.tokenizer,
                            handle.max_seq_len,
                            stop_ids=stop_ids,
                            pad_token_id=handle.pad_token_id,
                        )
                        for seq in rollout_path
                    ]
                if _token_sequence_key(rollout_token_seqs[-1]) != _token_sequence_key(
                    final_seq
                ):
                    rollout_token_seqs.append(final_seq)
                    if rollout_lengths is not None:
                        rollout_lengths.append(final_token_len)
                    if rollout_lengths_with_masks is not None:
                        rollout_lengths_with_masks.append(final_token_len_with_masks)
                else:
                    if rollout_lengths is not None and rollout_lengths:
                        rollout_lengths[-1] = final_token_len
                    if rollout_lengths_with_masks is not None and rollout_lengths_with_masks:
                        rollout_lengths_with_masks[-1] = final_token_len_with_masks
                token_len_history = rollout_lengths
                token_len_with_mask_history = rollout_lengths_with_masks
            elif new_seq is final_seq:
                rollout_token_seqs = [final_seq]
                token_len_history = [final_token_len]
                token_len_with_mask_history = [final_token_len_with_masks]
            else:
                rollout_token_seqs = [new_seq, final_seq]
                if new_seq_token_len is None:
                    new_seq_token_len = _count_sequence_tokens(
                        new_seq,
                        handle.tokenizer,
                        handle.max_seq_len,
                        stop_ids=stop_ids,
                        mask_token_id=handle.mask_token_id,
                        pad_token_id=handle.pad_token_id,
                    )
                if new_seq_token_len_with_masks is None:
                    new_seq_token_len_with_masks = _count_sequence_tokens_with_masks(
                        new_seq,
                        handle.tokenizer,
                        handle.max_seq_len,
                        stop_ids=stop_ids,
                        pad_token_id=handle.pad_token_id,
                    )
                token_len_history = [new_seq_token_len, final_token_len]
                token_len_with_mask_history = [
                    new_seq_token_len_with_masks,
                    final_token_len_with_masks,
                ]

        lm_response = _decode_generated(
            handle.tokenizer,
            final_tokens,
            stop_ids=stop_ids,
            mask_token_id=handle.mask_token_id,
            pad_token_id=handle.pad_token_id,
        )

        return GenerationResult(
            new_token_seq=new_seq,
            lm_response=lm_response,
            num_func_evals=steps_used,
            rollout_token_seqs=rollout_token_seqs,
            final_token_len=final_token_len,
            token_len_history=token_len_history,
            token_len_with_mask_history=token_len_with_mask_history,
            demask_step_indices=demask_step_indices,
            demask_mask_counts=demask_mask_counts,
        )


def _build_prompt_msgs(prompt: str, system_prompt: str | None) -> list[dict[str, str]]:
    msgs: list[dict[str, str]] = []
    if system_prompt:
        msgs.append({"role": "system", "content": system_prompt})
    msgs.append({"role": "user", "content": prompt})
    return msgs


def _tokenize_sequence(
    seq: TokenSequence, tokenizer: Any, max_seq_len: int | None
) -> tuple[int, list[int]]:
    if max_seq_len is None:
        return seq.to_token_ids(tokenizer)
    return seq.to_token_ids(
        tokenizer,
        truncate_tid_strategy="right_pad_mask_remove",
        max_tid_len=max_seq_len,
    )


def _token_sequence_key(seq: TokenSequence) -> tuple:
    segments = tuple(
        (segment.kind, segment.content, int(segment.repetition))
        for segment in seq.segments
    )
    prompt = tuple((msg.get("role", ""), msg.get("content", "")) for msg in seq.prompt_msgs)
    return (prompt, segments, int(seq.gen_length))


def _decode_generated(
    tokenizer: Any,
    token_ids: list[int],
    *,
    stop_ids: Iterable[int],
    mask_token_id: int,
    pad_token_id: int,
) -> str:
    stop_set = set(stop_ids)
    cleaned: list[int] = []
    for token_id in token_ids:
        if token_id in stop_set:
            break
        if token_id in (mask_token_id, pad_token_id):
            continue
        cleaned.append(token_id)
    if not cleaned:
        return ""
    return tokenizer.decode(cleaned)


def _count_generated_tokens(
    token_ids: list[int],
    *,
    stop_ids: Iterable[int],
    mask_token_id: int,
    pad_token_id: int,
) -> int:
    stop_set = set(stop_ids)
    count = 0
    for token_id in token_ids:
        if token_id in stop_set:
            break
        if token_id in (mask_token_id, pad_token_id):
            continue
        count += 1
    return count


def _count_generated_tokens_with_masks(
    token_ids: list[int],
    *,
    stop_ids: Iterable[int],
    pad_token_id: int,
) -> int:
    stop_set = set(stop_ids)
    count = 0
    for token_id in token_ids:
        if token_id in stop_set:
            break
        if token_id == pad_token_id:
            continue
        count += 1
    return count


def _count_sequence_tokens(
    seq: TokenSequence,
    tokenizer: Any,
    max_seq_len: int | None,
    *,
    stop_ids: Iterable[int],
    mask_token_id: int,
    pad_token_id: int,
) -> int:
    prompt_len, token_ids = _tokenize_sequence(seq, tokenizer, max_seq_len)
    return _count_generated_tokens(
        token_ids[prompt_len:],
        stop_ids=stop_ids,
        mask_token_id=mask_token_id,
        pad_token_id=pad_token_id,
    )


def _count_sequence_tokens_with_masks(
    seq: TokenSequence,
    tokenizer: Any,
    max_seq_len: int | None,
    *,
    stop_ids: Iterable[int],
    pad_token_id: int,
) -> int:
    prompt_len, token_ids = _tokenize_sequence(seq, tokenizer, max_seq_len)
    return _count_generated_tokens_with_masks(
        token_ids[prompt_len:],
        stop_ids=stop_ids,
        pad_token_id=pad_token_id,
    )


def _count_masks(x: Any, mask_token_id: int) -> int:
    return int((x == mask_token_id).sum().item())


def _to_tensor(token_ids: list[int], device: Any) -> Any:
    import torch

    return torch.tensor(token_ids, dtype=torch.long, device=device).unsqueeze(0)


def _resolve_llada_demasker(demasker: Any, gen_length: int) -> Any:
    try:
        from bd_mcts.demask_interface.llada import Llada
    except Exception:
        return demasker

    if not isinstance(demasker, Llada):
        return demasker
    if gen_length <= 0:
        raise ValueError("LLaDA gen_length must be >= 1 after truncation")

    cfg = demasker.generation_config
    if gen_length == cfg.gen_length:
        return demasker

    block_length = cfg.block_length
    if block_length <= 0 or gen_length % block_length != 0:
        block_length = gen_length
    else:
        num_blocks = gen_length // block_length
        if num_blocks <= 0 or cfg.steps % num_blocks != 0:
            block_length = gen_length

    return Llada(
        demasker.model,
        steps=cfg.steps,
        gen_length=gen_length,
        block_length=block_length,
        alg=cfg.alg,
        temperature=cfg.temperature,
        cfg_scale=cfg.cfg_scale,
        max_new_tokens=cfg.max_new_tokens,
        eos_penalty=cfg.eos_penalty,
        eot_confidence_penalty=cfg.eot_confidence_penalty,
        mask_token_id=cfg.mask_token_id,
        pad_token_id=cfg.pad_token_id,
        bos_token_id=cfg.bos_token_id,
        eos_token_id=cfg.eos_token_id,
        eot_token_ids=cfg.eot_token_ids,
    )


def _resolve_demask_steps(demasker: Any) -> int:
    steps = getattr(getattr(demasker, "generation_config", None), "steps", None)
    if steps is None:
        raise ValueError("demasker does not expose generation_config.steps")
    try:
        steps = int(steps)
    except (TypeError, ValueError):
        raise ValueError("demasker steps must be an integer")
    if steps < 1:
        raise ValueError("steps must be >= 1")
    return steps


def _resolve_demask_start_step(demasker: Any, *, initial_masks: int, gen_length: int) -> int:
    """
    Resume diffusion from an intermediate masked state.

    Dream-style demaskers expose `timesteps` and `generation_config.eps` where the noise
    parameter t decreases linearly from 1 to eps over `steps` iterations. When we start
    from a partially demasked parent sequence, starting at step 0 wastes many iterations
    where the transfer count is deterministically 0 (because of floor/rounding).

    For non-Dream demaskers, return 0 (start from the beginning).
    """
    cfg = getattr(demasker, "generation_config", None)
    if cfg is None or not hasattr(demasker, "timesteps") or not hasattr(cfg, "eps"):
        return 0
    try:
        eps = float(getattr(cfg, "eps"))
    except (TypeError, ValueError):
        return 0
    if not (0.0 < eps < 1.0):
        return 0

    steps = _resolve_demask_steps(demasker)
    if steps <= 1 or gen_length <= 0 or initial_masks <= 0:
        return 0

    mask_fraction = initial_masks / gen_length
    if not math.isfinite(mask_fraction):
        return 0
    mask_fraction = max(eps, min(1.0, mask_fraction))

    denom = 1.0 - eps
    if denom <= 0:
        return 0

    # timesteps[k] = 1 - k * (1 - eps) / steps
    # => k ~= (1 - mask_fraction) * steps / (1 - eps)
    start_float = (1.0 - mask_fraction) * steps / denom
    start_step = int(math.ceil(start_float))
    start_step = max(0, min(steps - 1, start_step))

    # For Dream non-origin algorithms, the number of transferred tokens per step is
    # `int(num_masks * (1 - s/t))`. With the linear schedule this becomes
    # `int(num_masks * delta / t)`, which can be 0 for early steps. Advance until the
    # first step that deterministically transfers at least one token.
    alg = str(getattr(cfg, "alg", ""))
    if alg and alg != "origin":
        delta = denom / steps
        while start_step < steps - 1:
            t = 1.0 - start_step * delta
            if t <= 0:
                break
            expected = initial_masks * delta / t
            if expected >= 1.0:
                break
            start_step += 1

    return start_step


def _set_seed(seed: int | None) -> None:
    if seed is None:
        return
    random.seed(seed)
    try:
        import numpy as np

        np.random.seed(seed)
    except Exception:
        pass
    try:
        import torch

        torch.manual_seed(seed)
    except Exception:
        pass


def _progress_enabled(cfg: dict[str, Any]) -> bool:
    progress_cfg = cfg.get("progress")
    if progress_cfg is None:
        return True
    if isinstance(progress_cfg, dict):
        return bool(progress_cfg.get("enabled", True))
    return bool(progress_cfg)


def _progress_nfe_total(cfg: dict[str, Any], sample_total: int) -> int | None:
    try:
        nfe = int(cfg.get("nfe", 0))
    except (TypeError, ValueError):
        return None
    if nfe <= 0:
        return None
    return nfe * sample_total


def _progress_label(cfg: dict[str, Any], shard_id: int) -> str | None:
    try:
        num_workers = int(cfg.get("parallel", {}).get("num_workers", 1))
    except (TypeError, ValueError):
        num_workers = 1
    if num_workers <= 1:
        return None
    return str(shard_id)


class _ProgressTracker:
    def __init__(
        self,
        *,
        enable: bool,
        sample_total: int,
        nfe_total: int | None,
        position: int | None,
        label: str | None,
    ) -> None:
        self._sample_bar = None
        self._nfe_bar = None
        if not enable:
            return
        from tqdm import tqdm

        sample_desc = "Samples"
        nfe_desc = "NFE"
        if label:
            sample_desc = f"{sample_desc}[{label}]"
            nfe_desc = f"{nfe_desc}[{label}]"

        self._sample_bar = tqdm(
            total=sample_total,
            desc=sample_desc,
            unit="sample",
            position=position,
            dynamic_ncols=True,
            ascii=True,
        )
        nfe_position = None if position is None else position + 1
        self._nfe_bar = tqdm(
            total=nfe_total,
            desc=nfe_desc,
            unit="nfe",
            position=nfe_position,
            dynamic_ncols=True,
            ascii=True,
        )

    def update_samples(self, count: int = 1) -> None:
        if self._sample_bar is None or count <= 0:
            return
        self._sample_bar.update(count)

    def update_nfe(self, count: int) -> None:
        if self._nfe_bar is None or count <= 0:
            return
        self._nfe_bar.update(count)

    def close(self) -> None:
        if self._sample_bar is not None:
            self._sample_bar.close()
        if self._nfe_bar is not None:
            self._nfe_bar.close()


def _json_safe(value: Any) -> Any:
    if isinstance(value, dict):
        return {k: _json_safe(v) for k, v in value.items()}
    if isinstance(value, list):
        return [_json_safe(v) for v in value]
    if isinstance(value, float):
        if math.isfinite(value):
            return value
        return str(value)
    return value


def _resolve_torch_dtype(dtype_name: str | None) -> Any:
    if dtype_name is None:
        return None
    import torch

    if not hasattr(torch, dtype_name):
        raise ValueError(f"unknown torch dtype: {dtype_name}")
    return getattr(torch, dtype_name)


def _resolve_penalty_value(value: Any, *, name: str) -> float:
    if value is None:
        return 0.0
    if isinstance(value, (int, float)):
        return float(value)
    if isinstance(value, str):
        text = value.strip().lower()
        if text in ("inf", "+inf", ".inf", "+.inf", "infinity", "+infinity", "torch.inf"):
            return math.inf
        if text in ("-inf", "-.inf", "-infinity"):
            return -math.inf
        if text in ("nan", ".nan", "torch.nan"):
            return math.nan
        try:
            return float(text)
        except ValueError as exc:
            raise ValueError(
                f"{name} must be a number or inf, got {value!r}"
            ) from exc
    raise TypeError(f"{name} must be a number or inf, got {type(value).__name__}")


def _resolve_model_entries(cfg: dict[str, Any]) -> list[dict[str, Any]]:
    if cfg["algo"]["name"] in ("abmcts", "multimodel_mcts"):
        models = cfg["algo"].get("models", [])
        if not models:
            raise ValueError(
                f"{cfg['algo']['name']} requires algo.models to be non-empty"
            )
        entries: list[dict[str, Any]] = []
        for idx, model in enumerate(models):
            if isinstance(model, str):
                entries.append({"name": model, "demask": {}})
                continue
            if not isinstance(model, dict):
                raise ValueError(
                    "multi-model entries must be strings or mappings, "
                    f"got {type(model).__name__} at index {idx}"
                )
            name = model.get("name") or model.get("model_name")
            if not name:
                raise ValueError(
                    f"{cfg['algo']['name']} models[{idx}] must include name"
                )
            demask_override = {
                k: v
                for k, v in model.items()
                if k not in ("name", "model_name", "demask")
            }
            if "demask" in model:
                demask_cfg = model.get("demask") or {}
                if not isinstance(demask_cfg, dict):
                    raise ValueError(
                        f"{cfg['algo']['name']} models[{idx}].demask must be a mapping"
                    )
                demask_override.update(demask_cfg)
            entries.append({"name": name, "demask": demask_override})
        return entries
    model_name = cfg["demask"].get("model_name")
    if not model_name:
        raise ValueError("demask.model_name is required for non-abmcts runs")
    return [{"name": model_name, "demask": {}}]


def _resolve_model_names(cfg: dict[str, Any]) -> list[str]:
    return [entry["name"] for entry in _resolve_model_entries(cfg)]


def _merge_demask_cfg(
    base: dict[str, Any] | None, override: dict[str, Any] | None
) -> dict[str, Any]:
    merged = dict(base or {})
    if override and "type" in override and merged.get("type") != override.get("type"):
        shared_keys = {
            "dtype",
            "trust_remote_code",
            "use_fast",
            "max_seq_len",
            "steps",
        }
        merged = {key: value for key, value in merged.items() if key in shared_keys}
    merged.update(override or {})
    return merged


def _build_model_handles(
    cfg: dict[str, Any], model_entries: list[dict[str, Any]], device: Any
) -> dict[str, ModelHandle]:
    from transformers import AutoModel, AutoTokenizer

    from bd_mcts.token_sequence import TOKEN_ID_MAPPER

    base_demask_cfg = cfg.get("demask", {})
    handles: dict[str, ModelHandle] = {}

    for entry in model_entries:
        name = entry["name"]
        if name not in TOKEN_ID_MAPPER:
            raise ValueError(f"model {name} is not in TOKEN_ID_MAPPER")

        demask_cfg = _merge_demask_cfg(base_demask_cfg, entry.get("demask", {}))
        eos_penalty = _resolve_penalty_value(
            demask_cfg.get("eos_penalty", 0.0), name="demask.eos_penalty"
        )
        eot_confidence_penalty = _resolve_penalty_value(
            demask_cfg.get("eot_confidence_penalty", 0.0),
            name="demask.eot_confidence_penalty",
        )
        if "type" not in demask_cfg:
            raise ValueError(f"demask.type is required for model {name}")
        if "steps" not in demask_cfg:
            raise ValueError(f"demask.steps is required for model {name}")
        dtype = _resolve_torch_dtype(demask_cfg.get("dtype"))
        tokenizer = AutoTokenizer.from_pretrained(
            name,
            trust_remote_code=demask_cfg.get("trust_remote_code", False),
            use_fast=demask_cfg.get("use_fast", True),
        )
        model = AutoModel.from_pretrained(
            name,
            trust_remote_code=demask_cfg.get("trust_remote_code", False),
            torch_dtype=dtype,
        ).to(device)
        model.eval()

        ids = TOKEN_ID_MAPPER[name]
        bos_id = tokenizer.bos_token_id or ids["eos"]
        eos_id = tokenizer.eos_token_id or ids["eos"]

        if demask_cfg["type"] == "dream":
            from bd_mcts.demask_interface.dream import Dream

            max_len = demask_cfg.get("max_seq_len") or getattr(
                tokenizer, "model_max_length", None
            )
            demasker = Dream(
                model,
                steps=demask_cfg["steps"],
                alg=demask_cfg.get("alg", "origin"),
                alg_temp=demask_cfg.get("alg_temp"),
                temperature=demask_cfg.get("temperature", 0.0),
                top_p=demask_cfg.get("top_p"),
                top_k=demask_cfg.get("top_k"),
                eos_penalty=eos_penalty,
                max_length=max_len or cfg["algo"]["gen_length"],
                mask_token_id=ids["mask"],
                pad_token_id=ids["pad"],
                bos_token_id=bos_id,
                eos_token_id=eos_id,
            )
            eot_ids = (ids["eot"],)
        elif demask_cfg["type"] == "llada":
            from bd_mcts.demask_interface.llada import Llada

            demasker = Llada(
                model,
                steps=demask_cfg["steps"],
                gen_length=cfg["algo"]["gen_length"],
                block_length=demask_cfg.get("block_length", -1),
                alg=demask_cfg.get("alg", "random"),
                temperature=demask_cfg.get("temperature", 0.0),
                cfg_scale=demask_cfg.get("cfg_scale", 0.0),
                max_new_tokens=demask_cfg.get("max_new_tokens"),
                eos_penalty=eos_penalty,
                eot_confidence_penalty=eot_confidence_penalty,
                mask_token_id=ids["mask"],
                pad_token_id=ids["pad"],
                bos_token_id=bos_id,
                eos_token_id=eos_id,
                eot_token_ids=(ids["eot"],),
            )
            eot_ids = (ids["eot"],)
        else:
            raise ValueError(f"unknown demask type: {demask_cfg['type']}")

        handles[name] = ModelHandle(
            name=name,
            tokenizer=tokenizer,
            demasker=demasker,
            max_seq_len=demask_cfg.get("max_seq_len"),
            device=device,
            mask_token_id=ids["mask"],
            pad_token_id=ids["pad"],
            eos_token_id=eos_id,
            eot_token_ids=eot_ids,
        )

    return handles


def _build_task(task_cfg: dict[str, Any]) -> Task:
    name = task_cfg["name"]
    kwargs = task_cfg.get("kwargs", {})
    return make_code_eval_task(name, **kwargs)


def _resolve_system_prompt(cfg: dict[str, Any]) -> str | None:
    prompt_cfg = cfg.get("prompt")
    if not isinstance(prompt_cfg, dict):
        prompt_cfg = {}
        cfg["prompt"] = prompt_cfg
    if "system" in prompt_cfg and prompt_cfg["system"] is not None:
        return prompt_cfg["system"]
    system_prompt = cfg.get("task", {}).get("kwargs", {}).get("system_prompt")
    if system_prompt is not None:
        prompt_cfg["system"] = system_prompt
    return system_prompt


def _apply_task_defaults(cfg: dict[str, Any]) -> None:
    task_cfg = cfg.get("task", {})
    task_name = str(task_cfg.get("name", "")).lower()
    if task_name not in (
        "lcb_gen",
        "livecodebench_gen",
        "livecodebench_generation",
    ):
        return
    kwargs = task_cfg.setdefault("kwargs", {})
    if not kwargs.get("prompt_use_instruct", False):
        return
    if kwargs.get("prompt_tokenizer_name"):
        return
    model_name = cfg.get("demask", {}).get("model_name")
    if not model_name:
        models = cfg.get("algo", {}).get("models") or []
        if models:
            first = models[0]
            if isinstance(first, dict):
                model_name = first.get("name") or first.get("model_name")
            else:
                model_name = first
    if model_name:
        kwargs["prompt_tokenizer_name"] = model_name


def _apply_nfe_overrides(cfg: dict[str, Any]) -> None:
    algo_cfg = cfg.get("algo", {})
    if algo_cfg.get("name") != "best_of_n":
        return

    try:
        n = int(algo_cfg.get("n", 0))
    except (TypeError, ValueError):
        raise ValueError("best_of_n requires algo.n to be an integer")
    try:
        steps = int(cfg.get("demask", {}).get("steps", 0))
    except (TypeError, ValueError):
        raise ValueError("best_of_n requires demask.steps to be an integer")

    if n <= 0 or steps <= 0:
        raise ValueError("best_of_n requires algo.n and demask.steps to be > 0")

    cfg["nfe"] = n * steps


def _task_size(task: Task) -> int:
    if hasattr(task, "_task_ids"):
        return len(getattr(task, "_task_ids"))
    if hasattr(task, "_samples"):
        return len(getattr(task, "_samples"))
    if hasattr(task, "_public_samples"):
        return len(getattr(task, "_public_samples"))
    raise ValueError("task does not expose sample length")


def _task_prompt(task: Task, sample_id: int) -> str:
    if hasattr(task, "get_prompt"):
        return task.get_prompt(sample_id)
    if isinstance(task, EvalPlusTask):
        task_id = task._task_ids[sample_id]
        return task._problems[task_id]["prompt"]
    if isinstance(task, BigCodeBenchTask):
        task_id = task._task_ids[sample_id]
        return task._problems[task_id]["complete_prompt"]
    if isinstance(task, CruxEvalTask):
        raise ValueError("CruxEval prompts are not supported in this runner")
    raise ValueError(f"unsupported task type: {type(task).__name__}")


def _select_sample_ids(
    task: Task, task_cfg: dict[str, Any], seed: int | None
) -> list[int]:
    if task_cfg.get("sample_ids"):
        return [int(x) for x in task_cfg["sample_ids"]]
    total = _task_size(task)
    ids = list(range(total))
    if task_cfg.get("shuffle", False):
        shuffle_seed = task_cfg.get("shuffle_seed")
        if shuffle_seed is None:
            shuffle_seed = seed
        rng = random.Random(shuffle_seed)
        rng.shuffle(ids)
    limit = task_cfg.get("limit")
    if limit is not None:
        ids = ids[: min(total, int(limit))]
    return ids


def _split_ids(ids: list[int], shards: int) -> list[list[int]]:
    if shards <= 1 or not ids:
        return [ids]
    chunk = max(1, math.ceil(len(ids) / shards))
    return [ids[i : i + chunk] for i in range(0, len(ids), chunk)]


def _default_max_trials(cfg: dict[str, Any], algo: SearchAlgo) -> int | None:
    if cfg.get("max_trials") is not None:
        return int(cfg["max_trials"])
    if isinstance(algo, BestOfN):
        return int(algo.n)
    return None


def _budget_exhausted(cfg: dict[str, Any], algo: SearchAlgo) -> bool:
    if cfg.get("nfe", 0) <= 0:
        return False
    remaining = getattr(algo, "num_func_eval_budget", None)
    if remaining is None:
        return False
    return remaining <= 0


def _resolve_abmcts_ask_workers(cfg: dict[str, Any], sample_ids: list[int]) -> int:
    parallel_cfg = cfg.get("parallel", {})
    ask_workers = parallel_cfg.get("ask_workers", 1)
    try:
        ask_workers = int(ask_workers)
    except (TypeError, ValueError):
        ask_workers = 1
    return max(1, min(ask_workers, len(sample_ids)))


def _resolve_submit_workers(cfg: dict[str, Any], sample_count: int) -> int:
    parallel_cfg = cfg.get("parallel", {})
    submit_workers = parallel_cfg.get("submit_workers", 2)
    try:
        submit_workers = int(submit_workers)
    except (TypeError, ValueError):
        submit_workers = 2
    if sample_count <= 0:
        return max(1, submit_workers)
    return max(1, min(submit_workers, sample_count))


def _resolve_parallel_schedule(cfg: dict[str, Any]) -> str:
    parallel_cfg = cfg.get("parallel", {})
    schedule = parallel_cfg.get("schedule", "static")
    if isinstance(schedule, bool):
        return "dynamic" if schedule else "static"
    schedule = str(schedule).strip().lower()
    if schedule in ("dynamic", "queue"):
        return "dynamic"
    return "static"


def _resolve_abmcts_tree_dir(cfg: dict[str, Any]) -> Path:
    output_cfg = cfg.get("output", {})
    tree_dir = output_cfg.get("abmcts_tree_dir")
    if tree_dir:
        return Path(tree_dir)
    output_path = Path(output_cfg.get("path", "results.json"))
    return output_path.parent / "abmcts_trees"


def _resolve_multimodel_mcts_tree_dir(cfg: dict[str, Any]) -> Path:
    output_cfg = cfg.get("output", {})
    tree_dir = output_cfg.get("multimodel_mcts_tree_dir")
    if tree_dir:
        return Path(tree_dir)
    output_path = Path(output_cfg.get("path", "results.json"))
    return output_path.parent / "multimodel_mcts_trees"


def _resolve_debug_dir(cfg: dict[str, Any]) -> Path:
    output_cfg = cfg.get("output", {})
    output_path = Path(output_cfg.get("path", "results.json"))
    return output_path.parent


def _render_dot_to_png(dot_path: Path, png_path: Path) -> None:
    dot_cmd = shutil.which("dot")
    if not dot_cmd:
        logger.warning("graphviz 'dot' not found; skipping png render for %s", dot_path)
        return
    try:
        subprocess.run(
            [dot_cmd, "-Tpng", str(dot_path), "-o", str(png_path)],
            check=True,
        )
    except Exception as exc:
        logger.warning("graphviz render failed for %s: %s", dot_path, exc)


def _render_multimodel_mcts_tree(
    algo: MultiModelMCTS, sample_id: int, cfg: dict[str, Any]
) -> None:
    output_cfg = cfg.get("output", {})
    render_png = bool(output_cfg.get("multimodel_mcts_tree_png", True))
    render_json = bool(output_cfg.get("multimodel_mcts_tree_json", True))
    if not render_png and not render_json:
        return
    max_depth = output_cfg.get("multimodel_mcts_tree_max_depth")
    max_nodes = output_cfg.get("multimodel_mcts_tree_max_nodes")
    include_clean_pred = bool(
        output_cfg.get("multimodel_mcts_tree_include_clean_pred", False)
    )
    correct_reward_threshold = output_cfg.get(
        "multimodel_mcts_tree_correct_reward_threshold", 1.0
    )
    try:
        max_depth = int(max_depth) if max_depth is not None else None
    except (TypeError, ValueError):
        max_depth = None
    try:
        max_nodes = int(max_nodes) if max_nodes is not None else None
    except (TypeError, ValueError):
        max_nodes = None
    try:
        correct_reward_threshold = float(correct_reward_threshold)
    except (TypeError, ValueError):
        correct_reward_threshold = 1.0

    tree_dir = _resolve_multimodel_mcts_tree_dir(cfg)
    tree_dir.mkdir(parents=True, exist_ok=True)
    stem = tree_dir / f"multimodel_mcts_tree_sample_{sample_id}"
    dot_path = stem.with_suffix(".dot")
    algo.render_tree_dot(
        dot_path,
        max_depth=max_depth,
        max_nodes=max_nodes,
        include_clean_pred=include_clean_pred,
        correct_reward_threshold=correct_reward_threshold,
    )
    if render_png:
        png_path = stem.with_suffix(".png")
        _render_dot_to_png(dot_path, png_path)

    if render_json:
        include_token_sequences = bool(
            output_cfg.get("multimodel_mcts_tree_json_include_token_sequences", True)
        )
        include_state_text = bool(
            output_cfg.get("multimodel_mcts_tree_json_include_state_text", True)
        )
        state_text_max_chars = output_cfg.get(
            "multimodel_mcts_tree_json_state_text_max_chars", 600
        )
        include_rollouts = bool(
            output_cfg.get("multimodel_mcts_tree_json_include_rollouts", True)
        )
        include_rollout_texts = bool(
            output_cfg.get("multimodel_mcts_tree_json_include_rollout_texts", True)
        )
        rollout_text_max_chars = output_cfg.get(
            "multimodel_mcts_tree_json_rollout_text_max_chars", 600
        )
        include_rollout_token_sequences = bool(
            output_cfg.get(
                "multimodel_mcts_tree_json_include_rollout_token_sequences", False
            )
        )
        include_full_preds = bool(
            output_cfg.get("multimodel_mcts_tree_json_include_full_preds", False)
        )
        pred_preview_chars = output_cfg.get(
            "multimodel_mcts_tree_json_pred_preview_chars", 160
        )
        if pred_preview_chars is not None:
            try:
                pred_preview_chars = int(pred_preview_chars)
            except (TypeError, ValueError):
                pred_preview_chars = 160
        if state_text_max_chars is not None:
            try:
                state_text_max_chars = int(state_text_max_chars)
            except (TypeError, ValueError):
                state_text_max_chars = 600
        if rollout_text_max_chars is not None:
            try:
                rollout_text_max_chars = int(rollout_text_max_chars)
            except (TypeError, ValueError):
                rollout_text_max_chars = 600
        json_path = stem.with_suffix(".json")
        algo.export_tree_json(
            json_path,
            max_depth=max_depth,
            max_nodes=max_nodes,
            include_token_sequences=include_token_sequences,
            include_state_text=include_state_text,
            state_text_max_chars=state_text_max_chars,
            include_rollouts=include_rollouts,
            include_rollout_texts=include_rollout_texts,
            rollout_text_max_chars=rollout_text_max_chars,
            include_rollout_token_sequences=include_rollout_token_sequences,
            include_full_preds=include_full_preds,
            pred_preview_chars=pred_preview_chars,
            correct_reward_threshold=correct_reward_threshold,
        )


def _render_abmcts_trees(algos: dict[int, ABMCTSD], cfg: dict[str, Any]) -> None:
    tree_dir = _resolve_abmcts_tree_dir(cfg)
    tree_dir.mkdir(parents=True, exist_ok=True)
    for sample_id, algo in algos.items():
        current_basename = tree_dir / f"abmcts_tree_sample_{sample_id}_current"
        render_abmcts_tree_html(algo.abmcts_state, current_basename)
        for archive_idx, archived_state in enumerate(algo.state_archive):
            archive_basename = (
                tree_dir / f"abmcts_tree_sample_{sample_id}_archive_{archive_idx}"
            )
            render_abmcts_tree_html(archived_state, archive_basename)


def _build_algo(cfg: dict[str, Any]) -> SearchAlgo:
    algo_cfg = cfg["algo"]
    nfe = int(cfg.get("nfe", 0))
    if nfe <= 0:
        nfe = -1

    if algo_cfg["name"] == "best_of_n":
        return BestOfN(
            BestOfNConfig(
                n=int(algo_cfg["n"]),
                gen_length=int(algo_cfg["gen_length"]),
                num_func_eval_budget=nfe,
            )
        )
    if algo_cfg["name"] == "dts":
        backup_lambda = algo_cfg.get("backup_lambda", math.inf)
        if isinstance(backup_lambda, str) and backup_lambda.lower() == "inf":
            backup_lambda = math.inf
        return DTS(
            DTSConfig(
                gen_length=int(algo_cfg["gen_length"]),
                num_func_eval_budget=nfe,
                full_rollout=bool(algo_cfg.get("full_rollout", True)),
                demask_schedule=list(
                    algo_cfg.get(
                        "demask_schedule", [0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.2]
                    )
                ),
                min_unmask_num=int(algo_cfg.get("min_unmask_num", 3)),
                exploration_const=float(algo_cfg.get("exploration_const", 1.0)),
                progressive_width_const=float(
                    algo_cfg.get("progressive_width_const", 2.0)
                ),
                progressive_width_alpha=float(
                    algo_cfg.get("progressive_width_alpha", 0.7)
                ),
                backup_lambda=float(backup_lambda),
                max_selection_attempts=int(algo_cfg.get("max_selection_attempts", 20)),
            )
        )
    if algo_cfg["name"] == "abmcts":
        return ABMCTSD(
            ABMCTSDConfig(
                models=_resolve_model_names(cfg),
                gen_length=int(algo_cfg["gen_length"]),
                num_func_eval_budget=nfe,
                full_rollout=bool(algo_cfg.get("full_rollout", True)),
                demask_schedule=list(
                    algo_cfg.get(
                        "demask_schedule", [0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.2]
                    )
                ),
                min_unmask_num=int(algo_cfg.get("min_unmask_num", 3)),
            )
        )
    if algo_cfg["name"] == "multimodel_mcts":
        return MultiModelMCTS(
            MultiModelMCTSConfig(
                models=_resolve_model_names(cfg),
                gen_length=int(algo_cfg["gen_length"]),
                num_func_eval_budget=nfe,
                full_rollout=bool(algo_cfg.get("full_rollout", True)),
                enable_rollout_cache=bool(algo_cfg.get("enable_rollout_cache", True)),
                demask_schedule=list(
                    algo_cfg.get(
                        "demask_schedule", [0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.2]
                    )
                ),
                min_unmask_num=int(algo_cfg.get("min_unmask_num", 3)),
                exploration_const=float(algo_cfg.get("exploration_const", 1.0)),
                max_selection_attempts=int(
                    algo_cfg.get("max_selection_attempts", 20)
                ),
            )
        )
    if algo_cfg["name"] == "mask_strategy_mcts":
        return MaskStrategyMCTS(
            MaskStrategyMCTSConfig(
                mask_strategies=list(algo_cfg.get("mask_strategies", [])),
                gen_length=int(algo_cfg["gen_length"]),
                num_func_eval_budget=nfe,
                full_rollout=bool(algo_cfg.get("full_rollout", True)),
                enable_rollout_cache=bool(algo_cfg.get("enable_rollout_cache", True)),
                demask_schedule=list(
                    algo_cfg.get(
                        "demask_schedule", [0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.2]
                    )
                ),
                min_unmask_num=int(algo_cfg.get("min_unmask_num", 3)),
                exploration_const=float(algo_cfg.get("exploration_const", 1.0)),
                max_selection_attempts=int(
                    algo_cfg.get("max_selection_attempts", 20)
                ),
            )
        )
    if algo_cfg["name"] == "temperature_mcts":
        return TemperatureMCTS(
            TemperatureMCTSConfig(
                temperatures=list(algo_cfg.get("temperatures", [])),
                gen_length=int(algo_cfg["gen_length"]),
                num_func_eval_budget=nfe,
                full_rollout=bool(algo_cfg.get("full_rollout", True)),
                enable_rollout_cache=bool(algo_cfg.get("enable_rollout_cache", True)),
                demask_schedule=list(
                    algo_cfg.get(
                        "demask_schedule", [0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.2]
                    )
                ),
                min_unmask_num=int(algo_cfg.get("min_unmask_num", 3)),
                exploration_const=float(algo_cfg.get("exploration_const", 1.0)),
                max_selection_attempts=int(
                    algo_cfg.get("max_selection_attempts", 20)
                ),
            )
        )
    raise ValueError(f"unknown algo name: {algo_cfg['name']}")


def _run_search_for_sample(
    *,
    sample_id: int,
    prompt: str,
    task: Task,
    algo: SearchAlgo,
    generator: DemaskGenerator,
    default_model: str,
    cfg: dict[str, Any],
    progress: _ProgressTracker | None = None,
    do_submit: bool = True,
) -> dict[str, Any]:
    trials: list[dict[str, Any]] = []
    max_trials = _default_max_trials(cfg, algo)
    if max_trials is None and cfg.get("nfe", 0) <= 0:
        raise ValueError("max_trials is required when nfe <= 0")

    while True:
        if max_trials is not None and len(trials) >= max_trials:
            break
        if _budget_exhausted(cfg, algo):
            break

        trial = algo.ask()
        algo_name = cfg.get("algo", {}).get("name")
        mask_strategy = None
        demask_temperature = None
        if algo_name == "mask_strategy_mcts":
            model_name = default_model
            mask_strategy = trial.action
        elif algo_name == "temperature_mcts":
            model_name = default_model
            if trial.action is None:
                raise RuntimeError("temperature_mcts trial.action is None")
            try:
                demask_temperature = float(trial.action)
            except (TypeError, ValueError) as exc:
                raise RuntimeError(
                    f"temperature_mcts trial.action is not a float: {trial.action!r}"
                ) from exc
        else:
            model_name = trial.action or default_model
        cached_result = getattr(trial, "cached_result", None)
        if cached_result is not None:
            answer = cached_result.clean_pred
            eval_res = task.evaluate(sample_id, answer)
            cached_result.trial_id = trial.trial_id
            cached_result.reward = eval_res.metric

            algo.tell(cached_result)

            trials.append(
                {
                    "trial_id": trial.trial_id,
                    "model": model_name,
                    "mask_strategy": mask_strategy,
                    "demask_temperature": demask_temperature,
                    "cached": True,
                    "reward": eval_res.metric,
                    "answer": answer,
                    "num_func_evals": cached_result.num_func_evals,
                    "final_token_len": cached_result.final_token_len,
                    "token_len_history": cached_result.token_len_history,
                    "token_len_with_mask_history": cached_result.token_len_with_mask_history,
                    "eval_detail": eval_res.sample_detail,
                }
            )
            continue

        if algo_name in ("mask_strategy_mcts", "temperature_mcts"):
            gen_kwargs = dict(
                prompt=prompt,
                parent_seq=trial.parent_token_seq,
                num_tokens_to_demask=trial.num_tokens_to_demask,
                full_rollout=trial.full_rollout,
                model_name=model_name,
                rollout_mask_targets=getattr(trial, "rollout_mask_targets", None),
                remaining_func_evals=getattr(trial, "remaining_func_evals", None),
            )
            if algo_name == "mask_strategy_mcts":
                gen_kwargs["mask_strategy"] = mask_strategy
            if algo_name == "temperature_mcts":
                gen_kwargs["demask_temperature"] = demask_temperature
            gen = generator.generate(
                **gen_kwargs,
            )
        else:
            gen = generator.generate(
                prompt=prompt,
                parent_seq=trial.parent_token_seq,
                num_tokens_to_demask=trial.num_tokens_to_demask,
                full_rollout=trial.full_rollout,
                model_name=model_name,
                rollout_mask_targets=getattr(trial, "rollout_mask_targets", None),
                remaining_func_evals=getattr(trial, "remaining_func_evals", None),
            )
        answer = task.parse_answer(sample_id, gen.lm_response)
        eval_res = task.evaluate(sample_id, answer)

        algo.tell(
            Result(
                trial_id=trial.trial_id,
                new_token_seq=gen.new_token_seq,
                clean_pred=answer,
                num_func_evals=gen.num_func_evals,
                reward=eval_res.metric,
                rollout_token_seqs=gen.rollout_token_seqs,
                rollout_clean_preds=None,
                final_token_len=gen.final_token_len,
                token_len_history=gen.token_len_history,
                token_len_with_mask_history=gen.token_len_with_mask_history,
                demask_step_indices=gen.demask_step_indices,
                demask_mask_counts=gen.demask_mask_counts,
            )
        )
        if progress is not None:
            progress.update_nfe(gen.num_func_evals)

        trials.append(
            {
                "trial_id": trial.trial_id,
                "model": model_name,
                "mask_strategy": mask_strategy,
                "demask_temperature": demask_temperature,
                "cached": False,
                "reward": eval_res.metric,
                "answer": answer,
                "num_func_evals": gen.num_func_evals,
                "final_token_len": gen.final_token_len,
                "token_len_history": gen.token_len_history,
                "token_len_with_mask_history": gen.token_len_with_mask_history,
                "eval_detail": eval_res.sample_detail,
            }
        )

    best_answer = ""
    best_reward = None
    top = algo.top_k(1)
    if top:
        best_answer, best_reward = top[0]
    elif trials:
        best_answer = trials[-1]["answer"]
        best_reward = trials[-1]["reward"]

    submit_metric = None
    submit_detail = None
    if do_submit:
        submit_res = (
            task.submit(sample_id, best_answer)
            if best_answer
            else ProbResult(metric=0.0, sample_detail={"error": "no_trials"})
        )
        submit_metric = submit_res.metric
        submit_detail = submit_res.sample_detail

    if isinstance(algo, MultiModelMCTS):
        _render_multimodel_mcts_tree(algo, sample_id, cfg)

    return {
        "sample_id": sample_id,
        "best_answer": best_answer,
        "best_reward": best_reward,
        "submit_metric": submit_metric,
        "submit_detail": submit_detail,
        "trials": trials,
    }


def _abmcts_sample_done(
    cfg: dict[str, Any],
    algo: ABMCTSD,
    trial_count: int,
    max_trials: int | None,
) -> bool:
    if max_trials is not None and trial_count >= max_trials:
        return True
    return _budget_exhausted(cfg, algo)


def _run_abmcts_parallel_for_samples(
    *,
    sample_ids: list[int],
    task: Task,
    generator: DemaskGenerator,
    default_model: str,
    cfg: dict[str, Any],
    progress: _ProgressTracker | None = None,
    do_submit: bool = True,
) -> list[dict[str, Any]]:
    prompts = {sample_id: _task_prompt(task, sample_id) for sample_id in sample_ids}
    algos: dict[int, ABMCTSD] = {}
    for sample_id in sample_ids:
        algo = _build_algo(cfg)
        if not isinstance(algo, ABMCTSD):
            raise ValueError("abmcts parallel run requires ABMCTSD")
        algos[sample_id] = algo

    max_trials = _default_max_trials(cfg, next(iter(algos.values())))
    if max_trials is None and cfg.get("nfe", 0) <= 0:
        raise ValueError("max_trials is required when nfe <= 0")

    trials: dict[int, list[dict[str, Any]]] = {
        sample_id: [] for sample_id in sample_ids
    }
    parallel_algo = ABMCTSDParallel(
        algos,
        ask_workers=_resolve_abmcts_ask_workers(cfg, sample_ids),
    )

    active_ids = list(sample_ids)
    while True:
        prev_active = list(active_ids)
        active_ids = [
            sample_id
            for sample_id in active_ids
            if not _abmcts_sample_done(
                cfg, algos[sample_id], len(trials[sample_id]), max_trials
            )
        ]
        if progress is not None and prev_active:
            done_now = set(prev_active) - set(active_ids)
            if done_now:
                progress.update_samples(len(done_now))
        if not active_ids:
            break

        for parallel_trial in parallel_algo.ask(active_ids):
            sample_id = parallel_trial.problem_id
            trial = parallel_trial.trial
            model_name = trial.action or default_model
            gen = generator.generate(
                prompt=prompts[sample_id],
                parent_seq=trial.parent_token_seq,
                num_tokens_to_demask=trial.num_tokens_to_demask,
                full_rollout=trial.full_rollout,
                model_name=model_name,
                remaining_func_evals=getattr(trial, "remaining_func_evals", None),
            )
            answer = task.parse_answer(sample_id, gen.lm_response)
            eval_res = task.evaluate(sample_id, answer)

            parallel_algo.tell(
                sample_id,
                Result(
                    trial_id=trial.trial_id,
                    new_token_seq=gen.new_token_seq,
                    clean_pred=answer,
                    num_func_evals=gen.num_func_evals,
                    reward=eval_res.metric,
                    rollout_token_seqs=gen.rollout_token_seqs,
                    rollout_clean_preds=None,
                    final_token_len=gen.final_token_len,
                    token_len_history=gen.token_len_history,
                    token_len_with_mask_history=gen.token_len_with_mask_history,
                    demask_step_indices=gen.demask_step_indices,
                    demask_mask_counts=gen.demask_mask_counts,
                ),
            )
            if progress is not None:
                progress.update_nfe(gen.num_func_evals)

            trials[sample_id].append(
                {
                    "trial_id": trial.trial_id,
                    "model": model_name,
                    "reward": eval_res.metric,
                    "answer": answer,
                    "num_func_evals": gen.num_func_evals,
                    "final_token_len": gen.final_token_len,
                    "token_len_history": gen.token_len_history,
                    "token_len_with_mask_history": gen.token_len_with_mask_history,
                    "eval_detail": eval_res.sample_detail,
                }
            )

    results: list[dict[str, Any]] = []
    for sample_id in sample_ids:
        algo = algos[sample_id]
        sample_trials = trials[sample_id]

        best_answer = ""
        best_reward = None
        top = algo.top_k(1)
        if top:
            best_answer, best_reward = top[0]
        elif sample_trials:
            best_answer = sample_trials[-1]["answer"]
            best_reward = sample_trials[-1]["reward"]

        submit_metric = None
        submit_detail = None
        if do_submit:
            submit_res = (
                task.submit(sample_id, best_answer)
                if best_answer
                else ProbResult(metric=0.0, sample_detail={"error": "no_trials"})
            )
            submit_metric = submit_res.metric
            submit_detail = submit_res.sample_detail

        results.append(
            {
                "sample_id": sample_id,
                "best_answer": best_answer,
                "best_reward": best_reward,
                "submit_metric": submit_metric,
                "submit_detail": submit_detail,
                "trials": sample_trials,
            }
        )

    _render_abmcts_trees(algos, cfg)
    return results


def _submit_results(
    task: Task,
    results: list[dict[str, Any]],
    *,
    max_workers: int,
) -> None:
    if not results:
        return
    pending: list[tuple[int, int, str]] = []
    for idx, result in enumerate(results):
        if result.get("submit_metric") is not None:
            continue
        best_answer = result.get("best_answer") or ""
        if not best_answer:
            result["submit_metric"] = 0.0
            result["submit_detail"] = {"error": "no_trials"}
            continue
        pending.append((idx, result["sample_id"], best_answer))
    if not pending:
        return

    worker_count = max(1, min(max_workers, len(pending)))
    if worker_count == 1:
        for idx, sample_id, answer in pending:
            submit_res = task.submit(sample_id, answer)
            result = results[idx]
            result["submit_metric"] = submit_res.metric
            result["submit_detail"] = submit_res.sample_detail
        return

    with ThreadPoolExecutor(max_workers=worker_count) as executor:
        futures = {
            executor.submit(task.submit, sample_id, answer): idx
            for idx, sample_id, answer in pending
        }
        for future, idx in futures.items():
            submit_res = future.result()
            result = results[idx]
            result["submit_metric"] = submit_res.metric
            result["submit_detail"] = submit_res.sample_detail


def _count_result_nfe(result: dict[str, Any]) -> int:
    total = 0
    for trial in result.get("trials", []):
        try:
            total += int(trial.get("num_func_evals", 0) or 0)
        except (TypeError, ValueError):
            continue
    return total


def _run_queue_worker(
    worker_id: int,
    gpu: int | None,
    cfg: dict[str, Any],
    task_queue: Any,
    result_queue: Any,
) -> None:
    try:
        if gpu is not None:
            os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu)

        _apply_task_defaults(cfg)
        _apply_nfe_overrides(cfg)
        system_prompt = _resolve_system_prompt(cfg)

        import torch

        device_pref = cfg.get("demask", {}).get("device", "cuda")
        device = (
            torch.device("cuda")
            if device_pref != "cpu" and torch.cuda.is_available()
            else torch.device("cpu")
        )

        task = _build_task(cfg["task"])
        model_entries = _resolve_model_entries(cfg)
        model_names = [entry["name"] for entry in model_entries]
        handles = _build_model_handles(cfg, model_entries, device=device)
        generator = DemaskGenerator(
            handles,
            gen_length=int(cfg["algo"]["gen_length"]),
            system_prompt=system_prompt,
            debug_dir=_resolve_debug_dir(cfg),
            trace_demask_steps=bool(
                cfg.get("output", {}).get(
                    "demask_trace_steps",
                    cfg.get("algo", {}).get("name")
                    in ("multimodel_mcts", "mask_strategy_mcts", "temperature_mcts"),
                )
            ),
        )
        default_model = model_names[0]
        submit_workers = _resolve_submit_workers(cfg, 1)

        while True:
            sample_id = task_queue.get()
            if sample_id is None:
                break

            if cfg["algo"]["name"] == "abmcts":
                sample_results = _run_abmcts_parallel_for_samples(
                    sample_ids=[sample_id],
                    task=task,
                    generator=generator,
                    default_model=default_model,
                    cfg=cfg,
                    progress=None,
                    do_submit=False,
                )
                if not sample_results:
                    raise RuntimeError(
                        f"worker {worker_id} received empty result for sample {sample_id}"
                    )
                result = sample_results[0]
            else:
                prompt = _task_prompt(task, sample_id)
                algo = _build_algo(cfg)
                result = _run_search_for_sample(
                    sample_id=sample_id,
                    prompt=prompt,
                    task=task,
                    algo=algo,
                    generator=generator,
                    default_model=default_model,
                    cfg=cfg,
                    progress=None,
                    do_submit=False,
                )

            _submit_results(task, [result], max_workers=submit_workers)
            result_queue.put(
                {
                    "result": result,
                    "nfe": _count_result_nfe(result),
                    "worker_id": worker_id,
                }
            )
    except Exception as exc:
        result_queue.put({"error": str(exc), "worker_id": worker_id})
        raise


def _run_shard(
    shard_id: int,
    sample_ids: list[int],
    cfg: dict[str, Any],
) -> list[dict[str, Any]]:
    if not sample_ids:
        return []

    _apply_task_defaults(cfg)
    _apply_nfe_overrides(cfg)
    system_prompt = _resolve_system_prompt(cfg)

    gpu_list = cfg.get("parallel", {}).get("gpus")
    if gpu_list:
        gpu = gpu_list[shard_id % len(gpu_list)]
        os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu)

    import torch

    device_pref = cfg.get("demask", {}).get("device", "cuda")
    device = (
        torch.device("cuda")
        if device_pref != "cpu" and torch.cuda.is_available()
        else torch.device("cpu")
    )

    task = _build_task(cfg["task"])
    model_entries = _resolve_model_entries(cfg)
    model_names = [entry["name"] for entry in model_entries]
    handles = _build_model_handles(cfg, model_entries, device=device)
    generator = DemaskGenerator(
        handles,
        gen_length=int(cfg["algo"]["gen_length"]),
        system_prompt=system_prompt,
        debug_dir=_resolve_debug_dir(cfg),
        trace_demask_steps=bool(
            cfg.get("output", {}).get(
                "demask_trace_steps",
                cfg.get("algo", {}).get("name")
                in ("multimodel_mcts", "mask_strategy_mcts", "temperature_mcts"),
            )
        ),
    )
    default_model = model_names[0]

    progress = _ProgressTracker(
        enable=_progress_enabled(cfg),
        sample_total=len(sample_ids),
        nfe_total=_progress_nfe_total(cfg, len(sample_ids)),
        position=shard_id * 2 if shard_id >= 0 else None,
        label=_progress_label(cfg, shard_id),
    )
    submit_workers = _resolve_submit_workers(cfg, len(sample_ids))
    try:
        results: list[dict[str, Any]] = []
        if cfg["algo"]["name"] == "abmcts":
            results = _run_abmcts_parallel_for_samples(
                sample_ids=sample_ids,
                task=task,
                generator=generator,
                default_model=default_model,
                cfg=cfg,
                progress=progress,
                do_submit=False,
            )
        else:
            for sample_id in sample_ids:
                prompt = _task_prompt(task, sample_id)
                algo = _build_algo(cfg)
                result = _run_search_for_sample(
                    sample_id=sample_id,
                    prompt=prompt,
                    task=task,
                    algo=algo,
                    generator=generator,
                    default_model=default_model,
                    cfg=cfg,
                    progress=progress,
                    do_submit=False,
                )
                results.append(result)
                progress.update_samples(1)
        _submit_results(task, results, max_workers=submit_workers)
        return results
    finally:
        progress.close()


def _failed_workers(workers: list[mp.Process]) -> list[mp.Process]:
    return [worker for worker in workers if worker.exitcode not in (None, 0)]


def _run_queue_schedule(
    sample_ids: list[int],
    cfg: dict[str, Any],
) -> list[dict[str, Any]]:
    if not sample_ids:
        return []

    parallel_cfg = cfg.get("parallel", {})
    try:
        requested_workers = int(parallel_cfg.get("num_workers", 1))
    except (TypeError, ValueError):
        requested_workers = 1
    worker_count = max(1, min(requested_workers, len(sample_ids)))
    gpu_list = list(parallel_cfg.get("gpus") or [])

    ctx = mp.get_context("spawn")
    task_queue = ctx.Queue()
    result_queue = ctx.Queue()

    for sample_id in sample_ids:
        task_queue.put(sample_id)
    for _ in range(worker_count):
        task_queue.put(None)

    workers: list[mp.Process] = []
    for worker_id in range(worker_count):
        gpu = None
        if gpu_list:
            gpu = gpu_list[worker_id % len(gpu_list)]
        worker = ctx.Process(
            target=_run_queue_worker,
            name=f"gpu-worker-{worker_id}",
            args=(worker_id, gpu, cfg, task_queue, result_queue),
        )
        worker.start()
        workers.append(worker)

    progress = None
    if _progress_enabled(cfg):
        progress = _ProgressTracker(
            enable=True,
            sample_total=len(sample_ids),
            nfe_total=_progress_nfe_total(cfg, len(sample_ids)),
            position=None,
            label=None,
        )

    results: list[dict[str, Any]] = []
    completed = 0
    total = len(sample_ids)
    error_msg = None
    try:
        while completed < total:
            try:
                payload = result_queue.get(timeout=1.0)
            except queue.Empty:
                failed = _failed_workers(workers)
                if failed:
                    error_msg = "queue workers failed: " + ", ".join(
                        f"{worker.name}:{worker.exitcode}" for worker in failed
                    )
                    break
                continue

            if not isinstance(payload, dict):
                error_msg = "queue worker returned invalid payload"
                break
            if "error" in payload:
                error_msg = (
                    f"queue worker {payload.get('worker_id')} failed: "
                    f"{payload.get('error')}"
                )
                break
            result = payload.get("result")
            if result is None:
                error_msg = "queue worker returned empty result"
                break
            results.append(result)
            completed += 1
            if progress is not None:
                progress.update_samples(1)
                progress.update_nfe(int(payload.get("nfe", 0) or 0))
    finally:
        if progress is not None:
            progress.close()

    if error_msg:
        for worker in workers:
            if worker.is_alive():
                worker.terminate()
        for worker in workers:
            worker.join()
        raise RuntimeError(error_msg)

    for worker in workers:
        worker.join()

    failed = _failed_workers(workers)
    if failed:
        raise RuntimeError(
            "queue workers failed: "
            + ", ".join(f"{worker.name}:{worker.exitcode}" for worker in failed)
        )
    if completed < total:
        raise RuntimeError(
            f"queue run incomplete: {completed}/{total} results collected"
        )
    return results


def run_experiment(cfg: dict[str, Any]) -> dict[str, Any]:
    _set_seed(cfg.get("seed"))
    _apply_task_defaults(cfg)
    _apply_nfe_overrides(cfg)
    _resolve_system_prompt(cfg)

    task = _build_task(cfg["task"])
    sample_ids = _select_sample_ids(task, cfg["task"], cfg.get("seed"))
    requested_workers = int(cfg.get("parallel", {}).get("num_workers", 1))
    shards = _split_ids(sample_ids, requested_workers)
    workers = max(1, min(requested_workers, len(shards)))

    schedule = _resolve_parallel_schedule(cfg)
    if schedule == "dynamic":
        results = _run_queue_schedule(sample_ids, cfg)
    else:
        with parallel_config(backend="loky", n_jobs=workers, prefer="processes"):
            shard_results = Parallel(return_as="list")(
                delayed(_run_shard)(shard_id, shard, cfg)
                for shard_id, shard in enumerate(shards)
            )

        results = [item for shard in shard_results for item in shard]
    results.sort(key=lambda item: item["sample_id"])

    total_trials = 0
    total_cached_trials = 0
    total_nfe = 0
    unique_answer_counts: list[int] = []
    unique_uncached_answer_counts: list[int] = []

    for result in results:
        trials = result.get("trials")
        if not isinstance(trials, list):
            continue
        total_trials += len(trials)

        cached_trials = 0
        nfe = 0
        answers: set[str] = set()
        uncached_answers: set[str] = set()
        for trial in trials:
            if not isinstance(trial, dict):
                continue
            is_cached = bool(trial.get("cached", False))
            cached_trials += int(is_cached)
            try:
                nfe += int(trial.get("num_func_evals", 0) or 0)
            except (TypeError, ValueError):
                pass
            answer = trial.get("answer")
            if isinstance(answer, str):
                answers.add(answer)
                if not is_cached:
                    uncached_answers.add(answer)

        total_cached_trials += cached_trials
        total_nfe += nfe
        unique_answer_counts.append(len(answers))
        unique_uncached_answer_counts.append(len(uncached_answers))

    submit_scores = [
        r["submit_metric"] for r in results if r["submit_metric"] is not None
    ]
    best_rewards = [r["best_reward"] for r in results if r["best_reward"] is not None]
    summary = {
        "num_samples": len(results),
        "mean_submit": sum(submit_scores) / len(submit_scores)
        if submit_scores
        else None,
        "mean_best_reward": sum(best_rewards) / len(best_rewards)
        if best_rewards
        else None,
        "total_nfe": total_nfe,
        "mean_nfe": total_nfe / len(results) if results else None,
        "total_trials": total_trials,
        "total_cached_trials": total_cached_trials,
        "total_uncached_trials": total_trials - total_cached_trials,
        "cache_hit_rate": total_cached_trials / total_trials if total_trials else None,
        "mean_trials": total_trials / len(results) if results else None,
        "mean_cached_trials": total_cached_trials / len(results) if results else None,
        "mean_uncached_trials": (total_trials - total_cached_trials) / len(results)
        if results
        else None,
        "mean_unique_answers": sum(unique_answer_counts) / len(unique_answer_counts)
        if unique_answer_counts
        else None,
        "mean_unique_uncached_answers": sum(unique_uncached_answer_counts)
        / len(unique_uncached_answer_counts)
        if unique_uncached_answer_counts
        else None,
    }
    output = {
        "summary": summary,
        "results": results,
        "config": cfg,
    }

    output = _json_safe(output)
    output_path = Path(cfg.get("output", {}).get("path", "results.json"))
    output_path.parent.mkdir(parents=True, exist_ok=True)
    with output_path.open("w", encoding="utf-8") as f:
        json.dump(output, f, indent=2, ensure_ascii=True)

    return output


@hydra.main(
    version_base=None,
    config_path="conf",
    config_name="experiment",
)
def main(cfg: DictConfig) -> None:
    run_experiment(OmegaConf.to_container(cfg, resolve=True))


if __name__ == "__main__":
    main()
