from __future__ import annotations

import argparse
import json
import math
import os
import random
import re
import select
import site
import subprocess
import sys
import time
from collections import Counter, defaultdict
from dataclasses import asdict, dataclass, field
from pathlib import Path
from typing import Any, Dict, Iterable, Iterator, List, Sequence, Tuple

try:
    from sgfmill import boards as sgf_boards
    from sgfmill import sgf
except ImportError:  # pragma: no cover
    sgf = None
    sgf_boards = None


BOARD_COLUMNS = "ABCDEFGHJKLMNOPQRSTUVWXYZ"
DEFAULT_RULES = "japanese"
WORD_RE = re.compile(r"[A-Za-z0-9_]+|[^\w\s]")


@dataclass
class ThresholdConfig:
    phase_opening_min: float = 12.0
    phase_mid_min: float = 6.0
    phase_late_min: float = 2.5
    control_none_abs_threshold: float = 0.08
    contest_none_threshold: float = 0.25
    contest_low_threshold: float = 0.25
    contest_high_threshold: float = 0.50
    urgency_must_win_gap: float = 0.10
    urgency_high_win_gap: float = 0.05
    urgency_med_win_gap: float = 0.02
    urgency_must_score_gap: float = 12.0
    urgency_high_score_gap: float = 6.0
    urgency_med_score_gap: float = 3.0
    urgency_must_visit_share: float = 0.75
    urgency_high_visit_share: float = 0.55
    urgency_med_visit_share: float = 0.35
    surprise_high_threshold: float = 0.50
    surprise_med_threshold: float = 0.20


@dataclass
class BuildConfig:
    sgf_dir: str
    output_dir: str
    max_positions: int = 10_000
    max_games: int = 0
    max_positions_per_game: int = 8
    min_comment_chars: int = 20
    split_seed: int = 42
    eval_fraction: float = 0.10
    max_position_tokens: int = 196
    katago_binary: str = ""
    katago_model: str = ""
    katago_config: str = ""
    katago_visits: int = 2000
    katago_inflight_limit: int = 64
    query_timeout_sec: float = 180.0
    include_ownership: bool = True
    include_ownership_stdev: bool = True
    report_during_search_every: float = 0.0
    incremental_flush_every: int = 50
    seed: int = 42
    thresholds: ThresholdConfig = field(default_factory=ThresholdConfig)


@dataclass
class CandidatePosition:
    index: int
    id: str
    game_id: str
    turn_number: int
    board_size: int
    to_move: str
    rules: str
    komi: float
    stones: Dict[str, List[str]]
    position_tokens: List[str]
    rationale_text: str
    initial_stones: List[List[str]]
    moves: List[List[str]]
    initial_player: str


def set_seed(seed: int) -> None:
    random.seed(seed)


def clamp_probability(value: float) -> float:
    return min(1.0, max(0.0, float(value)))


def maybe_float(value: Any, default: float) -> float:
    try:
        return float(value)
    except (TypeError, ValueError):
        return default


def maybe_int(value: Any, default: int) -> int:
    try:
        return int(value)
    except (TypeError, ValueError):
        return default


def maybe_text(node: Any, prop: str) -> str:
    try:
        value = node.get(prop)
    except KeyError:
        return ""
    return str(value or "")


def rule_token(rules: str) -> str:
    return "RULE_" + re.sub(r"[^A-Za-z0-9]+", "_", rules.strip().upper())


def normalize_komi_token(komi: float) -> str:
    return f"KOMI_{str(komi).replace('-', 'NEG_').replace('.', '_')}"


def row_col_to_gtp(row: int, col: int, board_size: int) -> str:
    return f"{BOARD_COLUMNS[col]}{board_size - row}"


def gtp_to_row_col(coord: str, board_size: int) -> Tuple[int, int]:
    coord = coord.strip().upper()
    if coord == "PASS":
        raise ValueError("pass has no board coordinate")
    match = re.fullmatch(r"([A-Z]+)(\d+)", coord)
    if not match:
        raise ValueError(f"Bad GTP coordinate: {coord}")
    col_text, row_text = match.groups()
    col = 0
    for ch in col_text:
        col = col * 26 + (ord(ch) - ord("A") + 1)
    col -= 1
    row = board_size - int(row_text)
    return row, col


