import numpy as np
from collections import deque


class TransformerFeaturizer:
    """
    Featurizer for TransformerAlphaNet inputs.

    - Stacks K previous canonical-encoded boards as history planes: [current, t-1, ..., t-K].
    - Optionally builds simple action token features per legal action: [row/H, col/W, 1.0, legal_flag].
      Action index mapping assumes a grid action space (row * W + col) when ROWS,COLS are present.

    Notes:
    - History is maintained per-episode via a deque of encoded canonical boards (C,H,W) numpy arrays.
    - make_input(encoded_current) returns stacked channels (C*(1+K), H, W) as numpy float32.
    - Action token utilities are provided but not wired into inference queues yet.
    """

    def __init__(self, game_cls, history_steps: int = 0,
                 include_steps_left_plane: bool = False,
                 include_repetition_plane: bool = False,
                 include_since_damage_plane: bool = False):
        self.game = game_cls
        self.K = max(0, int(history_steps or 0))
        self.rows = getattr(game_cls, 'ROWS', None)
        self.cols = getattr(game_cls, 'COLS', None)
        self._hist = deque(maxlen=self.K)
        # Extra planes toggles
        self.include_steps_left_plane = bool(include_steps_left_plane)
        self.include_repetition_plane = bool(include_repetition_plane)
        self.include_since_damage_plane = bool(include_since_damage_plane)
        # Per-root ephemeral context
        self._root_is_repetition = False
        self._since_last_damage_norm = 0.0
        self._steps_left_norm = 0.0

    def reset(self):
        self._hist.clear()
        self._root_is_repetition = False
        self._since_last_damage_norm = 0.0
        self._steps_left_norm = 0.0

    def push(self, encoded_canonical_board: np.ndarray):
        """Push one encoded canonical board (C,H,W) into history (most recent first)."""
        if self.K <= 0:
            return
        # Store a copy to avoid accidental mutation
        self._hist.appendleft(np.array(encoded_canonical_board, copy=True))

    def set_root_context(self, *, is_repetition: bool = False,
                          since_last_damage_norm: float = 0.0,
                          steps_left_norm: float = 0.0):
        """Set per-root context used for extra broadcast planes.

        - is_repetition: whether the current root has occurred earlier in this episode.
        - since_last_damage_norm: [0,1] normalized steps since last damage (domain-specific best-effort).
        - steps_left_norm: [0,1] normalized steps remaining (if available).
        """
        self._root_is_repetition = bool(is_repetition)
        self._since_last_damage_norm = float(np.clip(since_last_damage_norm, 0.0, 1.0))
        self._steps_left_norm = float(np.clip(steps_left_norm, 0.0, 1.0))

    def make_input(self, encoded_current: np.ndarray) -> np.ndarray:
        """Return stacked input (C*(1+K)+extra, H, W) as float32, with zeros for missing history frames.

        Extra planes (broadcast) are appended in order: [steps_left, repetition, since_last_damage] for any enabled flags.
        """
        cur = np.asarray(encoded_current, dtype=np.float32)
        C, H, W = cur.shape
        if self.K <= 0:
            stacked = cur
        else:
            out = [cur]
            # Append up to K previous frames; pad with zeros if not enough
            for i in range(self.K):
                if i < len(self._hist):
                    out.append(self._hist[i].astype(np.float32, copy=False))
                else:
                    out.append(np.zeros((C, H, W), dtype=np.float32))
            stacked = np.concatenate(out, axis=0)
        # Append extra broadcast planes if any flag is enabled
        extras = []
        if self.include_steps_left_plane:
            extras.append(np.full((1, H, W), np.float32(self._steps_left_norm), dtype=np.float32))
        if self.include_repetition_plane:
            extras.append(np.full((1, H, W), np.float32(1.0 if self._root_is_repetition else 0.0), dtype=np.float32))
        if self.include_since_damage_plane:
            extras.append(np.full((1, H, W), np.float32(self._since_last_damage_norm), dtype=np.float32))
        if extras:
            stacked = np.concatenate([stacked] + extras, axis=0)
        return stacked

    def build_action_tokens(self, legal_actions: list, action_size: int) -> tuple:
        """
        Build simple (L,D_in) tokens, (L,) indices, and (A,) legal mask for a grid action space.
        D_in = 4: [row/H, col/W, 1.0, 1.0]
        """
        A = int(action_size)
        mask = np.zeros((A,), dtype=bool)
        if not legal_actions:
            return np.zeros((0, 4), dtype=np.float32), np.zeros((0,), dtype=np.int64), mask
        mask[np.asarray(legal_actions, dtype=int)] = True
        if self.rows is None or self.cols is None:
            # Fallback: index normalized, no grid layout known
            L = len(legal_actions)
            idx = np.asarray(legal_actions, dtype=np.float32)
            tok = np.stack([
                idx / max(1.0, float(A - 1)),  # normalized index
                np.zeros_like(idx),
                np.ones_like(idx),
                np.ones_like(idx),
            ], axis=1).astype(np.float32)
            return tok, np.asarray(legal_actions, dtype=np.int64), mask
        H, W = float(self.rows), float(self.cols)
        rows = np.asarray([a // self.cols for a in legal_actions], dtype=np.float32)
        cols = np.asarray([a % self.cols for a in legal_actions], dtype=np.float32)
        tok = np.stack([
            rows / max(1.0, H - 1.0),
            cols / max(1.0, W - 1.0),
            np.ones_like(rows),
            np.ones_like(rows),
        ], axis=1).astype(np.float32)
        return tok, np.asarray(legal_actions, dtype=np.int64), mask
