import copy
import gc
import os
import tempfile
import time
from pathlib import Path
from typing import Any, Dict, Optional, Tuple

import numpy as np
import torch
from absl import logging as absl_logging
from dreamplace import NonLinearPlace
from dreamplace import Params
from dreamplace import PlaceDB

DREAMPLACE_NUM_THREADS = 8


# ============== Common Override Function ==============

def _apply_dreamplace_overrides(
        params: Params.Params,
        benchmark: str,
        gpu: bool = torch.cuda.is_available(),
        result_dir: str = None,
        render: bool = False,
        num_threads: int = DREAMPLACE_NUM_THREADS,
        seed: int = None,
        # Placement flags: True = enable for metrics, False = disable
        routability_opt_flag: bool = False,
        macro_overlap_flag: bool = False,
        detailed_place_flag: bool = None,
        legalize_flag: bool = None,
        macro_place_flag: bool = None,
        random_center_init_flag: bool = True,
        gift_init_flag: bool = False,
        # Numeric overrides: None = use JSON config default
        density_weight: Optional[float] = None,
        iteration: Optional[int] = None,
        num_bins_x: Optional[int] = None,
        num_bins_y: Optional[int] = None,
        target_density: Optional[float] = None,
        stop_overflow: Optional[float] = None,
        learning_rate: Optional[float] = None,
        optimizer: Optional[str] = None,
        use_bb: bool = True,
        # Anchor/spring term for soft-fixed macros
        anchor_flag: bool = False,
        anchor_weight: Optional[float] = None,
        # Hard-freeze macros (gradient zeroing)
        freeze_macro_flag: bool = False,
        # Noise ratio for global placement
        gp_noise_ratio: Optional[float] = None,
        # Evaluation frequency
        eval_frequency: int = 100,
        # Log config to console
        log_config: bool = True,
) -> Params.Params:
    """Apply common overrides to DREAMPlace params.

    This is the shared logic used by both get_bookshelf_dreamplace_params
    and get_lefdef_dreamplace_params.
    """
    # Override basic params
    params.gpu = gpu
    params.plot_flag = render
    params.num_threads = num_threads
    params.result_dir = result_dir

    # Mutual exclusivity check
    if random_center_init_flag and gift_init_flag:
        raise ValueError(
                "random_center_init_flag and gift_init_flag are mutually exclusive. "
                "Set one to False."
        )
    if anchor_flag and freeze_macro_flag:
        raise ValueError(
                "anchor_flag and freeze_macro_flag are mutually exclusive. "
                "Use anchor_flag for soft-fixed macros (analytical baseline), "
                "freeze_macro_flag for hard-fixed macros (ChipFormer dataset)."
        )
    params.random_center_init_flag = random_center_init_flag
    params.gift_init_flag = gift_init_flag
    params.use_bb = use_bb
    params.anchor_flag = anchor_flag
    params.anchor_weight = anchor_weight
    params.freeze_macro_flag = freeze_macro_flag

    # Override global_place_stages params if specified
    if iteration is not None:
        params.global_place_stages[0]['iteration'] = iteration
    if num_bins_x is not None:
        params.global_place_stages[0]['num_bins_x'] = num_bins_x
    if num_bins_y is not None:
        params.global_place_stages[0]['num_bins_y'] = num_bins_y
    if learning_rate is not None:
        params.global_place_stages[0]['learning_rate'] = learning_rate
    if optimizer is not None:
        params.global_place_stages[0]['optimizer'] = optimizer
    if target_density is not None:
        params.target_density = target_density
    if stop_overflow is not None:
        params.stop_overflow = stop_overflow
    if density_weight is not None:
        params.density_weight = density_weight
    if seed is not None:
        params.random_seed = seed
    if gp_noise_ratio is not None:
        params.gp_noise_ratio = gp_noise_ratio
    params.eval_frequency = eval_frequency

    # Override placement flags only if explicitly specified
    if detailed_place_flag is not None:
        params.detailed_place_flag = detailed_place_flag
    if legalize_flag is not None:
        params.legalize_flag = legalize_flag
        params.abacus_legalize_flag = legalize_flag
    if macro_place_flag is not None:
        params.macro_place_flag = macro_place_flag
    if macro_overlap_flag is not None:
        params.macro_overlap_flag = macro_overlap_flag
    if routability_opt_flag is not None:
        params.routability_opt_flag = routability_opt_flag
    from veoplace.utils.benchmark_registry import get_dreamplace_overrides
    overrides = get_dreamplace_overrides(benchmark)
    if "detailed_place_flag" in overrides:
        params.detailed_place_flag = overrides["detailed_place_flag"]
    if "legalize_flag" in overrides:
        params.legalize_flag = overrides["legalize_flag"]
        params.abacus_legalize_flag = overrides["legalize_flag"]
    if "lp_legalization_flag" in overrides:
        params.lp_legalization_flag = overrides["lp_legalization_flag"]

    # Log final config
    if log_config:
        stage0 = params.global_place_stages[0]
        sep = "=" * 70
        div = "-" * 70

        bins_str = f"{stage0['num_bins_x']}x{stage0['num_bins_y']}"
        gpu_str = "ON" if params.gpu else "OFF"
        opt_str = stage0.get('optimizer', 'nesterov')

        config_lines = [
                sep,
                f"  DREAMPlace Config: {benchmark}",
                sep,
                f"  GPU:            {gpu_str:<10}  Bins:           {bins_str:<10}  Iters:   {stage0['iteration']}   Threads: {params.num_threads}",
                f"  LR:             {stage0['learning_rate']:<10.4f}  Optimizer:      {opt_str}",
                f"  Target Density: {params.target_density:<10.2f}  Stop Overflow:  {params.stop_overflow:.2f}       Density Weight: {params.density_weight:.2e}",
                div,
                f"  legalize:       {str(params.legalize_flag):<10}  detailed:       {str(params.detailed_place_flag):<10}  macro_place:   {params.macro_place_flag}",
                f"  lp_legalize:    {str(getattr(params, 'lp_legalization_flag', True)):<10}",
                f"  routability:    {str(params.routability_opt_flag):<10}  macro_overlap:  {str(getattr(params, 'macro_overlap_flag', False))}",
                f"  random_center:  {str(params.random_center_init_flag):<10}  gift:           {str(params.gift_init_flag):<10}  render:        {params.plot_flag}",
                f"  anchor:         {str(getattr(params, 'anchor_flag', False)):<10}  anchor_weight:  {(getattr(params, 'anchor_weight', None) or 0.0):<10.4f}  seed:          {getattr(params, 'random_seed', 'N/A')}",
                sep,
        ]
        absl_logging.info("\n" + "\n".join(config_lines))

    return params


