"""Construct deterministic state-transition table for KeyLockEnv."""

import os
import collections
import numpy as np
from key_lock_env import KeyLockEnv
from minigrid.core.world_object import Wall, Door, Key


def encode_state(
    x: int, y: int, direction: int,
    yellow_door_open: int, blue_door_open: int,
    yellow_key_on_map: int, blue_key_on_map: int,
    size: int,
) -> int:
    idx = x
    idx = idx * size + y
    idx = idx * 4 + direction
    idx = idx * 2 + yellow_door_open
    idx = idx * 2 + blue_door_open
    idx = idx * 2 + yellow_key_on_map
    idx = idx * 2 + blue_key_on_map
    return int(idx)


def decode_state(state_index: int, size: int):
    blue_key_on_map = state_index % 2
    state_index //= 2

    yellow_key_on_map = state_index % 2
    state_index //= 2

    blue_door_open = state_index % 2
    state_index //= 2

    yellow_door_open = state_index % 2
    state_index //= 2

    direction = state_index % 4
    state_index //= 4

    y = state_index % size
    x = state_index // size

    return (
        x, y, direction,
        yellow_door_open, blue_door_open,
        yellow_key_on_map, blue_key_on_map,
    )


def build_keylock_transition_matrix(
    env: KeyLockEnv, size: int,
    yellow_key_pos: tuple, yellow_door_pos: tuple,
    blue_key_pos: tuple, blue_door_pos: tuple, goal_pos: tuple,
):
    num_actions = 6
    N = size * size * 4 * 2 * 2 * 2 * 2

    T_next = np.zeros((N, num_actions), dtype=np.int32)
    T_done = np.zeros((N, num_actions), dtype=np.int32)
    valid_states = set()

    print(f"[Transition Matrix] Building matrix for {N} states (size={size})...")

    # Cache wall positions
    wall_positions = set()
    for x in range(size):
        for y in range(size):
            cell = env.grid.get(x, y)
            if cell is not None and isinstance(cell, Wall):
                wall_positions.add((x, y))

    orient_dxdy = {
        0: (1, 0),   # east
        1: (0, 1),   # south
        2: (-1, 0),  # west
        3: (0, -1),  # north
    }

    for s in range(N):
        (
            x, y, direction,
            yellow_door_open, blue_door_open,
            yellow_key_on_map, blue_key_on_map,
        ) = decode_state(s, size)

        if (x, y) in wall_positions:
            T_next[s, :] = s
            T_done[s, :] = 0
            continue

        valid_states.add(s)

        for a in range(num_actions):
            next_x, next_y, next_dir = x, y, direction
            next_yellow_door_open = yellow_door_open
            next_blue_door_open = blue_door_open
            next_yellow_key_on_map = yellow_key_on_map
            next_blue_key_on_map = blue_key_on_map
            done = 0

            if a < 4:
                target_dir = {0: 3, 1: 1, 2: 2, 3: 0}[a]
                next_dir = target_dir
                dx, dy = orient_dxdy[target_dir]
                nx, ny = x + dx, y + dy
                blocked = False
                if not (0 <= nx < size and 0 <= ny < size):
                    blocked = True
                elif (nx, ny) in wall_positions:
                    blocked = True
                elif (nx, ny) == blue_door_pos and blue_door_open == 0:
                    blocked = True
                elif (nx, ny) == yellow_door_pos and yellow_door_open == 0:
                    blocked = True
                elif (nx, ny) == blue_key_pos and blue_key_on_map == 1:
                    blocked = True
                elif (nx, ny) == yellow_key_pos and yellow_key_on_map == 1:
                    blocked = True
                if not blocked:
                    next_x, next_y = nx, ny

            elif a == 4:
                if direction == 0:
                    front_x, front_y = x + 1, y
                elif direction == 1:
                    front_x, front_y = x, y + 1
                elif direction == 2:
                    front_x, front_y = x - 1, y
                else:
                    front_x, front_y = x, y - 1
                agent_carrying = False
                if yellow_key_on_map == 1 and blue_key_on_map == 1:
                    agent_carrying = False
                elif yellow_key_on_map == 1 and blue_key_on_map == 0:
                    agent_carrying = blue_door_open == 0
                elif yellow_key_on_map == 0 and blue_key_on_map == 1:
                    agent_carrying = yellow_door_open == 0
                else:
                    agent_carrying = not (yellow_door_open == 1 and blue_door_open == 1)
                if (front_x, front_y) == yellow_key_pos and next_yellow_key_on_map == 1:
                    if not agent_carrying:
                        next_yellow_key_on_map = 0
                elif (front_x, front_y) == blue_key_pos and next_blue_key_on_map == 1:
                    if not agent_carrying:
                        next_blue_key_on_map = 0

            else:
                if direction == 0:
                    front_x, front_y = x + 1, y
                elif direction == 1:
                    front_x, front_y = x, y + 1
                elif direction == 2:
                    front_x, front_y = x - 1, y
                else:
                    front_x, front_y = x, y - 1
                if (front_x, front_y) == blue_door_pos and next_blue_key_on_map == 0:
                    next_blue_door_open = 1 - next_blue_door_open
                if (front_x, front_y) == yellow_door_pos and next_yellow_key_on_map == 0:
                    next_yellow_door_open = 1 - next_yellow_door_open
            if (next_x, next_y) == goal_pos:
                done = 1

            sn = encode_state(
                next_x, next_y, next_dir,
                next_yellow_door_open, next_blue_door_open,
                next_yellow_key_on_map, next_blue_key_on_map,
                size,
            )
            T_next[s, a] = sn
            T_done[s, a] = done

    print(f"[Transition Matrix] Built matrix with {len(valid_states)} valid states")
    initial_state = encode_state(1, 1, 0, 0, 0, 1, 1, size)
    reachable_states = set()
    queue = collections.deque([initial_state])
    reachable_states.add(initial_state)
    print(f"[Transition Matrix] Performing reachability analysis from initial state {initial_state}...")
    while queue:
        s = queue.popleft()
        for a in range(num_actions):
            sn = int(T_next[s, a])
            if sn not in reachable_states and sn in valid_states:
                reachable_states.add(sn)
                queue.append(sn)
    print(f"[Transition Matrix] Found {len(reachable_states)} reachable states out of {len(valid_states)} valid states")
    print(f"[Transition Matrix] Reachability ratio: {len(reachable_states) / len(valid_states) * 100:.2f}%")
    return T_next, T_done, reachable_states


