from collections import deque, defaultdict, namedtuple
import threading
from queue import Queue, Empty
import time
import math
import numpy as np
import torch
import torch.nn as nn
import logging  # NEW
from typing import Optional, Callable, Any
try:
    # Optional featurizer for history-only stacking; code works without it
    from src.alphazero.featurizer import TransformerFeaturizer  # noqa: F401
except Exception:
    TransformerFeaturizer = None  # type: ignore

# Optional shaping utilities
try:
    from .shaping import ShapingConfig, annealed_scale, call_phi  # type: ignore
except Exception:
    ShapingConfig = None  # type: ignore
    def annealed_scale(cfg, step):
        return 0.0
    def call_phi(phi_fn, game, board, player):
        return 0.0

# Shared constants
INFERENCE_WAIT_TIMEOUT = 20.0  # seconds

def stack_with_history(current: np.ndarray, history_deque, history_steps: int) -> np.ndarray:
    """Stack current (C,H,W) with up to history_steps frames from a deque/list (most recent first)."""
    try:
        C, H, W = current.shape
    except Exception:
        current = current.astype(np.float32)
        C, H, W = current.shape
    planes = [current.astype(np.float32, copy=False)]
    for i in range(history_steps):
        if i < len(history_deque):
            planes.append(np.asarray(history_deque[i], dtype=np.float32))
        else:
            planes.append(np.zeros((C, H, W), dtype=np.float32))
    return np.concatenate(planes, axis=0)


InferenceResult = namedtuple('InferenceResult', ['policy_logits', 'value'])

