from __future__ import annotations
from collections import defaultdict
from contextlib import nullcontext
import math

import math
import numpy as np
import torch
from torch import nn
from typing import Dict, List, Tuple, TYPE_CHECKING, Literal, Optional, Iterable, Union
from torch.utils.checkpoint import checkpoint  # Import checkpointing utility

from utils.logs import SharedLogger

if TYPE_CHECKING:
    from prune.halpe import HALPE

@torch.inference_mode()
def assert_all_params_finite(model):
    bad = []
    for name, p in model.named_parameters():
        if not torch.isfinite(p).all().item():
            bad.append(name)
    if bad:
        raise RuntimeError(f"Non-finite params in: {bad[:5]}{'...' if len(bad)>5 else ''}")

@torch.inference_mode()
def report_param_extrema(model, topk=10):
    big = []
    for name, p in model.named_parameters():
        m = p.abs().max().item()
        if m > 1e6 or not torch.isfinite(torch.tensor(m)):
            big.append((m, name))
    for m,name in sorted(big, reverse=True)[:topk]:
        print(f"EXTREME {name}: max|w|={m:.3e}")

def report_param_extrema_and_finite(model, topk=10):
    report_param_extrema(model, topk)
    assert_all_params_finite(model)

@torch.inference_mode()
def zero_heads_gqa(
    q_proj: nn.Linear,
    k_proj: nn.Linear,
    v_proj: nn.Linear,
    o_proj: nn.Linear,
    head_indices: Iterable[int],
) -> None:
    """
    Zero rows of Q/K/V for given attention heads and columns of O.
    Automatically infers (num_heads, num_kv_heads, head_dim) from weight shapes.

    Shapes (HF LLaMA-style):
      q_proj.weight: [num_heads*head_dim, hidden_size]
      k_proj.weight: [num_kv_heads*head_dim, hidden_size]
      v_proj.weight: [num_kv_heads*head_dim, hidden_size]
      o_proj.weight: [hidden_size, num_heads*head_dim]
    """
    device = q_proj.weight.device

    # ---- Infer dimensions from shapes ----
    q_out, q_in = q_proj.weight.shape
    k_out, k_in = k_proj.weight.shape
    v_out, v_in = v_proj.weight.shape
    o_out, o_in = o_proj.weight.shape

    # Basic consistency
    assert q_in == k_in == v_in == o_out, "Q/K/V input dims must match O rows (hidden size)."
    assert o_in == q_out, "O in_features must equal Q out_features."

    # Infer head_dim via GCD of q_out and o_in (robust if multiples exist)
    # For LLaMA, q_out == o_in and both are num_heads * head_dim.
    head_dim_candidates = []
    for x in (q_out, o_in):
        # try common head dims first for speed; fallback to full factor search
        for d in (128, 96, 80, 64, 48, 40, 32):
            if x % d == 0:
                head_dim_candidates.append(d)
                break
        else:
            # fallback: smallest factor >=32
            for d in range(32, x + 1):
                if x % d == 0:
                    head_dim_candidates.append(d)
                    break
    head_dim = min(head_dim_candidates) if head_dim_candidates else None
    if head_dim is None or (q_out % head_dim) != 0:
        # last resort: assume KV uses same head_dim as Q
        # (k_out should also be divisible by this)
        raise ValueError(f"Cannot infer head_dim from shapes: q_out={q_out}, o_in={o_in}")

    num_heads = q_out // head_dim
    assert o_in == num_heads * head_dim, "Inferred num_heads/head_dim mismatch with o_proj."

    # Infer num_kv_heads using the same head_dim
    if k_out % head_dim != 0 or v_out % head_dim != 0:
        raise ValueError(f"Cannot infer num_kv_heads with head_dim={head_dim} (k_out={k_out}, v_out={v_out}).")
    num_kv_heads = k_out // head_dim
    assert num_kv_heads == v_out // head_dim, "K and V must have same num_kv_heads."

    # GQA groups: how many Q/O heads share one KV head
    assert num_heads % num_kv_heads == 0, "num_heads must be multiple of num_kv_heads (GQA)."
    groups = num_heads // num_kv_heads  # heads per kv head

    # ---- Normalize indices ----
    heads = torch.tensor(sorted(set(int(h) for h in head_indices)), device=device, dtype=torch.long)
    if heads.numel() == 0:
        return
    if not ((0 <= heads).all() and (heads < num_heads).all()):
        raise IndexError(f"Head index out of range: valid [0, {num_heads-1}]")

    # Map head -> kv_head
    kv_heads = torch.div(heads, groups, rounding_mode='floor').unique()

    # ---- Build index tensors ----
    D = head_dim
    # rows for Q: concat ranges [h*D : (h+1)*D) for h in heads
    q_rows = (heads[:, None] * D + torch.arange(D, device=device)).reshape(-1)
    # rows for K/V: concat ranges for kv_heads
    kv_rows = (kv_heads[:, None] * D + torch.arange(D, device=device)).reshape(-1)
    # cols for O: concat ranges [h*D : (h+1)*D)
    o_cols = q_rows.clone()  # same pattern as Q rows but used as columns

    # ---- Zero in-place ----
    q_proj.weight.data.index_fill_(0, q_rows, 0)
    k_proj.weight.data.index_fill_(0, kv_rows, 0)
    v_proj.weight.data.index_fill_(0, kv_rows, 0)
    o_proj.weight.data.index_fill_(1, o_cols, 0)

@torch.inference_mode()
def zero_ffn_units(
    up_proj: nn.Linear,
    down_proj: nn.Linear,
    unit_indices: Iterable[int],
    *,
    gate_proj: nn.Linear | None = None,
) -> None:
    """
    Zero FFN neurons:
      - rows of up_proj (and gate_proj if given)
      - columns of down_proj
    Shapes:
      up_proj.weight   : [intermediate_size, hidden_size]
      gate_proj.weight : [intermediate_size, hidden_size]  (optional)
      down_proj.weight : [hidden_size, intermediate_size]
    """
    device = up_proj.weight.device
    inter_size, hidden_size = up_proj.weight.shape
    assert down_proj.weight.shape == (hidden_size, inter_size), "down_proj shape mismatch"
    if gate_proj is not None:
        assert gate_proj.weight.shape == (inter_size, hidden_size), "gate_proj shape mismatch"

    units = torch.tensor(sorted(set(int(u) for u in unit_indices)), device=device, dtype=torch.long)
    if units.numel() == 0:
        return
    if not ((0 <= units).all() and (units < inter_size).all()):
        raise IndexError(f"Unit index out of range: valid [0, {inter_size-1}]")

    up_proj.weight.data.index_fill_(0, units, 0)
    if gate_proj is not None:
        gate_proj.weight.data.index_fill_(0, units, 0)
    down_proj.weight.data.index_fill_(1, units, 0)

def print_memory_usage(name: str, tensor: torch.Tensor):
            """Helper function to print memory usage of a tensor"""
            if tensor is not None:
                size_mb = tensor.numel() * tensor.element_size() / (1024 * 1024)
                print(f"[MEMORY] {name}: {size_mb:.2f} MB (shape: {tensor.shape}, dtype: {tensor.dtype}, device: {tensor.device})")
            else:
                print(f"[MEMORY] {name}: None")

@torch.inference_mode()
def block_l2sq_importance(params: list[torch.Tensor], size_normalize=True):
    """
    Computes L2^2 importance for a block (sum of squared weights).

    Args:
        params: list of tensors that belong to the block
        size_normalize: if True, normalize by number of parameters

    Returns:
        torch.Tensor scalar (float32) with block importance
    """
    # params: list of tensors belonging to the block (e.g., Q/K/V rows + O cols or FFN pieces)
    s = sum((p.float()**2).sum() for p in params)         # L2^2 across the block
    if size_normalize:
        n = sum(p.numel() for p in params)
        s = s / n
    return s

@torch.inference_mode()
def block_l1_importance(params: list[torch.Tensor], size_normalize: bool = True):
    """
    Computes L1 importance for a block (sum of absolute weights).

    Args:
        params: list of tensors that belong to the block
        size_normalize: if True, normalize by number of parameters

    Returns:
        torch.Tensor scalar (float32) with block importance
    """
    s = sum(p.float().abs().sum() for p in params)  # L1 norm across all params
    if size_normalize:
        n = sum(p.numel() for p in params)
        s = s / n
    return s

@torch.inference_mode()
def head_l2sq_full_gqa(q_proj: nn.Linear, k_proj: nn.Linear, v_proj: nn.Linear, o_proj: nn.Linear, head_dim: int) -> torch.Tensor:
    """
    Returns: [num_heads] tensor with ∑ w^2 across (Q rows + mapped K/V rows + O cols) per head.
    Infers num_heads, num_kv_heads, head_dim from shapes (works for MHA and GQA).
    """
    Q = q_proj.weight.float()  # [n_heads*D, H]
    K = k_proj.weight.float()  # [n_kv*D, H]
    V = v_proj.weight.float()  # [n_kv*D, H]
    O = o_proj.weight.float()  # [H, n_heads*D]

    q_out, H = Q.shape
    k_out, _ = K.shape
    v_out, _ = V.shape
    _, o_in = O.shape
    assert o_in == q_out, "O in_features must equal Q out_features"

    n_heads = q_out // head_dim
    assert k_out % head_dim == 0 and v_out % head_dim == 0
    n_kv = k_out // head_dim
    assert n_kv == v_out // head_dim
    assert n_heads % n_kv == 0
    groups = n_heads // n_kv  # heads per kv head

    # Q rows -> per head
    q_row = (Q.square().sum(dim=1)).view(n_heads, head_dim).sum(dim=1)  # [n_heads]
    # K/V rows -> per kv head, then broadcast to heads sharing it
    k_row = (K.square().sum(dim=1)).view(n_kv, head_dim).sum(dim=1)     # [n_kv]
    v_row = (V.square().sum(dim=1)).view(n_kv, head_dim).sum(dim=1)     # [n_kv]
    kv_to_head = (k_row + v_row).repeat_interleave(groups)               # [n_heads]
    # O columns -> per head
    o_col = (O.square().sum(dim=0)).view(n_heads, head_dim).sum(dim=1)  # [n_heads]

    scores = q_row + kv_to_head + o_col                                  # [n_heads], fp32
    return scores

@torch.inference_mode()
def ffn_l2sq_full(up_proj: nn.Linear, down_proj: nn.Linear, gate_proj: nn.Linear | None = None) -> torch.Tensor:
    """
    Returns: [intermediate_size] tensor with ∑ w^2 across (up rows + optional gate rows + down cols) per unit.
    """
    U = up_proj.weight.float()       # [inter, hidden]
    Dn = down_proj.weight.float()    # [hidden, inter]
    inter, hidden = U.shape
    assert Dn.shape == (hidden, inter)

    up_row  = U.square().sum(dim=1)              # [inter]
    gate_row = gate_proj.weight.float().square().sum(dim=1) if gate_proj is not None else 0.0
    down_col = Dn.square().sum(dim=0)            # [inter]

    scores = up_row + (gate_row if torch.is_tensor(gate_row) else 0.0) + down_col  # [inter]
    return scores

@torch.inference_mode()
def head_l1_full_gqa(q_proj: nn.Linear, k_proj: nn.Linear, v_proj: nn.Linear, o_proj: nn.Linear, head_dim: int) -> torch.Tensor:
    """
    Returns: [num_heads] tensor with ∑|w| across (Q rows + mapped K/V rows + O cols) per head.
    Infers num_heads, num_kv_heads, head_dim from shapes (works for MHA and GQA).
    """
    Q = q_proj.weight.float()  # [n_heads*D, H]
    K = k_proj.weight.float()  # [n_kv*D, H]
    V = v_proj.weight.float()  # [n_kv*D, H]
    O = o_proj.weight.float()  # [H, n_heads*D]

    q_out, H = Q.shape
    k_out, _ = K.shape
    v_out, _ = V.shape
    _, o_in = O.shape
    assert o_in == q_out

    n_heads = q_out // head_dim
    assert k_out % head_dim == 0 and v_out % head_dim == 0
    n_kv = k_out // head_dim
    assert n_kv == v_out // head_dim
    assert n_heads % n_kv == 0
    groups = n_heads // n_kv  # heads per kv head

    # Q rows per head
    q_row = (Q.abs().sum(dim=1)).view(n_heads, head_dim).sum(dim=1)   # [n_heads]
    # K/V rows per kv-head, broadcast to heads
    k_row = (K.abs().sum(dim=1)).view(n_kv, head_dim).sum(dim=1)      # [n_kv]
    v_row = (V.abs().sum(dim=1)).view(n_kv, head_dim).sum(dim=1)      # [n_kv]
    kv_to_head = (k_row + v_row).repeat_interleave(groups)             # [n_heads]
    # O cols per head
    o_col = (O.abs().sum(dim=0)).view(n_heads, head_dim).sum(dim=1)   # [n_heads]

    scores = q_row + kv_to_head + o_col                               # [n_heads]
    return scores

