import collections
import io
import json
import multiprocessing
import sys
import time
import traceback
from copy import deepcopy
from pathlib import Path
from typing import Any

from PIL import Image

import gymnasium
import numpy as np
import torch
from absl import flags
from absl import logging
from dreamplace import PlaceDB
from gymnasium import spaces
from gymnasium.core import ActType
from gymnasium.core import ObsType
from gymnasium.error import AlreadyPendingCallError
from gymnasium.spaces.utils import is_space_dtype_shape_equiv
from gymnasium.vector import AsyncVectorEnv
from gymnasium.vector import AutoresetMode
from gymnasium.vector import SyncVectorEnv
from gymnasium.vector.async_vector_env import AsyncState
from gymnasium.vector.utils import concatenate
from gymnasium.vector.utils import iterate
from gymnasium.vector.utils import write_to_shared_memory

from veoplace.environment.utils import fill_mask_regions
from veoplace.environment.utils import get_hard_macro_ordering_from_placedb
from veoplace.utils import generate_short_names
from veoplace.utils.constants import MAX_NUM_NODES
from veoplace.utils.constants import VEOPLACE_GRID_SIZE
from veoplace.utils.dreamplace import get_global_hpwl
from veoplace.utils.render import render
from veoplace.utils.render import render_all_suggestions
from veoplace.utils.render import render_suggestion