# ============== DEF/LEF Format Support ==============
# For benchmarks with DEF/LEF files (e.g., ariane133), we can bypass the CT format
# and load directly into DREAMPlace. This avoids the floating-point coordinate issues
# that cause legalization to hang.

def get_lefdef_paths(netlist_dir: str, benchmark: str) -> Optional[
    Dict[str, Any]]:
    """
    Check if DEF/LEF files exist for a benchmark and return their paths.

    Looks for files in: {netlist_dir}/{benchmark}/lefdef/
    Expected structure:
        lefdef/
        ├── {benchmark}.def
        ├── *.tech.lef (technology LEF)
        ├── *.macro.*.lef or *.lef (cell library LEFs)
        └── fakeram*.lef (optional SRAM LEFs)

    Returns:
        Dict with 'def_input' and 'lef_input' paths, or None if not found.
    """
    lefdef_dir = Path(netlist_dir) / benchmark / "lefdef"

    if not lefdef_dir.exists():
        return None

    # Find DEF file
    def_file = lefdef_dir / f"{benchmark}.def"
    if not def_file.exists():
        # Try without benchmark prefix
        def_files = list(lefdef_dir.glob("*.def"))
        if def_files:
            def_file = def_files[0]
        else:
            return None

    # Find LEF files (order matters: tech LEF first, then cell LEFs)
    lef_files = []

    # 1. Technology LEF (usually *.tech.lef)
    tech_lefs = list(lefdef_dir.glob("*.tech.lef"))
    lef_files.extend(sorted(tech_lefs))

    # 2. Cell library LEFs (*.macro.*.lef or other .lef files)
    for lef in sorted(lefdef_dir.glob("*.lef")):
        if lef not in lef_files:
            lef_files.append(lef)

    if not lef_files:
        return None

    return {
            'def_input': str(def_file),
            'lef_input': [str(lef) for lef in lef_files],
    }


def get_lefdef_dreamplace_params(
        benchmark: str,
        config_file: Path,
        def_file: Path,
        lef_files: tuple[Path, ...],
        verilog_file: Path = None,
        gpu: bool = torch.cuda.is_available(),
        result_dir: str = None,
        render: bool = False,
        num_threads: int = DREAMPLACE_NUM_THREADS,
        seed: int = None,
        # Placement flags: True = enable for metrics, False = disable
        routability_opt_flag: bool = False,
        macro_overlap_flag: bool = False,
        detailed_place_flag: bool = None,
        legalize_flag: bool = None,
        macro_place_flag: bool = None,
        random_center_init_flag: bool = True,
        gift_init_flag: bool = False,
        # Numeric overrides: None = use JSON config default
        density_weight: Optional[float] = None,
        iteration: Optional[int] = None,
        num_bins_x: Optional[int] = None,
        num_bins_y: Optional[int] = None,
        target_density: Optional[float] = None,
        stop_overflow: Optional[float] = None,
        learning_rate: Optional[float] = None,
        optimizer: Optional[str] = None,
        use_bb: bool = True,
        anchor_flag: bool = False,
        anchor_weight: Optional[float] = None,
        freeze_macro_flag: bool = False,
        gp_noise_ratio: Optional[float] = None,
        eval_frequency: int = 100,
) -> Optional[Params.Params]:
    """
    Create DREAMPlace params for DEF/LEF format benchmarks.

    Args:
        benchmark: Benchmark name
        config_file: Path to DREAMPlace JSON config file
        def_file: Path to DEF file
        lef_files: Tuple of paths to LEF files
        verilog_file: Optional path to Verilog netlist (for connectivity)
    """
    json_cfg = Path(config_file)
    if not json_cfg.exists():
        return None

    # Load params from JSON config
    params = Params.Params()
    params.load(str(json_cfg))

    # Set DEF/LEF paths from registry
    params.def_input = str(def_file)
    params.lef_input = [str(lef) for lef in lef_files]

    # Set Verilog file if provided (for netlist connectivity)
    if verilog_file is not None:
        params.verilog_input = str(verilog_file)

    # Apply common overrides (shared with Bookshelf)
    return _apply_dreamplace_overrides(
            params=params,
            benchmark=benchmark,
            gpu=gpu,
            result_dir=result_dir,
            render=render,
            num_threads=num_threads,
            seed=seed,
            routability_opt_flag=routability_opt_flag,
            macro_overlap_flag=macro_overlap_flag,
            detailed_place_flag=detailed_place_flag,
            legalize_flag=legalize_flag,
            macro_place_flag=macro_place_flag,
            random_center_init_flag=random_center_init_flag,
            gift_init_flag=gift_init_flag,
            density_weight=density_weight,
            iteration=iteration,
            num_bins_x=num_bins_x,
            num_bins_y=num_bins_y,
            target_density=target_density,
            stop_overflow=stop_overflow,
            learning_rate=learning_rate,
            optimizer=optimizer,
            use_bb=use_bb,
            anchor_flag=anchor_flag,
            anchor_weight=anchor_weight,
            freeze_macro_flag=freeze_macro_flag,
            gp_noise_ratio=gp_noise_ratio,
            eval_frequency=eval_frequency,
    )


