from __future__ import annotations

import argparse
import json
import math
import os
import random
import re
import select
import site
import subprocess
import sys
import tempfile
from dataclasses import asdict, dataclass, field
from pathlib import Path
from typing import Any, Dict, Iterable, Iterator, List, Sequence, Tuple

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

try:
    from sgfmill import boards as sgf_boards
    from sgfmill import sgf
    from sgfmill import sgf_moves
except ImportError:  # pragma: no cover - optional dependency
    sgf = None
    sgf_boards = None
    sgf_moves = None


SPECIAL_TOKENS = [
    "[PAD]",
    "[UNK]",
    "[BOS]",
    "[RAT]",
    "[CLAIM]",
    "[EOS]",
]
BIN_TOKENS = [f"V{i}" for i in range(10)]
VARIANT_NAMES = (
    "lm_only",
    "no_consistency_loss",
    "rationale_only",
    "full_consistency",
    "random_consistency",
)
SPECIAL_TOKENS.extend(BIN_TOKENS)

WORD_RE = re.compile(r"[A-Za-z0-9_]+|[^\w\s]")
BOARD_COLUMNS = "ABCDEFGHJKLMNOPQRSTUVWXYZ"
DEFAULT_RULES = "japanese"
MIN_POSITION_METADATA_TOKENS = 4


@dataclass
class Config:
    train_path: str = ""
    eval_path: str = ""
    max_seq_len: int = 256
    max_position_tokens: int = 196
    batch_size: int = 32
    epochs: int = 10
    lr: float = 3e-4
    weight_decay: float = 0.01
    d_model: int = 256
    n_layers: int = 4
    n_heads: int = 8
    d_ff: int = 1024
    dropout: float = 0.1
    consistency_weight: float = 0.5
    rationale_loss_weight: float = 1.0
    claim_loss_weight: float = 1.0
    counterfactual_samples: int = 256
    seed: int = 42
    output_csv: str = "katago_winprob_results.csv"
    output_markdown: str = ""
    variants: Tuple[str, ...] = VARIANT_NAMES
    smoke_test: bool = False
    smoke_train_size: int = 128
    smoke_eval_size: int = 64
    num_bins: int = 10


@dataclass
class PreprocessConfig:
    sgf_dir: str
    output_path: str
    sample_every_n_moves: int = 20
    max_positions_per_game: int = 8
    max_games: int = 0
    max_position_tokens: int = 196
    num_bins: int = 10
    default_rules: str = DEFAULT_RULES
    seed: int = 42
    katago_binary: str = ""
    katago_model: str = ""
    katago_config: str = ""
    katago_visits: int = 32
    katago_extra_args: Tuple[str, ...] = field(default_factory=tuple)


def set_seed(seed: int) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


def simple_tokenize_text(text: str) -> List[str]:
    tokens = WORD_RE.findall(text.strip())
    if not tokens:
        return ["no_commentary"]
    return [tok.lower() for tok in tokens]


def probability_to_bin(prob: float, num_bins: int = 10) -> int:
    prob = min(1.0, max(0.0, float(prob)))
    return min(num_bins - 1, max(0, round(prob * (num_bins - 1))))


def normalize_komi_token(komi: float) -> str:
    komi_str = str(komi).replace("-", "NEG_").replace(".", "_")
    return f"KOMI_{komi_str}"


def rule_token(rules: str) -> str:
    return "RULE_" + re.sub(r"[^A-Za-z0-9]+", "_", rules.strip().upper())


def maybe_float(value: Any, default: float) -> float:
    try:
        return float(value)
    except (TypeError, ValueError):
        return default


def maybe_int(value: Any, default: int) -> int:
    try:
        return int(value)
    except (TypeError, ValueError):
        return default


def clamp_probability(prob: float) -> float:
    return min(1.0, max(0.0, float(prob)))


def sgf_property_or_default(node: Any, prop: str, default: Any) -> Any:
    try:
        return node.get(prop)
    except KeyError:
        return default


def row_col_to_gtp(row: int, col: int, board_size: int) -> str:
    if board_size > len(BOARD_COLUMNS):
        raise ValueError(f"Board size {board_size} exceeds supported symbolic columns")
    return f"{BOARD_COLUMNS[col]}{board_size - row}"


def build_position_tokens(
    board_size: int,
    to_move: str,
    komi: float,
    rules: str,
    stones: Dict[str, Sequence[str]],
    max_position_tokens: int | None = None,
) -> List[str]:
    black_coords = sorted(str(coord) for coord in stones.get("black", []))
    white_coords = sorted(str(coord) for coord in stones.get("white", []))
    tokens = [
        f"SZ{board_size}",
        f"TM_{to_move.upper()}",
        normalize_komi_token(komi),
        rule_token(rules or DEFAULT_RULES),
    ]
    tokens.extend(f"B_{coord.upper()}" for coord in black_coords)
    tokens.extend(f"W_{coord.upper()}" for coord in white_coords)
    if max_position_tokens is not None:
        tokens = tokens[:max_position_tokens]
    return tokens


def stone_heuristics(stones: Dict[str, Sequence[str]], board_size: int) -> Dict[str, float]:
    black = [coord.upper() for coord in stones.get("black", [])]
    white = [coord.upper() for coord in stones.get("white", [])]
    black_count = len(black)
    white_count = len(white)
    total = max(1, black_count + white_count)

    center = (board_size - 1) / 2.0

    def center_score(coords: Sequence[str]) -> float:
        if not coords:
            return 0.0
        score = 0.0
        for coord in coords:
            col = BOARD_COLUMNS.index(coord[0])
            row = board_size - int(coord[1:])
            dist = abs(row - center) + abs(col - center)
            score += 1.0 - dist / max(1.0, board_size)
        return score / len(coords)

    return {
        "black_count": float(black_count),
        "white_count": float(white_count),
        "stone_balance": (black_count - white_count) / total,
        "center_balance": center_score(black) - center_score(white),
    }


def templated_rationale(
    win_prob: float,
    to_move: str,
    stones: Dict[str, Sequence[str]],
    board_size: int,
) -> str:
    heur = stone_heuristics(stones, board_size)
    balance = heur["stone_balance"]
    center = heur["center_balance"]
    side = "Black" if win_prob >= 0.5 else "White"
    other = "White" if side == "Black" else "Black"
    pressure = "stable overall position" if abs(center) < 0.1 else "stronger central influence"
    if win_prob >= 0.8:
        return f"{side} is clearly ahead with the {pressure} and fewer urgent weaknesses."
    if win_prob >= 0.6:
        return f"{side} is ahead because the shape looks steadier and the board flow favors {side.lower()}."
    if win_prob <= 0.2:
        return f"{other} is clearly ahead while {side.lower()} still has multiple unsettled groups to solve."
    if win_prob <= 0.4:
        return f"{other} appears ahead with the cleaner territorial picture and the more reliable groups."
    if abs(balance) > 0.1:
        return f"The game looks close, but the stone balance is still contested and both sides have unresolved fights."
    return f"The position looks close, with unsettled areas and no side fully in control yet."