class PlacementEnv:

    def __init__(self,
            benchmark: str,
            benchmark_path: str | Path,
            dreamplace_params,
            grid: int = 84,
            output_dir: str | Path = "dreamplace_results",
            run_dreamplace: bool = True,
            max_num_nodes: int = MAX_NUM_NODES,
    ):
        self.benchmark = benchmark
        self.benchmark_path = Path(benchmark_path)

        # Load color config from configs/<benchmark>/color_config.json
        from veoplace.utils.benchmark_registry import config_dir
        cfg_path = config_dir(benchmark) / 'color_config.json'
        if cfg_path.exists():
            self.color_config = json.load(open(cfg_path, 'r', encoding='utf-8'))
        else:
            logging.warning(
                    f"No color_config.json found for {benchmark}, using empty config")
            self.color_config = {}

        self.run_dreamplace = run_dreamplace
        self.output_dir = Path(output_dir)
        self.dreamplace_params = dreamplace_params

        # Create PlaceDB and get macro ordering
        self.placedb = PlaceDB.PlaceDB()
        self.placedb(self.dreamplace_params)

        hard_macro_names, _ = get_hard_macro_ordering_from_placedb(
                benchmark_name=benchmark,
                placedb=self.placedb
        )
        self.node_name_list = hard_macro_names

        # Generate shorter names for the nodes
        self.node_name_to_short_name = generate_short_names(
                self.node_name_list)

        # Need to add np.string_ for compatibility with DreamPlace
        if not hasattr(np, "string_"):
            np.string_ = np.bytes_

        # Limit the number of macros that will be placed by the agent
        if len(self.node_name_list) > max_num_nodes:
            logging.info('Placing %d macros (limited from %d): %s',
                         max_num_nodes, len(self.node_name_list),
                         self.node_name_list[:max_num_nodes])
            self.node_name_list = self.node_name_list[:max_num_nodes]
        else:
            logging.info('Placing %d macros: %s',
                         len(self.node_name_list), self.node_name_list)

        self.placed_num_macro = len(self.node_name_list)

        self.canvas_width = self.placedb.width
        self.canvas_height = self.placedb.height

        self.ratio_x = self.canvas_width / grid
        self.ratio_y = self.canvas_height / grid
        self.grid_size = grid

        # Re-usable vars for each episode
        self.canvas = np.zeros((self.grid_size, self.grid_size), dtype=np.bool_)
        self._mask = np.zeros((self.grid_size, self.grid_size), dtype=np.bool_)
        self.mask = None
        self.last_dreamplace_metrics = None  # Stores full metrics dict at episode end

        # Get all nets and nodes.
        self.nets = sorted(name.decode() for name in self.placedb.net_names)

        # Decode movable_names and exclude the groups of standard cells since
        # DreamPlace will be placing those macros
        grouped_movable_names = []
        ungrouped_movable_names = []
        for name in self.placedb.node_names[:self.placedb.num_movable_nodes]:
            if "Grp" not in name.decode():
                ungrouped_movable_names.append(name.decode())
            else:
                grouped_movable_names.append(name.decode())

        logging.info(
                'Found %d non-grouped movable nodes',
                len(ungrouped_movable_names),
        )
        logging.info(
                'Found %d grouped movable nodes',
                len(grouped_movable_names),
        )

        fixed_names = [
                name.decode()
                for name in self.placedb.node_names[
                    self.placedb.num_movable_nodes: self.placedb.num_movable_nodes
                                                    + self.placedb.num_terminals
                ]
        ]

        logging.info(
                'Found %d fixed nodes',
                len(fixed_names)
        )

        self.nets = sorted(name.decode() for name in self.placedb.net_names)
        self.nodes = sorted(
                ungrouped_movable_names + fixed_names + grouped_movable_names)
        self.net_to_idx = {name: idx  # str  → int
                           for name, idx in
                           self.placedb.net_name2id_map.items()}

        # Unescape bracket names to match get_hard_macro_ordering_from_placedb
        self.node_to_idx = {
                name.replace('\\[', '[').replace('\\]', ']'): idx
                for name, idx in
                self.placedb.node_name2id_map.items()
        }

        # Initialize arrays for node properties
        # Two sizes needed:
        # - n_nodes_total: for boolean masks indexed by pin2node_map (can reference filler nodes)
        # - n_physical_nodes: for position/size arrays (node_x, node_y are only physical nodes)
        n_nodes_total = self.placedb.num_nodes
        n_physical_nodes = len(
                self.placedb.node_x)  # physical nodes only (no fillers)
        n_nets = len(self.nets)

        # Initialize node_is_fixed FIRST (needed for filtering)
        self.node_is_fixed = np.zeros(n_nodes_total, dtype=bool)
        self.node_is_fixed[
            self.placedb.num_movable_nodes:self.placedb.num_physical_nodes] = True

        # Filter to relevant nodes (hard macros + fixed terminals only, exclude standard cells)
        relevant_node_mask = np.zeros(n_nodes_total, dtype=bool)
        hard_macro_indices = [self.node_to_idx[name] for name in
                              self.node_name_list]
        relevant_node_mask[hard_macro_indices] = True
        relevant_node_mask[self.node_is_fixed] = True
        relevant_node_indices = np.where(relevant_node_mask)[0]

        # Store for reverse mapping in comp_hpwl
        self.relevant_node_indices = relevant_node_indices

        # Create mapping: original_node_id -> reduced_matrix_column_id
        self.node_map = np.full(n_nodes_total, -1, dtype=np.int32)
        self.node_map[relevant_node_indices] = np.arange(
                len(relevant_node_indices))

        # Filter to nets touching at least one relevant node
        relevant_net_mask = np.zeros(n_nets, dtype=bool)
        for net_id, pin_ids in enumerate(self.placedb.net2pin_map):
            nodes = self.placedb.pin2node_map[pin_ids]
            if np.any(relevant_node_mask[nodes]):
                relevant_net_mask[net_id] = True
        relevant_net_indices = np.where(relevant_net_mask)[0]
        self.relevant_net_indices = relevant_net_indices  # Store for later use

        # Build reduced matrices
        n_rel_nets, n_rel_nodes = len(relevant_net_indices), len(
                relevant_node_indices)
        logging.info(
                f"Filtering: {n_nets} nets → {n_rel_nets}, {n_physical_nodes} physical nodes → {n_rel_nodes}")

        self.net_node_matrix = np.zeros((n_rel_nets, n_rel_nodes), dtype=bool)
        self.net_node_offsets = np.zeros((n_rel_nets, n_rel_nodes, 2),
                                         dtype=np.float32)

        for new_net_idx, original_net_idx in enumerate(relevant_net_indices):
            pin_ids = self.placedb.net2pin_map[original_net_idx]
            nodes = self.placedb.pin2node_map[pin_ids]

            # Filter to only relevant nodes on this net
            rel_mask = relevant_node_mask[nodes]
            rel_nodes = nodes[rel_mask]
            rel_pin_ids = pin_ids[rel_mask]

            # Map to reduced column indices
            col_indices = self.node_map[rel_nodes]

            self.net_node_matrix[new_net_idx, col_indices] = True
            self.net_node_offsets[new_net_idx, col_indices, 0] = \
                self.placedb.pin_offset_x[rel_pin_ids]
            self.net_node_offsets[new_net_idx, col_indices, 1] = \
                self.placedb.pin_offset_y[rel_pin_ids]

        # Filter net weights to match reduced net indices
        self.net_weights = self.placedb.net_weights[
            relevant_net_indices].astype(np.float32)

        # ── orientation-aware width / height ──────────────────────────────
        rotate90 = {"E", "W", "FE", "FW"}  # 90° / 270° rotations

        size_x = self.placedb.node_size_x  # ndarray, len = all nodes
        size_y = self.placedb.node_size_y
        n = len(size_x)

        orient = self.placedb.node_orient  # may be shorter
        orient = np.char.decode(orient)

        rot90 = np.zeros(n, dtype=bool)  # default: no rotation
        rot90[: len(orient)] = np.isin(orient, list(rotate90))

        widths = np.where(rot90, size_y, size_x)  # swap w/h where rotated
        heights = np.where(rot90, size_x, size_y)

        self.node_dimensions = np.stack([widths, heights],
                                        axis=1).astype(np.float32)

        # Also update placedb with the new dimensions
        self.placedb.node_size_x = widths
        self.placedb.node_size_y = heights

        # Create 2D arrays for node properties (physical nodes only - matches node_x/node_y size)
        self.node_raw_positions = np.zeros((n_physical_nodes, 2),
                                           dtype=np.float32)  # [raw_x, raw_y]

        self.node_raw_positions[:, 0] = self.placedb.node_x
        self.node_raw_positions[:, 1] = self.placedb.node_y

        # Divide by ratio to get normalized dimensions, ceiling for grid coordinates
        self.node_normalized_dimensions = np.zeros_like(self.node_dimensions)
        self.node_normalized_dimensions[:, 0] = self.node_dimensions[:,
        0] / self.ratio_x  # widths
        self.node_normalized_dimensions[:, 1] = self.node_dimensions[:,
        1] / self.ratio_y  # heights
        self.node_normalized_dimensions_int = np.ceil(
                self.node_normalized_dimensions)
        self.node_normalized_dimensions_int = np.maximum(1,
                                                         self.node_normalized_dimensions_int).astype(
                int)

        # Additional useful arrays (physical nodes only)
        self.node_is_placed = np.zeros(n_physical_nodes, dtype=bool)
        self.node_placed = np.zeros((n_physical_nodes, 2), dtype=np.float32)

        # Finish up color grouping data
        self.color_to_macros = collections.defaultdict(list)
        self.first_macro_of_color_group = {}
        self.color_group_count = collections.defaultdict(int)
        for node_name in self.node_name_list:
            if node_name in self.color_config:
                color = self.color_config[node_name]
                if color not in self.color_group_count:
                    # This is the first macro of this color group
                    self.first_macro_of_color_group[node_name] = color
                self.color_group_count[color] += 1
                self.color_to_macros[color].append(node_name)
        self.color_group_to_first_macro = {
                v: k for k, v in
                self.first_macro_of_color_group.items()}
        self._reset()

    def get_prompt_context(self):
        return dict(
                grid_size=self.grid_size,
                canvas_width=self.canvas_width,
                canvas_height=self.canvas_height,
                ratio_x=self.ratio_x,
                ratio_y=self.ratio_y,
                color_config=self.color_config,

                node_name_list=self.node_name_list,
                placed_num_macro=self.placed_num_macro,
                num_macro_placed=self.num_macro_placed,
                node_to_idx=self.node_to_idx,
                node_name_to_short_name=self.node_name_to_short_name,

                node_dims_int=self.node_normalized_dimensions_int,
                node_dims_real=self.node_dimensions,

                color_group_count=self.color_group_count,
                color_group_to_first_macro=self.color_group_to_first_macro,
                first_macro_of_color_group=self.first_macro_of_color_group,

                node_pos=self.node_pos,  # include only if some prompts use it

                # DREAMPlace static arrays (same across episodes for a benchmark)
                dreamplace_node_size_x=self.placedb.node_size_x,
                dreamplace_node_size_y=self.placedb.node_size_y,
                dreamplace_node_names=self.placedb.node_names,
                dreamplace_num_movable=self.placedb.num_movable_nodes,
                dreamplace_num_terminals=self.placedb.num_terminals,
        )

    def _reset(self):
        self.node_pos = {}
        self.mask = self._mask
        self.canvas.fill(0)
        self.num_macro_placed = 0
        self.node_placed.fill(0)
        self.node_is_placed.fill(0)

    def get_dreamplace_metrics(self, anchor_indices=None,
            output_dir=None) -> dict:
        """
        Runs DREAMPlace to place standard cells and returns comprehensive metrics.

        Uses anchor mode: macros can move but pay spring penalty for drifting
        from VLM positions. Weight controlled by params.anchor_weight.

        Args:
            anchor_indices: List of PlaceDB node indices to anchor.
            output_dir: Directory for DREAMPlace renders. If None, uses self.output_dir.

        Returns:
            dict with keys: hpwl, congestion, overflow, rmst_wl, wirelength,
                           density, max_density, objective, macro_overlap
        """
        params = self.dreamplace_params

        # Use provided output_dir or fall back to self.output_dir
        render_dir = output_dir if output_dir else str(self.output_dir)

        # Build node_pos dict from self.node_pos (set by evaluate_with_suggestions)
        node_pos = {}
        for name in self.node_pos.keys():
            if name in self.node_to_idx:
                node_pos[name] = (float(self.node_pos[name][0]),
                                  float(self.node_pos[name][1]))

        # Check for freeze mode vs anchor mode
        freeze_macro_flag = getattr(params, 'freeze_macro_flag', False)
        anchor_flag = getattr(params, 'anchor_flag', False)

        # Build freeze indices if freeze mode is enabled
        # NOTE: Only pass indices, NOT names. Passing both triggers dual PlaceDB mode
        # which requires .pl files (bookshelf only). Indices-only uses gradient zeroing.
        freeze_indices = None
        if freeze_macro_flag:
            freeze_indices = []
            for name in self.node_name_list[:self.num_macro_placed]:
                if name in self.node_to_idx:
                    freeze_indices.append(self.node_to_idx[name])
            logging.info("Freeze mode: zeroing gradients for %d macros",
                         len(freeze_indices))

        dp_start = time.perf_counter()
        # with suppress_all_output():
        dp_result = get_global_hpwl(
                params=params,
                node_pos=node_pos,
                placedb=self.placedb,
                output_dir=render_dir,
                render_final=params.plot_flag,
                # Freeze mode - gradient zeroing only (no dual PlaceDB)
                freeze_macro_indices=freeze_indices,
                freeze_macro_names=None,
                # Anchor mode (soft constraint)
                anchor_node_indices=anchor_indices if anchor_flag else None,
                anchor_weight=params.anchor_weight if anchor_flag else None,
        )
        dp_elapsed = time.perf_counter() - dp_start
        logging.info("DREAMPlace completed in %.2fs", dp_elapsed)

        # 6) Update placedb reference
        self.placedb = dp_result['placedb']

        # 7) Build result metrics (exclude placedb and processed_metrics)
        result_metrics = {k: v for k, v in dp_result.items()
                          if k not in ('placedb', 'processed_metrics')}

        # 8) Add full node positions from DREAMPlace for rendering
        if self.placedb is not None:
            result_metrics['dreamplace_node_x'] = self.placedb.node_x.copy()
            result_metrics['dreamplace_node_y'] = self.placedb.node_y.copy()
            result_metrics[
                'dreamplace_node_size_x'] = self.placedb.node_size_x.copy()
            result_metrics[
                'dreamplace_node_size_y'] = self.placedb.node_size_y.copy()
            result_metrics[
                'dreamplace_node_names'] = self.placedb.node_names.copy()
            result_metrics[
                'dreamplace_num_movable'] = self.placedb.num_movable_nodes
            result_metrics[
                'dreamplace_num_terminals'] = self.placedb.num_terminals

            # 9) Update node_pos with final DREAMPlace positions (always, for displacement calc)
            for name in self.node_name_list[:self.placed_num_macro]:
                if name in self.node_to_idx:
                    idx = self.node_to_idx[name]
                    # Get DREAMPlace's final position
                    new_x = float(self.placedb.node_x[idx])
                    new_y = float(self.placedb.node_y[idx])
                    # Get size from placedb
                    real_w = float(self.placedb.node_size_x[idx])
                    real_h = float(self.placedb.node_size_y[idx])
                    # Compute grid sizes
                    grid_w = int(real_w / self.ratio_x)
                    grid_h = int(real_h / self.ratio_y)
                    self.node_pos[name] = (new_x, new_y, grid_w, grid_h,
                                           real_w, real_h)

            # Build positions_array from node_pos (same logic as step() done block)
            pos = np.empty((self.placed_num_macro, 4), np.float32)
            for i, n in enumerate(self.node_name_list[:self.placed_num_macro]):
                if n in self.node_pos:
                    pos[i] = self.node_pos[n][:-2]  # x, y, grid_w, grid_h
            result_metrics['positions_array'] = pos
            result_metrics['node_pos'] = self.node_pos.copy()
            result_metrics['ratio_x'] = self.ratio_x
            result_metrics['ratio_y'] = self.ratio_y
            result_metrics['max_width'] = self.canvas_width
            result_metrics['max_height'] = self.canvas_height

        return result_metrics

    def evaluate_with_suggestions(self, suggestions: dict = None,
            output_dir: str = None) -> dict:
        """
        Apply VLM suggestions and run DREAMPlace evaluation in one step.

        Args:
            suggestions: A dictionary mapping macro names to coordinate lists (corners).
                         e.g., {'Macro1': [[0,0], [10,10], ...]}
                         If None or empty, runs Baseline (all macros movable).
            output_dir: Directory for DREAMPlace renders. If None, uses self.output_dir.

        Returns:
            dict with DREAMPlace metrics (hpwl, congestion, macro_overlap, etc.)
        """
        # 1. Reset environment to clean slate
        self._reset()

        anchor_indices = []
        suggested_positions = {}  # Store suggested positions for displacement calc

        if suggestions:
            for macro_name, coords in suggestions.items():
                if coords is not None and len(coords) > 0:
                    # A. Extract Bottom-Left Position from corners
                    xs = [c[0] for c in coords]
                    ys = [c[1] for c in coords]
                    min_x = min(xs)
                    min_y = min(ys)

                    # B. Convert grid coordinates to canvas/micron coordinates
                    canvas_x = min_x * self.ratio_x
                    canvas_y = min_y * self.ratio_y

                    # C. Update node_pos with the suggested position
                    if macro_name in self.node_to_idx:
                        node_idx = self.node_to_idx[macro_name]
                        size_x, size_y = self.node_normalized_dimensions_int[
                            node_idx]
                        real_w = float(self.node_dimensions[node_idx, 0])
                        real_h = float(self.node_dimensions[node_idx, 1])

                        self.node_pos[macro_name] = (
                                canvas_x, canvas_y, size_x, size_y, real_w,
                                real_h
                        )

                        anchor_indices.append(node_idx)
                        # Store suggested position for displacement calculation
                        suggested_positions[macro_name] = (canvas_x, canvas_y)

        # 2. Run DREAMPlace with anchor mode
        metrics = self.get_dreamplace_metrics(
                anchor_indices=anchor_indices,
                output_dir=output_dir
        )

        # 3. Compute displacement between suggestions and final positions
        if suggested_positions:
            displacements = []
            for macro_name, (sug_x, sug_y) in suggested_positions.items():
                if macro_name in self.node_pos:
                    final_x, final_y = self.node_pos[macro_name][:2]
                    dist = np.sqrt(
                            (final_x - sug_x) ** 2 + (final_y - sug_y) ** 2)
                    displacements.append(dist)

            if displacements:
                displacements = np.array(displacements)
                canvas_diagonal = np.sqrt(
                        self.canvas_width ** 2 + self.canvas_height ** 2)

                metrics['displacement_mean'] = float(np.mean(displacements))
                metrics['displacement_max'] = float(np.max(displacements))
                metrics['displacement_std'] = float(np.std(displacements))
                # Normalized by canvas diagonal for cross-benchmark comparison
                metrics['displacement_mean_normalized'] = float(
                        np.mean(displacements) / canvas_diagonal)
                metrics['displacement_max_normalized'] = float(
                        np.max(displacements) / canvas_diagonal)

        return metrics

    def reset(self):
        self._reset()

        next_macro = self.node_name_list[self.num_macro_placed]
        next_macro_idx = self.node_to_idx[next_macro]
        next_x, next_y = self.node_normalized_dimensions[next_macro_idx]
        next_x_int, next_y_int = self.node_normalized_dimensions_int[
            next_macro_idx]
        self.mask = self.get_mask(next_x_int, next_y_int)

        states = np.stack((self.canvas, self.mask), axis=0)
        done = False
        reward = 0

        return states, reward, done, np.array([next_x, next_y])

    def step(self, action):
        try:
            row, col = divmod(action, self.grid_size)
            reward = 0.0

            next_macro = self.node_name_list[self.num_macro_placed]
            node_idx = self.node_to_idx[next_macro]
            size_x, size_y = self.node_normalized_dimensions_int[node_idx]

            # Update the canvas
            self.canvas[row:row + size_y, col:col + size_x] = 1.0
            self.num_macro_placed += 1

            # Store the node's position and size
            self.node_pos[next_macro] = (
                    col * self.ratio_x, row * self.ratio_y, size_x, size_y,
                    float(self.node_dimensions[node_idx, 0]),
                    float(self.node_dimensions[node_idx, 1]))
            self.node_placed[node_idx] = [row, col]
            self.node_is_placed[node_idx] = True

            # Update the raw position of the node
            raw_x = col * self.ratio_x  # grid  ->  micron
            raw_y = row * self.ratio_y
            self.node_raw_positions[node_idx, 0] = raw_x
            self.node_raw_positions[node_idx, 1] = raw_y

            hpwl = None
            rudy_congestion_score = None
            overlap = None
            density = None

            # Compute the next macro information
            if self.num_macro_placed < self.placed_num_macro:
                done = False
                next_macro = self.node_name_list[self.num_macro_placed]
                next_macro_idx = self.node_to_idx[next_macro]
                next_x, next_y = self.node_normalized_dimensions[next_macro_idx]
                next_x_int, next_y_int = self.node_normalized_dimensions_int[
                    next_macro_idx]

            else:
                done = True
                next_x, next_y = 0, 0
                next_x_int, next_y_int = 0, 0

                if self.run_dreamplace:
                    t0 = time.perf_counter()
                    # DreamPlace resets torch random state, so we need to save it
                    # and restore it
                    with torch.random.fork_rng():
                        metrics = self.get_dreamplace_metrics()

                    hpwl = metrics['hpwl']
                    rudy_congestion_score = metrics['congestion']
                    overlap = metrics['macro_overlap']
                    density = metrics['density']

                    # Store full metrics for wrapper to access (includes dreamplace arrays)
                    self.last_dreamplace_metrics = metrics

                    logging.info(
                            'HPWL: %.2f, RUDY Congestion: %.4f, Overlap EST: %.2f, CT Density: %.4f',
                            hpwl,
                            rudy_congestion_score,
                            overlap,
                            density
                    )
                    logging.info(
                            'Time taken for DREAMPlace + Metrics: %.2f seconds',
                            time.perf_counter() - t0
                    )
                else:
                    hpwl = 0
                    rudy_congestion_score = 0  # Set a default value
                    overlap = 0
                    density = 0
                reward -= hpwl

            self.mask = self.get_mask(next_x_int, next_y_int)
            states = np.stack((self.canvas, self.mask),
                              axis=0)

            return states, reward, done, np.array(
                    [next_x, next_y]
            ), hpwl, rudy_congestion_score, overlap, density
        except Exception:  # noqa
            try:
                macro = self.node_name_list[self.num_macro_placed]
            except Exception:  # noqa
                macro = "<unknown>"

            logging.exception(
                    "step() failed at macro %s (placed %d / %d, action=%s)",
                    macro,
                    self.num_macro_placed,
                    self.placed_num_macro,
                    action,
            )
            raise

    def get_mask(self, next_x, next_y):
        self._mask.fill(0)
        mask = self._mask
        placed_node_indices = np.where(self.node_is_placed)[0]

        if len(placed_node_indices) > 0:
            # stored as [row, col]
            row_positions = self.node_placed[placed_node_indices, 0]
            col_positions = self.node_placed[placed_node_indices, 1]

            # sizes in cells (W,H)
            node_sizes_x = self.node_normalized_dimensions_int[
                placed_node_indices, 0]  # width  (cols)
            node_sizes_y = self.node_normalized_dimensions_int[
                placed_node_indices, 1]  # height (rows)

            # use HEIGHT for rows, WIDTH for cols (these 4 lines are the key fix)
            start_row = np.maximum(0, row_positions - int(next_y) + 1).astype(
                    np.int32)
            start_col = np.maximum(0, col_positions - int(next_x) + 1).astype(
                    np.int32)
            end_row = np.minimum(row_positions + node_sizes_y - 1,
                                 self.grid_size - 1).astype(np.int32)
            end_col = np.minimum(col_positions + node_sizes_x - 1,
                                 self.grid_size - 1).astype(np.int32)

            mask = fill_mask_regions(mask, start_row, start_col, end_row,
                                     end_col)

        # out-of-bounds bands: rows use HEIGHT, cols use WIDTH (swap fixed)
        mask[self.grid_size - int(next_y) + 1:, :] = 1  # rows (height)
        mask[:, self.grid_size - int(next_x) + 1:] = 1  # cols (width)
        return mask

    def render_gemini_suggestion(self, suggestion, hex_color="#FF0000",
            return_bytes=False):
        return render_suggestion(node_pos=self.node_pos,
                                 ratio_x=self.ratio_x,
                                 ratio_y=self.ratio_y,
                                 max_width=self.canvas_width,
                                 max_height=self.canvas_height,
                                 color_config=self.color_config,
                                 grid=self.grid_size,
                                 suggestion=suggestion,
                                 hex_color=hex_color,
                                 return_bytes=return_bytes)

    def render_all_gemini_suggestions(self, suggestions, return_bytes=False,
            highlight_nodes=None):
        return render_all_suggestions(node_pos=self.node_pos,
                                      ratio_x=self.ratio_x,
                                      ratio_y=self.ratio_y,
                                      node_name_to_short_name=self.node_name_to_short_name,
                                      max_width=self.canvas_width,
                                      max_height=self.canvas_height,
                                      color_config=self.color_config,
                                      grid=self.grid_size,
                                      highlight_nodes=highlight_nodes,
                                      all_suggestions=suggestions,
                                      return_bytes=return_bytes)

    # we need another version of render that does it in a thread
    def render(self, return_bytes=False, highlight_nodes=None):

        return render(self.node_pos,
                      # ratio_x=self.ratio_x,
                      # ratio_y=self.ratio_y,
                      max_width=self.canvas_width,
                      max_height=self.canvas_height,
                      color_config=self.color_config,
                      highlight_nodes=highlight_nodes,
                      grid=self.grid_size,
                      return_bytes=return_bytes,
                      node_name_to_short_name=self.node_name_to_short_name, )

    def render_full_canvas(self, return_bytes=False, highlight_nodes=None):
        # Use placedb data for accurate rendering after DREAMPlace
        return render(self.node_pos,
                      max_width=self.canvas_width,
                      max_height=self.canvas_height,
                      color_config=self.color_config,
                      grid=self.grid_size,
                      return_bytes=return_bytes,
                      highlight_nodes=highlight_nodes,
                      node_name_to_short_name=self.node_name_to_short_name,
                      # DREAMPlace placedb data for full placement rendering
                      dreamplace_node_x=self.placedb.node_x if self.placedb else None,
                      dreamplace_node_y=self.placedb.node_y if self.placedb else None,
                      dreamplace_node_size_x=self.placedb.node_size_x if self.placedb else None,
                      dreamplace_node_size_y=self.placedb.node_size_y if self.placedb else None,
                      dreamplace_num_movable=self.placedb.num_movable_nodes if self.placedb else None,
                      dreamplace_num_terminals=self.placedb.num_terminals if self.placedb else None,
                      )

    def render_comparison(self, suggestions=None, return_bytes=False):
        """Render side-by-side: final placement (left) + suggestions overlay (right).

        If no suggestions provided, returns None (nothing to compare).
        """
        if not suggestions:
            return None

        # Left: final placement with stdcells
        left_img = self.render_full_canvas(return_bytes=False)

        # Right: macros only with suggestion boxes (no stdcells - cleaner for seeing displacement)
        # Macros are semi-transparent so suggestion boxes show through
        right_img = self.render_all_gemini_suggestions(suggestions,
                                                       return_bytes=False)

        # Combine side-by-side (convert RGBA to RGB if needed)
        if right_img.mode == 'RGBA':
            right_img = right_img.convert('RGB')
        if left_img.mode == 'RGBA':
            left_img = left_img.convert('RGB')
        combined = Image.new('RGB', (left_img.width + right_img.width,
                                     left_img.height))
        combined.paste(left_img, (0, 0))
        combined.paste(right_img, (left_img.width, 0))

        if return_bytes:
            buffer = io.BytesIO()
            combined.save(buffer, format="PNG", dpi=(300, 300))
            buffer.seek(0)
            return buffer.getvalue()
        return combined

    def comp_hpwl(self):
        """
        Vectorised  HPWL (ports optional, Steiner cost skipped).

        Returns
        -------
        hpwl : float
        """
        n_rel_nets = len(self.relevant_net_indices)

        # --------------------------------------------------------------------
        # 1. pin-location arrays  (one entry per (net,node) edge that matters)
        # --------------------------------------------------------------------
        # net_idx and node_idx are in REDUCED matrix space
        net_idx, node_idx = np.where(self.net_node_matrix)  # (E,)

        # Map reduced node_idx back to original node IDs for accessing other arrays
        original_node_idx = self.relevant_node_indices[node_idx]

        placed_mask = self.node_is_placed[original_node_idx] | \
                      self.node_is_fixed[original_node_idx]
        if not placed_mask.any():  # nothing has been placed yet
            return 0.0
        net_idx = net_idx[placed_mask]
        node_idx = node_idx[placed_mask]
        original_node_idx = original_node_idx[placed_mask]

        half_w = 0.5 * self.node_dimensions[original_node_idx, 0]
        half_h = 0.5 * self.node_dimensions[original_node_idx, 1]
        x_off = self.net_node_offsets[net_idx, node_idx, 0]
        y_off = self.net_node_offsets[net_idx, node_idx, 1]

        # movable nodes use (grid_x * ratio), fixed nodes use raw_x
        is_fixed = self.node_is_fixed[original_node_idx]
        base_x = np.where(is_fixed,
                          self.node_raw_positions[original_node_idx, 0],
                          self.node_placed[original_node_idx, 1] * self.ratio_x)
        base_y = np.where(is_fixed,
                          self.node_raw_positions[original_node_idx, 1],
                          self.node_placed[original_node_idx, 0] * self.ratio_y)

        pin_x = base_x + half_w + x_off
        pin_y = base_y + half_h + y_off

        # --------------------------------------------------------------------
        # 2. update net bounding-boxes in one shot
        # --------------------------------------------------------------------
        max_x = np.zeros(n_rel_nets)
        max_x.fill(-np.inf)
        max_y = np.zeros_like(max_x)
        max_y.fill(-np.inf)
        min_x = np.zeros_like(max_x)
        min_x.fill(np.inf)
        min_y = np.zeros_like(max_x)
        min_y.fill(np.inf)

        np.maximum.at(max_x, net_idx, pin_x)
        np.maximum.at(max_y, net_idx, pin_y)
        np.minimum.at(min_x, net_idx, pin_x)
        np.minimum.at(min_y, net_idx, pin_y)

        # --------------------------------------------------------------------
        # 4. HPWL + weights
        # --------------------------------------------------------------------
        hpwl_per_net = (max_x - min_x) + (max_y - min_y)

        # Unused nets have inf, they must be excluded
        invalid = (
                ~np.isfinite(min_x) | ~np.isfinite(max_x) |
                ~np.isfinite(min_y) | ~np.isfinite(max_y)
        )
        hpwl_per_net[invalid] = 0.0  # <- the "zero-out"

        # Weight each net
        hpwl_per_net *= self.net_weights  # broadcast weights

        return float(hpwl_per_net.sum())

    def close(self):
        pass

    def train(self):
        self.training = True

    def eval(self):
        self.training = False