class InferenceServer:
    """
    Collects inference requests from many MCTS workers and processes them in batches on the GPU.
    Each request is (board_tensor, callback_event, container_to_fill)
    """
    def __init__(self, net: nn.Module, device='cuda', max_batch_size=64, max_batch_wait=0.02, model_lock = None):
        self.net = net
        self.net.eval()
        # Respect requested device; don't override to CUDA unless specified
        try:
            self.device = torch.device(device) if isinstance(device, (str, torch.device)) else torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        except Exception:
            logging.exception("Invalid device '%s'; defaulting to CPU", str(device))
            self.device = torch.device('cpu')
        # Optional lock to coordinate with weight swaps
        self.model_lock = model_lock or threading.RLock()
        # Ensure model is on the same device as tensors we'll feed  # NEW
        try:
            self.net.to(self.device)
        except Exception:
            logging.exception("Failed to move net to device %s", self.device)
            # Fallback to CPU to keep inference alive
            self.device = torch.device('cpu')
            try:
                self.net.to(self.device)
                logging.warning("InferenceServer fell back to CPU device.")
            except Exception:
                logging.exception("Fallback to CPU also failed; inference likely broken.")
        self.max_batch_size = max_batch_size
        self.max_batch_wait = max_batch_wait
        self.queue = Queue()
        self._stop = threading.Event()
        self.worker = threading.Thread(target=self._run, daemon=True, name="InferenceServer")  # NEW name
        self.worker.start()
        logging.info("InferenceServer started on device %s (batch_size=%d, wait=%.3fs)", str(self.device), self.max_batch_size, self.max_batch_wait)

    def submit(self, board_tensor):
        """Submit a single board tensor (numpy or torch) and return a future-like object (Event + container).
        Returns (event, container). When event is set, container['out'] contains InferenceResult.
        """
        ev = threading.Event()
        container = {}
        try:
            self.queue.put((board_tensor, ev, container))
        except Exception:
            logging.exception("Failed to enqueue inference request")
        return ev, container

    def _run(self):
        buffer = []
        times = []
        while not self._stop.is_set():
            try:
                item = self.queue.get(timeout=self.max_batch_wait)
                buffer.append(item)
                times.append(time.time())
            except Empty:
                pass
            if len(buffer) == 0:
                continue
            if len(buffer) >= self.max_batch_size or (time.time() - times[0]) >= self.max_batch_wait:
                batch = buffer[:self.max_batch_size]
                buffer = buffer[self.max_batch_size:]
                times = times[len(batch):]
                try:
                    logging.debug("InferenceServer processing batch of %d", len(batch))
                    # prepare batch tensor
                    tensors = []
                    expected_C = getattr(self.net, 'C', None)
                    for board_tensor, ev, container in batch:
                        if isinstance(board_tensor, np.ndarray):
                            t = torch.from_numpy(board_tensor).to(dtype=torch.float32)
                        else:
                            t = board_tensor.to(dtype=torch.float32)
                        # Auto-expand channels with zeros if input has fewer channels than model expects
                        try:
                            if expected_C is not None and t.dim() == 3 and t.shape[0] != expected_C:
                                c_in, h, w = t.shape
                                if expected_C > c_in and (expected_C % c_in == 0):
                                    pad = torch.zeros((expected_C - c_in, h, w), dtype=t.dtype, device=t.device)
                                    t = torch.cat([t, pad], dim=0)
                        except Exception:
                            pass
                        tensors.append(t)
                    try:
                        bat = torch.stack(tensors, dim=0).to(self.device, non_blocking=True)
                    except Exception:
                        logging.exception("Failed to move batch to device %s; retrying on CPU", str(self.device))
                        # Fall back to CPU for both inputs and model to keep service alive
                        self.device = torch.device('cpu')
                        try:
                            with self.model_lock:
                                self.net.to(self.device)
                            logging.warning("InferenceServer moved model to CPU after device transfer failure")
                        except Exception:
                            logging.exception("Failed to move model to CPU after device transfer failure")
                        bat = torch.stack(tensors, dim=0).to(self.device)
                    with torch.no_grad():
                        with self.model_lock:
                            p_logits, v = self.net(bat)
                    p_logits = p_logits.detach().cpu().numpy()
                    v = v.detach().cpu().numpy()
                    for i, (board_tensor, ev, container) in enumerate(batch):
                        container['out'] = InferenceResult(policy_logits=p_logits[i], value=float(v[i]))
                        ev.set()
                except Exception:
                    logging.exception("InferenceServer forward failed (batch_size=%d)", len(batch))
                    # Unblock callers to avoid deadlock  # NEW
                    for _, ev, container in batch:
                        container['out'] = None
                        ev.set()
        # drain remaining
        while not self.queue.empty():
            try:
                board_tensor, ev, container = self.queue.get_nowait()
                try:
                    t = torch.from_numpy(board_tensor).float().unsqueeze(0).to(self.device)
                    with torch.no_grad():
                        with self.model_lock:
                            p_logits, v = self.net(t)
                    container['out'] = InferenceResult(policy_logits=p_logits.cpu().numpy()[0], value=float(v.cpu().numpy()[0]))
                except Exception:
                    logging.exception("InferenceServer drain forward failed")
                    container['out'] = None
                ev.set()
            except Empty:
                break

    def stop(self):
        self._stop.set()
        self.worker.join()

# ----------------------------
# === MCTS Node & Search ===
# ----------------------------

class MCTSNode:
    def __init__(self, prior=0.0):
        self.prior = prior
        self.visit_count = 0
        self.total_value = 0.0
        self.children = {}  # action -> MCTSNode
        self.lock = threading.Lock()
        # Track fallback expansions to avoid repeated NN retries at the same node within a search
        self._fallback_expanded = False
        self._nn_failures = 0

    def q_value(self):
        if self.visit_count == 0:
            return 0.0
        return self.total_value / self.visit_count