def build_position_tokens(
    board_size: int,
    to_move: str,
    komi: float,
    rules: str,
    stones: Dict[str, Sequence[str]],
    max_position_tokens: int,
) -> List[str]:
    black = sorted(coord.upper() for coord in stones.get("black", []))
    white = sorted(coord.upper() for coord in stones.get("white", []))
    tokens = [
        f"SZ{board_size}",
        f"TM_{to_move.upper()}",
        normalize_komi_token(komi),
        rule_token(rules or DEFAULT_RULES),
    ]
    tokens.extend(f"B_{coord}" for coord in black)
    tokens.extend(f"W_{coord}" for coord in white)
    return tokens[:max_position_tokens]


def normalize_commentary(text: str) -> str:
    cleaned = " ".join(text.replace("\r", " ").replace("\n", " ").split())
    cleaned = cleaned.strip(" []")
    return cleaned


def board_to_stones(board: Any, board_size: int) -> Dict[str, List[str]]:
    stones = {"black": [], "white": []}
    for row in range(board_size):
        for col in range(board_size):
            value = board.get(row, col)
            if value == "b":
                stones["black"].append(row_col_to_gtp(row, col, board_size))
            elif value == "w":
                stones["white"].append(row_col_to_gtp(row, col, board_size))
    return stones


def initial_stones_tuples(stones: Dict[str, Sequence[str]]) -> List[List[str]]:
    tuples: List[List[str]] = []
    for coord in stones.get("black", []):
        tuples.append(["B", coord.upper()])
    for coord in stones.get("white", []):
        tuples.append(["W", coord.upper()])
    return tuples


def iter_mainline_nodes(game: Any) -> List[Any]:
    if sgf is None:
        raise ImportError("sgfmill is required")
    return game.get_main_sequence()


def comment_is_useful(text: str, min_comment_chars: int) -> bool:
    cleaned = normalize_commentary(text)
    if len(cleaned) < min_comment_chars:
        return False
    if not any(ch.isalpha() for ch in cleaned):
        return False
    return True


def infer_initial_player(root: Any, moves: Sequence[List[str]]) -> str:
    root_player = maybe_text(root, "PL").strip().upper()
    if moves:
        return moves[0][0]
    if root_player in {"B", "W"}:
        return root_player
    return "B"


def first_mainline_player(nodes: Sequence[Any]) -> str:
    for node in nodes[1:]:
        color, _move = node.get_move()
        if color is not None:
            return str(color).upper()
    return "B"


def candidate_id(game_id: str, turn_number: int, comment_index: int) -> str:
    return f"{game_id}_move{turn_number}_c{comment_index}"


