from typing import Union
import jax
import jax.numpy as jnp
from jax import Array
from flax import struct

from navix import rewards, observations, terminations
from navix.components import EMPTY_POCKET_ID, DISCARD_PILE_COORDS
from navix.rendering.cache import RenderingCache
from navix.rendering.registry import PALETTE
from navix.environments import Environment
from navix.entities import Player, Key, Door, Goal, Wall
from navix.states import State
from navix import Timestep
from navix.grid import mask_by_coordinates, room, random_positions, random_directions
from navix.environments.registry import register_env

# ======================= Sampling probabilities  =============================

# If the door is closed, probability that the player already carries the key.
PROB_KEY_IN_POCKET_IF_DOOR_CLOSED: float = 0.5

# =============================================================================
# Whether the key position (when ON THE FLOOR) is fixed. If False it is sampled
# uniformly in the allowed room.
KEY_FIXED_POS: bool = True
# Coordinates to place the key when fixed. Negative values are interpreted as
# indexing from the end (Python-style). For example (-2, 1) means second row
# from the bottom, second column from the left (inside the wall borders).
KEY_COORDS: tuple[int, int] = (-2, 1)

# Whether the goal position is fixed. If False it is sampled uniformly in the
# right-hand room.
GOAL_FIXED_POS: bool = True
GOAL_COORDS: tuple[int, int] = (-2, -2)
# =============================================================================

# =============================================================================
# Probability of spawning the player ALREADY standing in front of the key
# (adjacent cell, orientation facing the key) when the key is on the floor.
PROB_ADJACENT_KEY: float = 0.4
# Probability of spawning the player ALREADY standing in front of the door
# (adjacent cell inside the left room, orientation facing the door) when
# the door is still closed and the key is not already handled above.
PROB_ADJACENT_DOOR: float = PROB_ADJACENT_KEY

# Probability that the door starts open.
PROB_DOOR_OPEN: float = 1 - (PROB_ADJACENT_DOOR + PROB_ADJACENT_KEY) * 0.9
# =============================================================================