@torch.inference_mode()
def ffn_l1_full(up_proj: nn.Linear, down_proj: nn.Linear, gate_proj: nn.Linear | None = None) -> torch.Tensor:
    """
    Returns: [intermediate_size] tensor with ∑|w| across (up rows + optional gate rows + down cols) per unit.
    """
    U = up_proj.weight.float()       # [inter, hidden]
    Dn = down_proj.weight.float()    # [hidden, inter]
    inter, hidden = U.shape
    assert Dn.shape == (hidden, inter)

    up_row   = U.abs().sum(dim=1)                   # [inter]
    gate_row = gate_proj.weight.float().abs().sum(dim=1) if gate_proj is not None else 0.0
    down_col = Dn.abs().sum(dim=0)                  # [inter]

    scores = up_row + (gate_row if torch.is_tensor(gate_row) else 0.0) + down_col
    return scores  # [inter]

@torch.inference_mode()
def compute_layerwise_decay_schedule(
    num_layers: int,
    schedule_type: Literal['linear', 'exponential', 'cosine', 'gaussian', 'sigmoid'] = 'gaussian',
    center_layer: Optional[int] = None,
    scale: float = 1.0,
    steepness: float = 1.0,
    device: str = 'cpu',
):
    """
    Compute decay coefficients for each layer in a model to adjust pruning strength.

    Args:
        num_layers (int): Total number of layers in the model.
        schedule_type (str): Type of decay ('linear', 'exponential', 'cosine', 'gaussian', 'sigmoid').
        center_layer (int, optional): Center of the decay (used in gaussian/sigmoid). Defaults to middle layer.
        scale (float): Multiplier for the decay strength (default: 1.0).
        steepness (float): Controls sharpness for exponential/gaussian/sigmoid (e.g., std-dev or decay rate).
        device (str): Device to place the tensor on.

    Returns:
        decay: torch.Tensor of shape (num_layers,) with values in [0, 1].
    """
    layers = torch.arange(num_layers, device=device).float()
    
    if center_layer is None:
        center_layer = num_layers // 2
    center_layer_float = float(center_layer)

    if schedule_type == 'linear':
        decay = 1.0 - torch.abs(layers - center_layer_float) / center_layer_float
        decay = torch.clamp(decay, min=0.0)

    elif schedule_type == 'exponential':
        decay = torch.exp(-steepness * torch.abs(layers - center_layer_float) / center_layer_float)

    elif schedule_type == 'cosine':
        decay = 0.5 * (1 + torch.cos(math.pi * (layers - center_layer_float) / center_layer_float))
        decay = torch.clamp(decay, min=0.0)

    elif schedule_type == 'gaussian':
        # decay = torch.exp(-0.5 * ((layers - center_layer_float) / steepness) ** 2)
        raw = torch.exp(-0.5 * ((layers - center_layer_float) / steepness) ** 2)
        raw = raw / raw.max()  # normalize to max=1
        floor_val = 0.05
        decay = raw * (1.0 - floor_val) + floor_val

    elif schedule_type == 'sigmoid':
        decay = 1.0 / (1.0 + torch.exp(steepness * (torch.abs(layers - center_layer_float) - center_layer_float / 2)))
    
    else:
        raise ValueError(f"Unsupported schedule_type: {schedule_type}")

    return scale * decay

@torch.inference_mode()
def min_max_normalize_tensor_list(tensor_list, column_idx, epsilon: float = 1e-8):
    min_val = float('inf')
    max_val = float('-inf')

    for tensor in tensor_list:
        col0 = tensor[:, column_idx]
        min_val = min(min_val, col0.min().item())
        max_val = max(max_val, col0.max().item())
    
    range_val = max_val - min_val + epsilon

    for t in tensor_list:
        t[:, column_idx] = (t[:, column_idx] - min_val) / range_val

@torch.inference_mode()
def min_max_normalize_dict(dict_values: Dict[int, float], epsilon: float = 1e-8):
    min_val = min(dict_values.values())
    max_val = max(dict_values.values())
    range_val = max_val - min_val + epsilon
    for key, value in dict_values.items():
        dict_values[key] = (value - min_val) / range_val
    return dict_values

@torch.inference_mode()
def normalize_minmax(x: torch.Tensor, eps: float = 1e-12) -> torch.Tensor:
    """
    Normalize tensor to [0, 1] with min–max scaling.
    If constant, returns all zeros.
    """
    x_min, x_max = x.min(), x.max()
    rng = x_max - x_min
    if rng.abs() < eps:
        return torch.zeros_like(x)
    return (x - x_min) / rng

@torch.inference_mode()
def normalize_l2(x: torch.Tensor, eps: float = 1e-12) -> torch.Tensor:
    """
    Normalize tensor to unit L2 norm.
    If zero vector, returns zeros.
    """
    norm = torch.norm(x)
    if norm < eps:
        return torch.zeros_like(x)
    return x / norm

@torch.inference_mode()
def normalize_importance_tensor(blocks_importance: torch.Tensor, column: int = 0, epsilon: float = 1e-8):
    """Normalize a specified column (e.g., importance) in-place across all blocks."""
    values = blocks_importance[:, column]
    finite_mask = torch.isfinite(values)
    values = values[finite_mask]

    min_val = values.min()
    max_val = values.max()
    range_val = max_val - min_val + epsilon

    blocks_importance[:, column] = (blocks_importance[:, column] - min_val) / range_val

@torch.inference_mode()
def meanstd_normalize_tensor_list(tensor_list, column_idx, epsilon: float = 1e-8):
    """Z-score normalization (mean=0, std=1) for a specified column across all tensors in a list."""
    # Collect all values from the specified column
    all_values = []
    for tensor in tensor_list:
        if tensor.numel() > 0:
            all_values.extend(tensor[:, column_idx].tolist())
    
    if not all_values:
        return
    
    # Convert to tensor and compute statistics
    values_tensor = torch.tensor(all_values, dtype=torch.float32)
    mean_val = torch.mean(values_tensor)
    std_val = torch.std(values_tensor, unbiased=False)

    values_tensor = (values_tensor - mean_val) / (std_val + epsilon)
    values_tensor = torch.sigmoid(values_tensor)

    for tensor in tensor_list:
        if tensor.numel() > 0:
            tensor[:, column_idx] = values_tensor[:tensor.shape[0]]
    
    # Apply Z-score normalization to each tensor
    for tensor in tensor_list:
        if tensor.numel() > 0:
            tensor[:, column_idx] = (tensor[:, column_idx] - mean_val) / (std_val + epsilon)

@torch.inference_mode()
def meanstd_normalize_dict(dict_values: Dict[int, float], epsilon: float = 1e-8):
    """Z-score normalization (mean=0, std=1) for dictionary values."""
    if not dict_values:
        return dict_values
    
    values = list(dict_values.values())
    mean_val = sum(values) / len(values)
    variance = sum((x - mean_val) ** 2 for x in values) / len(values)
    std_val = (variance ** 0.5) + epsilon
    
    # Apply normalization in-place
    for key in dict_values:
        dict_values[key] = (dict_values[key] - mean_val) / std_val
    
    return dict_values

@torch.inference_mode()
def meanstd_normalize_tensor(blocks_importance: torch.Tensor, column: int = 0, epsilon: float = 1e-8):
    """Z-score normalization (mean=0, std=1) for a specified column in a tensor."""
    values = blocks_importance[:, column]
    finite_mask = torch.isfinite(values)
    values = values[finite_mask]
    
    if values.numel() == 0:
        return
    
    mean_val = torch.mean(values)
    std_val = torch.std(values)
    
    blocks_importance[:, column] = (blocks_importance[:, column] - mean_val) / (std_val + epsilon)

@torch.inference_mode()
def downdate_hessian_postnorm(H, A_raw, gamma, sigma, eps=1e-6):
    """
    Post-Norm version: Downdate Hessian H ≈ XᵀX after pruning a block.
    
    Args:
        H      (d×d): Current Hessian matrix (in-place updated)
        A_raw  (d×k): Output projection of the pruned block (e.g. W_O[:, i])
        gamma  (d,): LayerNorm scale parameter (layernorm.weight)
        sigma  (d,): Runtime std dev over calibration data
        eps     (float): Damping constant for stability
    """
    # Step 1: Scale the projection for Post-Norm
    scale = (gamma / sigma).view(-1, 1)   # shape (d, 1)
    A = scale * A_raw                     # scaled projection

    # Step 2: Compute orthonormal basis U
    U, _ = torch.linalg.qr(A)            # U: d × k

    # Step 3: Downdate H using Schur complement
    R    = H @ U                         # d × k
    H_PP = U.T @ R                      # k × k
    H_PP += eps * torch.eye(H_PP.size(0), device=H.device)

    # Final downdate
    H -= R @ torch.linalg.solve(H_PP, R.T)

@torch.inference_mode()
def downdate_hessian_prenorm(H, A_raw, eps=1e-6):
    """
    Pre-Norm version: Downdate Hessian H ≈ XᵀX after pruning a block.

    Args:
        H      (d×d): Current Hessian matrix (in-place updated)
        A_raw  (d×k): Output projection of the pruned block (e.g. W_O[:, i])
        eps     (float): Damping constant for stability
    """
    # No scaling in Pre-Norm
    A = A_raw

    # Step 1: Compute orthonormal basis U
    U, _ = torch.linalg.qr(A)            # U: d × k

    # Step 2: Downdate H
    R    = H @ U                         # d × k
    H_PP = U.T @ R                      # k × k
    H_PP += eps * torch.eye(H_PP.size(0), device=H.device)

    # Final downdate
    H -= R @ torch.linalg.solve(H_PP, R.T)

