from heapq import heappop, heappush
from collections import deque

# ==== Actions ====
TURN_LEFT   = 0
TURN_RIGHT  = 1
MOVE_FORWARD= 2
PICK_UP     = 3
DROP        = 4
TOGGLE      = 5

# ==== Objects ====
WALL = 2
GOAL = 8
DOOR = 4
KEY  = 5

# ==== Door state (MiniGrid) ====
DOOR_OPEN   = 0
DOOR_CLOSED = 1
DOOR_LOCKED = 2

# ==== Directions ====
RIGHT = 0
DOWN  = 1
LEFT  = 2
UP    = 3

DIRECTION_OFFSETS = {
    RIGHT: (1, 0),
    DOWN:  (0, 1),
    LEFT:  (-1, 0),
    UP:    (0, -1),
}

# ---- Persistent state ----
carrying_key_color = None
drop_cooldown = 0
last_drop_front = None  # avoid dropping twice in same front cell

# Mission state (what color we are trying to unlock)
desired_key_color = None
target_door = None
target_key  = None

# Oscillation detection
_last_positions = deque(maxlen=4)

# ---- Reset hook (call after env.reset()) ----
def policy_reset():
    global carrying_key_color, drop_cooldown, last_drop_front
    global desired_key_color, target_door, target_key, _last_positions
    carrying_key_color = None
    drop_cooldown = 0
    last_drop_front = None
    desired_key_color = None
    target_door = None
    target_key = None
    _last_positions.clear()

# ========== POLICY ==========
def policy(obs, agent_pos, agent_dir):
    """
    obs: N x N x 3 array (obj, color, state)
    agent_pos: (x, y)
    agent_dir: 0:RIGHT,1:DOWN,2:LEFT,3:UP
    returns: action int
    """
    global carrying_key_color, drop_cooldown
    global desired_key_color, target_door, target_key, _last_positions

    if drop_cooldown > 0:
        _dec_drop_cooldown()

    _last_positions.append(agent_pos)
    oscillating = _detect_oscillation()

    front = get_facing(agent_pos, agent_dir)
    goal  = find_goal(obs)

    # 1) Immediate front interactions (mission-aware)
    if in_bounds(front, obs):
        fobj, fcol, fstate = tile(obs, front)

        # Doors
        if fobj == DOOR:
            if is_door_closed_unlocked(fstate):
                return TOGGLE
            if is_door_locked(fstate) and carrying_key_color == fcol:
                return TOGGLE

        # Keys
        if fobj == KEY:
            if carrying_key_color is None and drop_cooldown == 0:
                if desired_key_color is None or fcol == desired_key_color:
                    carrying_key_color = fcol
                    return PICK_UP
            else:
                # holding a key already
                if desired_key_color is not None:
                    if carrying_key_color != desired_key_color and fcol == desired_key_color:
                        act = drop_key_somewhere(obs, agent_pos, agent_dir, goal)
                        if act is not None:
                            return act
                # otherwise ignore keys

    # 2) Try path to goal (open-only). If holding any key, avoid stepping on other keys.
    if goal:
        avoid_keys = carrying_key_color is not None
        path = a_star_open_only(obs, agent_pos, goal, avoid_keys=avoid_keys)
        if path:
            return step_to(path[0], agent_pos, agent_dir)

    # 3) Handle blocking doors (locked or closed)
    blocking = find_blocking_doors(obs, agent_pos, goal)  # (x,y,color,state)
    blocking.sort(key=lambda d: abs(d[0]-agent_pos[0]) + abs(d[1]-agent_pos[1]))

    for (dx, dy, dcol, dstate) in blocking:
        # Locked door: fetch/hold correct key, then approach door
        if is_door_locked(dstate):
            target_door = (dx, dy)
            desired_key_color = dcol

            # Need the right key
            if carrying_key_color != dcol:
                kpos = find_key_of_color(obs, dcol)
                target_key = kpos
                if kpos:
                    adj = nearest_adjacent_open_only(obs, kpos, agent_pos, avoid_keys=True)
                    if adj:
                        path = a_star_open_only(obs, agent_pos, adj, avoid_keys=True)
                        if path:
                            return step_to(path[0], agent_pos, agent_dir)
                return TURN_RIGHT

            # Have the right key -> go adjacent to door (avoid stepping on keys)
            adj = nearest_adjacent_open_only(obs, (dx, dy), agent_pos, avoid_keys=True)
            if adj:
                path = a_star_open_only(obs, agent_pos, adj, avoid_keys=True)
                if path:
                    return step_to(path[0], agent_pos, agent_dir)
            return TURN_RIGHT

        # Closed (unlocked) door: go adjacent and toggle
        if is_door_closed_unlocked(dstate):
            adj = nearest_adjacent_open_only(obs, (dx, dy), agent_pos, avoid_keys=(carrying_key_color is not None))
            if adj:
                path = a_star_open_only(obs, agent_pos, adj, avoid_keys=(carrying_key_color is not None))
                if path:
                    return step_to(path[0], agent_pos, agent_dir)
            return TURN_RIGHT

    # 4) Oscillation break / mild exploration
    if oscillating:
        side = _escape_step(obs, agent_pos, agent_dir)
        if side is not None:
            return side
        return TURN_RIGHT

    if in_bounds(front, obs):
        fo, fc, fs = tile(obs, front)
        if is_passable_open_only(fo, fs) or fo in (KEY, GOAL):
            return MOVE_FORWARD

    return TURN_RIGHT