def load_jsonl(path: str | Path) -> List[Dict[str, Any]]:
    rows: List[Dict[str, Any]] = []
    with Path(path).open() as f:
        for line in f:
            line = line.strip()
            if line:
                rows.append(json.loads(line))
    return rows


def write_jsonl(path: str | Path, rows: Sequence[Dict[str, Any]]) -> None:
    out_path = Path(path)
    out_path.parent.mkdir(parents=True, exist_ok=True)
    with out_path.open("w") as f:
        for row in rows:
            f.write(json.dumps(row) + "\n")


def normalize_example(
    row: Dict[str, Any],
    num_bins: int,
    max_position_tokens: int,
) -> Dict[str, Any]:
    board_size = maybe_int(row.get("board_size"), 19)
    to_move = str(row.get("to_move", "B")).upper()
    rules = str(row.get("rules", DEFAULT_RULES))
    komi = maybe_float(row.get("komi"), 6.5)
    stones = row.get("stones") or {"black": [], "white": []}
    win_prob = clamp_probability(row.get("win_prob", 0.5))
    win_prob_bin = maybe_int(row.get("win_prob_bin"), probability_to_bin(win_prob, num_bins))
    rationale_text = str(
        row.get("rationale_text")
        or templated_rationale(win_prob, to_move, stones, board_size)
    )
    position_tokens = list(
        row.get("position_tokens")
        or build_position_tokens(
            board_size=board_size,
            to_move=to_move,
            komi=komi,
            rules=rules,
            stones=stones,
            max_position_tokens=max_position_tokens,
        )
    )
    return {
        "id": str(row.get("id", f"example_{random.randrange(1_000_000)}")),
        "board_size": board_size,
        "to_move": to_move,
        "rules": rules,
        "komi": komi,
        "stones": {
            "black": [coord.upper() for coord in stones.get("black", [])],
            "white": [coord.upper() for coord in stones.get("white", [])],
        },
        "position_tokens": position_tokens[:max_position_tokens],
        "rationale_text": rationale_text,
        "rationale_tokens": simple_tokenize_text(rationale_text),
        "win_prob": win_prob,
        "win_prob_bin": min(num_bins - 1, max(0, win_prob_bin)),
    }


def build_vocab(*datasets: Sequence[Dict[str, Any]]) -> Tuple[Dict[str, int], Dict[int, str]]:
    vocab = set(SPECIAL_TOKENS)
    for rows in datasets:
        for row in rows:
            vocab.update(row["position_tokens"])
            vocab.update(row["rationale_tokens"])
    tok2id = {tok: i for i, tok in enumerate(sorted(vocab))}
    id2tok = {i: tok for tok, i in tok2id.items()}
    return tok2id, id2tok


def truncate_sequence_parts(
    position_tokens: Sequence[str],
    rationale_tokens: Sequence[str],
    cfg: Config,
) -> Tuple[List[str], List[str]]:
    max_body_tokens = cfg.max_seq_len - 5
    position = list(position_tokens[: cfg.max_position_tokens])
    rationale = list(rationale_tokens) or ["no_commentary"]
    if len(position) + len(rationale) <= max_body_tokens:
        return position, rationale

    overflow = len(position) + len(rationale) - max_body_tokens
    if len(position) > MIN_POSITION_METADATA_TOKENS:
        removable = len(position) - MIN_POSITION_METADATA_TOKENS
        cut = min(overflow, removable)
        position = position[: len(position) - cut]
        overflow -= cut

    if overflow > 0:
        rationale = rationale[: max(1, len(rationale) - overflow)]

    if len(position) + len(rationale) > max_body_tokens:
        position = position[: max(0, max_body_tokens - len(rationale))]

    if not rationale:
        rationale = ["no_commentary"]
    return position, rationale


def encode_example(
    example: Dict[str, Any],
    tok2id: Dict[str, int],
    cfg: Config,
) -> Dict[str, Any]:
    position_tokens, rationale_tokens = truncate_sequence_parts(
        example["position_tokens"],
        example["rationale_tokens"],
        cfg,
    )
    claim_token = f"V{example['win_prob_bin']}"
    seq_tokens = ["[BOS]"] + position_tokens + ["[RAT]"] + rationale_tokens + ["[CLAIM]", claim_token, "[EOS]"]
    rat_start = len(position_tokens) + 2
    rat_end = rat_start + len(rationale_tokens)
    # Use the hidden state at [CLAIM], which predicts the claim token autoregressively.
    claim_position = rat_end
    unk_id = tok2id["[UNK]"]
    ids = [tok2id.get(tok, unk_id) for tok in seq_tokens]
    return {
        "id": example["id"],
        "input_ids": torch.tensor(ids, dtype=torch.long),
        "claim_bin": example["win_prob_bin"],
        "scalar": example["win_prob"],
        "pad_id": tok2id["[PAD]"],
        "rat_positions": (rat_start, rat_end),
        "claim_position": claim_position,
        "position_tokens": position_tokens,
        "rationale_tokens": rationale_tokens,
        "claim_token": claim_token,
    }


class GoWinProbDataset(Dataset):
    def __init__(self, rows: Sequence[Dict[str, Any]], tok2id: Dict[str, int], cfg: Config):
        self.raw_examples = list(rows)
        self.pad_id = tok2id["[PAD]"]
        self.encoded = [encode_example(row, tok2id, cfg) for row in rows]

    def __len__(self) -> int:
        return len(self.encoded)

    def __getitem__(self, idx: int) -> Dict[str, Any]:
        return self.encoded[idx]


def collate_fn(batch: Sequence[Dict[str, Any]]) -> Dict[str, Any]:
    max_len = max(len(ex["input_ids"]) for ex in batch)
    pad_id = batch[0]["pad_id"]
    input_ids: List[torch.Tensor] = []
    targets: List[torch.Tensor] = []
    attention_masks: List[torch.Tensor] = []
    claim_bins: List[int] = []
    scalars: List[float] = []
    rat_positions: List[Tuple[int, int]] = []
    claim_positions: List[int] = []

    for ex in batch:
        ids = ex["input_ids"]
        pad_len = max_len - len(ids)
        padded = F.pad(ids, (0, pad_len), value=pad_id)
        input_ids.append(padded[:-1])
        targets.append(padded[1:])
        attention_masks.append(torch.tensor([1] * (len(ids) - 1) + [0] * pad_len, dtype=torch.long))
        claim_bins.append(ex["claim_bin"])
        scalars.append(ex["scalar"])
        rat_positions.append(ex["rat_positions"])
        claim_positions.append(ex["claim_position"])

    return {
        "input_ids": torch.stack(input_ids),
        "targets": torch.stack(targets),
        "attention_mask": torch.stack(attention_masks),
        "claim_bins": torch.tensor(claim_bins, dtype=torch.long),
        "scalars": torch.tensor(scalars, dtype=torch.float32),
        "rat_positions": rat_positions,
        "claim_positions": claim_positions,
    }