def extract_comment_positions_from_sgf(path: Path, cfg: BuildConfig) -> List[CandidatePosition]:
    if sgf is None or sgf_boards is None:
        raise ImportError("sgfmill is required for SGF parsing")
    game = sgf.Sgf_game.from_bytes(path.read_bytes())
    nodes = iter_mainline_nodes(game)
    root = nodes[0]
    board_size = game.get_size()
    rules = maybe_text(root, "RU") or DEFAULT_RULES
    komi = maybe_float(maybe_text(root, "KM"), 6.5)

    board = sgf_boards.Board(board_size)
    for prop, color in [("AB", "b"), ("AW", "w")]:
        try:
            coords = root.get_setup_stones(prop)
        except Exception:
            coords = []
        for row, col in coords:
            board.play(row, col, color)

    moves: List[List[str]] = []
    initial_stones = initial_stones_tuples(board_to_stones(board, board_size))
    initial_player = infer_initial_player(root, [])
    if initial_player == "B":
        initial_player = first_mainline_player(nodes) if maybe_text(root, "PL").strip().upper() not in {"B", "W"} else initial_player

    game_id = path.stem
    candidates: List[CandidatePosition] = []
    per_game_count = 0
    comment_index = 0

    root_comment = maybe_text(root, "C")
    if root_comment and comment_is_useful(root_comment, cfg.min_comment_chars):
        stones = board_to_stones(board, board_size)
        candidates.append(
            CandidatePosition(
                index=-1,
                id=candidate_id(game_id, 0, comment_index),
                game_id=game_id,
                turn_number=0,
                board_size=board_size,
                to_move=initial_player,
                rules=rules,
                komi=komi,
                stones=stones,
                position_tokens=build_position_tokens(
                    board_size, initial_player, komi, rules, stones, cfg.max_position_tokens
                ),
                rationale_text=normalize_commentary(root_comment),
                initial_stones=initial_stones,
                moves=[],
                initial_player=initial_player,
            )
        )
        per_game_count += 1
        comment_index += 1

    for node in nodes[1:]:
        color, move = node.get_move()
        if color is None:
            comment = maybe_text(node, "C")
            if (
                comment
                and not moves
                and comment_is_useful(comment, cfg.min_comment_chars)
                and (cfg.max_positions_per_game <= 0 or per_game_count < cfg.max_positions_per_game)
            ):
                stones = board_to_stones(board, board_size)
                candidates.append(
                    CandidatePosition(
                        index=-1,
                        id=candidate_id(game_id, 0, comment_index),
                        game_id=game_id,
                        turn_number=0,
                        board_size=board_size,
                        to_move=initial_player,
                        rules=rules,
                        komi=komi,
                        stones=stones,
                        position_tokens=build_position_tokens(
                            board_size, initial_player, komi, rules, stones, cfg.max_position_tokens
                        ),
                        rationale_text=normalize_commentary(comment),
                        initial_stones=initial_stones,
                        moves=[],
                        initial_player=initial_player,
                    )
                )
                per_game_count += 1
                comment_index += 1
            continue

        color = color.upper()
        gtp_move = "pass" if move is None else row_col_to_gtp(move[0], move[1], board_size)
        if move is not None:
            try:
                board.play(move[0], move[1], color.lower())
            except ValueError as exc:
                raise ValueError(f"invalid move in {path.name} at turn {len(moves)+1}: {color}[{gtp_move}]") from exc
        moves.append([color, gtp_move])
        comment = maybe_text(node, "C")
        if not comment or not comment_is_useful(comment, cfg.min_comment_chars):
            continue
        if cfg.max_positions_per_game > 0 and per_game_count >= cfg.max_positions_per_game:
            continue
        stones = board_to_stones(board, board_size)
        to_move = "W" if color == "B" else "B"
        turn_number = len(moves)
        candidates.append(
            CandidatePosition(
                index=-1,
                id=candidate_id(game_id, turn_number, comment_index),
                game_id=game_id,
                turn_number=turn_number,
                board_size=board_size,
                to_move=to_move,
                rules=rules,
                komi=komi,
                stones=stones,
                position_tokens=build_position_tokens(
                    board_size, to_move, komi, rules, stones, cfg.max_position_tokens
                ),
                rationale_text=normalize_commentary(comment),
                initial_stones=initial_stones,
                moves=list(moves),
                initial_player=infer_initial_player(root, moves),
            )
        )
        per_game_count += 1
        comment_index += 1
    return candidates


def mine_comment_positions(cfg: BuildConfig) -> Tuple[List[CandidatePosition], Counter]:
    sgf_paths = sorted(Path(cfg.sgf_dir).glob("**/*.sgf"))
    if cfg.max_games > 0:
        sgf_paths = sgf_paths[: cfg.max_games]
    stats: Counter = Counter()
    candidates: List[CandidatePosition] = []
    for sgf_path in sgf_paths:
        try:
            rows = extract_comment_positions_from_sgf(sgf_path, cfg)
        except Exception as exc:
            stats["skipped_sgf"] += 1
            print(f"[WARN] Skipping {sgf_path}: {exc}")
            continue
        candidates.extend(rows)
        stats["games_considered"] += 1
        stats["candidate_positions"] += len(rows)

    rng = random.Random(cfg.seed)
    rng.shuffle(candidates)
    candidates = candidates[: cfg.max_positions]
    for idx, candidate in enumerate(candidates):
        candidate.index = idx
    stats["selected_positions"] = len(candidates)
    return candidates, stats


def region_name_for_point(row: int, col: int, board_size: int) -> str:
    top_rows = math.ceil(0.25 * board_size)
    bottom_start = math.floor(0.75 * board_size)
    left_cols = math.ceil(0.25 * board_size)
    right_start = math.floor(0.75 * board_size)
    if row < top_rows:
        return "TOP"
    if row >= bottom_start:
        return "BOTTOM"
    if col < left_cols:
        return "LEFT"
    if col >= right_start:
        return "RIGHT"
    return "CENTER"


def region_scores_from_board(values: Sequence[float], board_size: int) -> Dict[str, float]:
    sums = defaultdict(float)
    counts = defaultdict(int)
    for idx, value in enumerate(values):
        row = idx // board_size
        col = idx % board_size
        region = region_name_for_point(row, col, board_size)
        sums[region] += float(value)
        counts[region] += 1
    return {region: (sums[region] / counts[region]) for region in ["TOP", "BOTTOM", "LEFT", "RIGHT", "CENTER"]}


