"""
A new decisive variant BattlefieldDuel with shrinking safe zone, center capture streak, extended shoot range, shorter horizon, and tie-breakers;
added ammo-agnostic heuristic. Registered keys: duel, battlefield_duel, bf_duel. Initial runs show wins occurring (reducing pure draw scenario),
though draws still high (40/50). To further cut draws: shrink faster (SHRINK_INTERVAL 4), lower CAPTURE_STEPS to 2, or reduce MAX_STEPS to ~28;
add minor per-move penalty in MCTS to pressure action. Let me know if you want those adjustments applied.

"""
from random import random
import numpy as np

class BattlefieldDuel:
    """Decisive Battlefield variant designed to reduce draws.

    Mechanics differences vs base Battlefield:
      - Shrinking safe zone: every SHRINK_INTERVAL steps (after action), the safe zone shrinks by one layer from all sides.
        Players outside the safe zone immediately lose all remaining health (eliminated).
      - Center capture objective: occupying the center cell for CAPTURE_STEPS consecutive turns (not necessarily uninterrupted
        by opponent) grants an immediate win.
      - Shorter time horizon (MAX_STEPS) and slightly larger shooting range encourage confrontation.
      - Limited obstacles (fewer random blockers) to prevent permanent stalemates.

    State tensor channels (CHANNELS=8):
      0: Player 1 position (one-hot)
      1: Player 2 position (one-hot)
      2: Player 1 health (constant plane of current HP value)
      3: Player 2 health
      4: Obstacles (binary)
      5: Steps remaining (uniform plane)
      6: P1 center streak counter (uniform plane)
      7: P2 center streak counter (uniform plane)

    Actions (0-7): Move Up/Down/Left/Right, Shoot Up/Down/Left/Right.
    Players alternate turns; identifiers are 1 and -1.

    Goal is to eliminate the opponent or capture the center. Center capture is achieved
    by occupying the center cell for CAPTURE_STEPS consecutive turns.
    """
    ROWS = 10
    COLS = 10
    NUM_OBSTACLES = 8
    MAX_HEALTH = 3
    SHOOT_RANGE = 2
    MAX_STEPS = 36
    CAPTURE_STEPS = 3
    SHRINK_INTERVAL = 6
    CHANNELS = 8

    ACTION_NAMES = [
        "Move Up", "Move Down", "Move Left", "Move Right",
        "Shoot Up", "Shoot Down", "Shoot Left", "Shoot Right"
    ]

    @staticmethod
    def action_names():
        return list(BattlefieldDuel.ACTION_NAMES)

    @classmethod
    def configure(cls, board_rows=None, board_cols=None, **overrides):
        """Override core geometry or other tunable class-level parameters.

        Supported (if attributes exist): ROWS, COLS, NUM_OBSTACLES, SHOOT_RANGE, MAX_STEPS,
        CAPTURE_STEPS, SHRINK_INTERVAL, MAX_HEALTH.
        Unknown keys are ignored gracefully.
        """
        try:
            if board_rows is not None:
                cls.ROWS = int(board_rows)
            if board_cols is not None:
                cls.COLS = int(board_cols)
            for k, v in overrides.items():
                if hasattr(cls, k):
                    try:
                        setattr(cls, k, type(getattr(cls, k))(v))
                    except Exception:
                        try:
                            setattr(cls, k, v)
                        except Exception:
                            pass
        except Exception:
            pass

    @staticmethod
    def action_size():
        return 8

    @staticmethod
    def get_initial_state(num_obstacles=NUM_OBSTACLES):
        s = np.zeros((BattlefieldDuel.CHANNELS, BattlefieldDuel.ROWS, BattlefieldDuel.COLS), dtype=np.int8)
        # Random columns (ensure not same to reduce early symmetry)
        c1 = np.random.randint(0, BattlefieldDuel.COLS)
        c2_choices = [c for c in range(BattlefieldDuel.COLS) if c != c1]
        c2 = np.random.choice(c2_choices)
        s[0, 0, c1] = 1
        s[1, BattlefieldDuel.ROWS - 1, c2] = 1
        s[2, :, :] = BattlefieldDuel.MAX_HEALTH
        s[3, :, :] = BattlefieldDuel.MAX_HEALTH
        s[5, :, :] = BattlefieldDuel.MAX_STEPS
        # Obstacles: avoid starting cells & center; fewer to maintain path connectivity
        forbidden = {(0, c1), (BattlefieldDuel.ROWS - 1, c2), (BattlefieldDuel.ROWS // 2, BattlefieldDuel.COLS // 2)}
        free = [(r, c) for r in range(BattlefieldDuel.ROWS) for c in range(BattlefieldDuel.COLS) if (r, c) not in forbidden]
        np.random.shuffle(free)
        for (r, c) in free[:num_obstacles]:
            s[4, r, c] = 1
        return s

    @staticmethod
    def legal_actions(state):
        return list(range(8))

    @staticmethod
    def _locate(state, ch):
        idx = np.argwhere(state[ch] == 1)
        return tuple(idx[0]) if idx.size else None

    @staticmethod
    def _apply_move(s, action, player):
        pos_ch = 0 if player == 1 else 1
        pos = BattlefieldDuel._locate(s, pos_ch)
        if pos is None: return
        r, c = pos
        dr = [-1,1,0,0][action]
        dc = [0,0,-1,1][action]
        nr, nc = r+dr, c+dc
        if 0 <= nr < BattlefieldDuel.ROWS and 0 <= nc < BattlefieldDuel.COLS and s[0,nr,nc]==0 and s[1,nr,nc]==0 and s[4,nr,nc]==0:
            s[pos_ch] = 0
            s[pos_ch, nr, nc] = 1

    @staticmethod
    def _apply_shoot(s, action, player):
        pos_ch = 0 if player == 1 else 1
        opp_pos_ch = 1 - pos_ch
        opp_health_ch = 3 if player == 1 else 2
        pos = BattlefieldDuel._locate(s, pos_ch)
        if pos is None: return
        r, c = pos
        dr = [-1,1,0,0][action-4]
        dc = [0,0,-1,1][action-4]
        for dist in range(1, BattlefieldDuel.SHOOT_RANGE+1):
            tr, tc = r + dr*dist, c + dc*dist
            if not (0 <= tr < BattlefieldDuel.ROWS and 0 <= tc < BattlefieldDuel.COLS):
                break
            if s[4, tr, tc] == 1:  # obstacle blocks
                break
            if s[opp_pos_ch, tr, tc] == 1:
                s[opp_health_ch, :, :] -= 1
                break

    @staticmethod
    def next_state(state, action, player):
        s = state.copy()
        if action <= 3:
            BattlefieldDuel._apply_move(s, action, player)
        else:
            BattlefieldDuel._apply_shoot(s, action, player)
        # Decrement steps
        s[5, :, :] -= 1
        steps_left = int(s[5,0,0])
        # Update center streaks
        center = (BattlefieldDuel.ROWS//2, BattlefieldDuel.COLS//2)
        p1_pos = BattlefieldDuel._locate(s,0)
        p2_pos = BattlefieldDuel._locate(s,1)
        if p1_pos == center:
            s[6,:,:] += 1
        else:
            s[6,:,:] = 0
        if p2_pos == center:
            s[7,:,:] += 1
        else:
            s[7,:,:] = 0
        # Shrink zone & apply outside damage
        shrink_level = (BattlefieldDuel.MAX_STEPS - steps_left) // BattlefieldDuel.SHRINK_INTERVAL
        if shrink_level > 0:
            min_r = min_c = shrink_level
            max_r = BattlefieldDuel.ROWS - 1 - shrink_level
            max_c = BattlefieldDuel.COLS - 1 - shrink_level
            # out of bounds => eliminate
            def outside(p):
                if p is None: return False
                r,c = p
                return not (min_r <= r <= max_r and min_c <= c <= max_c)
            if outside(p1_pos):
                s[2,:,:] = 0
            if outside(p2_pos):
                s[3,:,:] = 0
        # Optional debug trace
        try:  # pragma: no cover (debug only)
            import os
            if os.getenv('BF_DUEL_DEBUG'):
                print(f"[BF_DUEL_DEBUG] action={action} player={player} steps_left={steps_left} shrink={shrink_level} p1={p1_pos} p2={p2_pos}")
        except Exception:
            pass
        return s

    @staticmethod
    def is_terminal(state):
        p1_h = int(state[2,0,0])
        p2_h = int(state[3,0,0])
        # Center capture
        if int(state[6,0,0]) >= BattlefieldDuel.CAPTURE_STEPS:
            return True, 1
        if int(state[7,0,0]) >= BattlefieldDuel.CAPTURE_STEPS:
            return True, -1
        if p1_h <= 0 and p2_h <= 0:
            return True, 0
        if p1_h <= 0:
            return True, -1
        if p2_h <= 0:
            return True, 1
        steps_left = int(state[5,0,0])
        if steps_left <= 0:
            # Decide by health, then by center proximity as tiebreaker, else draw
            if p1_h > p2_h: return True, 1
            if p2_h > p1_h: return True, -1
            # center distance tie-break
            center = (BattlefieldDuel.ROWS//2, BattlefieldDuel.COLS//2)
            def locate(ch):
                idx = np.argwhere(state[ch]==1)
                return tuple(idx[0]) if idx.size else None
            p1_pos = locate(0); p2_pos = locate(1)
            if p1_pos and p2_pos:
                d1 = abs(p1_pos[0]-center[0]) + abs(p1_pos[1]-center[1])
                d2 = abs(p2_pos[0]-center[0]) + abs(p2_pos[1]-center[1])
                if d1 < d2: return True, 1
                if d2 < d1: return True, -1
            return True, 0
        return False, None

    @staticmethod
    def canonical_form(state, player):
        """        
        Swap player 1 and player 2 channels, and their health and center streak channels
            states channels: P1_pos, P2_pos, P1_health, P2_health, Obstacles, Steps, P1_streak, P2_streak
            output channels: P2_pos, P1_pos, P2_health, P1_health, Obstacles, Steps, P2_streak, P1_streak
        """
        return state if player == 1 else state[[1,0,3,2,4,5,7,6]]

    @staticmethod
    def encode_board(state):
        return state.astype(np.float32)

    @staticmethod
    def symmetries(board, pi):
        return [(board.copy(), pi.copy())]

    @staticmethod
    def render(state):
        try:
            p1 = np.argwhere(state[0]==1)[0]
        except Exception:
            p1=None
        try:
            p2 = np.argwhere(state[1]==1)[0]
        except Exception:
            p2=None
        p1_h = int(state[2,0,0]); p2_h = int(state[3,0,0])
        steps = int(state[5,0,0])
        cst1 = int(state[6,0,0]); cst2 = int(state[7,0,0])
        # Compute shrink zone
        shrink_level = (BattlefieldDuel.MAX_STEPS - steps) // BattlefieldDuel.SHRINK_INTERVAL
        min_r=min_c=shrink_level; max_r=BattlefieldDuel.ROWS-1-shrink_level; max_c=BattlefieldDuel.COLS-1-shrink_level
        print("BattlefieldDuel (steps_left=%d shrink=%d):"%(steps,shrink_level))
        print("   "+" ".join(str(c) for c in range(BattlefieldDuel.COLS)))
        for r in range(BattlefieldDuel.ROWS):
            row=[]
            for c in range(BattlefieldDuel.COLS):
                if not (min_r <= r <= max_r and min_c <= c <= max_c):
                    ch='x'  # unsafe
                elif state[4,r,c]==1:
                    ch='#'
                elif p1 is not None and (r,c)==tuple(p1):
                    ch='A'
                elif p2 is not None and (r,c)==tuple(p2):
                    ch='B'
                else:
                    ch='.'
                row.append(ch)
            print(f"{r:>2} "+" ".join(row))
        print(f"HP: P1={p1_h} P2={p2_h} | CenterStreak P1={cst1} P2={cst2}")

    @staticmethod
    def heuristic_value(state, player):
        """Heuristic for Duel variant: emphasize health, center control, approaching opponent, and imminent capture.
        Returns value in [-1,1] relative to `player`.
        """
        try:
            p1_h = float(state[2,0,0]); p2_h = float(state[3,0,0])
        except Exception:
            return 0.0
        max_h = float(BattlefieldDuel.MAX_HEALTH)
        health_term = (p1_h - p2_h)/max_h
        def locate(ch):
            idx = np.argwhere(state[ch]==1)
            return tuple(idx[0]) if idx.size else None
        p1_pos = locate(0); p2_pos = locate(1)
        center = (BattlefieldDuel.ROWS//2, BattlefieldDuel.COLS//2)
        # Center control (distance inverted)
        def inv_center(p):
            if p is None: return 0.0
            d = abs(p[0]-center[0]) + abs(p[1]-center[1])
            return 1.0 - d/((BattlefieldDuel.ROWS-1)+(BattlefieldDuel.COLS-1))
        center_term = inv_center(p1_pos) - inv_center(p2_pos)
        # Capture imminence (streak progress)
        cap_term = (state[6,0,0] - state[7,0,0]) / BattlefieldDuel.CAPTURE_STEPS
        # Distance to opponent
        if p1_pos and p2_pos:
            dist = abs(p1_pos[0]-p2_pos[0]) + abs(p1_pos[1]-p2_pos[1])
            maxd = (BattlefieldDuel.ROWS-1)+(BattlefieldDuel.COLS-1)
            approach_term = 1.0 - dist/maxd  # 1 when adjacent (or same), 0 far apart
        else:
            approach_term = 0.0
        # Safe zone pressure (both inside vs one outside soon) approximated by shrink level vs margin
        steps_left = int(state[5,0,0])
        shrink_level = (BattlefieldDuel.MAX_STEPS - steps_left)//BattlefieldDuel.SHRINK_INTERVAL
        zone_term = 0.0
        if p1_pos:
            margin1 = min(p1_pos[0]-shrink_level, p1_pos[1]-shrink_level,
                          (BattlefieldDuel.ROWS-1-shrink_level)-p1_pos[0], (BattlefieldDuel.COLS-1-shrink_level)-p1_pos[1])
        else: margin1 = -1
        if p2_pos:
            margin2 = min(p2_pos[0]-shrink_level, p2_pos[1]-shrink_level,
                          (BattlefieldDuel.ROWS-1-shrink_level)-p2_pos[0], (BattlefieldDuel.COLS-1-shrink_level)-p2_pos[1])
        else: margin2 = -1
        if margin1 >=0 and margin2 >=0:
            zone_term = (margin2 - margin1)/max(1,(BattlefieldDuel.ROWS//2))  # smaller margin (closer to edge) is worse
        # Shooting threat immediate LOS
        def has_los(src,dst):
            if src is None or dst is None: return False
            r,c=src; drs=[(-1,0),(1,0),(0,-1),(0,1)]
            for dr,dc in drs:
                for dist in range(1,BattlefieldDuel.SHOOT_RANGE+1):
                    tr,tc=r+dr*dist,c+dc*dist
                    if not (0<=tr<BattlefieldDuel.ROWS and 0<=tc<BattlefieldDuel.COLS): break
                    if state[4,tr,tc]==1: break
                    if dst==(tr,tc): return True
            return False
        threat_term = (1 if has_los(p1_pos,p2_pos) else 0) - (1 if has_los(p2_pos,p1_pos) else 0)
        # Weighted sum
        v_p1 = (0.35*health_term + 0.18*center_term + 0.12*cap_term + 0.15*approach_term +
                0.08*zone_term + 0.12*threat_term)
        v_p1 = max(-1.0, min(1.0, v_p1))
        return v_p1 if player==1 else -v_p1

    @staticmethod
    def display(state, ax=None, show=True, annotate=True, title=True,
        save_path=None, return_frame=False,
        dpi=120, close=True, legend_outside=True,
        action=None, actor=None):
        """Visualize a BattlefieldDuel state using matplotlib.

        Parameters:
          state: np.ndarray game state tensor (CHANNELS x ROWS x COLS)
          ax: optional matplotlib Axes to draw on; if None, create a new figure+axes
          show: call plt.show() automatically (ignored if ax provided and show=False)
          annotate: draw text annotations (HP, steps, center streaks)
          title: display a title with variant name
            action: optional action index (0-7). If provided, an arrow will be drawn to represent
                the action on the board. Actions 0-3 are moves U/D/L/R; 4-7 are shots U/D/L/R.
            actor: optional player indicator (+1 for P1, -1 for P2). If None and action is provided,
                the last actor is inferred from steps_left parity (P1 on odd move counts).

        Returns:
          ax: the matplotlib Axes containing the visualization
        """
        try:
            import matplotlib.pyplot as plt
            from .utils import make_grid, draw_board, draw_action_arrow, add_legend, MoveDirs
        except Exception as e:  # pragma: no cover
            raise RuntimeError("matplotlib required for display(); install it first") from e

        H, W = BattlefieldDuel.ROWS, BattlefieldDuel.COLS

        # Locate players (may be None if eliminated)
        def locate(ch):
            idx = np.argwhere(state[ch] == 1)
            return tuple(idx[0]) if idx.size else None
        p1 = locate(0)
        p2 = locate(1)

        p1_h = int(state[2, 0, 0])
        p2_h = int(state[3, 0, 0])
        steps = int(state[5, 0, 0])
        cst1 = int(state[6, 0, 0])
        cst2 = int(state[7, 0, 0])

        # Compute shrink zone for current steps
        shrink_level = (BattlefieldDuel.MAX_STEPS - steps) // BattlefieldDuel.SHRINK_INTERVAL
        min_r = min_c = shrink_level
        max_r = H - 1 - shrink_level
        max_c = W - 1 - shrink_level

        # Obstacles mask and grid
        obs = (state[4] == 1)
        bounds = (min_r, min_c, max_r, max_c) if shrink_level > 0 else None
        grid = make_grid(H, W, obs, bounds, p1, p2)
        center = (H // 2, W // 2)

        # Create axes if needed
        if ax is None:
            extra_w = 2.2 if legend_outside else 1.5
            fig, ax = plt.subplots(figsize=(W * 0.8 + extra_w, H * 0.8 + 1.5), constrained_layout=False)
        # Draw board + center marker
        safe_bounds = (min_r, min_c, max_r, max_c) if min_r <= max_r and min_c <= max_c else None
        draw_board(ax, grid, H, W, safe_bounds=safe_bounds, center_marker=center)

        # Overlay explicit markers so agents remain visible even when the board shrinks
        try:
            if p1 is not None:
                r, c = p1
                ax.plot([c], [r], marker='s', markersize=11, mew=1.3,
                        markerfacecolor=(0.2, 0.5, 0.9), markeredgecolor='k', zorder=6)
            if p2 is not None:
                r, c = p2
                ax.plot([c], [r], marker='s', markersize=11, mew=1.3,
                        markerfacecolor=(0.9, 0.4, 0.2), markeredgecolor='k', zorder=6)
        except Exception:
            pass

        # Overlay explicit markers so agents remain visible even when the board shrinks
        try:
            for idx, pos in enumerate([p for p in p1 if p is not None]):
                r, c = pos
                ax.plot([c], [r], marker='s', markersize=10, mew=1.2, 
                        markerfacecolor=(0.2, 0.4 + 0.1*idx, 0.9), markeredgecolor='k', zorder=6)
            for idx, pos in enumerate([p for p in p2 if p is not None]):
                r, c = pos
                ax.plot([c], [r], marker='s', markersize=10, mew=1.2, 
                        markerfacecolor=(0.9, 0.3 + 0.1*idx, 0.2), markeredgecolor='k', zorder=6)
        except Exception:
            pass

        # Optional action arrow overlay
        if action is not None:
            try:
                if actor not in (1, -1):
                    moves_made = BattlefieldDuel.MAX_STEPS - steps
                    actor = 1 if (moves_made % 2 == 1) else -1
                cur_pos = p1 if actor == 1 else p2
                if cur_pos is not None:
                    color = (0.2, 0.4, 0.9) if actor == 1 else (0.9, 0.3, 0.2)
                    a = int(action)
                    start_pos = cur_pos
                    if 0 <= a <= 3:
                        dr, dc = MoveDirs[a]
                        prev_r, prev_c = cur_pos[0] - dr, cur_pos[1] - dc
                        if 0 <= prev_r < H and 0 <= prev_c < W:
                            start_pos = (prev_r, prev_c)
                    draw_action_arrow(ax, a, color, start_pos, H, W, BattlefieldDuel.SHOOT_RANGE, obs, safe_bounds)
            except Exception:
                pass

        # Legend
        add_legend(ax, legend_outside=legend_outside)

        # Annotations
        if annotate:
            txt = f"P1 HP:{p1_h}  P2 HP:{p2_h}  Steps:{steps}  Shrink:{shrink_level}  CenterStreak P1:{cst1} P2:{cst2}"
            if action is not None:
                try:
                    act_name = BattlefieldDuel.ACTION_NAMES[int(action)] if isinstance(action, (int, np.integer)) else str(action)
                except Exception:
                    act_name = str(action)
                if actor not in (1, -1):
                    moves_made = BattlefieldDuel.MAX_STEPS - steps
                    actor = 1 if (moves_made % 2 == 1) else -1
                side = 'P1' if actor == 1 else 'P2'
                txt += f"  | Last: {side} {act_name}"
            ax.text(0.02, 1.01, txt, transform=ax.transAxes, ha='left', va='bottom', fontsize=9,
                    bbox=dict(facecolor='white', alpha=0.6, edgecolor='none', boxstyle='round,pad=0.2'))

        if title:
            ax.set_title('BattlefieldDuel', pad=18)

        fig = ax.figure
        frame = None
        if save_path is not None:
            fig.savefig(save_path, bbox_inches='tight', pad_inches=0.15, dpi=dpi)
        if return_frame:
            fig.canvas.draw()
            w, h = fig.canvas.get_width_height()
            try:
                if hasattr(fig.canvas, 'tostring_rgb'):
                    buf = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
                    frame = buf.reshape(h, w, 3).copy()
                elif hasattr(fig.canvas, 'buffer_rgba'):
                    rgba = np.frombuffer(fig.canvas.buffer_rgba(), dtype=np.uint8).reshape(h, w, 4)
                    frame = rgba[:, :, :3].copy()
                elif hasattr(fig.canvas, 'tostring_argb'):
                    argb = np.frombuffer(fig.canvas.tostring_argb(), dtype=np.uint8).reshape(h, w, 4)
                    frame = argb[:, :, 1:4].copy()
                else:
                    raise RuntimeError('Unsupported matplotlib canvas for frame extraction')
            except Exception:
                import tempfile
                import imageio.v2 as iio
                with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp:
                    tmp_path = tmp.name
                fig.savefig(tmp_path, bbox_inches='tight', pad_inches=0.1, dpi=dpi)
                frame = iio.imread(tmp_path)
        if show and ax is not None:
            try:
                fig.tight_layout()
            except Exception:
                pass
            plt.show()
        if close and ax is not None and fig is not None and fig.stale is False:
            try:
                import matplotlib.pyplot as plt
                plt.close(fig)
            except Exception:
                pass
        # Always return (fig, ax) for overlay compatibility
        if return_frame:
            return (fig, ax, frame)
        return (fig, ax)

###############################################
# Multi-agent simultaneous variant (2v2)
###############################################

# Primitive per-agent actions for squad variant
# 0 WAIT | 1 MOVE_N | 2 MOVE_E | 3 MOVE_S | 4 MOVE_W | 5 SHOOT_N | 6 SHOOT_E | 7 SHOOT_S | 8 SHOOT_W
_PRIMS = ["WAIT","MOVE_N","MOVE_E","MOVE_S","MOVE_W","SHOOT_N","SHOOT_E","SHOOT_S","SHOOT_W"]
_MOVE_DELTAS = {1:(-1,0),2:(0,1),3:(1,0),4:(0,-1)}
_SHOOT_DELTAS = {5:(-1,0),6:(0,1),7:(1,0),8:(0,-1)}

def _decode_joint(idx:int, agents:int):
    base = 9
    out = []
    for _ in range(agents):
        out.append(idx % base)
        idx //= base
    return out

class BattlefieldDuelSquad2:
    """Battlefield Duel with 2 agents per side and simultaneous per-team actions.

    Mechanics inherited from BattlefieldDuel:
      - Shrinking safe zone, center capture streak, shorter horizon, extended shoot range.
    Differences:
      - Each team has 2 agents that act simultaneously on its turn (joint action space 9^2).
      - No ammo; shoots are always available and resolved after movement.
      - Team HP equals number of alive agents and decreases when an agent is eliminated.

    Channels (C,H,W) with H=ROWS, W=COLS:
      0..1: P1 agent positions (one-hot per agent)
      2..3: P2 agent positions (one-hot per agent)
      4: Obstacles (binary)
      5: P1 team HP (uniform plane)
      6: P2 team HP (uniform plane)
      7: Steps remaining (uniform plane)
      8: P1 center streak (uniform plane)
      9: P2 center streak (uniform plane)
    """
    ROWS = 10
    COLS = 10
    NUM_OBSTACLES = 5
    AGENTS = 2

    # take values from BattlefieldDuel - if required, change
    SHOOT_RANGE = BattlefieldDuel.SHOOT_RANGE
    MAX_STEPS = BattlefieldDuel.MAX_STEPS
    CAPTURE_STEPS = BattlefieldDuel.CAPTURE_STEPS
    SHRINK_INTERVAL = BattlefieldDuel.SHRINK_INTERVAL
    CHANNELS = 10

    @classmethod
    def configure(cls, board_rows=None, board_cols=None, **overrides):
        """Override board geometry and other tunables for Squad2 variant.

        Supported keys (if present): ROWS, COLS, NUM_OBSTACLES, SHOOT_RANGE, MAX_STEPS,
        CAPTURE_STEPS, SHRINK_INTERVAL, AGENTS.
        """
        try:
            # Update geometry on the target class
            if board_rows is not None:
                try:
                    br = int(board_rows)
                except Exception:
                    br = board_rows
                try:
                    cls.ROWS = br
                except Exception:
                    pass
            if board_cols is not None:
                try:
                    bc = int(board_cols)
                except Exception:
                    bc = board_cols
                try:
                    cls.COLS = bc
                except Exception:
                    pass

            # Mirror geometry onto the base class so static references like
            # BattlefieldDuelSquad2.ROWS/COLS see updated values (important for subclasses)
            try:
                if board_rows is not None:
                    BattlefieldDuelSquad2.ROWS = int(board_rows)
                if board_cols is not None:
                    BattlefieldDuelSquad2.COLS = int(board_cols)
            except Exception:
                pass

            # Apply other overrides to cls
            for k, v in overrides.items():
                if hasattr(cls, k):
                    try:
                        setattr(cls, k, type(getattr(cls, k))(v))
                    except Exception:
                        try:
                            setattr(cls, k, v)
                        except Exception:
                            pass
            # And best-effort mirror to base class as well for commonly-used attributes
            for k, v in overrides.items():
                try:
                    if hasattr(BattlefieldDuelSquad2, k):
                        setattr(BattlefieldDuelSquad2, k, type(getattr(BattlefieldDuelSquad2, k))(v))
                except Exception:
                    try:
                        setattr(BattlefieldDuelSquad2, k, v)
                    except Exception:
                        pass
        except Exception:
            pass

    @staticmethod
    def action_size():
        return 9 ** BattlefieldDuelSquad2.AGENTS

    @staticmethod
    def action_name(a:int) -> str:
        acts = _decode_joint(a, BattlefieldDuelSquad2.AGENTS)
        return " | ".join(f"A{i}:{_PRIMS[x]}" for i,x in enumerate(acts))

    @staticmethod
    def get_initial_state(num_obstacles=None, seed=None):
        if num_obstacles is None:
            num_obstacles = np.random.randint(0, BattlefieldDuelSquad2.NUM_OBSTACLES + 1)
        rng = np.random.default_rng(seed)
        H,W = BattlefieldDuelSquad2.ROWS, BattlefieldDuelSquad2.COLS
        s = np.zeros((BattlefieldDuelSquad2.CHANNELS, H, W), dtype=np.float32)
        # Spawn P1 agents on top row, different columns; P2 on bottom row
        cols = list(range(W))
        rng.shuffle(cols)
        p1_cols = sorted(cols[:2])
        rng.shuffle(cols)
        p2_cols = sorted(cols[:2])
        s[0, 0, p1_cols[0]] = 1
        s[1, 0, p1_cols[1]] = 1
        s[2, H-1, p2_cols[0]] = 1
        s[3, H-1, p2_cols[1]] = 1
        # Obstacles
        forbidden = {(0, p1_cols[0]), (0, p1_cols[1]), (H-1, p2_cols[0]), (H-1, p2_cols[1]), (H//2, W//2)}
        free = [(r,c) for r in range(H) for c in range(W) if (r,c) not in forbidden]
        rng.shuffle(free)
        for (r,c) in free[:num_obstacles]:
            s[4, r, c] = 1
        # HP equals alive agents per team initially = 2
        s[5, :, :] = 2
        s[6, :, :] = 2
        # Steps
        s[7, :, :] = BattlefieldDuelSquad2.MAX_STEPS
        # Streaks
        # s[8], s[9] default zeros
        return s

    @staticmethod
    def _extract_positions(state):
        def find(ch):
            idx = np.argwhere(state[ch] == 1)
            return (int(idx[0,0]), int(idx[0,1])) if idx.size else None
        p1 = [find(0), find(1)]
        p2 = [find(2), find(3)]
        return p1, p2

    @staticmethod
    def legal_actions(state):
        H,W = BattlefieldDuelSquad2.ROWS, BattlefieldDuelSquad2.COLS
        blocks = state[4] == 1
        p1, p2 = BattlefieldDuelSquad2._extract_positions(state)
        occ = blocks.copy()
        for pos in p1 + p2:
            if pos is not None:
                occ[pos] = True
        # For side to move, we allow all primitives for each living agent; we only prune out-of-bounds moves.
        choices = []
        for pos in p1:  # symmetric action count; we use P1 as template
            acts = [0]  # WAIT
            if pos is not None:
                x,y = pos
                for a,(dx,dy) in _MOVE_DELTAS.items():
                    nx, ny = x+dx, y+dy
                    if 0 <= nx < H and 0 <= ny < W and not occ[nx,ny]:
                        acts.append(a)
                acts.extend([5,6,7,8])  # shoots always allowed
            choices.append(acts)
        out = []
        for a0 in choices[0]:
            for a1 in choices[1]:
                out.append(a0 + 9*a1)
        return out

    @staticmethod
    def _ring_bounds(steps_left:int):
        H,W = BattlefieldDuelSquad2.ROWS, BattlefieldDuelSquad2.COLS
        shrink = (BattlefieldDuelSquad2.MAX_STEPS - steps_left) // BattlefieldDuelSquad2.SHRINK_INTERVAL
        min_r = min_c = shrink
        max_r = H - 1 - shrink
        max_c = W - 1 - shrink
        return shrink, min_r, min_c, max_r, max_c

    @staticmethod
    def next_state(state, action, player):
        H, W = BattlefieldDuelSquad2.ROWS, BattlefieldDuelSquad2.COLS
        s = state.copy()
        # Decrement steps
        steps_left = int(s[7, 0, 0]) - 1
        s[7, :, :] = steps_left
        # Extract positions and static masks
        p1, p2 = BattlefieldDuelSquad2._extract_positions(s)
        blocks = (s[4] == 1)

        # Safe ring bounds and helper
        shrink, min_r, min_c, max_r, max_c = BattlefieldDuelSquad2._ring_bounds(steps_left)
        def inside(p):
            if p is None:
                return False
            x, y = p
            return (min_r <= x <= max_r) and (min_c <= y <= max_c)

        # Decode acting team's joint action; other team waits
        acts = _decode_joint(int(action), BattlefieldDuelSquad2.AGENTS)
        if player == 1:
            intents_p1 = acts
            intents_p2 = [0] * BattlefieldDuelSquad2.AGENTS  # WAIT
        else:
            intents_p1 = [0] * BattlefieldDuelSquad2.AGENTS
            intents_p2 = acts

        # Occupancy at start of turn (blocks + all agents)
        occ = blocks.copy()
        for pos in p1 + p2:
            if pos is not None:
                occ[pos] = True

        # Resolve intra-team movement with simple collision rule
        def resolve_team(pos_list, intents):
            targets = []
            for pos, a in zip(pos_list, intents):
                if pos is None:
                    targets.append(None)
                elif a in _MOVE_DELTAS:
                    dx, dy = _MOVE_DELTAS[a]
                    nx, ny = pos[0] + dx, pos[1] + dy
                    if 0 <= nx < H and 0 <= ny < W and not occ[nx, ny]:
                        targets.append((nx, ny))
                    else:
                        targets.append(pos)
                else:
                    targets.append(pos)
            # If two agents target same cell and neither stays, both stay
            counts = {}
            for t in targets:
                if t is not None:
                    counts[t] = counts.get(t, 0) + 1
            new_list = []
            for cur, t in zip(pos_list, targets):
                if cur is None:
                    new_list.append(None)
                elif t is None:
                    new_list.append(None)
                elif counts[t] > 1 and t != cur:
                    new_list.append(cur)
                else:
                    new_list.append(t)
            return new_list

        p1_new = resolve_team(p1, intents_p1)
        p2_new = resolve_team(p2, intents_p2)

        # Eliminate agents that end up outside the ring
        def eliminate_outside(pos_list, hp_ch):
            removed = 0
            for i, pos in enumerate(pos_list):
                if pos is not None and not inside(pos):
                    pos_list[i] = None
                    removed += 1
            if removed:
                s[hp_ch, :, :] = max(0, int(s[hp_ch, 0, 0]) - removed)
            return pos_list

        p1_new = eliminate_outside(p1_new, 5)
        p2_new = eliminate_outside(p2_new, 6)

        # Shooting by acting team only
        def ray_hit(start, d, blocks_mask):
            x, y = start
            dx, dy = d
            steps = 0
            while True:
                x += dx
                y += dy
                steps += 1
                if BattlefieldDuelSquad2.SHOOT_RANGE is not None and steps > BattlefieldDuelSquad2.SHOOT_RANGE:
                    return None
                if not (0 <= x < H and 0 <= y < W):
                    return None
                if blocks_mask[x, y]:
                    return None
                if not (min_r <= x <= max_r and min_c <= y <= max_c):
                    return None
                idx = idx_map.get((x, y))
                if idx is not None:
                    return idx

        if player == 1:
            enemy = p2_new
            hp_ch = 6
            shooters = p1_new
            intents = intents_p1
        else:
            enemy = p1_new
            hp_ch = 5
            shooters = p2_new
            intents = intents_p2

        idx_map = {pos: i for i, pos in enumerate(enemy) if pos is not None}
        blocks_mask = (s[4] == 1)
        for pos, a in zip(shooters, intents):
            if pos is None:
                continue
            if a in _SHOOT_DELTAS:
                d = _SHOOT_DELTAS[a]
                hit_idx = ray_hit(pos, d, blocks_mask)
                if hit_idx is not None and enemy[hit_idx] is not None:
                    enemy[hit_idx] = None
                    s[hp_ch, :, :] = max(0, int(s[hp_ch, 0, 0]) - 1)

        # Write back positions after movement and shooting
        s[0:4] = 0
        for i, pos in enumerate(p1_new):
            if pos is not None:
                s[i, pos[0], pos[1]] = 1
        for i, pos in enumerate(p2_new):
            if pos is not None:
                s[2 + i, pos[0], pos[1]] = 1

        # Update center streaks
        center = (H // 2, W // 2)
        if any(pos == center for pos in p1_new if pos is not None):
            s[8, :, :] += 1
        else:
            s[8, :, :] = 0
        if any(pos == center for pos in p2_new if pos is not None):
            s[9, :, :] += 1
        else:
            s[9, :, :] = 0

        return s

    @staticmethod
    def is_terminal(state):
        p1_h = int(state[5,0,0])
        p2_h = int(state[6,0,0])
        if int(state[8,0,0]) >= BattlefieldDuelSquad2.CAPTURE_STEPS:
            return True, 1
        if int(state[9,0,0]) >= BattlefieldDuelSquad2.CAPTURE_STEPS:
            return True, -1
        if p1_h <= 0 and p2_h <= 0:
            return True, 0
        if p1_h <= 0:
            return True, -1
        if p2_h <= 0:
            return True, 1
        steps_left = int(state[7,0,0])
        if steps_left <= 0:
            # Winner by HP, then center proximity tiebreak
            if p1_h > p2_h: return True, 1
            if p2_h > p1_h: return True, -1
            # distance tiebreak: min distance among agents to center
            def min_dist(prefix):
                dmin = 1e9
                for ch in prefix:
                    idx = np.argwhere(state[ch]==1)
                    if idx.size:
                        r,c = int(idx[0,0]), int(idx[0,1])
                        d = abs(r - BattlefieldDuelSquad2.ROWS//2) + abs(c - BattlefieldDuelSquad2.COLS//2)
                        dmin = min(dmin, d)
                return dmin if dmin != 1e9 else 999
            d1 = min_dist([0,1]); d2 = min_dist([2,3])
            if d1 < d2: return True, 1
            if d2 < d1: return True, -1
            return True, 0
        return False, None

    @staticmethod
    def canonical_form(state, player):
        """Swap player 1 and player 2 channels, and their health and center streak channels"""
        if player == 1:
            return state
        order = [2,3,0,1, 4, 6, 5, 7, 9,8]
        return state[order]

    @staticmethod
    def encode_board(state):
        return state.astype(np.float32)

    @staticmethod
    def symmetries(board, pi):
        return [(board.copy(), pi.copy())]

    @staticmethod
    def action_names():
        return list(_PRIMS)

    @staticmethod
    def render(state):
        # simple ASCII render showing positions and HP/steps
        H,W = BattlefieldDuelSquad2.ROWS, BattlefieldDuelSquad2.COLS
        p1,p2 = BattlefieldDuelSquad2._extract_positions(state)
        steps = int(state[7,0,0])
        shrink = (BattlefieldDuelSquad2.MAX_STEPS - steps)//BattlefieldDuelSquad2.SHRINK_INTERVAL
        min_r=min_c=shrink; max_r=H-1-shrink; max_c=W-1-shrink
        print(f"DuelSquad2 (steps={steps} shrink={shrink})")
        for r in range(H):
            row=[]
            for c in range(W):
                if not (min_r<=r<=max_r and min_c<=c<=max_c): ch='x'
                elif state[4,r,c]==1: ch='#'
                elif (r,c) in [p for p in p1 if p is not None]: ch='A'
                elif (r,c) in [p for p in p2 if p is not None]: ch='B'
                else: ch='.'
                row.append(ch)
            print(" ".join(row))
        print(f"HP P1={int(state[5,0,0])} P2={int(state[6,0,0])}")

    @staticmethod
    def display(state, ax=None, show=True, annotate=True, title=True,
                save_path=None, return_frame=False, dpi=120, close=True,
                legend_outside=True, action=None, actor=None):
        """Visualize a BattlefieldDuelSquad2 state using matplotlib, leveraging shared utils."""
        try:
            import matplotlib.pyplot as plt
            from .utils import make_grid, draw_board, draw_joint_action_arrows, add_legend, MoveDirs, draw_action_arrow
        except Exception as e:  # pragma: no cover
            raise RuntimeError("matplotlib required for display(); install it first") from e

        H, W = BattlefieldDuelSquad2.ROWS, BattlefieldDuelSquad2.COLS

        # read state
        def locs(ch0, ch1):
            out = []
            for ch in (ch0, ch1):
                idx = np.argwhere(state[ch] == 1)
                out.append((int(idx[0, 0]), int(idx[0, 1])) if idx.size else None)
            return out
        p1 = locs(0, 1)
        p2 = locs(2, 3)
        steps = int(state[7, 0, 0])
        p1_h = int(state[5, 0, 0])
        p2_h = int(state[6, 0, 0])
        cst1 = int(state[8, 0, 0])
        cst2 = int(state[9, 0, 0])

        shrink, min_r, min_c, max_r, max_c = BattlefieldDuelSquad2._ring_bounds(steps)
        obs = (state[4] == 1)
        bounds = (min_r, min_c, max_r, max_c) if shrink > 0 else None
        grid = make_grid(H, W, obs, bounds, p1, p2)
        center = (H // 2, W // 2)

        if ax is None:
            extra_w = 2.2 if legend_outside else 1.5
            fig, ax = plt.subplots(figsize=(W * 0.8 + extra_w, H * 0.8 + 1.5), constrained_layout=False)

        safe_bounds = (min_r, min_c, max_r, max_c) if min_r <= max_r and min_c <= max_c else None
        draw_board(ax, grid, H, W, safe_bounds=safe_bounds, center_marker=center)

        # arrows for per-agent actions if provided
        if isinstance(action, int):
            acts = _decode_joint(action, BattlefieldDuelSquad2.AGENTS)
        elif isinstance(action, (list, tuple)):
            acts = list(action)
        else:
            acts = None
        if acts is not None:
            # actor is the side that just acted
            if actor not in (1, -1):
                moves_made = BattlefieldDuelSquad2.MAX_STEPS - steps
                actor = 1 if (moves_made % 2 == 1) else -1
            team_pos = p1 if actor == 1 else p2
            color = (0.2, 0.4, 0.9) if actor == 1 else (0.9, 0.3, 0.2)
            # Split waits vs actionable primitives; 0 means WAIT in Squad2
            nw_positions, nw_actions, wait_positions = [], [], []
            for pos, a in zip(team_pos, acts):
                if pos is None:
                    continue
                if int(a) == 0:
                    wait_positions.append(pos)
                else:
                    nw_positions.append(pos)
                    nw_actions.append(int(a))
            if nw_actions:
                # Translate Squad2 primitive indices to utils' expected indices:
                # utils expects: 0..3 moves (U,D,L,R), 4..7 shoots (U,D,L,R)
                # Squad2 has:   1..4 moves (N,E,S,W), 5..8 shoots (N,E,S,W)
                def _map_for_utils(a:int) -> int:
                    if a in (1,2,3,4):
                        return {1:0, 2:3, 3:1, 4:2}[a]
                    if a in (5,6,7,8):
                        return {5:4, 6:7, 7:5, 8:6}[a]
                    # default passthrough (shouldn't happen for arrows)
                    return a
                # For move actions, use pre-move start by reversing per-primitive delta
                adjusted_positions = []
                draw_actions = []
                for pos, a in zip(nw_positions, nw_actions):
                    if a in _MOVE_DELTAS:
                        dr, dc = _MOVE_DELTAS[a]
                        sr, sc = pos[0] - dr, pos[1] - dc
                        if 0 <= sr < H and 0 <= sc < W:
                            adjusted_positions.append((sr, sc))
                        else:
                            adjusted_positions.append(pos)
                    else:
                        adjusted_positions.append(pos)
                    
                    
                    
                    draw_actions.append(_map_for_utils(a))
                draw_joint_action_arrows(ax, draw_actions, adjusted_positions, color, H, W,
                                         BattlefieldDuelSquad2.SHOOT_RANGE, obs, safe_bounds)
            # Draw a small cross for agents that WAIT
            for (r, c) in wait_positions:
                try:
                    ax.plot([c], [r], marker='x', markersize=9, mew=2.0, color='white', zorder=6)
                    ax.plot([c], [r], marker='x', markersize=7, mew=1.6, color=color, zorder=7)
                except Exception:
                    pass
        # legend & annotations
        add_legend(ax, legend_outside=legend_outside)
        if annotate:
            txt = f"P1 HP:{p1_h}  P2 HP:{p2_h}  Steps:{steps}  Shrink:{shrink}  CenterStreak P1:{cst1} P2:{cst2}"
            # Add last action name if provided
            if action is not None:
                try:
                    if isinstance(action, (int, np.integer)):
                        act_name = BattlefieldDuelSquad2.action_name(int(action))
                    elif isinstance(action, (list, tuple)):
                        act_name = " | ".join(_PRIMS[int(a)] if 0 <= int(a) < len(_PRIMS) else str(a) for a in action)
                    else:
                        act_name = str(action)
                except Exception:
                    act_name = str(action)
                if actor not in (1, -1):
                    moves_made = BattlefieldDuelSquad2.MAX_STEPS - steps
                    actor = 1 if (moves_made % 2 == 1) else -1
                side = 'P1' if actor == 1 else 'P2'
                txt += f"  | Last: {side} {act_name}"
            ax.text(0.02, 1.01, txt, transform=ax.transAxes, ha='left', va='bottom', fontsize=9,
                    bbox=dict(facecolor='white', alpha=0.6, edgecolor='none', boxstyle='round,pad=0.2'))

        if title:
            ax.set_title('BattlefieldDuelSquad2', pad=18)

        fig = ax.figure
        frame = None
        if save_path is not None:
            fig.savefig(save_path, bbox_inches='tight', pad_inches=0.15, dpi=dpi)
        if return_frame:
            fig.canvas.draw()
            w, h = fig.canvas.get_width_height()
            try:
                if hasattr(fig.canvas, 'tostring_rgb'):
                    buf = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
                    frame = buf.reshape(h, w, 3).copy()
                elif hasattr(fig.canvas, 'buffer_rgba'):
                    rgba = np.frombuffer(fig.canvas.buffer_rgba(), dtype=np.uint8).reshape(h, w, 4)
                    frame = rgba[:, :, :3].copy()
                elif hasattr(fig.canvas, 'tostring_argb'):
                    argb = np.frombuffer(fig.canvas.tostring_argb(), dtype=np.uint8).reshape(h, w, 4)
                    frame = argb[:, :, 1:4].copy()
            except Exception:
                pass
        if show:
            try:
                fig.tight_layout()
            except Exception:
                pass
            try:
                plt.show()
            except Exception:
                pass
        if close and ax is not None and fig is not None and fig.stale is False:
            try:
                import matplotlib.pyplot as plt
                plt.close(fig)
            except Exception:
                pass
        return (ax, frame) if return_frame else ax

###############################################
# Squad variant with ammunition tracking
###############################################

class BattlefieldDuelSquadAmmo(BattlefieldDuelSquad2):
    """Squad2 variant with limited ammunition per agent.

    Adds one extra plane that stores remaining ammo at each agent's cell (0 elsewhere).
    - Ammo decreases by 1 whenever an agent fires a shot.
    - Agents with ammo == 0 cannot perform shooting primitives.
    - Ammo moves with the agent as it moves.
    """

    # Inherit base geometry and rules; add ammo configuration
    AMMO_INIT = 3
    CHANNELS = BattlefieldDuelSquad2.CHANNELS + 1  # +1 ammo plane

    @staticmethod
    def get_initial_state(num_obstacles=None, seed=None):
        if num_obstacles is None:
            # Respect subclass-configured obstacles if set via configure()
            num_obstacles = BattlefieldDuelSquadAmmo.NUM_OBSTACLES
        s = BattlefieldDuelSquad2.get_initial_state(num_obstacles=num_obstacles, seed=seed)
        # print(f"In BattlefieldDuelSquadAmmo: s.shape = {s.shape}")
        # Extend with ammo plane at the end
        # Derive actual geometry from produced state to avoid stale class constants
        try:
            H, W = int(s.shape[-2]), int(s.shape[-1])
        except Exception:
            H, W = BattlefieldDuelSquadAmmo.ROWS, BattlefieldDuelSquadAmmo.COLS
        ammo = np.zeros((H, W), dtype=np.float32)
        # Place ammo at agent positions (2 per side)
        def find(ch):
            idx = np.argwhere(s[ch] == 1)
            return (int(idx[0, 0]), int(idx[0, 1])) if idx.size else None
        for ch in (0, 1, 2, 3):
            pos = find(ch)
            if pos is not None:
                ammo[pos] = BattlefieldDuelSquadAmmo.AMMO_INIT
        # Stack new channel
        s2 = np.zeros((BattlefieldDuelSquadAmmo.CHANNELS, H, W), dtype=np.float32)
        s2[:BattlefieldDuelSquad2.CHANNELS] = s
        s2[BattlefieldDuelSquad2.CHANNELS] = ammo
        return s2

    @staticmethod
    def _extract_ammo(state):
        """Return ammo list for P1 agents then P2 agents by reading ammo plane at agent cells."""
        ammo_ch = BattlefieldDuelSquad2.CHANNELS
        def get_at(pos):
            if pos is None:
                return 0
            return int(state[ammo_ch, pos[0], pos[1]])
        p1, p2 = BattlefieldDuelSquad2._extract_positions(state)
        return [get_at(p) for p in p1], [get_at(p) for p in p2]

    @staticmethod
    def legal_actions(state):
        # Similar to parent but disallow shooting primitives for agents with zero ammo
        try:
            H, W = int(state.shape[-2]), int(state.shape[-1])
        except Exception:
            H, W = BattlefieldDuelSquadAmmo.ROWS, BattlefieldDuelSquadAmmo.COLS
        blocks = state[4] == 1
        p1, p2 = BattlefieldDuelSquad2._extract_positions(state)
        occ = blocks.copy()
        for pos in p1 + p2:
            if pos is not None:
                occ[pos] = True
        ammo_ch = BattlefieldDuelSquad2.CHANNELS
        choices = []
        # Build choices for a template side (P1) — action count symmetric
        for pos in p1:
            acts = [0]  # WAIT
            if pos is not None:
                x, y = pos
                for a, (dx, dy) in _MOVE_DELTAS.items():
                    nx, ny = x + dx, y + dy
                    if 0 <= nx < H and 0 <= ny < W and not occ[nx, ny]:
                        acts.append(a)
                # Shoots allowed only if ammo > 0 at current cell
                if int(state[ammo_ch, x, y]) > 0:
                    acts.extend([5, 6, 7, 8])
            choices.append(acts)
        out = []
        for a0 in choices[0]:
            for a1 in choices[1]:
                out.append(a0 + 9 * a1)
        return out

    @staticmethod
    def next_state(state, action, player):
        # We will largely follow parent logic but track ammo and decrement on firing
        try:
            H, W = int(state.shape[-2]), int(state.shape[-1])
        except Exception:
            H, W = BattlefieldDuelSquadAmmo.ROWS, BattlefieldDuelSquadAmmo.COLS
        s = state.copy()
        ammo_ch = BattlefieldDuelSquad2.CHANNELS
        # Decrement steps
        steps_left = int(s[7, 0, 0]) - 1
        s[7, :, :] = steps_left

        # Positions and static masks
        p1, p2 = BattlefieldDuelSquad2._extract_positions(s)
        blocks = (s[4] == 1)

        # Read current ammo counts at positions
        def read_ammo(pos_list):
            out = []
            for pos in pos_list:
                if pos is None:
                    out.append(0)
                else:
                    out.append(int(s[ammo_ch, pos[0], pos[1]]))
            return out
        ammo_p1 = read_ammo(p1)
        ammo_p2 = read_ammo(p2)

        # Safe ring bounds and helper (compute from actual H/W to avoid stale parent geometry)
        shrink = (BattlefieldDuelSquad2.MAX_STEPS - steps_left) // BattlefieldDuelSquad2.SHRINK_INTERVAL
        min_r = min_c = shrink
        max_r = H - 1 - shrink
        max_c = W - 1 - shrink
        def inside(p):
            if p is None:
                return False
            x, y = p
            return (min_r <= x <= max_r) and (min_c <= y <= max_c)

        # Decode acting team's joint action; other team waits
        acts = _decode_joint(int(action), BattlefieldDuelSquad2.AGENTS)
        if player == 1:
            intents_p1 = acts
            intents_p2 = [0] * BattlefieldDuelSquad2.AGENTS
        else:
            intents_p1 = [0] * BattlefieldDuelSquad2.AGENTS
            intents_p2 = acts

        # Occupancy at start
        occ = blocks.copy()
        for pos in p1 + p2:
            if pos is not None:
                occ[pos] = True

        # Movement resolution (same as parent)
        def resolve_team(pos_list, intents):
            targets = []
            for pos, a in zip(pos_list, intents):
                if pos is None:
                    targets.append(None)
                elif a in _MOVE_DELTAS:
                    dx, dy = _MOVE_DELTAS[a]
                    nx, ny = pos[0] + dx, pos[1] + dy
                    if 0 <= nx < H and 0 <= ny < W and not occ[nx, ny]:
                        targets.append((nx, ny))
                    else:
                        targets.append(pos)
                else:
                    targets.append(pos)
            counts = {}
            for t in targets:
                if t is not None:
                    counts[t] = counts.get(t, 0) + 1
            new_list = []
            for cur, t in zip(pos_list, targets):
                if cur is None:
                    new_list.append(None)
                elif t is None:
                    new_list.append(None)
                elif counts[t] > 1 and t != cur:
                    new_list.append(cur)
                else:
                    new_list.append(t)
            return new_list

        p1_new = resolve_team(p1, intents_p1)
        p2_new = resolve_team(p2, intents_p2)

        # Eliminate agents outside ring (also their ammo is lost)
        def eliminate_outside(pos_list, hp_ch):
            removed = 0
            for i, pos in enumerate(pos_list):
                if pos is not None and not inside(pos):
                    pos_list[i] = None
                    removed += 1
            if removed:
                s[hp_ch, :, :] = max(0, int(s[hp_ch, 0, 0]) - removed)
            return pos_list

        p1_new = eliminate_outside(p1_new, 5)
        p2_new = eliminate_outside(p2_new, 6)

        # Shooting by acting team only; enforce ammo > 0
        def ray_hit(start, d, blocks_mask):
            x, y = start
            dx, dy = d
            steps = 0
            while True:
                x += dx
                y += dy
                steps += 1
                if BattlefieldDuelSquad2.SHOOT_RANGE is not None and steps > BattlefieldDuelSquad2.SHOOT_RANGE:
                    return None
                if not (0 <= x < H and 0 <= y < W):
                    return None
                if blocks_mask[x, y]:
                    return None
                if not (min_r <= x <= max_r and min_c <= y <= max_c):
                    return None
                idx = idx_map.get((x, y))
                if idx is not None:
                    return idx

        if player == 1:
            enemy = p2_new
            hp_ch = 6
            shooters = p1_new
            intents = intents_p1
            ammo_team = ammo_p1
        else:
            enemy = p1_new
            hp_ch = 5
            shooters = p2_new
            intents = intents_p2
            ammo_team = ammo_p2

        idx_map = {pos: i for i, pos in enumerate(enemy) if pos is not None}
        blocks_mask = (s[4] == 1)
        for idx, (pos, a) in enumerate(zip(shooters, intents)):
            if pos is None:
                continue
            if a in _SHOOT_DELTAS and ammo_team[idx] > 0:
                d = _SHOOT_DELTAS[a]
                hit_idx = ray_hit(pos, d, blocks_mask)
                # Spend ammo regardless of hit
                ammo_team[idx] = max(0, ammo_team[idx] - 1)
                if hit_idx is not None and enemy[hit_idx] is not None:
                    enemy[hit_idx] = None
                    s[hp_ch, :, :] = max(0, int(s[hp_ch, 0, 0]) - 1)

        # Write back positions
        s[0:4] = 0
        for i, pos in enumerate(p1_new):
            if pos is not None:
                s[i, pos[0], pos[1]] = 1
        for i, pos in enumerate(p2_new):
            if pos is not None:
                s[2 + i, pos[0], pos[1]] = 1

        # Update center streaks
        center = (H // 2, W // 2)
        if any(pos == center for pos in p1_new if pos is not None):
            s[8, :, :] += 1
        else:
            s[8, :, :] = 0
        if any(pos == center for pos in p2_new if pos is not None):
            s[9, :, :] += 1
        else:
            s[9, :, :] = 0

        # Rebuild ammo plane at new positions from updated ammo lists
        s[ammo_ch] = 0
        for i, pos in enumerate(p1_new):
            if pos is not None:
                s[ammo_ch, pos[0], pos[1]] = ammo_p1[i]
        for i, pos in enumerate(p2_new):
            if pos is not None:
                s[ammo_ch, pos[0], pos[1]] = ammo_p2[i]

        return s

    @staticmethod
    def canonical_form(state, player):
        """Swap player-specific channels same as parent; keep ammo plane intact."""
        if player == 1:
            return state
        # Parent order: [2,3,0,1, 4, 6, 5, 7, 9, 8]
        base_order = [2, 3, 0, 1, 4, 6, 5, 7, 9, 8]
        ammo_idx = BattlefieldDuelSquad2.CHANNELS
        order = base_order + [ammo_idx]
        return state[order]

    @staticmethod
    def render(state):
        """ASCII render including ammo counts per agent."""
        try:
            H, W = int(state.shape[-2]), int(state.shape[-1])
        except Exception:
            H, W = BattlefieldDuelSquadAmmo.ROWS, BattlefieldDuelSquadAmmo.COLS
        p1, p2 = BattlefieldDuelSquad2._extract_positions(state)
        steps = int(state[7, 0, 0])
        shrink = (BattlefieldDuelSquad2.MAX_STEPS - steps) // BattlefieldDuelSquad2.SHRINK_INTERVAL
        min_r = min_c = shrink; max_r = H - 1 - shrink; max_c = W - 1 - shrink
        ammo_p1, ammo_p2 = BattlefieldDuelSquadAmmo._extract_ammo(state)
        print(f"BattlefieldDuelSquadAmmo (steps={steps} shrink={shrink})")
        for r in range(H):
            row = []
            for c in range(W):
                if not (min_r <= r <= max_r and min_c <= c <= max_c): ch = 'x'
                elif state[4, r, c] == 1: ch = '#'
                elif (r, c) in [p for p in p1 if p is not None]: ch = 'A'
                elif (r, c) in [p for p in p2 if p is not None]: ch = 'B'
                else: ch = '.'
                row.append(ch)
            print(" ".join(row))
        p1_h = int(state[5, 0, 0]); p2_h = int(state[6, 0, 0])
        print(f"HP P1={p1_h} (ammo {ammo_p1})  P2={p2_h} (ammo {ammo_p2})")

    @staticmethod
    def display(state, ax=None, show=True, annotate=True, title=True,
                save_path=None, return_frame=False, dpi=120, close=True,
                legend_outside=True, action=None, actor=None):
        """Visualize a BattlefieldDuelSquadAmmo state using matplotlib (shows ammo)."""
        try:
            import matplotlib.pyplot as plt
            from .utils import make_grid, draw_board, draw_joint_action_arrows, add_legend, MoveDirs, draw_action_arrow
        except Exception as e:  # pragma: no cover
            raise RuntimeError("matplotlib required for display(); install it first") from e

        H, W = BattlefieldDuelSquad2.ROWS, BattlefieldDuelSquad2.COLS

        # read state
        def locs(ch0, ch1):
            out = []
            for ch in (ch0, ch1):
                idx = np.argwhere(state[ch] == 1)
                out.append((int(idx[0, 0]), int(idx[0, 1])) if idx.size else None)
            return out
        p1 = locs(0, 1)
        p2 = locs(2, 3)
        steps = int(state[7, 0, 0])
        p1_h = int(state[5, 0, 0])
        p2_h = int(state[6, 0, 0])
        cst1 = int(state[8, 0, 0])
        cst2 = int(state[9, 0, 0])
        ammo_p1, ammo_p2 = BattlefieldDuelSquadAmmo._extract_ammo(state)

        shrink, min_r, min_c, max_r, max_c = BattlefieldDuelSquad2._ring_bounds(steps)
        obs = (state[4] == 1)
        bounds = (min_r, min_c, max_r, max_c) if shrink > 0 else None
        grid = make_grid(H, W, obs, bounds, p1, p2)
        center = (H // 2, W // 2)

        if ax is None:
            extra_w = 2.2 if legend_outside else 1.5
            fig, ax = plt.subplots(figsize=(W * 0.8 + extra_w, H * 0.8 + 1.5), constrained_layout=False)

        safe_bounds = (min_r, min_c, max_r, max_c) if min_r <= max_r and min_c <= max_c else None
        draw_board(ax, grid, H, W, safe_bounds=safe_bounds, center_marker=center)

        # arrows for per-agent actions if provided
        if isinstance(action, int):
            acts = _decode_joint(action, BattlefieldDuelSquad2.AGENTS)
        elif isinstance(action, (list, tuple)):
            acts = list(action)
        else:
            acts = None
        if acts is not None:
            # actor is the side that just acted
            if actor not in (1, -1):
                moves_made = BattlefieldDuelSquad2.MAX_STEPS - steps
                actor = 1 if (moves_made % 2 == 1) else -1
            team_pos = p1 if actor == 1 else p2
            color = (0.2, 0.4, 0.9) if actor == 1 else (0.9, 0.3, 0.2)
            nw_positions, nw_actions, wait_positions = [], [], []
            for pos, a in zip(team_pos, acts):
                if pos is None:
                    continue
                if int(a) == 0:
                    wait_positions.append(pos)
                else:
                    nw_positions.append(pos)
                    nw_actions.append(int(a))
            if nw_actions:
                def _map_for_utils(a:int) -> int:
                    if a in (1,2,3,4):
                        return {1:0, 2:3, 3:1, 4:2}[a]
                    if a in (5,6,7,8):
                        return {5:4, 6:7, 7:5, 8:6}[a]
                    return a
                adjusted_positions = []
                draw_actions = []
                for pos, a in zip(nw_positions, nw_actions):
                    if a in _MOVE_DELTAS:
                        dr, dc = _MOVE_DELTAS[a]
                        sr, sc = pos[0] - dr, pos[1] - dc
                        if 0 <= sr < H and 0 <= sc < W:
                            adjusted_positions.append((sr, sc))
                        else:
                            adjusted_positions.append(pos)
                    else:
                        adjusted_positions.append(pos)
                    draw_actions.append(_map_for_utils(a))
                draw_joint_action_arrows(ax, draw_actions, adjusted_positions, color, H, W,
                                         BattlefieldDuelSquad2.SHOOT_RANGE, obs, safe_bounds)
            for (r, c) in wait_positions:
                try:
                    ax.plot([c], [r], marker='x', markersize=9, mew=2.0, color='white', zorder=6)
                    ax.plot([c], [r], marker='x', markersize=7, mew=1.6, color=color, zorder=7)
                except Exception:
                    pass
        add_legend(ax, legend_outside=legend_outside)
        if annotate:
            txt = (
                f"P1 HP:{p1_h} Ammo:{ammo_p1}  "
                f"P2 HP:{p2_h} Ammo:{ammo_p2}  "
                f"Steps:{steps}  Shrink:{shrink}  CenterStreak P1:{cst1} P2:{cst2}"
            )
            if action is not None:
                try:
                    if isinstance(action, (int, np.integer)):
                        act_name = BattlefieldDuelSquad2.action_name(int(action))
                    elif isinstance(action, (list, tuple)):
                        act_name = " | ".join(_PRIMS[int(a)] if 0 <= int(a) < len(_PRIMS) else str(a) for a in action)
                    else:
                        act_name = str(action)
                except Exception:
                    act_name = str(action)
                if actor not in (1, -1):
                    moves_made = BattlefieldDuelSquad2.MAX_STEPS - steps
                    actor = 1 if (moves_made % 2 == 1) else -1
                side = 'P1' if actor == 1 else 'P2'
                txt += f"  | Last: {side} {act_name}"
            ax.text(0.02, 1.01, txt, transform=ax.transAxes, ha='left', va='bottom', fontsize=9,
                    bbox=dict(facecolor='white', alpha=0.6, edgecolor='none', boxstyle='round,pad=0.2'))

        if title:
            ax.set_title('BattlefieldDuelSquadAmmo', pad=18)

        fig = ax.figure
        frame = None
        if save_path is not None:
            fig.savefig(save_path, bbox_inches='tight', pad_inches=0.15, dpi=dpi)
        if return_frame:
            fig.canvas.draw()
            w, h = fig.canvas.get_width_height()
            try:
                if hasattr(fig.canvas, 'tostring_rgb'):
                    buf = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
                    frame = buf.reshape(h, w, 3).copy()
                elif hasattr(fig.canvas, 'buffer_rgba'):
                    rgba = np.frombuffer(fig.canvas.buffer_rgba(), dtype=np.uint8).reshape(h, w, 4)
                    frame = rgba[:, :, :3].copy()
                elif hasattr(fig.canvas, 'tostring_argb'):
                    argb = np.frombuffer(fig.canvas.tostring_argb(), dtype=np.uint8).reshape(h, w, 4)
                    frame = argb[:, :, 1:4].copy()
            except Exception:
                pass
        if show:
            try:
                fig.tight_layout()
            except Exception:
                pass
            try:
                plt.show()
            except Exception:
                pass
        if close and ax is not None and fig is not None and fig.stale is False:
            try:
                import matplotlib.pyplot as plt
                plt.close(fig)
            except Exception:
                pass
        return (ax, frame) if return_frame else ax

__all__ = ["BattlefieldDuel", "BattlefieldDuelSquad2", "BattlefieldDuelSquadAmmo"]