class DecoderBlock(nn.Module):
    def __init__(self, d_model: int, n_heads: int, d_ff: int, dropout: float):
        super().__init__()
        self.ln1 = nn.LayerNorm(d_model)
        self.attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)
        self.ln2 = nn.LayerNorm(d_model)
        self.ff = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model),
            nn.Dropout(dropout),
        )

    def forward(
        self,
        x: torch.Tensor,
        key_padding_mask: torch.Tensor | None = None,
        attn_mask: torch.Tensor | None = None,
    ) -> torch.Tensor:
        h = self.ln1(x)
        attn_out, _ = self.attn(
            h,
            h,
            h,
            key_padding_mask=key_padding_mask,
            attn_mask=attn_mask,
            need_weights=False,
        )
        x = x + attn_out
        x = x + self.ff(self.ln2(x))
        return x


class TinyWinProbLM(nn.Module):
    def __init__(self, vocab_size: int, cfg: Config):
        super().__init__()
        self.token_emb = nn.Embedding(vocab_size, cfg.d_model)
        self.pos_emb = nn.Embedding(cfg.max_seq_len, cfg.d_model)
        self.blocks = nn.ModuleList(
            [DecoderBlock(cfg.d_model, cfg.n_heads, cfg.d_ff, cfg.dropout) for _ in range(cfg.n_layers)]
        )
        self.ln_f = nn.LayerNorm(cfg.d_model)
        self.lm_head = nn.Linear(cfg.d_model, vocab_size, bias=False)
        self.scalar_head = nn.Linear(cfg.d_model, 1)
        self.bin_head = nn.Linear(cfg.d_model, cfg.num_bins)
        self.lm_head.weight = self.token_emb.weight
        self.apply(self._init_weights)

    def _init_weights(self, module: nn.Module) -> None:
        if isinstance(module, (nn.Linear, nn.Embedding)):
            nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if isinstance(module, nn.Linear) and module.bias is not None:
                nn.init.zeros_(module.bias)
        elif isinstance(module, nn.LayerNorm):
            nn.init.ones_(module.weight)
            nn.init.zeros_(module.bias)

    @staticmethod
    def causal_mask(length: int, device: torch.device) -> torch.Tensor:
        return torch.triu(torch.ones(length, length, dtype=torch.bool, device=device), diagonal=1)

    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        rat_positions: Sequence[Tuple[int, int]],
        claim_positions: Sequence[int],
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        batch_size, seq_len = input_ids.shape
        positions = torch.arange(seq_len, device=input_ids.device).unsqueeze(0).expand(batch_size, seq_len)
        x = self.token_emb(input_ids) + self.pos_emb(positions)
        key_padding_mask = attention_mask == 0
        attn_mask = self.causal_mask(seq_len, input_ids.device)
        for block in self.blocks:
            x = block(x, key_padding_mask=key_padding_mask, attn_mask=attn_mask)
        h = self.ln_f(x)
        lm_logits = self.lm_head(h)

        pooled_rationales: List[torch.Tensor] = []
        claim_states: List[torch.Tensor] = []
        for idx in range(batch_size):
            rat_start, rat_end = rat_positions[idx]
            pooled_rationales.append(h[idx, rat_start:rat_end, :].mean(dim=0))
            claim_states.append(h[idx, claim_positions[idx], :])
        rationale_hidden = torch.stack(pooled_rationales)
        claim_hidden = torch.stack(claim_states)
        scalar_logits = self.scalar_head(claim_hidden).squeeze(-1)
        bin_logits = self.bin_head(rationale_hidden)
        return lm_logits, scalar_logits, bin_logits


def pearson_r(x: Sequence[float], y: Sequence[float]) -> float:
    if len(x) < 2 or len(y) < 2:
        return 0.0
    x_arr = np.asarray(x, dtype=np.float64)
    y_arr = np.asarray(y, dtype=np.float64)
    x_std = x_arr.std()
    y_std = y_arr.std()
    if x_std == 0.0 or y_std == 0.0:
        return 0.0
    return float(np.corrcoef(x_arr, y_arr)[0, 1])


def average_ranks(values: Sequence[float]) -> np.ndarray:
    arr = np.asarray(values, dtype=np.float64)
    order = np.argsort(arr, kind="mergesort")
    ranks = np.empty(len(arr), dtype=np.float64)
    start = 0
    while start < len(arr):
        end = start + 1
        while end < len(arr) and arr[order[end]] == arr[order[start]]:
            end += 1
        avg_rank = (start + end - 1) / 2.0 + 1.0
        ranks[order[start:end]] = avg_rank
        start = end
    return ranks


def spearman_r(x: Sequence[float], y: Sequence[float]) -> float:
    if len(x) < 2 or len(y) < 2:
        return 0.0
    return pearson_r(average_ranks(x), average_ranks(y))


def evaluate(
    model: TinyWinProbLM,
    loader: DataLoader,
    device: torch.device,
) -> Dict[str, float]:
    model.eval()
    token_total = 0
    token_correct = 0
    claim_bin_correct = 0
    scalar_sse = 0.0
    scalar_abs_error = 0.0
    scalar_preds: List[float] = []
    scalar_targets: List[float] = []

    with torch.no_grad():
        for batch in loader:
            input_ids = batch["input_ids"].to(device)
            targets = batch["targets"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            claim_bins = batch["claim_bins"].to(device)
            scalars = batch["scalars"].to(device)

            lm_logits, scalar_logits, bin_logits = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                rat_positions=batch["rat_positions"],
                claim_positions=batch["claim_positions"],
            )
            pred_tokens = lm_logits.argmax(dim=-1)
            active = attention_mask.bool()
            token_correct += ((pred_tokens == targets) & active).sum().item()
            token_total += active.sum().item()
            claim_bin_correct += (bin_logits.argmax(dim=-1) == claim_bins).sum().item()

            scalar_pred = torch.sigmoid(scalar_logits)
            scalar_sse += F.mse_loss(scalar_pred, scalars, reduction="sum").item()
            scalar_abs_error += torch.abs(scalar_pred - scalars).sum().item()
            scalar_preds.extend(scalar_pred.detach().cpu().tolist())
            scalar_targets.extend(scalars.detach().cpu().tolist())

    dataset_size = max(1, len(loader.dataset))
    return {
        "token_acc": token_correct / max(1, token_total),
        "claim_bin_acc": claim_bin_correct / dataset_size,
        "scalar_mse": scalar_sse / dataset_size,
        "mae_winprob": scalar_abs_error / dataset_size,
        "pearson_r_winprob": pearson_r(scalar_preds, scalar_targets),
        "spearman_r_winprob": spearman_r(scalar_preds, scalar_targets),
    }