@torch.inference_mode()
def select_candidates_option_a(
    sensitivities: torch.Tensor,        # shape [L], already normalized to ~[0,1], where smaller = less sensitive
    block_importance: torch.Tensor,     # shape [N,4]: [importance, layer_idx, local_idx, module_type]; importance already normalized
    total_candidates: int,
    device: str = "cuda",
    p: float = 1.5,                     # convexity: bigger p => stronger bias toward low-sensitivity layers
    sens_floor: float = 1e-3,           # avoid division by 0 after normalization
    max_weight_multiplier: float = 50.0,# cap ratio between largest/smallest per-layer weight
    min_per_layer: int = 0,             # optional floor of candidates per layer
) -> torch.Tensor:
    """
    Option A (robust): allocate candidate quotas per layer proportional to (clamped sensitivity)^(-p),
    with an inter-layer weight cap to prevent pathological domination.
    Then, within each layer, pick the top-importance blocks up to that quota.

    Precondition: 'sensitivities' and 'block_importance[:,0]' should already be normalized (e.g., min-max).
    """
    logger = SharedLogger.get_logger("SELECT_CANDIDATES_OPTION_A")
    # Move/clone safely
    sens = sensitivities.to(device).float()
    bi   = block_importance.to(device).clone()
    bi[:, 0] = bi[:, 0].float()
    L = sens.numel()
    N = bi.size(0)

    # Defensive checks
    if total_candidates <= 0 or N == 0 or L == 0:
        logger.debug("[OptionA-Robust] Nothing to select (degenerate inputs).")
        return bi.new_zeros((0, 4))

    # 1) Robust per-layer weights
    #    small sensitivity => large weight
    s_eff = torch.clamp(sens, min=sens_floor)             # avoid zero
    raw_w = s_eff.pow(-p)                                  # inverse sensitivity
    # cap extreme skew: w_cap = min(raw_w, min(raw_w)*cap)
    w_min = torch.min(raw_w)
    w_cap_target = w_min * max_weight_multiplier
    w = torch.minimum(raw_w, w_cap_target)

    # normalize to sum=1
    w_sum = torch.sum(w)
    if w_sum <= 0:
        # fallback: uniform
        w = torch.ones_like(w) / L
    else:
        w = w / w_sum

    logger.debug(
        "[OptionA-Robust] L=%d, N=%d, target=%d, p=%.3f, sens_floor=%.1e, cap=%.1fx | "
        "sens[min,max]=(%.6f,%.6f) | raw_w[min,max]=(%.6f,%.6f) | w[min,max]=(%.6f,%.6f)",
        L, N, total_candidates, p, sens_floor, max_weight_multiplier,
        sens.min().item(), sens.max().item(),
        raw_w.min().item(), raw_w.max().item(),
        w.min().item(), w.max().item()
    )

    # 2) Compute per-layer quotas (integer)
    #    Start with proportional rounding; enforce a small floor if desired.
    ideal = w * float(total_candidates)
    quotas = torch.floor(ideal).to(torch.long)
    remainder = total_candidates - int(quotas.sum().item())
    if remainder > 0:
        # distribute leftovers to layers with largest fractional parts of 'ideal'
        frac = (ideal - quotas.float())
        top_idx = torch.topk(frac, k=remainder, largest=True).indices
        quotas[top_idx] += 1

    if min_per_layer > 0:
        # enforce per-layer minimum, then re-balance
        boost = torch.clamp_min(quotas, min_per_layer) - quotas
        extra = int(boost.sum().item())
        quotas = torch.clamp_min(quotas, min_per_layer)
        if extra > 0:
            # remove the extra from layers with largest quotas
            # (keep at least min_per_layer)
            reducible = quotas - min_per_layer
            if reducible.sum() >= extra:
                # greedily reduce from the largest quotas
                _, desc = torch.sort(quotas, descending=True)
                for idx in desc.tolist():
                    can_take = min(int(reducible[idx].item()), extra)
                    if can_take > 0:
                        quotas[idx] -= can_take
                        extra -= can_take
                    if extra == 0:
                        break
            # if still extra>0, we just accept a small overflow; it’s rare

    # 3) For each layer, pick top-K blocks by *importance* within that layer
    #    (remember: we compute OBS only for these candidates later)
    selected_rows = []
    # Pre-split indices by layer to avoid repeated masks
    layer_to_rows = [[] for _ in range(L)]
    lay = bi[:, 1].to(torch.long)
    for idx in range(N):
        li = int(lay[idx].item())
        if 0 <= li < L:
            layer_to_rows[li].append(idx)

    imp_min = bi[:, 0].min().item() if N > 0 else 0.0
    imp_max = bi[:, 0].max().item() if N > 0 else 0.0
    logger.debug("[OptionA-Robust] importance range: [%.6f, %.6f]", imp_min, imp_max)

    total_picked = 0
    for li in range(L):
        rows = layer_to_rows[li]
        k = int(quotas[li].item())
        if k <= 0 or len(rows) == 0:
            logger.debug("[OptionA-Robust] Layer %02d: sens=%.6f, weight=%.6f, avail=%d, picked=%d",
                         li, sens[li].item(), w[li].item(), len(rows), 0)
            continue

        layer_imp = bi[rows, 0]
        # top-k by importance
        if k >= len(rows):
            pick_idx = rows
        else:
            topk = torch.topk(layer_imp, k=k, largest=True)
            pick_idx = [rows[i] for i in topk.indices.tolist()]

        total_picked += len(pick_idx)
        logger.debug("[OptionA-Robust] Layer %02d: sens=%.6f, weight=%.6f, avail=%d, picked=%d",
                     li, sens[li].item(), w[li].item(), len(rows), len(pick_idx))
        selected_rows.extend(pick_idx)

    # 4) Gather and (optionally) sort selected by importance (not required for return format)
    sel = bi[selected_rows, :] if selected_rows else bi.new_zeros((0, 4))
    # Keep as-is; caller will compute OBS on these.
    logger.debug("[OptionA-Robust] selected %d / %d candidates.", sel.size(0), total_candidates)
    return sel

@torch.inference_mode()
def select_candidates_option_a1(
    sensitivities: torch.Tensor,          # shape [L], float
    block_importance: torch.Tensor,       # shape [N, 4] -> [importance, layer_idx, block_idx, block_type]
    total_candidates: int,                # target m
    device: str = "cuda",     # e.g., "cuda" / "cpu"
    eps = 1e-12,
    p = 1.5
) -> torch.Tensor:
    """
    Option A (1-knob): allocate candidate quota per layer ∝ sensitivity^{-p}, p=1.5,
    then pick the *lowest-importance* blocks within each layer.

    Inputs
    ------
    sensitivities: [L] float tensor; higher = more sensitive (riskier).
    block_importance: [N, 4] tensor:
        col0=float importance (lower means more prune-able),
        col1=int layer_idx in [0..L-1],
        col2=int local block index within its module/layer,
        col3=int block/module type (passed through).
    total_candidates: total number of candidate blocks to return (m).
    device: torch device (str or torch.device).
    eps: small constant to prevent division by zero.
    p: exponent for the sensitivity-based allocation.

    Returns
    -------
    candidates: [m, 4] tensor on `device`, same column schema as block_importance.
    """
    logger = SharedLogger.get_logger("SELECT_CANDIDATES_OPTION_A")

    if total_candidates <= 0:
        return torch.empty((0, 4), dtype=block_importance.dtype, device=device)

    # Move inputs to device (non-blocking if possible)
    sensitivities = sensitivities.to(device=device, dtype=torch.float32, non_blocking=True)
    B = block_importance.to(device=device)

    # Split columns & cast dtypes
    imp = B[:, 0].to(torch.float32)                     # importance
    lyr = B[:, 1].to(torch.long)                        # layer idx
    blk = B[:, 2].to(torch.long)                        # local block idx
    typ = B[:, 3].to(torch.long)                        # block type

    L = int(sensitivities.numel())
    N = int(B.shape[0])
    target = int(min(total_candidates, N))

    # ---- per-layer availability
    avail = torch.zeros(L, device=device, dtype=torch.long)
    avail.index_add_(0, lyr, torch.ones_like(lyr, dtype=torch.long))

    # ---- layer weights (favor low sensitivity)
    s_clamped = torch.clamp(sensitivities, min=eps)
    w = (1.0 / s_clamped).pow(p)  # shape [L]
    # zero-out layers that have no blocks
    w = torch.where(avail > 0, w, torch.zeros_like(w))

    w_sum = float(w.sum().item())
    if w_sum <= 0:
        # degenerate: no valid layers; return lowest-importance globally
        logger.debug("[OptionA] Degenerate weights; falling back to global lowest-importance selection.")
        idx_global = torch.argsort(imp, dim=0, descending=False)[:target]
        out = torch.stack([imp[idx_global], lyr[idx_global].to(imp.dtype), blk[idx_global].to(imp.dtype), typ[idx_global].to(imp.dtype)], dim=1)
        return out.to(device)

    # ---- initial fractional allocation
    raw_alloc = w * (target / w_sum)  # float per layer
    k0 = torch.floor(raw_alloc).to(torch.long)

    # Cap by availability
    k0 = torch.minimum(k0, avail)

    # ---- redistribute leftover quota
    selected_so_far = int(k0.sum().item())
    leftover = target - selected_so_far

    logger.debug(f"[OptionA] L={L}, N={N}, target={target}, "
                 f"sensitivity[min,max]=({float(s_clamped.min()):.6f},{float(s_clamped.max()):.6f}), "
                 f"w_sum={w_sum:.6f}, p={p}, initial_selected={selected_so_far}, leftover={leftover}")

    if leftover > 0:
        # Only layers with remaining capacity compete for leftovers
        cap = (avail - k0).clamp_min(0)
        mask_cap = cap > 0
        if mask_cap.any():
            w_cap = torch.where(mask_cap, w, torch.zeros_like(w))
            norm = float(w_cap.sum().item())
            if norm > 0:
                # fractional shares for the remaining
                share = w_cap * (leftover / norm)
                add = torch.floor(share).to(torch.long)
                add = torch.minimum(add, cap)
                # ensure we don't overshoot due to flooring
                add_sum = int(add.sum().item())
                k0 = k0 + add
                leftover2 = target - int(k0.sum().item())

                # Hand out any remaining ones by largest fractional parts
                if leftover2 > 0:
                    frac = share - torch.floor(share)
                    frac = torch.where(mask_cap, frac, torch.full_like(frac, -1.0))  # exclude no-cap layers
                    # take top `leftover2` layers by frac (ties arbitrary)
                    if leftover2 > 0:
                        top_idx = torch.topk(frac, k=min(leftover2, int(mask_cap.sum().item()))).indices
                        # increment by 1 but still respect cap
                        for i in top_idx.tolist():
                            if leftover2 <= 0:
                                break
                            if k0[i] < avail[i]:
                                k0[i] += 1
                                leftover2 -= 1
            # if norm==0 (no weights left), do nothing — we’ll fill globally below.

    # Final safety cap
    k0 = torch.minimum(k0, avail)
    selected_so_far = int(k0.sum().item())

    # ---- select within each layer (lowest-importance first)
    # Pre-sort all indices by importance (ascending) to make per-layer pick fast
    global_sorted = torch.argsort(imp, dim=0, descending=False)

    # Per-layer buckets of indices (lowest-first)
    layer_lists = [[] for _ in range(L)]
    for idx in global_sorted.tolist():
        l = int(lyr[idx].item())
        if len(layer_lists[l]) < int(k0[l].item()):
            layer_lists[l].append(idx)
        # early break if we've met all quotas
        if sum(len(x) for x in layer_lists) >= selected_so_far:
            break

    selected_idx = [i for bucket in layer_lists for i in bucket]

    # ---- fill any residual shortfall globally (e.g., no-cap layers, rounding)
    shortfall = target - len(selected_idx)
    if shortfall > 0:
        # Take next lowest-importance blocks that aren't already selected
        chosen = set(selected_idx)
        fill = []
        for idx in global_sorted.tolist():
            if idx not in chosen:
                fill.append(idx)
                if len(fill) >= shortfall:
                    break
        selected_idx.extend(fill)

    # Trim if we somehow overshoot
    if len(selected_idx) > target:
        selected_idx = selected_idx[:target]

    # ---- pack output (keep float in col0, others as float to match input dtype contract)
    sel = torch.tensor(selected_idx, device=device, dtype=torch.long)
    candidates = torch.stack([
        imp.index_select(0, sel),                         # importance (float32)
        lyr.index_select(0, sel).to(imp.dtype),           # layer_idx   (float)
        blk.index_select(0, sel).to(imp.dtype),           # block_idx   (float)
        typ.index_select(0, sel).to(imp.dtype),           # block_type  (float)
    ], dim=1)

    # ---- logging diagnostics
    with torch.no_grad():
        # Basic ranges
        imp_min = float(imp.min().item()) if imp.numel() else 0.0
        imp_max = float(imp.max().item()) if imp.numel() else 0.0
        logger.debug(f"[OptionA] importance range: [{imp_min:.6f}, {imp_max:.6f}]")
        # Per-layer summary (only for layers with quota or availability)
        for l in range(L):
            if avail[l] == 0:
                continue
            k_l = int(k0[l].item())
            if k_l == 0:
                continue
            w_l = float(w[l].item())
            s_l = float(s_clamped[l].item())
            logger.debug(
                f"[OptionA] Layer {l:02d}: sens={s_l:.6f}, weight={w_l:.6f}, "
                f"avail={int(avail[l].item())}, picked={k_l}"
            )

    return candidates

