"""Construct a deterministic state-transition table for MiniGrid layouts."""

import os
import numpy as np
from .bottleneck_env import SimpleEnv
from minigrid.core.world_object import Wall  # explicitly needed for type check


def build_state_transition_matrix(env, agent_start_pos=None):
    """
    Build the state-transition matrix for a MiniGrid-based environment.
    Walls are treated as absorbing states (all actions keep the agent in place).
    """
    width, height = env.width, env.height
    n_states = width * height
    n_actions = 4  # up / down / left / right

    # Record wall locations
    wall_mask = np.zeros((width, height), dtype=bool)
    for y in range(height):
        for x in range(width):
            cell = env.grid.get(x, y)
            if isinstance(cell, Wall):
                wall_mask[x, y] = True

    # Check that the start position is not a wall
    if agent_start_pos is not None:
        x, y = agent_start_pos
        if wall_mask[x, y]:
            raise ValueError(f"Start position {agent_start_pos} is a wall—choose another one.")

    # Initialize the transition matrix
    transition_matrix = np.zeros((n_states, n_actions), dtype=int)

    # Mapping from action index to (dx, dy)
    action_map = {
        0: (0, -1),  # up
        1: (0, 1),  # down
        2: (-1, 0),  # left
        3: (1, 0),  # right
    }

    # Populate the transition matrix
    for state in range(n_states):
        x = state % width
        y = state // width

        if wall_mask[x, y]:
            transition_matrix[state, :] = state  # wall cell: all actions self-loop
            continue

        for action in range(n_actions):
            dx, dy = action_map[action]
            nx, ny = x + dx, y + dy

            next_state = state  # default: stay in place
            if 0 <= nx < width and 0 <= ny < height:
                if not wall_mask[nx, ny]:
                    next_state = ny * width + nx  # map (x, y) to flat index

            transition_matrix[state, action] = next_state

    # Return transition matrix and a wall mask transposed to (height, width)
    return transition_matrix, wall_mask.T


def main():
    env = SimpleEnv(render_mode=None)
    env.reset()

    agent_start_pos = (1, 1)  # can be changed
    matrix, wall_mask = build_state_transition_matrix(env, agent_start_pos)

    print("Transition matrix shape:", matrix.shape)
    print(matrix[23, :])

    # Save results
    save_dir = "./state_transition_matrix"
    os.makedirs(save_dir, exist_ok=True)

    np.save(os.path.join(save_dir, "state_transition_matrix_full.npy"), matrix)
    np.save(os.path.join(save_dir, "wall_mask.npy"), wall_mask)

    print(f"State-transition matrix saved to {save_dir}/")


if __name__ == "__main__":
    main()