def has_lefdef_format(netlist_dir: str, benchmark: str) -> bool:
    """
    Check if a benchmark has complete DEF/LEF setup available.

    Requires:
    1. JSON config file: {netlist_dir}/{benchmark}/{benchmark}.json (top level)
    2. lefdef/ subdirectory exists with DEF/LEF files
    3. DEF and LEF files exist
    """
    # JSON at top level (same as bookshelf)
    json_cfg = Path(netlist_dir) / benchmark / f"{benchmark}.json"
    lefdef_dir = Path(netlist_dir) / benchmark / "lefdef"

    if not json_cfg.exists():
        return False

    if not lefdef_dir.exists():
        return False

    # Verify DEF/LEF files exist using get_lefdef_paths
    paths = get_lefdef_paths(netlist_dir, benchmark)
    if paths is None:
        return False

    # Verify DEF file exists
    if not Path(paths['def_input']).exists():
        return False

    # Verify all LEF files exist
    for lef_path in paths['lef_input']:
        if not Path(lef_path).exists():
            return False

    return True


def has_bookshelf_format(netlist_dir: str, benchmark: str) -> bool:
    """
    Check if a benchmark has Bookshelf format files available.

    Requires:
    1. JSON config file: {netlist_dir}/{benchmark}/{benchmark}.json
    2. AUX file: {netlist_dir}/{benchmark}/{benchmark}.aux
    """
    json_cfg = Path(netlist_dir) / benchmark / f"{benchmark}.json"
    aux_file = Path(netlist_dir) / benchmark / f"{benchmark}.aux"

    return json_cfg.exists() and aux_file.exists()


def create_placedb_from_bookshelf(netlist_dir: str,
        benchmark: str) -> PlaceDB.PlaceDB:
    """
    Loads an original, ungrouped Bookshelf benchmark into a PlaceDB object.
    This is useful for extracting ground-truth parameters like canvas size,
    as PlaceDB correctly parses the .scl file.

    Args:
        netlist_dir: The directory containing the benchmark subdirectories
                     (e.g., 'data/netlists/ispd2005').
        benchmark: The name of the benchmark (e.g., 'adaptec1').

    Returns:
        An initialized PlaceDB.PlaceDB object.
    """
    json_cfg = os.path.join(netlist_dir, benchmark, f"{benchmark}.json")
    aux_path = os.path.join(netlist_dir, benchmark, f"{benchmark}.aux")

    if not os.path.isfile(json_cfg) or not os.path.isfile(aux_path):
        raise FileNotFoundError(
                f"Missing .json or .aux for {benchmark} in {netlist_dir}")

    params = Params.Params()
    params.load(json_cfg)
    params.aux_input = str(aux_path)

    # This is a known compatibility fix for newer versions of numpy.
    if not hasattr(np, "string_"):
        np.string_ = np.bytes_

    placedb = PlaceDB.PlaceDB()

    # This call triggers the parsing of all Bookshelf files (.aux, .scl, .nodes, etc.)
    placedb(params)
    return placedb


