import random
import time

import numpy as np
from absl import logging
from scipy.spatial.distance import pdist
from scipy.stats import rankdata
from sklearn.cluster import MiniBatchKMeans

from veoplace.utils import format_scientific
from veoplace.utils.constants import EPS_GREEDY
from veoplace.utils.constants import KMEANS_BATCH_SIZE
from veoplace.utils.constants import KMEANS_NUM_ITERATIONS
from veoplace.utils.constants import MAX_HISTORY_CLUSTERS
from veoplace.utils.constants import RANK_SOFTMAX_TEMP
from veoplace.utils.dp_array_cache import enrich_episode_for_render
from veoplace.utils.render import render_full_canvas


def _require_dp_render_arrays(context: dict) -> tuple:
    dp_size_x = context.get("dreamplace_node_size_x")
    dp_size_y = context.get("dreamplace_node_size_y")
    dp_num_movable = context.get("dreamplace_num_movable")
    dp_num_terminals = context.get("dreamplace_num_terminals")

    if dp_size_x is None or dp_size_y is None:
        raise ValueError("Prompt rendering requires DREAMPlace size arrays in context")
    if dp_num_movable is None:
        raise ValueError("Prompt rendering requires dreamplace_num_movable in context")
    if dp_num_terminals is None:
        raise ValueError("Prompt rendering requires dreamplace_num_terminals in context")

    return dp_size_x, dp_size_y, dp_num_movable, dp_num_terminals


def limit_history_by_quality(history, history_limit):
    # Heuristic for limiting the history based on quality so that clustering
    # is fast.
    cap = MAX_HISTORY_CLUSTERS * history_limit
    if len(history) <= cap:
        logging.info(
                "History is small enough (%d) to not need limiting. Returning all episodes.",
                len(history))
        return history
    logging.info(
            "History is too large (%d) to use all episodes. Limiting to %d best episodes.",
            len(history), cap)
    return sorted(history, key=lambda ep: (ep.get("macro_overlap", 0), ep["hpwl"]))[:cap]


def history_variable_regions(**kwargs):
    """
    Build prompt section that shows, for each previous episode, where the
    *selected* macros were placed.  Uses the lightweight context fields that
    are unpacked into **kwargs – no env object required.
    """
    prompt = []

    # --------- unpack context + parameters --------------------------------
    history_mode = kwargs["history_mode"]
    selected_macros = kwargs["selected_macros"]  # list[str]

    node_names = kwargs["node_name_list"]
    placed_num_macro = kwargs["placed_num_macro"]
    node_to_idx = kwargs["node_to_idx"]
    dims_int = kwargs["node_dims_int"]  # (M,2) ints
    dims_real = kwargs["node_dims_real"]  # (M,2) µm
    ratio_x = kwargs["ratio_x"]
    ratio_y = kwargs["ratio_y"]
    short_names = kwargs["node_name_to_short_name"]
    color_config = kwargs["color_config"]

    # ----------------------------------------------------------------------
    prompt.append("# PREVIOUS PLACEMENT EPISODES:\n\n")
    prompt.append(
            "Below are previous episodes with their final results. For each episode, you'll see:\n\n"
            " - **Macro Positions**: Shows where the selected macros you need to place were put on the canvas in previous episodes\n"
            " - **Canvas Image**: Shows the final state of the canvas with:\n"
            "     * The names of each macro you need to place drawn directly on the macro\n"
            "     * These selected macros outlined in red for easy identification\n"
            " - **Final Metrics**: The overall quality metrics of the completed chip design\n\n"
    )

    history = HISTORY_STRATEGIES[history_mode](**kwargs)
    sorted_episodes = sorted(history, key=lambda x: (x.get("macro_overlap", 0), x["hpwl"]))
    kwargs["processed_history"] = sorted_episodes

    for episode_idx, ep in enumerate(sorted_episodes, 1):
        positions = ep["positions"]  # (M,4)
        pos_dict = {}

        prompt.append(f"## Episode #{episode_idx}\n\n")
        prompt.append("### Position of Selected Macros:\n")

        for i, long_name in enumerate(node_names[:placed_num_macro]):
            # store full placement for rendering (needs real sizes)
            macro_idx = node_to_idx[long_name]
            real_w, real_h = dims_real[macro_idx]
            pos_dict[long_name] = (*positions[i], real_w, real_h)

            # add human-readable line only for the selected macros
            if long_name in selected_macros:
                short = short_names[long_name]
                x, y, *_ = positions[i]
                w_int, h_int = dims_int[macro_idx]
                grid_x, grid_y = int(round(x / ratio_x)), int(
                        round(y / ratio_y))
                max_x, max_y = grid_x + w_int, grid_y + h_int
                prompt.append(
                        f" - {short}: ({grid_x},{grid_y}) to ({max_x},{max_y})\n"
                )

        prompt.append("\n### Canvas Image:\n")

        # Load DREAMPlace arrays from disk if needed (on-demand for rendering)
        # Uses traj_id from episode + output_dir to find: {output_dir}/data/dp/{traj_id:06d}.npz
        ep_enriched = enrich_episode_for_render(ep, kwargs.get("output_dir"))

        dp_size_x, dp_size_y, dp_num_movable, dp_num_terminals = _require_dp_render_arrays(kwargs)

        final_image = render_full_canvas(
                node_pos=pos_dict,
                color_config=color_config,
                grid=ep["grid"],
                max_width=ep["max_width"],
                max_height=ep["max_height"],
                return_bytes=False,
                node_name_to_short_name=short_names,
                highlight_nodes=selected_macros,
                # DREAMPlace full node positions for enhanced rendering (loaded on-demand)
                dreamplace_node_x=ep_enriched.get("dreamplace_node_x"),
                dreamplace_node_y=ep_enriched.get("dreamplace_node_y"),
                dreamplace_node_size_x=dp_size_x,
                dreamplace_node_size_y=dp_size_y,
                dreamplace_num_movable=dp_num_movable,
                dreamplace_num_terminals=dp_num_terminals,
        )
        prompt.append(final_image)
        prompt.append(
                "\n - The image above shows the final placement with the selected "
                "macros you need to place outlined in red and labeled with their names. "
                "The blue texture represents standard cells placed by the analytical "
                "placer. Dark gray rectangles are fixed I/O terminals.\n\n"
        )

        prompt.append(f"### Results for Episode #{episode_idx}:\n")
        prompt.append(
                f"- Wirelength: {format_scientific(ep['hpwl'])}\n")
        prompt.append(
                f"- Macro Overlap: {format_scientific(ep.get('macro_overlap', 0.0))}\n\n")
        prompt.append("---\n\n")

    return prompt, kwargs