class PlacementGymEnv(gymnasium.Env):
    """
    A Gymnasium wrapper for your placement Env, allowing vectorized usage.
    """
    metadata = {"render_modes": ["human"], "render_fps": 4}

    def __init__(
            self,
            benchmark: str,
            benchmark_path,
            dreamplace_params,
            grid: int = 84,
            output_dir=None,
            base_env=None,
            run_dreamplace: bool = True,
            max_num_nodes: int = MAX_NUM_NODES,
    ):
        super().__init__()

        # Store parameters
        self.benchmark = benchmark
        self.grid_size = grid

        # Instantiate your original environment
        if base_env is not None:
            # If a base_env is provided, use it directly
            self.base_env = base_env
        else:
            self.base_env = PlacementEnv(
                    benchmark=self.benchmark,
                    benchmark_path=benchmark_path,
                    dreamplace_params=dreamplace_params,
                    grid=self.grid_size,
                    output_dir=output_dir,
                    run_dreamplace=run_dreamplace,
                    max_num_nodes=max_num_nodes,
            )
        self.base_env.eval()

        # Define the action space:
        # We treat an action as [0..grid*grid - 1], which you decode as (row, col).
        self.action_space = spaces.Discrete(self.grid_size * self.grid_size)

        # Define the observation space:
        # The underlying Env returns a 2×grid×grid array/tensor (canvas, mask).
        obs_shape = (2, self.grid_size, self.grid_size)
        self.observation_space = spaces.Box(
                low=0.0,
                high=1.0,
                shape=obs_shape,
                dtype=np.float32
        )

    def __getattr__(self, name):
        """
        Allow access to the base_env's attributes directly.
        """
        # Check if the attribute exists in base_env
        if hasattr(self.base_env, name):
            return getattr(self.base_env, name)
        else:
            # Raise AttributeError to mimic default `getattr` behavior
            raise AttributeError(
                    f"'{type(self).__name__}' object has no attribute '{name}'")

    def reset(self, *, seed=None, options=None):
        """
        Resets the environment and returns (observation, info).
        Gymnasium requires a two-item return for reset().
        """
        if seed is not None:
            self.np_random, _ = gymnasium.utils.seeding.np_random(seed)

        states, reward, done, next_xy = self.base_env.reset()
        info = {'next_xy': next_xy}

        return states, info

    def step(self, action):
        """
        Applies one step in the environment. Must return:
          (obs, reward, done, truncated, info)
        per Gymnasium's API.
        """
        # We use -1 as a no-op action (no placement).
        if action == -1:
            states = np.stack(
                    (self.base_env.canvas, self.base_env.mask),
                    axis=0
            ).astype(np.float32)

            return states, 0.0, False, False, {
                    'next_xy': np.array([0.0, 0.0], dtype=np.float32)}

        # The original Env expects an integer action,
        # which specifies the cell in the grid (row*grid + col).
        # So we pass this to the base_env step:
        states, reward, done, next_xy, hpwl, congestion, overlap, density = self.base_env.step(
                action)
        info = {'next_xy': next_xy}
        # If the base_env returns states as a torch.Tensor, convert to numpy.
        if isinstance(states, torch.Tensor):
            states = states.numpy()
        states = states.astype(np.float32)

        # Gymnasium wants "truncated" separate from "done" (terminations due to time limit vs. real done).
        truncated = False  # or put logic if you use a max_steps constraint

        # These keys must be present on every step.
        info['ratio_x'] = self.base_env.ratio_x
        info['ratio_y'] = self.base_env.ratio_y
        info['max_width'] = self.base_env.canvas_width
        info['max_height'] = self.base_env.canvas_height

        if done:
            info['hard_macro_hpwl'] = self.base_env.comp_hpwl()
            info['hpwl'] = hpwl
            info['congestion'] = congestion
            info['overlap'] = overlap
            info['density'] = density
            # Positions array needed by prepare_output_dict
            pos = np.empty((self.base_env.num_macro_placed, 4), np.int32)
            for i, n in enumerate(self.base_env.node_name_list[
                                      :self.base_env.num_macro_placed]):
                pos[i] = self.base_env.node_pos[n][:-2]
            info['positions_array'] = pos
            # Full node-pos dict (optional but handy for rendering)
            info["node_pos"] = self.base_env.node_pos.copy()

            # Add DREAMPlace full node positions for enhanced rendering
            if self.base_env.last_dreamplace_metrics is not None:
                metrics = self.base_env.last_dreamplace_metrics
                info['dreamplace_node_x'] = metrics['dreamplace_node_x']
                info['dreamplace_node_y'] = metrics['dreamplace_node_y']

        return states, float(reward), done, truncated, info

    def render(self, mode="human", return_bytes=False, highlight_nodes=None):
        """
        Optional: implement if you have a special rendering method
        or can defer to the base_env.render() method.
        """
        return self.base_env.render(
                return_bytes=return_bytes,
                highlight_nodes=highlight_nodes
        )

    def close(self):
        """
        Closes the environment, e.g. any window or internal resources.
        """
        self.base_env.close()


