"""
Potential-based reward shaping utilities

"""

from __future__ import annotations

from dataclasses import dataclass
from typing import Optional, Callable, Any


@dataclass
class ShapingConfig:
    gamma: float = 1.0
    scale: float = 0.1
    use_in_mcts: bool = True
    use_in_targets: bool = True
    anneal_steps: int = 200_000  # linear anneal to 0 over this many steps; <=0 disables anneal


def annealed_scale(cfg: ShapingConfig, global_step: int | float) -> float:
    if cfg.anneal_steps and cfg.anneal_steps > 0:
        alpha = max(0.0, 1.0 - float(global_step) / float(cfg.anneal_steps))
    else:
        alpha = 1.0
    return float(cfg.scale * alpha)



def clamp_unit(x: float) -> float:
    return -1.0 if x < -1.0 else (1.0 if x > 1.0 else float(x))



def phi_duel(state, player: int = 1) -> float:
    """Potential for BattlefieldDuel in [-1, 1].

    Channels (C,H,W):
      0: P1 pos, 1: P2 pos, 2: P1 HP, 3: P2 HP, 4: Obstacles, 5: Steps,
      6: P1 center streak, 7: P2 center streak.
    Emphasizes health, center control, capture streak progress, approach, safe-zone margin, and LOS threat.
    """
    try:
        import numpy as np
        try:
            from src.games.battlefield_duel import BattlefieldDuel as _Duel
        except Exception:
            _Duel = None  # fallback values used below

        C, H, W = int(state.shape[0]), int(state.shape[1]), int(state.shape[2])

        def locate(ch: int):
            idx = np.argwhere(state[ch] == 1)
            return (int(idx[0, 0]), int(idx[0, 1])) if idx.size else None

        p1 = locate(0) if C > 0 else None
        p2 = locate(1) if C > 1 else None
        p1_h = float(state[2, 0, 0]) if C > 3 else 0.0
        p2_h = float(state[3, 0, 0]) if C > 3 else 0.0
        max_h = float(getattr(_Duel, 'MAX_HEALTH', 3))
        health_term = (p1_h - p2_h) / max(1.0, max_h)

        # Center control (inverse Manhattan distance to center)
        center = (H // 2, W // 2)
        def inv_center(pos):
            if pos is None:
                return 0.0
            d = abs(pos[0] - center[0]) + abs(pos[1] - center[1])
            return 1.0 - d / max(1.0, (H - 1) + (W - 1))
        center_term = inv_center(p1) - inv_center(p2)

        # Capture streak progress (normalized)
        cap_steps = float(getattr(_Duel, 'CAPTURE_STEPS', 3))
        cap_term = 0.0
        if C > 7 and cap_steps > 0:
            cap_term = (float(state[6, 0, 0]) - float(state[7, 0, 0])) / cap_steps

        # Approach (closer to opponent is better)
        approach_term = 0.0
        if p1 is not None and p2 is not None:
            md = abs(p1[0] - p2[0]) + abs(p1[1] - p2[1])
            max_md = (H - 1) + (W - 1)
            if max_md > 0:
                approach_term = 1.0 - (md / max_md)

    # Safe zone margin difference (smaller margin is worse). Existing zone_term historically used (m2 - m1)
    # which rewarded being closer to the border. We keep it for backward compatibility but add an explicit
    # border_safety_term that rewards larger margin (farther from current shrink boundary) to reduce
    # eliminations during shrinking.
        zone_term = 0.0
        border_safety_term = 0.0
        try:
            steps_left = int(state[5, 0, 0]) if C > 5 else 0
            if _Duel is not None:
                shrink = (int(getattr(_Duel, 'MAX_STEPS', 36)) - steps_left) // int(getattr(_Duel, 'SHRINK_INTERVAL', 6))
                min_r = min_c = shrink
                max_r = H - 1 - shrink
                max_c = W - 1 - shrink
                def margin(pos):
                    if pos is None:
                        return -1
                    r, c = pos
                    return min(r - min_r, c - min_c, max_r - r, max_c - c)
                m1 = margin(p1)
                m2 = margin(p2)
                if m1 >= 0 and m2 >= 0:
                    zone_term = (m2 - m1) / max(1.0, H // 2)
                    # New: reward being safer (larger margin); normalize by same denominator
                    denom = max(1.0, H // 2)
                    s1 = m1 / denom
                    s2 = m2 / denom
                    border_safety_term = s1 - s2  # positive if P1 is safer / more central
        except Exception:
            zone_term = 0.0
            border_safety_term = 0.0

        # Immediate LOS threat using variant's shoot range
        shoot_range = int(getattr(_Duel, 'SHOOT_RANGE', 2))
        def has_los(src, dst):
            if src is None or dst is None:
                return False
            r, c = src
            for dr, dc in [(-1, 0), (1, 0), (0, -1), (0, 1)]:
                tr, tc = r, c
                for _ in range(1, shoot_range + 1):
                    tr += dr; tc += dc
                    if not (0 <= tr < H and 0 <= tc < W):
                        break
                    if C > 4 and state[4, tr, tc] == 1:
                        break
                    if (tr, tc) == dst:
                        return True
            return False
        threat_term = (1.0 if has_los(p1, p2) else 0.0) - (1.0 if has_los(p2, p1) else 0.0)

        # Small elapsed penalty
        elapsed_term = 0.0
        if C > 5:
            steps_left = float(state[5, 0, 0])
            horizon = float(getattr(_Duel, 'MAX_STEPS', max(36, steps_left)))
            if horizon > 0:
                elapsed_term = -((horizon - steps_left) / horizon)

        # Weighted sum (close to Duel heuristic)
        w_health = 0.35
        w_center = 0.18
        w_cap = 0.12
        w_approach = 0.15
        w_zone = 0.04  # reduced to accommodate new border safety term without inflating total
        w_border = 0.08  # weight for new border_safety_term (central safety)
        w_threat = 0.12
        w_elapsed = 0.0  # optional, keep 0 by default to match heuristic
        total_w = w_health + w_center + w_cap + w_approach + w_zone + w_border + w_threat + w_elapsed
        v_p1 = (
            w_health * health_term +
            w_center * center_term +
            w_cap * cap_term +
            w_approach * approach_term +
            w_zone * zone_term +
            w_border * border_safety_term +
            w_threat * threat_term +
            w_elapsed * elapsed_term
        ) / (total_w if total_w > 0 else 1.0)
        return clamp_unit(float(v_p1 if int(player) == 1 else -v_p1))
    except Exception:
        return 0.0


def phi_duel_squad(state, player: int = 1) -> float:
    """Potential for BattlefieldDuelSquad2 in [-1, 1].

    Channels (C,H,W):
      0..1: P1 agents, 2..3: P2 agents (one-hot)
      4: Obstacles, 5: P1 team HP, 6: P2 team HP, 7: Steps, 8: P1 streak, 9: P2 streak.
    Emphasizes team HP, center control, capture streak, approach (min distance among agents), zone margin,
    LOS threat, plus a coordination bonus (cohesion + focus-fire) that entices teammates to act in unison.
    """
    try:
        import numpy as np
        try:
            from src.games.battlefield_duel import BattlefieldDuelSquad2 as _Squad
        except Exception:
            _Squad = None

        C, H, W = int(state.shape[0]), int(state.shape[1]), int(state.shape[2])

        def find(ch):
            idx = np.argwhere(state[ch] == 1)
            return (int(idx[0, 0]), int(idx[0, 1])) if idx.size else None

        p1_list = [find(0) if C > 0 else None, find(1) if C > 1 else None]
        p2_list = [find(2) if C > 2 else None, find(3) if C > 3 else None]

        # Team HP advantage normalized by initial agents (2)
        p1_hp = float(state[5, 0, 0]) if C > 6 else 0.0
        p2_hp = float(state[6, 0, 0]) if C > 6 else 0.0
        max_hp = float(getattr(_Squad, 'AGENTS', 2))
        health_term = (p1_hp - p2_hp) / max(1.0, max_hp)

        # Center control: min distance among agents to center
        center = (H // 2, W // 2)
        def min_inv_center(lst):
            best = -1.0
            for pos in lst:
                if pos is None:
                    continue
                d = abs(pos[0] - center[0]) + abs(pos[1] - center[1])
                inv = 1.0 - d / max(1.0, (H - 1) + (W - 1))
                if inv > best:
                    best = inv
            return best if best >= 0 else 0.0
        center_term = min_inv_center(p1_list) - min_inv_center(p2_list)

        # Capture streak progress
        cap_steps = float(getattr(_Squad, 'CAPTURE_STEPS', 3))
        cap_term = 0.0
        if C > 9 and cap_steps > 0:
            cap_term = (float(state[8, 0, 0]) - float(state[9, 0, 0])) / cap_steps

        # Approach term: minimal inter-team distance (closest pair)
        def min_pair_dist(a_list, b_list):
            best = None
            for a in a_list:
                if a is None:
                    continue
                for b in b_list:
                    if b is None:
                        continue
                    d = abs(a[0] - b[0]) + abs(a[1] - b[1])
                    best = d if best is None else min(best, d)
            return best
        approach_term = 0.0
        md = min_pair_dist(p1_list, p2_list)
        if md is not None:
            max_md = (H - 1) + (W - 1)
            if max_md > 0:
                approach_term = 1.0 - (md / max_md)

    # Zone margin: compare best (max) margin within ring; keep legacy difference plus new safety term
        zone_term = 0.0
        border_safety_term = 0.0
        try:
            steps_left = int(state[7, 0, 0]) if C > 7 else 0
            if _Squad is not None:
                shrink = (int(getattr(_Squad, 'MAX_STEPS', 36)) - steps_left) // int(getattr(_Squad, 'SHRINK_INTERVAL', 6))
                min_r = min_c = shrink
                max_r = H - 1 - shrink
                max_c = W - 1 - shrink
                def margin_list(lst):
                    best = None
                    for pos in lst:
                        if pos is None:
                            continue
                        r, c = pos
                        m = min(r - min_r, c - min_c, max_r - r, max_c - c)
                        best = m if best is None else max(best, m)
                    return best
                m1 = margin_list(p1_list)
                m2 = margin_list(p2_list)
                if m1 is not None and m2 is not None and max(1, H // 2) > 0:
                    zone_term = (m2 - m1) / max(1.0, H // 2)
                    denom = max(1.0, H // 2)
                    s1 = (m1 / denom) if m1 is not None else 0.0
                    s2 = (m2 / denom) if m2 is not None else 0.0
                    border_safety_term = s1 - s2
        except Exception:
            zone_term = 0.0
            border_safety_term = 0.0

        # LOS threat: any P1 has LOS to any P2 minus reverse
        shoot_range = int(getattr(_Squad, 'SHOOT_RANGE', 2))
        def any_los(a_list, b_list):
            for a in a_list:
                if a is None:
                    continue
                ar, ac = a
                for dr, dc in [(-1,0),(1,0),(0,-1),(0,1)]:
                    r, c = ar, ac
                    for _ in range(1, shoot_range + 1):
                        r += dr; c += dc
                        if not (0 <= r < H and 0 <= c < W):
                            break
                        if C > 4 and state[4, r, c] == 1:
                            break
                        if (r, c) in b_list:
                            return True
            return False
        threat_term = (1.0 if any_los(p1_list, p2_list) else 0.0) - (1.0 if any_los(p2_list, p1_list) else 0.0)

        # Coordination: encourage teammates to act in unison
        # 1) Cohesion: keep teammates closer together (normalized), relative to enemy cohesion
        def _cohesion_score(lst):
            # For two agents: score in [0,1], 1 when occupying same/adjacent cells, 0 when opposite corners
            if len(lst) >= 2 and lst[0] is not None and lst[1] is not None:
                d = abs(lst[0][0] - lst[1][0]) + abs(lst[0][1] - lst[1][1])
                denom = max(1.0, (H - 1) + (W - 1))
                return 1.0 - (d / denom)
            return 0.0
        cohesion_term = _cohesion_score(p1_list) - _cohesion_score(p2_list)

        # 2) Focus-fire: both teammates prefer the same nearest target enemy (encourages unified focus)
        def _focus_score(ours, theirs):
            # Need at least two of ours alive to coordinate
            if sum(1 for p in ours if p is not None) < 2:
                return 0.0
            enemies = [p for p in theirs if p is not None]
            if not enemies:
                return 0.0
            def _nearest_idx(pos):
                best_i = 0; best_d = None
                for i, e in enumerate(enemies):
                    d = abs(pos[0] - e[0]) + abs(pos[1] - e[1])
                    if best_d is None or d < best_d:
                        best_d = d; best_i = i
                return best_i
            idxs = [_nearest_idx(p) for p in ours if p is not None][:2]
            if len(idxs) < 2:
                return 0.0
            return 1.0 if idxs[0] == idxs[1] else 0.0
        focus_term = _focus_score(p1_list, p2_list) - _focus_score(p2_list, p1_list)

        # Elapsed penalty (optional)
        elapsed_term = 0.0
        if C > 7:
            steps_left = float(state[7, 0, 0])
            horizon = float(getattr(_Squad, 'MAX_STEPS', max(36, steps_left)))
            if horizon > 0:
                elapsed_term = -((horizon - steps_left) / horizon)

        # Weights tuned lightly for squads (add coordination weights)
        w_health = 0.36
        w_center = 0.16
        w_cap = 0.12
        w_approach = 0.16
        w_zone = 0.05  # slightly reduced
        w_threat = 0.12
        w_elapsed = 0.0
        w_cohesion = 0.08
        w_focus = 0.06
        w_border = 0.07  # new border safety weight
        total_w = w_health + w_center + w_cap + w_approach + w_zone + w_border + w_threat + w_elapsed + w_cohesion + w_focus
        v_p1 = (
            w_health * health_term +
            w_center * center_term +
            w_cap * cap_term +
            w_approach * approach_term +
            w_zone * zone_term +
            w_border * border_safety_term +
            w_threat * threat_term +
            w_elapsed * elapsed_term +
            w_cohesion * cohesion_term +
            w_focus * focus_term
        ) / (total_w if total_w > 0 else 1.0)
        return clamp_unit(float(v_p1 if int(player) == 1 else -v_p1))
    except Exception:
        return 0.0


def phi_duel_squad2(state, player: int = 1) -> float:
    """Potential for BattlefieldDuelSquad2 in [-1, 1] with explicit goal proximity bonus.

    Same channels and core terms as `phi_duel_squad`, but additionally rewards agents for
    moving closer to the goal square (the board center). Since this is a potential function,
    approaching the goal increases phi and thus yields positive shaping reward via
    gamma*phi(s') - phi(s).

    Channels (C,H,W):
      0..1: P1 agents, 2..3: P2 agents (one-hot)
      4: Obstacles, 5: P1 team HP, 6: P2 team HP, 7: Steps, 8: P1 streak, 9: P2 streak.
    """
    try:
        # Weights: start from phi_duel_squad and add w_goal
        w_health = 0.36  # 0.35 + 0.01 to keep total similar
        w_center = 0.16
        w_goal = 0.07
        w_cap = 0.12
        w_approach = 0.16
        w_zone = 0.05
        w_border = 0.07
        w_threat = 0.12
        w_elapsed = 0.0
        w_cohesion = 0.08
        w_focus = 0.06

        # Compute each criterion via dedicated subfunctions
        health_term = term_w_health(state)
        center_term = term_w_center(state)
        goal_term = term_w_goal(state)
        cap_term = term_w_cap(state)
        approach_term = term_w_approach(state)
        zone_term = term_w_zone(state)
        border_safety_term = term_w_border(state)
        threat_term = term_w_threat(state)
        elapsed_term = term_w_elapsed(state)
        cohesion_term = term_w_cohesion(state)
        focus_term = term_w_focus(state)

        total_w = (
            w_health + w_center + w_goal + w_cap + w_approach + w_zone + w_border +
            w_threat + w_elapsed + w_cohesion + w_focus
        )
        v_p1 = (
            w_health * health_term +
            w_center * center_term +
            w_goal * goal_term +
            w_cap * cap_term +
            w_approach * approach_term +
            w_zone * zone_term +
            w_border * border_safety_term +
            w_threat * threat_term +
            w_elapsed * elapsed_term +
            w_cohesion * cohesion_term +
            w_focus * focus_term
        ) / (total_w if total_w > 0 else 1.0)
        return clamp_unit(float(v_p1 if int(player) == 1 else -v_p1))
    except Exception:
        return 0.0

# --- DuelSquad2 per-criterion terms (names based on weight keys) ---
def _squad2_lists(state):
    """Return (C,H,W,p1_positions,p2_positions). Positions are [pos0,pos1] or None entries."""
    import numpy as np
    C, H, W = int(state.shape[0]), int(state.shape[1]), int(state.shape[2])
    def find(ch):
        idx = np.argwhere(state[ch] == 1)
        return (int(idx[0, 0]), int(idx[0, 1])) if idx.size else None
    p1_list = [find(0) if C > 0 else None, find(1) if C > 1 else None]
    p2_list = [find(2) if C > 2 else None, find(3) if C > 3 else None]
    return C, H, W, p1_list, p2_list


def term_w_health(state) -> float:
    """Health advantage term for DuelSquad2 in [-1, 1]."""
    try:
        try:
            from src.games.battlefield_duel import BattlefieldDuelSquad2 as _Squad
        except Exception:
            _Squad = None
        C, H, W, _, _ = _squad2_lists(state)
        p1_hp = float(state[5, 0, 0]) if C > 6 else 0.0
        p2_hp = float(state[6, 0, 0]) if C > 6 else 0.0
        max_hp = float(getattr(_Squad, 'AGENTS', 2)) if _Squad is not None else 2.0
        return clamp_unit((p1_hp - p2_hp) / max(1.0, max_hp))
    except Exception:
        return 0.0


def term_w_center(state) -> float:
    """Center control term for DuelSquad2 in [-1, 1]."""
    try:
        C, H, W, p1_list, p2_list = _squad2_lists(state)
        center = (H // 2, W // 2)
        def min_inv_center(lst):
            best = -1.0
            for pos in lst:
                if pos is None:
                    continue
                d = abs(pos[0] - center[0]) + abs(pos[1] - center[1])
                inv = 1.0 - d / max(1.0, (H - 1) + (W - 1))
                if inv > best:
                    best = inv
            return best if best >= 0 else 0.0
        return clamp_unit(min_inv_center(p1_list) - min_inv_center(p2_list))
    except Exception:
        return 0.0


def term_w_goal(state) -> float:
    """Goal proximity term for DuelSquad2 in [-1, 1]. Rewards moving closer to center."""
    try:
        C, H, W, p1_list, p2_list = _squad2_lists(state)
        center = (H // 2, W // 2)
        def avg_inv_center(lst):
            vals = []
            for pos in lst:
                if pos is None:
                    continue
                d = abs(pos[0] - center[0]) + abs(pos[1] - center[1])
                vals.append(1.0 - d / max(1.0, (H - 1) + (W - 1)))
            return (sum(vals) / len(vals)) if len(vals) > 0 else 0.0
        return clamp_unit(avg_inv_center(p1_list) - avg_inv_center(p2_list))
    except Exception:
        return 0.0


def term_w_cap(state) -> float:
    """Capture streak progress term for DuelSquad2 in [-1, 1]. Rewards advancing streak."""
    try:
        try:
            from src.games.battlefield_duel import BattlefieldDuelSquad2 as _Squad
        except Exception:
            _Squad = None
        C, H, W, _, _ = _squad2_lists(state)
        cap_steps = float(getattr(_Squad, 'CAPTURE_STEPS', 3)) if _Squad is not None else 3.0
        if C > 9 and cap_steps > 0:
            return clamp_unit((float(state[8, 0, 0]) - float(state[9, 0, 0])) / cap_steps)
        return 0.0
    except Exception:
        return 0.0


def term_w_approach(state) -> float:
    """Approach term for DuelSquad2 in [0, 1]. Rewards minimizing distance between teams."""
    try:
        C, H, W, p1_list, p2_list = _squad2_lists(state)
        def min_pair_dist(a_list, b_list):
            best = None
            for a in a_list:
                if a is None:
                    continue
                for b in b_list:
                    if b is None:
                        continue
                    d = abs(a[0] - b[0]) + abs(a[1] - b[1])
                    best = d if best is None else min(best, d)
            return best
        md = min_pair_dist(p1_list, p2_list)
        if md is None:
            return 0.0
        max_md = (H - 1) + (W - 1)
        if max_md <= 0:
            return 0.0
        return clamp_unit(1.0 - (md / max_md))
    except Exception:
        return 0.0


def term_w_zone(state) -> float:
    """Zone margin term for DuelSquad2 in [-1, 1]. Rewards larger margin (farther from border)."""
    try:
        try:
            from src.games.battlefield_duel import BattlefieldDuelSquad2 as _Squad
        except Exception:
            _Squad = None
        C, H, W, p1_list, p2_list = _squad2_lists(state)
        zone_term = 0.0
        steps_left = int(state[7, 0, 0]) if C > 7 else 0
        if _Squad is not None:
            shrink = (int(getattr(_Squad, 'MAX_STEPS', 36)) - steps_left) // int(getattr(_Squad, 'SHRINK_INTERVAL', 6))
            min_r = min_c = shrink
            max_r = H - 1 - shrink
            max_c = W - 1 - shrink
            def margin_list(lst):
                best = None
                for pos in lst:
                    if pos is None:
                        continue
                    r, c = pos
                    m = min(r - min_r, c - min_c, max_r - r, max_c - c)
                    best = m if best is None else max(best, m)
                return best
            m1 = margin_list(p1_list)
            m2 = margin_list(p2_list)
            if m1 is not None and m2 is not None and max(1, H // 2) > 0:
                zone_term = (m2 - m1) / max(1.0, H // 2)
        return clamp_unit(zone_term)
    except Exception:
        return 0.0


def term_w_border(state) -> float:
    """Border safety term for DuelSquad2 in [-1, 1]. Rewards being farther from border."""
    try:
        try:
            from src.games.battlefield_duel import BattlefieldDuelSquad2 as _Squad
        except Exception:
            _Squad = None
        C, H, W, p1_list, p2_list = _squad2_lists(state)
        border = 0.0
        steps_left = int(state[7, 0, 0]) if C > 7 else 0
        if _Squad is not None:
            shrink = (int(getattr(_Squad, 'MAX_STEPS', 36)) - steps_left) // int(getattr(_Squad, 'SHRINK_INTERVAL', 6))
            min_r = min_c = shrink
            max_r = H - 1 - shrink
            max_c = W - 1 - shrink
            def best_margin(lst):
                best = None
                for pos in lst:
                    if pos is None:
                        continue
                    r, c = pos
                    m = min(r - min_r, c - min_c, max_r - r, max_c - c)
                    best = m if best is None else max(best, m)
                return best
            m1 = best_margin(p1_list)
            m2 = best_margin(p2_list)
            denom = max(1.0, H // 2)
            s1 = (m1 / denom) if m1 is not None else 0.0
            s2 = (m2 / denom) if m2 is not None else 0.0
            border = s1 - s2
        return clamp_unit(border)
    except Exception:
        return 0.0


def term_w_threat(state) -> float:
    """LOS threat term for DuelSquad2 in [-1, 1]. Rewards having line-of-sight to any enemy."""
    try:
        try:
            from src.games.battlefield_duel import BattlefieldDuelSquad2 as _Squad
        except Exception:
            _Squad = None
        C, H, W, p1_list, p2_list = _squad2_lists(state)
        shoot_range = int(getattr(_Squad, 'SHOOT_RANGE', 2)) if _Squad is not None else 2
        def any_los(a_list, b_list):
            for a in a_list:
                if a is None:
                    continue
                ar, ac = a
                for dr, dc in [(-1,0),(1,0),(0,-1),(0,1)]:
                    r, c = ar, ac
                    for _ in range(1, shoot_range + 1):
                        r += dr; c += dc
                        if not (0 <= r < H and 0 <= c < W):
                            break
                        if C > 4 and state[4, r, c] == 1:
                            break
                        if (r, c) in b_list:
                            return True
            return False
        return clamp_unit((1.0 if any_los(p1_list, p2_list) else 0.0) - (1.0 if any_los(p2_list, p1_list) else 0.0))
    except Exception:
        return 0.0


def term_w_elapsed(state) -> float:
    """Elapsed time penalty term for DuelSquad2 in [-1, 0]."""
    try:
        try:
            from src.games.battlefield_duel import BattlefieldDuelSquad2 as _Squad
        except Exception:
            _Squad = None
        C, H, W, _, _ = _squad2_lists(state)
        if C <= 7:
            return 0.0
        steps_left = float(state[7, 0, 0])
        horizon = float(getattr(_Squad, 'MAX_STEPS', max(36, steps_left))) if _Squad is not None else max(36.0, steps_left)
        if horizon > 0:
            return clamp_unit(-((horizon - steps_left) / horizon))
        return 0.0
    except Exception:
        return 0.0


def term_w_cohesion(state) -> float:
    """Cohesion term for DuelSquad2 in [-1, 1]. Rewards keeping teammates closer together."""
    try:
        C, H, W, p1_list, p2_list = _squad2_lists(state)
        def _cohesion_score(lst):
            if len(lst) >= 2 and lst[0] is not None and lst[1] is not None:
                d = abs(lst[0][0] - lst[1][0]) + abs(lst[0][1] - lst[1][1])
                denom = max(1.0, (H - 1) + (W - 1))
                return 1.0 - (d / denom)
            return 0.0
        return clamp_unit(_cohesion_score(p1_list) - _cohesion_score(p2_list))
    except Exception:
        return 0.0


def term_w_focus(state) -> float:
    """Focus-fire term for DuelSquad2 in [-1, 1]. Rewards teammates targeting same nearest enemy."""
    try:
        C, H, W, p1_list, p2_list = _squad2_lists(state)
        def _focus_score(ours, theirs):
            # Need at least two of ours alive to coordinate
            if sum(1 for p in ours if p is not None) < 2:
                return 0.0
            enemies = [p for p in theirs if p is not None]
            if not enemies:
                return 0.0
            def _nearest_idx(pos):
                best_i = 0; best_d = None
                for i, e in enumerate(enemies):
                    d = abs(pos[0] - e[0]) + abs(pos[1] - e[1])
                    if best_d is None or d < best_d:
                        best_d = d; best_i = i
                return best_i
            idxs = [_nearest_idx(p) for p in ours if p is not None][:2]
            if len(idxs) < 2:
                return 0.0
            return 1.0 if idxs[0] == idxs[1] else 0.0
        return clamp_unit(_focus_score(p1_list, p2_list) - _focus_score(p2_list, p1_list))
    except Exception:
        return 0.0

def call_phi(phi_fn: Optional[Callable[..., float]], game: Any, board: Any, player: int) -> float:
    """Call phi with a flexible signature: (game, board, player) | (board, player) | (board).
    Returns 0.0 if phi_fn is None or raises.
    """
    if phi_fn is None:
        return 0.0
    try:
        return float(phi_fn(game, board, player))
    except TypeError:
        try:
            return float(phi_fn(board, player))
        except TypeError:
            try:
                return float(phi_fn(board))
            except Exception:
                return 0.0
    except Exception:
        return 0.0