def get_bookshelf_dreamplace_params(
        benchmark: str,
        config_file: Path,
        aux_file: Path,
        gpu: bool = torch.cuda.is_available(),
        result_dir: str = None,
        render: bool = False,
        num_threads: int = DREAMPLACE_NUM_THREADS,
        seed: int = None,
        # Placement flags: True = enable for metrics, False = disable
        routability_opt_flag: bool = False,
        macro_overlap_flag: bool = False,
        detailed_place_flag: bool = None,
        legalize_flag: bool = None,
        macro_place_flag: bool = None,
        random_center_init_flag: bool = True,
        gift_init_flag: bool = False,
        # Numeric overrides: None = use JSON config default
        density_weight: Optional[float] = None,
        iteration: Optional[int] = None,
        num_bins_x: Optional[int] = None,
        num_bins_y: Optional[int] = None,
        target_density: Optional[float] = None,
        stop_overflow: Optional[float] = None,
        learning_rate: Optional[float] = None,
        optimizer: Optional[str] = None,
        # "adam", "nesterov", "sgd", "sgd_momentum", "sgd_nesterov"
        use_bb: bool = True,
        # Anchor/spring term for soft-fixed macros
        anchor_flag: bool = False,
        anchor_weight: Optional[float] = None,
        # Hard-freeze macros (gradient zeroing)
        freeze_macro_flag: bool = False,
        # Noise ratio for global placement (default 0.025, set to 0 for warm start)
        gp_noise_ratio: Optional[float] = None,
        # Evaluation frequency - only compute full metrics every N iterations (default 50)
        # This provides 20-30% speedup by skipping expensive wirelength/hpwl computations
        eval_frequency: int = 100,
) -> Params.Params:
    """Construct a DreamPlace Params() from benchmark JSON config with optional overrides.

    Args:
        benchmark: Benchmark name
        config_file: Path to DREAMPlace JSON config file
        aux_file: Path to Bookshelf .aux file
    """
    params = Params.Params()

    json_cfg = str(config_file)
    aux_path = str(aux_file)
    aux_dir = aux_file.parent

    if not os.path.isfile(aux_path):
        raise FileNotFoundError(aux_path)

    # Parse .aux file to get .pl path
    with open(aux_path, "r") as f:
        for line in f:
            if line.startswith("RowBasedPlacement"):
                parts = line.split()
                pl_path = str(aux_dir / parts[-2])
                break

    if not os.path.isfile(json_cfg):
        raise FileNotFoundError(json_cfg)

    if not os.path.isfile(pl_path):
        raise FileNotFoundError(pl_path)

    params.load(json_cfg)
    params.aux_input = aux_path
    params.params_dict["pl_path"] = pl_path
    params.params_dict["json_config_path"] = json_cfg

    # Apply common overrides (shared with DEF/LEF)
    return _apply_dreamplace_overrides(
            params=params,
            benchmark=benchmark,
            gpu=gpu,
            result_dir=result_dir,
            render=render,
            num_threads=num_threads,
            seed=seed,
            routability_opt_flag=routability_opt_flag,
            macro_overlap_flag=macro_overlap_flag,
            detailed_place_flag=detailed_place_flag,
            legalize_flag=legalize_flag,
            macro_place_flag=macro_place_flag,
            random_center_init_flag=random_center_init_flag,
            gift_init_flag=gift_init_flag,
            density_weight=density_weight,
            iteration=iteration,
            num_bins_x=num_bins_x,
            num_bins_y=num_bins_y,
            target_density=target_density,
            stop_overflow=stop_overflow,
            learning_rate=learning_rate,
            optimizer=optimizer,
            use_bb=use_bb,
            anchor_flag=anchor_flag,
            anchor_weight=anchor_weight,
            freeze_macro_flag=freeze_macro_flag,
            gp_noise_ratio=gp_noise_ratio,
            eval_frequency=eval_frequency,
    )


def rewrite_pl_file(
        template: Path,
        out_pl: Path,
        node_pos: Dict[str, Tuple[float, float, bool]]  # (x, y, fixed)
):
    """Rewrite a .pl file by injecting node positions from a dictionary.
    Assumes the coordinates in node_pos are already in the final, desired micron space.
    """
    absl_logging.debug(
            "↳ _rewrite_pl: %s → %s  (%d nodes)",
            template, out_pl, len(node_pos))

    # The mapping is now direct, as the caller is responsible for coordinate conversion.
    mapping = {n: (float(x), float(y), fixed) for n, (x, y, fixed, *_) in
               node_pos.items()}
    updated_nodes = set()

    with template.open() as fin, out_pl.open("w") as fout:
        for ln in fin:
            s = ln.lstrip()
            if not s or s[0] not in ('a', "o", "p"):
                fout.write(ln)
                continue

            parts = s.split()
            n = parts[0]
            if n in mapping:
                old_pos = (float(parts[1]), float(parts[2]))
                parts[1], parts[2] = map(str, mapping[n][:2])  # x, y
                fixed = mapping[n][2]  # True if fixed, False if movable

                new_ln = f"{ln[:len(ln) - len(s)]}" + " ".join(parts)
                if fixed and "/FIXED" not in s and "/FIXED_NI" not in s:
                    new_ln += " /FIXED"
                new_ln += "\n"

                absl_logging.debug("  · %s  (%s → %s)", n, old_pos, mapping[n])
                updated_nodes.add(n)
                ln = new_ln

            fout.write(ln)

    assert len(updated_nodes) == len(node_pos), (
            f"Expected {len(node_pos)} nodes, but only {len(updated_nodes)} were updated."
    )
    absl_logging.debug("✓ _rewrite_pl finished; %d / %d nodes updated",
                       len(updated_nodes), len(node_pos))


def rewrite_aux_file(template: Path, out_aux: Path, new_pl: Path) -> None:
    """Copy `template` to `out_aux`, but replace the .pl token with `new_pl.name`."""
    absl_logging.debug("↳ _rewrite_aux: %s → %s  (new .pl = %s)",
                       template, out_aux, new_pl)

    # --- read ----------------------------------------------------------
    with template.open("r") as fin:
        aux_lines = fin.read().splitlines()

    # --- edit first data line -----------------------------------------
    for idx, line in enumerate(aux_lines):
        stripped = line.strip()
        if stripped and not stripped.startswith("#"):
            tokens = stripped.split()
            for i, tok in enumerate(tokens):
                if tok.endswith(".pl"):
                    replacement = new_pl.name  # basename only
                    absl_logging.debug("  · Replacing '%s' → '%s'",
                                       tok, replacement)
                    tokens[i] = replacement
                    break
            aux_lines[idx] = " ".join(tokens)
            break

    # --- write ---------------------------------------------------------
    with out_aux.open("w") as fout:
        fout.write("\n".join(aux_lines))

    absl_logging.debug("✓ _rewrite_aux finished; .pl path updated")


