"""
Model definitions for AlphaZero. Includes:
    - AlphaNet: A CNN-based architecture for policy/value prediction.
    - TransformerAlphaNet: A transformer-based architecture with attention mechanisms.

"""
import math
from typing import Optional, List, Dict
import torch
import torch.nn as nn
import torch.nn.functional as F

class AlphaNet(nn.Module):
    def __init__(self, board_shape=(2, 3, 3), action_size=9, channels=64, policy_out_channels: int = 2):
        super().__init__()
        c, h, w = board_shape
        
        self.C, self.H, self.W = int(c), int(h), int(w)
        self.policy_out_channels = int(policy_out_channels)
        self.conv1 = nn.Conv2d(c, channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(channels)
        
        self.res_layers = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(channels, channels, 3, padding=1),
                nn.BatchNorm2d(channels),
                nn.ReLU(),
                nn.Conv2d(channels, channels, 3, padding=1),
                nn.BatchNorm2d(channels),
            ) for _ in range(3)
        ])
        # policy head
        self.policy_conv = nn.Conv2d(channels, self.policy_out_channels, kernel_size=1)
        self.policy_bn = nn.BatchNorm2d(self.policy_out_channels)
        self.policy_fc = nn.Linear(self.policy_out_channels * h * w, action_size)
        # value head
        self.value_conv = nn.Conv2d(channels, 1, kernel_size=1)
        self.value_bn = nn.BatchNorm2d(1)
        self.value_fc1 = nn.Linear(1 * h * w, channels)
        self.value_fc2 = nn.Linear(channels, 1)

    def forward(self, x):
        # x: (B, C, H, W)
        out = F.relu(self.bn1(self.conv1(x)))
        for res in self.res_layers:
            r = res(out)
            out = F.relu(out + r)
        # policy
        p = F.relu(self.policy_bn(self.policy_conv(out)))
        p = p.view(p.size(0), -1)
        p = self.policy_fc(p)
        # value
        v = F.relu(self.value_bn(self.value_conv(out)))
        v = v.view(v.size(0), -1)
        v = F.relu(self.value_fc1(v))
        v = torch.tanh(self.value_fc2(v)).squeeze(-1)
        return p, v