def compute_contest_scores(ownership: Sequence[float], ownership_stdev: Sequence[float]) -> List[float]:
    if len(ownership) != len(ownership_stdev):
        raise ValueError("ownership and ownershipStdev length mismatch")
    scores = []
    for own, stdev in zip(ownership, ownership_stdev):
        ownership_conflict = 1.0 - abs(float(own))
        uncertainty = float(stdev)
        scores.append(0.5 * ownership_conflict + 0.5 * uncertainty)
    return scores


def classify_win_prob_bin(win_prob: float) -> int:
    return min(9, max(0, round(clamp_probability(win_prob) * 9)))


def classify_score_lead_bin(score: float) -> str:
    if score <= -20:
        return "LEAD_N20P"
    if score <= -10:
        return "LEAD_N10_20"
    if score <= -5:
        return "LEAD_N5_10"
    if score < 5:
        return "LEAD_CLOSE"
    if score < 10:
        return "LEAD_P5_10"
    if score < 20:
        return "LEAD_P10_20"
    return "LEAD_P20P"


def classify_phase_estimate(raw_var_time_left: float, thresholds: ThresholdConfig) -> str:
    value = float(raw_var_time_left)
    if value >= thresholds.phase_opening_min:
        return "PHASE_OPENING"
    if value >= thresholds.phase_mid_min:
        return "PHASE_MID"
    if value >= thresholds.phase_late_min:
        return "PHASE_LATE"
    return "PHASE_SETTLED"


def classify_main_control_region(
    ownership: Sequence[float], board_size: int, thresholds: ThresholdConfig
) -> Tuple[str, Dict[str, float]]:
    means = region_scores_from_board(ownership, board_size)
    best_region, best_mean = max(means.items(), key=lambda item: abs(item[1]))
    if abs(best_mean) < thresholds.control_none_abs_threshold:
        return "CTRL_NONE", means
    side = "B" if best_mean > 0 else "W"
    return f"CTRL_{best_region}_{side}", means


def classify_main_contested_region(
    ownership: Sequence[float], ownership_stdev: Sequence[float], board_size: int, thresholds: ThresholdConfig
) -> Tuple[str, Dict[str, float], float]:
    point_scores = compute_contest_scores(ownership, ownership_stdev)
    means = region_scores_from_board(point_scores, board_size)
    best_region, best_score = max(means.items(), key=lambda item: item[1])
    if best_score < thresholds.contest_none_threshold:
        return "CONTEST_NONE", means, sum(point_scores) / len(point_scores)
    return f"CONTEST_{best_region}", means, sum(point_scores) / len(point_scores)


def classify_global_contestedness(global_score: float, thresholds: ThresholdConfig) -> str:
    if global_score < thresholds.contest_low_threshold:
        return "CONTEST_LOW"
    if global_score < thresholds.contest_high_threshold:
        return "CONTEST_MED"
    return "CONTEST_HIGH"


def ordered_move_infos(move_infos: Sequence[Dict[str, Any]]) -> List[Dict[str, Any]]:
    return sorted(move_infos, key=lambda item: maybe_int(item.get("order"), 1_000_000))


def classify_best_move_region(best_move: str, board_size: int) -> str:
    if best_move.lower() == "pass":
        return "BESTREG_PASS"
    row, col = gtp_to_row_col(best_move, board_size)
    return f"BESTREG_{region_name_for_point(row, col, board_size)}"


def classify_move_urgency(
    move_infos: Sequence[Dict[str, Any]], thresholds: ThresholdConfig
) -> Tuple[str, Dict[str, float]]:
    ordered = ordered_move_infos(move_infos)
    if not ordered:
        return "URG_LOW", {"win_gap": 0.0, "score_gap": 0.0, "visit_share": 0.0}
    top1 = ordered[0]
    top2 = ordered[1] if len(ordered) > 1 else ordered[0]
    edge_total = sum(max(0.0, maybe_float(item.get("edgeVisits"), 0.0)) for item in ordered)
    win_gap = abs(maybe_float(top1.get("winrate"), 0.0) - maybe_float(top2.get("winrate"), 0.0))
    score_gap = abs(maybe_float(top1.get("scoreLead"), 0.0) - maybe_float(top2.get("scoreLead"), 0.0))
    visit_share = maybe_float(top1.get("edgeVisits"), 0.0) / max(1.0, edge_total)
    if (
        win_gap >= thresholds.urgency_must_win_gap
        or score_gap >= thresholds.urgency_must_score_gap
        or visit_share >= thresholds.urgency_must_visit_share
    ):
        label = "URG_MUST"
    elif (
        win_gap >= thresholds.urgency_high_win_gap
        or score_gap >= thresholds.urgency_high_score_gap
        or visit_share >= thresholds.urgency_high_visit_share
    ):
        label = "URG_HIGH"
    elif (
        win_gap >= thresholds.urgency_med_win_gap
        or score_gap >= thresholds.urgency_med_score_gap
        or visit_share >= thresholds.urgency_med_visit_share
    ):
        label = "URG_MED"
    else:
        label = "URG_LOW"
    return label, {"win_gap": win_gap, "score_gap": score_gap, "visit_share": visit_share}


