from minigrid.core.world_object import Goal, Door, Key, Wall
from minigrid.core.grid import Grid
from minigrid.minigrid_env import MiniGridEnv
from minigrid.core.mission import MissionSpace
import numpy as np, random
from collections import deque

VALID_COLORS = ["red", "green", "blue", "purple", "yellow", "grey"]

# -------------------- BFS utils --------------------

def _passable(env, x, y, keys):
    cell = env.grid.get(x, y)
    if cell is None: return True
    if isinstance(cell, Wall): return False
    if isinstance(cell, Door) and cell.is_locked and (cell.color not in keys): return False
    return True

def flood_reachable(env, start, keys):
    """4-neighborhood flood respecting walls/locked doors."""
    q, vis = deque([start]), {start}
    while q:
        x, y = q.popleft()
        for dx, dy in ((1,0),(-1,0),(0,1),(0,-1)):
            nx, ny = x+dx, y+dy
            if 0 <= nx < env.width and 0 <= ny < env.height and (nx,ny) not in vis:
                if _passable(env, nx, ny, keys):
                    vis.add((nx,ny)); q.append((nx,ny))
    return vis

def shortest_path(env, start, goal, keys):
    """BFS shortest path (list of (x,y)) or None."""
    q, par = deque([start]), {start: None}
    while q:
        x, y = q.popleft()
        if (x, y) == goal:
            path, cur = [], (x, y)
            while cur is not None:
                path.append(cur); cur = par[cur]
            return list(reversed(path))
        for dx, dy in ((1,0),(-1,0),(0,1),(0,-1)):
            nx, ny = x+dx, y+dy
            if 0 <= nx < env.width and 0 <= ny < env.height and (nx,ny) not in par:
                if _passable(env, nx, ny, keys):
                    par[(nx,ny)] = (x,y); q.append((nx,ny))
    return None

def bfs_ignore_doors(env, start, goal):
    """Scaffold path ignoring locks (doors passable, walls block)."""
    def pass_ign(cell):
        if cell is None: return True
        return not isinstance(cell, Wall)
    q, par = deque([start]), {start: None}
    while q:
        x, y = q.popleft()
        if (x, y) == goal:
            path, cur = [], (x, y)
            while cur is not None: path.append(cur); cur = par[cur]
            return list(reversed(path))
        for dx, dy in ((1,0),(-1,0),(0,1),(0,-1)):
            nx, ny = x+dx, y+dy
            if 0 <= nx < env.width and 0 <= ny < env.height and (nx,ny) not in par:
                if pass_ign(env.grid.get(nx, ny)):
                    par[(nx,ny)] = (x,y); q.append((nx,ny))
    return None

# -------------------- small helpers --------------------

def find_random_empty(env, exclude=None):
    ex = set(exclude or [])
    choices = [(x,y)
               for x in range(1, env.width-1)
               for y in range(1, env.height-1)
               if env.grid.get(x,y) is None and (x,y) not in ex]
    return random.choice(choices) if choices else None

def manhattan(a, b): return abs(a[0]-b[0]) + abs(a[1]-b[1])

# -------------------- Environment --------------------