def history_all_regions(**kwargs):
    """
    Build prompt section that shows every episode and where the FIRST macro
    of each colour-group ended up.  Uses the lightweight context fields
    passed by collect_rollouts – **no `env` object required**.
    """
    prompt = []

    # ---------- unpack what we need from kwargs ---------------------------
    history = kwargs['history']
    selected_macros = kwargs['selected_macros']
    placed_num_macro = kwargs['placed_num_macro']

    node_names = kwargs['node_name_list']  # list[str]
    node_shortname = kwargs['node_name_to_short_name']  # {long:short}
    node_to_idx = kwargs['node_to_idx']  # {long:int}
    dims_int = kwargs['node_dims_int']  # (M,2) int
    color_config = kwargs['color_config']
    first_macro_of_color_gp = kwargs['first_macro_of_color_group']

    # ---------------------------------------------------------------------
    prompt.append("# PREVIOUS PLACEMENT EPISODES:\n\n")
    prompt.append(
            "Below are previous episodes with their final results. For each episode, you'll see:\n\n"
            " - **Position of First Macro for Each Color Group**: Shows where the first macro of each color group was placed\n"
            " - **Canvas Image**: Shows the final state of the canvas with:\n"
            "     * The names of the first macro of each color group drawn directly on the macro\n"
            "     * These first macros outlined in red for easy identification\n"
            " - **Final Metrics**: The overall quality metrics of the completed chip design\n\n"
    )

    sorted_episodes = sorted(history, key=lambda x: (x.get("macro_overlap", 0), x["hpwl"]))
    kwargs["processed_history"] = sorted_episodes

    for episode_idx, ep in enumerate(sorted_episodes, 1):
        positions = ep["positions"]  # (M,4) array for that ep
        pos_dict = {}

        prompt.append(f"## Episode #{episode_idx}\n\n")
        prompt.append("### Position of First Macro for Each Color Group:\n")

        # iterate over the macros placed by RL agent (fixed order)
        for i, long_name in enumerate(node_names[:placed_num_macro]):
            pos_dict[long_name] = positions[i]

            if long_name in selected_macros:
                short = node_shortname[long_name]
                x, y, *_ = positions[i]

                macro_idx = node_to_idx[long_name]
                w_int, h_int = dims_int[macro_idx]  # grid units
                max_x, max_y = x + w_int, y + h_int

                prompt.append(
                        f" - {short} occupies ({x},{y}) to ({max_x},{max_y})\n"
                )

        prompt.append("\n### Canvas Image:\n")

        # Load DREAMPlace arrays from disk if needed (on-demand for rendering)
        # Uses traj_id from episode + output_dir to find: {output_dir}/data/dp/{traj_id:06d}.npz
        ep_enriched = enrich_episode_for_render(ep, kwargs.get("output_dir"))

        dp_size_x, dp_size_y, dp_num_movable, dp_num_terminals = _require_dp_render_arrays(kwargs)

        final_image = render_full_canvas(
                node_pos=pos_dict,
                color_config=color_config,
                grid=ep["grid"],
                max_width=ep["max_width"],
                max_height=ep["max_height"],
                return_bytes=False,
                node_name_to_short_name=node_shortname,
                highlight_nodes=first_macro_of_color_gp.keys(),
                # DREAMPlace full node positions for enhanced rendering (loaded on-demand)
                dreamplace_node_x=ep_enriched.get("dreamplace_node_x"),
                dreamplace_node_y=ep_enriched.get("dreamplace_node_y"),
                dreamplace_node_size_x=dp_size_x,
                dreamplace_node_size_y=dp_size_y,
                dreamplace_num_movable=dp_num_movable,
                dreamplace_num_terminals=dp_num_terminals,
        )
        prompt.append(final_image)
        prompt.append(
                "\n - The image above shows the final placement with the first "
                "macro of each color group outlined in red and labeled with its name. "
                "The blue texture represents standard cells placed by the analytical "
                "placer. Dark gray rectangles are fixed I/O terminals.\n\n"
        )

        prompt.append(f"### Results for Episode #{episode_idx}:\n")
        prompt.append(
                f"- Wirelength: {format_scientific(ep['hpwl'])}\n")
        prompt.append(
                f"- Macro Overlap: {format_scientific(ep.get('macro_overlap', 0.0))}\n\n")
        prompt.append("---\n\n")

    return prompt, kwargs


def first_in_first_out_history(**kwargs):
    history = kwargs['history']
    history_limit = kwargs['history_limit']
    if history_limit <= 0:
        return []

    return history[-history_limit:]


def best_history(**kwargs):
    # Sort history by wirelength (best to worst) and return the best N episodes
    history = kwargs['history']
    history_limit = kwargs['history_limit']
    if history_limit <= 0:
        return []

    sorted_history = sorted(history, key=lambda x: (x.get('macro_overlap', 0), x['hpwl']))
    return sorted_history[:history_limit]


def worst_history(**kwargs):
    # Sort history by wirelength (worst to best) and return the worst N episodes
    history = kwargs['history']
    history_limit = kwargs['history_limit']
    if history_limit <= 0:
        return []

    sorted_history = sorted(history, key=lambda x: (x.get('macro_overlap', 0), x['hpwl']))
    return sorted_history[-history_limit:]