def encode_counterfactual_sequence(
    position_tokens: Sequence[str],
    rationale_tokens: Sequence[str],
    claim_bin: int,
    tok2id: Dict[str, int],
    cfg: Config,
) -> Tuple[torch.Tensor, Tuple[int, int], int]:
    truncated_position, truncated_rationale = truncate_sequence_parts(position_tokens, rationale_tokens, cfg)
    claim_token = f"V{claim_bin}"
    seq_tokens = ["[BOS]"] + truncated_position + ["[RAT]"] + truncated_rationale + ["[CLAIM]", claim_token, "[EOS]"]
    rat_start = len(truncated_position) + 2
    rat_end = rat_start + len(truncated_rationale)
    # Match training: read the claim head from the [CLAIM] position, not the gold claim token.
    claim_position = rat_end
    unk_id = tok2id["[UNK]"]
    ids = [tok2id.get(tok, unk_id) for tok in seq_tokens[:-1]]
    return torch.tensor([ids], dtype=torch.long), (rat_start, rat_end), claim_position


def evaluate_counterfactual(
    model: TinyWinProbLM,
    dataset: GoWinProbDataset,
    tok2id: Dict[str, int],
    cfg: Config,
    device: torch.device,
) -> Dict[str, float]:
    model.eval()
    rng = random.Random(cfg.seed)
    indices = list(range(len(dataset.raw_examples)))
    rng.shuffle(indices)
    indices = indices[: min(cfg.counterfactual_samples, len(indices))]

    follows_swap = 0
    follows_orig = 0
    scalar_mse_to_swap = 0.0
    evaluated = 0

    with torch.no_grad():
        for idx in indices:
            example_a = dataset.raw_examples[idx]
            candidates = [row for row in dataset.raw_examples if row["win_prob_bin"] != example_a["win_prob_bin"]]
            if not candidates:
                continue
            example_b = rng.choice(candidates)
            input_ids, rat_span, claim_position = encode_counterfactual_sequence(
                position_tokens=example_a["position_tokens"],
                rationale_tokens=example_b["rationale_tokens"],
                claim_bin=example_a["win_prob_bin"],
                tok2id=tok2id,
                cfg=cfg,
            )
            attention_mask = torch.ones_like(input_ids)
            _, scalar_logits, bin_logits = model(
                input_ids=input_ids.to(device),
                attention_mask=attention_mask.to(device),
                rat_positions=[rat_span],
                claim_positions=[claim_position],
            )
            pred_bin = int(bin_logits.argmax(dim=-1).item())
            pred_scalar = float(torch.sigmoid(scalar_logits).item())

            if pred_bin == example_b["win_prob_bin"]:
                follows_swap += 1
            if pred_bin == example_a["win_prob_bin"]:
                follows_orig += 1
            scalar_mse_to_swap += (pred_scalar - example_b["win_prob"]) ** 2
            evaluated += 1

    denom = max(1, evaluated)
    return {
        "cfact_cls_follows_swap": follows_swap / denom,
        "cfact_cls_follows_orig": follows_orig / denom,
        "cfact_scalar_mse_to_swap": scalar_mse_to_swap / denom,
    }


def train_variant(
    cfg: Config,
    variant: str,
    train_loader: DataLoader,
    eval_loader: DataLoader,
    train_dataset: GoWinProbDataset,
    eval_dataset: GoWinProbDataset,
    tok2id: Dict[str, int],
) -> Dict[str, float]:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = TinyWinProbLM(len(tok2id), cfg).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)

    for _ in range(cfg.epochs):
        model.train()
        for batch in train_loader:
            input_ids = batch["input_ids"].to(device)
            targets = batch["targets"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            claim_bins = batch["claim_bins"].to(device)
            scalars = batch["scalars"].to(device)

            lm_logits, scalar_logits, bin_logits = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                rat_positions=batch["rat_positions"],
                claim_positions=batch["claim_positions"],
            )
            vocab_size = lm_logits.size(-1)
            lm_loss = F.cross_entropy(
                lm_logits.reshape(-1, vocab_size),
                targets.reshape(-1),
                ignore_index=train_dataset.pad_id,
            )
            scalar_pred = torch.sigmoid(scalar_logits)
            scalar_loss = F.mse_loss(scalar_pred, scalars)
            consistency_loss = F.cross_entropy(bin_logits, claim_bins)

            if variant == "lm_only":
                loss = cfg.rationale_loss_weight * lm_loss
            elif variant == "no_consistency_loss":
                loss = cfg.rationale_loss_weight * lm_loss + cfg.claim_loss_weight * scalar_loss
            elif variant == "rationale_only":
                loss = cfg.rationale_loss_weight * lm_loss + cfg.consistency_weight * consistency_loss
            elif variant == "full_consistency":
                loss = (
                    cfg.rationale_loss_weight * lm_loss
                    + cfg.claim_loss_weight * scalar_loss
                    + cfg.consistency_weight * consistency_loss
                )
            elif variant == "random_consistency":
                random_bins = torch.randint(0, cfg.num_bins, claim_bins.shape, device=device)
                loss = (
                    cfg.rationale_loss_weight * lm_loss
                    + cfg.claim_loss_weight * scalar_loss
                    + cfg.consistency_weight * F.cross_entropy(bin_logits, random_bins)
                )
            else:
                raise ValueError(f"Unknown variant: {variant}")

            optimizer.zero_grad(set_to_none=True)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

    metrics = evaluate(model, eval_loader, device)
    cf_metrics = evaluate_counterfactual(model, eval_dataset, tok2id, cfg, device)
    return {"variant": variant, **metrics, **cf_metrics}


def dataframe_to_markdown(df: pd.DataFrame) -> str:
    columns = list(df.columns)
    rows = [columns]
    for _, row in df.iterrows():
        formatted: List[str] = []
        for value in row.tolist():
            if isinstance(value, float):
                formatted.append(f"{value:.6f}")
            else:
                formatted.append(str(value))
        rows.append(formatted)

    widths = [max(len(str(row[idx])) for row in rows) for idx in range(len(columns))]

    def fmt_row(values: Sequence[str]) -> str:
        return "| " + " | ".join(str(values[idx]).ljust(widths[idx]) for idx in range(len(values))) + " |"

    header = fmt_row(rows[0])
    separator = "| " + " | ".join("-" * widths[idx] for idx in range(len(columns))) + " |"
    body = "\n".join(fmt_row(row) for row in rows[1:])
    return "\n".join([header, separator, body])


