# import jax
# import jax.numpy as jnp
# from flax import linen as nn
# from Craftax.craftax.craftax_classic.model import CREATE_SBERT_MODEL
from functools import partial
from Craftax.craftax.craftax_classic.constants import *
# class ImageEncoder(nn.Module):
#     @nn.compact
#     def __call__(self, x):
#         # first conv layer
#         x = nn.Conv(features=32, kernel_size=(3, 3))(x)
#         x = nn.relu(x)
#         x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
#
#         # # second conv layer
#         # x = nn.Conv(features=64, kernel_size=(3, 3))(x)
#         # x = nn.relu(x)
#         # x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
#
#         # fully connected layer
#         x = x.reshape((x.shape[0], -1))
#         x = nn.Dense(features=128)(x)
#         x = nn.relu(x)
#
#         # output layer
#         features = nn.Dense(features=512)(x)
#
#         return features

@partial(jax.jit)
def state_info(state):
    obs_dim_array = jnp.array([OBS_DIM[0], OBS_DIM[1]], dtype=jnp.int32)

    # Map
    padded_grid = jnp.pad(
        state.map,
        (MAX_OBS_DIM + 2, MAX_OBS_DIM + 2),
        constant_values=BlockType.OUT_OF_BOUNDS.value,
    )

    tl_corner = state.player_position - obs_dim_array // 2 + MAX_OBS_DIM + 2

    map_view = jax.lax.dynamic_slice(padded_grid, tl_corner, OBS_DIM)

    # Items # None
    # Mobs
    mob_map = jnp.zeros((*OBS_DIM, 4), dtype=jnp.uint8)  # 4 types of mobs

    def _add_mob_to_map(carry, mob_index):
        mob_map, mobs, mob_type_index = carry

        local_position = (
                mobs.position[mob_index]
                - state.player_position
                + jnp.array([OBS_DIM[0], OBS_DIM[1]]) // 2
        )
        on_screen = jnp.logical_and(
            local_position >= 0, local_position < jnp.array([OBS_DIM[0], OBS_DIM[1]])
        ).all()
        on_screen *= mobs.mask[mob_index]

        # mob_identifier = mob_class_index * mob_types_per_class + mobs.type_id[mob_index]
        mob_map = mob_map.at[local_position[0], local_position[1], mob_type_index].set(
            on_screen.astype(jnp.uint8)
        )

        return (mob_map, mobs, mob_type_index), None

    (mob_map, _, _), _ = jax.lax.scan(
        _add_mob_to_map,
        (mob_map, state.zombies, 0),
        jnp.arange(state.zombies.mask.shape[0]),
    )
    (mob_map, _, _), _ = jax.lax.scan(
        _add_mob_to_map, (mob_map, state.cows, 1), jnp.arange(state.cows.mask.shape[0])
    )
    (mob_map, _, _), _ = jax.lax.scan(
        _add_mob_to_map,
        (mob_map, state.skeletons, 2),
        jnp.arange(state.skeletons.mask.shape[0]),
    )
    (mob_map, _, _), _ = jax.lax.scan(
        _add_mob_to_map,
        (mob_map, state.arrows, 3),
        jnp.arange(state.arrows.mask.shape[0]),
    )

    inventory_values = jnp.array([
        state.inventory.wood,
        state.inventory.stone,
        state.inventory.coal,
        state.inventory.iron,
        state.inventory.diamond,
        state.inventory.sapling,
        state.inventory.wood_pickaxe,
        state.inventory.stone_pickaxe,
        state.inventory.iron_pickaxe,
        state.inventory.wood_sword,
        state.inventory.stone_sword,
        state.inventory.iron_sword,
    ])
    status_values = jnp.array([
        state.player_health,
        state.player_food,
        state.player_drink,
        state.player_energy,
        state.light_level*9,
    ])
    return (map_view, mob_map, inventory_values, status_values)