def classify_search_surprise(
    move_infos: Sequence[Dict[str, Any]], thresholds: ThresholdConfig
) -> Tuple[str, Dict[str, float]]:
    ordered = ordered_move_infos(move_infos)
    if not ordered:
        return "SURPRISE_LOW", {"prior_top": 1.0, "edge_share": 0.0, "surprise_score": 0.0}
    top1 = ordered[0]
    edge_total = sum(max(0.0, maybe_float(item.get("edgeVisits"), 0.0)) for item in ordered)
    prior_top = clamp_probability(maybe_float(top1.get("prior"), 0.0))
    edge_share = maybe_float(top1.get("edgeVisits"), 0.0) / max(1.0, edge_total)
    surprise_score = (1.0 - prior_top) * edge_share
    if surprise_score >= thresholds.surprise_high_threshold:
        label = "SURPRISE_HIGH"
    elif surprise_score >= thresholds.surprise_med_threshold:
        label = "SURPRISE_MED"
    else:
        label = "SURPRISE_LOW"
    return label, {"prior_top": prior_top, "edge_share": edge_share, "surprise_score": surprise_score}


def derive_claim_fields(
    candidate: CandidatePosition,
    payload: Dict[str, Any],
    thresholds: ThresholdConfig,
) -> Dict[str, Any]:
    root_info = payload.get("rootInfo") or {}
    ownership = payload.get("ownership")
    ownership_stdev = payload.get("ownershipStdev")
    move_infos = payload.get("moveInfos") or []
    if ownership is None or ownership_stdev is None or not move_infos:
        raise ValueError("missing ownership/ownershipStdev/moveInfos in KataGo response")

    win_prob = clamp_probability(maybe_float(root_info.get("winrate"), 0.5))
    score_lead = maybe_float(root_info.get("scoreLead"), 0.0)
    raw_var_time_left = maybe_float(root_info.get("rawVarTimeLeft"), 0.0)
    best_move_info = ordered_move_infos(move_infos)[0]
    best_move = str(best_move_info.get("move", "pass"))

    control_region, control_region_means = classify_main_control_region(
        ownership, candidate.board_size, thresholds
    )
    contested_region, contested_region_means, global_contest_score = classify_main_contested_region(
        ownership, ownership_stdev, candidate.board_size, thresholds
    )
    move_urgency, urgency_debug = classify_move_urgency(move_infos, thresholds)
    search_surprise, surprise_debug = classify_search_surprise(move_infos, thresholds)

    return {
        "win_prob": win_prob,
        "win_prob_bin": classify_win_prob_bin(win_prob),
        "score_lead": score_lead,
        "score_lead_bin": classify_score_lead_bin(score_lead),
        "phase_estimate": classify_phase_estimate(raw_var_time_left, thresholds),
        "main_control_region": control_region,
        "main_contested_region": contested_region,
        "global_contestedness": classify_global_contestedness(global_contest_score, thresholds),
        "best_move": best_move,
        "best_move_region": classify_best_move_region(best_move, candidate.board_size),
        "move_urgency": move_urgency,
        "search_surprise": search_surprise,
        "raw_win_prob": win_prob,
        "raw_score_lead": score_lead,
        "raw_var_time_left": raw_var_time_left,
        "control_region_means": control_region_means,
        "contested_region_means": contested_region_means,
        "global_contest_score": global_contest_score,
        **urgency_debug,
        **surprise_debug,
    }


def row_from_candidate_and_claims(candidate: CandidatePosition, claims: Dict[str, Any]) -> Dict[str, Any]:
    return {
        "id": candidate.id,
        "game_id": candidate.game_id,
        "turn_number": candidate.turn_number,
        "board_size": candidate.board_size,
        "to_move": candidate.to_move,
        "rules": candidate.rules,
        "komi": candidate.komi,
        "stones": candidate.stones,
        "position_tokens": candidate.position_tokens,
        "rationale_text": candidate.rationale_text,
        **claims,
    }