@torch.inference_mode()
def compute_candidate_blocks(
    global_sensitivity: torch.Tensor,
    blocks_importance: torch.Tensor,
    *,
    K_tot: int | None = None,
    alpha: float = 1.0,
    beta: float = 1.0,
    epsilon: float = 1e-6,
    device: str = "cuda",
) -> torch.Tensor:
    """Compute candidate blocks based on importance scores, sensitivity, and CV.
    
    This function computes which blocks should be selected as candidates for pruning
    based on their importance scores, layer sensitivity, and coefficient of variation.
    It returns a tensor in the same format as initial_importance: [importance, layer_idx, block_idx, block_type].

    Parameters
    ----------
    global_sensitivity : torch.Tensor
        Tensor of sensitivity values for each layer (larger ⇒ more sensitive).
    blocks_importance : torch.Tensor
        Rank‑4 tensor with columns [importance, layer_idx, block_idx, block_type]
        where block_type: 0=head, 1=ffn. Any rows whose importance is ±inf are discarded.
    K_tot : int | None, default = total number of valid blocks
        Total number of candidate blocks to select across all layers.
    alpha, beta : float, default = 1.0
        Exponents controlling the weight of sensitivity and heterogeneity.
    epsilon : float, default = 1e-6
        Small constant to avoid divide‑by‑zero.
    device : str, default = 'cuda'
        Device on which torch operations will be carried out.

    Returns
    -------
    torch.Tensor
        Tensor with columns [importance, layer_idx, block_idx, block_type] containing
        the selected candidate blocks.
    """
    logger = SharedLogger.get_logger("COMPUTE_CANDIDATE_BLOCKS")

    # 1. Keep only finite importance scores
    finite_mask = torch.isfinite(blocks_importance[:, 0])
    blocks_importance = blocks_importance[finite_mask]
    
    if blocks_importance.shape[0] == 0:
        logger.warning("No finite importance scores found")
        return torch.empty(0, 4, device=device, dtype=torch.float32)
    
    # Debug: Overall statistics
    logger.debug(f"Computing candidate blocks for {len(global_sensitivity)} layers")
    logger.debug(f"Total blocks available: {blocks_importance.shape[0]}")
    logger.debug(f"Global sensitivity range: [{global_sensitivity.min():.6f}, {global_sensitivity.max():.6f}]")
    logger.debug(f"Blocks importance range: [{blocks_importance[:, 0].min():.6f}, {blocks_importance[:, 0].max():.6f}]")
    logger.debug(f"Block types: 0=head, 1=ffn")

    # 2. Compute per-layer statistics and CV values
    layer_stats = []
    
    for layer_idx in range(len(global_sensitivity)):
        # Get blocks for this layer
        layer_mask = blocks_importance[:, 1] == layer_idx
        if not torch.any(layer_mask):
            continue
            
        layer_blocks = blocks_importance[layer_mask]
        
        # Separate head and FFN blocks
        head_mask = layer_blocks[:, 3] == 0
        ffn_mask = layer_blocks[:, 3] == 1
        
        head_blocks = layer_blocks[head_mask]
        ffn_blocks = layer_blocks[ffn_mask]
        
        # Compute CV for head blocks
        if len(head_blocks) > 1:
            head_importance = head_blocks[:, 0].to(device=device, dtype=torch.float32)
            mean_head = torch.mean(head_importance)
            std_head = torch.std(head_importance, unbiased=False)
            cv_head = std_head / (mean_head.abs() + epsilon)
        else:
            cv_head = torch.tensor(0.0, device=device)
        
        # Compute CV for FFN blocks
        if len(ffn_blocks) > 1:
            ffn_importance = ffn_blocks[:, 0].to(device=device, dtype=torch.float32)
            mean_ffn = torch.mean(ffn_importance)
            std_ffn = torch.std(ffn_importance, unbiased=False)
            cv_ffn = std_ffn / (mean_ffn.abs() + epsilon)
        else:
            cv_ffn = torch.tensor(0.0, device=device)
        
        # Compute sensitivity weight
        sensitivity = global_sensitivity[layer_idx]
        sensitivity_norm = (sensitivity - global_sensitivity.min()) / (global_sensitivity.max() - global_sensitivity.min() + epsilon)
        
        # Compute R_i for head and FFN separately
        R_head = (1 - sensitivity_norm) * cv_head
        R_ffn = (1 - sensitivity_norm) * cv_ffn
        
        layer_stats.append({
            'layer_idx': layer_idx,
            'head_blocks': head_blocks,
            'ffn_blocks': ffn_blocks,
            'R_head': R_head,
            'R_ffn': R_ffn,
            'cv_head': cv_head,
            'cv_ffn': cv_ffn
        })
        
        logger.debug(f"Layer {layer_idx}: R_head={R_head:.6f}, R_ffn={R_ffn:.6f}, CV_head={cv_head:.6f}, CV_ffn={cv_ffn:.6f}")

    if not layer_stats:
        logger.warning("No valid layers found")
        return torch.empty(0, 4, device=device, dtype=torch.float32)

    # 3. Compute total weights and normalize
    # The issue: we need to compute combined weights per layer, not separate head/FFN weights
    total_R_per_layer = sum(stat['R_head'] + stat['R_ffn'] for stat in layer_stats) + epsilon
    
    # 4. Distribute K_tot across layers and block types
    if K_tot is None:
        K_tot = blocks_importance.shape[0]
    
    # First pass: compute initial distribution
    layer_allocations = []
    total_allocated = 0
    
    for stat in layer_stats:
        # Compute combined weight per layer
        layer_weight = (stat['R_head'] + stat['R_ffn']) / total_R_per_layer
        
        # Distribute the layer's allocation between head and FFN proportionally
        total_layer_blocks = len(stat['head_blocks']) + len(stat['ffn_blocks'])
        if total_layer_blocks > 0:
            # Compute how many blocks this layer should get
            layer_k = max(0, min(int(torch.floor(layer_weight * K_tot).item()), total_layer_blocks))
            
            # Distribute between head and FFN proportionally
            if len(stat['head_blocks']) > 0 and len(stat['ffn_blocks']) > 0:
                head_ratio = len(stat['head_blocks']) / total_layer_blocks
                ffn_ratio = len(stat['ffn_blocks']) / total_layer_blocks
                
                k_head = max(0, min(int(head_ratio * layer_k), len(stat['head_blocks'])))
                k_ffn = max(0, min(int(ffn_ratio * layer_k), len(stat['ffn_blocks'])))
                
                # Adjust to ensure we use all allocated blocks
                remaining = layer_k - k_head - k_ffn
                if remaining > 0:
                    if k_head < len(stat['head_blocks']):
                        additional = min(remaining, len(stat['head_blocks']) - k_head)
                        k_head += additional
                        remaining -= additional
                    if remaining > 0 and k_ffn < len(stat['ffn_blocks']):
                        additional = min(remaining, len(stat['ffn_blocks']) - k_ffn)
                        k_ffn += additional
            else:
                # Only one block type available
                k_head = layer_k if len(stat['head_blocks']) > 0 else 0
                k_ffn = layer_k if len(stat['ffn_blocks']) > 0 else 0
        else:
            k_head = 0
            k_ffn = 0
        
        logger.debug(f"Layer {stat['layer_idx']}: R_head={stat['R_head']:.6f}, R_ffn={stat['R_ffn']:.6f}")
        logger.debug(f"  layer_weight={layer_weight:.6f}, layer_k={layer_k if 'layer_k' in locals() else 'N/A'}")
        logger.debug(f"  k_head={k_head}, k_ffn={k_ffn}, total={k_head + k_ffn}")
        
        layer_allocations.append({
            'stat': stat,
            'k_head': k_head,
            'k_ffn': k_ffn,
            'total_k': k_head + k_ffn
        })
        total_allocated += k_head + k_ffn
    
    # Second pass: adjust to reach exactly K_tot
    remaining = K_tot - total_allocated
    logger.debug(f"Initial allocation: {total_allocated}, target: {K_tot}, remaining: {remaining}")
    
    if remaining != 0:
        # Sort layers by their R values to prioritize those with higher importance
        layer_allocations.sort(key=lambda x: x['stat']['R_head'] + x['stat']['R_ffn'], reverse=True)
        
        for allocation in layer_allocations:
            if remaining <= 0:
                break
                
            # Try to add more head blocks
            if allocation['k_head'] < len(allocation['stat']['head_blocks']):
                additional_head = min(remaining, len(allocation['stat']['head_blocks']) - allocation['k_head'])
                allocation['k_head'] += additional_head
                remaining -= additional_head
                logger.debug(f"  Added {additional_head} head blocks to layer {allocation['stat']['layer_idx']}, remaining: {remaining}")
                
            if remaining <= 0:
                break
                
            # Try to add more FFN blocks
            if allocation['k_ffn'] < len(allocation['stat']['ffn_blocks']):
                additional_ffn = min(remaining, len(allocation['stat']['ffn_blocks']) - allocation['k_ffn'])
                allocation['k_ffn'] += additional_ffn
                remaining -= additional_ffn
                logger.debug(f"  Added {additional_ffn} FFN blocks to layer {allocation['stat']['layer_idx']}, remaining: {remaining}")
    
    # Verify final allocation
    final_total = sum(allocation['k_head'] + allocation['k_ffn'] for allocation in layer_allocations)
    logger.debug(f"Final allocation: {final_total}, target: {K_tot}")
    
    if final_total != K_tot:
        logger.warning(f"Final allocation ({final_total}) does not match target ({K_tot})")
    
    # Third pass: select the actual blocks
    selected_blocks = []
    
    for allocation in layer_allocations:
        stat = allocation['stat']
        k_head = allocation['k_head']
        k_ffn = allocation['k_ffn']
        
        # Select top k_head head blocks by importance (LEAST important for pruning)
        if k_head > 0:
            head_sorted = torch.argsort(stat['head_blocks'][:, 0], descending=False)
            selected_head = stat['head_blocks'][head_sorted[:k_head]]
            selected_blocks.append(selected_head)
        
        # Select top k_ffn FFN blocks by importance (LEAST important for pruning)
        if k_ffn > 0:
            ffn_sorted = torch.argsort(stat['ffn_blocks'][:, 0], descending=False)
            selected_ffn = stat['ffn_blocks'][ffn_sorted[:k_ffn]]
            selected_blocks.append(selected_ffn)
    
    # 5. Combine all selected blocks
    if selected_blocks:
        result = torch.cat(selected_blocks, dim=0)
    else:
        result = torch.empty(0, 4, device=device, dtype=torch.float32)
    
    logger.debug(f"Selected {len(result)} candidate blocks out of {blocks_importance.shape[0]} available (target: {K_tot})")
    if len(result) > 0:
        logger.debug(f"Selected blocks importance range: [{result[:, 0].min():.6f}, {result[:, 0].max():.6f}]")
    
    return result

@torch.inference_mode()
def compute_candidate_blocks_no_hyperparameter(global_sensitivity_dict: Dict[int, float], blocks_importance: torch.Tensor, epsilon: float = 1e-6, device: str = 'cuda'):
    """
    Computes how many candidate blocks to consider per layer using a fully normalized,
    hyperparameter-free scoring strategy based on layer sensitivity and variance.

    Args:
        global_sensitivity_dict (dict): Maps layer index to global sensitivity.
        blocks_importance (torch.Tensor): Tensor of block importances.
        epsilon (float): Small constant to prevent division by zero.
        device (str): 'cuda' or 'cpu'.

    Returns:
        dict: Number of candidate blocks to consider from each layer (layer index -> K_i).
    """
    R_list = []  # List of (layer_index, N_i, R_i)
    total_blocks = 0

    # Create mask to filter out rows with inf or -inf in the importance_value (first column)
    finite_mask = torch.isfinite(input=blocks_importance[:, 0])

    # Apply the mask
    blocks_importance = blocks_importance[finite_mask]

    for index, sensitivity in global_sensitivity_dict.items():
        # Mask for entries from layer i
        layer_mask = blocks_importance[:, 1] == index
        if layer_mask.numel() == 0:
            continue
        # Filter block importances for the current layer
        blocks_importance_i = blocks_importance[layer_mask][:, 0].to(device=device, dtype=torch.float32)  # Get only the importance values
        if blocks_importance_i.numel() == 0:
            continue
        N_i = blocks_importance_i.numel()
        var_i = torch.var(blocks_importance_i, unbiased=False) if N_i > 1 else torch.tensor(0.0, device=device)
        R_i = (1.0 / (sensitivity + epsilon)) * (var_i + 1.0)

        R_list.append((index, N_i, R_i))
        total_blocks += N_i

    # Normalize R_i and compute K_i per layer
    total_R = torch.stack([r[2] for r in R_list]).sum() + epsilon
    candidate_blocks_per_layer = {}

    for index, N_i, R_i in R_list:
        weight = R_i / total_R
        K_i = torch.ceil(weight * N_i).long().item()
        candidate_blocks_per_layer[index] = min(K_i, N_i)

    return candidate_blocks_per_layer