# ================== HELPERS ==================

def get_facing(pos, dir_):
    dx, dy = DIRECTION_OFFSETS[dir_]
    return (pos[0] + dx, pos[1] + dy)

def in_bounds(pos, obs):
    n = obs.shape[0]
    return 0 <= pos[0] < n and 0 <= pos[1] < n

def tile(obs, pos):
    return obs[pos[0], pos[1], 0], obs[pos[0], pos[1], 1], obs[pos[0], pos[1], 2]

def is_door_locked(state):           return state == DOOR_LOCKED
def is_door_closed_unlocked(state):  return state == DOOR_CLOSED
def is_open_door(state):             return state == DOOR_OPEN

def is_passable_open_only(obj, state):
    if obj == WALL:
        return False
    if obj == DOOR:
        return is_open_door(state)
    return True  # GOAL, empty, key, etc.

def find_goal(obs):
    n = obs.shape[0]
    for x in range(n):
        for y in range(n):
            if obs[x, y, 0] == GOAL:
                return (x, y)
    return None

def manhattan(a, b):
    return abs(a[0]-b[0]) + abs(a[1]-b[1])

def step_to(next_pos, agent_pos, agent_dir):
    dx, dy = next_pos[0] - agent_pos[0], next_pos[1] - agent_pos[1]
    for dir_, (ox, oy) in DIRECTION_OFFSETS.items():
        if (dx, dy) == (ox, oy):
            if agent_dir == dir_:
                return MOVE_FORWARD
            elif (agent_dir - dir_) % 4 == 1:
                return TURN_LEFT
            else:
                return TURN_RIGHT
    return TURN_LEFT

# ---------- A* (open doors only), with optional key-avoidance ----------
def a_star_open_only(obs, start, goal, avoid_keys=False):
    if goal is None:
        return None
    def h(a, b): return abs(a[0]-b[0]) + abs(a[1]-b[1])
    open_set = [(0, start)]
    came_from = {}
    g = {start: 0}
    f = {start: h(start, goal)}
    while open_set:
        _, current = heappop(open_set)
        if current == goal:
            return _reconstruct_path(came_from, current)
        for ox, oy in DIRECTION_OFFSETS.values():
            nx, ny = current[0] + ox, current[1] + oy
            nxt = (nx, ny)
            if not in_bounds(nxt, obs):
                continue
            o, c, s = tile(obs, nxt)
            if not is_passable_open_only(o, s):
                continue
            if avoid_keys and o == KEY:
                continue
            tg = g[current] + 1
            if nxt not in g or tg < g[nxt]:
                g[nxt] = tg
                came_from[nxt] = current
                fn = tg + h(nxt, goal)
                f[nxt] = fn
                heappush(open_set, (fn, nxt))
    return None

def _reconstruct_path(came_from, cur):
    path = []
    while cur in came_from:
        path.append(cur)
        cur = came_from[cur]
    path.reverse()
    return path

# ---------- Blocking doors (first frontier of closed/locked) ----------
def find_blocking_doors(obs, start, goal):
    n = obs.shape[0]
    q = deque([start])
    seen = {start}
    blockers = []
    while q:
        x, y = q.popleft()
        if goal and (x, y) == goal:
            return []
        for ox, oy in DIRECTION_OFFSETS.values():
            nx, ny = x + ox, y + oy
            if not (0 <= nx < n and 0 <= ny < n):
                continue
            o, c, s = obs[nx, ny]
            nxt = (nx, ny)
            if (o == DOOR) and (is_door_locked(s) or is_door_closed_unlocked(s)):
                if nxt not in [b[:2] for b in blockers]:
                    blockers.append((nx, ny, c, s))
                continue
            if nxt in seen:
                continue
            if not is_passable_open_only(o, s):
                continue
            seen.add(nxt)
            q.append(nxt)
    return blockers

# ---------- Adjacent passable tile near target ----------
def nearest_adjacent_open_only(obs, target_pos, from_pos, avoid_keys=False):
    adjs = []
    for ox, oy in DIRECTION_OFFSETS.values():
        p = (target_pos[0] + ox, target_pos[1] + oy)
        if in_bounds(p, obs):
            o, c, s = tile(obs, p)
            if is_passable_open_only(o, s) and not (avoid_keys and o == KEY):
                adjs.append(p)
    if not adjs:
        return None
    best = None
    best_len = 1e9
    for a in adjs:
        path = a_star_open_only(obs, from_pos, a, avoid_keys=avoid_keys)
        if path and len(path) < best_len:
            best_len = len(path)
            best = a
    return best