class KataGoAnalysisClient:
    def __init__(self, cfg: BuildConfig):
        self.cfg = cfg
        self.proc: subprocess.Popen[str] | None = None

    def _subprocess_env(self) -> Dict[str, str]:
        env = os.environ.copy()
        env.setdefault("APPIMAGE_EXTRACT_AND_RUN", "1")
        lib_dirs: List[str] = [p for p in env.get("LD_LIBRARY_PATH", "").split(":") if p]
        search_roots = {Path(sys.prefix)}
        for path_str in sys.path:
            if path_str:
                search_roots.add(Path(path_str))
        for path_str in site.getsitepackages():
            search_roots.add(Path(path_str))
        candidate_dirs: set[Path] = set()
        for root in search_roots:
            if not root.exists():
                continue
            direct_lib = root / "lib"
            if direct_lib.exists():
                candidate_dirs.add(direct_lib)
            for pattern in ["torch/lib", "nvidia/*/lib"]:
                for match in root.glob(pattern):
                    if match.is_dir():
                        candidate_dirs.add(match)
        needed_patterns = ("libcublas.so*", "libcudnn.so*", "libcudart.so*", "libcusolver.so*")
        for directory in sorted(candidate_dirs):
            try:
                if any(any(directory.glob(pattern)) for pattern in needed_patterns):
                    lib_dirs.append(str(directory))
            except OSError:
                continue
        if lib_dirs:
            env["LD_LIBRARY_PATH"] = ":".join(dict.fromkeys(lib_dirs))
        return env

    def __enter__(self) -> "KataGoAnalysisClient":
        command = [
            self.cfg.katago_binary,
            "analysis",
            "-model",
            self.cfg.katago_model,
            "-config",
            self.cfg.katago_config,
        ]
        self.proc = subprocess.Popen(
            command,
            stdin=subprocess.PIPE,
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
            text=True,
            bufsize=1,
            env=self._subprocess_env(),
        )
        return self

    def __exit__(self, exc_type, exc, tb) -> None:
        if self.proc is None:
            return
        try:
            if self.proc.stdin is not None:
                self.proc.stdin.close()
        except Exception:
            pass
        try:
            self.proc.wait(timeout=5)
        except subprocess.TimeoutExpired:
            self.proc.terminate()
            self.proc.wait(timeout=5)

    def submit(self, candidate: CandidatePosition) -> None:
        if self.proc is None or self.proc.stdin is None:
            raise RuntimeError("KataGo process is not running")
        query: Dict[str, Any] = {
            "id": candidate.id,
            "initialStones": candidate.initial_stones,
            "moves": candidate.moves,
            "rules": candidate.rules,
            "komi": candidate.komi,
            "boardXSize": candidate.board_size,
            "boardYSize": candidate.board_size,
            "maxVisits": self.cfg.katago_visits,
            "includeOwnership": self.cfg.include_ownership,
            "includeOwnershipStdev": self.cfg.include_ownership_stdev,
        }
        if not candidate.moves:
            query["initialPlayer"] = candidate.initial_player
        if self.cfg.report_during_search_every > 0:
            query["reportDuringSearchEvery"] = self.cfg.report_during_search_every
        self.proc.stdin.write(json.dumps(query) + "\n")
        self.proc.stdin.flush()

    def _read_ready_lines(self, timeout: float) -> List[str]:
        if self.proc is None or self.proc.stdout is None:
            raise RuntimeError("KataGo process is not running")
        ready, _, _ = select.select([self.proc.stdout], [], [], timeout)
        if not ready:
            return []
        lines: List[str] = []
        while True:
            line = self.proc.stdout.readline()
            if line == "":
                break
            lines.append(line.rstrip("\n"))
            ready, _, _ = select.select([self.proc.stdout], [], [], 0)
            if not ready:
                break
        return lines

    def poll_stderr(self) -> str:
        if self.proc is None or self.proc.stderr is None:
            return ""
        ready, _, _ = select.select([self.proc.stderr], [], [], 0)
        if not ready:
            return ""
        chunks: List[str] = []
        while True:
            line = self.proc.stderr.readline()
            if line == "":
                break
            chunks.append(line.rstrip("\n"))
            ready, _, _ = select.select([self.proc.stderr], [], [], 0)
            if not ready:
                break
        return "\n".join(chunks)

    def recv_one(self, timeout: float) -> Dict[str, Any] | None:
        lines = self._read_ready_lines(timeout=timeout)
        for line in lines:
            stripped = line.strip()
            if not stripped:
                continue
            try:
                payload = json.loads(stripped)
            except json.JSONDecodeError:
                continue
            if "warning" in payload:
                print(f"[KATAGO WARNING] {payload}")
                continue
            return payload
        return None