class DoorKeyUniform(Environment):
    random_start: bool = struct.field(pytree_node=False, default=False)

    def _reset(self, key: Array, cache: Union[RenderingCache, None] = None) -> Timestep:
        # check minimum height and width
        assert (
            self.height > 3
        ), f"Room height must be greater than 3, got {self.height} instead"
        assert (
            self.width > 4
        ), f"Room width must be greater than 5, got {self.width} instead"

        # We need several independent RNG keys – split once up-front.
        (
            key,
            k1,  # door col
            k2,  # door row
            k3,  # door open bernoulli
            k4,  # random player position (fallback)
            k5,  # random player orientation (fallback)
            k6,  # key-in-pocket bernoulli (when door closed)
            k7,  # spawn adjacent to key bernoulli
            k8,  # spawn adjacent to door bernoulli
            k9,  # sample specific adjacent‐to-key candidate
            k10, # sample specific adjacent-to-door candidate
            k11, # bias face key (when not already aligned)
            k12, # bias face door (when not already aligned)
            k13, # key floor random position (if not fixed)
            k14, # goal random position (if not fixed)
        ) = jax.random.split(key, 15)

        grid = room(height=self.height, width=self.width)

        # door positions
        door_col = jax.random.randint(k1, (), 2, self.width - 2)
        door_row = jax.random.randint(k2, (), 1, self.height - 1)
        door_pos = jnp.asarray((door_row, door_col))

        # sample door status (open / closed)
        door_open = jax.random.bernoulli(k3, p=jnp.asarray(PROB_DOOR_OPEN))

        doors = Door.create(
            position=door_pos,
            requires=jnp.asarray(3 * (1-door_open) + -1 * door_open),
            open=jnp.asarray(door_open).astype(jnp.int32),  # obey sampled status
            colour=PALETTE.YELLOW,
        )

        # wall positions
        wall_rows = jnp.arange(1, self.height - 1)
        wall_cols = jnp.asarray([door_col] * (self.height - 2))
        wall_pos = jnp.stack((wall_rows, wall_cols), axis=1)
        # remove wall where the door is
        wall_pos = jnp.delete(
            wall_pos, door_row - 1, axis=0, assume_unique_indices=True
        )
        walls = Wall.create(position=wall_pos)

        # get rooms
        first_room_mask = mask_by_coordinates(
            grid, (jnp.asarray(self.height), door_col), jnp.less
        )
        first_room = jnp.where(first_room_mask, grid, -1)  # put walls where not mask
        second_room_mask = mask_by_coordinates(
            grid, (jnp.asarray(0), door_col), jnp.greater
        )
        second_room = jnp.where(second_room_mask, grid, -1)  # put walls where not mask

        full_room = jnp.where(first_room_mask | second_room_mask, grid, -1)

        # rule (2): if the door is closed the player must stay in the left / first room
        allowed_player_grid = jnp.where(door_open, full_room, first_room)

        # rule (1): door_open => key_in_pocket. Otherwise draw according to constant
        key_in_pocket = jnp.logical_or(
            door_open,
            jax.random.bernoulli(k6, p=jnp.asarray(PROB_KEY_IN_POCKET_IF_DOOR_CLOSED)),
        )

        # ---------------- Key position on the floor ----------------
        if KEY_FIXED_POS:
            # Resolve negative indices w.r.t grid size          
            key_row = self.height + KEY_COORDS[0] if KEY_COORDS[0] < 0 else KEY_COORDS[0]
            key_col = self.width + KEY_COORDS[1] if KEY_COORDS[1] < 0 else KEY_COORDS[1]
            key_pos_floor = jnp.asarray([key_row, key_col])
        else:
            key_pos_floor = random_positions(k13, first_room, exclude=door_pos - jnp.asarray([0, 1]))

        key_pos = jnp.where(key_in_pocket[..., None], DISCARD_PILE_COORDS, key_pos_floor)

        # ---------------- Goal position ----------------
        if GOAL_FIXED_POS:
            goal_row = self.height + GOAL_COORDS[0] if GOAL_COORDS[0] < 0 else GOAL_COORDS[0]
            goal_col = self.width + GOAL_COORDS[1] if GOAL_COORDS[1] < 0 else GOAL_COORDS[1]
            goal_pos = jnp.asarray([goal_row, goal_col])
        else:
            goal_pos = random_positions(k14, second_room)

        # ------------------------------------------------------------------
        # Helper: sample a valid adjacent position oriented towards `target`.
        # ------------------------------------------------------------------
        def _adjacent_position(target_pos: Array, rng: Array):
            offsets = jnp.asarray([[0, -1], [-1, 0], [0, 1], [1, 0]], dtype=jnp.int32)
            directions = jnp.asarray([0, 1, 2, 3], dtype=jnp.int32)  # E,S,W,N
            candidates = target_pos + offsets  # shape (4,2)

            # In-bounds mask
            in_bounds = (
                (candidates[:, 0] >= 0)
                & (candidates[:, 0] < self.height)
                & (candidates[:, 1] >= 0)
                & (candidates[:, 1] < self.width)
            )

            # Free cell mask (walkable – value 0 in `allowed_player_grid`)
            free = allowed_player_grid[candidates[:, 0], candidates[:, 1]] == 0

            valid = in_bounds & free
            valid_f = valid.astype(jnp.float32)
            total_valid = valid_f.sum()

            def _choose(_):
                probs = valid_f / total_valid
                idx = jax.random.choice(rng, 4, p=probs)
                return candidates[idx], directions[idx]

            def _fallback(_):
                return random_positions(rng, allowed_player_grid), random_directions(rng)

            return jax.lax.cond(total_valid > 0, _choose, _fallback, operand=None)

        # ------------------------------------------------------------------
        # Decide whether to spawn already adjacent to key / door.
        # ------------------------------------------------------------------
        spawn_adj_key = (~key_in_pocket) & jax.random.bernoulli(k7, p=jnp.asarray(PROB_ADJACENT_KEY))
        spawn_adj_door = (~door_open) & (~spawn_adj_key) & jax.random.bernoulli(
            k8, p=jnp.asarray(PROB_ADJACENT_DOOR)
        )

        pos_key, dir_key = _adjacent_position(key_pos_floor, k9)
        pos_door, dir_door = _adjacent_position(door_pos, k10)

        # Fallback random placement
        pos_rand = random_positions(k4, allowed_player_grid, exclude=jnp.stack([key_pos, door_pos, goal_pos]))
        dir_rand = random_directions(k5)

        # Select final player position & direction
        player_pos = jnp.where(
            spawn_adj_key[..., None],
            pos_key,
            jnp.where(spawn_adj_door[..., None], pos_door, pos_rand),
        )

        player_dir = jnp.where(
            spawn_adj_key,
            dir_key,
            jnp.where(spawn_adj_door, dir_door, dir_rand),
        )

        # create entities
        pocket_id = jnp.where(key_in_pocket, jnp.asarray(3), EMPTY_POCKET_ID)
        player = Player.create(position=player_pos, direction=player_dir, pocket=pocket_id)
        keys = Key.create(position=key_pos, id=jnp.asarray(3), colour=PALETTE.YELLOW)
        goals = Goal.create(position=goal_pos, probability=jnp.asarray(1.0))

        # remove the wall beneath the door
        grid = grid.at[tuple(door_pos)].set(0)

        entities = {
            "player": player[None],
            "key": keys[None],
            "door": doors[None],
            "goal": goals[None],
            "wall": walls,
        }

        state = State(
            key=key,
            grid=grid,
            cache=cache or RenderingCache.init(grid),
            entities=entities,
        )
        return Timestep(
            t=jnp.asarray(0, dtype=jnp.int32),
            observation=self.observation_fn(state),
            action=jnp.asarray(-1, dtype=jnp.int32),
            reward=jnp.asarray(0.0, dtype=jnp.float32),
            step_type=jnp.asarray(0, dtype=jnp.int32),
            state=state,
        )