def find_key_of_color(obs, color):
    n = obs.shape[0]
    for x in range(n):
        for y in range(n):
            if obs[x, y, 0] == KEY and obs[x, y, 1] == color:
                return (x, y)
    return None

# ---------- Important-cell aware DROP logic ----------
def drop_key_somewhere(obs, agent_pos, agent_dir, goal):
    """
    Avoid dropping on important squares:
    - GOAL/adjacent, DOOR-adjacent, current shortest path to GOAL, chokepoints
    """
    global carrying_key_color, last_drop_front

    imp = compute_important_cells(obs, agent_pos, goal)

    # 1) Try front
    front = get_facing(agent_pos, agent_dir)
    if in_bounds(front, obs) and is_safe_drop_target(obs, front, imp) and front != last_drop_front:
        carrying_key_color = None
        set_drop_cooldown()
        last_drop_front = front
        return DROP

    # 2) Rotate to face a safe front cell
    for turns, next_dir in ((0, agent_dir),
                            (1, (agent_dir + 1) % 4),
                            (2, (agent_dir + 2) % 4),
                            (3, (agent_dir + 3) % 4)):
        f = get_facing(agent_pos, next_dir)
        if in_bounds(f, obs) and is_safe_drop_target(obs, f, imp) and f != last_drop_front:
            if turns == 0:
                carrying_key_color = None
                set_drop_cooldown()
                last_drop_front = f
                return DROP
            if (agent_dir - next_dir) % 4 == 1:
                return TURN_LEFT
            else:
                return TURN_RIGHT

    # 3) Go stand on a platform with a safe front
    platform, face_dir = nearest_drop_platform(obs, agent_pos, imp)
    if platform:
        if agent_pos != platform:
            path = a_star_open_only(obs, agent_pos, platform, avoid_keys=True)
            if path:
                return step_to(path[0], agent_pos, agent_dir)
        if agent_dir != face_dir:
            if (agent_dir - face_dir) % 4 == 1:
                return TURN_LEFT
            else:
                return TURN_RIGHT
        front2 = get_facing(agent_pos, agent_dir)
        if in_bounds(front2, obs) and is_safe_drop_target(obs, front2, imp):
            carrying_key_color = None
            set_drop_cooldown()
            last_drop_front = front2
            return DROP

    return TURN_RIGHT

def is_safe_drop_target(obs, pos, important_set):
    o, c, s = tile(obs, pos)
    if o in (WALL, DOOR, KEY, GOAL):
        return False
    if pos in important_set:
        return False
    if is_chokepoint(obs, pos):
        return False
    return True

def compute_important_cells(obs, agent_pos, goal):
    important = set()
    n = obs.shape[0]

    if goal:
        important.add(goal)
        for ox, oy in DIRECTION_OFFSETS.values():
            g2 = (goal[0] + ox, goal[1] + oy)
            if in_bounds(g2, obs):
                important.add(g2)

    for x in range(n):
        for y in range(n):
            if obs[x, y, 0] == DOOR:
                for ox, oy in DIRECTION_OFFSETS.values():
                    d2 = (x + ox, y + oy)
                    if in_bounds(d2, obs):
                        important.add(d2)

    path = a_star_open_only(obs, agent_pos, goal, avoid_keys=False)
    if path:
        for p in path:
            important.add(p)

    return important

def is_chokepoint(obs, pos):
    cnt = 0
    for ox, oy in DIRECTION_OFFSETS.values():
        p = (pos[0] + ox, pos[1] + oy)
        if not in_bounds(p, obs):
            continue
        o, c, s = tile(obs, p)
        if is_passable_open_only(o, s):
            cnt += 1
    return cnt <= 2

def nearest_drop_platform(obs, from_pos, important_set):
    q = deque([from_pos])
    seen = {from_pos}
    while q:
        c = q.popleft()
        for face_dir, (ox, oy) in DIRECTION_OFFSETS.items():
            f = (c[0] + ox, c[1] + oy)
            if in_bounds(f, obs) and is_safe_drop_target(obs, f, important_set) and f != last_drop_front:
                return c, face_dir
        for ox, oy in DIRECTION_OFFSETS.values():
            nx, ny = c[0] + ox, c[1] + oy
            nxt = (nx, ny)
            if not in_bounds(nxt, obs) or nxt in seen:
                continue
            o, c_, s = tile(obs, nxt)
            if is_passable_open_only(o, s):
                seen.add(nxt)
                q.append(nxt)
    return None, None

def set_drop_cooldown():
    global drop_cooldown
    drop_cooldown = 2

def _dec_drop_cooldown():
    global drop_cooldown
    drop_cooldown = max(0, drop_cooldown - 1)

# ---- Oscillation detection: A-B-A pattern ----
def _detect_oscillation():
    if len(_last_positions) < 3:
        return False
    a, b, c = _last_positions[-3], _last_positions[-2], _last_positions[-1]
    return a == c and a != b
