# Edited from JaxMarl: https://github.com/FLAIROx/JaxMARL/tree/main/jaxmarl/environments/overcooked

from enum import IntEnum
import time

import numpy as np
import jax
import jax.numpy as jnp
from jax import lax
from typing import Callable, Optional, Tuple, Dict
import chex
from flax import struct


from jaxmarl.environments import spaces

from src.envs.ogc.common import make_overcooked_map
from src.envs.ogc.underspecified_env import UnderspecifiedMultiAgentEnv

from src.envs.overcooked.augmented_layouts import overcooked_layouts

asymm_advantages_6_9 = """
WWWWWWWWW
O WXWOW X
W   P A W
WA  P   W
WWWBWBWWW
WWWWWWWWW
"""

counter_circuit_6_9 = """
WWWPPWWWW
W A    WW
B WWWW XW
W     AWW
WWWOOWWWW
WWWWWWWWW
"""

forced_coord_6_9 = """
WWWPWWWWW
OAWAPWWWW
O W WWWWW
B W WWWWW
WWWXWWWWW
WWWWWWWWW
"""

cramped_room_6_9 = """
WWPWWWWWW
OAA OWWWW
W   WWWWW
WBWXWWWWW
WWWWWWWWW
WWWWWWWWW
"""

coord_ring_6_9 = """
WWWPWWWWW
WA APWWWW
B W WWWWW
O   WWWWW
WOXWWWWWW
WWWWWWWWW
"""

forced_coord_5_5 = """
WWWPW
OAWAP
O W W
B W W
WWWXW
"""

cramped_room_5_5 = """
WWPWW
OAA O
W   W
WBWXW
WWWWW
"""

coord_ring_5_5 = """
WWWPW
WA AP
B W W
O   W
WOXWW
"""


class Actions(IntEnum):
    # Turn left, turn right, move forward
    right = 0
    down = 1
    left = 2
    up = 3
    stay = 4
    interact = 5
    done = 6


@struct.dataclass
class EnvState:
    agent_pos: chex.Array
    agent_dir: chex.Array
    agent_dir_idx: chex.Array
    agent_inv: chex.Array
    goal_pos: chex.Array
    pot_pos: chex.Array
    wall_map: chex.Array
    maze_map: chex.Array
    bowl_pile_pos: chex.Array
    onion_pile_pos: chex.Array
    time: int
    terminal: bool


@struct.dataclass
class EnvParams:
    # height: int = 6
    # width: int = 9
    # h_min: int = 4
    # w_min: int = 4
    # n_walls: int = 5
    # agent_view_size: int = 5
    # replace_wall_pos: bool = False
    # normalize_obs: bool = False
    # sample_n_walls: bool = False  # Sample n_walls uniformly in [0, n_walls]
    max_steps: int = 400
    # singleton_seed: int = -1
    max_episode_steps: int = 400


@struct.dataclass
class Observation:
    image: chex.Array


@struct.dataclass
class EnvParam:
    agent_pos: chex.Array
    agent_dir_idx: chex.Array
    agent_inv: chex.Array
    goal_pos: chex.Array
    pot_pos: chex.Array
    onion_pile_pos: chex.Array
    plate_pile_pos: chex.Array
    wall_map: chex.Array