@torch.inference_mode()
def select_least_important_globally(importances, k):
    """
    Select the k least important blocks globally from a tensor of block importances.
    
    Args:
        importances (torch.Tensor): Tensor with columns [importance, layer_idx, block_idx, block_type]
        k (int): Number of least important blocks to select
        
    Returns:
        torch.Tensor: Selected blocks in the same format [importance, layer_idx, block_idx, block_type]
    """
    # Handle both old format (list of tensors) and new format (single tensor)
    if isinstance(importances, list):
        # Old format: concatenate list of tensors
        all_blocks_tensor = torch.cat(importances, dim=0)
    else:
        # New format: single tensor
        all_blocks_tensor = importances
    
    # Keep only finite importance scores
    finite_mask = torch.isfinite(all_blocks_tensor[:, 0])
    all_blocks_tensor = all_blocks_tensor[finite_mask]
    
    if all_blocks_tensor.shape[0] == 0:
        return torch.empty(0, 4, device=all_blocks_tensor.device, dtype=all_blocks_tensor.dtype)
    
    # Sort all blocks by importance (ascending: smaller = less important)
    sorted_indices = torch.argsort(all_blocks_tensor[:, 0])
    selected = all_blocks_tensor[sorted_indices][:k]
    
    return selected


@torch.inference_mode()
def select_least_important_percentage(importances, percent):
    """
    Select a percentage of the least important blocks from a tensor of block importances.
    
    Args:
        importances (torch.Tensor): Tensor with columns [importance, layer_idx, block_idx, block_type]
        percent (float): Percentage of blocks to select (0.0 to 1.0)
        
    Returns:
        torch.Tensor: Selected blocks in the same format [importance, layer_idx, block_idx, block_type]
    """
    # Handle both old format (list of tensors) and new format (single tensor)
    if isinstance(importances, list):
        # Old format: sum the number of elements in each tensor
        total_blocks = sum(layer_imp.shape[0] for layer_imp in importances)
    else:
        # New format: single tensor
        total_blocks = importances.shape[0]
    
    k = max(1, int(total_blocks * percent))  # Ensure at least 1 block
    return select_least_important_globally(importances, k)

@torch.inference_mode()
def compute_candidate_blocks_pct_rank(global_sensitivity: torch.Tensor, blocks_importance: torch.Tensor, K_tot: Optional[int] = None, alpha: float = 1.0, beta: float = 1.0, epsilon: float = 1e-3, device: str = 'cuda') -> torch.Tensor:
    """
    Computes candidate blocks using percentile rank normalization and sensitivity weighting.
    Returns a tensor in the same format as initial_importance: [importance, layer_idx, block_idx, block_type].
    
    Args:
        global_sensitivity (torch.Tensor): Global sensitivity values for each layer.
        blocks_importance (torch.Tensor): Tensor of block importances with columns [importance, layer_idx, block_idx, block_type].
        K_tot (Optional[int]): Total number of candidate blocks to select.
        alpha (float): Power for sensitivity factor (typically 1.0).
        beta (float): Power for CV factor (typically 1.0).
        epsilon (float): Small constant to prevent division by zero.
        device (str): 'cuda' or 'cpu'.

    Returns:
        torch.Tensor
        Tensor with columns [importance, layer_idx, block_idx, block_type] containing the selected candidate blocks.
    """
    logger = SharedLogger.get_logger("COMPUTE_CANDIDATE_BLOCKS_PCT_RANK")

    # 1. Keep only finite importance scores ----------------------------------
    finite_mask = torch.isfinite(blocks_importance[:, 0])
    blocks_importance = blocks_importance[finite_mask]
    
    # Debug: Overall statistics
    logger.debug(f"Computing candidate blocks for {len(global_sensitivity)} layers using percentile rank normalization")
    logger.debug(f"Total blocks available: {blocks_importance.shape[0]}")
    logger.debug(f"Global sensitivity range: [{global_sensitivity.min():.6f}, {global_sensitivity.max():.6f}]")
    logger.debug(f"Alpha={alpha}, Beta={beta}, Epsilon={epsilon}")
    logger.debug(f"Blocks importance range: [{blocks_importance[:, 0].min():.6f}, {blocks_importance[:, 0].max():.6f}]")

    # 2. First pass: collect CV values for percentile rank computation --------
    cv_values = []
    sensitivity_values = []
    
    for layer_idx in range(len(global_sensitivity)):
        mask = blocks_importance[:, 1] == layer_idx
        if not torch.any(mask):
            continue
        imp_values = blocks_importance[mask][:, 0].to(device=device, dtype=torch.float32)
        N_i = imp_values.numel()
        if N_i > 1:
            mean_i = torch.mean(imp_values)
            std_i = torch.std(imp_values, unbiased=False)
            CV_i = std_i / (mean_i.abs() + epsilon)
            cv_values.append(CV_i)
            sensitivity_values.append(global_sensitivity[layer_idx])
    
    # Compute percentile ranks for both sensitivity and CV
    if cv_values and sensitivity_values:
        cv_tensor = torch.stack(cv_values)
        sensitivity_tensor = torch.stack(sensitivity_values)
        
        # Convert to percentile ranks (0-1)
        cv_ranks = torch.argsort(torch.argsort(cv_tensor)) / (len(cv_tensor) - 1)
        sensitivity_ranks = torch.argsort(torch.argsort(sensitivity_tensor)) / (len(sensitivity_tensor) - 1)
        
        # Create mapping from layer index to percentile ranks
        cv_rank_map = {}
        sensitivity_rank_map = {}
        
        rank_idx = 0
        for layer_idx in range(len(global_sensitivity)):
            mask = blocks_importance[:, 1] == layer_idx
            if not torch.any(mask):
                continue
            cv_rank_map[layer_idx] = cv_ranks[rank_idx]
            sensitivity_rank_map[layer_idx] = sensitivity_ranks[rank_idx]
            rank_idx += 1
            
        logger.debug(f"CV percentile ranks: min={cv_ranks.min():.6f}, max={cv_ranks.max():.6f}")
        logger.debug(f"Sensitivity percentile ranks: min={sensitivity_ranks.min():.6f}, max={sensitivity_ranks.max():.6f}")
    else:
        logger.debug("No CV values found, using default normalization")
        cv_rank_map = {i: 0.5 for i in range(len(global_sensitivity))}
        sensitivity_rank_map = {i: 0.5 for i in range(len(global_sensitivity))}

    # 3. Aggregate per‑layer statistics --------------------------------------
    layer_stats: list[tuple[int, int, torch.Tensor]] = []  # (idx, N_i, R_i)

    for layer_idx in range(len(global_sensitivity)):
        mask = blocks_importance[:, 1] == layer_idx
        if not torch.any(mask):
            continue

        imp_values = blocks_importance[mask][:, 0].to(device=device, dtype=torch.float32)
        N_i = imp_values.numel()

        # Mean & std for CV; guard against zero mean.
        mean_i = torch.mean(imp_values)
        std_i = torch.std(imp_values, unbiased=False) if N_i > 1 else torch.tensor(0.0, device=device)
        CV_i = std_i / (mean_i.abs() + epsilon)  # use |μ| to avoid sign issues

        # Get percentile ranks for this layer
        sensitivity_rank = sensitivity_rank_map.get(layer_idx, 0.5)
        cv_rank = cv_rank_map.get(layer_idx, 0.5)

        # Original power law formula with percentile rank normalization
        R_i = (sensitivity_rank + epsilon) ** (-alpha) * (cv_rank + epsilon) ** beta

        # Debug logging for each layer
        logger.debug(f"Layer {layer_idx}: N_i={N_i}, sensitivity={global_sensitivity[layer_idx]:.6f}")
        logger.debug(f"  imp_values: min={imp_values.min():.6f}, max={imp_values.max():.6f}")
        logger.debug(f"  mean_i={mean_i:.6f}, std_i={std_i:.6f}")
        logger.debug(f"  CV_i={CV_i:.6f} (std/|mean|)")
        logger.debug(f"  sensitivity_rank={sensitivity_rank:.6f}, cv_rank={cv_rank:.6f}")
        logger.debug(f"  R_i={R_i:.6f} = (sensitivity_rank^({-alpha}) * cv_rank^({beta}))")
        logger.debug(f"  sensitivity_rank^(-alpha)={(sensitivity_rank + epsilon) ** (-alpha):.6f}")
        logger.debug(f"  cv_rank^beta={(cv_rank + epsilon) ** beta:.6f}")
        
        # Additional numerical stability checks
        if torch.isnan(R_i) or torch.isinf(R_i):
            logger.warning(f"Layer {layer_idx}: R_i is {R_i} - numerical issue detected!")
        if CV_i > 1000:  # Very high CV might indicate numerical instability
            logger.warning(f"Layer {layer_idx}: CV_i={CV_i:.6f} is very high - check for numerical issues")
        if mean_i.abs() < 1e-8:  # Very small mean might cause CV explosion
            logger.warning(f"Layer {layer_idx}: mean_i={mean_i:.6e} is very small - CV computation might be unstable")

        layer_stats.append((layer_idx, N_i, R_i))

    if not layer_stats:
        return {}
    
    # Debug: R_i distribution summary
    R_values = [r[2] for r in layer_stats]
    R_tensor = torch.stack(R_values)
    logger.debug(f"R_i distribution: min={R_tensor.min():.6f}, max={R_tensor.max():.6f}, mean={R_tensor.mean():.6f}")
    logger.debug(f"R_i values: {[f'{r:.6f}' for r in R_values]}")

    # 4. Normalise R_i to obtain k_i -----------------------------------------
    total_R = torch.stack([r[2] for r in layer_stats]).sum() + epsilon
    logger.debug(f"Total R across all layers: {total_R:.6f}")

    if K_tot is None:
        # Default to total number of blocks considered
        K_tot = sum(r[1] for r in layer_stats)
    
    logger.debug(f"K_tot (total candidate blocks): {K_tot}")

    candidate_blocks_per_layer: Dict[int, int] = {}

    for layer_idx, N_i, R_i in layer_stats:
        weight = R_i / total_R
        k_i = int(torch.round(weight * K_tot).item())
        k_i = max(1, min(k_i, N_i))  # at least one, at most N_i
        candidate_blocks_per_layer[layer_idx] = k_i
        
        logger.debug(f"Layer {layer_idx}: weight={weight:.6f}, k_i={k_i} (N_i={N_i})")
    
    # Ensure all layers are included (even if they have no blocks)
    for layer_idx in range(len(global_sensitivity)):
        if layer_idx not in candidate_blocks_per_layer:
            candidate_blocks_per_layer[layer_idx] = 0

    # 5. Adjust to make Σ k_i == K_tot (rounding can introduce ±1 error) ------
    diff = K_tot - sum(candidate_blocks_per_layer.values())
    logger.debug(f"Initial k_i sum: {sum(candidate_blocks_per_layer.values())}, target: {K_tot}, diff: {diff}")
    
    if diff != 0:
        logger.debug(f"Adjusting k_i values to match K_tot...")
        # Sort layers by descending residual weight to distribute the diff.
        sorted_layers = sorted(
            candidate_blocks_per_layer,
            key=lambda idx: layer_stats[[l[0] for l in layer_stats].index(idx)][2],
            reverse=(diff > 0),
        )
        sign = 1 if diff > 0 else -1
        
        # Fix: Use proper round-robin to distribute all remaining diff
        layer_idx = 0
        while diff != 0 and len(sorted_layers) > 0:
            idx = sorted_layers[layer_idx % len(sorted_layers)]
            new_val = candidate_blocks_per_layer[idx] + sign
            
            # Fix: Remove the wrong upper bound that was preventing adjustment
            max_blocks = layer_stats[[l[0] for l in layer_stats].index(idx)][1]  # Get N_i for this layer
            if 1 <= new_val <= max_blocks:  # Check against max blocks in layer
                candidate_blocks_per_layer[idx] = new_val
                diff -= sign
                logger.debug(f"  Adjusted Layer {idx}: {candidate_blocks_per_layer[idx] - sign} -> {candidate_blocks_per_layer[idx]}")
            
            layer_idx += 1
            
            # Safety: prevent infinite loop
            if layer_idx > abs(diff) * len(sorted_layers):
                logger.warning(f"Adjustment loop exceeded safety limit, stopping with diff={diff}")
                break

    logger.debug(f"Final k_i sum: {sum(candidate_blocks_per_layer.values())}")
    logger.debug(f"Final candidate_blocks_per_layer: {candidate_blocks_per_layer}")
    
    # Final verification
    final_sum = sum(candidate_blocks_per_layer.values())
    if final_sum != K_tot:
        logger.warning(f"Final k_i sum ({final_sum}) does not match target ({K_tot}), difference: {final_sum - K_tot}")
    else:
        logger.debug(f"Successfully distributed all {K_tot} candidate blocks")

    # Convert the dictionary result to a tensor format
    # This maintains backward compatibility while returning the new format
    result_blocks = []
    
    for layer_idx in range(len(global_sensitivity)):
        k_i = candidate_blocks_per_layer.get(layer_idx, 0)
        if k_i > 0:
            # Get blocks for this layer
            layer_mask = blocks_importance[:, 1] == layer_idx
            layer_blocks = blocks_importance[layer_mask]
            
            # Sort by importance and take top k_i (LEAST important for pruning)
            sorted_indices = torch.argsort(layer_blocks[:, 0], descending=False)
            selected_blocks = layer_blocks[sorted_indices[:k_i]]
            result_blocks.append(selected_blocks)
    
    if result_blocks:
        result = torch.cat(result_blocks, dim=0)
    else:
        result = torch.empty(0, 4, device=device, dtype=torch.float32)
    
    logger.debug(f"Returning {result.shape[0]} candidate blocks in tensor format")
    return result

