import numpy as np
import jax.numpy as jnp
from scipy.ndimage import label

from flax.core.frozen_dict import FrozenDict
from jaxmarl.environments.overcooked.layouts import overcooked_layouts


cramped_room_5_5 = {
    "height": 5,
    "width": 5,
    "wall_idx": jnp.array([0, 1, 2, 3, 4,
                           5, 9,
                           10, 14,
                           15, 16, 17, 18, 19,
                           20, 21, 22, 23, 24]),
    "agent_idx": jnp.array([6, 8]),
    "goal_idx": jnp.array([18]),
    "plate_pile_idx": jnp.array([16]),
    "onion_pile_idx": jnp.array([5, 9]),
    "pot_idx": jnp.array([2])
}

overcooked_layouts = {
    **overcooked_layouts,
    "cramped_room_5_5": FrozenDict(cramped_room_5_5)
}


def get_augmented_layouts():
    '''Computes the number of connected components for each layout.
    This information is precomputed and stored in the layouts dictionary to
    allow placing agents in separate components at initialization.
    '''
    augmented_layouts = {}

    for layout_name, layout in overcooked_layouts.items():
        h = layout["height"]
        w = layout["width"]
        all_pos = np.arange(np.prod([h, w]), dtype=jnp.uint32)

        wall_idx = layout.get("wall_idx")
        occupied_mask = jnp.zeros_like(all_pos)
        occupied_mask = occupied_mask.at[wall_idx].set(1)
        wall_map = occupied_mask.reshape(h, w).astype(jnp.bool_)
        free_space_map = ~wall_map

        labelled_free_space, num_components = label(free_space_map)

        # construct augmented layout
        aug_layout = {}
        for k, v in layout.items():
            aug_layout[k] = v

        aug_layout["wall_map"] = wall_map
        aug_layout["free_space_map"] = jnp.array(labelled_free_space)
        aug_layout["num_components"] = num_components

        augmented_layouts[layout_name] = FrozenDict(aug_layout)
    return augmented_layouts


augmented_layouts = get_augmented_layouts()