def variant_interpretation(row: pd.Series) -> str:
    variant = row["variant"]
    claim_acc = float(row["claim_bin_acc"])
    follows_orig = float(row["cfact_cls_follows_orig"])
    follows_swap = float(row["cfact_cls_follows_swap"])
    if variant == "lm_only":
        return "Pure LM baseline: should model commentary tokens but usually underperforms on calibrated oracle claims."
    if variant == "no_consistency_loss":
        return "LM + scalar baseline: can learn continuous win probability without directly tying rationale-pooled states to the bin claim."
    if variant == "rationale_only":
        return (
            "LM + rationale consistency only: strongest evidence that rationale-pooled hidden states alone can support the inline claim."
        )
    if variant == "full_consistency":
        return (
            "Full objective: if this wins on bin accuracy and scalar quality, the shared hidden-state mechanism is working end to end."
        )
    if follows_orig > follows_swap and claim_acc > 0.7:
        return "Control variant with random consistency labels; strong performance here would be suspicious."
    return "Control variant behavior."


def format_results_markdown(
    df: pd.DataFrame,
    cfg: Config,
    train_size: int,
    eval_size: int,
    train_path: str,
    eval_path: str,
) -> str:
    lines = [
        "# KataGo Win-Probability Claim-Consistency Experiment",
        "",
        "## Setup Summary",
        "",
        f"- Train dataset: `{train_path}`",
        f"- Eval dataset: `{eval_path}`",
        f"- Train rows: `{train_size}`",
        f"- Eval rows: `{eval_size}`",
        f"- Variants: `{', '.join(cfg.variants)}`",
        f"- Batch size: `{cfg.batch_size}`",
        f"- Epochs: `{cfg.epochs}`",
        f"- LR: `{cfg.lr}`",
        f"- Model: `d_model={cfg.d_model}`, `layers={cfg.n_layers}`, `heads={cfg.n_heads}`, `d_ff={cfg.d_ff}`",
        f"- Max sequence length: `{cfg.max_seq_len}`",
        f"- Consistency weight: `{cfg.consistency_weight}`",
        "",
        "## Results",
        "",
        dataframe_to_markdown(df),
        "",
        "## Variant Notes",
        "",
    ]
    for _, row in df.iterrows():
        lines.append(f"- `{row['variant']}`: {variant_interpretation(row)}")
    lines.extend(
        [
            "",
            "## Interpretation Notes",
            "",
            "- The earlier synthetic generated-rationale scalar experiment showed perfect claim-bin accuracy and perfect orig-following counterfactual behavior for `rationale_only` and `full_consistency`.",
            "- This KataGo version is a stronger domain test because the language model emits both commentary and an inline claim while only the claim is supervised by an oracle.",
            "- FEVER-from-scratch was a weaker bridge task because it supervised evidence classification rather than oracle-graded generated claims.",
        ]
    )
    return "\n".join(lines) + "\n"