register_env(
    "Navix-DoorKey-Uniform-5x5-v0",
    lambda *args, **kwargs: DoorKeyUniform.create(
        observation_fn=kwargs.pop("observation_fn", observations.symbolic),
        reward_fn=kwargs.pop("reward_fn", rewards.on_goal_reached),
        termination_fn=kwargs.pop("termination_fn", terminations.on_goal_reached),
        height=5,
        width=5,
        random_start=False,
        *args,
        **kwargs,
    ),
)
register_env(
    "Navix-DoorKey-Uniform-6x6-v0",
    lambda *args, **kwargs: DoorKeyUniform.create(
        observation_fn=kwargs.pop("observation_fn", observations.symbolic),
        reward_fn=kwargs.pop("reward_fn", rewards.on_goal_reached),
        termination_fn=kwargs.pop("termination_fn", terminations.on_goal_reached),
        height=6,
        width=6,
        random_start=False,
        *args,
        **kwargs,
    ),
)
register_env(
    "Navix-DoorKey-Uniform-8x8-v0",
    lambda *args, **kwargs: DoorKeyUniform.create(
        observation_fn=kwargs.pop("observation_fn", observations.symbolic),
        reward_fn=kwargs.pop("reward_fn", rewards.on_goal_reached),
        termination_fn=kwargs.pop("termination_fn", terminations.on_goal_reached),
        height=8,
        width=8,
        random_start=False,
        *args,
        **kwargs,
    ),
)
register_env(
    "Navix-DoorKey-Uniform-16x16-v0",
    lambda *args, **kwargs: DoorKeyUniform.create(
        observation_fn=kwargs.pop("observation_fn", observations.symbolic),
        reward_fn=kwargs.pop("reward_fn", rewards.on_goal_reached),
        termination_fn=kwargs.pop("termination_fn", terminations.on_goal_reached),
        height=16,
        width=16,
        random_start=False,
        *args,
        **kwargs,
    ),
)