class MCTS:
    def __init__(self, game, inference_server: InferenceServer, cpuct=1.0, num_simulations=50,
                 root_dirichlet_alpha=0.3, root_exploration_frac=0.25, add_root_noise=False,
                 history_steps: int = 0, featurizer=None,
                 shaping_config: Optional[Any] = None,
                 phi_fn: Optional[Callable[..., float]] = None):
        self.game = game
        self.root = MCTSNode()
        self.inference = inference_server
        self.cpuct = cpuct
        self.num_simulations = num_simulations
        # Optional: external featurizer to construct inputs with history
        self.featurizer = featurizer
        # Optional: potential-based shaping
        self.shaping = shaping_config
        self.phi_fn = phi_fn
        self.shaping_step = 0  # external coordinator may update for annealing
        # Track which canonical board the current root corresponds to for reuse
        self.root_key = None  # bytes key of canonical board
        # Root exploration noise params
        self.root_dirichlet_alpha = float(root_dirichlet_alpha)
        self.root_exploration_frac = float(root_exploration_frac)
        self.add_root_noise = bool(add_root_noise)
        # internal flag set per search()
        self._root_noise_applied = False
        # History planes
        self.history_steps = max(0, int(history_steps or 0))
        self._root_history = deque(maxlen=self.history_steps)

    def set_root_history(self, frames):
        """Set root history frames: list/deque of (C,H,W) arrays from most recent backward (t-1,...)."""
        self._root_history.clear()
        if not frames or self.history_steps <= 0:
            return
        for f in frames:
            if f is None:
                continue
            self._root_history.append(f)

    def _key(self, canonical_board):
        try:
            return canonical_board.tobytes()
        except Exception:
            # Fallback: hash of flattened list
            return bytes(memoryview(np.ascontiguousarray(canonical_board)))

    def _apply_noise_to_existing_root(self, actual_board, player):
        if not self.add_root_noise:
            return
        if not self.root.children:
            return
        # Build prior vector over legal actions from existing children
        legal = self.game.legal_actions(actual_board)
        if not legal:
            return
        with self.root.lock:
            p = np.zeros(self.game.action_size(), dtype=np.float32)
            for a, child in self.root.children.items():
                p[a] = max(0.0, float(child.prior))
            s = p[legal].sum()
            if s <= 0:
                p[legal] = 1.0 / float(len(legal))
            else:
                p[legal] /= s
            noise = np.random.dirichlet([self.root_dirichlet_alpha] * len(legal)).astype(np.float32)
            mixed = (1.0 - self.root_exploration_frac) * p[legal] + self.root_exploration_frac * noise
            # Update child priors
            for idx, a in enumerate(legal):
                if a in self.root.children:
                    self.root.children[a].prior = float(mixed[idx])
            self._root_noise_applied = True

    def search(self, board, player):
        """Run multiple simulations from `board` (canonicalized) and return visit-probabilities.
        Reuse the existing root if it matches this board; otherwise reset. Optionally add Dirichlet noise at root.
        """
        # board is the actual game state object (e.g., python-chess Board). Build canonical for root key only.
        canonical = self.game.canonical_form(board, player)
        key = self._key(canonical)
        if self.root_key != key:
            # New position: reset root but keep tree for potential GC reuse of nodes
            self.root = MCTSNode()
            self.root_key = key
        # reset per-search flag
        self._root_noise_applied = False
        # If root already expanded (due to reuse), perturb existing priors once per search
        if self.add_root_noise and self.root.children:
            self._apply_noise_to_existing_root(board, player)
        for _ in range(self.num_simulations):
            self._simulate(board, self.root, player)
        counts = np.zeros(self.game.action_size(), dtype=np.float32)
        for a, child in self.root.children.items():
            counts[a] = child.visit_count
        probs = counts if counts.sum() == 0 else counts / counts.sum()
        return probs

    def advance_root(self, board_before, action, player, next_board=None):
        """Advance the search tree to the child corresponding to (board_before, action, player).
        Optionally provide next_board to avoid recomputing next_state.
        """
        # Move root pointer to selected child if it exists, else start fresh
        new_root = self.root.children.get(action)
        if new_root is None:
            new_root = MCTSNode()
        self.root = new_root
        # Update root key to match the next canonical position (from next player perspective)
        try:
            if next_board is None:
                nb = self.game.next_state(board_before, action, player)
            else:
                nb = next_board
            next_canonical = self.game.canonical_form(nb, -player)
            self.root_key = self._key(next_canonical)
        except Exception:
            logging.exception("advance_root failed to compute next canonical; resetting root_key")
            self.root_key = None
        # Reset noise flag; noise to be added on next search if enabled
        self._root_noise_applied = False

    def _simulate(self, actual_board, node: MCTSNode, to_play):
        """Run a single simulation from an actual game state.
        Evaluate NN on the canonicalized view for the side to play; use actual state for rules/transitions.
        """
        board_actual = actual_board
        cur = node
        path = []  # list of (parent_node, action_taken)
        edge_deltas = []  # gamma*phi(s') - phi(s) from actor perspective (optional)
        cur_player = to_play

        while True:
            # Terminal check on actual state
            terminal, winner = self.game.is_terminal(board_actual)
            if terminal:
                # Value from the perspective of the current player at this node
                if winner == 0:
                    value = 0.0
                else:
                    value = 1.0 if winner == cur_player else -1.0
                break

            with cur.lock:
                if len(cur.children) == 0:
                    # Expand leaf: evaluate NN on canonical board from current player's perspective
                    legal = self.game.legal_actions(board_actual)
                    canonical = self.game.canonical_form(board_actual, cur_player)
                    enc = self.game.encode_board(canonical).astype(np.float32)
                    if self.featurizer is not None:
                        # Use provided featurizer's per-episode history
                        x = self.featurizer.make_input(enc)
                    else:
                        x = enc
                        if self.history_steps > 0:
                            x = stack_with_history(x, self._root_history, self.history_steps)
                    ev, container = self.inference.submit(x)
                    t0_wait = time.time()
                    ok = ev.wait(timeout=INFERENCE_WAIT_TIMEOUT)
                    if not ok or container.get('out', None) is None:
                        # Fallback: expand with uniform/legal priors once to keep tree progressing
                        cur._nn_failures += 1
                        if not cur._fallback_expanded and legal:
                            p = np.zeros(self.game.action_size(), dtype=np.float32)
                            p[legal] = 1.0 / float(len(legal))
                            for a in legal:
                                cur.children[a] = MCTSNode(prior=float(p[a]))
                            cur._fallback_expanded = True
                            if cur is node and self.add_root_noise and not self._root_noise_applied:
                                # Apply root noise to the uniform prior
                                noise = np.random.dirichlet([self.root_dirichlet_alpha] * len(legal)).astype(np.float32)
                                mixed = (1.0 - self.root_exploration_frac) * p[legal] + self.root_exploration_frac * noise
                                for idx, a in enumerate(legal):
                                    if a in cur.children:
                                        cur.children[a].prior = float(mixed[idx])
                                self._root_noise_applied = True
                            if cur._nn_failures == 1:
                                waited = time.time() - t0_wait
                                logging.warning("Inference failure at leaf (waited %.2fs); using uniform priors over %d legal actions", waited, len(legal))
                        else:
                            logging.error("Inference timeout or failure; returning neutral value")
                        value = 0.0
                        break
                    res = container['out']
                    policy_logits = res.policy_logits
                    # NN value is already from current player's perspective by convention
                    value = float(res.value)

                    policy = np.exp(policy_logits - np.max(policy_logits))
                    mask = np.zeros_like(policy, dtype=np.float32)
                    mask[legal] = 1.0
                    policy *= mask
                    if policy.sum() == 0:
                        policy = mask
                    s = policy.sum()
                    if s == 0:
                        logging.error("No legal actions available but non-terminal state detected")
                        value = 0.0
                        break
                    policy /= s

                    # Inject Dirichlet noise at the root on first expansion if enabled and not applied yet
                    if (cur is node) and self.add_root_noise and not self._root_noise_applied:
                        noise = np.random.dirichlet([self.root_dirichlet_alpha] * len(legal)).astype(np.float32)
                        policy_legal = policy[legal]
                        mixed = (1.0 - self.root_exploration_frac) * policy_legal + self.root_exploration_frac * noise
                        policy = policy.copy()
                        policy[legal] = mixed
                        self._root_noise_applied = True

                    for a in legal:
                        cur.children[a] = MCTSNode(prior=float(policy[a]))
                    break

                # Select
                best_score = -1e9
                best_action = None
                best_child = None
                total_vis = sum(c.visit_count for c in cur.children.values())
                for a, c in cur.children.items():
                    u = c.prior * math.sqrt(total_vis + 1e-8) / (1 + c.visit_count)
                    q = c.q_value()
                    score = q + self.cpuct * u
                    if score > best_score:
                        best_score = score
                        best_action = a
                        best_child = c

            path.append((cur, best_action))

            try:
                next_actual = self.game.next_state(board_actual, best_action, cur_player)
            except Exception:
                logging.exception("next_state failed: action=%s player=%s", best_action, cur_player)
                value = -1.0  # penalize illegal transition from current player's perspective
                # No next state; no shaping delta
                if self.shaping and getattr(self.shaping, 'use_in_mcts', False):
                    edge_deltas.append(0.0)
                break
            # Compute shaping delta for this edge if enabled
            if self.shaping and getattr(self.shaping, 'use_in_mcts', False) and self.phi_fn is not None:
                try:
                    phi_s = float(call_phi(self.phi_fn, self.game, board_actual, cur_player))
                    phi_s2 = float(call_phi(self.phi_fn, self.game, next_actual, -cur_player))
                    d = float(getattr(self.shaping, 'gamma', 1.0)) * phi_s2 - phi_s
                except Exception:
                    d = 0.0
                edge_deltas.append(d)
            # Next node: flip player and re-canonicalize for the next player to move
            cur_player = -cur_player
            board_actual = next_actual
            cur = best_child

        # Backpropagate. Child totals are interpreted from the PARENT's perspective; start with -value.
        sign = -1.0
        use_shape = bool(self.shaping and getattr(self.shaping, 'use_in_mcts', False) and self.phi_fn is not None)
        acc_shape = 0.0
        shape_scale = 0.0
        gamma = 1.0
        if use_shape:
            shape_scale = float(annealed_scale(self.shaping, getattr(self, 'shaping_step', 0)))
            gamma = float(getattr(self.shaping, 'gamma', 1.0))
        for i, (nd, a) in enumerate(reversed(path)):
            shaped_val = value
            if use_shape and i < len(edge_deltas):
                d = float(edge_deltas[-1 - i])
                acc_shape = shape_scale * d + gamma * acc_shape
                shaped_val = value + acc_shape
            with nd.lock:
                child = nd.children[a]
                child.visit_count += 1
                child.total_value += sign * shaped_val
            sign = -sign

        return value