#
# def fit_canvas_to_nodes(db, pad: float = 0.0):
#     """
#     Re-compute db.xl/yl/xh/yh so that every PHYSICAL node
#     (lower-left + size) fits inside.
#     Works for torch tensors or numpy arrays.
#     """
#     s = slice(0, db.num_physical_nodes)  # the safe overlap
#
#     xl = float(db.node_x[s].min()) - pad
#     yl = float(db.node_y[s].min()) - pad
#     xh = float((db.node_x[s] + db.node_size_x[s]).max()) + pad
#     yh = float((db.node_y[s] + db.node_size_y[s]).max()) + pad
#
#     db.xl, db.yl, db.xh, db.yh = xl, yl, xh, yh
#     return xh - xl, yh - yl


def debug_placedb_extents(placedb, stage_name: str) -> None:
    # helper to turn bytes → str once
    def decode(n):  # one-liner equivalent of the old _name()
        return n.decode() if isinstance(n, (bytes, np.bytes_)) else n

    # find extreme indices
    max_x_idx = int(np.argmax(placedb.node_x))
    min_x_idx = int(np.argmin(placedb.node_x))
    max_y_idx = int(np.argmax(placedb.node_y))
    min_y_idx = int(np.argmin(placedb.node_y))

    # pull coordinates / sizes
    max_x = float(placedb.node_x[max_x_idx])
    max_x_w = float(placedb.node_size_x[max_x_idx])
    min_x = float(placedb.node_x[min_x_idx])

    max_y = float(placedb.node_y[max_y_idx])
    max_y_h = float(placedb.node_size_y[max_y_idx])
    min_y = float(placedb.node_y[min_y_idx])

    # print
    absl_logging.info(
            "[DEBUG] %s - Canvas: (%.2f, %.2f) → (%.2f, %.2f)",
            stage_name, placedb.xl, placedb.yl, placedb.xh, placedb.yh,
    )
    absl_logging.info(
            "[DEBUG] %s - Min X: idx=%d, name=%s, x=%.2f",
            stage_name, min_x_idx, decode(placedb.node_names[min_x_idx]), min_x,
    )
    absl_logging.info(
            "[DEBUG] %s - Max X: idx=%d, name=%s, x=%.2f, w=%.2f, extent=%.2f",
            stage_name, max_x_idx, decode(placedb.node_names[max_x_idx]),
            max_x, max_x_w, max_x + max_x_w,
    )
    absl_logging.info(
            "[DEBUG] %s - Min Y: idx=%d, name=%s, y=%.2f",
            stage_name, min_y_idx, decode(placedb.node_names[min_y_idx]), min_y,
    )
    absl_logging.info(
            "[DEBUG] %s - Max Y: idx=%d, name=%s, y=%.2f, h=%.2f, extent=%.2f",
            stage_name, max_y_idx, decode(placedb.node_names[max_y_idx]),
            max_y, max_y_h, max_y + max_y_h,
    )


# placedb util for getting the osition of a nod eby name
def get_node_position(placedb: PlaceDB.PlaceDB, node_name: str) -> Tuple[
    float, float]:
    """
    Get the position of a node by its name in the PlaceDB.
    Returns (x, y) coordinates.
    """
    if node_name not in placedb.node_name2id_map:
        raise ValueError(f"Node '{node_name}' not found in PlaceDB.")

    node_id = placedb.node_name2id_map[node_name]
    x = float(placedb.node_x[node_id])
    y = float(placedb.node_y[node_id])

    return x, y


def set_position(
        placedb: PlaceDB.PlaceDB,
        node_name: str,
        x: float,
        y: float) -> None:
    """
    Directly assign (x, y) to the node’s lower-left corner.

    Only updates internal arrays; does not perform any
    boundary or movability checks.
    """
    # Look up the node ID
    node_id = placedb.node_name2id_map[node_name]

    # Write the new coordinates
    placedb.node_x[node_id] = x
    placedb.node_y[node_id] = y

    # If PlaceDB also carries a flattened pos tensor, sync it
    if hasattr(placedb, "pos"):
        placedb.pos[node_id] = x
        placedb.pos[node_id + placedb.num_nodes] = y


def get_node_size(placedb: PlaceDB.PlaceDB, node_name: str) -> Tuple[
    float, float]:
    """
    Get the size of a node by its name in the PlaceDB.
    Returns (width, height) as a tuple.
    """
    if node_name not in placedb.node_name2id_map:
        raise ValueError(f"Node '{node_name}' not found in PlaceDB.")

    node_id = placedb.node_name2id_map[node_name]
    width = float(placedb.node_size_x[node_id])
    height = float(placedb.node_size_y[node_id])

    return width, height


def nm_per_internal_unit(placedb: PlaceDB.PlaceDB,
        params: Params.Params) -> float:
    """
    Convert DreamPlace internal units (site units) to nanometers.
    """
    if placedb.def_unit is None or placedb.def_unit == 0:
        raise ValueError(
                "placedb.def_unit is missing; rebuild DreamPlace bindings or parse DEF units.")
    if not hasattr(params, "scale_factor") or params.scale_factor == 0:
        raise ValueError("params.scale_factor is missing or zero.")
    nm_per_dbu = 1000.0 / float(placedb.def_unit)
    return nm_per_dbu / float(params.scale_factor)