def main():
    import argparse
    
    pa = argparse.ArgumentParser()
    pa.add_argument("--size", type=int, default=15, help="grid size")
    pa.add_argument("--yellow_key_pos", type=int, nargs=2, default=[12, 3], help="yellow key position (x, y)")
    pa.add_argument("--yellow_door_pos", type=int, nargs=2, default=[3, 8], help="yellow door position (x, y)")
    pa.add_argument("--blue_key_pos", type=int, nargs=2, default=[12, 12], help="blue key position (x, y)")
    pa.add_argument("--blue_door_pos", type=int, nargs=2, default=[9, 3], help="blue door position (x, y)")
    pa.add_argument("--goal_pos", type=int, nargs=2, default=[3, 12], help="goal position (x, y)")
    pa.add_argument("--save_dir", default="keylock_transition_matrix")
    args = pa.parse_args()
    
    # Create environment to get grid layout
    env = KeyLockEnv(
        size=args.size,
        agent_start_pos=(1, 1),
        agent_start_dir=0,
        yellow_key_pos=tuple(args.yellow_key_pos),
        yellow_door_pos=tuple(args.yellow_door_pos),
        blue_key_pos=tuple(args.blue_key_pos),
        blue_door_pos=tuple(args.blue_door_pos),
        goal_pos=tuple(args.goal_pos),
        render_mode=None,
    )
    env.reset()
    
    # Build transition matrix
    T_next, T_done, valid_states = build_keylock_transition_matrix(
        env, args.size,
        tuple(args.yellow_key_pos),
        tuple(args.yellow_door_pos),
        tuple(args.blue_key_pos),
        tuple(args.blue_door_pos),
        tuple(args.goal_pos),
    )
    
    # Save results
    os.makedirs(args.save_dir, exist_ok=True)
    np.save(os.path.join(args.save_dir, "T_next.npy"), T_next)
    np.save(os.path.join(args.save_dir, "T_done.npy"), T_done)
    np.save(os.path.join(args.save_dir, "valid_states.npy"), np.array(list(valid_states), dtype=np.int32))
    
    print(f"\n[Save] Transition matrix saved to {args.save_dir}/")
    print(f"  T_next shape: {T_next.shape}")
    print(f"  T_done shape: {T_done.shape}")
    print(f"  Valid states: {len(valid_states)}")


if __name__ == "__main__":
    main()