def half_best_half_worst(**kwargs):
    history = kwargs['history']
    history_limit = kwargs['history_limit']
    if history_limit <= 0:
        return []

    half = history_limit // 2
    best_half = best_history(**(kwargs | {'history': history,
                                          'history_limit': history_limit - half
                                          }))
    worst_half = worst_history(**(kwargs | {'history': history,
                                            'history_limit': half}))
    full_history = best_half + worst_half
    return full_history


def random_history(**kwargs):
    history = kwargs['history']
    history_limit = kwargs['history_limit']
    if history_limit <= 0:
        return []

    sample_size = min(history_limit, len(history))
    return random.sample(history, sample_size)


def _cluster_and_select_probs(placement_node_array, history, history_limit,
        seed=0, description="", ):
    """
    Helper function that performs clustering on placement vectors and selects episodes
    using softmax sampling over cluster quality (rank-based).

    Args:
        placement_node_array: Numpy array of shape (num_episodes, feature_dimensions)
        history: List of episode dictionaries
        history_limit: Maximum number of episodes to return
        seed: Random seed for clustering
        description: Description string for logging

    Returns:
        List of selected episode dictionaries
    """

    wirelengths = [episode['hpwl'] for episode in history]
    num_history = len(history)

    # Calculate number of clusters with proper bounds
    num_clusters = min(
            max(1, num_history // history_limit),
            # At least 1, targeting len/limit clusters
            MAX_HISTORY_CLUSTERS,  # Don't exceed maximum allowed clusters
            num_history  # Can't have more clusters than items
    )

    # --- Handle edge case: 0 or 1 cluster ---
    if num_clusters <= 1:
        logging.info(
                f"Only {num_clusters} cluster(s) found. Selecting best {history_limit} episodes.")
        sorted_indices = sorted(range(num_history),
                                key=lambda idx: wirelengths[idx])
        selected_indices = sorted_indices[:min(history_limit, num_history)]

        # Log diversity etc. if applicable (adapted from end of function)
        if len(selected_indices) >= 2 and placement_node_array is not None:
            selected_placements = placement_node_array[selected_indices]
            # Check if selected_placements has enough samples and dimensions for pdist
            if selected_placements.shape[0] >= 2 and selected_placements.shape[
                1] > 0:
                try:
                    pairwise_dists = pdist(selected_placements, 'euclidean')
                    diversity = np.mean(pairwise_dists)
                    logging.info("Selected episodes diversity score: %.2f",
                                 diversity)
                except ValueError:
                    logging.warning(
                            "Could not calculate diversity in 0/1 cluster case: placement data not suitable for pdist.")

        return [history[idx] for idx in selected_indices]

    # --- Proceed with Clustering for num_clusters > 1 ---
    start = time.perf_counter()
    kmeans = MiniBatchKMeans(
            max_iter=KMEANS_NUM_ITERATIONS,
            batch_size=KMEANS_BATCH_SIZE,
            n_clusters=num_clusters,
            random_state=seed,
    )
    clusters = kmeans.fit_predict(placement_node_array)
    end = time.perf_counter()
    logging.info(
            f"Clustering {num_history} episodes into {num_clusters} clusters{description} took {end - start:.2f} seconds")

    # Analyze clusters to find min wirelength for each
    cluster_min_values = []
    cluster_sizes = []
    cluster_indices_by_id = []  # Store indices for each cluster

    for cluster_id in range(num_clusters):
        cluster_indices = np.where(clusters == cluster_id)[0]
        cluster_indices_by_id.append(cluster_indices)
        cluster_sizes.append(len(cluster_indices))
        cluster_wirelengths = [float(wirelengths[idx]) for idx in
                               cluster_indices]
        if len(cluster_wirelengths) > 0:
            cluster_min_values.append(np.min(cluster_wirelengths))
        else:
            cluster_min_values.append(float('inf'))

    cluster_min_values = np.array(cluster_min_values)

    # For logging: identify best cluster by min wirelength (still useful)
    best_cluster_id = np.argmin(cluster_min_values)

    # --- Rank-Based Softmax Sampling ---

    # Handle edge case: all cluster minimums are identical
    if np.all(cluster_min_values == cluster_min_values[0]):
        logging.info(
                "All cluster minimum wirelengths are identical. Sampling uniformly.")
        probs = np.ones(num_clusters) / num_clusters
        selected_cluster_id = np.random.choice(range(num_clusters), p=probs)
        logging.info(
                f"Sampled cluster {selected_cluster_id} with uniform probability {probs[selected_cluster_id]:.4f}")

    else:
        # Use rank-based softmax sampling to select a cluster
        # Lower wirelength is better, so we rank based on values,
        # and assign scores such that better (lower) wirelength rank gets a higher score.
        # rankdata assigns rank 1 to the smallest value.
        # We use 'average' method for ties.
        ranks = rankdata(cluster_min_values, method='average')

        # Create scores from ranks: higher score for lower rank (better wirelength)
        # scores = -ranks => Rank 1 gets score -1, Rank 2 gets score -2, etc.
        scores = -ranks

        # Apply softmax with temperature
        # Use the defined RANK_SOFTBOX_TEMP
        # Subtract max score for numerical stability (max score is min_rank * -1)
        max_score = np.max(scores)
        shifted_scores = (
                                 scores - max_score) / RANK_SOFTMAX_TEMP  # Use the dedicated temperature
        exp_shifted_scores = np.exp(shifted_scores)

        # Handle potential numerical issues if exp_shifted_scores sum is zero
        sum_exp_scores = exp_shifted_scores.sum()
        if sum_exp_scores < 1e-9:
            logging.warning(
                    "Softmax sum is near zero, using uniform probabilities.")
            probs = np.ones(num_clusters) / num_clusters
        else:
            probs = exp_shifted_scores / sum_exp_scores

        # Sample a cluster according to the softmax distribution
        selected_cluster_id = np.random.choice(range(num_clusters), p=probs)

        # Detailed logging (adapted from original, using ranks)
        # best_cluster_id is already calculated

        logging.info("Cluster min values: %s", [f"{v:.2f}" for v in
                                                cluster_min_values.tolist()])  # Format for readability
        logging.info("Cluster ranks (1=best): %s", [f"{r:.1f}" for r in
                                                    ranks.tolist()])  # Format for readability
        logging.info("Cluster sizes: %s", [f"{s}" for s in cluster_sizes])

        logging.info(
                f"Cluster selection probabilities (Rank Softmax, Temp={RANK_SOFTMAX_TEMP}): {probs.round(4).tolist()}")
        logging.info(
                f"Sampled cluster {selected_cluster_id} with probability {probs[selected_cluster_id]:.4f}  and size {cluster_sizes[selected_cluster_id]} " +
                f"(Rank {ranks[selected_cluster_id]:.1f})")

        if selected_cluster_id == best_cluster_id:
            logging.info("Selected the best cluster (by wirelength rank)")
        else:
            # Find rank of sampled cluster (only if not uniform case)
            sampled_rank = ranks[selected_cluster_id]
            best_rank = ranks[best_cluster_id]
            logging.info(
                    f"Selected cluster {selected_cluster_id} (Rank {sampled_rank:.1f}) " +
                    f"with min wirelength {cluster_min_values[selected_cluster_id]:.2f} " +
                    f"(best was cluster {best_cluster_id} Rank {best_rank:.1f} with {cluster_min_values[best_cluster_id]:.2f})")

    # ------------------------------------------------------
    # From selected cluster, take up to history_limit best rollouts.
    # If needed, dip into the most similar clusters based on distance.
    # ------------------------------------------------------
    # This part remains the same, using the selected_cluster_id determined above
    selected_indices = []

    # First, add episodes from the selected cluster
    selected_cluster_indices = cluster_indices_by_id[selected_cluster_id]
    sorted_selected_indices = sorted(selected_cluster_indices,
                                     key=lambda idx: wirelengths[idx])
    selected_indices.extend(
            sorted_selected_indices[
            :min(history_limit, len(sorted_selected_indices))])

    # If we need more episodes, select from the most similar clusters
    if len(selected_indices) < history_limit:
        # Get the center of the selected cluster
        selected_center = kmeans.cluster_centers_[selected_cluster_id]

        # Calculate distances from each cluster center to the selected cluster center
        center_distances = []
        for i in range(num_clusters):
            if i == selected_cluster_id:
                center_distances.append(
                        float('inf'))  # Don't want to select from this again
            else:
                center = kmeans.cluster_centers_[i]
                distance = np.linalg.norm(center - selected_center)
                center_distances.append(distance)

        # Sort clusters by distance
        clusters_by_distance = list(range(num_clusters))
        # Ensure the selected cluster isn't in the list before removing
        if selected_cluster_id in clusters_by_distance:
            clusters_by_distance.remove(selected_cluster_id)
        clusters_by_distance.sort(key=lambda i: center_distances[i])

        # Dip into clusters by increasing distance
        for cluster_id in clusters_by_distance:
            # Get indices from this cluster sorted by wirelength
            cluster_indices = cluster_indices_by_id[cluster_id]
            sorted_indices = sorted(cluster_indices,
                                    key=lambda idx: wirelengths[idx])

            # Add as many episodes as needed
            remaining = history_limit - len(selected_indices)
            if remaining <= 0:
                break

            logging.info(
                    "Dipping into cluster %d for %d additional episodes. Currently have %d/%d episodes.",
                    cluster_id, min(remaining, len(sorted_indices)),
                    len(selected_indices), history_limit)

            selected_indices.extend(sorted_indices[:remaining])

            if len(selected_indices) == history_limit:
                break

    # Selected episodes info
    selected_wirelengths = [wirelengths[idx] for idx in selected_indices]

    logging.info("Selected episode indices: %s", selected_indices)
    logging.info("Selected episode wirelengths: %s", [f"{wl:.2f}" for wl in
                                                      selected_wirelengths])  # Format for readability

    if len(selected_indices) < history_limit:
        logging.info(
                "Note: Selected only %d episodes instead of requested %d",
                len(selected_indices), history_limit)

    # Calculate diversity of selected episodes
    if len(selected_indices) >= 2 and placement_node_array is not None:
        selected_placements = placement_node_array[selected_indices]
        # Check if selected_placements has enough samples and dimensions for pdist
        if selected_placements.shape[0] >= 2 and selected_placements.shape[
            1] > 0:
            try:
                pairwise_dists = pdist(selected_placements, 'euclidean')
                diversity = np.mean(pairwise_dists)
                logging.info("Selected episodes diversity score: %.2f",
                             diversity)
            except ValueError:
                logging.warning(
                        "Could not calculate diversity: placement data not suitable for pdist.")

    # Return up to history_limit episodes
    best_episodes = [history[idx] for idx in selected_indices]

    return best_episodes


def maximum_top_stratified_prob_history(**kwargs):
    """
    Binary‑search the largest k whose *best* cluster has ≥ history_limit
    roll‑outs, then sample a cluster probabilistically (rank‑softmax on the
    cluster’s best HPWL).  No second K‑means pass is performed.
    """
    history_limit = kwargs["history_limit"]
    if history_limit <= 0:
        return []

    history = limit_history_by_quality(kwargs["history"], history_limit)

    seed = kwargs.get("seed", 0)

    if not history:
        raise ValueError(
                "History is empty – cannot build top stratified history.")

    # ------------------------------------------------------------------ #
    # 1⃣  Build placement matrix  X  (N × (M*2))                         #
    # ------------------------------------------------------------------ #
    pos_arr = np.stack([ep["positions"] for ep in history])  # N×M×4
    placement_node_array = pos_arr[:, :, :2].reshape(len(history),
                                                     -1).astype(
            np.float64)  # N×(M*2)
    wirelengths = np.fromiter((ep["hpwl"] for ep in history),
                              dtype=np.float64)

    # ------------------------------------------------------------------ #
    # 2⃣  Binary‑search on k                                             #
    # ------------------------------------------------------------------ #
    min_k, max_k = 2, min(len(history) - 1, MAX_HISTORY_CLUSTERS)
    best_diversity = np.inf
    best_cluster_indices = None  # list[np.ndarray] for the chosen k
    best_cluster_mins = None
    best_cluster_sizes = None
    best_cluster_centers = None

    logging.info("Binary-searching k in [%d, %d]", min_k, max_k)
    t0_total = time.perf_counter()

    while min_k <= max_k:
        k = (min_k + max_k) // 2
        kmeans = MiniBatchKMeans(
                max_iter=KMEANS_NUM_ITERATIONS,
                batch_size=KMEANS_BATCH_SIZE,
                n_clusters=k,
                random_state=seed,
        )
        clusters = kmeans.fit_predict(placement_node_array)

        # ── gather per-cluster stats ───────────────────────────────────────
        cluster_indices_by_id = [np.where(clusters == cid)[0]  # ← was labels
                                 for cid in range(k)]
        cluster_sizes = [len(idx) for idx in cluster_indices_by_id]
        cluster_min_values = [np.min(wirelengths[idx]) if len(idx) else np.inf
                              for idx in cluster_indices_by_id]

        best_cluster_id = int(np.argmin(cluster_min_values))
        best_cluster_sz = cluster_sizes[best_cluster_id]

        logging.info(
                "k=%d: best cluster id=%d, size=%d, min_HPWL=%.2f",
                k, best_cluster_id, best_cluster_sz,
                cluster_min_values[best_cluster_id]
        )

        if best_cluster_sz >= history_limit:  # feasible
            idx_sorted = cluster_indices_by_id[best_cluster_id] \
                [np.argsort(wirelengths[
                                cluster_indices_by_id[best_cluster_id]])]
            diversity = pdist(
                    placement_node_array[idx_sorted[:history_limit]],
                    "euclidean"
            ).mean() if history_limit > 1 else 0.0

            logging.info("  feasible (diversity=%.2f) – moving min_k up",
                         diversity)

            if diversity < best_diversity:
                # ★ store everything we need so we DON'T have to recluster
                best_cluster_centers = kmeans.cluster_centers_.copy()  # <── add
                best_diversity = diversity
                best_cluster_indices = cluster_indices_by_id
                best_cluster_mins = cluster_min_values
                best_cluster_sizes = cluster_sizes
            else:
                logging.info(
                        "Early-stop: diversity rose from %.3f to %.3f.",
                        best_diversity, diversity
                )
                break

            min_k = k + 1  # search upper half
        else:  # infeasible
            logging.info("  cluster too small – moving max_k down")
            max_k = k - 1  # search lower half

    logging.info("Binary search finished in %.2fs",
                 time.perf_counter() - t0_total)

    # ------------------------------------------------------------------ #
    # 3⃣  Fallback if no k worked                                        #
    # ------------------------------------------------------------------ #
    if best_cluster_indices is None:
        logging.warning("No k produced ≥%d episodes; using fallback.",
                        history_limit)
        return _cluster_and_select_probs(
                placement_node_array, history, history_limit, seed,
                " (fallback from max_undiverse_prob)"
        )
        # otherwise show the diversity score

    # ------------------------------------------------------------------ #
    # 4⃣  Probabilistic cluster pick (rank-softmax)                      #
    # ------------------------------------------------------------------ #
    valid_cids = [cid for cid, sz in enumerate(best_cluster_sizes) if
                  sz > 0]

    cluster_mins = np.array([best_cluster_mins[cid] for cid in valid_cids])
    cluster_szs = [best_cluster_sizes[cid] for cid in valid_cids]

    ranks = rankdata(cluster_mins, method="average")  # 1 = best
    scores = -ranks  # high = good
    if np.all(cluster_mins == cluster_mins[0]):  # all ties
        probs = np.ones(len(valid_cids)) / len(valid_cids)
    else:
        scores = (scores - scores.max()) / RANK_SOFTMAX_TEMP
        probs = np.exp(scores)
        probs /= probs.sum()

    logging.info("Cluster min HPWLs: %s",
                 [f"{v:.2f}" for v in cluster_mins])
    logging.info("Cluster ranks (1=best): %s", [f"{r:.1f}" for r in ranks])
    logging.info("Cluster sizes: %s", cluster_szs)
    logging.info("Selection probs (T=%.2f): %s",
                 RANK_SOFTMAX_TEMP, probs.round(4).tolist())

    chosen_cid = np.random.choice(valid_cids, p=probs)
    chosen_rank = ranks[valid_cids.index(chosen_cid)]
    logging.info(
            "Sampled cluster %d (p=%.4f, size=%d, rank=%.1f, min hpwl=%.2f)",
            chosen_cid, probs[valid_cids.index(chosen_cid)],
            best_cluster_sizes[chosen_cid], chosen_rank,
            best_cluster_mins[chosen_cid]  # min HPWL
    )
    logging.info("Selected episodes diversity score: %.2f",
                 best_diversity)

    # ------------------------------------------------------------------ #
    # 5⃣  Collect roll-outs: first the chosen cluster,                   #
    #     then nearest clusters until history_limit is met               #
    # ------------------------------------------------------------------ #
    selected_idx = []

    # (a) best episodes from chosen cluster
    idx_in_chosen = best_cluster_indices[chosen_cid]
    idx_sorted = idx_in_chosen[np.argsort(wirelengths[idx_in_chosen])]
    take = min(history_limit, len(idx_sorted))
    selected_idx.extend(idx_sorted[:take])

    # (b) top-up from clusters closest in centroid space
    if len(selected_idx) < history_limit:
        chosen_center = best_cluster_centers[chosen_cid]
        center_dists = np.linalg.norm(best_cluster_centers - chosen_center,
                                      axis=1)

        for cid in np.argsort(center_dists):  # ascending distance
            if cid == chosen_cid:
                continue
            idxs = best_cluster_indices[cid]
            if idxs.size == 0:
                continue
            idxs_sorted = idxs[np.argsort(wirelengths[idxs])]
            remaining = history_limit - len(selected_idx)
            selected_idx.extend(idxs_sorted[:remaining])

            logging.info(
                    "Topped up with %d episodes from cluster %d (dist=%.4f)",
                    min(remaining, len(idxs_sorted)),
                    cid, center_dists[cid])

            if len(selected_idx) == history_limit:
                break

    # final guard in case we overfilled
    selected_idx = selected_idx[:history_limit]

    logging.info("Returning %d episodes (cluster %d + neighbors).",
                 len(selected_idx), chosen_cid)

    return [history[i] for i in selected_idx]


def _cluster_and_select_best(placement_node_array, history, history_limit,
        seed=0, description="", eps=0.0):
    """
    Helper function that performs clustering on placement vectors and selects the best episodes.
    If the best cluster doesn't have enough samples, additional samples are taken from the
    most similar clusters (closest in feature space).

    Args:
        placement_node_array: Numpy array of shape (num_episodes, feature_dimensions)
        history: List of episode dictionaries
        history_limit: Maximum number of episodes to return
        seed: Random seed for clustering
        description: Description string for logging

    Returns:
        List of selected episode dictionaries
    """

    wirelengths = [episode['hpwl'] for episode in history]

    # Calculate number of clusters with proper bounds
    num_clusters = min(
            max(1, len(history) // history_limit),
            # At least 1, targeting len/limit clusters
            MAX_HISTORY_CLUSTERS,  # Don't exceed maximum allowed clusters
            len(history)  # Can't have more clusters than items
    )

    # Time the clustering operation
    start = time.perf_counter()
    kmeans = MiniBatchKMeans(
            max_iter=KMEANS_NUM_ITERATIONS,
            batch_size=KMEANS_BATCH_SIZE,
            n_clusters=num_clusters,
            random_state=seed,
    )
    clusters = kmeans.fit_predict(placement_node_array)
    end = time.perf_counter()
    logging.info(
            f"Clustering {len(history)} episodes into {num_clusters} clusters{description} took {end - start:.2f} seconds")

    # ------------------------------------------------------
    # 1. Determine which cluster is "best" by looking at the
    #    minimal wirelength in each cluster, and picking the
    #    cluster whose min wirelength is smallest overall.
    # ------------------------------------------------------
    cluster_min_values = []
    cluster_sizes = []
    cluster_wirelengths_by_cluster = []
    cluster_indices_by_id = []  # Store indices for each cluster

    for cluster_id in range(num_clusters):
        cluster_indices = np.where(clusters == cluster_id)[0]
        cluster_indices_by_id.append(cluster_indices)
        cluster_sizes.append(len(cluster_indices))
        cluster_wirelengths = [float(wirelengths[idx]) for idx in
                               cluster_indices]
        cluster_wirelengths_by_cluster.append(cluster_wirelengths)
        if len(cluster_wirelengths) > 0:
            cluster_min_values.append(np.min(cluster_wirelengths))
        else:
            cluster_min_values.append(float('inf'))

    # Get the best cluster ID (one with minimum wirelength)
    best_cluster_id = np.argmin(cluster_min_values)
    selected_cluster_id = best_cluster_id

    if eps > 0 and num_clusters > 1 and np.random.random() < eps:
        # Find valid alternative clusters
        alternatives = [i for i in range(num_clusters)
                        if i != best_cluster_id and cluster_sizes[i] > 0]
        if alternatives:
            selected_cluster_id = np.random.choice(alternatives)
            logging.info(
                    f"Exploring random cluster {selected_cluster_id} instead of best cluster {best_cluster_id}")

    # Detailed logging
    logging.info("Best cluster id: %d", best_cluster_id)
    if selected_cluster_id != best_cluster_id:
        logging.info("Using exploration: selected cluster id: %d",
                     selected_cluster_id)
    logging.info("Cluster min values: %s", cluster_min_values)
    logging.info("Cluster sizes: %s", cluster_sizes)

    logging.info(
            "Selected cluster %d with min wirelength %.2f",
            selected_cluster_id, cluster_min_values[selected_cluster_id])

    # ------------------------------------------------------
    # 2. From selected cluster, take up to history_limit best rollouts.
    # If needed, dip into the most similar clusters based on distance.
    # ------------------------------------------------------
    selected_indices = []

    # First, add episodes from the selected cluster
    selected_cluster_indices = cluster_indices_by_id[selected_cluster_id]
    sorted_selected_indices = sorted(selected_cluster_indices,
                                     key=lambda idx: wirelengths[idx])
    selected_indices.extend(
            sorted_selected_indices[
            :min(history_limit, len(sorted_selected_indices))])

    # If we need more episodes, select from the most similar clusters
    if len(selected_indices) < history_limit:
        # Get the center of the selected cluster
        selected_center = kmeans.cluster_centers_[selected_cluster_id]

        # Calculate distances from each cluster center to the selected cluster center
        center_distances = []
        for i in range(num_clusters):
            if i == selected_cluster_id:
                center_distances.append(0.0)  # Distance to self is 0
            else:
                center = kmeans.cluster_centers_[i]
                distance = np.linalg.norm(center - selected_center)
                center_distances.append(distance)

        # Sort clusters by distance (excluding the selected cluster)
        clusters_by_distance = list(range(num_clusters))
        clusters_by_distance.remove(
                selected_cluster_id)  # Remove selected cluster
        clusters_by_distance.sort(key=lambda i: center_distances[i])

        # Dip into clusters by increasing distance
        for cluster_id in clusters_by_distance:
            # Get indices from this cluster sorted by wirelength
            cluster_indices = cluster_indices_by_id[cluster_id]
            sorted_indices = sorted(cluster_indices,
                                    key=lambda idx: wirelengths[idx])

            # Add as many episodes as needed
            remaining = history_limit - len(selected_indices)
            if remaining <= 0:
                break

            logging.info(
                    "Dipping into cluster %d (distance: %.4f) for %d additional episodes. Currently have %d/%d episodes.",
                    cluster_id, center_distances[cluster_id],
                    min(remaining, len(sorted_indices)),
                    len(selected_indices), history_limit)

            selected_indices.extend(sorted_indices[:remaining])

            if len(selected_indices) == history_limit:
                break

    # Selected episodes info
    selected_wirelengths = [wirelengths[idx] for idx in selected_indices]

    logging.info("Selected episode indices: %s", selected_indices)
    logging.info("Selected episode wirelengths: %s", selected_wirelengths)

    if len(selected_indices) < history_limit:
        logging.info(
                "Note: Selected only %d episodes instead of requested %d",
                len(selected_indices), history_limit)

    # Calculate diversity of selected episodes
    if len(selected_indices) >= 2:
        selected_placements = placement_node_array[selected_indices]
        pairwise_dists = pdist(selected_placements, 'euclidean')
        diversity = np.mean(pairwise_dists)
        logging.info("Selected episodes diversity score: %.2f", diversity)

    # Return up to history_limit episodes
    best_episodes = [history[idx] for idx in selected_indices]

    return best_episodes


def maximum_undiverse_history(**kwargs):
    """
    Binary-search the largest k such that the *best* cluster still contains
    ≥ history_limit episodes; return the top-HPWL roll-outs from that cluster.
    """
    history_limit = kwargs['history_limit']
    if history_limit <= 0:
        return []

    history = limit_history_by_quality(kwargs["history"], history_limit)
    seed = kwargs.get('seed', 0)

    if not history:
        raise ValueError(
                "History is empty – maximum_undiverse_history needs data.")

    # ------------------------------------------------------------------ #
    # 1⃣  Build placement matrix  X  (N × (M*2))                         #
    # ------------------------------------------------------------------ #
    # stack -> (N, M, 4)   ; keep cols 0-1 (x, y) ; flatten last 2 dims
    pos_arr = np.stack([ep['positions'] for ep in history])  # N × M × 4
    placement_node_array = pos_arr[:, :, :2].reshape(len(history), -1).astype(
            np.float64)

    wirelengths = np.fromiter((ep['hpwl'] for ep in history),
                              dtype=np.float64)

    # ------------------------------------------------------------------ #
    # 2⃣  Binary-search on k                                             #
    # ------------------------------------------------------------------ #
    min_k = 2
    max_k = min(len(history) - 1, MAX_HISTORY_CLUSTERS)
    best_k, best_diversity, best_episodes = None, np.inf, None

    logging.info("Binary searching k in [%d, %d]", min_k, max_k)
    t0_total = time.perf_counter()

    while min_k <= max_k:
        k = (min_k + max_k) // 2
        logging.info("Trying k=%d clusters (range %d-%d)", k, min_k, max_k)
        kmeans = MiniBatchKMeans(
                max_iter=KMEANS_NUM_ITERATIONS,
                batch_size=KMEANS_BATCH_SIZE,
                n_clusters=k,
                random_state=seed,
        )
        clusters = kmeans.fit_predict(placement_node_array)
        # -- find best (lowest-HPWL) cluster ----------------------------
        best_cluster_id = None
        best_cluster_hpwl = np.inf
        cluster_indices_by_id = []

        for cid in range(k):
            idx = np.where(clusters == cid)[0]
            cluster_indices_by_id.append(idx)
            if idx.size:
                hpwl_min = wirelengths[idx].min()
                if hpwl_min < best_cluster_hpwl:
                    best_cluster_hpwl = hpwl_min
                    best_cluster_id = cid

        idx_best = cluster_indices_by_id[best_cluster_id]
        logging.info("k=%d: best cluster size = %d", k, idx_best.size)

        if idx_best.size >= history_limit:
            # --- viable: compute diversity of top history_limit roll-outs
            idx_sorted = idx_best[np.argsort(wirelengths[idx_best])][
                         :history_limit]
            sel = placement_node_array[idx_sorted]
            diversity = pdist(sel,
                              "euclidean").mean() if history_limit > 1 else 0.0

            logging.info("Viable (div=%.2f); moving min_k up", diversity)

            if diversity < best_diversity:
                best_diversity = diversity
                best_k = k
                best_episodes = [history[i] for i in idx_sorted]
            else:
                logging.info(
                        "Early-stop: diversity rose from %.3f to %.3f.",
                        best_diversity, diversity
                )
                break
            min_k = k + 1  # search upper half
        else:
            logging.info("Cluster too small; moving max_k down")
            max_k = k - 1  # search lower half

    logging.info("Binary search finished in %.2fs",
                 time.perf_counter() - t0_total)

    # ------------------------------------------------------------------ #
    # 3⃣  Fallback if no k worked                                        #
    # ------------------------------------------------------------------ #
    if best_episodes is None:
        logging.warning(
                "No k produced ≥%d episodes; falling back to distance pick.",
                history_limit)
        return _cluster_and_select_best(
                placement_node_array, history, history_limit, seed,
                " (fallback)"
        )

    logging.info("Selected k=%d with diversity %.2f (returned %d episodes)",
                 best_k, best_diversity, len(best_episodes))
    return best_episodes


def top_stratified_eps_greedy(**kwargs):
    history_limit = kwargs['history_limit']
    if history_limit <= 0:
        return []

    history = limit_history_by_quality(kwargs["history"], history_limit)
    seed = kwargs.get('seed', 0)

    if not history:
        raise ValueError(
                "History is empty – cannot use undiverse_history strategy.")

    # ------------------------------------------------------------------
    # 1⃣  Stack every (M, 4) positions array into shape (N, M, 4)
    # ------------------------------------------------------------------
    pos_arr = np.stack(
            [episode['positions'] for episode in history])  # N × M × 4

    # ------------------------------------------------------------------
    # 2⃣  Keep only (x, y) -> first two cols, then flatten per episode
    # ------------------------------------------------------------------
    #   pos_arr[:, :, :2]  ->  N × M × 2
    #   reshape(-1)        ->  N × (M*2)
    placement_node_array = pos_arr[:, :, :2].reshape(len(history), -1).astype(
            np.float64)  # 2-D

    return _cluster_and_select_best(
            placement_node_array,
            history,
            history_limit,
            seed,
            eps=EPS_GREEDY,
            description=" using ALL macros"
    )


def top_stratified_history(**kwargs):
    """
    Clusters the entire history on *all* macro positions (x, y of every macro),
    then samples a cluster via a soft-max over HPWL quality.

    Returns up to `history_limit` top rollouts from the chosen cluster.
    """
    history_limit = kwargs['history_limit']
    if history_limit <= 0:
        return []

    history = limit_history_by_quality(kwargs["history"], history_limit)
    seed = kwargs.get('seed', 0)

    if not history:
        raise ValueError(
                "History is empty – cannot use undiverse_history strategy.")

    # ------------------------------------------------------------------
    # 1⃣  Stack every (M, 4) positions array into shape (N, M, 4)
    # ------------------------------------------------------------------
    pos_arr = np.stack(
            [episode['positions'] for episode in history])  # N × M × 4

    # ------------------------------------------------------------------
    # 2⃣  Keep only (x, y) -> first two cols, then flatten per episode
    # ------------------------------------------------------------------
    #   pos_arr[:, :, :2]  ->  N × M × 2
    #   reshape(-1)        ->  N × (M*2)
    placement_node_array = pos_arr[:, :, :2].reshape(len(history), -1).astype(
            np.float64)  # 2-D

    return _cluster_and_select_probs(
            placement_node_array,
            history,
            history_limit,
            seed,
            " using ALL macros (array version)",
    )


def diverse_history(**kwargs):
    """
    Return up to `history_limit` diverse, high-quality episodes.

    •  history             – list of episode dicts
    •  history_limit       – max episodes to return
    •  selected_macros     – list[str] | None     (if None we use ALL macros)
    •  node_name_list      – order of macros as used by the agent
    •  placed_num_macro    – number of macros placed by the agent
    •  seed                – rng seed (optional)
    """
    history = kwargs['history']
    history_limit = kwargs['history_limit']
    if history_limit <= 0:
        return []

    selected_macros = kwargs['selected_macros']  # list[str] | None
    seed = kwargs.get('seed', 0)

    if not history:
        raise ValueError("History is empty – cannot use diverse_history.")

    # ---------- macro order / indices from context ------------------------
    node_names = kwargs['node_name_list']  # list[str]
    placed_num_macro = kwargs['placed_num_macro']

    # macros actually placed by RL agent (fixed order)
    all_macros = node_names[:placed_num_macro]

    # choose rows (macros) that matter
    if selected_macros:
        row_idx = [all_macros.index(m) for m in selected_macros]
    else:
        row_idx = slice(None)  # keep all rows

    # ---------- build placement matrix  (N, features) ---------------------
    # stack -> (N, M, 4) ; keep x,y ; flatten
    pos_arr = np.stack([ep['positions'] for ep in history])  # N × M × 4
    placement_node_array = pos_arr[:, row_idx, :2].reshape(
            len(history), -1).astype(np.float64)  # N × (K*2)

    wirelengths = np.fromiter((ep['hpwl'] for ep in history),
                              dtype=np.float64)

    # ---------- simple k-means (k = min(N, history_limit)) ----------------
    num_clusters = min(len(history), history_limit)
    t0 = time.perf_counter()
    kmeans = MiniBatchKMeans(
            max_iter=KMEANS_NUM_ITERATIONS,
            batch_size=KMEANS_BATCH_SIZE,
            n_clusters=num_clusters,
            random_state=seed,
    )
    clusters = kmeans.fit_predict(placement_node_array)
    logging.info("Clustering %d roll-outs into %d clusters took %.2fs",
                 len(history), num_clusters, time.perf_counter() - t0)

    # ---------- pick lowest-HPWL episode from each cluster ----------------
    diverse_episodes = []
    for cid in range(num_clusters):
        idx_in_cluster = np.where(clusters == cid)[0]
        if idx_in_cluster.size == 0:
            continue
        best_idx = idx_in_cluster[np.argmin(wirelengths[idx_in_cluster])]
        diverse_episodes.append(history[best_idx])

    return diverse_episodes


def half_best_half_top_stratified_history(**kwargs):
    history = kwargs['history']
    history_limit = kwargs['history_limit']
    if history_limit <= 0:
        return []

    half = history_limit // 2
    undiverse_half = top_stratified_history(**(kwargs | {'history': history,
                                                         'history_limit': history_limit - half
                                                         }))
    remaining_history = [episode for episode in history if
                         episode not in undiverse_half]
    best_half = best_history(**(kwargs | {'history': remaining_history,
                                          'history_limit': half
                                          }))
    full_history = undiverse_half + best_half
    return full_history


def half_best_half_diverse_history(**kwargs):
    history = kwargs['history']
    history_limit = kwargs['history_limit']
    if history_limit <= 0:
        return []

    half = history_limit // 2
    diverse_half = diverse_history(**(kwargs | {'history': history,
                                                'history_limit': history_limit - half
                                                }))
    remaining_history = [episode for episode in history if
                         episode not in diverse_half]
    best_half = best_history(**(kwargs | {'history': remaining_history,
                                          'history_limit': half
                                          }))
    full_history = diverse_half + best_half
    return full_history


def half_fifo_half_diverse_history(**kwargs):
    history = kwargs['history']
    history_limit = kwargs['history_limit']
    if history_limit <= 0:
        return []

    half = history_limit // 2
    diverse_half = diverse_history(**(kwargs | {'history': history,
                                                'history_limit': history_limit - half
                                                }))
    remaining_history = [episode for episode in history if
                         episode not in diverse_half]

    fifo_half = first_in_first_out_history(
            **(kwargs | {'history': remaining_history,
                         'history_limit': half}))
    full_history = diverse_half + fifo_half
    return full_history


HISTORY_STRATEGIES = dict(fifo=first_in_first_out_history,
                          half_best_half_undiverse=half_best_half_top_stratified_history,
                          best=best_history,
                          undiverse_eps_greedy=top_stratified_eps_greedy,
                          undiverse=top_stratified_history,
                          top_stratified=maximum_top_stratified_prob_history,
                          worst=worst_history,
                          diverse=diverse_history,
                          random=random_history,
                          half_fifo_half_diverse=half_fifo_half_diverse_history,
                          half_best_half_worst=half_best_half_worst,
                          half_best_half_diverse=half_best_half_diverse_history)
