# Evolved from env_8.txt per the same prompt constraints. Constant sizes, 1500x1500. :contentReference[oaicite:0]{index=0}
import numpy as np
import pygame
import math
from collections import deque

class CustomEnv:
    def __init__(self, render_mode=False):
        # ---- Constants (no randomness in canvas/zone sizes) ----
        self.W = 1400
        self.H = 1000

        self.agent_radius = 15
        self.zone_margin = 14
        self.zone_size = [2 * self.agent_radius + self.zone_margin,
                          2 * self.agent_radius + self.zone_margin]

        # ---- Layout density ----
        self.n_obstacles_total = 64         # total rectangles (bars + blocks)
        self.n_bars = 8                     # structured long “walls” to create corridors
        self.block_size_range = (60, 150)   # random blocks
        self.bar_thick_range  = (40, 60)    # bar thickness
        self.bar_long_range   = (350, 800)  # bar length

        # ---- Motion & horizon ----
        self.agent_speed = 20
        self.speed = 20
        self.min_agent_goal_dist = 1200
        self.max_steps = 5000

        # ---- Rendering ----
        self.render_mode = render_mode
        self.screen = None
        self.clock = None
        self.colors = {
            'agent': (50, 140, 255),
            'zone': (50, 220, 50),
            'obstacle': (120, 120, 120),
            'background': (14, 14, 14),
        }

        self.reset()

    # ================= Core API =================

    def reset(self):
        self.step_count = 0
        self.done = False
        self._success = False

        max_tries = 5000
        zw, zh = self.zone_size
        edge_buf = self.agent_radius + 8

        for _ in range(max_tries):
            # ----- Goal zone (kept constant size, random position within bounds) -----
            x0 = edge_buf + zw // 2
            x1 = self.W - edge_buf - zw // 2
            y0 = edge_buf + zh // 2
            y1 = self.H - edge_buf - zh // 2
            self.zone_center = [
                int(np.random.randint(x0, x1 + 1)),
                int(np.random.randint(y0, y1 + 1))
            ]

            approach_rect = pygame.Rect(
                self.zone_center[0] - zw // 2 - self.agent_radius,
                self.zone_center[1] - zh // 2 - self.agent_radius,
                zw + 2 * self.agent_radius,
                zh + 2 * self.agent_radius
            )

            # ----- Structured bars to form corridors -----
            self.obstacles = []
            if not self._place_bars(approach_rect, max_tries):
                continue

            # ----- Sprinkle remaining random blocks -----
            remaining = max(0, self.n_obstacles_total - len(self.obstacles))
            if not self._place_blocks(remaining, approach_rect, max_tries):
                continue

            # ----- Sample agent start (far, collision-free) -----
            for _ in range(max_tries):
                ax = int(np.random.randint(self.agent_radius, self.W - self.agent_radius + 1))
                ay = int(np.random.randint(self.agent_radius, self.H - self.agent_radius + 1))

                if np.linalg.norm(np.array([ax, ay]) - np.array(self.zone_center)) < self.min_agent_goal_dist:
                    continue
                if self._point_in_any_obstacle((ax, ay)):
                    continue

                self.agent_pos = np.array([float(ax), float(ay)])
                if self._path_exists():  # coarse BFS solvability check (inflated) :contentReference[oaicite:1]{index=1}
                    if self.render_mode and self.screen is None:
                        pygame.init()
                        self.screen = pygame.display.set_mode((self.W, self.H))
                        self.clock = pygame.time.Clock()
                    return self._get_obs()
            # retry full layout
        raise RuntimeError("Failed to generate a solvable environment.")

    def step(self, action):
        self._handle_quit()
        if self.done:
            return self._get_obs(), 0.0, True

        # Move with clipping
        dx = float(np.clip(action[0], -1.0, 1.0)) * self.agent_speed
        dy = float(np.clip(action[1], -1.0, 1.0)) * self.agent_speed
        new_pos = np.array([self.agent_pos[0] + dx, self.agent_pos[1] + dy])
        new_pos[0] = np.clip(new_pos[0], self.agent_radius, self.W - self.agent_radius)
        new_pos[1] = np.clip(new_pos[1], self.agent_radius, self.H - self.agent_radius)

        if not self._collides_any(new_pos):
            self.agent_pos = new_pos

        self.step_count += 1
        reward = -0.01

        # Success detection outside step() (per prompt style); step() only enforces max-steps. :contentReference[oaicite:2]{index=2}
        obs = self._get_obs()
        if getattr(self, "_success", False):
            reward = 1.0
        if self.step_count >= self.max_steps:
            self.done = True

        return obs, reward, self.done

    def draw_objects(self):
        if not self.render_mode:
            return
        self.screen.fill(self.colors['background'])
        zx, zy = self.zone_center
        zw, zh = self.zone_size
        pygame.draw.rect(self.screen, self.colors['zone'],
                         pygame.Rect(zx - zw // 2, zy - zh // 2, zw, zh), border_radius=8)
        for ox, oy, w, h in self.obstacles:
            pygame.draw.rect(self.screen, self.colors['obstacle'],
                             pygame.Rect(ox - w // 2, oy - h // 2, w, h), border_radius=6)
        pygame.draw.circle(self.screen, self.colors['agent'],
                           (int(self.agent_pos[0]), int(self.agent_pos[1])), self.agent_radius)

    def render(self, wait=15):
        if not self.render_mode:
            return
        self.draw_objects()
        pygame.display.flip()
        if self.clock:
            self.clock.tick(60)
        if wait and self.done:
            pygame.time.wait(wait)

    # ================= Helpers =================

    def _handle_quit(self):
        if not self.render_mode:
            return
        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                pygame.quit()
                import sys
                sys.exit(0)

    # ---- Placement ----

    def _place_bars(self, approach_rect, max_tries):
        placed = 0
        tries = 0
        while placed < self.n_bars and tries < max_tries:
            tries += 1
            horizontal = bool(np.random.randint(0, 2))
            thick = int(np.random.randint(*self.bar_thick_range))
            long  = int(np.random.randint(*self.bar_long_range))

            if horizontal:
                w, h = long, thick
            else:
                w, h = thick, long

            ox = int(np.random.randint(w // 2 + 5, self.W - w // 2 - 5))
            oy = int(np.random.randint(h // 2 + 5, self.H - h // 2 - 5))
            rect = pygame.Rect(ox - w // 2, oy - h // 2, w, h)

            if rect.colliderect(approach_rect):
                continue
            if not self._rect_can_place(rect):
                continue

            self.obstacles.append([ox, oy, w, h])
            placed += 1
        return placed == self.n_bars

    def _place_blocks(self, count, approach_rect, max_tries):
        placed = 0
        tries = 0
        while placed < count and tries < max_tries:
            tries += 1
            w = int(np.random.randint(*self.block_size_range))
            h = int(np.random.randint(*self.block_size_range))
            ox = int(np.random.randint(w // 2 + 5, self.W - w // 2 - 5))
            oy = int(np.random.randint(h // 2 + 5, self.H - h // 2 - 5))
            rect = pygame.Rect(ox - w // 2, oy - h // 2, w, h)

            if rect.colliderect(approach_rect):
                continue
            if not self._rect_can_place(rect):
                continue

            self.obstacles.append([ox, oy, w, h])
            placed += 1
        return placed == count

    def _rect_can_place(self, rect):
        # clearance: inflate both new and existing by agent radius so corridors stay navigable
        infl_new = rect.inflate(2 * self.agent_radius, 2 * self.agent_radius)
        if not self._in_bounds(infl_new):
            return False
        for px, py, pw, ph in self.obstacles:
            existing = pygame.Rect(px - pw // 2, py - ph // 2, pw, ph)
            infl_old = existing.inflate(2 * self.agent_radius, 2 * self.agent_radius)
            if infl_new.colliderect(infl_old) or rect.colliderect(existing):
                return False
        return True

    def _in_bounds(self, rect):
        return (rect.left >= 0 and rect.top >= 0 and
                rect.right <= self.W and rect.bottom <= self.H)

    def _point_in_any_obstacle(self, p):
        x, y = p
        for ox, oy, w, h in self.obstacles:
            r = pygame.Rect(ox - w // 2, oy - h // 2, w, h)
            if r.collidepoint(x, y):
                return True
        return False

    def _collides_any(self, pos):
        for ox, oy, w, h in self.obstacles:
            if self._circle_rect_overlap(pos, self.agent_radius,
                                         pygame.Rect(ox - w // 2, oy - h // 2, w, h)):
                return True
        return False

    def _circle_rect_overlap(self, circle_pos, radius, rect):
        cx, cy = float(circle_pos[0]), float(circle_pos[1])
        closest_x = np.clip(cx, rect.left, rect.right)
        closest_y = np.clip(cy, rect.top, rect.bottom)
        distance = np.sqrt((closest_x - cx) ** 2 + (closest_y - cy) ** 2)
        return distance < radius

    # ---- Solvability gate ----

    def _path_exists(self):
        # BFS on coarse grid with inflated obstacles to guarantee at least one free route. :contentReference[oaicite:3]{index=3}
        grid_size = 3
        grid_w = self.W // grid_size
        grid_h = self.H // grid_size
        grid = np.zeros((grid_h, grid_w), dtype=np.uint8)

        for ox, oy, w, h in self.obstacles:
            x0 = max(0, (ox - w // 2 - self.agent_radius) // grid_size)
            x1 = min(grid_w, (ox + w // 2 + self.agent_radius) // grid_size)
            y0 = max(0, (oy - h // 2 - self.agent_radius) // grid_size)
            y1 = min(grid_h, (oy + h // 2 + self.agent_radius) // grid_size)
            grid[y0:y1, x0:x1] = 1

        start = (int(self.agent_pos[1] // grid_size), int(self.agent_pos[0] // grid_size))

        zx, zy = self.zone_center
        zw, zh = self.zone_size
        zone_rect = pygame.Rect(zx - zw // 2, zy - zh // 2, zw, zh)

        goal_cells = []
        for y in range(grid_h):
            for x in range(grid_w):
                if grid[y, x] != 0:
                    continue
                px, py = x * grid_size, y * grid_size
                if (px - self.agent_radius >= zone_rect.left and
                    px + self.agent_radius <= zone_rect.right and
                    py - self.agent_radius >= zone_rect.top and
                    py + self.agent_radius <= zone_rect.bottom):
                    goal_cells.append((y, x))

        if not goal_cells:
            return False

        visited = np.zeros_like(grid, dtype=bool)
        q = deque([start])
        if not (0 <= start[0] < grid_h and 0 <= start[1] < grid_w) or grid[start] != 0:
            return False
        visited[start] = True

        while q:
            y, x = q.popleft()
            if (y, x) in goal_cells:
                return True
            for dy, dx in [(-1,0), (1,0), (0,-1), (0,1)]:
                ny, nx = y + dy, x + dx
                if 0 <= ny < grid_h and 0 <= nx < grid_w and not visited[ny, nx] and grid[ny, nx] == 0:
                    visited[ny, nx] = True
                    q.append((ny, nx))
        return False

    # ---- Termination (outside step) ----

    def _check_done(self, obs):
        # Agent circle fully inside goal rect → return 1 else 0. Mirrors prior success rule. :contentReference[oaicite:4]{index=4}
        agent_x, agent_y = obs['agent_pos']
        for o in obs['objects']:
            if o.get('type') == 'zone':
                zx, zy = o['pos']
                zw, zh = o['size']
                left = zx - zw / 2
                right = zx + zw / 2
                top = zy - zh / 2
                bottom = zy + zh / 2
                cx, cy = agent_x, agent_y
                closest_x = max(left, min(cx, right))
                closest_y = max(top,  min(cy, bottom))
                dist = math.hypot(cx - closest_x, cy - closest_y)
                return 1 if dist < self.agent_radius else 0
        return 0

    def _get_obs(self):
        objects = [{
            'type': 'zone',
            'pos': [float(self.zone_center[0]), float(self.zone_center[1])],
            'size': [float(self.zone_size[0]), float(self.zone_size[1])],
            'purpose': 'goal'
        }]
        for ox, oy, w, h in self.obstacles:
            objects.append({
                'type': 'obstacle',
                'pos': [float(ox), float(oy)],
                'size': [float(w), float(h)]
            })
        obs = {
            'agent_pos': [float(self.agent_pos[0]), float(self.agent_pos[1])],
            'objects': objects,
            'step_count': int(self.step_count)
        }
        # Side-effect termination when goal reached (keep step() clean for max-steps). :contentReference[oaicite:5]{index=5}
        self._success = bool(self._check_done(obs))
        if self._success:
            self.done = True
        return obs

    def task_description(self):
        return (
            "Objective: Move the blue circular agent (radius 15 px) fully inside the green rectangular goal zone "
            f"(size {self.zone_size[0]}x{self.zone_size[1]} px). The map size is {self.W}x{self.H} px (constant). "
            f"There are {len(self.obstacles)} gray axis-aligned rectangular obstacles (mix of long bars and blocks). "
            "Obstacles never overlap each other nor the inflated goal-approach window; corridors have at least agent-radius clearance. "
            f"The agent starts at least {self.min_agent_goal_dist} px from the goal and outside all obstacles. "
            "Action: 2D vector [dx, dy] each in [-1.0, 1.0]; movement is (dx*4.0, dy*4.0) px, clipped to bounds and blocked by obstacles. "
            "- Agent: blue circle, radius 15 pixels.\n"
            "Observation: {'agent_pos':[x,y], 'objects':[{'type':'zone','pos':[x,y],'size':[w,h],'purpose':'goal'}, "
            "{'type':'obstacle','pos':[x,y],'size':[w,h]}, ...], 'step_count':N}. "
            "Episode ends when the agent is fully inside the goal zone or after max_steps."
        )
