from pathlib import Path
import pickle

import numba
import numpy as np
from absl import logging

from veoplace.utils.benchmark_registry import config_dir


@numba.njit(cache=True, fastmath=True)
def fill_mask_regions(mask: np.ndarray,
        start_x: np.ndarray, start_y: np.ndarray,
        end_x: np.ndarray, end_y: np.ndarray) -> np.ndarray:
    """
    Vectorised: sx[i]‥ex[i], sy[i]‥ey[i] are the blocked ranges
    for node i. All arrays have the same length n.
    """
    n = start_x.shape[0]  # compile-time constant
    for i in range(n):
        mask[start_x[i]:end_x[i] + 1, start_y[i]:end_y[i] + 1] = 1
    return mask


@numba.njit(cache=True, fastmath=True)
def calculate_all_net_costs(net_img: np.ndarray,
        start_x: np.ndarray, end_x: np.ndarray,
        start_y: np.ndarray, end_y: np.ndarray,
        net_weights: np.ndarray, grid: int) -> np.ndarray:
    """
    Accumulate Manhattan fan-out costs for every net (in place).

    Args
    ----
    net_img : (grid, grid) float32   buffer to update
    start_x, end_x, start_y, end_y : (N,) int32      bounding-box edges (already clipped)
    net_weights                : (N,) float32  net weights
    grid             : int           grid dimension
    """
    n = start_x.shape[0]
    for k in range(n):
        sxi, exi = start_x[k], end_x[k]  # 0‥grid   and  –1‥grid-1  by contract
        syi, eyi = start_y[k], end_y[k]
        wi = net_weights[k]

        # X-fan-out (rows)
        for i in range(sxi):  # rows 0 … sxi-1
            net_img[i, :] += (sxi - i) * wi
        for i in range(exi + 1, grid):  # rows exi+1 … grid-1
            net_img[i, :] += (i - exi) * wi

        # Y-fan-out (columns)
        for j in range(syi):  # cols 0 … syi-1
            net_img[:, j] += (syi - j) * wi
        for j in range(eyi + 1, grid):  # cols eyi+1 … grid-1
            net_img[:, j] += (j - eyi) * wi

    return net_img


def get_hard_macro_ordering_from_placedb(
        benchmark_name: str,
        placedb,
        use_cache: bool = True,
) -> tuple[list[str], dict[str, int]]:
    """
    Compute hard macro ordering from PlaceDB (no plc needed).
    Used for DEF/LEF benchmarks that don't have protobuf.

    Args:
        benchmark_name: Name of the benchmark
        placedb: DREAMPlace PlaceDB object
        use_cache: Whether to use/save cache file

    Returns:
        Tuple of (hard_macro_names, hard_macro_degrees)
        Note: degrees dict is keyed by name (str), not index (int)
    """
    out_dir = config_dir(benchmark_name)
    out_dir.mkdir(parents=True, exist_ok=True)
    cache_file = out_dir / "macro_ordering.pkl"

    if use_cache and cache_file.exists():
        with open(cache_file, 'rb') as f:
            data = pickle.load(f)
        logging.info(
                f"Loaded cached macro ordering for {benchmark_name}: {len(data['names'])} macros")
        return data['names'], data['degrees']

    logging.info(
            f"Computing macro ordering from PlaceDB for {benchmark_name}")

    # Use PlaceDB's built-in movable_macro_mask (uses >= for height, matching PLC)
    hard_macro_indices = \
        np.where(placedb.movable_macro_mask[:placedb.num_movable_nodes])[0]

    logging.info(f"Found {len(hard_macro_indices)} hard macros in PlaceDB")

    # Get names and compute areas
    # DEF/LEF parser escapes brackets, so we unescape them for consistency with PLC
    hard_macro_names = [
            placedb.node_names[i].decode().replace('\\[', '[').replace('\\]',
                                                                       ']')
            for i in hard_macro_indices
    ]
    areas = placedb.node_size_x[hard_macro_indices] * placedb.node_size_y[
        hard_macro_indices]

    # Compute degrees vectorized - count nets each node appears in (single pass)
    degree_per_node = np.zeros(placedb.num_nodes, dtype=np.int32)
    for pin_ids in placedb.net2pin_map:
        nodes_in_net = np.unique(placedb.pin2node_map[pin_ids])
        degree_per_node[nodes_in_net] += 1

    # Extract degrees for macros only
    degrees = {name: int(degree_per_node[idx])
               for idx, name in zip(hard_macro_indices, hard_macro_names)}

    # Sort by area (desc), then degree (desc)
    name_to_area = {hard_macro_names[i]: areas[i] for i in
                    range(len(hard_macro_names))}
    sorted_names = sorted(
            hard_macro_names,
            key=lambda n: (-name_to_area[n], -degrees[n])
    )

    # Cache for next time
    if use_cache:
        cache_data = {'names': sorted_names, 'degrees': degrees}
        with open(cache_file, 'wb') as f:
            pickle.dump(cache_data, f)
        logging.info(
                f"Computed and cached macro ordering for {benchmark_name}: {len(sorted_names)} macros")

    return sorted_names, degrees