# ----------------------------
# Helper: pick K rows per layer
# ----------------------------
@torch.inference_mode()
def _pick_k_in_layer(rows: torch.Tensor, k: int, least: bool = True) -> torch.Tensor:
    """
    rows: [m,4] -> [importance, layer_idx, local_idx, block_type]
    k: number to choose
    least=True  -> choose k smallest importance (for pruning)
    """
    if k <= 0 or rows.numel() == 0:
        return rows.new_zeros((0, rows.size(1)))
    if k >= rows.size(0):
        return rows
    imp = rows[:, 0]
    # least => smallest values => largest=False
    idx = torch.topk(imp, k=k, largest=not least, sorted=False).indices
    return rows[idx]


# ------------------------------------------
# Rank-based per-layer candidate allocations
# ------------------------------------------
@torch.inference_mode()
def rank_based_weights(
    sens: torch.Tensor,
    p: float = 1.5,
    sens_floor: float | None = None,
    cap: float | None = None,
):
    """
    Convert per-layer sensitivities into allocation weights using ranks.
    Lower sensitivity => larger weight (more candidates).
    - sens: shape [L] (any dtype/ device)
    - p: exponent applied to (L - rank + 1). Higher p => more skew to low-sensitivity layers
    - sens_floor, cap: optional (kept for API parity; not needed for rank mode)

    Returns:
      w: shape [L], float, normalized to sum to 1 on sens.device
    """
    L = sens.numel()
    if L == 0:
        return sens.new_zeros(())
    # ranks: 1..L (1 = least sensitive, L = most sensitive)
    order = torch.argsort(sens, dim=0, descending=False)
    ranks_long = torch.empty_like(order, dtype=torch.long, device=sens.device)
    ranks_long[order] = torch.arange(L, device=sens.device, dtype=torch.long)

    # Convert to 1..L
    ranks1 = ranks_long + 1  # long
    # Score: higher for less sensitive layers (small rank)
    # Use float for pow
    scores = (L - ranks1 + 1).to(dtype=torch.float32, device=sens.device).pow(p)
    w = scores / scores.sum()
    return w


# ---------------------------------------------------------
# Option A (ranked): select candidates across all modules
# ---------------------------------------------------------
@torch.inference_mode()
def select_candidates_option_a_rank(
    sensitivities: torch.Tensor,           # [L]
    blocks_info: torch.Tensor,             # [N,4] = [importance, layer_idx, local_idx, block_type]
    total_candidates: int,
    device: torch.device | str = "cpu",
    p: float = 1.5,
    logger=None,
):
    """
    Allocate candidate quotas to layers using rank-based sensitivity weights,
    then (crucially) pick the *least-important* blocks inside each layer.

    Returns:
      selected: [M,4] on `device`, where M <= total_candidates (bounded by availability)
    """
    ctx = torch.no_grad() if torch.is_grad_enabled() else nullcontext()
    with ctx:
        if total_candidates <= 0 or blocks_info.numel() == 0:
            return blocks_info.new_zeros((0, 4))

        blocks_info = blocks_info.to(device)
        sens = sensitivities.to(device)

        L = sens.numel()
        layer_idx_all = blocks_info[:, 1].long()

        # per-layer counts available
        avail_per_layer = torch.bincount(layer_idx_all, minlength=L)

        # rank-based weights (sums to 1)
        w = rank_based_weights(sens, p=p)

        # initial integer allocation (floor), then distribute remainder by largest fractional part
        raw = w * float(total_candidates)
        k_floor = torch.floor(raw)
        k = k_floor.clone().long()
        remainder = int(total_candidates - int(k.sum().item()))
        if remainder > 0:
            frac = (raw - k_floor)
            # give the remainder to the biggest fractional parts
            top = torch.topk(frac, k=min(remainder, L), largest=True).indices
            k[top] += 1
            remainder -= min(remainder, L)
            if remainder > 0:
                # round-robin if anything still remains (extremely rare due to ties)
                rr = torch.arange(L, device=device)
                k[rr[:remainder]] += 1

        # clamp by availability
        k = torch.minimum(k, avail_per_layer)

        if logger is not None:
            logger.debug(
                f"[OptionA-Rank] L={L}, N={blocks_info.size(0)}, target={total_candidates}, "
                f"p={p:.3f} | sens[min,max]=({sens.min().item():.6f},{sens.max().item():.6f}) "
                f"| w[min,max]=({w.min().item():.6f},{w.max().item():.6f})"
            )

        out = []
        imp_min = float("inf")
        imp_max = float("-inf")

        # pick *least* important rows inside each layer
        for i in range(L):
            ki = int(k[i].item())
            if ki <= 0:
                continue
            mask = (layer_idx_all == i)
            rows = blocks_info[mask]
            if rows.numel() == 0:
                continue

            chosen = _pick_k_in_layer(rows, ki, least=True)
            out.append(chosen)

            if logger is not None and chosen.numel() > 0:
                imp_min = min(imp_min, float(chosen[:, 0].min().item()))
                imp_max = max(imp_max, float(chosen[:, 0].max().item()))
                logger.debug(
                    f"[OptionA-Rank] Layer {i:02d}: sens={sens[i].item():.6f}, "
                    f"weight={w[i].item():.6f}, avail={int(rows.size(0))}, picked={ki}"
                )

        if len(out) == 0:
            return blocks_info.new_zeros((0, 4))

        selected = torch.cat(out, dim=0)

        if logger is not None:
            logger.debug(f"[OptionA-Rank] importance range among selected: [{imp_min:.6f}, {imp_max:.6f}]")
            logger.debug(f"[OptionA-Rank] selected {selected.size(0)} / {total_candidates} candidates.")

        return selected


# ---------------------------------------------------------
# (Optional) Option A (direct sensitivity^(-p)) with least
# ---------------------------------------------------------
@torch.inference_mode()
def select_candidates_option_a(
    sensitivities: torch.Tensor,           # [L]
    blocks_info: torch.Tensor,             # [N,4]
    total_candidates: int,
    device: torch.device | str = "cpu",
    p: float = 1.5,
    sens_floor: float = 1e-3,
    cap: float = 50.0,
    logger=None,
):
    """
    Direct sensitivity reweighting: w_i ∝ (max(sens, sens_floor))^-p,
    clipped to at most `cap`× the average weight, and then pick the
    *least-important* blocks per layer.

    Returns:
      selected: [M,4] on `device`
    """
    ctx = torch.no_grad() if torch.is_grad_enabled() else nullcontext()
    with ctx:
        if total_candidates <= 0 or blocks_info.numel() == 0:
            return blocks_info.new_zeros((0, 4))

        blocks_info = blocks_info.to(device)
        sens = sensitivities.to(device)

        L = sens.numel()
        layer_idx_all = blocks_info[:, 1].long()

        avail_per_layer = torch.bincount(layer_idx_all, minlength=L)

        # robust weights from sensitivities
        s = torch.clamp(sens, min=sens_floor).to(dtype=torch.float32)
        raw_w = s.pow(-p)
        if cap is not None and cap > 0:
            avg = raw_w.mean()
            raw_w = torch.clamp(raw_w, max=avg * cap)
        w = raw_w / raw_w.sum()

        # integer allocation
        raw = w * float(total_candidates)
        k_floor = torch.floor(raw)
        k = k_floor.clone().long()
        remainder = int(total_candidates - int(k.sum().item()))
        if remainder > 0:
            frac = (raw - k_floor)
            top = torch.topk(frac, k=min(remainder, L), largest=True).indices
            k[top] += 1
            remainder -= min(remainder, L)
            if remainder > 0:
                rr = torch.arange(L, device=blocks_info.device)
                k[rr[:remainder]] += 1

        k = torch.minimum(k, avail_per_layer)

        if logger is not None:
            logger.debug(
                "[OptionA-Robust] "
                f"L={L}, N={blocks_info.size(0)}, target={total_candidates}, "
                f"p={p:.3f}, sens_floor={sens_floor:.1e}, cap={cap:.1f}x | "
                f"sens[min,max]=({sens.min().item():.6f},{sens.max().item():.6f}) | "
                f"raw_w[min,max]=({raw_w.min().item():.6f},{raw_w.max().item():.6f}) | "
                f"w[min,max]=({w.min().item():.6f},{w.max().item():.6f})"
            )

        out = []
        imp_min = float("inf")
        imp_max = float("-inf")

        for i in range(L):
            ki = int(k[i].item())
            if ki <= 0:
                continue
            rows = blocks_info[layer_idx_all == i]
            if rows.numel() == 0:
                continue

            chosen = _pick_k_in_layer(rows, ki, least=True)
            out.append(chosen)

            if logger is not None and chosen.numel() > 0:
                imp_min = min(imp_min, float(chosen[:, 0].min().item()))
                imp_max = max(imp_max, float(chosen[:, 0].max().item()))
                logger.debug(
                    f"[OptionA-Robust] Layer {i:02d}: sens={sens[i].item():.6f}, "
                    f"weight={w[i].item():.6f}, avail={int(rows.size(0))}, picked={ki}"
                )

        if len(out) == 0:
            return blocks_info.new_zeros((0, 4))

        selected = torch.cat(out, dim=0)

        if logger is not None:
            logger.debug(f"[OptionA-Robust] importance range: [{imp_min:.6f}, {imp_max:.6f}]")
            logger.debug(f"[OptionA-Robust] selected {selected.size(0)} / {total_candidates} candidates.")

        return selected

@torch.inference_mode()
def norm_rank01(x: torch.Tensor) -> torch.Tensor:
    """
    Rank-normalize to [0,1]. Ties handled stably.
    0 -> smallest, 1 -> largest.
    """
    order = torch.argsort(x, stable=True)
    ranks = torch.empty_like(order, dtype=torch.float32)
    ranks[order] = torch.arange(x.numel(), device=x.device, dtype=torch.float32)
    return ranks / max(1, x.numel() - 1)


@torch.inference_mode()
def norm_minmax_safe(x: torch.Tensor, clip_q=(0.0, 1.0), eps: float = 1e-12) -> torch.Tensor:
    """
    Robust min-max to [0,1] with optional quantile clipping.
    clip_q = (lo, hi) in [0,1], e.g. (0.01, 0.99) to suppress outliers.
    """
    if clip_q != (0.0, 1.0):
        lo = torch.quantile(x, clip_q[0])
        hi = torch.quantile(x, clip_q[1])
        x = x.clamp(lo, hi)
    mn, mx = x.min(), x.max()
    return (x - mn) / (mx - mn + eps)


