import math, heapq
from collections import deque

class Policy:
    """
    A* grid planner + LOS smoothing + reactive avoidance (fast version).
    Speedups (no accuracy loss):
      • Grid cache: reuse occupancy & distance map across steps until goal / window changes.
      • Distance transform on the grid (O(cells)) -> O(1) proximity penalty per A* expansion.
      • Short replan cooldown to avoid thrashing.
      • Relaxed A* goal-accept test (grid-aware) but strict stop via _inside_goal.

    Returns [ax, ay] in [-1, 1] using ONLY `obs`.
    """

    AGENT_RADIUS = 15
    BASE_GRID    = 1
    MAX_STEP_PIX = 3

    # knobs
    REPLAN_COOLDOWN = 5          # min ticks between full replans
    WP_POP_MULT     = 1.2        # waypoint pop tolerance vs measured step
    WINDOW_PAD      = 64         # base extra margin around start+goal for the planning window

    def __init__(self, verbose: bool = False):
        self.verbose = verbose
        self._path = []
        self._last_pos = None
        self._step_est = self.MAX_STEP_PIX
        self._pos_hist = deque(maxlen=16)
        self._stuck_ticks = 0
        self._last_side = 0
        self._last_wp = None
        self._last_plan_step = -999  # for cooldown

        # ---- grid cache
        self._occ = None
        self._dmap = None
        self._gw = self._gh = 0
        self._bbox = None     # (xmin, ymin, xmax, ymax)
        self._G = float(self.BASE_GRID)
        self._goal_sig = None # (gx, gy, gw, gh)
        self._inflate_sig = None
        self._c2w = None
        self._w2c = None

    # ------------------------- public API -------------------------

    def compute_action(self, obs):
        pos  = (float(obs["agent_pos"][0]), float(obs["agent_pos"][1]))
        objs = obs.get("objects", [])

        goals = [o for o in objs if o.get("type") == "zone" and o.get("purpose") == "goal"]
        if not goals:
            goals = [o for o in objs if o.get("type") == "zone"]
        obst  = [o for o in objs if o.get("type") == "obstacle"]

        self._update_step_est(pos)
        goal = self._select_goal(pos, goals)
        if goal is None:
            return [0.0, 0.0]
        if self._inside_goal(pos, goal):
            return [0.0, 0.0]

        # Replan if no path or first segment blocked; honor cooldown.
        need_plan = (not self._path) or (not self._segment_ok(pos, self._path[0], obst)) or (self._stuck_ticks >= 6)
        step_now = len(self._pos_hist)
        if need_plan and (step_now - self._last_plan_step >= self.REPLAN_COOLDOWN):
            if not self._plan_and_optimize(pos, goal, obst):
                # deterministic escape (no random jitter)
                ax, ay = self._escape_move(pos, obst)
                self._pos_hist.append(pos)
                return [ax, ay]
            self._last_plan_step = step_now

        # Drop reached / overshot waypoints
        self._pop_waypoints(pos)
        if not self._path:
            return [0.0, 0.0]

        # Steer to current waypoint
        tx, ty = self._path[0]
        dx, dy = tx - pos[0], ty - pos[1]
        dist   = math.hypot(dx, dy)
        if dist < 1e-9:
            return [0.0, 0.0]

        step_px = min(self.MAX_STEP_PIX, dist)
        ax, ay  = (dx / dist) * (step_px / self.MAX_STEP_PIX), (dy / dist) * (step_px / self.MAX_STEP_PIX)

        # Predictive collision check
        if self._hit(pos, (ax, ay), obst, self._step_est):
            ax, ay = self._sidestep(pos, (ax, ay), obst)
            if self._hit(pos, (ax, ay), obst, self._step_est):
                self._stuck_ticks = max(self._stuck_ticks, 6)  # force a replan next tick

        # Progress monitoring
        self._pos_hist.append(pos)
        if self._is_pinned(pos, (dx, dy)):
            ax, ay = self._sidestep(pos, (ax, ay), obst)

        # Precision finish
        if len(self._path) == 1 and dist < max(4.0, 1.2 * self._step_est):
            ax, ay = dx / dist, dy / dist

        return [float(max(-1.0, min(1.0, ax))), float(max(-1.0, min(1.0, ay)))]

    # ------------------------- goal selection -------------------------

    def _select_goal(self, pos, goals):
        if not goals:
            return None
        def rect_dist(g):
            return self._point_rect_dist(pos, g["pos"], g["size"])
        return sorted(goals, key=lambda g: (rect_dist(g), -g["size"][0]*g["size"][1]))[0]

    # ------------------------- state & detection -------------------------

    def _update_step_est(self, pos):
        if self._last_pos is not None:
            moved = self._dist(pos, self._last_pos)
            if moved > 1e-3:
                self._step_est = 0.85 * self._step_est + 0.15 * moved
            self._step_est = max(0.5, min(10.0, self._step_est))
        self._last_pos = pos

    def _is_pinned(self, pos, to_wp_vec):
        if len(self._pos_hist) < 6:
            return False
        ux, uy = self._unit(to_wp_vec)
        if ux == 0.0 and uy == 0.0:
            return False
        proj_sum = 0.0
        ph = list(self._pos_hist)
        for p0, p1 in zip(ph[:-1], ph[1:]):
            vx, vy = p1[0] - p0[0], p1[1] - p0[1]
            proj_sum += vx * ux + vy * uy
        if proj_sum < 0.3 * self._step_est:
            self._stuck_ticks += 1
            return True
        else:
            self._stuck_ticks = 0
            return False

    # ------------------------- geometry & collision -------------------------

    def _inside_goal(self, p, z):
        cx, cy = p
        zx, zy = z['pos']
        w, h   = z['size']
        L, R = zx - w/2, zx + w/2
        T, B = zy - h/2, zy + h/2

        # Require the agent circle to fit fully inside the goal rect
        margin = self.AGENT_RADIUS
        return (L + margin <= cx <= R - margin) and (T + margin <= cy <= B - margin)


    def _point_in_rect(self, p, center, size, expand=0.0):
        x, y = p
        cx, cy = center
        w, h   = size
        l = cx - w/2 - expand
        r = cx + w/2 + expand
        t = cy - h/2 - expand
        b = cy + h/2 + expand
        return (l <= x <= r) and (t <= y <= b)

    def _circle_rect_dist(self, c, r, rect):
        x, y = c
        cx, cy = rect['pos']; w, h  = rect['size']
        l, rr, t, b = cx - w/2, cx + w/2, cy - h/2, cy + h/2
        px = min(max(x, l), rr)
        py = min(max(y, t), b)
        return math.hypot(x - px, y - py) - r

    def _point_rect_dist(self, p, rect_center, rect_size):
        x, y = p
        cx, cy = rect_center; w, h = rect_size
        l, r, t, b = cx - w/2, cx + w/2, cy - h/2, cy + h/2
        dx = 0.0 if l <= x <= r else (l - x if x < l else x - r)
        dy = 0.0 if t <= y <= b else (t - y if y < t else y - b)
        return math.hypot(dx, dy)

    def _dist(self, a, b): 
        return math.hypot(a[0] - b[0], a[1] - b[1])

    def _unit(self, v):
        vx, vy = v
        n = math.hypot(vx, vy)
        return ((vx / n, vy / n) if n > 1e-9 else (0.0, 0.0))

    def _hit(self, pos, act, obst, step_len):
        nx = pos[0] + act[0] * step_len
        ny = pos[1] + act[1] * step_len
        return any(self._circle_rect_dist((nx, ny), self.AGENT_RADIUS, o) < 0 for o in obst)

    def _los_clear(self, a, b, obst, inflate=0.0):
        seg_len = self._dist(a, b)
        steps = max(6, int(seg_len / max(1.5, 0.75 * self._G)))  # fewer samples when grid coarser
        R = self.AGENT_RADIUS + inflate
        for i in range(1, steps + 1):
            t = i / steps
            px = a[0] + (b[0] - a[0]) * t
            py = a[1] + (b[1] - a[1]) * t
            if any(self._circle_rect_dist((px, py), R, o) < 0 for o in obst):
                return False
        return True

    def _segment_ok(self, a, b, obst):
        return self._los_clear(a, b, obst, inflate=0.0)

    # ------------------------- planning window & cached grid -------------------------

    def _plan_window(self, pos, goal, obst, pad=None):
        # Window around start+goal + margin; independent of far-away obstacles (faster, accurate).
        if pad is None:
            pad = self.WINDOW_PAD
        xs = [pos[0], goal['pos'][0]]
        ys = [pos[1], goal['pos'][1]]
        xmin = max(0.0, min(xs) - pad)
        ymin = max(0.0, min(ys) - pad)
        xmax = max(xs) + pad
        ymax = max(ys) + pad
        return [xmin, ymin, xmax, ymax]

    def _need_rebuild_grid(self, goal, inflate_px):
        # rebuild if no cache or goal changed or inflate changed or pos left window
        if self._occ is None or self._bbox is None or self._c2w is None:
            return True
        gx, gy = goal['pos']; gw, gh = goal['size']
        sig = (round(gx, 1), round(gy, 1), round(gw, 1), round(gh, 1))
        if self._goal_sig != sig:
            return True
        if self._inflate_sig != round(inflate_px, 3):
            return True
        return False  # window was defined by start+goal; as we move, we stay inside

    def _grid_from_obs(self, pos, goal, obst, inflate_px=0.0):
        """
        Cached occupancy + distance transform.
        Relaxed A* goal acceptance (grid aware), strict stop remains via _inside_goal.
        """
        G = float(self.BASE_GRID)
        rebuild = self._need_rebuild_grid(goal, inflate_px)
        if rebuild:
            bbox = self._plan_window(pos, goal, obst)
            xmin, ymin, xmax, ymax = bbox
            gw = max(4, int(math.ceil((xmax - xmin) / G)))
            gh = max(4, int(math.ceil((ymax - ymin) / G)))

            def c2w(cx, cy):  # cell center -> world
                return (xmin + (cx + 0.5) * G, ymin + (cy + 0.5) * G)

            def w2c(x, y):    # world -> cell index
                return (int((x - xmin) // G), int((y - ymin) // G))

            occ = [[0]*gw for _ in range(gh)]

            # Conservative inflation: agent radius + half cell keeps safety on coarse grids
            R = self.AGENT_RADIUS + inflate_px + (G * 0.5)

            # paint obstacles quickly in cell coordinates
            for o in obst:
                cx, cy = o['pos']; w, h = o['size']
                l = cx - w/2 - R; r = cx + w/2 + R
                t = cy - h/2 - R; b = cy + h/2 + R
                x0 = max(0, min(gw - 1, int((l - xmin) // G)))
                x1 = max(0, min(gw - 1, int((r - xmin) // G)))
                y0 = max(0, min(gh - 1, int((t - ymin) // G)))
                y1 = max(0, min(gh - 1, int((b - ymin) // G)))
                if x1 < x0 or y1 < y0:
                    continue
                for yy in range(y0, y1 + 1):
                    row = occ[yy]
                    for xx in range(x0, x1 + 1):
                        row[xx] = 1

            # distance transform (4-neighbor) for proximity penalty (grid cells)
            dmap = self._distance_transform4(occ, gw, gh)

            # cache
            self._occ = occ
            self._dmap = dmap
            self._gw, self._gh = gw, gh
            self._bbox = (xmin, ymin, xmax, ymax)
            self._G = G
            self._c2w, self._w2c = c2w, w2c
            gx, gy = goal['pos']; gW, gH = goal['size']
            self._goal_sig = (round(gx,1), round(gy,1), round(gW,1), round(gH,1))
            self._inflate_sig = round(inflate_px, 3)

        # relaxed A* goal acceptance (grid-aware)
        relax = max(0.75*self._G, 6.0)
        def cell_in_goal(cx, cy):
            wx, wy = self._c2w(cx, cy)
            return self._point_in_rect((wx, wy), goal['pos'], goal['size'], expand=relax)

        sx, sy = self._w2c(pos[0], pos[1])
        sx = max(0, min(self._gw - 1, sx)); sy = max(0, min(self._gh - 1, sy))
        return (self._occ, (sx, sy), cell_in_goal, (self._gw, self._gh), self._c2w, self._dmap)

    # ------------------------- A* path + smoothing -------------------------

    def _plan_and_optimize(self, start, goal, obst, prefer_clearance=False, extra_clearance=False):
        inflate = (1.0 if prefer_clearance else 0.0) + (2.0 if extra_clearance else 0.0)
        occ, start_cell, cell_in_goal, shape, c2w, dmap = self._grid_from_obs(start, goal, obst, inflate_px=inflate)
        gw, gh = shape

        path = self._astar(start_cell, occ, cell_in_goal, gw, gh, c2w, goal, inflate, dmap)
        if not path:
            self._path = []
            self._last_wp = None
            return False

        # Theta*-style smoothing (LOS)
        sm = [path[0]]
        anchor = path[0]
        prev = path[0]
        for p in path[1:]:
            if not self._los_clear(anchor, p, obst, inflate=inflate):
                sm.append(prev)
                anchor = prev
            prev = p
        if sm[-1] != path[-1]:
            sm.append(path[-1])

        # Density pruning
        pr = [sm[0]]
        for q in sm[1:]:
            if self._dist(pr[-1], q) >= max(2.0, 0.8 * self._step_est):
                pr.append(q)
        if pr[-1] != sm[-1]:
            pr.append(sm[-1])

        # Nudge final waypoint into goal (strict stop still via _inside_goal)
        if pr and not self._inside_goal(pr[-1], goal):
            pr[-1] = (goal['pos'][0], goal['pos'][1])

        self._path = pr
        self._last_wp = None
        return True

    def _astar(self, start, occ, is_goal, gw, gh, c2w, goal, inflate, dmap):
        def clear(cx, cy):
            return 0 <= cx < gw and 0 <= cy < gh and occ[cy][cx] == 0

        # nudge start if discretization blocks it
        if not clear(*start):
            from collections import deque as _dq
            q = _dq([start]); seen = {start}
            alt = None
            while q and alt is None:
                x, y = q.popleft()
                for dx, dy in [(-1,0),(1,0),(0,-1),(0,1)]:
                    nx, ny = x + dx, y + dy
                    if (nx, ny) in seen or not (0 <= nx < gw and 0 <= ny < gh): 
                        continue
                    if clear(nx, ny):
                        alt = (nx, ny); break
                    seen.add((nx, ny)); q.append((nx, ny))
            if alt is None:
                return []
            start = alt

        dirs = [(-1,0),(1,0),(0,-1),(0,1),(-1,-1),(-1,1),(1,-1),(1,1)]
        def diag_block(c, n):
            if c[0] == n[0] or c[1] == n[1]:
                return False
            return (not clear(c[0], n[1])) or (not clear(n[0], c[1]))

        # heuristic: distance from cell center to goal rectangle (in grid units)
        def h_cost(cx, cy):
            wx, wy = c2w(cx, cy)
            return self._point_rect_dist((wx, wy), goal['pos'], goal['size']) / self._G

        frontier = [(0.0, start)]
        came = {}
        g = {start: 0.0}
        goal_cell = None

        while frontier:
            _, cur = heapq.heappop(frontier)
            if is_goal(*cur):
                goal_cell = cur
                break
            gcx, gcy = cur
            for dx, dy in dirs:
                nx, ny = gcx + dx, gcy + dy
                if not clear(nx, ny) or diag_block(cur, (nx, ny)):
                    continue
                step = 1.4142 if dx and dy else 1.0

                # proximity penalty via precomputed distance map (cells -> pixels)
                # dmap counts grid steps to nearest obstacle cell (0 at obstacles)
                dc = dmap[ny][nx]
                dpx = dc * self._G
                prox_pen = 0.2 if dpx < 6.0 else (0.05 if dpx < 12.0 else 0.0)

                cost = g[cur] + step + prox_pen
                if (nx, ny) not in g or cost < g[(nx, ny)]:
                    g[(nx, ny)] = cost
                    heapq.heappush(frontier, (cost + h_cost(nx, ny), (nx, ny)))
                    came[(nx, ny)] = cur

        if goal_cell is None:
            return []

        # reconstruct path (cell->world)
        cells = []
        cur = goal_cell
        while cur != start:
            cells.append(cur)
            cur = came[cur]
        cells.reverse()
        return [c2w(cx, cy) for (cx, cy) in cells]

    # ------------------------- waypoint management -------------------------

    def _pop_waypoints(self, pos):
        while self._path:
            wp = self._path[0]
            d  = self._dist(pos, wp)
            near = d < max(self.WP_POP_MULT * self._step_est, 1.5)
            crossed = False
            if self._last_wp is not None:
                v_prev = (self._last_wp[0] - pos[0], self._last_wp[1] - pos[1])
                v_now  = (wp[0] - pos[0], wp[1] - pos[1])
                crossed = (v_prev[0] * v_now[0] + v_prev[1] * v_now[1]) < 0.0 and d < 2.5 * self._step_est
            if near or crossed:
                self._path.pop(0)
                self._last_wp = None
            else:
                self._last_wp = wp
                break

    # ------------------------- local avoidance / escape -------------------------

    def _sidestep(self, pos, base, obst):
        bx, by = base
        perp = (-by, bx); n = math.hypot(*perp) or 1.0
        mag = 0.7 if self._step_est < self.MAX_STEP_PIX else 0.9
        sx, sy = (perp[0] / n) * mag, (perp[1] / n) * mag
        order = [1, -1] if self._last_side >= 0 else [-1, 1]
        candidates = [(bx + s * sx, by + s * sy) for s in order] + [(bx * 0.5, by * 0.5)]

        best, best_score = base, -1.0
        chosen_side = self._last_side
        for i, (ax, ay) in enumerate(candidates):
            if not self._hit(pos, (ax, ay), obst, self._step_est):
                score = self._clearance_score(pos, (ax, ay), obst)
                if score > best_score:
                    best_score = score
                    best = (ax, ay)
                    chosen_side = 1 if i == 0 else (-1 if i == 1 else 0)
        self._last_side = chosen_side
        return best

    def _escape_move(self, pos, obst):
        # deterministic escape: sample 12 directions, pick max clearance
        best = (0.0, 0.0)
        best_score = -1.0
        for k in range(12):
            th = (2*math.pi) * (k / 12.0)
            ax, ay = math.cos(th), math.sin(th)
            if self._hit(pos, (ax, ay), obst, self._step_est):
                continue
            score = self._clearance_score(pos, (ax, ay), obst)
            if score > best_score:
                best_score = score
                best = (ax, ay)
        if best_score < 0:
            best = (0.2, 0.0)
        return [float(max(-1.0, min(1.0, best[0]))), float(max(-1.0, min(1.0, best[1])))]

    def _clearance_score(self, pos, act, obst):
        score, steps = 0.0, 5
        x, y = pos
        for i in range(1, steps + 1):
            nx = x + act[0] * self._step_est * i
            ny = y + act[1] * self._step_est * i
            if any(self._circle_rect_dist((nx, ny), self.AGENT_RADIUS, o) < 0 for o in obst):
                break
            score += 1.0
        return score

    # ------------------------- grid utilities -------------------------

    def _distance_transform4(self, occ, gw, gh):
        """
        4-neighbor distance (in grid steps) to the nearest occupied cell.
        Multi-source BFS; O(gw*gh). 0 at obstacles, grows outward.
        """
        INF = 10**9
        d = [[INF]*gw for _ in range(gh)]
        q = deque()
        for y in range(gh):
            row = occ[y]
            for x in range(gw):
                if row[x] == 1:
                    d[y][x] = 0
                    q.append((x, y))
        nbrs = [(-1,0),(1,0),(0,-1),(0,1)]
        while q:
            x, y = q.popleft()
            for dx, dy in nbrs:
                nx, ny = x + dx, y + dy
                if 0 <= nx < gw and 0 <= ny < gh and d[ny][nx] > d[y][x] + 1:
                    d[ny][nx] = d[y][x] + 1
                    q.append((nx, ny))
        # For free cells that never saw an obstacle (edge cases), clamp to a large value
        LIM = int(max(gw, gh))
        for y in range(gh):
            for x in range(gw):
                if d[y][x] >= INF:
                    d[y][x] = LIM
        return d