class SubsetStepVectorEnv(SyncVectorEnv):
    """Extension of SyncVectorEnv that allows stepping a subset of environments."""

    def step_subset(
            self, actions: ActType, mask: np.ndarray
    ) -> tuple[list, np.ndarray, np.ndarray, np.ndarray, dict[str, Any]]:
        """Steps through a subset of environments based on the provided mask.

        Args:
            actions: The actions to take in each environment.
            mask: Boolean array of shape (num_envs,) indicating which environments to step.

        Returns:
            The results only for environments that were stepped:
            (observations, rewards, terminations, truncations, infos)
        """
        assert mask.shape == (
                self.num_envs,), f"Mask shape {mask.shape} doesn't match number of environments ({self.num_envs},)"

        # Convert actions to an iterable if needed
        actions = iterate(self.action_space, actions)

        infos = {}

        # Track which environments were actually stepped
        active_indices = np.where(mask)[0]

        # Only step the environments where mask is True
        for i, (action, should_step) in enumerate(zip(actions, mask)):
            if not should_step:
                continue

            if self.autoreset_mode == AutoresetMode.NEXT_STEP:
                if self._autoreset_envs[i]:
                    self._env_obs[i], env_info = self.envs[i].reset()

                    self._rewards[i] = 0.0
                    self._terminations[i] = False
                    self._truncations[i] = False
                else:
                    (
                            self._env_obs[i],
                            self._rewards[i],
                            self._terminations[i],
                            self._truncations[i],
                            env_info,
                    ) = self.envs[i].step(action)
            elif self.autoreset_mode == AutoresetMode.DISABLED:
                # assumes that the user has correctly autoreset
                assert not self._autoreset_envs[i], f"{self._autoreset_envs=}"
                (
                        self._env_obs[i],
                        self._rewards[i],
                        self._terminations[i],
                        self._truncations[i],
                        env_info,
                ) = self.envs[i].step(action)
            elif self.autoreset_mode == AutoresetMode.SAME_STEP:
                (
                        self._env_obs[i],
                        self._rewards[i],
                        self._terminations[i],
                        self._truncations[i],
                        env_info,
                ) = self.envs[i].step(action)

                if self._terminations[i] or self._truncations[i]:
                    infos = self._add_info(
                            infos,
                            {"final_obs": self._env_obs[i],
                             "final_info": env_info},
                            i,
                    )

                    self._env_obs[i], env_info = self.envs[i].reset()
            else:
                raise ValueError(
                        f"Unexpected autoreset mode, {self.autoreset_mode}")

            infos = self._add_info(infos, env_info, i)

        # Update autoreset flags for the stepped environments
        self._autoreset_envs[mask] = np.logical_or(
                self._terminations[mask], self._truncations[mask]
        )

        # Only return the observations for active environments
        return (
                [self._env_obs[i] for i in active_indices],  # observations
                self._rewards[active_indices].copy(),  # ← add .copy()
                self._terminations[active_indices].copy(),  # ← add .copy()
                self._truncations[active_indices].copy(),  # ← add .copy()
                {k: v[active_indices].copy() for k, v in infos.items()}  # infos
        )

    def reset(
            self,
            *,
            seed: int | list[int] | None = None,
            options: dict[str, Any] | None = None,
    ) -> tuple[ObsType, dict[str, Any]]:
        """Resets each of the sub-environments and concatenate the results together.

        Args:
            seed: Seeds used to reset the sub-environments, either
                * ``None`` - random seeds for all environment
                * ``int`` - ``[seed, seed+1, ..., seed+n]``
                * List of ints - ``[1, 2, 3, ..., n]``
            options: Option information used for each sub-environment

        Returns:
            Concatenated observations and info from each sub-environment
        """
        if seed is None:
            seed = [None for _ in range(self

                                        .num_envs)]
        elif isinstance(seed, int):
            seed = [seed + i for i in range(self.num_envs)]
        assert (
                len(seed) == self.num_envs
        ), f"If seeds are passed as a list the length must match num_envs={self.num_envs} but got length={len(seed)}."

        if options is not None and "reset_mask" in options:
            reset_mask = options.pop("reset_mask")
            assert isinstance(
                    reset_mask, np.ndarray
            ), f"`options['reset_mask': mask]` must be a numpy array, got {type(reset_mask)}"
            assert reset_mask.shape == (
                    self.num_envs,
            ), f"`options['reset_mask': mask]` must have shape `({self.num_envs},)`, got {reset_mask.shape}"
            assert (
                    reset_mask.dtype == np.bool_
            ), f"`options['reset_mask': mask]` must have `dtype=np.bool_`, got {reset_mask.dtype}"
            assert np.any(
                    reset_mask
            ), f"`options['reset_mask': mask]` must contain a boolean array, got reset_mask={reset_mask}"

            self._terminations[reset_mask] = False
            self._truncations[reset_mask] = False
            self._autoreset_envs[reset_mask] = False

            infos = {}
            for i, (env, single_seed, env_mask) in enumerate(
                    zip(self.envs, seed, reset_mask)
            ):
                if env_mask:
                    self._env_obs[i], env_info = env.reset(
                            seed=single_seed, options=options
                    )

                    infos = self._add_info(infos, env_info, i)
        else:
            self._terminations = np.zeros((self.num_envs,), dtype=np.bool_)
            self._truncations = np.zeros((self.num_envs,), dtype=np.bool_)
            self._autoreset_envs = np.zeros((self.num_envs,), dtype=np.bool_)

            infos = {}
            for i, (env, single_seed) in enumerate(zip(self.envs, seed)):
                self._env_obs[i], env_info = env.reset(
                        seed=single_seed, options=options
                )

                infos = self._add_info(infos, env_info, i)

        # Concatenate the observations
        self._observations = concatenate(
                self.single_observation_space, self._env_obs, self._observations
        )

        # Set the rewards to 0
        self._rewards.fill(0.0)

        return deepcopy(
                self._observations) if self.copy else self._observations, infos

    def reset_single(
            self,
            index: int,
            *,
            seed: int | None = None,
            options: dict[str, Any] | None = None,
    ) -> tuple[ObsType, dict[str, Any]]:
        """
        Reset exactly one sub-environment.

        • Updates _env_obs and all bookkeeping tensors in-place.
        • Leaves the other environments completely untouched.
        • Mimics the info-merging behaviour of SyncVectorEnv.reset().
        """
        # 1) Hard-reset the chosen env and fetch (obs, info)
        obs, env_info = self.envs[index].reset(seed=seed, options=options)

        # 2) In-place bookkeeping -------------------------------------------------
        self._env_obs[index] = obs  # raw per-env cache
        self._rewards[index] = 0.0
        self._terminations[index] = False
        self._truncations[index] = False
        self._autoreset_envs[index] = False  # “needs reset” flag

        # Keep the big concatenated observation array current.
        # (_observations already has the right shape, so slice-assignment is cheap.)
        if self._observations is not None:
            self._observations[index] = obs
        else:  # first call – build the buffer lazily
            self._observations = concatenate(
                    self.single_observation_space, self._env_obs,
                    self._observations
            )

        # 3) Build a vector-style info dict that has the right slot updated
        infos: dict[str, Any] = {}
        infos = self._add_info(infos, env_info, index)

        # 4) Return – honour the “copy” flag used elsewhere in your class
        if self.copy:
            return deepcopy(obs), {k: deepcopy(v) for k, v in infos.items()}
        return obs, infos