@torch.no_grad()
def flatness_cv_per_layer(block_info: torch.Tensor, L: int, eps: float = 1e-12) -> torch.Tensor:
    """
    CV per layer from initial importances: std/mean (clamped).
    block_info: [N,4] = [importance, layer_idx, local_idx, module_type]
    Returns: cv[L]
    """
    imp = block_info[:, 0]
    layer = block_info[:, 1].long()
    cv = torch.empty(L, device=block_info.device, dtype=imp.dtype)
    for i in range(L):
        m = (layer == i)
        if not m.any():
            cv[i] = torch.tensor(1.0, device=imp.device, dtype=imp.dtype)  # neutral
        else:
            vals = imp[m]
            mean = vals.mean()
            std = vals.std(unbiased=False) if vals.numel() > 1 else torch.tensor(0.0, device=imp.device, dtype=imp.dtype)
            cv[i] = (std / (mean.abs() + eps)).clamp_min(1e-6)
    return cv


@torch.inference_mode()
def flatness_boost_from_cv(cv: torch.Tensor, q: float = 1.0, cap: float = 4.0) -> torch.Tensor:
    """
    Lower CV (flatter) => higher boost.
    boost_i = clamp( (median_cv / cv_i)^q, 1/cap, cap )
    Returns: boost[L]
    """
    med = torch.median(cv)
    boost = (med / cv).pow(q)
    return boost.clamp(min=1.0/cap, max=cap)

@torch.inference_mode()
def select_candidates_option_a_normalized(
    sensitivities: torch.Tensor,            # [L], higher = more sensitive
    block_info: torch.Tensor,               # [N,4] = [importance, layer_idx, local_idx, module_type]
    total_candidates: int,                  # total M to select
    device=None,
    p_sens: float = 1.5,                    # steeper focus on less sensitive layers
    q_flat: float = 1.0,                    # flatness boost strength
    cap_flat: float = 4.0,                  # flatness boost cap
    logger=None,
):
    """
    Returns:
      candidates: [M,4] = [importance, layer_idx, local_idx, module_type] for M=total_candidates
      (Least-important blocks chosen within each layer.)
    """
    if device is None:
        device = block_info.device
    L = sensitivities.numel()
    N = block_info.size(0)

    sens = sensitivities.to(device).float()

    # --- Sensitivity weight (invert by rank so less sensitive gets larger weight) ---
    s_rank01 = norm_rank01(sens)            # 0..1 (largest sens → 1)
    inv_rank = 1.0 - s_rank01               # less sensitive → closer to 1
    w_sens = inv_rank.clamp_min(1e-6).pow(p_sens)

    # --- Flatness boost from CV of importances (flatter => more candidates) ---
    cv = flatness_cv_per_layer(block_info.to(device), L)
    boost = flatness_boost_from_cv(cv, q=q_flat, cap=cap_flat)

    # --- Combined layer weights ---
    w = w_sens * boost
    w = (w / (w.sum() + 1e-12)).clamp_min(1e-12)

    # --- Per-layer target counts (float → int with rounding & fixup) ---
    k_float = w * float(total_candidates)
    k = torch.floor(k_float).long()
    leftover = total_candidates - int(k.sum().item())

    if leftover > 0:
        # distribute remainder to layers with largest fractional parts
        frac = (k_float - k.float())
        idx = torch.argsort(frac, descending=True)
        k[idx[:leftover]] += 1
    elif leftover < 0:
        # (rare) trim from smallest fractional parts
        frac = (k_float - k.float())
        idx = torch.argsort(frac, descending=False)
        k[idx[:(-leftover)]] -= 1

    # --- Cap by available blocks per layer, and re-balance if needed ---
    layer_idx = block_info[:, 1].long().to(device)
    avail = torch.zeros(L, device=device, dtype=torch.long)
    for i in range(L):
        avail[i] = (layer_idx == i).sum()

    over = torch.clamp(k - avail, min=0)
    give_back = int(over.sum().item())
    k = torch.minimum(k, avail)

    if give_back > 0:
        # redistribute to layers with headroom proportional to weights
        headroom = (avail - k).clamp_min(0)
        if headroom.sum() > 0:
            h_w = (w * headroom.float())
            h_w = h_w / (h_w.sum() + 1e-12)
            add_float = h_w * float(give_back)
            add = torch.floor(add_float).long()
            leftover2 = give_back - int(add.sum().item())
            if leftover2 > 0:
                frac2 = (add_float - add.float())
                idx2 = torch.argsort(frac2, descending=True)
                add[idx2[:leftover2]] += 1
            k = k + torch.minimum(add, headroom)

    if logger is not None:
        logger.debug(
            f"[OptionA-Norm] L={L}, N={N}, target={total_candidates}, "
            f"p_sens={p_sens:.2f}, q_flat={q_flat:.2f}, cap_flat={cap_flat:.1f}"
        )
        logger.debug(
            f"[OptionA-Norm] sens[min,max]=({sens.min().item():.6f},{sens.max().item():.6f}), "
            f"CV[med]={cv.median().item():.4f}"
        )
        for i in [0,1,8,16,24,27,30,31]:
            logger.debug(
                f"[OptionA-Norm] L{i:02d}: inv_rank={inv_rank[i].item():.3f}, "
                f"cv={cv[i].item():.4f}, boost={boost[i].item():.3f}, "
                f"w={w[i].item():.6f}, k={k[i].item()}, avail={avail[i].item()}"
            )

    # --- Within each layer: select the k[i] LEAST important blocks ---
    importance = block_info[:, 0].to(device)
    out = []
    for i in range(L):
        ki = int(k[i].item())
        if ki <= 0:
            continue
        mask = (layer_idx == i)
        if not mask.any():
            continue
        idx_i = torch.nonzero(mask, as_tuple=False).squeeze(1)
        # sort ascending by importance (least important first)
        sort_i = torch.argsort(importance[idx_i], descending=False)
        take = idx_i[sort_i[:ki]]
        out.append(block_info[take])

    if not out:
        return block_info.new_zeros((0, 4))

    return torch.cat(out, dim=0)

# ---------- stable normalizers ----------
@torch.no_grad()
def normalize_minmax_stable(x: torch.Tensor, eps: float = 1e-12) -> torch.Tensor:
    """Stable [0,1] min-max; returns zeros if x is constant."""
    x = x.to(dtype=torch.float32)
    xmin, xmax = x.min(), x.max()
    rng = (xmax - xmin).clamp_min(eps)
    out = (x - xmin) / rng
    # If truly constant, squash to zeros (no spurious NaNs)
    if (xmax - xmin) < eps:
        out = torch.zeros_like(x, dtype=torch.float32)
    return out

@torch.no_grad()
def inv_rank01_from_sensitivity(sens: torch.Tensor) -> torch.Tensor:
    """
    Rank-based inverse sensitivity in [0,1]:
      0 -> most sensitive, 1 -> least sensitive.
    Rank is robust to outliers and avoids divide-by-zero.
    """
    L = sens.numel()
    order = torch.argsort(sens, descending=True)           # most sensitive first
    ranks = torch.empty_like(sens, dtype=torch.long)
    ranks[order] = torch.arange(L, device=sens.device, dtype=torch.long)
    if L > 1:
        inv01 = ranks.to(torch.float32) / (L - 1)
    else:
        inv01 = torch.zeros_like(sens, dtype=torch.float32)
    return inv01  # 0..1


# ---------- per-module CV (no extra hyperparams) ----------

@torch.no_grad()
def per_module_cv(block_importance: torch.Tensor, num_layers: int) -> torch.Tensor:
    """
    Compute CV per (layer, module_type).
    Input blocks: [N,4] = [importance, layer_idx, local_idx, module_type]
    Returns cv_map with shape [num_layers, 2] (module_type ∈ {0,1}); 0 if insufficient data.
    """
    device = block_importance.device
    cv_map = torch.zeros((num_layers, 2), device=device, dtype=torch.float32)
    for l in range(num_layers):
        for mt in (0, 1):
            mask = (block_importance[:, 1] == l) & (block_importance[:, 3] == mt)
            cnt = int(mask.sum().item())
            if cnt >= 2:
                imp = block_importance[mask, 0].to(torch.float32)
                m = imp.mean()
                s = imp.std(unbiased=False)
                cv_map[l, mt] = (s / (m.abs() + 1e-12)).clamp_min(0.0)
            else:
                cv_map[l, mt] = 0.0
    # Normalize CV globally to [0,1] (so “flatness boost” = 1 + (1 - cv_norm))
    cv_norm = normalize_minmax_stable(cv_map)
    return cv_norm  # 0..1


# ---------- Option A (single hyperparam p) ----------

@torch.no_grad()
def select_candidates_option_a_normalized_final(
    sensitivities: torch.Tensor,          # [L] (raw)
    block_importance: torch.Tensor,       # [N,4] = [importance, layer_idx, local_idx, module_type]
    total_candidates: int,
    *,
    p: float = 1.5,                       # the ONLY hyperparameter
    device: str | torch.device = "cpu",
) -> torch.Tensor:
    """
    Option A:
      1) rank-normalize sensitivities (inverse rank 0..1), then w_layer = (inv_rank01 + eps)^p
      2) split each layer's quota across modules using a *flatness boost* derived from per-module CV:
         boost = 1 + (1 - cv_norm)  (i.e., more candidates where scores are flat/close)
      3) within each (layer, module), select *least-important* blocks
      4) obey availability caps; never allocate to empty groups; try to hit global target
    Returns: [M,4] rows (importance, layer_idx, local_idx, module_type), M≈total_candidates.
    """
    logger = SharedLogger.get_logger("SELECT_CANDIDATES_OPTION_A")

    # Move & types
    sens = sensitivities.to(device=device, dtype=torch.float32)
    blocks = block_importance.to(device=device, dtype=torch.float32)
    assert blocks.shape[1] == 4, "block_importance must be [N,4]"

    L = sens.numel()
    N = blocks.size(0)
    T = int(total_candidates)

    # Early exit
    if N == 0 or T <= 0:
        return torch.empty(0, 4, device=device, dtype=torch.float32)

    # Availability per (layer, module)
    layer_idx = blocks[:, 1].long()
    mod_idx   = blocks[:, 3].long()
    avail_L   = torch.bincount(layer_idx, minlength=L).to(device)
    avail_LM  = torch.zeros((L, 2), device=device, dtype=torch.long)
    for mt in (0, 1):
        mask = (mod_idx == mt)
        if mask.any():
            counts = torch.bincount(layer_idx[mask], minlength=L)
            avail_LM[:, mt] = counts[:L].to(torch.long)

    # 1) Layer weights from inverse rank of sensitivity (stable, no zeros unless L=1)
    inv01 = inv_rank01_from_sensitivity(sens)              # 0..1
    w_layer = (inv01 + 1e-12).pow(p)
    # Zero-out layers with no blocks
    w_layer = torch.where(avail_L > 0, w_layer, torch.zeros_like(w_layer))
    w_sum = w_layer.sum()
    if w_sum <= 0:
        # Fallback: uniform across non-empty layers
        w_layer = torch.where(avail_L > 0, torch.ones_like(w_layer), torch.zeros_like(w_layer))
        w_sum = w_layer.sum()
    w_layer = w_layer / w_sum

    if logger:
        smin, smax = sens.min().item(), sens.max().item()
        logger.debug(f"[OptionA] L={L}, N={N}, T={T}, p={p} | sens[min,max]=({smin:.6f},{smax:.6f})")
        logger.debug(f"[OptionA] non-empty layers: {(avail_L>0).sum().item()} / {L}")

    # 2) Split layer quota across modules using CV-based flatness boost (no extra hyperparams)
    cv_norm = per_module_cv(blocks, num_layers=L)          # 0..1 per (layer, module)
    boost = 1.0 + (1.0 - cv_norm)                          # ∈ [1,2], bigger when flatter (scores closer)

    # Helper: integer distribution with caps
    def distribute_with_caps(total: int, weights: torch.Tensor, caps: torch.Tensor) -> torch.Tensor:
        """Distribute 'total' across dims with fractional rounding and per-dim caps."""
        total = int(total)
        if total <= 0 or weights.sum() <= 0 or caps.sum() <= 0:
            return torch.zeros_like(caps, dtype=torch.long)

        w = weights.clone()
        w = torch.where(caps > 0, w, torch.zeros_like(w))
        s = w.sum()
        if s <= 0:
            return torch.zeros_like(caps, dtype=torch.long)
        w = w / s

        raw = (w * total).to(torch.float32)
        base = torch.floor(raw).to(torch.long)
        # Respect caps
        base = torch.minimum(base, caps)
        rem = int(total - base.sum().item())
        if rem <= 0:
            return base

        # Greedy fill remaining into those with largest fractional part and remaining cap
        frac = (raw - base.to(torch.float32))
        order = torch.argsort(frac, descending=True)
        for idx in order.tolist():
            if rem == 0:
                break
            room = int(caps[idx].item() - base[idx].item())
            if room > 0:
                take = min(room, rem)
                base[idx] += take
                rem -= take
        return base

    # First, allocate per-layer counts with availability caps
    caps_layer = avail_L.clone()
    counts_layer = distribute_with_caps(T, w_layer, caps_layer)

    # 3) Within each layer, split to modules by boost and module availability
    counts_LM = torch.zeros((L, 2), device=device, dtype=torch.long)
    for l in range(L):
        cap0, cap1 = int(avail_LM[l, 0].item()), int(avail_LM[l, 1].item())
        need = int(counts_layer[l].item())
        if need == 0 or (cap0 + cap1) == 0:
            continue
        w0 = boost[l, 0] if cap0 > 0 else torch.tensor(0.0, device=device)
        w1 = boost[l, 1] if cap1 > 0 else torch.tensor(0.0, device=device)
        w = torch.stack([w0, w1])
        caps = torch.tensor([cap0, cap1], device=device, dtype=torch.long)
        counts_LM[l] = distribute_with_caps(need, w, caps)

    # If we’re short globally (caps inside layers), try a global refill pass on remaining capacity
    used_LM = counts_LM.clone()
    short = T - int(used_LM.sum().item())
    if short > 0:
        # global weights per (l,m): w_layer[l] * boost[l,m]
        wLM = (w_layer[:, None] * boost).reshape(-1)
        capsLM = (avail_LM - used_LM).clamp_min(0).reshape(-1)
        extraLM = distribute_with_caps(short, wLM, capsLM).reshape(L, 2)
        counts_LM += extraLM

    # 4) Select least-important blocks inside each (layer, module)
    selected_rows = []
    for l in range(L):
        for mt in (0, 1):
            need = int(counts_LM[l, mt].item())
            if need <= 0:
                continue
            mask = (layer_idx == l) & (mod_idx == mt)
            if not mask.any():
                continue
            imp = blocks[mask, 0]
            k = min(need, imp.numel())
            idx_local = torch.topk(imp, k, largest=False).indices  # <-- least-important
            chosen = blocks[mask][idx_local]
            selected_rows.append(chosen)

    if not selected_rows:
        return torch.empty(0, 4, device=device, dtype=torch.float32)

    selected = torch.cat(selected_rows, dim=0)

    # Debug
    if logger:
        logger.debug(f"[OptionA] selected {selected.size(0)} / {T} candidates "
                     f"(imp[min,max]=[{selected[:,0].min().item():.3e},{selected[:,0].max().item():.3e}])")
        # per-layer count dump
        for l in range(L):
            c = int((selected[:, 1].long() == l).sum().item())
            logger.debug(f"[OptionA] Layer {l:02d} candidates: {c}")

    return selected


