# ----------------------------
# === Self Play Worker ===
# ----------------------------

GameExample = namedtuple('GameExample', ['board', 'pi', 'value'])

class SelfPlayWorker(threading.Thread):
    def __init__(self, game_cls, inference_server: InferenceServer, examples_queue: Queue, num_sims=50, temperature=1.0, history_steps: int = 0, featurizer=None, featurizer_config: dict | None = None,
                 shaping_config: Optional[Any] = None, phi_fn: Optional[Callable[..., float]] = None,
                 root_dirichlet_alpha: float = 0.3, root_exploration_frac: float = 0.25, add_root_noise: bool = True):
        super().__init__(daemon=True)
        self.game_cls = game_cls
        self.inference = inference_server
        self.examples_queue = examples_queue
        self.num_sims = num_sims
        self.temperature = temperature
        self._stop = threading.Event()
        self.history_steps = max(0, int(history_steps or 0))
        self.featurizer = featurizer
        self.featurizer_config = dict(featurizer_config) if featurizer_config else None
        # Optional shaping for MCTS and targets
        self.shaping = shaping_config
        self.phi_fn = phi_fn
        # Root noise config
        self.root_dirichlet_alpha = float(root_dirichlet_alpha)
        self.root_exploration_frac = float(root_exploration_frac)
        self.add_root_noise = bool(add_root_noise)

    def run(self):
        logging.info("SelfPlayWorker started (sims=%d, temp=%.2f)", self.num_sims, self.temperature)  # NEW
        while not self._stop.is_set():
            try:
                examples = self.play_game()
                for ex in examples:
                    self.examples_queue.put(ex)
                logging.debug("Enqueued %d examples from one game", len(examples))  # NEW
            except Exception:
                logging.exception("SelfPlayWorker crashed during play_game; restarting after short delay")
                time.sleep(0.1)

    def stop(self):
        self._stop.set()

    def play_game(self):
        # generate a single self-play game
        board = self.game_cls.get_initial_state()
        player = 1
        # Build a fresh per-episode featurizer if configured
        feat = self.featurizer
        if feat is None and self.featurizer_config is not None and TransformerFeaturizer is not None:
            try:
                feat = TransformerFeaturizer(
                    self.game_cls,
                    history_steps=self.featurizer_config.get('history_steps', 0),
                    include_steps_left_plane=self.featurizer_config.get('include_steps_left_plane', False),
                    include_repetition_plane=self.featurizer_config.get('include_repetition_plane', False),
                    include_since_damage_plane=self.featurizer_config.get('include_since_damage_plane', False),
                )
            except Exception:
                logging.exception("Failed to construct TransformerFeaturizer; proceeding without it")
                feat = None
        if feat is not None:
            try:
                feat.reset()
            except Exception:
                pass
        # Enable root noise during self-play
        mcts = MCTS(
            self.game_cls,
            self.inference,
            num_simulations=self.num_sims,
            add_root_noise=self.add_root_noise,
            root_dirichlet_alpha=self.root_dirichlet_alpha,
            root_exploration_frac=self.root_exploration_frac,
            history_steps=self.history_steps,
            featurizer=feat,
            shaping_config=self.shaping,
            phi_fn=self.phi_fn,
        )
        examples = []
        turn = 0
        # Track history of encoded canonical boards (most recent first)
        from collections import deque as _dq
        hist = _dq(maxlen=self.history_steps)
        # Track repetition based on canonical board hashes
        seen_keys = set()
        # Track actual trajectory for optional shaped targets
        traj_states = []  # states before action
        traj_next_states = []  # states after action
        traj_players = []  # player who acted
        while True:
            # Provide history frames to MCTS (previous encoded canonical frames)
            if self.history_steps > 0:
                mcts.set_root_history(list(hist))
            # Update per-root featurizer context if available
            if feat is not None:
                try:
                    # Steps-left normalization if the game exposes it
                    steps_left_norm = 0.0
                    if hasattr(self.game_cls, 'STEP_CHANNEL') and hasattr(self.game_cls, 'MAX_STEPS'):
                        try:
                            ch = int(getattr(self.game_cls, 'STEP_CHANNEL'))
                            max_steps = float(getattr(self.game_cls, 'MAX_STEPS') or 1.0)
                            val = float(board[ch, 0, 0])
                            steps_left_norm = max(0.0, min(1.0, val / max_steps))
                        except Exception:
                            steps_left_norm = 0.0
                    # Repetition: compare canonical board bytes against seen set
                    try:
                        canonical = self.game_cls.canonical_form(board, player)
                    except Exception:
                        canonical = board if player == 1 else -board
                    try:
                        key = np.ascontiguousarray(canonical).tobytes()
                    except Exception:
                        key = bytes(memoryview(np.ascontiguousarray(canonical)))
                    is_rep = key in seen_keys
                    feat.set_root_context(is_repetition=is_rep, since_last_damage_norm=0.0, steps_left_norm=steps_left_norm)
                except Exception:
                    logging.debug("Failed to update featurizer root context", exc_info=True)
            probs = mcts.search(board, player)
            # temperature sampling
            if self.temperature == 0:
                action = int(np.argmax(probs))
            else:
                probs_temp = probs ** (1.0 / self.temperature)
                if probs_temp.sum() == 0:
                    legal = self.game_cls.legal_actions(board)
                    probs_temp = np.zeros_like(probs)
                    probs_temp[legal] = 1.0
                probs_temp /= probs_temp.sum()
                action = int(np.random.choice(len(probs_temp), p=probs_temp))
            canonical = self.game_cls.canonical_form(board, player)
            board_enc = self.game_cls.encode_board(canonical)
            # Store history-stacked board for training to match model channels
            if feat is not None:
                # Keep the featurizer's internal history in sync with training examples
                stacked = feat.make_input(board_enc.astype(np.float32, copy=False))
                examples.append(GameExample(board=stacked, pi=probs.copy(), value=None))
            elif self.history_steps > 0:
                stacked = stack_with_history(board_enc.astype(np.float32, copy=False), hist, self.history_steps)
                examples.append(GameExample(board=stacked, pi=probs.copy(), value=None))
            else:
                examples.append(GameExample(board=board_enc, pi=probs.copy(), value=None))
            if feat is not None:
                feat.push(board_enc)
            elif self.history_steps > 0:
                hist.appendleft(board_enc)
            # Mark repetition after consuming this position
            try:
                seen_keys.add(key)
            except Exception:
                pass
            prev_board = board
            try:
                board = self.game_cls.next_state(board, action, player)
            except Exception:
                logging.exception("Illegal action chosen: action=%s on turn=%d", action, turn)
                # Treat as terminal loss for the acting player and return filled examples
                winner = -player
                # From current player's perspective this is a loss
                result = -1.0
                cur_player = 1
                filled = []
                for ex in examples:
                    val = result if cur_player == player else -result
                    filled.append(GameExample(board=ex.board, pi=ex.pi, value=val))
                    cur_player *= -1
                return filled
            # Track trajectory edge for optional shaped targets (actor is current player)
            try:
                traj_states.append(prev_board)
                traj_next_states.append(board)
                traj_players.append(player)
            except Exception:
                pass
            # Advance the tree to reuse subtree for the next turn
            try:
                mcts.advance_root(prev_board, action, player, next_board=board)
            except Exception:
                logging.exception("Failed to advance MCTS root; continuing with fresh root")
            terminal, winner = self.game_cls.is_terminal(board)
            if terminal:
                if winner == 0:
                    result = 0.0
                else:
                    result = 1.0 if winner == player else -1.0
                # If shaping of targets is enabled, compute shaped returns per timestep
                use_shape_tgt = bool(self.shaping and getattr(self.shaping, 'use_in_targets', False) and self.phi_fn is not None)
                if use_shape_tgt and len(traj_players) == len(examples):
                    try:
                        gamma = float(getattr(self.shaping, 'gamma', 1.0))
                        scale = float(annealed_scale(self.shaping, 0))
                        # compute per-edge deltas d_t = gamma*phi(s') - phi(s) from actor perspective
                        d_list = []
                        for s, s2, p_act in zip(traj_states, traj_next_states, traj_players):
                            try:
                                d = gamma * float(call_phi(self.phi_fn, self.game_cls, s2, -p_act)) - float(call_phi(self.phi_fn, self.game_cls, s, p_act))
                            except Exception:
                                d = 0.0
                            d_list.append(d)
                        # accumulate discounted deltas backward
                        acc = 0.0
                        shaped_add = [0.0] * len(d_list)
                        for i in reversed(range(len(d_list))):
                            acc = scale * float(d_list[i]) + gamma * acc
                            shaped_add[i] = acc
                        filled = []
                        for i, ex in enumerate(examples):
                            base = result if traj_players[i] == player else -result
                            val = base + shaped_add[i]
                            filled.append(GameExample(board=ex.board, pi=ex.pi, value=val))
                    except Exception:
                        logging.exception("Failed to compute shaped targets; falling back to base targets")
                        cur_player = 1
                        filled = []
                        for ex in examples:
                            val = result if cur_player == player else -result
                            filled.append(GameExample(board=ex.board, pi=ex.pi, value=val))
                            cur_player *= -1
                else:
                    cur_player = 1
                    filled = []
                    for ex in examples:
                        val = result if cur_player == player else -result
                        filled.append(GameExample(board=ex.board, pi=ex.pi, value=val))
                        cur_player *= -1
                return filled
            player = -player
            turn += 1