class TransformerAlphaNet(nn.Module):
    """
    Transformer-based policy/value network.

    Inputs:
      - x: (B, C, H, W) canonical board tensor

    Outputs:
      - policy_logits: (B, action_size)
      - value: (B,) in [-1, 1]
    """
    def __init__(
        self,
        board_shape=(2, 3, 3),
        action_size=9,
        embed_dim=128,
        depth=4,
        num_heads=4,
        mlp_ratio=2.0,
        dropout=0.1,
        use_cls_token=True,
        include_action_tokens=False,
        action_token_in_dim=None,
        cross_attn_layers=1,
        use_sinusoidal_2d_pe=False,
        use_relative_bias=False,
        enable_global_context=False,
    ):
        super().__init__()
        self.C, self.H, self.W = board_shape
        self.N = self.H * self.W
        self.action_size = action_size
        self.embed_dim = embed_dim
        self.use_cls = use_cls_token
        
        self.include_action_tokens = bool(include_action_tokens)
        self.cross_attn_layers = max(0, int(cross_attn_layers))
        self.use_sinusoidal_2d_pe = bool(use_sinusoidal_2d_pe)
        self.use_relative_bias = bool(use_relative_bias)
        self.enable_global_context = bool(enable_global_context)

        # Project per-cell features to token embeddings
        self.proj = nn.Conv2d(self.C, embed_dim, kernel_size=1, bias=True)

        # Positional embeddings
        n_tokens = self.N + (1 if self.use_cls else 0)
        if not self.use_sinusoidal_2d_pe:
            self.pos_embed = nn.Parameter(torch.zeros(1, n_tokens, embed_dim))
            nn.init.trunc_normal_(self.pos_embed, std=0.02)
        else:
            self.register_parameter('pos_embed', None)

        # Optional class token for global aggregation
        if self.use_cls:
            self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
            nn.init.trunc_normal_(self.cls_token, std=0.02)
        else:
            self.cls_token = None

        # Transformer encoder
        dff = int(embed_dim * mlp_ratio)
        enc_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim,
            nhead=num_heads,
            dim_feedforward=dff,
            dropout=dropout,
            activation="gelu",
            batch_first=True,
            norm_first=True,
        )
        self.encoder = nn.TransformerEncoder(enc_layer, num_layers=depth)

        # Heads
        self.head_norm = nn.LayerNorm(embed_dim)
        self.policy_head = nn.Linear(embed_dim, action_size)

        self.value_mlp = nn.Sequential(
            nn.LayerNorm(embed_dim),
            nn.Linear(embed_dim, embed_dim),
            nn.GELU(),
            nn.Linear(embed_dim, 1),
        )

        # Optional global context modules
        if self.enable_global_context:
            self.side_embed = nn.Embedding(2, embed_dim)
            self.turn_proj = nn.Sequential(
                nn.Linear(1, embed_dim), nn.GELU(), nn.Linear(embed_dim, embed_dim)
            )
        else:
            self.side_embed = None
            self.turn_proj = None

        # Optional action-token modules
        if self.include_action_tokens:
            at_in = action_token_in_dim if action_token_in_dim is not None else embed_dim
            self.action_proj = nn.Linear(at_in, embed_dim)
            self.action_norm_q = nn.LayerNorm(embed_dim)
            self.action_attn = nn.ModuleList([
                nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
                for _ in range(self.cross_attn_layers or 1)
            ])
            self.action_ffn = nn.ModuleList([
                nn.Sequential(
                    nn.LayerNorm(embed_dim),
                    nn.Linear(embed_dim, int(embed_dim * mlp_ratio)),
                    nn.GELU(),
                    nn.Linear(int(embed_dim * mlp_ratio), embed_dim),
                )
                for _ in range(self.cross_attn_layers or 1)
            ])
            self.action_policy_head = nn.Linear(embed_dim, 1)
        else:
            self.action_proj = None
            self.action_norm_q = None
            self.action_attn = None
            self.action_ffn = None
            self.action_policy_head = None

        self.apply(self._init_weights)

        # Relative bias tables (2D) if enabled
        if self.use_relative_bias:
            # Learn a scalar bias per 2D displacement (shared across heads)
            self.rel_bias_table = nn.Parameter(torch.zeros((2 * self.H - 1, 2 * self.W - 1)))
            nn.init.trunc_normal_(self.rel_bias_table, std=0.02)
            # Precompute index maps for (i,j) -> (dr_index, dc_index) over spatial tokens
            coords = torch.stack(
                torch.meshgrid(torch.arange(self.H), torch.arange(self.W), indexing='ij'),
                dim=-1,
            ).view(-1, 2)  # (N,2) of (r,c)
            r_i = coords[:, 0].unsqueeze(1)
            c_i = coords[:, 1].unsqueeze(1)
            r_j = coords[:, 0].unsqueeze(0)
            c_j = coords[:, 1].unsqueeze(0)
            dr = (r_i - r_j) + (self.H - 1)  # (N,N)
            dc = (c_i - c_j) + (self.W - 1)  # (N,N)
            # Store as buffers for fast indexing; dtype long
            self.register_buffer('rel_bias_dr_idx', dr.long(), persistent=False)
            self.register_buffer('rel_bias_dc_idx', dc.long(), persistent=False)
        else:
            self.rel_bias_table = None
            self.register_buffer('rel_bias_dr_idx', None, persistent=False)
            self.register_buffer('rel_bias_dc_idx', None, persistent=False)

        # Internal flag to ensure we patch attention once for capture
        self._attn_capture_patched = False
        # Buffers to collect attention maps per forward (by encoder layer index)
        # Each entry will be a tensor (B, heads, T, T)
        self._captured_attn = []  # type: List[torch.Tensor]
        # Map of MultiheadAttention module -> encoder layer index (filled on patch)
        self._attn_layer_index = {}  # type: Dict[int, int]

    def _ensure_attn_capture(self):
        """Wrap encoder self-attention modules to request and store per-head attention weights.

        The wrapper forces need_weights=True and average_attn_weights=False, stores the
        per-head weights on the attention module as '._last_attn', and also stores per-layer
        weights in self._captured_attn at the correct encoder layer index.
        If the backend returns no weights (e.g., SDPA path), compute them manually from q,k.
        """
        if self._attn_capture_patched:
            return

        def _compute_attn_weights_manual(mha: nn.MultiheadAttention, q_in: torch.Tensor, k_in: torch.Tensor, v_in: torch.Tensor, kw: dict) -> torch.Tensor:
            """Compute attention weights manually when not returned by backend."""
            batch_first = getattr(mha, 'batch_first', False)
            if not batch_first and q_in.dim() == 3 and k_in.dim() == 3:
                q_in = q_in.transpose(0, 1)
                k_in = k_in.transpose(0, 1)
                v_in = v_in.transpose(0, 1)
            
            E = mha.embed_dim
            H = mha.num_heads
            head_dim = E // H
            scale = 1.0 / math.sqrt(head_dim)

            def proj(x: torch.Tensor, w: torch.Tensor, b: Optional[torch.Tensor]) -> torch.Tensor:
                return F.linear(x, w, b)

            # Handle different projection weight configurations
            if getattr(mha, '_qkv_same_embed_dim', True):
                # Standard case: single in_proj_weight for all Q, K, V
                if hasattr(mha, 'in_proj_weight') and mha.in_proj_weight is not None:
                    Wq, Wk, Wv = mha.in_proj_weight.chunk(3, dim=0)
                    if mha.in_proj_bias is not None:
                        bq, bk, bv = mha.in_proj_bias.chunk(3, dim=0)
                    else:
                        bq = bk = bv = None
                    q = proj(q_in, Wq, bq)
                    k = proj(k_in, Wk, bk)
                else:
                    # Fallback to separate projection weights
                    bq = bk = None
                    if hasattr(mha, 'q_proj_weight'):
                        q = proj(q_in, mha.q_proj_weight, getattr(mha, 'q_proj_bias', None))
                    if hasattr(mha, 'k_proj_weight'):
                        k = proj(k_in, mha.k_proj_weight, getattr(mha, 'k_proj_bias', None))
            else:
                # Separate projection weights
                bq = bk = None
                if mha.in_proj_bias is not None:
                    bq = mha.in_proj_bias[:E]
                    bk = mha.in_proj_bias[E:2*E]
                q = proj(q_in, mha.q_proj_weight, bq)
                k = proj(k_in, mha.k_proj_weight, bk)

            B, T, _ = q.shape
            S = k.shape[1]
            q = q.view(B, T, H, head_dim).transpose(1, 2)  # (B,H,T,hd)
            k = k.view(B, S, H, head_dim).transpose(1, 2)  # (B,H,S,hd)
            scores = torch.matmul(q, k.transpose(-2, -1)) * scale  # (B,H,T,S)

            # Apply attention mask
            attn_mask = kw.get('attn_mask', None)
            if attn_mask is not None:
                # Handle different mask formats
                if attn_mask.dim() == 2:  # (T, S)
                    attn_mask = attn_mask.unsqueeze(0).unsqueeze(0)  # (1, 1, T, S)
                elif attn_mask.dim() == 3:  # (B, T, S) or (H, T, S)
                    if attn_mask.shape[0] == B:
                        attn_mask = attn_mask.unsqueeze(1)  # (B, 1, T, S)
                    else:
                        attn_mask = attn_mask.unsqueeze(0)  # (1, H, T, S)
                
                if attn_mask.dtype == torch.bool:
                    scores = scores.masked_fill(attn_mask, float('-inf'))
                else:
                    scores = scores + attn_mask.to(scores.dtype)

            # Apply key padding mask
            kpm = kw.get('key_padding_mask', None)
            if kpm is not None and kpm.numel() > 0:
                # kpm is typically (B, S) where True means "ignore this key"
                if kpm.dim() == 2:
                    kpm = kpm.unsqueeze(1).unsqueeze(2)  # (B, 1, 1, S)
                scores = scores.masked_fill(kpm, float('-inf'))

            # Apply dropout if in training mode
            weights = torch.softmax(scores, dim=-1)  # (B, H, T, S)
            
            # Apply attention dropout if module has it and is in training mode
            if hasattr(mha, 'dropout') and mha.training:
                weights = F.dropout(weights, p=mha.dropout, training=True)
            
            return weights.detach().cpu()

        def _wrap_mha(m: nn.MultiheadAttention, layer_index: Optional[int] = None):
            # Chain-wrap even if another patch already wrapped forward.
            if hasattr(m, '_orig_forward_for_attn_capture2'):
                return
            orig = m.forward

            def new_forward(query, key, value, *args, **kwargs):
                kw = dict(kwargs)
                # Force per-head attention weights
                kw['need_weights'] = True
                kw['average_attn_weights'] = False
                
                out = orig(query, key, value, *args, **kw)
                
                # Always attempt manual per-head computation; keep backend weights as fallback
                attn_out = out[0] if (isinstance(out, tuple) and len(out) >= 1) else out
                backend_w = out[1] if (isinstance(out, tuple) and len(out) >= 2) else None

                attn_det: Optional[torch.Tensor] = None
                error_str: Optional[str] = None
                # Diagnostics: count calls and record q/k shapes
                try:
                    cnt = getattr(m, '_forward_called_count', 0)
                    setattr(m, '_forward_called_count', int(cnt) + 1)
                    setattr(m, '_last_q_shape', list(query.shape) if hasattr(query, 'shape') else None)
                    setattr(m, '_last_k_shape', list(key.shape) if hasattr(key, 'shape') else None)
                except Exception:
                    pass
                try:
                    attn_det = _compute_attn_weights_manual(m, query, key, value, kw)
                except Exception as e:
                    error_str = str(e)
                    attn_det = None

                # If manual path failed, consider backend output
                if attn_det is None and isinstance(backend_w, torch.Tensor) and backend_w.numel() > 0:
                    try:
                        if backend_w.dim() == 4:
                            attn_det = backend_w.detach().cpu()
                        elif backend_w.dim() == 3:
                            # Averaged across heads; expand as 1-head tensor
                            attn_det = backend_w.unsqueeze(1).detach().cpu()
                    except Exception as e:
                        error_str = f"backend_w err: {e}"

                # Store results (or None) and lightweight diagnostics
                try:
                    if attn_det is not None and attn_det.dim() == 4:
                        m._last_attn = attn_det  # type: ignore[attr-defined]
                        setattr(m, '_last_attn_shape', list(attn_det.shape))  # type: ignore[attr-defined]
                        # Determine layer index for deterministic storage
                        idx = getattr(m, '_layer_index', None)
                        if idx is None:
                            idx = self._attn_layer_index.get(id(m), None)
                        if hasattr(self, '_captured_attn') and isinstance(self._captured_attn, list):
                            if isinstance(idx, int) and idx >= 0:
                                while len(self._captured_attn) <= idx:
                                    self._captured_attn.append(None)
                                self._captured_attn[idx] = attn_det
                            else:
                                self._captured_attn.append(attn_det)
                    else:
                        m._last_attn = None  # type: ignore[attr-defined]
                        # Provide more helpful diagnostics when shape unexpected
                        try:
                            if attn_det is None:
                                setattr(m, '_last_attn_err', f"manual=None; backend={'ok' if isinstance(backend_w, torch.Tensor) else 'none'}; err={error_str}")  # type: ignore[attr-defined]
                            else:
                                setattr(m, '_last_attn_err', f"manual_shape={list(attn_det.shape)}; err={error_str}")  # type: ignore[attr-defined]
                        except Exception:
                            setattr(m, '_last_attn_err', error_str)  # type: ignore[attr-defined]
                except Exception:
                    pass

                return attn_out

            m._orig_forward_for_attn_capture2 = orig  # type: ignore[attr-defined]
            m.forward = new_forward  # type: ignore[method-assign]

        # Patch all encoder layers' self-attn modules
        layers = getattr(self.encoder, 'layers', None)
        if isinstance(layers, (list, nn.ModuleList)):
            for li, lyr in enumerate(layers):
                attn = getattr(lyr, 'self_attn', None)
                if isinstance(attn, nn.MultiheadAttention):
                    try:
                        setattr(attn, '_layer_index', li)
                    except Exception:
                        pass
                    self._attn_layer_index[id(attn)] = li
                    _wrap_mha(attn, layer_index=li)
                    # Also attach a layer-level forward hook to compute attention weights from inputs
                    if not hasattr(lyr, '_attn_capture_hook_registered'):
                        def _layer_hook(module, inputs, output, li=li, mha=attn):
                            try:
                                # inputs: tuple like (src,) because masks often passed by kwargs
                                src = inputs[0] if (isinstance(inputs, tuple) and len(inputs) >= 1) else None
                                if src is None or not torch.is_tensor(src):
                                    return
                                # Determine tokens before attention (respect norm_first when possible)
                                try:
                                    if getattr(module, 'norm_first', False) and hasattr(module, 'norm1'):
                                        q_in = module.norm1(src)
                                    else:
                                        q_in = src
                                except Exception:
                                    q_in = src
                                # Compute weights from q_in using the attn module parameters
                                w = None
                                try:
                                    w = _compute_attn_weights_manual(mha, q_in, q_in, q_in, {})
                                except Exception as e:
                                    try:
                                        setattr(mha, '_last_attn_err', f'layer_hook:{e}')  # type: ignore[attr-defined]
                                    except Exception:
                                        pass
                                if isinstance(w, torch.Tensor) and w.dim() == 4:
                                    # Store on module and model buffer
                                    try:
                                        mha._last_attn = w  # type: ignore[attr-defined]
                                        setattr(mha, '_last_attn_shape', list(w.shape))  # type: ignore[attr-defined]
                                    except Exception:
                                        pass
                                    if hasattr(self, '_captured_attn') and isinstance(self._captured_attn, list):
                                        while len(self._captured_attn) <= li:
                                            self._captured_attn.append(None)
                                        self._captured_attn[li] = w
                            except Exception:
                                pass
                        try:
                            lyr.register_forward_hook(_layer_hook)
                            setattr(lyr, '_attn_capture_hook_registered', True)
                        except Exception:
                            pass
                    
        # Patch optional action attention modules as well (cross-attn path)
        if isinstance(getattr(self, 'action_attn', None), (list, nn.ModuleList)):
            for mha in self.action_attn:  # type: ignore[assignment]
                if isinstance(mha, nn.MultiheadAttention):
                    _wrap_mha(mha, layer_index=None)
        
        self._attn_capture_patched = True
    
    def _init_weights(self, m):
        if isinstance(m, (nn.Linear, nn.Conv2d)):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.zeros_(m.bias)
        elif isinstance(m, nn.LayerNorm):
            nn.init.zeros_(m.bias)
            nn.init.ones_(m.weight)

    def _add_pos_embed(self, tokens, cur_hw=None):
        """Add positional embeddings to tokens, adapting to current input HxW.

        - Uses learned pos_embed when available; if token count mismatches the
          initialized grid, interpolate to current (h, w).
        - Falls back to 1D interpolation across the token dimension if 2D shape
          cannot be reliably inferred.
        """
        B, T, D = tokens.shape
        n_cls = 1 if self.use_cls else 0

        # Split class token and spatial tokens if applicable
        if self.use_cls:
            cls_tok, spatial = tokens[:, :1], tokens[:, 1:]
        else:
            cls_tok, spatial = None, tokens

        # Determine current spatial grid from caller if provided; else try to infer
        if isinstance(cur_hw, (tuple, list)) and len(cur_hw) == 2:
            cur_h, cur_w = int(cur_hw[0]), int(cur_hw[1])
        else:
            # Try to infer from model's last-known dims, else from a near-square guess
            prod = spatial.size(1)
            if self.H * self.W == prod:
                cur_h, cur_w = int(self.H), int(self.W)
            else:
                # Heuristic: choose factors close to sqrt(prod)
                side = int(math.sqrt(prod))
                # Find a divisor of prod closest to side
                best_h = max(1, side)
                while best_h > 1 and (prod % best_h) != 0:
                    best_h -= 1
                cur_h = best_h
                cur_w = prod // best_h

        if self.use_sinusoidal_2d_pe:
            pos = self._build_2d_sincos(cur_h, cur_w, D, tokens.device)  # (H*W, D)
            out_spatial = spatial + pos.unsqueeze(0)
            if self.use_cls:
                return torch.cat([cls_tok, out_spatial], dim=1)
            return out_spatial

        # Learned positional embeddings
        if T == (self.N + n_cls):
            # Perfect match with initialized grid
            return tokens + self.pos_embed

        # Interpolate spatial positions to match current HxW if needed
        pe_full = self.pos_embed
        pe_spatial = pe_full[:, 1:] if self.use_cls else pe_full  # (1, N0, D)

        try:
            # 2D interpolation using initialized HxW to current cur_h x cur_w
            pe_2d = pe_spatial.reshape(1, self.H, self.W, D).permute(0, 3, 1, 2)  # (1,D,H0,W0)
            pe_interp = F.interpolate(pe_2d, size=(cur_h, cur_w), mode="bilinear", align_corners=False)
            pe_tokens = pe_interp.permute(0, 2, 3, 1).reshape(1, cur_h * cur_w, D)
            out_spatial = spatial + pe_tokens
        except Exception:
            # Fallback: 1D interpolate along token axis to match spatial length
            T0 = pe_spatial.size(1)
            pe_seq = pe_spatial.permute(0, 2, 1)  # (1, D, T0)
            pe_seq_interp = F.interpolate(pe_seq, size=spatial.size(1), mode="linear", align_corners=False)
            pe_tokens = pe_seq_interp.permute(0, 2, 1)  # (1, T_spatial, D)
            out_spatial = spatial + pe_tokens

        if self.use_cls:
            cls_added = cls_tok + pe_full[:, :1]
            return torch.cat([cls_added, out_spatial], dim=1)
        return out_spatial

    @staticmethod
    def _build_2d_sincos(H, W, D, device):
        assert D % 4 == 0, "embed_dim must be divisible by 4 for 2D sincos"
        d = D // 4
        inv = 1.0 / (10000 ** (torch.arange(0, d, device=device).float() / d))
        y = torch.arange(H, device=device).float()
        x = torch.arange(W, device=device).float()
        siny = torch.sin(torch.einsum('i,j->ij', y, inv))
        cosy = torch.cos(torch.einsum('i,j->ij', y, inv))
        sinx = torch.sin(torch.einsum('i,j->ij', x, inv))
        cosx = torch.cos(torch.einsum('i,j->ij', x, inv))
        pos_y = torch.cat([siny, cosy], dim=1)  # (H, 2d)
        pos_x = torch.cat([sinx, cosx], dim=1)  # (W, 2d)
        pos = pos_y[:, None, :].expand(H, W, -1)
        pos = torch.cat([pos, pos_x[None, :, :].expand(H, W, -1)], dim=2)  # (H,W,4d)
        pos = pos.view(H * W, D)
        return pos

    def forward(self, x, side_to_move=None, turn=None, legal_mask=None,
                action_tokens=None, action_index=None):
        # x: (B, C, H, W)
        b, c, h, w = x.shape
        assert c == self.C, f"Expected channels={self.C}, got {c}"
        
        # Ensure attention capture wrappers are active
        self._ensure_attn_capture()
        
        # Reset per-forward attention buffer
        try:
            n_layers = len(getattr(self.encoder, 'layers', []))
        except Exception:
            n_layers = 4  # fallback to reasonable default
        self._captured_attn = [None] * n_layers  # Pre-allocate with correct size

        feat = self.proj(x)                  # (B, D, H, W)
        tokens = feat.flatten(2).transpose(1, 2)  # (B, N, D)

        if self.use_cls:
            cls = self.cls_token.expand(b, -1, -1)  # (B,1,D)
            tokens = torch.cat([cls, tokens], dim=1)  # (B, N+1, D)

        tokens = self._add_pos_embed(tokens, cur_hw=(h, w))        # (B, N(+1), D)
        
        if self.use_relative_bias:
            # Build additive attention bias of shape (T,T) on the fly and pass as mask
            T = tokens.size(1)
            device = tokens.device
            # Spatial bias (N,N)
            bias_spatial = self.rel_bias_table[self.rel_bias_dr_idx, self.rel_bias_dc_idx]  # (N,N)
            if self.use_cls:
                attn_bias = torch.zeros((T, T), device=device, dtype=bias_spatial.dtype)
                attn_bias[1:, 1:] = bias_spatial.to(device)
            else:
                attn_bias = bias_spatial.to(device)
            enc = self.encoder(tokens, mask=attn_bias)
        else:
            enc = self.encoder(tokens)                  # (B, N(+1), D)

        pooled = enc[:, 0] if self.use_cls else enc.mean(dim=1)  # (B, D)

        # Optional global context
        if self.enable_global_context:
            if (side_to_move is not None) and (self.side_embed is not None):
                pooled = pooled + self.side_embed(side_to_move.long().clamp(0, 1))
            if (turn is not None) and (self.turn_proj is not None):
                t = turn.float().view(b, 1)
                pooled = pooled + self.turn_proj(t)

        # Policy logits
        if self.include_action_tokens and (action_tokens is not None) and (action_index is not None) and (self.action_proj is not None):
            q = self.action_proj(action_tokens)  # (B, L, D)
            for i in range(self.cross_attn_layers or 1):
                qn = self.action_norm_q(q)
                attn_out, _ = self.action_attn[i](qn, enc, enc)
                q = q + attn_out
                q = q + self.action_ffn[i](q)
            logits_small = self.action_policy_head(q).squeeze(-1)  # (B, L)
            logits = pooled.new_full((b, self.action_size), -1e9)
            logits.scatter_(1, action_index.long().clamp(0, self.action_size - 1), logits_small)
        else:
            logits = self.policy_head(self.head_norm(pooled))        # (B, A)

        if legal_mask is not None:
            mask = ~legal_mask.bool()
            logits = logits.masked_fill(mask, -1e9)

        value = torch.tanh(self.value_mlp(pooled)).squeeze(-1)   # (B,)
        return logits, value


    def get_captured_attention(self, layer: Optional[int] = None, *, stack: bool = False,
                               clear: bool = False, detach: bool = True):
        """Return attention maps captured during the last forward pass.

        Returns:
          - If layer is None: a list of tensors [L0, L1, ...], each shaped (B, heads, T, T)
          - If layer is an int: the tensor for that layer/index, or None if out of range

        Args:
          layer: Optional index into the captured attention list.
          stack: If True and layer is None, attempts to stack the list into a single tensor.
          clear: If True, clears the internal buffer after reading.
          detach: If True, returns detached CPU tensors to avoid tying to graph/device.
        """
        attn_list = getattr(self, '_captured_attn', None)
        if not isinstance(attn_list, list):
            return [] if layer is None else None
        
        # If a specific layer is requested, return its entry directly (or None)
        if layer is not None:
            idx = int(layer)
            if -len(attn_list) <= idx < len(attn_list):
                t = attn_list[idx]
                if isinstance(t, torch.Tensor) and t.numel() > 0:
                    return t.detach().cpu() if detach else t
                return None
            return None
        
        # Otherwise, filter out empty/None placeholders for the full list
        items = []
        for t in attn_list:
            if isinstance(t, torch.Tensor) and t.numel() > 0:
                items.append(t.detach().cpu() if detach else t)
        
        try:
            if clear:
                self._captured_attn = []
        except Exception:
            pass
            
        if stack:
            try:
                return torch.stack(items, dim=0)
            except Exception:
                return items
        return items

    def get_attention_head(self, layer: int, head: int, batch_idx: int = 0):
        """Get attention weights for a specific layer, head, and batch item.
        
        Args:
            layer: Layer index
            head: Head index  
            batch_idx: Batch index (default 0)
            
        Returns:
            torch.Tensor of shape (T, T) or None if not available
        """
        attn = self.get_captured_attention(layer=layer)
        if attn is None:
            return None
        
        try:
            if attn.dim() == 4:  # (B, H, T, T)
                return attn[batch_idx, head]  # (T, T)
            else:
                return None
        except (IndexError, AttributeError):
            return None
    
    def get_attention_summary(self, batch_idx: int = 0):
        """Get a summary of attention patterns across all layers and heads.
        
        Args:
            batch_idx: Which item in the batch to analyze
            
        Returns:
            dict with attention statistics
        """
        all_attn = self.get_captured_attention()
        if not all_attn:
            return {"error": "No attention weights captured"}
        
        summary = {
            "num_layers": len(all_attn),
            "layers": {}
        }
        
        for layer_idx, attn in enumerate(all_attn):
            if attn is None or attn.numel() == 0:
                continue
                
            try:
                if attn.dim() == 4 and batch_idx < attn.shape[0]:  # (B, H, T, T)
                    layer_attn = attn[batch_idx]  # (H, T, T)
                    num_heads, seq_len, _ = layer_attn.shape
                    
                    # Compute statistics per head
                    head_stats = {}
                    for h in range(num_heads):
                        head_attn = layer_attn[h]  # (T, T)
                        head_stats[f"head_{h}"] = {
                            "mean": float(head_attn.mean()),
                            "std": float(head_attn.std()),
                            "max": float(head_attn.max()),
                            "min": float(head_attn.min()),
                            "entropy": float(-torch.sum(head_attn * torch.log(head_attn + 1e-12), dim=-1).mean()),
                        }
                    
                    summary["layers"][f"layer_{layer_idx}"] = {
                        "num_heads": num_heads,
                        "sequence_length": seq_len,
                        "shape": list(layer_attn.shape),
                        "head_statistics": head_stats,
                        "mean_attention": float(layer_attn.mean()),
                        "attention_variance": float(layer_attn.var()),
                    }
            except Exception as e:
                summary["layers"][f"layer_{layer_idx}"] = {"error": str(e)}
        
        return summary