def _async_subset_worker(
        index, env_fn, pipe, parent_pipe, shared_memory, error_q, autoreset_mode
):
    """
    Identical to Gymnasium’s default async worker, but recognises one
    extra command  ('step-noop', None)  that returns the cached observation
    with   reward = 0, terminated = truncated = False, info = {}.
    """

    env = env_fn()
    obs_space = env.observation_space
    action_space = env.action_space
    observation = None  # last observation cache
    autoreset = False
    parent_pipe.close()

    try:
        while True:
            cmd, data = pipe.recv()

            # ---------- new ------------------------------------------------
            if cmd == "step-noop":
                pipe.send(((observation, 0.0, False, False, {}), True))

            # ---------- existing commands (trimmed to essentials) ----------
            elif cmd == "reset":
                observation, info = env.reset(**data)
                if shared_memory:
                    write_to_shared_memory(obs_space, index, observation,
                                           shared_memory)
                    observation = None
                pipe.send(((observation, info), True))

            elif cmd == "step":
                observation, reward, terminated, truncated, info = env.step(
                        data)
                if shared_memory:
                    write_to_shared_memory(obs_space, index, observation,
                                           shared_memory)
                    observation = None
                pipe.send(((observation, reward, terminated, truncated, info),
                           True))

            elif cmd == "close":
                pipe.send((None, True))
                break

            elif cmd == "_call":
                name, args, kwargs = data

                # Handle scatter_suggestions: distribute list elements to workers
                if "scatter_suggestions" in kwargs:
                    sugg_list = kwargs.pop("scatter_suggestions")
                    # Extract this worker's suggestion, or None if out of bounds
                    if sugg_list is not None and index < len(sugg_list):
                        kwargs["suggestions"] = sugg_list[index]
                    else:
                        kwargs["suggestions"] = None

                # Handle scatter_output_dir: distribute output dirs to workers
                if "scatter_output_dir" in kwargs:
                    dir_list = kwargs.pop("scatter_output_dir")
                    if dir_list is not None and index < len(dir_list):
                        kwargs["output_dir"] = dir_list[index]
                    else:
                        kwargs["output_dir"] = None

                attr = env.get_wrapper_attr(name)
                pipe.send(((attr(*args, **kwargs) if callable(attr) else attr),
                           True))

            elif cmd == "_setattr":
                name, val = data
                env.set_wrapper_attr(name, val)
                pipe.send((None, True))

            elif cmd == "_check_spaces":
                omode, single_obs, single_act = data
                pipe.send(
                        (
                                (
                                        (
                                                single_obs == obs_space
                                                if omode == "same"
                                                else is_space_dtype_shape_equiv(
                                                        single_obs, obs_space)
                                        ),
                                        single_act == action_space,
                                ),
                                True,
                        )
                )
            else:
                raise RuntimeError(f"Unknown worker command {cmd!r}")

    except (KeyboardInterrupt, Exception):
        etype, evalue, _ = sys.exc_info()
        error_q.put((index, etype, evalue, traceback.format_exc()))
        pipe.send((None, False))
    finally:
        env.close()


