"""
AlphaZero-style training skeleton.
"""

import argparse
import threading
import time
import math
import random
from queue import Queue, Empty
from collections import deque, defaultdict, namedtuple

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import logging
from tqdm.auto import tqdm
import copy  
import os
import json
import sys, shlex
from src.utils import TensorBoardHandler, BatchingTensorBoardHandler


from torch.utils.tensorboard import SummaryWriter

from src.alphazero.models import AlphaNet, TransformerAlphaNet
from src.games.battlefield_duel import BattlefieldDuel, BattlefieldDuelSquad2 

from src.alphazero.trainer import ReplayBuffer, Trainer
from src.alphazero.mcts import InferenceServer, SelfPlayWorker

from src.alphazero import shaping as _shaping
from src.alphazero.shaping import ShapingConfig as _ShapingConfig


# MP backend (optional)
try:
    import multiprocessing as mp
    from src.alphazero.mp_infer import MPInferenceBroker, _MPConfig
    from src.alphazero.mp_workers import run_selfplay_proc
    _MP_AVAILABLE = True
except Exception:
    _MP_AVAILABLE = False
from src.alphazero.arena import Arena, MCTSPlayer  # NEW
from src.alphazero.featurizer import TransformerFeaturizer  # NEW

# Shared CLI utilities (tests/common_cli.py)
from train.common_cli import setup_logger_and_tb, resolve_device, log_hparams, select_game, get_git_info

# Checkpoint + analysis helpers
from src.alphazero.utils import (
    resolve_remote_path,
    extract_state_dict,
    infer_arch_from_state_dict,
    infer_in_channels,
    infer_transformer_hparams,
    infer_alphazero_from_state_dict,
)

# ----------------------------
# === Orchestration / Main ===
# ----------------------------

# --- Acceptance helpers (SPRT + Wilson) ---
class _SPRTGate:
    def __init__(self, p0=0.50, p1=0.55, alpha=0.05, beta=0.05):
        self.p0 = float(p0); self.p1 = float(p1)
        self.alpha = float(alpha); self.beta = float(beta)
        self.A = math.log((1.0 - self.beta) / max(self.alpha, 1e-9))
        self.B = math.log(max(self.beta, 1e-9) / (1.0 - self.alpha))
        self.llr = 0.0
        self.n = 0
        self.total_score = 0.0

    def update(self, x: float):
        # x in {1.0, 0.5, 0.0}
        x = float(x)
        self.n += 1
        self.total_score += x
        # Treat as Bernoulli with mean x
        self.llr += math.log(max(self.p1, 1e-9)) * x + math.log(max(1.0 - self.p1, 1e-9)) * (1.0 - x) \
                  - math.log(max(self.p0, 1e-9)) * x - math.log(max(1.0 - self.p0, 1e-9)) * (1.0 - x)
        if self.llr >= self.A:
            return True  # accept
        if self.llr <= self.B:
            return False  # reject
        return None  # continue


def _wilson_lcb(phat: float, n: int, z: float = 1.96) -> float:
    if n <= 0:
        return 0.0
    p = float(phat)
    denom = 1.0 + (z * z) / n
    centre = p + (z * z) / (2.0 * n)
    margin = z * math.sqrt((p * (1.0 - p) / n) + (z * z) / (4.0 * n * n))
    return (centre - margin) / denom

def _propagate_overrides_to_bases(game_cls, overrides: dict):
    """Apply selected overrides to base classes as well.

    Some game subclasses reference parent class constants (e.g., ROWS/COLS) directly.
    When we override geometry on the subclass, also update parents to avoid
    inconsistencies (important for Arena/threads paths where there's no worker-side
    propagation).
    """
    try:
        if not overrides:
            return
        bases = getattr(game_cls, '__mro__', ())[1:]  # exclude self
        for base in bases:
            if base is object:
                continue
            for k, v in overrides.items():
                if v is None:
                    continue
                if hasattr(base, k):
                    try:
                        cur = getattr(base, k)
                        setattr(base, k, type(cur)(v))
                    except Exception:
                        try:
                            setattr(base, k, v)
                        except Exception:
                            pass
    except Exception:
        pass


def build_default(game='battlefield_duel_squad', network="cnn", model_device=None, args=None):
    if model_device is None:
        model_device = 'cuda' if torch.cuda.is_available() else 'cpu'
    game_key = str(game).lower()
    # Use shared selector for consistent mapping across scripts
    game = select_game(game_key)
    # Optionally override board geometry and other tunables via args
    if args is not None:
        override_map = [
            ('ROWS', 'board_rows'), ('COLS', 'board_cols'),
            ('SHOOT_RANGE', 'shoot_range'), ('SHRINK_INTERVAL', 'shrink_interval'),
            ('NUM_OBSTACLES', 'num_obstacles'), ('MAX_STEPS', 'max_steps'),
            ('CAPTURE_STEPS', 'capture_steps'), ('MAX_HEALTH', 'max_health'),
        ]
        overrides = {}
        for attr, flag in override_map:
            val = getattr(args, flag, None)
            if val is not None:
                overrides[attr] = val
        if overrides:
            try:
                board_rows = overrides.pop('ROWS', None)
                board_cols = overrides.pop('COLS', None)
                if hasattr(game, 'configure') and callable(getattr(game, 'configure')):
                    game.configure(board_rows=board_rows, board_cols=board_cols, **overrides)
                else:
                    if board_rows is not None and hasattr(game, 'ROWS'):
                        setattr(game, 'ROWS', int(board_rows))
                    if board_cols is not None and hasattr(game, 'COLS'):
                        setattr(game, 'COLS', int(board_cols))
                    for k, v in overrides.items():
                        if hasattr(game, k):
                            try:
                                setattr(game, k, type(getattr(game, k))(v))
                            except Exception:
                                setattr(game, k, v)
                # Also propagate to base classes that expose the same attributes.
                base_overrides = {
                    'ROWS': board_rows,
                    'COLS': board_cols,
                }
                # Include selected tunables if present in overrides
                for name in ['SHOOT_RANGE', 'SHRINK_INTERVAL', 'NUM_OBSTACLES', 'MAX_STEPS', 'CAPTURE_STEPS', 'MAX_HEALTH']:
                    if name in overrides:
                        base_overrides[name] = overrides[name]
                _propagate_overrides_to_bases(game, base_overrides)
            except Exception:
                pass
    # Effective input channels with history stacking (+ optional extra planes for transformer)
    hist = int(getattr(args, 'history_steps', 0) or 0)
    base_C = int(getattr(game, 'CHANNELS', 1))
    if network == "cnn":
        eff_C = base_C * (1 + hist)
        net = AlphaNet(board_shape=(eff_C, game.ROWS, game.COLS), action_size=game.action_size())
    elif network == "transformer":
        # Add extra broadcast planes to input channels when enabled
        extras = int(bool(getattr(args, 'feat_steps_left', False))) \
                 + int(bool(getattr(args, 'feat_repetition', False))) \
                 + int(bool(getattr(args, 'feat_since_damage', False)))
        eff_C = base_C * (1 + hist) + extras
        net = TransformerAlphaNet(
            board_shape=(eff_C, game.ROWS, game.COLS), action_size=game.action_size(),
            embed_dim=int(args.embed_dim), depth=int(args.num_layers), num_heads=int(args.num_heads),
            use_sinusoidal_2d_pe=bool(getattr(args, 'tr_use_sincos_pe', False)),
            use_relative_bias=bool(getattr(args, 'tr_use_relative_bias', False)),
            enable_global_context=bool(getattr(args, 'tr_enable_global_ctx', False)),
            include_action_tokens=bool(getattr(args, 'tr_include_action_tokens', False)),
            action_token_in_dim=getattr(args, 'tr_action_token_dim', None),
            cross_attn_layers=int(getattr(args, 'tr_cross_attn_layers', 1)),
        )
    else:
        raise ValueError(f"Unknown network: {network}")
    net.to(model_device)
    return game, net