def flush_incremental_rows(path: Path, rows: Sequence[Dict[str, Any]]) -> None:
    if not rows:
        return
    path.parent.mkdir(parents=True, exist_ok=True)
    with path.open("a") as f:
        for row in rows:
            f.write(json.dumps(row) + "\n")


def run_katago_queries(
    cfg: BuildConfig,
    candidates: Sequence[CandidatePosition],
    incremental_full_path: Path | None = None,
) -> Tuple[List[Dict[str, Any]], Counter]:
    if not (cfg.katago_binary and cfg.katago_model and cfg.katago_config):
        raise ValueError("KataGo binary/model/config are required")

    rows_by_index: Dict[int, Dict[str, Any]] = {}
    incremental_buffer: List[Dict[str, Any]] = []
    stats: Counter = Counter()
    pending: Iterator[CandidatePosition] = iter(candidates)
    inflight: Dict[str, CandidatePosition] = {}
    completed = 0
    last_report = time.time()

    with KataGoAnalysisClient(cfg) as client:
        while True:
            while len(inflight) < cfg.katago_inflight_limit:
                try:
                    candidate = next(pending)
                except StopIteration:
                    break
                client.submit(candidate)
                inflight[candidate.id] = candidate
                stats["submitted"] += 1

            if not inflight:
                break

            payload = client.recv_one(timeout=cfg.query_timeout_sec)
            if payload is None:
                stderr = client.poll_stderr()
                stats["timed_out_queries"] += len(inflight)
                stats["timed_out_query_ids"] = [candidate.id for candidate in inflight.values()]
                print(
                    "[WARN] Timed out waiting for remaining KataGo results. "
                    f"Skipping {len(inflight)} stuck queries."
                )
                if stderr:
                    print(f"[WARN] KataGo stderr before timeout:\n{stderr}")
                break

            query_id = str(payload.get("id", ""))
            if not query_id or query_id not in inflight:
                continue
            candidate = inflight.pop(query_id)
            if "error" in payload:
                stats["katago_error"] += 1
                print(f"[WARN] KataGo error for {query_id}: {payload['error']}")
                continue
            try:
                claims = derive_claim_fields(candidate, payload, cfg.thresholds)
            except Exception as exc:
                stats["field_derivation_error"] += 1
                print(f"[WARN] Skipping {query_id}: {exc}")
                continue
            row = row_from_candidate_and_claims(candidate, claims)
            rows_by_index[candidate.index] = row
            incremental_buffer.append(row)
            completed += 1
            stats["completed"] = completed
            if incremental_full_path is not None and len(incremental_buffer) >= cfg.incremental_flush_every:
                flush_incremental_rows(incremental_full_path, incremental_buffer)
                incremental_buffer = []

            now = time.time()
            if now - last_report >= 10:
                print(
                    f"[INFO] KataGo progress: completed={completed} submitted={stats['submitted']} inflight={len(inflight)}"
                )
                last_report = now

    if incremental_full_path is not None and incremental_buffer:
        flush_incremental_rows(incremental_full_path, incremental_buffer)
    rows = [rows_by_index[idx] for idx in sorted(rows_by_index)]
    stats["final_rows"] = len(rows)
    return rows, stats


def split_rows_by_game(rows: Sequence[Dict[str, Any]], eval_fraction: float, seed: int) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
    grouped: Dict[str, List[Dict[str, Any]]] = defaultdict(list)
    for row in rows:
        game_id = str(row.get("game_id") or str(row["id"]).split("_move", 1)[0])
        grouped[game_id].append(row)
    game_ids = list(grouped)
    rng = random.Random(seed)
    rng.shuffle(game_ids)
    eval_target = max(1, round(len(rows) * eval_fraction))
    train_rows: List[Dict[str, Any]] = []
    eval_rows: List[Dict[str, Any]] = []
    eval_count = 0
    for game_id in game_ids:
        bucket = eval_rows if eval_count < eval_target else train_rows
        bucket.extend(grouped[game_id])
        if bucket is eval_rows:
            eval_count += len(grouped[game_id])
    return train_rows, eval_rows


