import jax
import jax.numpy as jnp
from jax import Array
from typing import Any, Dict, Tuple

from gymnax.environments.environment import EnvParams
from flax import struct

from navix.actions import MINIGRID_ACTION_SET, COMPLETE_ACTION_SET
from navix.environments.environment import Environment, Timestep
from navix.entities import Entities
from navix.states import State
from .gymnax_wrappers import NavixToGymnax, GymnaxState


class NavixToGymnaxGraph(NavixToGymnax):
    """
    A Gymnax wrapper for Navix that returns, in the info dict, the local dependency graph
    (adjacency matrix) of the underlying state variables for the taken action.
    Supports both the MINIGRID_ACTION_SET and COMPLETE_ACTION_SET.
    """
    def __init__(self, env: Environment, autoreset: bool = True):
        super().__init__(env, autoreset)
        # placeholders that will be constructed on first step once we know
        # the exact variable ordering coming from `get_state_names`
        self.var_names: list[str] | None = None
        self.I: Array | None = None
        self.minigrid_adj: Array | None = None
        self.complete_adj: Array | None = None

        timestep = env.reset(jax.random.PRNGKey(0))
        self._build_graphs(timestep.state)

    def _extract_vars(self, state: State) -> Array:
        """Return the vector of variable values in the same order as `var_names`."""
        # Use the helper from the parent wrapper that already follows the same
        # order as `get_state_names`.
        return self.get_state(state)

    def step(
        self,
        key: Array,
        state: GymnaxState,
        action: jax.Array,
        params: EnvParams
    ) -> Tuple[Any, GymnaxState, Array, Array, Dict[str, Any]]:
        # Build graphs lazily the first time we see a state
        # self._build_graphs(state.timestep.state)

        # First run the regular NavixToGymnax step
        obs, new_state, reward, done, info = super().step(key, state, action, params)
        # Convert the action array into a Python int for indexing
        # action_idx = int(action) # problematic when jax traces this function
        action_idx = action
        # Select the right adjacency based on the current action-set
        if len(self.env.action_set) == len(MINIGRID_ACTION_SET):
            local_adj = self.minigrid_adj[action_idx]
        elif len(self.env.action_set) == len(COMPLETE_ACTION_SET):
            local_adj = self.complete_adj[action_idx]
        else:
            raise ValueError(
                f"Unsupported Navix action-set of size {len(self.env.action_set)}"
            )
        # if a variable didn't actually change, fall back to identity (noop) graph for that row
        old_s = state.timestep.state
        new_s = new_state.timestep.state
        old_vars = self._extract_vars(old_s)
        new_vars = self._extract_vars(new_s)
        changed = new_vars != old_vars            # bool[D]
        # row-wise select: if changed use static row, else use identity row
        adj = jnp.where(changed[:, None], local_adj, self.I)
        info['local_graph'] = adj
        # info['local_var_names'] = self.var_names # strings are not serializable by jax
        return obs, new_state, reward, done, info 

    # ------------------------------------------------------------------
    # Graph construction helpers
    # ------------------------------------------------------------------
    def _index(self, name: str):
        if self.var_names is None:
            return None
        try:
            return self.var_names.index(name)
        except ValueError:
            return None

    def _build_graphs(self, state: State):
        """Build identity and action-specific adjacency matrices based on the
        variable ordering returned by NavixToGymnax.get_state_names."""
        if self.var_names is not None:
            return  # already built

        # variable names according to parent helper
        self.var_names = self.get_state_names(state)
        D = len(self.var_names)
        self.I = jnp.eye(D, dtype=jnp.int32)

        # helper sets of indices
        px_idx = self._index('player_x')
        py_idx = self._index('player_y')
        pdir_idx = self._index('player_dir')
        pocket_idx = self._index('player_pocket')

        # collect key indices (positions/state)
        key_x_indices = [i for i, n in enumerate(self.var_names) if n.startswith('key_x_') or n == 'key_x']
        key_y_indices = [i for i, n in enumerate(self.var_names) if n.startswith('key_y_') or n == 'key_y']
        key_id_indices = [i for i, n in enumerate(self.var_names) if n.startswith('key_id') or n == 'key_id']

        # door indices
        door_state_indices = [i for i, n in enumerate(self.var_names) if n.startswith('door_state') or n == 'door_state']
        door_x_indices = [i for i, n in enumerate(self.var_names) if n.startswith('door_x_') or n == 'door_x']
        door_y_indices = [i for i, n in enumerate(self.var_names) if n.startswith('door_y_') or n == 'door_y']

        # -------------------------------- movement ----------------------
        A_move = self.I.copy()
        if px_idx is not None and pdir_idx is not None:
            A_move = A_move.at[px_idx, pdir_idx].set(1)
        if py_idx is not None and pdir_idx is not None:
            A_move = A_move.at[py_idx, pdir_idx].set(1)

        # ----------------------------- pickup ---------------------------
        A_pick = self.I.copy()
        # pocket row deps: player vars, pocket itself, key pos and id
        if pocket_idx is not None:
            pick_parents = [idx for idx in [px_idx, py_idx, pdir_idx, pocket_idx] + key_x_indices + key_y_indices + key_id_indices if idx is not None]
            parent_arr = jnp.array(pick_parents, dtype=jnp.int32)
            A_pick = A_pick.at[pocket_idx, parent_arr].set(1)
        # key_x rows depend on player vars and their own old positions
        if key_x_indices and px_idx is not None and py_idx is not None and pdir_idx is not None:
            parents = jnp.array([px_idx, py_idx, pdir_idx] + key_x_indices, dtype=jnp.int32)
            rows = jnp.repeat(jnp.array(key_x_indices, dtype=jnp.int32), len(parents))
            cols = jnp.tile(parents, len(key_x_indices))
            A_pick = A_pick.at[(rows, cols)].set(1)
        # key_y rows depend on player vars and their own old positions
        if key_y_indices and px_idx is not None and py_idx is not None and pdir_idx is not None:
            parents = jnp.array([px_idx, py_idx, pdir_idx] + key_y_indices, dtype=jnp.int32)
            rows = jnp.repeat(jnp.array(key_y_indices, dtype=jnp.int32), len(parents))
            cols = jnp.tile(parents, len(key_y_indices))
            A_pick = A_pick.at[(rows, cols)].set(1)

        # ----------------------------- drop -----------------------------
        A_drop = self.I.copy()
        # pocket row clears via indexing
        if pocket_idx is not None:
            drop_parents = [idx for idx in [px_idx, py_idx, pdir_idx, pocket_idx] if idx is not None]
            parent_arr = jnp.array(drop_parents, dtype=jnp.int32)
            A_drop = A_drop.at[pocket_idx, parent_arr].set(1)
        # key_x rows depend on player vars and pocket
        if key_x_indices and drop_parents:
            parents = jnp.array(drop_parents, dtype=jnp.int32)
            rows = jnp.repeat(jnp.array(key_x_indices, dtype=jnp.int32), len(parents))
            cols = jnp.tile(parents, len(key_x_indices))
            A_drop = A_drop.at[(rows, cols)].set(1)
        # key_y rows depend on player vars and pocket
        if key_y_indices and drop_parents:
            parents = jnp.array(drop_parents, dtype=jnp.int32)
            rows = jnp.repeat(jnp.array(key_y_indices, dtype=jnp.int32), len(parents))
            cols = jnp.tile(parents, len(key_y_indices))
            A_drop = A_drop.at[(rows, cols)].set(1)

        # ----------------------------- toggle/open ----------------------
        A_tog = self.I.copy()
        # pocket row deps via indexing
        if pocket_idx is not None:
            tog_parents = [idx for idx in [px_idx, py_idx, pdir_idx, pocket_idx] + door_state_indices + door_x_indices + door_y_indices if idx is not None]
            parent_arr = jnp.array(tog_parents, dtype=jnp.int32)
            A_tog = A_tog.at[pocket_idx, parent_arr].set(1)
        # door state rows depend on player vars and their own state
        if door_state_indices:
            ds_arr = jnp.array(door_state_indices, dtype=jnp.int32)
            ds_parents = [idx for idx in [px_idx, py_idx, pdir_idx, pocket_idx] + door_state_indices if idx is not None]
            parent_arr = jnp.array(ds_parents, dtype=jnp.int32)
            rows = jnp.repeat(ds_arr, len(parent_arr))
            cols = jnp.tile(parent_arr, len(ds_arr))
            A_tog = A_tog.at[(rows, cols)].set(1)

        # Build stacks according to action sets
        # Minigrid order: rotate_ccw, rotate_cw, forward, pickup, drop, toggle, done
        self.minigrid_adj = jnp.stack([
            self.I,  # rotate_ccw
            self.I,  # rotate_cw
            A_move,
            A_pick,
            A_drop,
            A_tog,
            self.I,
        ], axis=0)

        # Complete order: noop, rotate_cw, rotate_ccw, forward, right, backward, left, pickup, open, done
        self.complete_adj = jnp.stack([
            self.I,  # noop
            self.I,  # rotate_cw
            self.I,  # rotate_ccw
            A_move,
            A_move,
            A_move,
            A_move,
            A_pick,
            A_tog,   # open
            self.I,
        ], axis=0) 