def generate_mock_example(idx: int, rng: random.Random, board_size: int = 9, num_bins: int = 10) -> Dict[str, Any]:
    occupied: set[str] = set()
    black: List[str] = []
    white: List[str] = []
    total_stones = rng.randint(8, min(board_size * board_size // 2, 28))
    for stone_idx in range(total_stones):
        while True:
            row = rng.randrange(board_size)
            col = rng.randrange(board_size)
            coord = row_col_to_gtp(row, col, board_size)
            if coord not in occupied:
                occupied.add(coord)
                break
        if stone_idx % 2 == 0:
            black.append(coord)
        else:
            white.append(coord)

    center = (board_size - 1) / 2.0

    def avg_center(coords: Sequence[str]) -> float:
        if not coords:
            return 0.0
        total = 0.0
        for coord in coords:
            col = BOARD_COLUMNS.index(coord[0])
            row = board_size - int(coord[1:])
            dist = abs(row - center) + abs(col - center)
            total += 1.0 - dist / max(1.0, board_size)
        return total / len(coords)

    black_center = avg_center(black)
    white_center = avg_center(white)
    stone_balance = (len(black) - len(white)) / max(1, len(black) + len(white))
    center_balance = black_center - white_center
    raw_score = 1.6 * stone_balance + 1.2 * center_balance + rng.uniform(-0.15, 0.15)
    win_prob = 1.0 / (1.0 + math.exp(-raw_score))
    to_move = "B" if rng.random() < 0.5 else "W"
    stones = {"black": sorted(black), "white": sorted(white)}
    rationale_text = templated_rationale(win_prob, to_move, stones, board_size)
    return {
        "id": f"mock_game_{idx}",
        "board_size": board_size,
        "to_move": to_move,
        "rules": DEFAULT_RULES,
        "komi": 6.5,
        "stones": stones,
        "position_tokens": build_position_tokens(board_size, to_move, 6.5, DEFAULT_RULES, stones),
        "rationale_text": rationale_text,
        "win_prob": win_prob,
        "win_prob_bin": probability_to_bin(win_prob, num_bins),
    }


def write_mock_dataset(path: str | Path, size: int, seed: int, num_bins: int = 10) -> str:
    rng = random.Random(seed)
    rows = [generate_mock_example(idx, rng, num_bins=num_bins) for idx in range(size)]
    write_jsonl(path, rows)
    return str(path)


def resolve_dataset_paths(cfg: Config) -> Tuple[str, str]:
    if cfg.train_path and cfg.eval_path:
        return cfg.train_path, cfg.eval_path
    if not cfg.smoke_test:
        raise ValueError("train_path and eval_path are required unless smoke_test is enabled")

    temp_dir = Path(tempfile.mkdtemp(prefix="katago_winprob_smoke_"))
    train_path = temp_dir / "train.jsonl"
    eval_path = temp_dir / "eval.jsonl"
    write_mock_dataset(train_path, cfg.smoke_train_size, cfg.seed, cfg.num_bins)
    write_mock_dataset(eval_path, cfg.smoke_eval_size, cfg.seed + 1, cfg.num_bins)
    return str(train_path), str(eval_path)


def run_experiment(cfg: Config) -> pd.DataFrame:
    set_seed(cfg.seed)
    train_path, eval_path = resolve_dataset_paths(cfg)
    train_rows = [
        normalize_example(row, num_bins=cfg.num_bins, max_position_tokens=cfg.max_position_tokens)
        for row in load_jsonl(train_path)
    ]
    eval_rows = [
        normalize_example(row, num_bins=cfg.num_bins, max_position_tokens=cfg.max_position_tokens)
        for row in load_jsonl(eval_path)
    ]

    tok2id, _ = build_vocab(train_rows, eval_rows)
    train_dataset = GoWinProbDataset(train_rows, tok2id, cfg)
    eval_dataset = GoWinProbDataset(eval_rows, tok2id, cfg)
    train_loader = DataLoader(train_dataset, batch_size=cfg.batch_size, shuffle=True, collate_fn=collate_fn)
    eval_loader = DataLoader(eval_dataset, batch_size=cfg.batch_size, shuffle=False, collate_fn=collate_fn)

    rows: List[Dict[str, float]] = []
    for variant in cfg.variants:
        rows.append(
            train_variant(
                cfg=cfg,
                variant=variant,
                train_loader=train_loader,
                eval_loader=eval_loader,
                train_dataset=train_dataset,
                eval_dataset=eval_dataset,
                tok2id=tok2id,
            )
        )

    df = pd.DataFrame(rows)
    output_csv = Path(cfg.output_csv)
    output_csv.parent.mkdir(parents=True, exist_ok=True)
    df.to_csv(output_csv, index=False)

    output_md = Path(cfg.output_markdown) if cfg.output_markdown else output_csv.with_suffix(".md")
    output_md.write_text(
        format_results_markdown(
            df=df,
            cfg=cfg,
            train_size=len(train_rows),
            eval_size=len(eval_rows),
            train_path=train_path,
            eval_path=eval_path,
        )
    )
    return df


def board_to_stones(board: Any, board_size: int) -> Dict[str, List[str]]:
    stones = {"black": [], "white": []}
    for row in range(board_size):
        for col in range(board_size):
            value = board.get(row, col)
            if value == "b":
                stones["black"].append(row_col_to_gtp(row, col, board_size))
            elif value == "w":
                stones["white"].append(row_col_to_gtp(row, col, board_size))
    return stones


def sgf_move_to_gtp(move: tuple[int, int] | None, board_size: int) -> str:
    if move is None:
        return "pass"
    row, col = move
    return row_col_to_gtp(row, col, board_size)


def sampled_turn_numbers(total_moves: int, cfg: PreprocessConfig) -> List[int]:
    if total_moves <= 0:
        return []
    turns = list(range(1, total_moves + 1, max(1, cfg.sample_every_n_moves)))
    if cfg.max_positions_per_game > 0:
        turns = turns[: cfg.max_positions_per_game]
    return turns


class KataGoAnalysisClient:
    def __init__(
        self,
        binary_path: str,
        model_path: str,
        config_path: str,
        extra_args: Sequence[str] = (),
        visits: int = 1000,
    ):
        self.binary_path = binary_path
        self.model_path = model_path
        self.config_path = config_path
        self.extra_args = list(extra_args)
        self.visits = visits
        self.proc: subprocess.Popen[str] | None = None

    def _subprocess_env(self) -> Dict[str, str]:
        env = os.environ.copy()
        # Official Linux KataGo release binaries are packaged as AppImages.
        # Modal containers do not provide FUSE, so force extract-and-run mode.
        env.setdefault("APPIMAGE_EXTRACT_AND_RUN", "1")
        lib_dirs: List[str] = []
        for existing in env.get("LD_LIBRARY_PATH", "").split(":"):
            if existing:
                lib_dirs.append(existing)

        search_roots = {Path(sys.prefix)}
        for path_str in sys.path:
            if path_str:
                search_roots.add(Path(path_str))
        for path_str in site.getsitepackages():
            search_roots.add(Path(path_str))

        candidate_dirs: set[Path] = set()
        for root in search_roots:
            if not root.exists():
                continue
            direct_lib = root / "lib"
            if direct_lib.exists():
                candidate_dirs.add(direct_lib)
            for pattern in [
                "torch/lib",
                "nvidia/*/lib",
            ]:
                for match in root.glob(pattern):
                    if match.is_dir():
                        candidate_dirs.add(match)

        needed_patterns = (
            "libcublas.so*",
            "libcudnn.so*",
            "libcudart.so*",
            "libcusolver.so*",
        )
        for directory in sorted(candidate_dirs):
            try:
                if any(any(directory.glob(pattern)) for pattern in needed_patterns):
                    lib_dirs.append(str(directory))
            except OSError:
                continue

        if lib_dirs:
            env["LD_LIBRARY_PATH"] = ":".join(dict.fromkeys(lib_dirs))
        return env

    def __enter__(self) -> "KataGoAnalysisClient":
        command = [
            self.binary_path,
            "analysis",
            "-model",
            self.model_path,
            "-config",
            self.config_path,
            *self.extra_args,
        ]
        self.proc = subprocess.Popen(
            command,
            stdin=subprocess.PIPE,
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
            text=True,
            bufsize=1,
            env=self._subprocess_env(),
        )
        return self

    def __exit__(self, exc_type, exc, tb) -> None:
        if self.proc is not None:
            try:
                if self.proc.stdin is not None:
                    self.proc.stdin.close()
            except Exception:
                pass
            try:
                self.proc.wait(timeout=5)
            except subprocess.TimeoutExpired:
                self.proc.terminate()
                self.proc.wait(timeout=5)

    def _read_ready_lines(self, timeout: float) -> List[str]:
        if self.proc is None or self.proc.stdout is None:
            raise RuntimeError("KataGo process is not running")
        ready, _, _ = select.select([self.proc.stdout], [], [], timeout)
        if not ready:
            return []
        lines: List[str] = []
        while True:
            line = self.proc.stdout.readline()
            if line == "":
                break
            lines.append(line.rstrip("\n"))
            ready, _, _ = select.select([self.proc.stdout], [], [], 0)
            if not ready:
                break
        return lines

    def _poll_stderr(self) -> str:
        if self.proc is None or self.proc.stderr is None:
            return ""
        ready, _, _ = select.select([self.proc.stderr], [], [], 0)
        if not ready:
            return ""
        chunks: List[str] = []
        while True:
            line = self.proc.stderr.readline()
            if line == "":
                break
            chunks.append(line.rstrip("\n"))
            ready, _, _ = select.select([self.proc.stderr], [], [], 0)
            if not ready:
                break
        return "\n".join(chunks)

    def query_game_win_probs(
        self,
        game_record: Dict[str, Any],
        turn_numbers: Sequence[int],
    ) -> Dict[int, float]:
        if self.proc is None or self.proc.stdin is None or self.proc.stdout is None:
            raise RuntimeError("KataGo process is not running")
        query = {
            "id": game_record["id_prefix"],
            "initialStones": game_record["initial_stones"],
            "moves": game_record["moves"],
            "rules": game_record["rules"],
            "komi": game_record["komi"],
            "boardXSize": game_record["board_size"],
            "boardYSize": game_record["board_size"],
            "analyzeTurns": list(turn_numbers),
            "maxVisits": self.visits,
        }
        if game_record.get("initial_player"):
            query["initialPlayer"] = game_record["initial_player"]
        self.proc.stdin.write(json.dumps(query) + "\n")
        self.proc.stdin.flush()

        expected = set(turn_numbers)
        outputs: Dict[int, float] = {}
        non_json_lines: List[str] = []
        while expected - outputs.keys():
            lines = self._read_ready_lines(timeout=60.0)
            if not lines:
                stderr = self._poll_stderr()
                raise RuntimeError(
                    f"Timed out waiting for KataGo analysis results for query {game_record['id_prefix']}.\nSTDERR:\n{stderr}"
                )
            for line in lines:
                stripped = line.strip()
                if not stripped:
                    continue
                try:
                    payload = json.loads(stripped)
                except json.JSONDecodeError:
                    non_json_lines.append(stripped)
                    continue
                if payload.get("id") != game_record["id_prefix"]:
                    continue
                if "error" in payload:
                    raise RuntimeError(f"KataGo analysis error for {game_record['id_prefix']}: {payload['error']}")
                if "turnNumber" not in payload:
                    continue
                turn_number = int(payload.get("turnNumber"))
                root_info = payload.get("rootInfo") or {}
                current_player = payload.get("currentPlayer", "B")
                winrate = clamp_probability(root_info.get("winrate", 0.5))
                black_win_prob = winrate if current_player == "B" else 1.0 - winrate
                outputs[turn_number] = black_win_prob
        if non_json_lines:
            sample = "\n".join(non_json_lines[:5])
            print(
                f"Ignored {len(non_json_lines)} non-JSON KataGo stdout lines for {game_record['id_prefix']}:\n{sample}",
                file=sys.stderr,
            )
        return outputs


def katago_analysis_smoke_check(
    binary_path: str,
    model_path: str,
    config_path: str,
    visits: int = 8,
) -> Dict[str, float]:
    probe_game = {
        "id_prefix": "smoke-check",
        "initial_stones": [],
        "moves": [["B", "D4"]],
        "rules": DEFAULT_RULES,
        "komi": 6.5,
        "board_size": 19,
        "initial_player": "B",
    }
    with KataGoAnalysisClient(
        binary_path=binary_path,
        model_path=model_path,
        config_path=config_path,
        visits=visits,
    ) as analyzer:
        return analyzer.query_game_win_probs(probe_game, [1])


def parse_sgf_game_record(data: bytes, source_name: str) -> Dict[str, Any]:
    if sgf is None or sgf_moves is None or sgf_boards is None:
        raise ImportError("sgfmill is required for SGF preprocessing. Install with `pip install sgfmill`.")

    game = sgf.Sgf_game.from_bytes(data)
    root = game.get_root()
    board_size = game.get_size()
    rules = sgf_property_or_default(root, "RU", DEFAULT_RULES) or DEFAULT_RULES
    komi = maybe_float(sgf_property_or_default(root, "KM", 6.5), 6.5)

    board, plays = sgf_moves.get_setup_and_moves(game)
    if board is None:
        board = sgf_boards.Board(board_size)

    initial_stones_tuples: List[List[str]] = []
    initial_stones_dict = board_to_stones(board, board_size)
    for coord in initial_stones_dict["black"]:
        initial_stones_tuples.append(["B", coord])
    for coord in initial_stones_dict["white"]:
        initial_stones_tuples.append(["W", coord])

    moves: List[List[str]] = []
    for color, move in plays:
        moves.append([color.upper(), sgf_move_to_gtp(move, board_size)])

    root_player = str(sgf_property_or_default(root, "PL", "") or "").strip().upper()
    if moves:
        initial_player = moves[0][0]
    elif root_player in {"B", "W"}:
        initial_player = root_player
    else:
        initial_player = "B"

    return {
        "id_prefix": Path(source_name).stem,
        "board_size": board_size,
        "rules": rules,
        "komi": komi,
        "initial_stones": initial_stones_tuples,
        "initial_player": initial_player,
        "moves": moves,
    }


def snapshots_from_game_record(game_record: Dict[str, Any], turn_numbers: Sequence[int]) -> List[Dict[str, Any]]:
    board_size = game_record["board_size"]
    board = sgf_boards.Board(board_size)
    for color, coord in game_record["initial_stones"]:
        if coord.lower() == "pass":
            continue
        col = BOARD_COLUMNS.index(coord[0])
        row = board_size - int(coord[1:])
        try:
            board.play(row, col, color.lower())
        except ValueError as exc:
            raise ValueError(f"invalid initial setup move {color}[{coord}]") from exc

    requested = set(turn_numbers)
    snapshots: List[Dict[str, Any]] = []
    for turn_number, (color, coord) in enumerate(game_record["moves"], start=1):
        if coord.lower() != "pass":
            col = BOARD_COLUMNS.index(coord[0])
            row = board_size - int(coord[1:])
            try:
                board.play(row, col, color.lower())
            except ValueError as exc:
                raise ValueError(f"invalid move at turn {turn_number}: {color}[{coord}]") from exc
        if turn_number in requested:
            stones = board_to_stones(board, board_size)
            snapshots.append(
                {
                    "id": f"{game_record['id_prefix']}_move{turn_number}",
                    "turn_number": turn_number,
                    "board_size": board_size,
                    "to_move": "W" if color == "B" else "B",
                    "rules": game_record["rules"],
                    "komi": game_record["komi"],
                    "stones": stones,
                }
            )
    return snapshots


def preprocess_sgf_directory(cfg: PreprocessConfig) -> List[Dict[str, Any]]:
    set_seed(cfg.seed)
    sgf_paths = sorted(Path(cfg.sgf_dir).glob("**/*.sgf"))
    if cfg.max_games > 0:
        sgf_paths = sgf_paths[: cfg.max_games]

    analyzer: KataGoAnalysisClient | None = None
    if cfg.katago_binary and cfg.katago_model and cfg.katago_config:
        analyzer = KataGoAnalysisClient(
            binary_path=cfg.katago_binary,
            model_path=cfg.katago_model,
            config_path=cfg.katago_config,
            extra_args=cfg.katago_extra_args,
            visits=cfg.katago_visits,
        )

    output_rows: List[Dict[str, Any]] = []
    rng = random.Random(cfg.seed)
    skipped_files = 0

    if analyzer is None:
        analyzer_cm: Iterator[KataGoAnalysisClient | None] = iter([None])
    else:
        analyzer_cm = iter([analyzer.__enter__()])

    active_analyzer = next(analyzer_cm)
    try:
        for sgf_path in sgf_paths:
            try:
                game_record = parse_sgf_game_record(sgf_path.read_bytes(), str(sgf_path))
            except Exception as exc:
                skipped_files += 1
                print(f"[WARN] Skipping {sgf_path}: {exc}")
                continue
            turn_numbers = sampled_turn_numbers(len(game_record["moves"]), cfg)
            if not turn_numbers:
                continue
            try:
                snapshots = snapshots_from_game_record(game_record, turn_numbers)
            except ValueError as exc:
                skipped_files += 1
                print(f"[WARN] Skipping {sgf_path}: {exc}")
                continue

            win_probs_by_turn: Dict[int, float] = {}
            if active_analyzer is not None:
                win_probs_by_turn = active_analyzer.query_game_win_probs(game_record, turn_numbers)

            for row in snapshots:
                if active_analyzer is not None:
                    win_prob = win_probs_by_turn[row["turn_number"]]
                else:
                    heur = stone_heuristics(row["stones"], row["board_size"])
                    raw = 1.3 * heur["stone_balance"] + 0.9 * heur["center_balance"] + rng.uniform(-0.1, 0.1)
                    win_prob = 1.0 / (1.0 + math.exp(-raw))
                row["position_tokens"] = build_position_tokens(
                    row["board_size"],
                    row["to_move"],
                    row["komi"],
                    row["rules"],
                    row["stones"],
                    max_position_tokens=cfg.max_position_tokens,
                )
                row["win_prob"] = clamp_probability(win_prob)
                row["win_prob_bin"] = probability_to_bin(win_prob, cfg.num_bins)
                row["rationale_text"] = templated_rationale(
                    win_prob=row["win_prob"],
                    to_move=row["to_move"],
                    stones=row["stones"],
                    board_size=row["board_size"],
                )
                row.pop("turn_number", None)
                output_rows.append(row)
    finally:
        if analyzer is not None:
            analyzer.__exit__(None, None, None)

    write_jsonl(cfg.output_path, output_rows)
    if skipped_files:
        print(f"[INFO] Skipped {skipped_files} SGF files due to parse or setup issues.")
    return output_rows


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="KataGo win-probability claim-consistency experiment")
    subparsers = parser.add_subparsers(dest="command")

    train_parser = subparsers.add_parser("train", help="Run the win-probability experiment")
    train_parser.add_argument("--train-path", type=str, default="")
    train_parser.add_argument("--eval-path", type=str, default="")
    train_parser.add_argument("--max-seq-len", type=int, default=256)
    train_parser.add_argument("--max-position-tokens", type=int, default=196)
    train_parser.add_argument("--batch-size", type=int, default=32)
    train_parser.add_argument("--epochs", type=int, default=10)
    train_parser.add_argument("--lr", type=float, default=3e-4)
    train_parser.add_argument("--weight-decay", type=float, default=0.01)
    train_parser.add_argument("--d-model", type=int, default=256)
    train_parser.add_argument("--n-layers", type=int, default=4)
    train_parser.add_argument("--n-heads", type=int, default=8)
    train_parser.add_argument("--d-ff", type=int, default=1024)
    train_parser.add_argument("--dropout", type=float, default=0.1)
    train_parser.add_argument("--consistency-weight", type=float, default=0.5)
    train_parser.add_argument("--counterfactual-samples", type=int, default=256)
    train_parser.add_argument("--output-csv", type=str, default="katago_winprob_results.csv")
    train_parser.add_argument("--output-markdown", type=str, default="")
    train_parser.add_argument("--seed", type=int, default=42)
    train_parser.add_argument("--variants", nargs="+", default=list(VARIANT_NAMES))
    train_parser.add_argument("--smoke-test", action="store_true", default=False)

    preprocess_parser = subparsers.add_parser("preprocess-sgf", help="Build a JSONL dataset from SGF files")
    preprocess_parser.add_argument("--sgf-dir", required=True, type=str)
    preprocess_parser.add_argument("--output-path", required=True, type=str)
    preprocess_parser.add_argument("--sample-every-n-moves", type=int, default=20)
    preprocess_parser.add_argument("--max-positions-per-game", type=int, default=8)
    preprocess_parser.add_argument("--max-games", type=int, default=0)
    preprocess_parser.add_argument("--max-position-tokens", type=int, default=196)
    preprocess_parser.add_argument("--katago-binary", type=str, default="")
    preprocess_parser.add_argument("--katago-model", type=str, default="")
    preprocess_parser.add_argument("--katago-config", type=str, default="")
    preprocess_parser.add_argument("--katago-visits", type=int, default=32)
    preprocess_parser.add_argument("--seed", type=int, default=42)
    return parser.parse_args()


def config_from_args(args: argparse.Namespace) -> Config:
    return Config(
        train_path=args.train_path,
        eval_path=args.eval_path,
        max_seq_len=args.max_seq_len,
        max_position_tokens=args.max_position_tokens,
        batch_size=args.batch_size,
        epochs=args.epochs,
        lr=args.lr,
        weight_decay=args.weight_decay,
        d_model=args.d_model,
        n_layers=args.n_layers,
        n_heads=args.n_heads,
        d_ff=args.d_ff,
        dropout=args.dropout,
        consistency_weight=args.consistency_weight,
        counterfactual_samples=args.counterfactual_samples,
        seed=args.seed,
        output_csv=args.output_csv,
        output_markdown=args.output_markdown,
        variants=tuple(args.variants),
        smoke_test=args.smoke_test,
    )


def preprocess_config_from_args(args: argparse.Namespace) -> PreprocessConfig:
    return PreprocessConfig(
        sgf_dir=args.sgf_dir,
        output_path=args.output_path,
        sample_every_n_moves=args.sample_every_n_moves,
        max_positions_per_game=args.max_positions_per_game,
        max_games=args.max_games,
        max_position_tokens=args.max_position_tokens,
        katago_binary=args.katago_binary,
        katago_model=args.katago_model,
        katago_config=args.katago_config,
        katago_visits=args.katago_visits,
        seed=args.seed,
    )


def main() -> None:
    args = parse_args()
    if args.command is None:
        print("Specify either `train` or `preprocess-sgf`.\n", file=sys.stderr)
        raise SystemExit(2)
    if args.command == "preprocess-sgf":
        cfg = preprocess_config_from_args(args)
        rows = preprocess_sgf_directory(cfg)
        print(f"Wrote {len(rows)} dataset rows to {cfg.output_path}")
        return

    cfg = config_from_args(args)
    df = run_experiment(cfg)
    print(df.to_string(index=False))
    md_path = cfg.output_markdown or str(Path(cfg.output_csv).with_suffix(".md"))
    print(f"Saved results to {cfg.output_csv} and {md_path}")


if __name__ == "__main__":
    main()