def write_jsonl(path: Path, rows: Sequence[Dict[str, Any]]) -> None:
    path.parent.mkdir(parents=True, exist_ok=True)
    with path.open("w") as f:
        for row in rows:
            f.write(json.dumps(row) + "\n")


def build_manifest(
    cfg: BuildConfig,
    comment_stats: Counter,
    katago_stats: Counter,
    train_rows: Sequence[Dict[str, Any]],
    eval_rows: Sequence[Dict[str, Any]],
) -> Dict[str, Any]:
    return {
        "config": {
            **asdict(cfg),
            "thresholds": asdict(cfg.thresholds),
        },
        "comment_stats": dict(comment_stats),
        "katago_stats": dict(katago_stats),
        "train_rows": len(train_rows),
        "eval_rows": len(eval_rows),
        "created_at_epoch": time.time(),
    }


def run_build(cfg: BuildConfig) -> Dict[str, Any]:
    set_seed(cfg.seed)
    out_dir = Path(cfg.output_dir)
    out_dir.mkdir(parents=True, exist_ok=True)

    candidates, comment_stats = mine_comment_positions(cfg)
    print(f"[INFO] Selected {len(candidates)} comment-bearing positions for KataGo analysis")
    incremental_full_path = out_dir / "full.incremental.jsonl"
    if incremental_full_path.exists():
        incremental_full_path.unlink()
    rows, katago_stats = run_katago_queries(cfg, candidates, incremental_full_path=incremental_full_path)
    train_rows, eval_rows = split_rows_by_game(rows, cfg.eval_fraction, cfg.split_seed)

    full_path = out_dir / "full.jsonl"
    train_path = out_dir / "train.jsonl"
    eval_path = out_dir / "eval.jsonl"
    manifest_path = out_dir / "dataset_manifest.json"

    write_jsonl(full_path, rows)
    write_jsonl(train_path, train_rows)
    write_jsonl(eval_path, eval_rows)
    manifest = build_manifest(cfg, comment_stats, katago_stats, train_rows, eval_rows)
    manifest_path.write_text(json.dumps(manifest, indent=2))

    print(f"[DONE] Wrote {len(rows)} rows to {full_path}")
    print(f"[DONE] Train rows: {len(train_rows)} Eval rows: {len(eval_rows)}")
    return manifest


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Build a large human-commentary KataGo dataset from SGFs")
    parser.add_argument("--sgf-dir", required=True, type=str)
    parser.add_argument("--output-dir", required=True, type=str)
    parser.add_argument("--max-positions", type=int, default=10_000)
    parser.add_argument("--max-games", type=int, default=0)
    parser.add_argument("--max-positions-per-game", type=int, default=8)
    parser.add_argument("--min-comment-chars", type=int, default=20)
    parser.add_argument("--eval-fraction", type=float, default=0.10)
    parser.add_argument("--split-seed", type=int, default=42)
    parser.add_argument("--max-position-tokens", type=int, default=196)
    parser.add_argument("--katago-binary", required=True, type=str)
    parser.add_argument("--katago-model", required=True, type=str)
    parser.add_argument("--katago-config", required=True, type=str)
    parser.add_argument("--katago-visits", type=int, default=2000)
    parser.add_argument("--katago-inflight-limit", type=int, default=64)
    parser.add_argument("--query-timeout-sec", type=float, default=180.0)
    parser.add_argument("--report-during-search-every", type=float, default=0.0)
    parser.add_argument("--incremental-flush-every", type=int, default=50)
    parser.add_argument("--seed", type=int, default=42)
    return parser.parse_args()


def config_from_args(args: argparse.Namespace) -> BuildConfig:
    return BuildConfig(
        sgf_dir=args.sgf_dir,
        output_dir=args.output_dir,
        max_positions=args.max_positions,
        max_games=args.max_games,
        max_positions_per_game=args.max_positions_per_game,
        min_comment_chars=args.min_comment_chars,
        eval_fraction=args.eval_fraction,
        split_seed=args.split_seed,
        max_position_tokens=args.max_position_tokens,
        katago_binary=args.katago_binary,
        katago_model=args.katago_model,
        katago_config=args.katago_config,
        katago_visits=args.katago_visits,
        katago_inflight_limit=args.katago_inflight_limit,
        query_timeout_sec=args.query_timeout_sec,
        report_during_search_every=args.report_during_search_every,
        incremental_flush_every=args.incremental_flush_every,
        seed=args.seed,
    )


def main() -> None:
    args = parse_args()
    cfg = config_from_args(args)
    run_build(cfg)


if __name__ == "__main__":
    main()
