"""
Real-time manual control for KeyLockEnv across 5 fixed episodes (init_test configs).

Episodes (agent always starts at (1,1), dir=0):
  1) no keys, doors closed
  2) has blue key
  3) blue door open (blue key consumed)
  4) has yellow key (blue door open)
  5) yellow door open (both keys consumed)

Keys:
  Arrow Up/Down/Left/Right : move (actions 0/1/2/3)
  Z : pickup key            (action 4)
  X : toggle/open door      (action 5)
  R : reset current episode to its initial config
  Q or ESC : quit

Other keys: digits 0-5 map directly to actions 0-5; all others are ignored.

Requirements: pygame (pip install pygame)
"""

import sys
import pygame
from key_lock_env import KeyLockEnv
from key_lock_options import reset_env_to_state

# Five configs matching simplified state representation (key_on_map instead of has_key)
# yk = yellow_key_on_map, bk = blue_key_on_map, yd = yellow_door_open, bd = blue_door_open
CONFIGS = [
    {"name": "0_No_keys", "yk": 1, "bk": 1, "yd": 0, "bd": 0},  # Both keys on map, doors closed
    {"name": "1_Blue_key_picked", "yk": 1, "bk": 0, "yd": 0, "bd": 0},  # Yellow key on map, blue key picked (not on map), doors closed
    {"name": "2_Blue_door_open_key_consumed", "yk": 1, "bk": 0, "yd": 0, "bd": 1},  # Blue door open (blue key consumed), yellow key on map
    {"name": "3_Yellow_key_picked_after_blue_door", "yk": 0, "bk": 0, "yd": 0, "bd": 1},  # Both keys picked (not on map), blue door open
    {"name": "4_Yellow_door_open_both_keys_consumed", "yk": 0, "bk": 0, "yd": 1, "bd": 1},  # Both doors open (both keys consumed)
]


def apply_config(env: KeyLockEnv, cfg_idx: int):
    cfg = CONFIGS[cfg_idx]
    reset_env_to_state(
        env,
        1, 1, 0,  # fixed start: x, y, dir
        yellow_door_open=cfg["yd"],
        blue_door_open=cfg["bd"],
        yellow_key_on_map=cfg["yk"],
        blue_key_on_map=cfg["bk"],
    )
    obs = env._get_obs()
    print(f"\n[Episode {cfg_idx+1}/5] {cfg['name']} | obs={obs}")
    return obs


def main():
    # Initialize pygame
    pygame.init()
    pygame.display.set_caption("KeyLockEnv - Manual Control")

    # Create environment
    env = KeyLockEnv(
        size=15,
        agent_start_pos=(1, 1),
        agent_start_dir=0,
        yellow_key_pos=(12, 3),
        yellow_door_pos=(3, 8),
        blue_key_pos=(12, 12),
        blue_door_pos=(9, 3),
        goal_pos=(3, 12),
        render_mode="rgb_array",  # we'll blit with pygame
        highlight=False,
    )

    cfg_idx = 0
    obs = apply_config(env, cfg_idx)
    print("Controls: Arrow keys move | Z=pickup | X=toggle | R=reset | Q/ESC=quit")

    # Prepare pygame window size based on first frame
    frame = env.render()
    if frame is None:
        print("Render returned None. Cannot display.")
        return
    height, width, _ = frame.shape
    screen = pygame.display.set_mode((width, height))

    clock = pygame.time.Clock()
    running = True
    while running:
        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                running = False
            elif event.type == pygame.KEYDOWN:
                key = event.key
                action = None
                if key == pygame.K_UP:
                    action = 0
                elif key == pygame.K_DOWN:
                    action = 1
                elif key == pygame.K_LEFT:
                    action = 2
                elif key == pygame.K_RIGHT:
                    action = 3
                elif key == pygame.K_z:
                    action = 4
                elif key == pygame.K_x:
                    action = 5
                elif key == pygame.K_r:
                    obs = apply_config(env, cfg_idx)
                    frame = env.render()
                elif key in (pygame.K_q, pygame.K_ESCAPE):
                    running = False
                else:
                    # Map digit keys '0'-'5' to actions 0-5; ignore others
                    if pygame.K_0 <= key <= pygame.K_9:
                        digit = key - pygame.K_0
                        if 0 <= digit <= 5:
                            action = digit

                if action is not None:
                    obs, reward, terminated, truncated, info = env.step(action)
                    print(f"action={action}, obs={obs}, reward={reward:.3f}, term={terminated}, trunc={truncated}")
                    frame = env.render()

                    if terminated or truncated:
                        print("Episode ended. Moving to next config...")
                        cfg_idx += 1
                        if cfg_idx >= len(CONFIGS):
                            print("All 5 episodes finished. Exiting.")
                            running = False
                            continue
                        obs = apply_config(env, cfg_idx)
                        frame = env.render()

        # Draw current frame
        if frame is not None:
            surf = pygame.surfarray.make_surface(frame.swapaxes(0, 1))
            screen.blit(surf, (0, 0))
            pygame.display.flip()

        clock.tick(30)  # limit to 30 FPS

    env.close()
    pygame.quit()


if __name__ == "__main__":
    try:
        main()
    except KeyboardInterrupt:
        pygame.quit()
        sys.exit(0)

