from __future__ import annotations
from typing import Union

import jax
import jax.numpy as jnp
from jax import Array

from navix import observations, rewards, terminations

from navix.components import EMPTY_POCKET_ID
from navix.entities import Entities, Goal, Player, Wall
from navix.states import State
from navix.grid import (
    random_positions,
    random_directions,
    room,
    horizontal_wall,
    vertical_wall,
)
from navix.rendering.cache import RenderingCache
from navix.environments.environment import Environment, Timestep
from navix.environments.registry import register_env


class FourRoomsFixed(Environment):
    def _reset(self, key: Array, cache: Union[RenderingCache, None] = None) -> Timestep:
        assert self.height > 4, f"Insufficient height for room {self.height} < 4"
        assert self.width > 4, f"Insufficient width for room {self.width} < 4"
        key, k1, k2 = jax.random.split(key, 3)

        # map
        grid = room(height=self.height, width=self.width)

        # vertical partition
        opening_1 = jax.random.randint(k1, shape=(), minval=1, maxval=self.height // 2)
        opening_2 = jax.random.randint(
            k1, shape=(), minval=self.height // 2 + 2, maxval=self.height
        )
        openings = jnp.stack([opening_1, opening_2])
        wall_pos_vert = vertical_wall(grid, 9, openings)

        # horizontal partition
        opening_1 = jax.random.randint(k2, shape=(), minval=1, maxval=self.width // 2)
        opening_2 = jax.random.randint(
            k1, shape=(), minval=self.width // 2 + 2, maxval=self.width
        )
        openings = jnp.stack([opening_1, opening_2])
        wall_pos_hor = horizontal_wall(grid, 9, openings)

        walls_pos = jnp.concatenate([wall_pos_vert, wall_pos_hor])
        walls = Wall.create(position=walls_pos)

        # player
        player_pos = random_positions(k1, grid, n=1, exclude=walls_pos)
        direction = random_directions(k2, n=1)
        player = Player.create(
            position=player_pos,
            direction=direction,
            pocket=EMPTY_POCKET_ID,
        )
        # goal
        goal = Goal.create(position=jnp.asarray([1, 1]), probability=jnp.asarray(1.0))
        entities = {
            Entities.PLAYER: player[None],
            # Entities.GOAL: goal[None],
            Entities.WALL: walls,
        }

        # systems
        state = State(
            key=key,
            grid=grid,
            cache=cache or RenderingCache.init(grid),
            entities=entities,
        )

        return Timestep(
            t=jnp.asarray(0, dtype=jnp.int32),
            observation=self.observation_fn(state),
            action=jnp.asarray(0, dtype=jnp.int32),
            reward=jnp.asarray(0.0, dtype=jnp.float32),
            step_type=jnp.asarray(0, dtype=jnp.int32),
            state=state,
        )


register_env(
    "Navix-FourRooms-fixed-v0",
    lambda *args, **kwargs: FourRoomsFixed.create(
        height=19,
        width=19,
        observation_fn=kwargs.pop("observation_fn", observations.symbolic),
        reward_fn=kwargs.pop("reward_fn", rewards.on_goal_reached),
        termination_fn=kwargs.pop("termination_fn", terminations.on_goal_reached),
        *args,
        **kwargs,
    ),
)