def hpwl_internal_to_nm(hpwl_internal: float, placedb: PlaceDB.PlaceDB,
        params: Params.Params) -> float:
    """
    Convert HPWL from DreamPlace internal units to nanometers.
    """
    return float(hpwl_internal) * nm_per_internal_unit(placedb, params)


def create_placedb_with_fixed_macros(
        params: Params.Params,
        macro_names: list,
        node_pos: dict = None,
) -> PlaceDB.PlaceDB:
    """
    Create a PlaceDB with specified macros marked as fixed.

    This function writes a temporary .pl file with macros marked as /FIXED,
    then loads PlaceDB from it. This sets up the fixed_macro_mask correctly
    so that subsequent calls can just update positions in-memory.

    Args:
        params: DREAMPlace parameters (used to get original .pl/.aux paths)
        macro_names: List of macro node names to mark as fixed
        node_pos: Optional dict of {name: (x, y)} positions to use for fixed macros.
                  If provided, macros are fixed at these positions.
                  If None, uses original benchmark positions.

    Returns:
        PlaceDB with fixed_macro_mask correctly initialized
    """

    # Load initial PlaceDB to get current positions
    initial_placedb = PlaceDB.PlaceDB()
    initial_placedb(params)

    # If no macros to fix, return directly (no temp file overhead)
    if not macro_names:
        absl_logging.info('No macros to fix - all macros movable')
        return initial_placedb

    macro_names_set = set(macro_names)
    absl_logging.info('Creating PlaceDB with %d macros marked as fixed',
                      len(macro_names_set))

    # Set up paths
    orig_aux = Path(params.aux_input)
    bench_dir = orig_aux.parent
    orig_pl_path = Path(params.params_dict["pl_path"])

    # Create node_pos dict with all macros marked as fixed
    # Use provided positions if available, otherwise use original benchmark positions
    fixed_macro_positions = {}
    fallback_count = 0
    for name in macro_names_set:
        idx = initial_placedb.node_name2id_map.get(name)
        if idx is not None:
            if node_pos and name in node_pos:
                # Use provided VeoPlace positions (must be in ORIGINAL coords!)
                x, y = node_pos[name][
                    :2]  # Handle (x, y) or (x, y, fixed) tuples
                fixed_macro_positions[name] = (float(x), float(y), True)
            else:
                # Fall back - convert TRANSFORMED to ORIGINAL coords!
                # BUG FIX: initial_placedb.node_x/y are TRANSFORMED (0-based),
                # but .pl file expects ORIGINAL coords. Must add xl/yl back.
                fallback_count += 1
                absl_logging.warning(
                        f"Macro '{name}' not in node_pos, using fallback!")
                xl = initial_placedb.xl
                yl = initial_placedb.yl
                original_x = float(initial_placedb.node_x[idx]) + xl
                original_y = float(initial_placedb.node_y[idx]) + yl
                fixed_macro_positions[name] = (original_x, original_y, True)

    if fallback_count > 0:
        absl_logging.warning(
                f"Used fallback for {fallback_count}/{len(macro_names_set)} macros!")

    # Write temporary .pl file with macros marked as /FIXED
    with tempfile.NamedTemporaryFile(
            mode="w", suffix=".pl", dir=bench_dir, delete=False
    ) as temp_pl, tempfile.NamedTemporaryFile(
            mode="w", suffix=".aux", dir=bench_dir, delete=False
    ) as temp_aux:
        temp_pl_path = Path(temp_pl.name)
        temp_aux_path = Path(temp_aux.name)

        rewrite_pl_file(orig_pl_path, temp_pl_path, fixed_macro_positions)
        rewrite_aux_file(orig_aux, temp_aux_path, temp_pl_path)

    # Reload PlaceDB from temp file - this sets fixed_macro_mask correctly
    params.aux_input = str(temp_aux_path)

    placedb_fixed = PlaceDB.PlaceDB()
    placedb_fixed(params)

    # NOTE: Do NOT delete temp files here - DREAMPlace uses lazy loading/mmap
    # and still needs them after PlaceDB.__call__() returns.
    # Files will be cleaned up when the process exits.

    absl_logging.info(
            f"PlaceDB initialized with {len(macro_names_set)} macros fixed")
    return placedb_fixed


def convert_actions_to_positions(macro_names: list[str], actions: list[int],
        grid_cols: int, grid_rows: int,
        canvas_width: float, canvas_height: float) -> dict:
    """
    Convert ChipFormer grid actions to canvas (x, y) positions for macros.

    Args:
        macro_names: List of macro names in the order ChipFormer placed them
        actions: List of grid actions from ChipFormer
        grid_cols: Number of grid columns
        grid_rows: Number of grid rows
        canvas_width: Canvas width in microns
        canvas_height: Canvas height in microns

    Returns:
        dict of {macro_name: (x, y)} positions
    """
    # Calculate conversion ratios
    ratio_x = canvas_width / grid_cols
    ratio_y = canvas_height / grid_rows

    # Convert each action to (x, y) position
    node_pos = {}
    for macro_name, action in zip(macro_names, actions):
        row = action // grid_cols
        col = action % grid_cols
        x = col * ratio_x
        y = row * ratio_y
        node_pos[macro_name] = (x, y)

    return node_pos