class AsyncSubsetVectorEnv(AsyncVectorEnv):
    """
    AsyncVectorEnv extended with  step_subset(actions, mask).
    All original API remains unchanged.
    """

    def __init__(self, env_fns, **kwargs):
        # use our customised worker unless the caller overrides it
        kwargs.setdefault("worker", _async_subset_worker)
        super().__init__(env_fns, **kwargs)

    def step_subset(
            self,
            actions,
            mask: np.ndarray,
            timeout: int | float | None = None,
    ):
        """
        Step only the envs where mask[i] is True.
        Returns the same 5-tuple as SyncVectorEnv.step_subset
        (observations list, rewards, terminations, truncations, infos).

        inactive envs keep their state; their results are **omitted**
        from the returned tensors/lists.
        """
        assert mask.shape == (self.num_envs,), (
                f"mask must have shape ({self.num_envs},)"
        )
        self._assert_is_running()
        if self._state != AsyncState.DEFAULT:
            raise AlreadyPendingCallError(
                    f"step_subset while waiting for {self._state.value}",
                    str(self._state.value),

            )

        # 1) send appropriate command to every pipe
        for pipe, act, active in zip(
                self.parent_pipes, iterate(self.action_space, actions), mask
        ):
            pipe.send(("step", act) if active else ("step-noop", None))
        self._state = AsyncState.WAITING_STEP

        # 2) wait for every worker
        if not self._poll_pipe_envs(timeout):
            self._state = AsyncState.DEFAULT
            raise multiprocessing.TimeoutError("step_subset timed out")

        # 3) collect replies – keep only active envs
        obs_idxs, rew, term, trunc, infos, succ = [], [], [], [], {}, []
        for env_idx, (pipe, active) in enumerate(zip(self.parent_pipes, mask)):
            payload, ok = pipe.recv()
            succ.append(ok)
            if active and ok:
                # payload[0] is None when shared_memory=True -> ignore
                _, r, t, c, info = payload
                obs_idxs.append(env_idx)
                rew.append(r)
                term.append(t)
                trunc.append(c)
                infos = self._add_info(infos, info, env_idx)

        self._raise_if_errors(succ)

        # ------------------------------------------------------------------
        # build observations for the active envs
        # ------------------------------------------------------------------
        if self.shared_memory:
            # just take a slice/view from the shared buffer
            obs_list = [deepcopy(self.observations[i]) if self.copy else
                        self.observations[i]
                        for i in obs_idxs]
        else:
            # self.observations is already updated via concatenate() below
            obs_list = [self.observations[i] for i in obs_idxs]

        # keep local copy current if shared_memory is False
        if not self.shared_memory and obs_list:
            self.observations = concatenate(
                    self.single_observation_space, obs_list, self.observations
            )

        self._state = AsyncState.DEFAULT
        return (
                obs_list,
                np.asarray(rew, dtype=np.float64),
                np.asarray(term, dtype=np.bool_),
                np.asarray(trunc, dtype=np.bool_),
                infos,
        )