from torch import Tensor
# ---------- helpers ----------

@torch.inference_mode()
def _rank01_asc(x: Tensor) -> Tensor:
    """Map x to (0,1] by rank (ascending): smallest -> 1/L, largest -> 1."""
    L = x.numel()
    order = torch.argsort(x, descending=False)
    ranks = torch.empty_like(x, dtype=torch.float32)
    ranks[order] = torch.arange(1, L + 1, device=x.device, dtype=torch.float32)
    return ranks / L


@torch.inference_mode()
def _layer_weights_from_sens(sens: Tensor, p: float = 1.5, eps: float = 1e-12) -> Tensor:
    """
    Rank-based layer weights (no scale issues):
      - Ascending rank on sensitivities (low sens => small rank),
      - Invert and raise to p,
      - Normalize to sum 1.
    """
    r = _rank01_asc(sens.float())
    inv = (1.0 / (r + eps)) ** p
    return inv / (inv.sum() + eps)


@torch.inference_mode()
def _module_flatness_from_importances(
    importances: Tensor,  # [N,4]: [imp, layer, local, module]
    L: int,
    n_module_types: int = 2,
    eps: float = 1e-12,
) -> Tensor:
    """
    Per-(layer,module) flatness f_{ℓ,t} ∈ [0,1].
      - Compute CV = std / (|mean| + eps) for each (ℓ,t).
      - Within each layer ℓ, inverse-rank CV across existing modules:
            lower CV  -> higher flatness.
      - If a layer has only one existing module, set its flatness to 0 (neutral).
      - If a (ℓ,t) has 0 blocks, f=0.
    """
    device = importances.device
    f = torch.zeros(L, n_module_types, device=device, dtype=torch.float32)

    lay = importances[:, 1].long()
    mod = importances[:, 3].long()
    imp = importances[:, 0].float()

    # First pass: store CVs (or 0 if no blocks)
    for ell in range(L):
        for t in range(n_module_types):
            mask = (lay == ell) & (mod == t)
            if not mask.any():
                continue
            x = imp[mask]
            mu = x.mean()
            sigma = x.std(unbiased=False)
            cv = sigma / (mu.abs() + eps)
            f[ell, t] = cv

    # Second pass: per-layer inverse-rank CV -> [0,1]
    for ell in range(L):
        row = f[ell, :]
        have = row > 0  # modules that exist (cv>0). Note: if CV=0 exactly, treat as exist too.
        # Consider zero-CV as exist (flattest); mark them:
        zero_cv_mask = (row == 0)
        exists = have | zero_cv_mask
        if exists.sum() == 1:
            # Only one module present → neutral
            f[ell, exists.nonzero(as_tuple=False)] = 0.0
        elif exists.any():
            cvs = row[exists]
            # Lower CV ⇒ flatter ⇒ higher flatness
            r = _rank01_asc(cvs)
            flat = (r - r.min()) / (r.max() - r.min() + eps)
            f[ell, :][exists] = flat
            # Non-existing modules remain 0
    return f

@torch.inference_mode()
def _safe_layer_allocation(weights: torch.Tensor,
                           avail_per_bucket: torch.Tensor,
                           total_candidates: int,
                           logger,
                           tag: str) -> torch.Tensor:
    """
    Robust integer allocation with caps, cycling until we hit total_candidates.
    weights: (B,) floats >=0
    avail_per_bucket: (B,) ints >=0
    returns k: (B,) ints, sum == total_candidates unless total availability is 0
    """
    B = weights.numel()
    eps = 1e-12

    w = weights.float().clone() + eps
    w = w / (w.sum() + eps)

    k_float = w * float(total_candidates)
    k = torch.floor(k_float).to(torch.long)

    # Early cap
    k = torch.minimum(k, avail_per_bucket)

    assigned = int(k.sum().item())
    need = int(total_candidates - assigned)

    # ---- Largest remainders with cycling ----
    if need > 0:
        slack = (avail_per_bucket - k).clamp_min(0)
        frac = (k_float - k.to(k_float.dtype))

        # Prefer bigger fractional part, but only where slack > 0
        frac = torch.where(slack > 0, frac, torch.full_like(frac, -1.0))
        order = torch.argsort(frac, descending=True)

        # Cycle through 'order' until we exhaust 'need' or slack
        i = 0
        total_slack = int(slack.sum().item())
        while need > 0 and total_slack > 0 and order.numel() > 0:
            j = order[i % order.numel()].item()
            if k[j] < avail_per_bucket[j]:
                k[j] += 1
                need -= 1
                total_slack -= 1
            i += 1

    # ---- Emergency fill if still short (use max spare capacity), with cycling ----
    total = int(k.sum().item())
    if total < total_candidates:
        deficit = total_candidates - total
        spare = (avail_per_bucket - k).clamp_min(0)
        if logger:
            logger.debug(f"[{tag}] Under by {deficit}; emergency fill.")

        order = torch.argsort(spare, descending=True)
        i = 0
        total_spare = int(spare.sum().item())
        while deficit > 0 and total_spare > 0 and order.numel() > 0:
            j = order[i % order.numel()].item()
            if k[j] < avail_per_bucket[j]:
                k[j] += 1
                deficit -= 1
                total_spare -= 1
            i += 1

    # ---- Trim if somehow over (should be rare) ----
    total = int(k.sum().item())
    if total > total_candidates:
        extra = total - total_candidates
        frac = (k_float - torch.floor(k_float))
        order = torch.argsort(frac)  # smallest fractional part first
        for j in order.tolist():
            if extra <= 0:
                break
            if k[j] > 0:
                k[j] -= 1
                extra -= 1

    return k

# ---------- Option B ----------

@torch.inference_mode()
def select_candidates_option_b(
    sensitivities: Tensor,        # [L], higher = more sensitive
    block_importance: Tensor,     # [N,4]: [importance, layer_idx, local_idx, module_type]
    total_candidates: int,
    device: str = "cuda",
    p: float = 1.5,               # sensitivity aggressiveness
    beta: float = 0.5,            # flatness mixing (0..1)
    module_types: int = 2,
) -> Tensor:
    """
    Option B (two knobs):
        w_{ℓ,t} ∝  (rank-inv(sens_ℓ))^p  ×  ((1-β) + β f_{ℓ,t})  ×  n_{ℓ,t}

    Then allocate integers robustly (caps + largest remainders), and within each (ℓ,t)
    choose the **least-important** k blocks.

    Returns: [M,4] = [importance, layer_idx, local_idx, module_type] with M=total_candidates.
    """
    logger = SharedLogger.get_logger("SELECT_CANDIDATES_OPTION_B")  # assume your logger
    sensitivities = sensitivities.to(device).float()
    block_importance = block_importance.to(device).float()

    L = sensitivities.numel()
    assert total_candidates > 0, "total_candidates must be > 0"

    # 1) Layer weights from sensitivities (rank-inv)
    w_layer = _layer_weights_from_sens(sensitivities, p=p)  # sum=1, shape [L]

    # 2) Per-(layer,module) flatness in [0,1]
    f_lm = _module_flatness_from_importances(block_importance, L, n_module_types=module_types)

    # 3) Availability per (layer,module)
    lay = block_importance[:, 1].long()
    mod = block_importance[:, 3].long()
    avail = torch.zeros(L, module_types, device=device, dtype=torch.long)
    avail.index_put_((lay, mod), torch.ones_like(lay, dtype=torch.long), accumulate=True)

    # 4) Raw weights per (layer,module)
    mix = (1.0 - beta) + beta * f_lm              # [L, T]
    w_lm = w_layer.view(L, 1) * mix               # [L, T]
    w_raw = w_lm * avail.to(w_lm.dtype)           # zero weight where no blocks

    # 5) Robust integer allocation across ALL (layer,module) buckets at once
    B = L * module_types
    w_flat = w_raw.reshape(B)
    avail_flat = avail.reshape(B)
    k_flat = _safe_layer_allocation(w_flat, avail_flat, total_candidates, logger, "OptionB")
    k_lm = k_flat.view(L, module_types)

    # 6) Select the LEAST-important k blocks per (layer,module)
    selected = []
    imp = block_importance[:, 0]
    for ell in range(L):
        for t in range(module_types):
            k = int(k_lm[ell, t].item())
            if k <= 0:
                continue
            msk = (lay == ell) & (mod == t)
            if not msk.any():
                continue
            grp = block_importance[msk]  # [m,4]
            kk = min(k, grp.shape[0])
            # least-important = smallest importance
            idx = torch.topk(grp[:, 0], k=kk, largest=False).indices
            selected.append(grp.index_select(0, idx))

    if len(selected) == 0:
        if logger: logger.debug("[OptionB] Allocation produced no candidates.")
        return block_importance.new_empty((0, 4))

    out = torch.cat(selected, dim=0)

    # 7) Ensure exact total (trim by global least-importance if needed)
    if out.shape[0] > total_candidates:
        keep = torch.topk(out[:, 0], k=total_candidates, largest=False).indices
        out = out.index_select(0, keep)

    # Debug summary per layer
    if logger:
        counts = [int((out[:, 1].long() == ell).sum().item()) for ell in range(L)]
        logger.debug(f"[OptionB] Alloc summary per layer: {counts} (sum={sum(counts)})")

    return out