def get_global_hpwl(
        params: Params.Params,
        node_pos: dict,
        placedb: PlaceDB.PlaceDB,
        output_dir: str = None,
        render_final: bool = False,
        freeze_macro_indices: list = None,
        freeze_macro_names: list = None,
        # Anchor/spring term params (alternative to hard gradient zeroing)
        anchor_node_indices: list = None,  # list of node indices to anchor
        anchor_weight: float = None,  # spring weight (higher = stiffer)
) -> dict:
    """
    DREAMPlace HPWL evaluation with multiple macro constraint modes.

    MODES:
    1. Anchor mode (anchor_node_indices provided):
       - Macros can move but pay a spring/anchor penalty for drifting from targets
       - Targets are taken from node_pos for the anchored nodes
       - loss += anchor_weight * sum((x - x0)^2 + (y - y0)^2)
       - Gradients flow normally - DREAMPlace can fine-tune macro positions
       - Best for VLM-guided placement where suggestions are "strong hints"

    2. Dual PlaceDB mode (freeze_macro_indices + freeze_macro_names):
       - Phase 1: Global placement with gradients zeroed for macros
       - Phase 2: Macros become TRUE terminals for legalization
       - Macros cannot move at all - hard constraint

    3. Single PlaceDB freeze mode (freeze_macro_indices only):
       - Gradients zeroed for macros throughout
       - Legalizer may drift macros slightly (Abacus treats as soft)

    4. All movable (no freeze/anchor params):
       - Standard DREAMPlace with all nodes optimizable

    Args:
        params: DREAMPlace parameters
        node_pos: Dict mapping node names to (x, y) positions
        placedb: PlaceDB instance (macros should be in movable range)
        output_dir: Directory for output files
        render_final: Whether to render final placement
        freeze_macro_indices: List of node indices to freeze (for modes 2-3)
        freeze_macro_names: List of macro names (required for mode 2)
        anchor_node_indices: List of node indices to anchor (mode 1) - positions from node_pos
        anchor_weight: Spring constant - higher = stiffer constraint (mode 1)

    Returns:
        dict with keys: hpwl, congestion, density, macro_overlap, rmst_wl,
                       overflow, wirelength, max_density, objective, placedb,
                       processed_metrics
    """
    debug_placedb_extents(placedb, "Template")

    t0 = time.perf_counter()

    if not hasattr(np, "string_"):
        np.string_ = np.bytes_

    # CRITICAL: Deep copy params to avoid state leakage between calls
    params = copy.deepcopy(params)

    # Update result_dir if output_dir provided (for dynamic per-run directories)
    if output_dir:
        params.result_dir = output_dir
        Path(output_dir).mkdir(parents=True, exist_ok=True)

    # Determine which mode we're in
    has_freeze_indices = freeze_macro_indices is not None and len(
            freeze_macro_indices) > 0
    # Only use anchor mode if anchor_flag is enabled AND indices are provided
    anchor_flag = getattr(params, 'anchor_flag', False)
    use_anchor = anchor_flag and anchor_node_indices is not None and len(
            anchor_node_indices) > 0

    # Set anchor params if using soft-fixed approach
    if use_anchor:
        # Build anchor_positions from node_pos using anchor_node_indices
        import torch
        positions_list = []
        for idx in anchor_node_indices:
            name = placedb.node_names[idx]
            if isinstance(name, bytes):
                name = name.decode()
            pos = node_pos[name]
            positions_list.append([pos[0], pos[1]])
        anchor_positions = torch.tensor(
                positions_list,
                dtype=torch.float32,
                device='cuda' if params.gpu else 'cpu'
        )
        params.anchor_positions = anchor_positions
        params.anchor_node_indices = anchor_node_indices
        params.anchor_weight = anchor_weight if anchor_weight is not None else 1.0
        # With anchor term, we DON'T use fix_node_indices - let gradients flow
        params.fix_node_indices = []
        params.skip_macro_legalization = False
        params.skip_macro_stage = False
        absl_logging.info(
                "[DREAMPlace] Anchor mode: %d macros with spring weight %.2e (gradients enabled)",
                len(anchor_node_indices), params.anchor_weight)
    elif has_freeze_indices:
        # Case 2: Freeze macros but run normal legalization (no Phase 2)
        params.fix_node_indices = freeze_macro_indices
        params.skip_macro_legalization = False  # Snap macros to grid for legal placement
        params.skip_macro_stage = False  # Run macro stage for halo inflation (macros still frozen via fix_node_indices)
        # Keep legalize_flag/detailed_place_flag from params - don't override!
        absl_logging.info(
                "[DREAMPlace] Freezing %d macros with halo + grid snap, normal legalization will run",
                len(freeze_macro_indices))
    else:
        # Case 3: All movable
        params.fix_node_indices = []
        params.skip_macro_legalization = False
        params.skip_macro_stage = False
        absl_logging.info(
                "[DREAMPlace] All macros movable")

    # Directly set positions on placedb (no file I/O needed)
    for name, pos in node_pos.items():
        x, y = pos[0], pos[1]
        set_position(placedb, name, x, y)

    debug_placedb_extents(placedb, "After setting positions")

    placer = NonLinearPlace.NonLinearPlace(params, placedb, None)

    # CRITICAL: Set macro positions on placer.pos AFTER init
    # NonLinearPlace reinitializes positions during construction,
    # so we must restore macro positions here (stdcells keep DREAMPlace init, which is fine)
    # EXCEPTION: If anchor_flag=True, skip this - keep DREAMPlace's default init and
    # let anchor loss pull macros toward VLM positions during optimization
    import torch
    anchor_flag = getattr(params, 'anchor_flag', False)
    if not anchor_flag:
        with torch.no_grad():
            for name, pos in node_pos.items():
                node_id = placedb.node_name2id_map[name]
                x, y = pos[0], pos[1]
                placer.pos[0][node_id] = x
                placer.pos[0][placedb.num_nodes + node_id] = y
    else:
        absl_logging.info(
                "[DREAMPlace] anchor_flag=True: keeping DREAMPlace default init, anchor loss pulls toward VLM positions")

    learning_rate = params.global_place_stages[0]['learning_rate']
    metrics = placer(params, placedb, learning_rate)

    # Unpack tuple: (rmst_wl, hpwl, congestion, density, macro_overlap, processed_metrics)
    rmst_wl, hpwl, congestion, density, overlap, processed_metrics = metrics
    # NOTE: placedb.node_x/node_y are already updated by placedb.apply() inside NonLinearPlace

    # ============== LEGALITY CHECK ==============
    # NonLinearPlace can fail legalization silently (logs error but returns illegal placement)
    # Check legality explicitly so we know if the HPWL is from a legal placement
    is_legal = True
    if params.legalize_flag:
        if placer.op_collections.legality_check_op is not None:
            is_legal = bool(
                    placer.op_collections.legality_check_op(placer.pos[0]))
            if not is_legal:
                absl_logging.warning(
                        "[DREAMPlace] LEGALITY CHECK FAILED - placement is ILLEGAL! "
                        "HPWL=%.0f is from an illegal placement.", hpwl)
            else:
                absl_logging.info("[DREAMPlace] Legality check PASSED")
        else:
            absl_logging.warning(
                    "[DREAMPlace] legality_check_op is None, cannot verify legality")

    # Standard flow: render final placement if requested
    if render_final and output_dir:
        # Get final iteration number
        if processed_metrics and len(processed_metrics) > 0:
            final_iteration = processed_metrics[-1].iteration
        else:
            final_iteration = 9999  # Fallback if no metrics

        plot_dir = Path(output_dir) / params.design_name() / "plot"
        plot_dir.mkdir(parents=True, exist_ok=True)
        figname = plot_dir / f"iter{final_iteration:04d}.png"

        final_pos = placer.pos[0].data.clone().cpu()
        if isinstance(final_pos, np.ndarray):
            final_pos = torch.from_numpy(final_pos)

        tt = time.time()
        placer.op_collections.draw_place_op(final_pos, str(figname))
        absl_logging.info(
                "Final placement plot (iter%04d) saved to %s (%.3f s)",
                final_iteration, figname, time.time() - tt)

    # Clean up placer object to free GPU memory
    del placer
    gc.collect()  # Force Python GC to actually free the placer before clearing cache
    torch.cuda.empty_cache()

    # Format None values before logging
    cong_str = "%.6f" % congestion if congestion is not None else "N/A"
    dens_str = "%.6f" % density if density is not None else "N/A"
    ovlp_str = "%.6f" % overlap if overlap is not None else "N/A"
    absl_logging.info(
            "[DREAMPlace] HPWL = %.0f, congestion = %s, density = %s, overlap = %s",
            hpwl, cong_str, dens_str, ovlp_str)

    elapsed = time.perf_counter() - t0
    absl_logging.info("[DREAMPlace] Total placement time: %.2f seconds",
                      elapsed)

    # Build result dict with all metrics
    result = {
            'hpwl': float(hpwl) if hpwl is not None else float('inf'),
            'congestion': float(
                    congestion) if congestion is not None else float('inf'),
            'density': float(
                    density) if density is not None else float('inf'),
            'macro_overlap': float(overlap) if overlap is not None else float(
                    'inf'),
            'rmst_wl': float(rmst_wl) if rmst_wl is not None else float('inf'),
            'overflow': float('inf'),
            'wirelength': float('inf'),
            'max_density': float('inf'),
            'objective': float('inf'),
            'placedb': placedb,
            'processed_metrics': processed_metrics,
            'is_legal': is_legal,
    }

    # Extract additional metrics from final iteration
    if processed_metrics and len(processed_metrics) > 0:
        last_metric = processed_metrics[-1]

        def extract(attr_name):
            val = getattr(last_metric, attr_name, None)
            return float(val) if val is not None else float('inf')

        result.update({
                'overflow': extract('overflow'),
                'wirelength': extract('wirelength'),
                'density': extract('density'),
                'max_density': extract('max_density'),
                'objective': extract('objective'),
        })

    return result


def extract_metrics(metrics: Any) -> Tuple[float, float, float]:
    m = metrics

    # Handle new tuple format: (rmst_wl, hpwl, congestion, density, macro_overlap, processed_metrics)
    if isinstance(m, tuple) and len(m) == 6:
        # Extract processed_metrics (last element of tuple)
        m = m[-1]

    route_utilization = float('inf')
    overflow = float('inf')
    hpwl = -float('inf')
    while isinstance(m, list):
        m = m[-1]
    if hasattr(m, "hpwl") and m.hpwl is not None:
        hpwl = float(m.hpwl)
    if hasattr(m, 'overflow') and m.overflow is not None:
        overflow = float(m.overflow)
    if hasattr(m, 'route_utilization') and m.route_utilization is not None:
        route_utilization = float(m.route_utilization)
    return hpwl, route_utilization, overflow