def _infer_hist_and_extras(game_cls, in_channels: int, is_transformer: bool):
    """Best-effort guess of history steps and extra planes to match in_channels.

    Returns (history_steps, extras_dict) where extras_dict toggles the 3 boolean feature planes
    in a stable order: steps_left, repetition, since_damage.
    """
    base_C = int(getattr(game_cls, 'CHANNELS', 1))
    extras_dict = {"feat_steps_left": False, "feat_repetition": False, "feat_since_damage": False}
    if base_C <= 0:
        return 0, extras_dict
    if not is_transformer:
        # cnn has no extras planes
        hist = max(0, (in_channels // base_C) - 1)
        return hist, extras_dict
    # Try extras_count in [0..3] to make (in_channels - extras) divisible by base_C
    for e in range(0, 4):
        rem = in_channels - e
        if rem >= base_C and rem % base_C == 0:
            hist = (rem // base_C) - 1
            if hist >= 0:
                # enable first e flags in stable order
                order = ["feat_steps_left", "feat_repetition", "feat_since_damage"]
                for i in range(e):
                    extras_dict[order[i]] = True
                return hist, extras_dict
    # Fallback: no perfect match; approximate
    hist = max(0, (in_channels // base_C) - 1)
    return hist, extras_dict


def main():
    parser = argparse.ArgumentParser()
    # removed --mode; always train
    parser.add_argument('--game', default='battlefield_duel_squad',
                    help='Game options: battlefield_duel|battlefield_duel_squad)')
    parser.add_argument('--network', choices=['cnn', 'transformer'], default='cnn', help='Neural network architecture')
    # Device selection
    parser.add_argument('--device', type=str, default=None, help='Compute device: cpu, cuda, or cuda:N (e.g., cuda:0, cuda:1)')
    parser.add_argument('--num-workers', type=int, default=4)
    parser.add_argument('--sims', type=int, default=50, help='MCTS simulations per move in self-play')
    # Board size (optional, supported by some games e.g., BattlefieldDuel, BattlefieldDuelSquad2)
    parser.add_argument('--board-rows', type=int, default=10, help='Board rows for supported games')
    parser.add_argument('--board-cols', type=int, default=10, help='Board cols for supported games')
    # Additional game parameter overrides (supported by some games)
    parser.add_argument('--shoot-range', type=int, default=None, help='Override SHOOT_RANGE (if supported)')
    parser.add_argument('--shrink-interval', type=int, default=None, help='Override SHRINK_INTERVAL (if supported)')
    parser.add_argument('--num-obstacles', type=int, default=None, help='Override NUM_OBSTACLES (if supported)')
    parser.add_argument('--max-steps', type=int, default=None, help='Override MAX_STEPS (if supported)')
    parser.add_argument('--capture-steps', type=int, default=None, help='Override CAPTURE_STEPS (if supported)')
    parser.add_argument('--max-health', type=int, default=None, help='Override MAX_HEALTH (if supported)')
    
    # Root exploration noise
    parser.add_argument('--root-noise-alpha', type=float, default=0.3, help='Dirichlet alpha for root exploration noise')
    parser.add_argument('--root-noise-frac', type=float, default=0.25, help='Mix fraction epsilon for root noise (epsilon in (1-eps)*p + eps*Dir(alpha))')
    parser.add_argument('--no-root-noise', action='store_true', help='Disable adding Dirichlet noise at the root during self-play')

    # Training parameters
    parser.add_argument('--batch-size', type=int, default=64, help='Batch size for inference server')
    parser.add_argument('--train-batch', type=int, default=32, help='Batch size for training')
    parser.add_argument('--replay-size', type=int, default=20000, help='Replay buffer size')
    parser.add_argument('--steps', type=int, default=200)  # kept for backward compat (unused in new loop)
    parser.add_argument('--batch-wait', type=float, default=10, help='Seconds to wait Inference Server for batch filling')
    parser.add_argument('--min-replay', type=int, default=256, help='Warmup replay size before training')  # unused now
    parser.add_argument('--iterations', type=int, default=10000, help='Outer iterations (generate+train cycles)')
    parser.add_argument('--examples-per-iter', type=int, default=1024, help='Self-play examples to collect each iteration')
    parser.add_argument('--train-steps-per-iter', type=int, default=500, help='Optimizer steps per iteration')
    parser.add_argument('--eval-games', type=int, default=20, help='Number of arena games after each iteration')
    parser.add_argument('--eval-sims', type=int, default=50, help='MCTS simulations per move in arena')
    parser.add_argument('--accept-threshold', type=float, default=0.6, help='(prev-threshold mode) Accept new model if winrate >= threshold vs previous')  # NEW
    
    # Acceptance policy (advanced)
    parser.add_argument('--accept-mode', type=str, default='prev-threshold', choices=['prev-threshold', 'pool-sprt'], help='Model acceptance policy')
    parser.add_argument('--accept-pool-size', type=int, default=8, help='Hall-of-fame size for pool-sprt')
    parser.add_argument('--accept-p0', type=float, default=0.50, help='SPRT null winrate (H0)')
    parser.add_argument('--accept-p1', type=float, default=0.55, help='SPRT alternative winrate (H1)')
    parser.add_argument('--accept-alpha', type=float, default=0.05, help='SPRT type I error')
    parser.add_argument('--accept-beta', type=float, default=0.05, help='SPRT type II error')
    parser.add_argument('--accept-max-games', type=int, default=400, help='Maximum evaluation games for pool-sprt before fallback decision')
    parser.add_argument('--accept-wilson-thresh', type=float, default=0.55, help='Wilson LCB threshold if SPRT inconclusive at cap')
    
    # Backend options
    parser.add_argument('--backend', type=str, default='processes', choices=['threads', 'processes'], help='Self-play backend')
    parser.add_argument('--auto-batch-wait', action='store_true', default=False, help='Adaptively tune broker max_batch_wait based on arrival rate (processes backend)')
    
    # Reward shaping (optional)
    parser.add_argument('--shaping-enable', action='store_true', help='Enable potential-based reward shaping (also requires non-zero shaping-scale and a valid phi function)')
    parser.add_argument('--shaping-phi', type=str, default='', help='Name of phi() in src.alphazero.shaping (e.g., chess_material_phi) - default: Phi(s) = 0 (i.e. no shaping)')
    parser.add_argument('--shaping-gamma', type=float, default=1.0, help='Discount factor for shaping dynamics')
    parser.add_argument('--shaping-scale', type=float, default=0.0, help='Scale factor for shaping contribution (annealed if anneal>0)')
    parser.add_argument('--shaping-use-in-mcts', action='store_true', help='Apply shaping in MCTS backups')
    parser.add_argument('--shaping-use-in-targets', action='store_true', help='Apply shaping to training targets')
    parser.add_argument('--shaping-anneal-steps', type=int, default=0, help='Linear anneal shaping scale to 0 over this many steps (0=off)')
    
    # Transformer parameters
    parser.add_argument('--embed-dim', type=float, default=128, help='Transformer: embedding dimension')
    parser.add_argument('--num-heads', type=float, default=4, help='Transformer: number of heads per layer')
    parser.add_argument('--num-layers', type=float, default=4, help='Transformer: number of layers')
    
    # Extended Transformer options
    parser.add_argument('--tr-use-sincos-pe', action='store_true', default=False, help='Use 2D sinusoidal positional encoding (Transformer only)')
    parser.add_argument('--tr-enable-global-ctx', action='store_true', default=False, help='Enable side/turn global context (Transformer only)')
    parser.add_argument('--tr-include-action-tokens', action='store_true', default=False, help='Enable action-token cross-attention head (Transformer only)')
    parser.add_argument('--tr-action-token-dim', type=int, default=None, help='Action token feature dim passed to the model (Transformer only)')
    parser.add_argument('--tr-cross-attn-layers', type=int, default=1, help='Cross-attention layers for action tokens (Transformer only)')
    parser.add_argument('--tr-use-relative-bias', action='store_true', default=False, help='Use 2D relative positional bias in attention (Transformer only)')
    
    # Featurizer options
    parser.add_argument('--history-steps', type=int, default=0, help='Number of past canonical boards to stack as history planes')
    parser.add_argument('--feat-steps-left', action='store_true', default=False, help='Append a broadcast plane with normalized steps_left (if available)')
    parser.add_argument('--feat-repetition', action='store_true', default=False, help='Append a broadcast plane indicating repetition of current root state in episode')
    parser.add_argument('--feat-since-damage', action='store_true', default=False, help='Append a broadcast plane with normalized steps since last damage (best-effort)')
    
    # Logging / run management
    parser.add_argument('--log-level', type=str, default='INFO', choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], help='Logging level')
    parser.add_argument('--run-id', type=str, default=None, help='Run ID; may be any string; defaults to SLURM_JOB_ID or process ID if not set')
    parser.add_argument('--init-checkpoint', type=str, default=None, help='Initialize model weights from this checkpoint (local path, file:// URL, http(s):// URL, or scp user@host:/path). Network/game params are inferred automatically.')
    args = parser.parse_args()

    # Assign/run directory under logs/<ID>
    if args.run_id is None or str(args.run_id).strip() == "":
        # Prefer SLURM_JOB_ID if present; else fall back to current process ID
        slurm_id = os.environ.get('SLURM_JOB_ID')
        args.run_id = str(slurm_id) if slurm_id is not None else str(os.getpid())  # NEW
    else:
        args.run_id = str(args.run_id)
    base_dir = f"logs/alphazero/{args.game}/{args.network}/{args.board_rows}x{args.board_cols}/{args.run_id}"  # NEW
    os.makedirs(base_dir, exist_ok=True)  # NEW
    checkpoint_dir = os.path.join(base_dir, "checkpoints")  # NEW
    os.makedirs(checkpoint_dir, exist_ok=True)  # NEW
    
    # configure logger (including TensorBoard handler)
    logger, writer = setup_logger_and_tb("alphazero.main", base_dir, args.log_level, flush_interval_s=60)
    # Log the full command line
    try:
        cmdline = " ".join(shlex.quote(a) for a in sys.argv)
        logger.info("Command line: %s", cmdline)
        try:
            # Also record explicitly in TensorBoard for convenience
            if writer is not None:
                writer.add_text("run/cmdline", "```bash\n" + cmdline + "\n```", 0)
        except Exception:
            pass
    except Exception:
        logger.debug("Failed to capture command line for logging.")

    git_info = get_git_info()
    if git_info:
        logger.info("Git commit: %s | branch: %s | dirty: %s | describe: %s",
                    git_info.get('commit_short', git_info.get('commit', '?')),
                    git_info.get('branch', '?'),
                    git_info.get('dirty', '?'),
                    git_info.get('describe', '?'))
        try:
            writer.add_text('run/git/commit', str(git_info.get('commit', 'unknown')), 0)
            writer.add_text('run/git/commit_short', str(git_info.get('commit_short', 'unknown')), 0)
            writer.add_text('run/git/branch', str(git_info.get('branch', 'unknown')), 0)
            writer.add_text('run/git/describe', str(git_info.get('describe', 'unknown')), 0)
            writer.add_text('run/git/dirty', str(git_info.get('dirty', 'unknown')), 0)
        except Exception:
            pass

    
    logger.info("Starting with args: %s", vars(args))
    logger.info("Run ID: %s | Output directory: %s", args.run_id, base_dir)  # NEW

    # Choose device: default CPU; resolve & log
    device = resolve_device(args.device, logger)
    # Log final device choice
    if device.startswith('cuda'):
        try:
            idx = int(device.split(':', 1)[1])
            name = torch.cuda.get_device_name(idx)
            logger.info("Using CUDA device: cuda:%d (%s)", idx, name)
        except Exception:
            logger.info("Using device: %s", device)
    else:
        logger.info("Using device: CPU")

    # TensorBoard writer already provided by setup_logger_and_tb
    # Log full args/hparams in a reusable way
    # Avoid creating an extra timestamped TB run dir from add_hparams; we keep text/scalars only.
    log_hparams(writer, vars(args), add_hparams_plugin=False)

    # Record the chosen device in TensorBoard
    writer.add_text("env/device", device)

    # Local counters with closure updaters for clarity
    _collect_step = 0
    def _collect_step_inc(n):
        nonlocal _collect_step
        _collect_step += int(n)

    # Helper: build featurizer config
    def _build_feat_cfg():
        return dict(
            history_steps=int(getattr(args, 'history_steps', 0) or 0),
            include_steps_left_plane=bool(getattr(args, 'feat_steps_left', False)),
            include_repetition_plane=bool(getattr(args, 'feat_repetition', False)),
            include_since_damage_plane=bool(getattr(args, 'feat_since_damage', False)),
        )

    # Helper: setup backend (threads or processes) and start workers
    def _setup_backend(game, net, device, model_lock, feat_cfg):
        backend = args.backend
        # Build shaping config and phi
        shaping_cfg = None
        phi_fn = None
        if args.shaping_enable and _ShapingConfig is not None:
            try:
                shaping_cfg = _ShapingConfig(
                    gamma=float(args.shaping_gamma),
                    scale=float(args.shaping_scale),
                    use_in_mcts=bool(args.shaping_use_in_mcts),
                    use_in_targets=bool(args.shaping_use_in_targets),
                    anneal_steps=int(args.shaping_anneal_steps),
                )
            except Exception:
                shaping_cfg = None
        if args.shaping_enable and _shaping is not None and args.shaping_phi:
            phi_fn = getattr(_shaping, args.shaping_phi, None)
        if backend == 'threads':
            infer = InferenceServer(net, device=device, max_batch_size=args.batch_size, max_batch_wait=args.batch_wait, model_lock=model_lock)
            examples_queue = Queue()
            workers = []
            logger.info("Spawning %d self-play workers (sims=%d) [threads]...", args.num_workers, args.sims)
            for i in range(1, args.num_workers + 1):
                w = SelfPlayWorker(
                    game, infer, examples_queue,
                    num_sims=args.sims, temperature=1.0,
                    history_steps=int(getattr(args, 'history_steps', 0) or 0),
                    featurizer=None, featurizer_config=feat_cfg,
                    shaping_config=shaping_cfg, phi_fn=phi_fn,
                    root_dirichlet_alpha=float(args.root_noise_alpha),
                    root_exploration_frac=float(args.root_noise_frac),
                    add_root_noise=(not args.no_root_noise),
                )
                w.start()
                workers.append(w)
                logger.debug("Worker %d started: %s", i, getattr(w, "name", "self-play"))
            logger.info("All workers started.")
            writer.add_scalar("system/workers", len(workers), 0)
            return backend, infer, workers, None, [], examples_queue
        else:
            if not _MP_AVAILABLE:
                raise RuntimeError("Multiprocessing backend requested but mp modules not available")
            try:
                mp.set_start_method('spawn', force=True)
            except RuntimeError:
                pass
            # Broker config: match local model channels (history + extras if transformer)
            hist = int(getattr(args, 'history_steps', 0) or 0)
            base_C = int(getattr(game, 'CHANNELS', 1))
            extras = int(bool(getattr(args, 'feat_steps_left', False))) \
                     + int(bool(getattr(args, 'feat_repetition', False))) \
                     + int(bool(getattr(args, 'feat_since_damage', False)))
            eff_C = base_C * (1 + hist) + (extras if args.network == 'transformer' else 0)
            board_shape = (eff_C, game.ROWS, game.COLS)
            cfg = _MPConfig(
                model_type=args.network,
                board_shape=board_shape,
                action_size=game.action_size(),
                embed_dim=int(args.embed_dim),
                num_heads=int(args.num_heads),
                num_layers=int(args.num_layers),
                use_sinusoidal_2d_pe=bool(getattr(args, 'tr_use_sincos_pe', False)),
                use_relative_bias=bool(getattr(args, 'tr_use_relative_bias', False)),
                enable_global_context=bool(getattr(args, 'tr_enable_global_ctx', False)),
                include_action_tokens=bool(getattr(args, 'tr_include_action_tokens', False)),
                action_token_in_dim=getattr(args, 'tr_action_token_dim', None),
                cross_attn_layers=int(getattr(args, 'tr_cross_attn_layers', 1)),
                device=device,
                max_batch_size=int(args.batch_size),
                max_batch_wait=float(args.batch_wait),
                auto_batch_wait=bool(getattr(args, 'auto_batch_wait', False)),
            )
            broker = MPInferenceBroker(state_dict=net.state_dict(), config=cfg)
            broker.start()
            examples_queue = mp.Queue(maxsize=0)
            proc_workers = []
            logger.info("Spawning %d self-play workers (sims=%d) [processes]...", args.num_workers, args.sims)
            for i in range(args.num_workers):
                p = mp.Process(
                    target=run_selfplay_proc,
                    args=(broker.request_queue, examples_queue, game, args.sims, 1.0, hist),
                    kwargs={
                        "featurizer_config": feat_cfg,
                        "shaping_config": (vars(shaping_cfg) if shaping_cfg is not None else None),
                        "phi_name": (args.shaping_phi if args.shaping_enable else None),
                        "root_noise": {
                            "alpha": float(args.root_noise_alpha),
                            "frac": float(args.root_noise_frac),
                            "enable": (not args.no_root_noise),
                        },
                        "board_rows": getattr(args, 'board_rows', None),
                        "board_cols": getattr(args, 'board_cols', None),
                        # Additional override params (worker-side configure)
                        "shoot_range": getattr(args, 'shoot_range', None),
                        "shrink_interval": getattr(args, 'shrink_interval', None),
                        "num_obstacles": getattr(args, 'num_obstacles', None),
                        "max_steps": getattr(args, 'max_steps', None),
                        "capture_steps": getattr(args, 'capture_steps', None),
                        "max_health": getattr(args, 'max_health', None),
                    },
                    daemon=True,
                )
                p.start()
                proc_workers.append(p)
            writer.add_scalar("system/workers", len(proc_workers), 0)
            return backend, None, [], broker, proc_workers, examples_queue

    # Helper: wait for the first example to verify the pipeline
    def _wait_for_first_example(examples_queue, replay):
        try:
            logger.info(f"Waiting for first self-play example (up to {args.batch_wait}s)...")
            ex = examples_queue.get(timeout=args.batch_wait)
            replay.push([ex])
            _collect_step_inc(1)
            writer.add_scalar("collect/replay_size", len(replay), _collect_step)  # NEW
            writer.add_scalar("queue/size", getattr(examples_queue, "qsize", lambda: 0)(), _collect_step)  # NEW
            logger.info("Received first example. Replay size now: %d", len(replay))
        except Empty:
            logger.warning("No examples received within %ss. Self-play may be slow or stalled.", args.batch_wait)
            writer.add_scalar("collect/timeout_first_example", 1, _collect_step)  # NEW

    # Helper: collection phase
    def _collect_phase(iter_idx, examples_queue, replay):
        """Collect exactly args.examples_per_iter NEW examples, keeping replay persistent across iterations."""
        target_new = int(args.examples_per_iter)
        new_count = 0
        pcol = tqdm(total=target_new, desc=f"Iter {iter_idx} collect", leave=True) if tqdm else None
        t0 = time.time()
        while new_count < target_new:
            drained = 0
            while not examples_queue.empty() and new_count < target_new:
                try:
                    ex = examples_queue.get_nowait()
                except Empty:
                    break
                replay.push([ex])
                drained += 1
            if drained:
                new_count += drained
                _collect_step_inc(drained)
                writer.add_scalar("collect/replay_size", len(replay), _collect_step)
                writer.add_scalar("queue/size", getattr(examples_queue, "qsize", lambda: 0)(), _collect_step)
                if pcol:
                    pcol.update(min(drained, target_new - (pcol.n or 0)))
            else:
                try:
                    ex = examples_queue.get(timeout=0.5)
                    replay.push([ex])
                    new_count += 1
                    _collect_step_inc(1)
                    writer.add_scalar("collect/replay_size", len(replay), _collect_step)
                    writer.add_scalar("queue/size", getattr(examples_queue, "qsize", lambda: 0)(), _collect_step)
                    if pcol:
                        pcol.update(1)
                except Empty:
                    pass
            if pcol and (new_count % max(1, target_new // 10) == 0):
                pcol.set_postfix({"new": new_count, "replay": len(replay), "qsize": getattr(examples_queue, "qsize", lambda: 0)()})
        if pcol:
            pcol.close()
        elapsed = time.time() - t0
        logger.info("Collected %d new examples in %.1fs (replay=%d, queue≈%s)", new_count, elapsed, len(replay), getattr(examples_queue, "qsize", lambda: 0)())
        writer.add_scalar("collect/examples_per_iter", new_count, iter_idx)
        writer.add_scalar("collect/seconds", elapsed, iter_idx)

    # Helper: training phase
    def _train_phase(iter_idx, trainer, replay, examples_queue, train_global_step):
        steps_this_iter = args.train_steps_per_iter
        ptrain = tqdm(total=steps_this_iter, desc=f"Iter {iter_idx} train  ", leave=True) if tqdm else None
        trained = 0
        cum_loss = cum_policy_loss = cum_value_loss = 0.0
        while trained < steps_this_iter:
            drained = 0
            while not examples_queue.empty():
                try:
                    ex = examples_queue.get_nowait()
                except Empty:
                    break
                replay.push([ex])
                drained += 1
            if drained and (trained % 10 == 0):
                logger.debug("Iter %d: drained %d examples (replay=%d)", iter_idx, drained, len(replay))
                writer.add_scalar("collect/drained", drained, train_global_step)
                writer.add_scalar("replay/size", len(replay), train_global_step)

            if len(replay) >= args.train_batch:
                batch = trainer.replay.sample(args.train_batch)
                loss, pl, vl, ent = trainer.train_step(batch)
                cum_loss += loss; cum_policy_loss += pl; cum_value_loss += vl
                trained += 1; train_global_step += 1
                writer.add_scalar("train/loss", loss, train_global_step)
                writer.add_scalar("train/total_loss", loss, train_global_step)
                writer.add_scalar("train/policy_loss", pl, train_global_step)
                writer.add_scalar("train/value_loss", vl, train_global_step)
                writer.add_scalar("train/entropy", ent, train_global_step)
                writer.add_scalar("replay/size", len(trainer.replay), train_global_step)
                writer.add_scalar("queue/size", getattr(examples_queue, "qsize", lambda: 0)(), train_global_step)
                writer.add_scalar("replay/average_value", sum(ex[2] for ex in replay.buf) / len(replay) if replay else 0, train_global_step)
                if ptrain:
                    ptrain.set_postfix({"loss": f"{loss:.4f}", "pl": f"{pl:.4f}", "vl": f"{vl:.4f}", "replay": len(trainer.replay), "q": getattr(examples_queue, "qsize", lambda: 0)()})
                    ptrain.update(1)
            else:
                time.sleep(0.02)
        if ptrain:
            ptrain.close()
        writer.add_scalar("train_avg/loss", cum_loss / trained if trained > 0 else 0.0, iter_idx)
        writer.add_scalar("train_avg/policy_loss", cum_policy_loss / trained if trained > 0 else 0.0, iter_idx)
        writer.add_scalar("train_avg/value_loss", cum_value_loss / trained if trained > 0 else 0.0, iter_idx)
        return train_global_step

    # Helper: arena evaluation and accept/reject
    def _arena_phase(iter_idx, game, net, prev_net_snapshot, prev_optim_state, trainer, backend, broker, model_lock, checkpoint_dir, feat_cfg):
        # Sanity check: ensure model input spatial size matches game geometry for CNNs.
        try:
            if isinstance(net, AlphaNet):
                in_feat = int(net.policy_fc.in_features)
                # policy_conv has 2 output channels
                expected_cells = in_feat // 2
                H = int(getattr(game, 'ROWS', 0)); W = int(getattr(game, 'COLS', 0))
                if H * W != expected_cells:
                    # If square, auto-adjust game geometry for evaluation; else warn
                    import math as _m
                    side = int(_m.isqrt(expected_cells))
                    if side * side == expected_cells and side > 0:
                        try:
                            if hasattr(game, 'configure') and callable(getattr(game, 'configure')):
                                game.configure(board_rows=side, board_cols=side)
                            else:
                                if hasattr(game, 'ROWS'): setattr(game, 'ROWS', side)
                                if hasattr(game, 'COLS'): setattr(game, 'COLS', side)
                            _propagate_overrides_to_bases(game, {'ROWS': side, 'COLS': side})
                        except Exception:
                            pass
                    else:
                        logging.getLogger("alphazero.main").warning(
                            "Arena: CNN model expects 2*H*W=%d inputs but game has H*W=%d (H=%d,W=%d). You may need to adjust board size.",
                            in_feat, H*W, H, W,
                        )
        except Exception:
            pass
        hist = int(getattr(args, 'history_steps', 0) or 0)
        if args.accept_mode == 'pool-sprt':
            logger.info("Evaluating new model vs hall-of-fame (pool-sprt): p0=%.2f p1=%.2f alpha=%.3f beta=%.3f max_games=%d pool=%d",
                        args.accept_p0, args.accept_p1, args.accept_alpha, args.accept_beta, args.accept_max_games, args.accept_pool_size)
            # Build new model inference once
            eval_infer_new = InferenceServer(net, device=device, max_batch_size=32, max_batch_wait=0.01, model_lock=model_lock)
            p_new = MCTSPlayer(game, eval_infer_new, num_sims=args.eval_sims, temperature=0.0, name="new", history_steps=hist, featurizer=None, featurizer_config=feat_cfg)

            # Load/prepare hall-of-fame pool
            base_dir_local = os.path.dirname(checkpoint_dir)
            pool_path = os.path.join(base_dir_local, 'accept_pool.json')
            try:
                if os.path.exists(pool_path):
                    with open(pool_path, 'r') as f:
                        pool_entries = json.load(f) or []
                else:
                    pool_entries = []
            except Exception:
                pool_entries = []
            # If empty, seed with previous snapshot in-memory
            pool_models = []  # list of tuples (name, model or state_dict)
            if not pool_entries:
                pool_models.append(("prev", copy.deepcopy(prev_net_snapshot).to(device)))
            else:
                # Load up to pool_size most recent entries
                pool_entries = sorted(pool_entries, key=lambda e: e.get('iter', 0), reverse=True)[:int(args.accept_pool_size)]
                for e in pool_entries:
                    pth = e.get('path')
                    if not pth or not os.path.exists(pth):
                        continue
                    try:
                        sd = extract_state_dict(torch.load(pth, map_location='cpu'))
                        # Load into a template copy of current net for arch match
                        opp = copy.deepcopy(net).to(device)
                        opp.load_state_dict(sd, strict=False)
                        pool_models.append((os.path.basename(pth), opp))
                    except Exception:
                        continue
            if not pool_models:
                pool_models.append(("prev", copy.deepcopy(prev_net_snapshot).to(device)))

            # SPRT loop
            gate = _SPRTGate(p0=args.accept_p0, p1=args.accept_p1, alpha=args.accept_alpha, beta=args.accept_beta)
            rng = np.random.default_rng(iter_idx)
            new_wins = old_wins = draws = 0
            starting_player = 1
            games_played = 0
            # Track decision details for logging
            decision_reason = ""
            decision_basis = "SPRT"
            while games_played < int(args.accept_max_games):
                # Pick an opponent uniformly at random
                name, opp_model = pool_models[rng.integers(len(pool_models))]
                # Build opponent inference server and player for this game
                eval_infer_old = InferenceServer(opp_model, device=device, max_batch_size=32, max_batch_wait=0.01, model_lock=model_lock)
                p_old = MCTSPlayer(game, eval_infer_old, num_sims=args.eval_sims, temperature=0.0, name=f"opp:{name}", history_steps=hist, featurizer=None, featurizer_config=feat_cfg)
                # Play one game
                arena = Arena(game, p_new, p_old, verbose=False)
                res = arena.play_game(starting_player=starting_player)
                eval_infer_old.stop()
                # Update stats
                if res.winner > 0:
                    new_wins += 1; s = 1.0
                elif res.winner < 0:
                    old_wins += 1; s = 0.0
                else:
                    draws += 1; s = 0.5
                verdict = gate.update(s)
                games_played += 1
                starting_player = -starting_player  # alternate
                if verdict is True:
                    accepted = True
                    # Upper/lower boundaries for LLR (classical):
                    try:
                        upper = math.log((1.0 - float(args.accept_beta)) / float(args.accept_alpha))
                        lower = math.log(float(args.accept_beta) / (1.0 - float(args.accept_alpha)))
                    except Exception:
                        upper = float('inf'); lower = float('-inf')
                    phat = gate.total_score / max(gate.n, 1)
                    decision_reason = (
                        f"SPRT accept after {games_played} games: LLR={gate.llr:.3f} >= {upper:.3f}; "
                        f"p0={args.accept_p0:.2f}, p1={args.accept_p1:.2f}, alpha={args.accept_alpha:.3f}, beta={args.accept_beta:.3f}; "
                        f"new={new_wins}, old={old_wins}, draws={draws}, p̂={phat:.3f}"
                    )
                    break
                if verdict is False:
                    accepted = False
                    try:
                        upper = math.log((1.0 - float(args.accept_beta)) / float(args.accept_alpha))
                        lower = math.log(float(args.accept_beta) / (1.0 - float(args.accept_alpha)))
                    except Exception:
                        upper = float('inf'); lower = float('-inf')
                    phat = gate.total_score / max(gate.n, 1)
                    decision_reason = (
                        f"SPRT reject after {games_played} games: LLR={gate.llr:.3f} <= {lower:.3f}; "
                        f"p0={args.accept_p0:.2f}, p1={args.accept_p1:.2f}, alpha={args.accept_alpha:.3f}, beta={args.accept_beta:.3f}; "
                        f"new={new_wins}, old={old_wins}, draws={draws}, p̂={phat:.3f}"
                    )
                    break
            else:
                # Inconclusive: fallback to Wilson LCB
                logger.info("SPRT inconclusive after %d games: new=%d old=%d draws=%d, fallback to Wilson LCB", games_played, new_wins, old_wins, draws)
                phat = gate.total_score / max(gate.n, 1)
                lcb = _wilson_lcb(phat, gate.n, z=1.96)
                accepted = (lcb >= float(args.accept_wilson_thresh))
                logger.info("Wilson LCB=%.4f (p̂=%.4f n=%d) => %s (thresh=%.4f)", lcb, phat, gate.n, ("ACCEPT" if accepted else "REJECT"), float(args.accept_wilson_thresh))
                decision_basis = "Wilson LCB"
                decision_reason = (
                    f"Wilson LCB decision after {games_played} games: LCB={lcb:.4f} vs thresh={float(args.accept_wilson_thresh):.4f}; "
                    f"p̂={phat:.4f}, n={gate.n}, new={new_wins}, old={old_wins}, draws={draws}"
                )
            # Log
            writer.add_scalar("arena/current_wins", new_wins, iter_idx)
            writer.add_scalar("arena/previous_wins", old_wins, iter_idx)
            writer.add_scalar("arena/draws", draws, iter_idx)
            winrate = (new_wins / (new_wins + old_wins)) if (new_wins + old_wins) > 0 else 0.0
            writer.add_scalar("arena/current_winrate", winrate, iter_idx)
            writer.add_scalar("arena/sprt_llr", gate.llr, iter_idx)
            writer.add_scalar("arena/sprt_n", gate.n, iter_idx)
            writer.add_scalar("arena/accepted", int(accepted), iter_idx)
            # Descriptive decision summary
            if accepted:
                logger.info("Decision (pool-sprt): ACCEPT — %s", decision_reason or f"basis={decision_basis}")
            else:
                logger.info("Decision (pool-sprt): REJECT — %s", decision_reason or f"basis={decision_basis}")
            try:
                eval_infer_new.stop()
            except Exception:
                pass
        else:
            # prev-threshold (existing behavior)
            logger.info("Evaluating new model vs previous snapshot in Arena: games=%d, sims=%d", args.eval_games, args.eval_sims)
            eval_infer_new = InferenceServer(net, device=device, max_batch_size=32, max_batch_wait=0.01, model_lock=model_lock)
            old_for_eval = copy.deepcopy(prev_net_snapshot).to(device)
            eval_infer_old = InferenceServer(old_for_eval, device=device, max_batch_size=32, max_batch_wait=0.01, model_lock=model_lock)
            p_new = MCTSPlayer(game, eval_infer_new, num_sims=args.eval_sims, temperature=0.0, name="new", history_steps=hist, featurizer=None, featurizer_config=feat_cfg)
            p_old = MCTSPlayer(game, eval_infer_old, num_sims=args.eval_sims, temperature=0.0, name="old", history_steps=hist, featurizer=None, featurizer_config=feat_cfg)
            arena = Arena(game, p_new, p_old, verbose=False)
            results = arena.play_games_balanced(args.eval_games, alternate_colors=True, show_progress=bool(tqdm), num_workers=args.num_workers)
            new_wins = results.get("new", 0); old_wins = results.get("old", 0); draws = results.get("draws", 0)
            winrate = new_wins / (old_wins + new_wins) if (old_wins + new_wins) > 0 else 0
            logger.info("Arena results: new=%d old=%d draws=%d (new winrate=%.1f%%)", new_wins, old_wins, draws, 100.0 * winrate)
            writer.add_scalar("arena/current_wins", new_wins, iter_idx)
            writer.add_scalar("arena/previous_wins", old_wins, iter_idx)
            writer.add_scalar("arena/draws", draws, iter_idx)
            writer.add_scalar("arena/current_winrate", winrate, iter_idx)
            writer.add_scalar("arena/threshold", args.accept_threshold, iter_idx)
            accepted = winrate >= args.accept_threshold
            writer.add_scalar("arena/accepted", int(accepted), iter_idx)
            eval_infer_new.stop(); eval_infer_old.stop()
        if accepted:
            if args.accept_mode == 'pool-sprt':
                logger.info("Accepted new model (pool-sprt). Saving checkpoint and updating baseline.")
            else:
                logger.info("Accepted new model (winrate %.1f%% >= %.1f%%). Saving checkpoint and updating baseline.", 100.0 * winrate, 100.0 * args.accept_threshold)
            ckpt_path = os.path.join(checkpoint_dir, f"iter_{iter_idx:03d}_accepted.pt"); tmp_path = ckpt_path + ".tmp"
            ckpt_obj = {'model': net.state_dict(), 'optimizer': trainer.optimizer.state_dict(), 'iter': iter_idx, 'args': vars(args)}
            # Attach git metadata if available
            if git_info:
                ckpt_obj['git'] = git_info
            try:
                try:
                    torch.save(ckpt_obj, tmp_path, _use_new_zipfile_serialization=False)
                except TypeError:
                    torch.save(ckpt_obj, tmp_path)
                os.replace(tmp_path, ckpt_path)
                logger.info("Saved checkpoint: %s", ckpt_path)
                writer.add_text("checkpoint/last", ckpt_path, iter_idx)
            except Exception as e:
                try:
                    if os.path.exists(tmp_path): os.remove(tmp_path)
                except Exception:
                    pass
                logger.warning("Failed to save checkpoint to %s: %s", ckpt_path, e)
            with model_lock:
                prev_net_snapshot = copy.deepcopy(net).cpu()
                prev_optim_state = copy.deepcopy(trainer.optimizer.state_dict())
            if backend == 'processes' and broker is not None:
                try:
                    broker.update_weights(net.state_dict())
                except Exception:
                    logger.exception("Failed to update broker after acceptance")
            # Update hall-of-fame pool (if enabled)
            if args.accept_mode == 'pool-sprt':
                base_dir_local = os.path.dirname(checkpoint_dir)
                pool_path = os.path.join(base_dir_local, 'accept_pool.json')
                try:
                    if os.path.exists(pool_path):
                        with open(pool_path, 'r') as f:
                            entries = json.load(f) or []
                    else:
                        entries = []
                except Exception:
                    entries = []
                entries.append({'path': ckpt_path, 'iter': int(iter_idx), 'time': time.time()})
                # keep most recent accept-pool-size entries
                entries = sorted(entries, key=lambda e: e.get('iter', 0), reverse=True)[:int(args.accept_pool_size)]
                try:
                    with open(pool_path, 'w') as f:
                        json.dump(entries, f, indent=2)
                    logger.info("Updated hall-of-fame pool (%d entries)", len(entries))
                except Exception:
                    logger.warning("Failed to update hall-of-fame pool at %s", pool_path)
        else:
            if args.accept_mode == 'pool-sprt':
                logger.info("Rejected new model (pool-sprt). Reverting to previous weights.")
            else:
                logger.info("Rejected new model (winrate %.1f%% < %.1f%%). Reverting to previous weights.", 100.0 * winrate, 100.0 * args.accept_threshold)
            with model_lock:
                net.load_state_dict(prev_net_snapshot.state_dict())
                try:
                    trainer.optimizer.load_state_dict(prev_optim_state)
                except Exception:
                    logger.exception("Failed to restore optimizer state; reinitializing optimizer")
                    trainer.optimizer = torch.optim.Adam(net.parameters(), lr=2e-3)
            if backend == 'processes' and broker is not None:
                try:
                    broker.update_weights(prev_net_snapshot.state_dict())
                except Exception:
                    logger.exception("Failed to revert broker weights after rejection")
        # Note: eval_infer_new/old are stopped in their respective branches
        logger.info(f"Iteration {iter_idx} complete.")
        return prev_net_snapshot, prev_optim_state

    train_global_step = 0

    # Optionally initialize from checkpoint: infer arch, in_channels, and adjust args
    init_sd = None
    ckpt_cleanup = None
    if getattr(args, 'init_checkpoint', None):
        try:
            local_path, ckpt_cleanup = resolve_remote_path(args.init_checkpoint, suffix='.pt', desc='checkpoint')
            ckpt = torch.load(local_path, map_location='cpu')
            sd = extract_state_dict(ckpt)
            arch = infer_arch_from_state_dict(sd)
            az_meta = infer_alphazero_from_state_dict(sd)
            in_ch = az_meta.get('in_channels') or infer_in_channels(arch, sd)
            # Adjust network type
            if arch == 'transformer':
                args.network = 'transformer'
                d_model, num_layers, _dim_ff = infer_transformer_hparams(sd, default_d_model=int(args.embed_dim))
                args.embed_dim = int(d_model)
                # Choose a num_heads that divides d_model; prefer existing if valid
                try:
                    nh = int(args.num_heads)
                except Exception:
                    nh = 4
                if d_model % nh != 0:
                    # pick a divisor (8,4,2) or 1 as last resort
                    for cand in (8, 4, 2, 1):
                        if d_model % cand == 0:
                            nh = cand; break
                args.num_heads = int(nh)
                args.num_layers = int(num_layers)
            elif arch in ('conv', 'resnet', 'conv_player', 'logistic'):
                # Treat everything non-transformer as CNN AlphaNet here
                args.network = 'cnn'
            else:
                args.network = 'cnn'

            # Prefer checkpoint game if available and action size suggests mismatch
            ckpt_args = ckpt.get('args', {}) if isinstance(ckpt, dict) else {}
            ckpt_game_key = None
            try:
                ckpt_game_key = str(ckpt_args.get('game', '')).lower() or None
            except Exception:
                ckpt_game_key = None
            # Select current desired game class
            try:
                desired_game = select_game(str(args.game).lower())
            except Exception:
                desired_game = None
            # Compute policy size from checkpoint if present
            ckpt_policy_size = az_meta.get('action_size')
            # If mismatch likely, switch to checkpoint's game if provided
            if ckpt_game_key:
                try:
                    ckpt_game_cls = select_game(ckpt_game_key)
                except Exception:
                    ckpt_game_cls = None
            else:
                ckpt_game_cls = None
            if ckpt_policy_size is not None and desired_game is not None:
                try:
                    if ckpt_policy_size != desired_game.action_size() and ckpt_game_cls is not None:
                        args.game = ckpt_game_key
                except Exception:
                    pass

            # Infer history/extras to match input channels (if we have in_channels)
            try:
                gcls = select_game(str(args.game).lower())
            except Exception:
                gcls = None
            if in_ch and gcls is not None:
                hist, extras = _infer_hist_and_extras(gcls, int(in_ch), is_transformer=(args.network == 'transformer'))
                args.history_steps = int(hist)
                for k, v in extras.items():
                    setattr(args, k, bool(v))
            # Apply board geometry from checkpoint if present
            try:
                if isinstance(ckpt_args, dict):
                    if ckpt_args.get('board_rows') is not None:
                        args.board_rows = int(ckpt_args['board_rows'])
                    if ckpt_args.get('board_cols') is not None:
                        args.board_cols = int(ckpt_args['board_cols'])
            except Exception:
                pass

            init_sd = sd
            logger.info("Initialized config from checkpoint: network=%s, embed_dim=%s, num_layers=%s, num_heads=%s, history=%s",
                        args.network, getattr(args, 'embed_dim', None), getattr(args, 'num_layers', None), getattr(args, 'num_heads', None), getattr(args, 'history_steps', 0))
            if writer is not None:
                try:
                    writer.add_text("init/checkpoint", str(args.init_checkpoint), 0)
                except Exception:
                    pass
        except Exception:
            logger.exception("Failed to initialize from checkpoint '%s' — proceeding with requested config.", args.init_checkpoint)

    # Build game+net and featurizer config (after possible checkpoint-based arg updates)
    game, net = build_default(args.game, network=args.network, model_device=device, args=args)
    # If we have an init state_dict, load it
    if init_sd is not None:
        missing, unexpected = net.load_state_dict(init_sd, strict=False)
        logger.warning("Loaded weights with missing=%d, unexpected=%d", len(missing), len(unexpected))
        if missing:
            logger.debug("Missing keys: %s", missing)
        if unexpected:
            logger.debug("Unexpected keys: %s", unexpected)
        # cleanup temp file if any
        try:
            if callable(ckpt_cleanup):
                ckpt_cleanup()
        except Exception:
            pass
    feat_cfg = _build_feat_cfg()

    # Initialize synchronization, replay, and backend
    model_lock = threading.RLock()
    replay = ReplayBuffer(capacity=args.replay_size)
    backend, infer, workers, broker, proc_workers, examples_queue = _setup_backend(game, net, device, model_lock, feat_cfg)
    trainer = Trainer(net, replay, device=device, lr=2e-3)

    # Sanity: ensure pipeline produces at least one example
    _wait_for_first_example(examples_queue, replay)

    # Baseline snapshot (on CPU) and optimizer state
    prev_net_snapshot = copy.deepcopy(net).cpu()
    prev_optim_state = copy.deepcopy(trainer.optimizer.state_dict())

    # === Iterations: Collect -> Train -> Evaluate ===
    for iter_idx in range(1, args.iterations + 1):
        logger.info(f"================== Iteration {iter_idx}/{args.iterations} ==================")
        # Keep a persistent replay buffer across iterations to avoid catastrophic forgetting.
        # We no longer reset the buffer here.
        writer.add_scalar("replay/persistent", 1, iter_idx)

        _collect_phase(iter_idx, examples_queue, replay)
        train_global_step = _train_phase(iter_idx, trainer, replay, examples_queue, train_global_step)

        if backend == 'processes' and broker is not None:
            try:
                broker.update_weights(net.state_dict())
                logger.info("Broker weights updated after training")
            except Exception:
                logger.exception("Failed to push updated weights to broker")

        # Arena evaluation only if we ran at least one iteration and requested eval games
        if args.iterations >= 1 and args.eval_games > 0:
            prev_net_snapshot, prev_optim_state = _arena_phase(
                iter_idx, game, net, prev_net_snapshot, prev_optim_state, trainer, backend, broker, model_lock, checkpoint_dir, feat_cfg
            )

    # Shutdown + cleanup
    writer.flush(); writer.close()
    logger.info("Stopping workers and inference server...")
    if backend == 'threads':
        for w in workers: w.stop()
        infer.stop()
    else:
        for p in proc_workers:
            try: p.terminate()
            except Exception: pass
        if broker is not None: broker.stop()
    logger.info("Shutdown complete.")

if __name__ == "__main__":
    main()