@struct.dataclass
class Level:
    """This represents a level in the maze environment.
    The main features are the wall map, goal position, agent position and agent direction.
    """
    height: int
    width: int

    wall_map: chex.Array
    empty_table_idx: chex.Array
    agent_pos: chex.Array
    agent_idx: chex.Array
    goal_pos: chex.Array
    plate_pile_pos: chex.Array
    onion_pile_pos: chex.Array
    pot_pos: chex.Array

    agent_dir_idx: chex.Array

    @classmethod
    def stack(cls, levels):
        level_dims = np.array(
            [[level.wall_map.shape[1], level.wall_map.shape[0]] for level in levels])
        max_width, max_height = level_dims.max(axis=0)
        return jax.tree.map(
            lambda *xs: jnp.stack(xs),
            *(level.pad_to_shape(max_width, max_height) for level in levels)
        )

    def pad_to_shape(self, w, h):
        return self

    @classmethod
    def from_layout_name(cls, layout_name):
        key = jax.random.PRNGKey(0)
        layout = overcooked_layouts[layout_name]
        h, w = layout.get("height"), layout.get("width")
        # all_pos = np.arange(np.prod([h, w]), dtype=jnp.uint32)
        empty = jnp.zeros((h * w), dtype=jnp.int32)
        occupied_mask = layout.get("wall_idx")
        occupied_pos = empty.copy()  # .at[occupied_mask].set(1)

        wall_map = empty.at[occupied_mask].set(
            1).reshape(h, w).astype(jnp.bool_)

        key, subkey = jax.random.split(key)

        agent_idx = layout.get("agent_idx")
        agent_pos = jnp.array([agent_idx % w, agent_idx // w],
                              dtype=jnp.uint32).transpose()  # dim = n_agents x 2
        # agent_pos = agent_idx.reshape(h,w)
        # occupied_pos = occupied_pos.at[agent_idx].set(1)

        key, subkey = jax.random.split(key)
        agent_dir_idx = jax.random.choice(subkey, jnp.arange(
            len(DIR_TO_VEC), dtype=jnp.int32), shape=(2,))
        agent_dir = DIR_TO_VEC.at[agent_dir_idx].get()  # dim = n_agents x 2

        goal_idx = layout.get("goal_idx")
        goal_pos = empty.at[goal_idx].set(1).reshape(h, w)
        occupied_pos = occupied_pos.at[goal_idx].set(1)

        onion_pile_idx = layout.get("onion_pile_idx")
        onion_pile_pos = empty.at[onion_pile_idx].set(1).reshape(h, w)
        occupied_pos = occupied_pos.at[onion_pile_idx].set(1)

        plate_pile_idx = layout.get("plate_pile_idx")
        plate_pile_pos = empty.at[plate_pile_idx].set(1).reshape(h, w)
        occupied_pos = occupied_pos.at[plate_pile_idx].set(1)

        pot_idx = layout.get("pot_idx")
        pot_pos = empty.at[pot_idx].set(1).reshape(h, w)
        occupied_pos = occupied_pos.at[pot_idx].set(1)

        return cls(
            height=h,
            width=w,
            wall_map=wall_map,
            empty_table_idx=(wall_map - occupied_pos.reshape(h, w)),
            agent_pos=agent_pos,
            agent_idx=agent_idx,
            goal_pos=goal_pos,
            plate_pile_pos=plate_pile_pos,
            onion_pile_pos=onion_pile_pos,
            pot_pos=pot_pos,
            agent_dir_idx=agent_dir_idx,
        )


# Pot status indicated by an integer, which ranges from 23 to 0
POT_EMPTY_STATUS = 23  # 22 = 1 onion in pot; 21 = 2 onions in pot; 20 = 3 onions in pot
# 3 onions. Below this status, pot is cooking, and status acts like a countdown timer.
POT_FULL_STATUS = 20
POT_READY_STATUS = 0
# A pot has at most 3 onions. A soup contains exactly 3 onions.
MAX_ONIONS_IN_POT = 3

URGENCY_CUTOFF = 40  # When this many time steps remain, the urgency layer is flipped on
DELIVERY_REWARD = 20


SHAPED_REWARD = {
    "PLACEMENT_IN_POT_REW": 0,
    "DISH_PICKUP_REWARD": 3,
    "SOUP_PICKUP_REWARD": 5,
    "PICKUP_TOMATO_REWARD": 0,
    "DISH_DISP_DISTANCE_REW": 0,
    "POT_DISTANCE_REW": 0,
    "SOUP_DISTANCE_REW": 0,
}

OBJECT_TO_INDEX = {
    "unseen": 0,
    "empty": 1,
    "wall": 2,
    "onion": 3,
    "onion_pile": 4,
    "plate": 5,
    "plate_pile": 6,
    "goal": 7,
    "pot": 8,
    "dish": 9,
    "agent": 10,
}


COLORS = {
    'red': np.array([255, 0, 0]),
    'green': np.array([0, 255, 0]),
    'blue': np.array([0, 0, 255]),
    'purple': np.array([112, 39, 195]),
    'yellow': np.array([255, 255, 0]),
    'grey': np.array([100, 100, 100]),
    'white': np.array([255, 255, 255]),
    'black': np.array([25, 25, 25]),
    'orange': np.array([230, 180, 0]),
}


COLOR_TO_INDEX = {
    'red': 0,
    'green': 1,
    'blue': 2,
    'purple': 3,
    'yellow': 4,
    'grey': 5,
    'white': 6,
    'black': 7,
    'orange': 8,
}

LAYOUT_STR_TO_LAYOUT = {
    "asymm_advantages_6_9": asymm_advantages_6_9,
    "counter_circuit_6_9": counter_circuit_6_9,
    "forced_coord_6_9": forced_coord_6_9,
    "cramped_room_6_9": cramped_room_6_9,
    "coord_ring_6_9": coord_ring_6_9,
    "coord_ring_5_5": coord_ring_5_5,
    "forced_coord_5_5": forced_coord_5_5,
    "cramped_room_5_5": cramped_room_5_5,
}


OBJECT_INDEX_TO_VEC = jnp.array([
    jnp.array([OBJECT_TO_INDEX['unseen'], 0, 0], dtype=jnp.uint8),
    jnp.array([OBJECT_TO_INDEX['empty'], 0, 0], dtype=jnp.uint8),
    jnp.array([OBJECT_TO_INDEX['wall'], COLOR_TO_INDEX['grey'], 0],
              dtype=jnp.uint8),
    jnp.array([OBJECT_TO_INDEX['onion'],
              COLOR_TO_INDEX["yellow"], 0], dtype=jnp.uint8),
    jnp.array([OBJECT_TO_INDEX['onion_pile'],
              COLOR_TO_INDEX["yellow"], 0], dtype=jnp.uint8),
    jnp.array([OBJECT_TO_INDEX['plate'],
              COLOR_TO_INDEX["white"], 0], dtype=jnp.uint8),
    jnp.array([OBJECT_TO_INDEX['plate_pile'],
              COLOR_TO_INDEX["white"], 0], dtype=jnp.uint8),
    jnp.array([OBJECT_TO_INDEX['goal'], COLOR_TO_INDEX['green'], 0],
              dtype=jnp.uint8),
    jnp.array([OBJECT_TO_INDEX['pot'], COLOR_TO_INDEX['black'], 0],
              dtype=jnp.uint8),
    jnp.array([OBJECT_TO_INDEX['dish'], COLOR_TO_INDEX["white"], 0],
              dtype=jnp.uint8),
    jnp.array([OBJECT_TO_INDEX['agent'], COLOR_TO_INDEX['red'], 0],
              dtype=jnp.uint8),  					# Default color and direction
])


# Map of agent direction indices to vectors
DIR_TO_VEC = jnp.array([
    (0, -1),  # NORTH
    (0, 1),  # SOUTH
    (1, 0),  # EAST
    (-1, 0),  # WEST
], dtype=jnp.int8)


class OGC(UnderspecifiedMultiAgentEnv):
    """Overcooked Procedural Multi-Agent"""

    def __init__(
        self,
        height: int,
        width: int,
        random_reset: bool = False,
        max_steps=400,
    ):
        # Sets self.num_agents to 2

        self.num_agents = 2
        self.width = width
        self.height = height

        # Hard coded. Only affects map padding -- not observations.
        self.agent_view_size = 5
        self.agents = ["agent_0", "agent_1"]

        # Define the observation function
        self.get_obs = self.get_obs_sparse
        self.obs_shape = (self.width, self.height, 26)

        self.action_set = jnp.array([
            Actions.right,
            Actions.down,
            Actions.left,
            Actions.up,
            Actions.stay,
            Actions.interact,
        ])

        self.random_reset = random_reset
        self.max_steps = max_steps

    def step_env(
            self,
            key: chex.PRNGKey,
            state: EnvState,
            actions: Dict[str, chex.Array],
            params: EnvParams,
    ) -> Tuple[Dict[str, chex.Array], EnvState, Dict[str, float], Dict[str, bool], Dict]:
        """Perform single timestep state transition."""

        acts = self.action_set.take(indices=jnp.array(
            [actions["agent_0"], actions["agent_1"]]))

        state, reward, shaped_reward_alice, shaped_reward_bob = self.step_agents(
            key, state, acts)

        state = state.replace(time=state.time + 1)

        done = self.is_terminal(state)
        state = state.replace(terminal=done)

        obs = self.get_obs(state)
        rewards = {
            "agent_0": reward,
            "agent_1": reward
        }
        dones = {"agent_0": done, "agent_1": done, "__all__": done}

        return (
            lax.stop_gradient(obs),
            lax.stop_gradient(state),
            rewards,
            dones,
            {
                "sparse_reward": jnp.array([reward, reward]),
                "shaped_reward": jnp.array([shaped_reward_alice, shaped_reward_bob]),
            },
        )

    def get_obs_sparse(self, state: EnvState) -> Dict[str, chex.Array]:
        """Return a full observation, of size(height x width x n_layers), where n_layers = 26.
        Layers are of shape(height x width) and are binary(0/1) except where indicated otherwise.
        The obs is very sparse(most elements are 0), which prob. contributes to generalization problems in Overcooked.
        A v2 of this environment should have much more efficient observations, e.g. using item embeddings

        The list of channels is below. Agent-specific layers are ordered so that an agent perceives its layers first.
        Env layers are the same (and in same order) for both agents.

        Agent positions:
        0. position of agent i(1 at agent loc, 0 otherwise)
        1. position of agent(1-i)

        Agent orientations:
        2-5. agent_{i}_orientation_0 to agent_{i}_orientation_3(layers are entirely zero except for the one orientation
        layer that matches the agent orientation. That orientation has a single 1 at the agent coordinates.)
        6-9. agent_{i-1}_orientation_{dir}

        Static env positions(1 where object of type X is located, 0 otherwise.):
        10. pot locations
        11. counter locations(table)
        12. onion pile locations
        13. tomato pile locations(tomato layers are included for consistency, but this env does not support tomatoes)
        14. plate pile locations
        15. delivery locations(goal)

        Pot and soup specific layers. These are non-binary layers:
        16. number of onions in pot(0, 1, 2, 3) for elements corresponding to pot locations. Nonzero only for pots that
        have NOT started cooking yet. When a pot starts cooking (or is ready), the corresponding element is set to 0
        17. number of tomatoes in pot.
        18. number of onions in soup(0, 3) for elements corresponding to either a cooking/done pot or to a soup(dish)
        ready to be served. This is a useless feature since all soups have exactly 3 onions, but it made sense in the
        full Overcooked where recipes can be a mix of tomatoes and onions
        19. number of tomatoes in soup
        20. pot cooking time remaining. [19 -> 1] for pots that are cooking. 0 for pots that are not cooking or done
        21. soup done. (Binary) 1 for pots done cooking and for locations containing a soup(dish). O otherwise.

        Variable env layers(binary):
        22. plate locations
        23. onion locations
        24. tomato locations

        Urgency:
        25. Urgency. The entire layer is 1 there are 40 or fewer remaining time steps. 0 otherwise
        """
        width = self.obs_shape[0]
        height = self.obs_shape[1]
        n_channels = self.obs_shape[2]
        padding = 4

        maze_map = state.maze_map[padding:-padding, padding:-padding, 0]
        soup_loc = jnp.array(
            maze_map == OBJECT_TO_INDEX["dish"], dtype=jnp.uint8)

        pot_loc_layer = jnp.array(
            maze_map == OBJECT_TO_INDEX["pot"], dtype=jnp.uint8)
        pot_status = state.maze_map[padding:-padding,
                                    padding: -padding, 2] * pot_loc_layer
        onions_in_pot_layer = jnp.minimum(POT_EMPTY_STATUS - pot_status, MAX_ONIONS_IN_POT) * (
            pot_status >= POT_FULL_STATUS)    # 0/1/2/3, as long as not cooking or not done
        onions_in_soup_layer = jnp.minimum(POT_EMPTY_STATUS - pot_status, MAX_ONIONS_IN_POT) * (pot_status < POT_FULL_STATUS) \
            * pot_loc_layer + MAX_ONIONS_IN_POT * soup_loc   # 0/3, as long as cooking or done
        pot_cooking_time_layer = pot_status * \
            (pot_status < POT_FULL_STATUS)                           # Timer: 19 to 0
        # Ready soups, plated or not
        soup_ready_layer = pot_loc_layer * \
            (pot_status == POT_READY_STATUS) + soup_loc
        urgency_layer = jnp.ones(maze_map.shape, dtype=jnp.uint8) * \
            ((self.max_steps - state.time) < URGENCY_CUTOFF)

        agent_pos_layers = jnp.zeros((2, height, width), dtype=jnp.uint8)
        agent_pos_layers = agent_pos_layers.at[0,
                                               state.agent_pos[0, 1], state.agent_pos[0, 0]].set(1)
        agent_pos_layers = agent_pos_layers.at[1,
                                               state.agent_pos[1, 1], state.agent_pos[1, 0]].set(1)

        # Add agent inv: This works because loose items and agent cannot overlap
        agent_inv_items = jnp.expand_dims(
            state.agent_inv, (1, 2)) * agent_pos_layers
        maze_map = jnp.where(jnp.sum(agent_pos_layers, 0),
                             agent_inv_items.sum(0), maze_map)
        soup_ready_layer = soup_ready_layer
        + (jnp.sum(agent_inv_items, 0) ==
           OBJECT_TO_INDEX["dish"]) * jnp.sum(agent_pos_layers, 0)
        onions_in_soup_layer = onions_in_soup_layer \
            + (jnp.sum(agent_inv_items, 0) ==
               OBJECT_TO_INDEX["dish"]) * 3 * jnp.sum(agent_pos_layers, 0)

        env_layers = [
            # Channel 10
            jnp.array(maze_map == OBJECT_TO_INDEX["pot"], dtype=jnp.uint8),
            jnp.array(maze_map == OBJECT_TO_INDEX["wall"], dtype=jnp.uint8),
            jnp.array(
                maze_map == OBJECT_TO_INDEX["onion_pile"], dtype=jnp.uint8),
            # tomato pile
            jnp.zeros(maze_map.shape, dtype=jnp.uint8),
            jnp.array(
                maze_map == OBJECT_TO_INDEX["plate_pile"], dtype=jnp.uint8),
            # 15
            jnp.array(maze_map == OBJECT_TO_INDEX["goal"], dtype=jnp.uint8),
            jnp.array(onions_in_pot_layer, dtype=jnp.uint8),
            # tomatoes in pot
            jnp.zeros(maze_map.shape, dtype=jnp.uint8),
            jnp.array(onions_in_soup_layer, dtype=jnp.uint8),
            # tomatoes in soup
            jnp.zeros(maze_map.shape, dtype=jnp.uint8),
            jnp.array(pot_cooking_time_layer,
                      dtype=jnp.uint8),                     # 20
            jnp.array(soup_ready_layer, dtype=jnp.uint8),
            jnp.array(maze_map == OBJECT_TO_INDEX["plate"], dtype=jnp.uint8),
            jnp.array(maze_map == OBJECT_TO_INDEX["onion"], dtype=jnp.uint8),
            # tomatoes
            jnp.zeros(maze_map.shape, dtype=jnp.uint8),
            urgency_layer,                                                          # 25
        ]

        # Agent related layers
        agent_direction_layers = jnp.zeros((8, height, width), dtype=jnp.uint8)
        dir_layer_idx = state.agent_dir_idx+jnp.array([0, 4])
        agent_direction_layers = agent_direction_layers.at[dir_layer_idx, :, :].set(
            agent_pos_layers)

        # Both agent see their layers first, then the other layer
        alice_obs = jnp.zeros((n_channels, height, width), dtype=jnp.uint8)
        alice_obs = alice_obs.at[0:2].set(agent_pos_layers)

        alice_obs = alice_obs.at[2:10].set(agent_direction_layers)
        alice_obs = alice_obs.at[10:].set(jnp.stack(env_layers))

        bob_obs = jnp.zeros((n_channels, height, width), dtype=jnp.uint8)
        bob_obs = bob_obs.at[0].set(
            agent_pos_layers[1]).at[1].set(agent_pos_layers[0])
        bob_obs = bob_obs.at[2:6].set(agent_direction_layers[4:]).at[6:10].set(
            agent_direction_layers[0:4])
        bob_obs = bob_obs.at[10:].set(jnp.stack(env_layers))

        alice_obs = jnp.transpose(alice_obs, (1, 2, 0))
        bob_obs = jnp.transpose(bob_obs, (1, 2, 0))
        return {"agent_0": alice_obs.flatten(), "agent_1": bob_obs.flatten()}

    def step_agents(
            self, key: chex.PRNGKey, state: EnvState, action: chex.Array
    ) -> Tuple[EnvState, float]:

        # Update agent position (forward action)
        is_move_action = jnp.logical_and(
            action != Actions.stay, action != Actions.interact)
        is_move_action_transposed = jnp.expand_dims(
            is_move_action, 0).transpose()  # Necessary to broadcast correctly

        fwd_pos = jnp.minimum(
            jnp.maximum(state.agent_pos + is_move_action_transposed * DIR_TO_VEC[jnp.minimum(action, 3)]
                        + ~is_move_action_transposed * state.agent_dir, 0),
            jnp.array((self.width - 1, self.height - 1), dtype=jnp.uint32)
        )

        # Can't go past wall or goal
        def _wall_or_goal(fwd_position, wall_map, goal_pos):
            fwd_wall = wall_map.at[fwd_position[1], fwd_position[0]].get()
            def goal_collision(pos, goal): return jnp.logical_and(
                pos[0] == goal[0], pos[1] == goal[1])
            fwd_goal = jax.vmap(goal_collision, in_axes=(
                None, 0))(fwd_position, goal_pos)
            # fwd_goal = jnp.logical_and(fwd_position[0] == goal_pos[0], fwd_position[1] == goal_pos[1])
            fwd_goal = jnp.any(fwd_goal)
            return fwd_wall, fwd_goal

        fwd_pos_has_wall, fwd_pos_has_goal = jax.vmap(_wall_or_goal, in_axes=(
            0, None, None))(fwd_pos, state.wall_map, state.goal_pos)

        fwd_pos_blocked = jnp.logical_or(
            fwd_pos_has_wall, fwd_pos_has_goal).reshape((self.num_agents, 1))

        bounced = jnp.logical_or(fwd_pos_blocked, ~is_move_action_transposed)

        # Agents can't overlap
        # Hardcoded for 2 agents (call them Alice and Bob)
        agent_pos_prev = jnp.array(state.agent_pos)
        fwd_pos = (bounced * state.agent_pos + (~bounced)
                   * fwd_pos).astype(jnp.uint32)
        collision = jnp.all(fwd_pos[0] == fwd_pos[1])

        # No collision = No movement. This matches original Overcooked env.
        alice_pos = jnp.where(
            collision,
            state.agent_pos[0],                     # collision and Bob bounced
            fwd_pos[0],
        )
        bob_pos = jnp.where(
            collision,
            # collision and Alice bounced
            state.agent_pos[1],
            fwd_pos[1],
        )

        # Prevent swapping places (i.e. passing through each other)
        swap_places = jnp.logical_and(
            jnp.all(fwd_pos[0] == state.agent_pos[1]),
            jnp.all(fwd_pos[1] == state.agent_pos[0]),
        )
        alice_pos = jnp.where(
            ~collision * swap_places,
            state.agent_pos[0],
            alice_pos
        )
        bob_pos = jnp.where(
            ~collision * swap_places,
            state.agent_pos[1],
            bob_pos
        )

        fwd_pos = fwd_pos.at[0].set(alice_pos)
        fwd_pos = fwd_pos.at[1].set(bob_pos)
        agent_pos = fwd_pos.astype(jnp.uint32)

        # Update agent direction
        agent_dir_idx = ~is_move_action * state.agent_dir_idx + is_move_action * action
        agent_dir = DIR_TO_VEC[agent_dir_idx]

        # Handle interacts. Agent 1 first, agent 2 second, no collision handling.
        # This matches the original Overcooked
        fwd_pos = state.agent_pos + state.agent_dir
        maze_map = state.maze_map
        is_interact_action = (action == Actions.interact)

        # Compute the effect of interact first, then apply it if needed
        candidate_maze_map, alice_inv, alice_reward, alice_shaped_reward = self.process_interact(
            maze_map, state, fwd_pos[0], state.agent_inv[0], state.agent_inv[1])
        alice_interact = is_interact_action[0]
        bob_interact = is_interact_action[1]

        maze_map = jax.lax.select(alice_interact,
                                  candidate_maze_map,
                                  maze_map)
        alice_inv = jax.lax.select(alice_interact,
                                   alice_inv,
                                   state.agent_inv[0])
        alice_reward = jax.lax.select(alice_interact, alice_reward, 0.)
        alice_shaped_reward = jax.lax.select(
            alice_interact, alice_shaped_reward, 0.)

        candidate_maze_map, bob_inv, bob_reward, bob_shaped_reward = self.process_interact(
            maze_map, state, fwd_pos[1], state.agent_inv[1], state.agent_inv[0])
        maze_map = jax.lax.select(bob_interact,
                                  candidate_maze_map,
                                  maze_map)
        bob_inv = jax.lax.select(bob_interact,
                                 bob_inv,
                                 state.agent_inv[1])
        bob_reward = jax.lax.select(bob_interact, bob_reward, 0.)
        bob_shaped_reward = jax.lax.select(bob_interact, bob_shaped_reward, 0.)

        agent_inv = jnp.array([alice_inv, bob_inv])

        # Update agent component in maze_map
        def _get_agent_updates(agent_dir_idx, agent_pos, agent_pos_prev, agent_idx):
            agent = jnp.array([OBJECT_TO_INDEX['agent'], COLOR_TO_INDEX['red'] +
                              agent_idx*2, agent_dir_idx], dtype=jnp.uint8)
            agent_x_prev, agent_y_prev = agent_pos_prev
            agent_x, agent_y = agent_pos
            return agent_x, agent_y, agent_x_prev, agent_y_prev, agent

        vec_update = jax.vmap(_get_agent_updates, in_axes=(0, 0, 0, 0))
        agent_x, agent_y, agent_x_prev, agent_y_prev, agent_vec = vec_update(
            agent_dir_idx, agent_pos, agent_pos_prev, jnp.arange(self.num_agents))
        empty = jnp.array([OBJECT_TO_INDEX['empty'], 0, 0], dtype=jnp.uint8)

        # Compute padding, added automatically by map maker function
        # height = self.obs_shape[1]
        padding = 4  # (state.maze_map.shape[0] - height) // 2

        maze_map = maze_map.at[padding + agent_y_prev,
                               padding + agent_x_prev, :].set(empty)
        maze_map = maze_map.at[padding + agent_y,
                               padding + agent_x, :].set(agent_vec)

        # Update pot cooking status
        def _cook_pots(maze_map, pot_pos):
            pot_pos_padded = jnp.zeros(
                (maze_map.shape[0], maze_map.shape[1]), dtype=jnp.uint8
            )
            pot_pos_padded = pot_pos_padded.at[
                padding:-padding, padding:-padding].set(pot_pos)
            is_cooking = jnp.array(
                maze_map[:, :, -1] * pot_pos_padded <= POT_FULL_STATUS, dtype=jnp.uint8) * pot_pos_padded
            not_done = jnp.array(
                maze_map[:, :, -1] * pot_pos_padded > POT_READY_STATUS, dtype=jnp.uint8) * pot_pos_padded
            pot_status_is_cooking_not_done = is_cooking * \
                not_done * (maze_map[:, :, -1] - 1) * pot_pos_padded
            pot_status_is_not_cooking = jnp.logical_not(
                is_cooking) * (maze_map[:, :, -1]) * pot_pos_padded  # defaults to zero if done pot_status
            pot_status = pot_status_is_cooking_not_done + pot_status_is_not_cooking

            pot_status_map = pot_pos_padded * pot_status + \
                jnp.logical_not(pot_pos_padded) * maze_map[:, :, -1]
            pot_status_map = jnp.concatenate(
                (jnp.zeros((*pot_status_map.shape, 2), dtype=jnp.uint8), pot_status_map[:, :, jnp.newaxis]), axis=-1)

            pot_pos_3 = jnp.concatenate(
                (jnp.zeros((pot_status_map.shape[0], pot_status_map.shape[1], 2), dtype=jnp.uint8), pot_pos_padded[:, :, jnp.newaxis]), axis=-1)

            maze_map = maze_map * (1-pot_pos_3) + pot_status_map * pot_pos_3

            return maze_map  # pot.at[-1].set(pot_status)

        maze_map = _cook_pots(maze_map, state.pot_pos)

        reward = alice_reward + bob_reward
        # shaped_reward = alice_shaped_reward + bob_shaped_reward

        return (
            state.replace(
                agent_pos=agent_pos,
                agent_dir_idx=agent_dir_idx,
                agent_dir=agent_dir,
                agent_inv=agent_inv,
                maze_map=maze_map,
                terminal=False),
            reward,
            alice_shaped_reward,
            bob_shaped_reward,
        )

    def process_interact(
            self,
            maze_map: chex.Array,
            state: EnvState,
            fwd_pos: chex.Array,
            inventory: chex.Array,
            other_inventory: chex.Array):
        """Assume agent took interact actions. Result depends on what agent is facing and what it is holding."""

        wall_map = state.wall_map
        height = self.height  # self.obs_shape[1]
        # padding = (maze_map.shape[0] - height) // 2
        padding = 4

        # Get object in front of agent (on the "table")
        maze_object_on_table = maze_map.at[padding +
                                           fwd_pos[1], padding + fwd_pos[0]].get()
        object_on_table = maze_object_on_table[0]  # Simple index

        # Booleans depending on what the object is
        object_is_pile = jnp.logical_or(
            object_on_table == OBJECT_TO_INDEX["plate_pile"], object_on_table == OBJECT_TO_INDEX["onion_pile"])
        object_is_pot = jnp.array(object_on_table == OBJECT_TO_INDEX["pot"])
        object_is_goal = jnp.array(object_on_table == OBJECT_TO_INDEX["goal"])
        object_is_agent = jnp.array(
            object_on_table == OBJECT_TO_INDEX["agent"])
        object_is_pickable = jnp.logical_or(
            jnp.logical_or(
                object_on_table == OBJECT_TO_INDEX["plate"], object_on_table == OBJECT_TO_INDEX["onion"]),
            object_on_table == OBJECT_TO_INDEX["dish"]
        )
        # Whether the object in front is counter space that the agent can drop on.
        is_table = jnp.logical_and(
            wall_map.at[fwd_pos[1], fwd_pos[0]].get(), ~object_is_pot)

        table_is_empty = jnp.logical_or(
            object_on_table == OBJECT_TO_INDEX["wall"], object_on_table == OBJECT_TO_INDEX["empty"])

        # Pot status (used if the object is a pot)
        pot_status = maze_object_on_table[-1]

        # Get inventory object, and related booleans
        inv_is_empty = jnp.array(inventory == OBJECT_TO_INDEX["empty"])
        object_in_inv = inventory
        holding_onion = jnp.array(object_in_inv == OBJECT_TO_INDEX["onion"])
        holding_plate = jnp.array(object_in_inv == OBJECT_TO_INDEX["plate"])
        holding_dish = jnp.array(object_in_inv == OBJECT_TO_INDEX["dish"])

        # Interactions with pot. 3 cases: add onion if missing, collect soup if ready, do nothing otherwise
        case_1 = (pot_status > POT_FULL_STATUS) * holding_onion * object_is_pot
        case_2 = (pot_status == POT_READY_STATUS) * \
            holding_plate * object_is_pot
        case_3 = (pot_status > POT_READY_STATUS) * \
            (pot_status <= POT_FULL_STATUS) * object_is_pot
        else_case = ~case_1 * ~case_2 * ~case_3

        # Update pot status and object in inventory
        new_pot_status = \
            case_1 * (pot_status - 1) \
            + case_2 * POT_EMPTY_STATUS \
            + case_3 * pot_status \
            + else_case * pot_status
        new_object_in_inv = \
            case_1 * OBJECT_TO_INDEX["empty"] \
            + case_2 * OBJECT_TO_INDEX["dish"] \
            + case_3 * object_in_inv \
            + else_case * object_in_inv

        # Interactions with onion/plate piles and objects on counter
        # Pickup if: table, not empty, room in inv & object is not something unpickable (e.g. pot or goal)
        successful_pickup = is_table * ~table_is_empty * inv_is_empty * \
            jnp.logical_or(object_is_pile, object_is_pickable)
        successful_drop = is_table * table_is_empty * ~inv_is_empty
        successful_delivery = is_table * object_is_goal * holding_dish
        no_effect = jnp.logical_and(jnp.logical_and(
            ~successful_pickup, ~successful_drop), ~successful_delivery)

        # Update object on table
        new_object_on_table = \
            no_effect * object_on_table \
            + successful_delivery * object_on_table \
            + successful_pickup * object_is_pile * object_on_table \
            + successful_pickup * object_is_pickable * OBJECT_TO_INDEX["wall"] \
            + successful_drop * object_in_inv

        # Update object in inventory
        new_object_in_inv = \
            no_effect * new_object_in_inv \
            + successful_delivery * OBJECT_TO_INDEX["empty"] \
            + successful_pickup * object_is_pickable * object_on_table \
            + successful_pickup * (object_on_table == OBJECT_TO_INDEX["plate_pile"]) * OBJECT_TO_INDEX["plate"] \
            + successful_pickup * (object_on_table == OBJECT_TO_INDEX["onion_pile"]) * OBJECT_TO_INDEX["onion"] \
            + successful_drop * OBJECT_TO_INDEX["empty"]

        # Apply inventory update
        inventory = new_object_in_inv

        # Apply changes to maze
        new_maze_object_on_table = \
            object_is_pot * OBJECT_INDEX_TO_VEC[new_object_on_table].at[-1].set(new_pot_status) \
            + ~object_is_pot * ~object_is_agent * OBJECT_INDEX_TO_VEC[new_object_on_table] \
            + object_is_agent * maze_object_on_table

        maze_map = maze_map.at[padding + fwd_pos[1],
                               padding + fwd_pos[0], :].set(new_maze_object_on_table)

        # Reward of 20 for a soup delivery
        reward = jnp.array(successful_delivery, dtype=float)*DELIVERY_REWARD

        no_plate_on_counter = (
            (maze_map[padding:-padding, padding:-padding, 0] * wall_map) == OBJECT_TO_INDEX["plate"]).sum() == 0
        num_pots = state.pot_pos.sum()
        #  (maze_map[padding:-padding, padding:-padding, -1].at[state.pot_pos].get() <= POT_FULL_STATUS).sum()
        num_pots_cooking = (
            (maze_map[padding:-padding, padding:-padding, -1] <= POT_FULL_STATUS) * state.pot_pos).sum()
        #  (maze_map[padding:-padding, padding:-padding, -1].at[state.pot_pos].get()  > POT_FULL_STATUS).sum()
        num_pots_not_started = (
            (maze_map[padding:-padding, padding:-padding, -1] > POT_FULL_STATUS) * state.pot_pos).sum()
        num_pots_ready = num_pots - num_pots_cooking - num_pots_not_started
        pot_left_over_for_plate = (num_pots_cooking + num_pots_ready -
                                   1 * (other_inventory == OBJECT_TO_INDEX["dish"])) > 0
        # As in orignal work: adding onion 3, getting a bowl while cooking 5, pickung up a soup 5
        shaped_reward_c1 = (new_object_in_inv == OBJECT_TO_INDEX["empty"]) * (
            object_in_inv == OBJECT_TO_INDEX["onion"]) * case_1 * 3.0
        shaped_reward_c2 = (new_object_in_inv == OBJECT_TO_INDEX["plate"]) * (object_on_table == OBJECT_TO_INDEX["plate_pile"]) * \
            successful_pickup * no_plate_on_counter * pot_left_over_for_plate * 5.0
        shaped_reward_c3 = (new_object_in_inv == OBJECT_TO_INDEX["dish"]) * (
            object_in_inv == OBJECT_TO_INDEX["plate"]) * case_2 * 5.0

        # jax.debug.print("no_plate {a}: {s}", a=no_plate_on_counter, s=shaped_reward_c2)
        shaped_reward = shaped_reward_c1 + shaped_reward_c2 + shaped_reward_c3
        return maze_map, inventory, reward, shaped_reward

    def is_terminal(self, state: EnvState) -> bool:
        """Check whether state is terminal."""
        done_steps = state.time >= self.max_steps
        return done_steps | state.terminal

    def get_eval_solved_rate_fn(self):
        def _fn(ep_stats):
            return ep_stats['return'] > 20  # More than one soup delivered

        return _fn

    @property
    def name(self) -> str:
        """Environment name."""
        return "Overcooked"

    @property
    def num_actions(self) -> int:
        """Number of actions possible in environment."""
        return len(self.action_set)

    def action_space(self, agent_id="") -> spaces.Discrete:
        """Action space of the environment. Agent_id not used since action_space is uniform for all agents"""
        return spaces.Discrete(
            len(self.action_set),
            dtype=jnp.uint8
        )

    def observation_space(self) -> spaces.Box:
        """Observation space of the environment."""
        return spaces.Box(0, 255, self.obs_shape)

    def max_episode_steps(self) -> int:
        return self.params.max_episode_steps

    def reset_env_to_level(
        self,
        rng: chex.PRNGKey,
        level: Level,
        params: EnvParams,
    ) -> Tuple[Observation, ]:
        """
        Instance is encoded as a PyTree containing the following fields:
        agent_pos, agent_dir, goal_pos, wall_map
        """
        state = self.init_state_from_level(rng, level)
        return self.get_obs(state), state

    def init_state_from_level(self, rng: chex.PRNGKey, level: Level) -> EnvState:
        agent_pos = level.agent_pos
        agent_dir_idx = level.agent_dir_idx

        if self.random_reset:
            rng, subkey = jax.random.split(rng)
            swap = jax.random.bernoulli(subkey)
            agent_pos = jax.lax.cond(swap, lambda x: jnp.flip(
                x, axis=0), lambda x: x, agent_pos)
            agent_dir_idx = jax.lax.cond(swap, lambda x: jnp.flip(
                x, axis=0), lambda x: x, agent_dir_idx)

        agent_dir = DIR_TO_VEC.at[agent_dir_idx].get()
        goal_pos = level.goal_pos
        wall_map = level.wall_map
        # agent_inv = level.agent_inv
        pot_pos = level.pot_pos

        onion_pile_pos = level.onion_pile_pos
        plate_pile_pos = level.plate_pile_pos

        onion_pos = jnp.zeros_like(onion_pile_pos, dtype=jnp.uint8)
        plate_pos = jnp.zeros_like(plate_pile_pos, dtype=jnp.uint8)
        dish_pos = jnp.zeros_like(plate_pile_pos, dtype=jnp.uint8)

        agent_inv = jnp.array(
            [OBJECT_TO_INDEX['empty'], OBJECT_TO_INDEX['empty']])

        pot_status = jnp.ones_like(level.pot_pos, dtype=jnp.uint8) * 23

        maze_map = make_overcooked_map(
            wall_map,
            goal_pos,
            agent_pos,
            agent_dir_idx,
            plate_pile_pos,
            onion_pile_pos,
            pot_pos,
            pot_status,
            onion_pos,
            plate_pos,
            dish_pos,
            pad_obs=True,
            num_agents=2,
            agent_view_size=5)

        state = EnvState(
            agent_pos=agent_pos,
            agent_dir=agent_dir,
            agent_dir_idx=agent_dir_idx,
            goal_pos=goal_pos,
            wall_map=wall_map,
            maze_map=maze_map,
            bowl_pile_pos=plate_pile_pos,
            onion_pile_pos=onion_pile_pos,
            agent_inv=agent_inv,
            pot_pos=pot_pos,
            time=0,
            terminal=False
        )
        return state

    def get_env_metrics(self, state: EnvState) -> dict:
        n_walls = state.wall_map.sum()
        return dict(
            n_walls=n_walls,
        )

    def state_space(self) -> spaces.Dict:
        """EnvState space of the environment."""
        h = self.height
        w = self.width
        agent_view_size = self.agent_view_size
        return spaces.Dict({
            "agent_pos": spaces.Box(0, max(w, h), (2,), dtype=jnp.uint32),
            "agent_dir": spaces.Discrete(4),
            "goal_pos": spaces.Box(0, max(w, h), (2,), dtype=jnp.uint32),
            "maze_map": spaces.Box(0, 255, (w + agent_view_size, h + agent_view_size, 3), dtype=jnp.uint32),
            "time": spaces.Discrete(self.max_steps),
            "terminal": spaces.Discrete(2),
        })

    def max_steps(self) -> int:
        return self.max_steps

    def get_monitored_metrics(self):
        return ('reward', 'shaped_reward', 'shaped_reward_scaled_by_shaped_reward_coeff', 'reward_p_shaped_reward_scaled')

    @property
    def default_params(self) -> EnvParams:
        # Default environment parameters
        return EnvParams()


###########################
# Generator
###########################
def make_level_generator(
    height: int, width: int, n_walls: int, heldout_set: Optional[Level], sample_n_walls: bool = True
) -> Callable[[chex.PRNGKey], Level]:
    """This takes in a height, width and number of walls and returns a function that takes in a PRNGKey and returns a level.

    Args:
        height (int):
        width (int):
        n_walls (int):
    """

    if n_walls == 0:
        sample_n_walls = False

    if heldout_set is None:
        check_held_out = False
    else:
        check_held_out = True

    def check_match(level: Level) -> bool:
        if heldout_set is None:
            return False

        def match_stack(stack, single):
            stacked = stack.reshape(stack.shape[0], -1)
            single_flat = single.reshape(-1)
            return jnp.all(stacked == single_flat, axis=1)

        goal_matches = match_stack(heldout_set.goal_pos, level.goal_pos)
        pot_matches = match_stack(heldout_set.pot_pos, level.pot_pos)
        wall_matches = match_stack(heldout_set.wall_map, level.wall_map)
        plate_matches = match_stack(
            heldout_set.plate_pile_pos, level.plate_pile_pos)
        onion_matches = match_stack(
            heldout_set.onion_pile_pos, level.onion_pile_pos)

        # Optionally include agent positions and dirs too:
        agent_pos_matches = match_stack(heldout_set.agent_pos, level.agent_pos)
        agent_dir_matches = match_stack(
            heldout_set.agent_dir_idx, level.agent_dir_idx)

        all_match = jnp.stack([
            goal_matches,
            pot_matches,
            wall_matches,
            plate_matches,
            onion_matches,
            agent_pos_matches,
            agent_dir_matches
        ], axis=0)

        # any entry where all fields match
        return jnp.any(jnp.all(all_match, axis=0))

    def sample(key: chex.PRNGKey) -> Level:
        """Samples a random layout that might or might not be playable.
        """
        def generate_level(k):
            h, w = height, width

            all_pos = np.arange(np.prod([h, w]), dtype=jnp.uint8)

            key, control_key, *subkeys = jax.random.split(k, 10)
            (
                walls_key, nwalls_key, goal_key,
                plate_pile_key, onion_pile_key,
                pot_key, agpos_key, agdir_key
            ) = subkeys
            wall_idx = jax.random.choice(
                walls_key,
                all_pos,
                shape=(n_walls,),
            )

            if sample_n_walls:
                sampled_n_walls = jax.random.randint(
                    nwalls_key, (), minval=0, maxval=n_walls)
                sample_wall_mask = jnp.arange(n_walls) < sampled_n_walls
                dummy_wall_idx = wall_idx.at[0].get().repeat(n_walls)
                wall_idx = jax.lax.select(
                    sample_wall_mask,
                    wall_idx,
                    dummy_wall_idx
                )

            walls = jnp.zeros_like(all_pos, dtype=jnp.uint8)
            walls = walls.at[wall_idx].set(1)
            walls = walls.reshape(h, w)
            walls = walls.at[:, 0].set(1)
            walls = walls.at[0, :].set(1)
            walls = walls.at[:, -1].set(1)
            walls = walls.at[-1, :].set(1)

            # Decide whether to insert a center wall if the grid is square
            # Only do this if grid is square
            # is_square = h == w

            def insert_single_center_wall(walls, key):
                walls = walls.at[h // 2, w // 2].set(1)
                walls_2d = walls.astype(jnp.bool_)  # Keep 2D form
                accessible_wall_mask_2d = walls_2d  # & adjacent_free
                return walls, walls, walls, accessible_wall_mask_2d.astype(jnp.uint8)

            def make_smaller(walls, key):
                col = jax.random.choice(key, jnp.array([1, w-2]), shape=())
                walls = walls.at[:, col].set(1)

                walls_2d = walls.astype(jnp.bool_)  # Keep 2D form
                accessible_wall_mask_2d = walls_2d  # & adjacent_free

                return walls, walls, walls, accessible_wall_mask_2d.astype(jnp.uint8)

            def insert_vertical_partition_wall(walls, key):
                col = w//2
                accessible_wall_mask = walls.copy()

                walls = walls.at[:, col].set(1)
                agent1_mask = walls.copy().at[:, :col].set(1)
                agent2_mask = walls.copy().at[:, col:].set(1)
                return walls, agent1_mask, agent2_mask, accessible_wall_mask.astype(jnp.uint8)

            def do_nothing(walls, key):
                walls_2d = walls.astype(jnp.bool_)  # Keep 2D form
                accessible_wall_mask_2d = walls_2d  # & adjacent_free
                return walls, walls, walls, accessible_wall_mask_2d.astype(jnp.uint8)

            key, _key, _key_t = jax.random.split(key, 3)
            t = jax.random.randint(_key_t, (), 0, 3)  # Must be static
            walls, agent1_mask, agent2_mask, accessible_wall_mask = jax.lax.switch(
                t,
                [insert_single_center_wall, make_smaller,
                    insert_vertical_partition_wall, do_nothing],
                walls,
                _key
            )

            rotate_key, control_key = jax.random.split(control_key)
            k = jax.random.randint(rotate_key, (), 0, 4)  # Must be static

            def rot0(w): return w
            def rot90(w): return jnp.rot90(w, k=1)
            def rot180(w): return jnp.rot90(w, k=2)
            def rot270(w): return jnp.rot90(w, k=3)

            walls = jax.lax.switch(
                k,
                [rot0, rot90, rot180, rot270],
                walls
            )

            agent1_mask = jax.lax.switch(
                k,
                [rot0, rot90, rot180, rot270],
                agent1_mask
            )

            agent2_mask = jax.lax.switch(
                k,
                [rot0, rot90, rot180, rot270],
                agent2_mask
            )

            accessible_wall_mask = jax.lax.switch(
                k,
                [rot0, rot90, rot180, rot270],
                accessible_wall_mask
            )
            accessible_wall_mask = accessible_wall_mask.reshape(
                -1).astype(jnp.uint8)

            walls = walls.reshape(-1)

            wall_mask = walls.copy()
            occupied_obj_mask = jnp.zeros_like(walls, dtype=jnp.uint8)

            occupied_obj_mask = occupied_obj_mask.reshape(h, w)
            occupied_obj_mask = occupied_obj_mask.reshape(-1)

            def add_1_or_2_items(key, all_pos, accessible_wall_mask, occupied_obj_mask):
                # occupied_obj_mask is only objects on tables so we can do:
                possible_positions = accessible_wall_mask - occupied_obj_mask
                obj_mask = jnp.zeros_like(all_pos, dtype=jnp.uint8)
                key, subkey1, subkey2, subkey3 = jax.random.split(key, 4)
                item_idx_1 = jax.random.choice(subkey1, all_pos, shape=(
                    1,), p=(possible_positions.astype(jnp.bool_)).astype(jnp.uint8))

                and_2 = jax.random.bernoulli(subkey2, 0.5)

                item_idx_2 = jax.random.choice(subkey3, all_pos, shape=(
                    1,), p=(possible_positions.astype(jnp.bool_)).astype(jnp.uint8))

                obj_mask = obj_mask.at[item_idx_1].set(1)

                update_2 = jnp.logical_or(
                    obj_mask.at[item_idx_2].get(), and_2.astype(jnp.uint8))
                obj_mask = obj_mask.at[item_idx_2].set(update_2)
                return obj_mask

            goal_pos = add_1_or_2_items(
                goal_key, all_pos, accessible_wall_mask, occupied_obj_mask)
            occupied_obj_mask = occupied_obj_mask + goal_pos

            plate_pile_pos = add_1_or_2_items(
                plate_pile_key, all_pos, accessible_wall_mask, occupied_obj_mask)
            occupied_obj_mask = occupied_obj_mask + plate_pile_pos

            onion_pile_pos = add_1_or_2_items(
                onion_pile_key, all_pos, accessible_wall_mask, occupied_obj_mask)
            occupied_obj_mask = occupied_obj_mask + onion_pile_pos

            pot_pos = add_1_or_2_items(
                pot_key, all_pos, accessible_wall_mask, occupied_obj_mask)
            occupied_obj_mask = occupied_obj_mask + pot_pos

            # agent_idx = jax.random.choice(agpos_key, all_pos, shape=(2,), replace=False, p=(
            #     ~wall_mask.astype(jnp.bool_)).astype(jnp.uint8))
            agpos1_key, agpos2_key = jax.random.split(agpos_key)
            agent_idx1 = jax.random.choice(agpos1_key, all_pos, shape=(), replace=False, p=(
                ~agent1_mask.reshape(-1).astype(jnp.bool_)).astype(jnp.uint8))
            agent2_mask = agent2_mask.reshape(-1).at[agent_idx1].set(1)
            agent_idx2 = jax.random.choice(agpos2_key, all_pos, shape=(), replace=False, p=(
                ~agent2_mask.astype(jnp.bool_)).astype(jnp.uint8))

            agent_idx = jnp.array([agent_idx1, agent_idx2])
            agent_pos = jnp.array([agent_idx % w, agent_idx // w],
                                  dtype=jnp.uint32).transpose()
            # occupied_mask = occupied_mask.at[agent_idx].set(2)
            agent_dir_idx = jax.random.choice(agdir_key, jnp.arange(
                len(DIR_TO_VEC), dtype=jnp.int32), shape=(2,))

            return Level(
                height=h,
                width=w,
                wall_map=walls.reshape(h, w).astype(jnp.bool_),
                empty_table_idx=(walls - occupied_obj_mask).reshape(h, w),
                agent_idx=agent_idx,
                agent_pos=agent_pos,
                goal_pos=goal_pos.reshape(h, w),
                plate_pile_pos=plate_pile_pos.reshape(h, w),
                onion_pile_pos=onion_pile_pos.reshape(h, w),
                pot_pos=pot_pos.reshape(h, w),
                agent_dir_idx=agent_dir_idx,
            )

        key, subkey1, subkey2 = jax.random.split(key, 3)
        level = generate_level(subkey1)

        def resample(k):
            return generate_level(k)

        level = jax.lax.cond(
            jnp.logical_and(check_match(level), check_held_out),
            resample,
            lambda k: level,
            subkey2
        )
        return level
    return sample


if __name__ == '__main__':

    render = False
    n_envs = 1

    kwargs = dict(
        # max_episode_steps=400,
        height=5,
        width=5,
        n_walls=10,
        agent_view_size=5,
        fix_to_single_layout="coord_ring_6_9"
    )
    env = Overcooked(**kwargs)
    params = env.default_params

    # rng = jax.random.PRNGKey(0)
    # extra = env.reset_env(rng)

    sampler = make_level_generator(5, 5, 5)

    key = jax.random.PRNGKey(0)
    keys = jax.random.split(key, 512)
    levels = jax.vmap(sampler)(keys)

    jit_reset_env = env.reset_env_to_level
    jit_step_env = env.step
    obs, state = jax.vmap(jit_reset_env, in_axes=(0, None))(levels, params)

    all_sps = []

    import time
    for ac in [0, 0, 5, 0, 0]:  # [1, 1, 3, 1, 5]:
        start = time.time()
        jax.debug.print('obs:\n{a}', a=(obs['agent_0'][:, :, 0]
                        * 1 + obs['agent_0'][:, :, 1]*2+obs['agent_0'][:, :, 11]*3).T)

        key, subkey = jax.random.split(key)
        ac = jax.random.choice(subkey, jnp.array(
            [0, 1, 2, 3, 4, 5]), shape=(512,))
        key, subkey = jax.random.split(key)
        step_subkey = jax.random.split(subkey, 512)

        obs, state, reward, done, info = jax.vmap(jit_step_env, in_axes=(0, 0, 0, None))(
            step_subkey,
            state,
            {'agent_0': ac, 'agent_1': ac},
            params,
        )
        jax.debug.print("reward r {r} {ir} {isr}", r=reward,
                        ir=info["sparse_reward"], isr=info["shaped_reward"])

        obs['agent_0'].block_until_ready()
        end = time.time()
        # print(f"sps: {1/(end-start) * n_envs}")
        # print('return:', info['return'])
        all_sps.append(1/(end-start) * n_envs)

    print('mean sps:', np.mean(all_sps))
    print('std sps:', np.std(all_sps))