class CustomEnv(MiniGridEnv):
    """
    Relaxed chokepoints + robust key placement.

    Knobs:
      - num_doors (int): number of locked doors on the main path
      - fraction_hard (float): fraction of doors as true chokepoints
      - soft_wing_len (int): soft door wing length (0 = no wings)
      - complexity (float): random obstacle density [0..1]

    Fixes:
      - Each key i is placed on the AGENT SIDE of door i.
      - We require: path(agent -> key_i) and path(key_i -> door_i_side).
      - We reserve those corridors (with a 1-cell halo) so later obstacles cannot block them.
    """
    def __init__(self,
                 size=30,
                 max_steps=None,
                 complexity=0.55,
                 num_doors=6,
                 fraction_hard=0.3,
                 soft_wing_len=1):
        self.size = size
        self.complexity = complexity
        self.num_doors = num_doors
        self.fraction_hard = fraction_hard
        self.soft_wing_len = soft_wing_len

        if max_steps is None:
            max_steps = 5 * (size ** 2)

        mission_space = MissionSpace(lambda: "Reach the green goal.")
        import gymnasium as gym
        self.observation_space = gym.spaces.Box(
            low=0.0, high=1.0, shape=(self.size * self.size * 3,), dtype=np.float32
        )
        super().__init__(mission_space=mission_space,
                         grid_size=self.size,
                         see_through_walls=True,
                         max_steps=max_steps)
        self.failure_feedback = ""

    # -------------------- generation --------------------

    def _gen_grid(self, width, height):
        if self.num_doors > len(VALID_COLORS):
            self.failure_feedback = f"Requested {self.num_doors} doors > available colors"
            return

        max_tries, solved = 800, False
        for _ in range(max_tries):
            # clear grid
            self.grid = Grid(width, height)
            self.grid.wall_rect(0, 0, width, height)
            for x in range(1, width-1):
                for y in range(1, height-1):
                    self.grid.set(x, y, None)

            # agent / goal
            agent_pos = find_random_empty(self)
            goal_pos  = find_random_empty(self, exclude=[agent_pos])
            if not agent_pos or not goal_pos or agent_pos == goal_pos:
                continue
            self.agent_pos = agent_pos
            self.agent_dir = random.randint(0, 3)
            self.put_obj(Goal(), goal_pos[0], goal_pos[1])

            # scaffold path
            main_path = bfs_ignore_doors(self, agent_pos, goal_pos)
            if not main_path or len(main_path) < (2 * self.num_doors + 5):
                continue
            path_set = set(main_path)

            # door slots & colors
            door_idxs = self._choose_sequential_door_indices(main_path, self.num_doors)
            colors = random.sample(VALID_COLORS, k=self.num_doors)

            # pick hard vs soft doors
            n_hard = max(1, min(self.num_doors, int(round(self.fraction_hard * self.num_doors))))
            hard_mask = set(random.sample(range(self.num_doors), k=n_hard))

            # place doors + structures
            door_positions = []
            for i, idx in enumerate(door_idxs):
                x, y = main_path[idx]
                if self.grid.get(x, y) is not None: break
                self.put_obj(Door(colors[i], is_locked=True), x, y)
                door_positions.append((x, y))
                orient = self._local_path_orient(main_path, idx)
                if i in hard_mask:
                    self._add_barrier_line(x, y, orient)
                    self._add_door_jamb(x, y, orient)
                else:
                    self._add_soft_wings(x, y, orient, self.soft_wing_len)
                    self._add_door_jamb(x, y, orient)
            else:
                # success placing doors
                # --- KEY PLACEMENT with corridor protection ---
                reserve = set(path_set) | {agent_pos, goal_pos} | set(door_positions)
                protected = set()

                key_positions = []
                for i, (dx, dy) in enumerate(door_positions):
                    unlocked = set(colors[:i])          # colors before door i are usable
                    # region reachable before opening door i
                    region = flood_reachable(self, agent_pos, unlocked)

                    # door-side anchors (neighbors that are in region)
                    side_neighbors = [(nx, ny) for (nx, ny) in self._neighbors(dx, dy)
                                      if (0 <= nx < self.width and 0 <= ny < self.height)
                                      and _passable(self, nx, ny, unlocked)
                                      and (nx, ny) in region]

                    if not side_neighbors:
                        break  # no approach side → reject layout

                    # candidate cells = region minus reserves; bias off-path and not cramped
                    candidates = [c for c in region
                                  if c not in reserve
                                  and self.grid.get(*c) is None
                                  and self._free_neighbors_count(*c, keys=unlocked) >= 2]

                    if not candidates:
                        break

                    # prefer off-path
                    far = [c for c in candidates if self._dist_to_set(c, path_set) >= 2]
                    pool = far if far else candidates
                    random.shuffle(pool)

                    # pick a candidate that validates both paths and protect corridors
                    placed = False
                    for cx, cy in pool:
                        p1 = shortest_path(self, agent_pos, (cx, cy), unlocked)
                        if not p1: continue
                        # path from key to ANY door-side neighbor
                        p2 = None
                        for s in side_neighbors:
                            p2 = shortest_path(self, (cx, cy), s, unlocked)
                            if p2: break
                        if not p2: continue

                        # accept; place key and protect both corridors (with 1-cell halo)
                        self.put_obj(Key(colors[i]), cx, cy)
                        key_positions.append((cx, cy))

                        for cell in p1 + p2:
                            protected.add(cell)
                            for nb in self._neighbors(*cell):
                                if 1 <= nb[0] < self.width-1 and 1 <= nb[1] < self.height-1:
                                    protected.add(nb)

                        reserve |= set(p1) | set(p2) | set(key_positions)
                        placed = True
                        break

                    if not placed:
                        break  # fail placing a valid, reachable key

                else:
                    # scatter obstacles but NEVER on protected/reserve
                    self._place_obstacles(reserve | protected, door_positions, key_positions)

                    # final solvability check (collecting keys in order)
                    if self._check_solvable_ordered(agent_pos, goal_pos, colors, door_positions, key_positions):
                        solved = True

                if solved: break

            # if any step failed, try again
            continue

        if not solved:
            self.failure_feedback = "No solvable layout found. Lower complexity or doors, or reduce fraction_hard."

    def gen_obs(self):
        return self.grid.encode().astype(np.float32)

    # -------------------- internals --------------------

    def _neighbors(self, x, y):
        return [(x+1,y),(x-1,y),(x,y+1),(x,y-1)]

    def _free_neighbors_count(self, x, y, keys=frozenset()):
        cnt = 0
        for nx, ny in self._neighbors(x, y):
            if 0 <= nx < self.width and 0 <= ny < self.height and _passable(self, nx, ny, keys):
                cnt += 1
        return cnt

    def _dist_to_set(self, c, S):
        # Manhattan distance to closest element in S (bounded small loop)
        x, y = c
        best = 1e9
        for px, py in S:
            d = abs(px-x) + abs(py-y)
            if d < best: best = d
            if best == 0: break
        return best

    def _check_solvable_ordered(self, start, goal, colors, door_positions, key_positions):
        """Simulate picking keys in order; at each stage verify reachability to next key and finally goal."""
        pos = start
        have = set()
        for i, color in enumerate(colors):
            kpos = key_positions[i]
            p = shortest_path(self, pos, kpos, have)
            if not p: return False
            have.add(color)  # pick key i
            pos = kpos
            # After picking, verify we can reach the approach side of door i (and thus pass it)
            dpos = door_positions[i]
            # pass through door cell is allowed now
            p2 = shortest_path(self, pos, dpos, have)
            if not p2: return False
            pos = dpos
        # finally to goal
        p_last = shortest_path(self, pos, goal, have)
        return p_last is not None

    def _choose_sequential_door_indices(self, path, k):
        n = len(path); start, end = 2, n-3
        if end - start + 1 < k:
            step = max(1, (end - start + 1) // k)
            idxs = [min(end, start + i * step) for i in range(k)]
            return sorted(set(idxs))[:k]
        base = [start + (i+1)*(end - start)//(k+1) for i in range(k)]
        out, last = [], start-1
        for b in base:
            j = max(last+1, b + random.randint(-2, 2))
            j = min(j, end - (k - len(out) - 1))
            out.append(j); last = j
        return out

    def _local_path_orient(self, path, idx):
        a = path[max(0, idx-1)]; c = path[min(len(path)-1, idx+1)]
        dx, dy = c[0]-a[0], c[1]-a[1]
        return 'x' if abs(dx) >= abs(dy) else 'y'

    def _add_barrier_line(self, x, y, orient):
        if orient == 'x':
            for yy in range(1, self.height-1):
                if yy == y: continue
                if self.grid.get(x, yy) is None: self.put_obj(Wall(), x, yy)
        else:
            for xx in range(1, self.width-1):
                if xx == x: continue
                if self.grid.get(xx, y) is None: self.put_obj(Wall(), xx, y)

    def _add_soft_wings(self, x, y, orient, wing_len=2):
        wing_len = max(0, int(wing_len))
        if wing_len == 0: return
        if orient == 'x':
            for d in range(1, wing_len+1):
                for yy in (y-d, y+d):
                    if 1 <= yy < self.height-1 and self.grid.get(x, yy) is None:
                        self.put_obj(Wall(), x, yy)
        else:
            for d in range(1, wing_len+1):
                for xx in (x-d, x+d):
                    if 1 <= xx < self.width-1 and self.grid.get(xx, y) is None:
                        self.put_obj(Wall(), xx, y)

    def _add_door_jamb(self, x, y, orient):
        sides = [(x, y-1), (x, y+1)] if orient == 'x' else [(x-1, y), (x+1, y)]
        for sx, sy in sides:
            if 1 <= sx < self.width-1 and 1 <= sy < self.height-1:
                if self.grid.get(sx, sy) is None:
                    self.put_obj(Wall(), sx, sy)

    def _place_obstacles(self, reserved, doors, keys):
        """Scatter obstacles with density self.complexity, never on reserved or adjacent to doors/keys."""
        reserved = set(reserved) | set(doors) | set(keys)
        # 1-cell halo around doors/keys
        for d in doors + keys:
            for nb in self._neighbors(*d):
                if 1 <= nb[0] < self.width-1 and 1 <= nb[1] < self.height-1:
                    reserved.add(nb)

        free = [(x,y)
                for x in range(1, self.width-1)
                for y in range(1, self.height-1)
                if (x,y) not in reserved and self.grid.get(x,y) is None]

        n_obs = int(len(free) * self.complexity)
        random.shuffle(free)
        placed = 0
        for (x,y) in free:
            if placed >= n_obs: break
            # avoid sealing narrow 1-wide corridors: keep at least 2 passable neighbors
            if self._free_neighbors_count(x, y, keys=frozenset()) <= 1:
                continue
            self.put_obj(Wall(), x, y); placed += 1