def make_env_fn(
        benchmark: str,
        benchmark_path: str | Path,
        dreamplace_params,
        grid: int,
        output_dir: str | Path | None = None,
        run_dreamplace: bool = True,
        absl_cfg: dict | None = None,
        gpu_id: int = 0,
        worker_id: int = 0,
        base_seed: int | None = None,
):
    """
    Returns a picklable callable that BUILDS the PlacementEnv inside the
    subprocess. dreamplace_params is passed in (it's picklable).
    """

    def _init():
        import random
        import torch
        from copy import deepcopy

        # Set per-worker RNG seeds for divergent placements
        if base_seed is not None:
            worker_seed = base_seed + worker_id
            random.seed(worker_seed)
            np.random.seed(worker_seed)
            torch.manual_seed(worker_seed)
            if torch.cuda.is_available():
                torch.cuda.manual_seed(worker_seed)
            logging.info("Worker %d using seed %d", worker_id, worker_seed)

        # Deep copy dreamplace_params and set per-worker seed
        worker_params = deepcopy(dreamplace_params)
        if base_seed is not None:
            worker_params.random_seed = base_seed + worker_id

        # Set GPU for this worker before any CUDA operations
        if torch.cuda.is_available():
            torch.cuda.set_device(gpu_id)
            logging.info("Worker %d using GPU %d", worker_id, gpu_id)

        # -------- re-apply flags inside the worker ------------------
        from absl import flags
        f = flags.FLAGS
        if not f.is_parsed():  # True in a fresh subprocess
            # create a "fake argv" so Abseil thinks it has parsed
            f([__file__])
        if absl_cfg:
            for k, v in absl_cfg.items():
                if hasattr(f, k):
                    setattr(f, k, v)
        # ------------------------------------------------------------

        # Configure absl logging for this worker process
        logging.set_verbosity(logging.INFO)
        logging.use_absl_handler()

        env = PlacementGymEnv(
                benchmark=benchmark,
                benchmark_path=benchmark_path,
                dreamplace_params=worker_params,
                grid=grid,
                output_dir=output_dir,
                run_dreamplace=run_dreamplace,
        )
        return env

    return _init


def create_vector_env(n_envs, benchmark, benchmark_path,
        dreamplace_params,
        grid=VEOPLACE_GRID_SIZE,
        run_dreamplace=True,
        output_dir=None,
        base_seed=None):
    """
    Create a vector environment with n_envs parallel workers.

    Args:
        dreamplace_params: Pre-built DREAMPlace params (picklable).
        base_seed: Base seed for per-worker RNG. Each worker uses base_seed + worker_id.
    """
    absl_cfg = flags.FLAGS.flag_values_dict()  # capture once
    num_gpus = max(1, torch.cuda.device_count())

    env_fns = [
            make_env_fn(
                    benchmark=benchmark,
                    benchmark_path=benchmark_path,
                    dreamplace_params=dreamplace_params,
                    grid=grid,
                    output_dir=output_dir,
                    absl_cfg=absl_cfg,
                    run_dreamplace=run_dreamplace,
                    gpu_id=i % num_gpus,
                    worker_id=i,
                    base_seed=base_seed,
            )
            for i in range(n_envs)
    ]

    vec_env = AsyncSubsetVectorEnv(
            env_fns,
            autoreset_mode=AutoresetMode.DISABLED,
    )
    return vec_env
