# src/models/snip_utils.py
from __future__ import annotations

from copy import deepcopy
from typing import Dict, List, Sequence, Tuple, Union, Optional

import torch
import torch.nn as nn
from ptflops import get_model_complexity_info
import numpy as np
IdxLike = Union[torch.Tensor, Sequence[int], List[int]]

# from __future__ import annotations
# from typing import Union, Optional, Tuple
# import torch
# import torch.nn as nn

import torch
import torch.nn as nn
from typing import Union

@torch.no_grad()
def recalibrate_bn(
    model: nn.Module,
    onebatch_or_loader,
    device: Union[str, torch.device] = "cpu",
    num_batches: int = 200,
    max_per_batch: int = 256,
    reset_running_stats: bool = False,
    log_prefix: str = "",
) -> nn.Module:
    """
    Light BN calibration on *trunk* BNs only.
    - Skips exit-head BNs (those typically keep track_running_stats=False by design).
    - Ensures buffers exist; can optionally reset them.
    - Restores each BN's original momentum after calibration.
    - Preserves model's train/eval mode.
    """
    was_training = model.training
    model.to(device)
    model.train()

    touched = []
    for name, m in model.named_modules():
        if not isinstance(m, nn.modules.batchnorm._BatchNorm):
            continue
        if name.startswith("exit_heads."):
            # Head BNs are per-batch; do not try to calibrate them.
            continue

        # Save original momentum to restore later
        orig_momentum = getattr(m, "momentum", None)

        # Ensure buffers exist and on the right device
        nf = int(m.num_features)
        m.track_running_stats = True

        if getattr(m, "running_mean", None) is None:
            if hasattr(m, "running_mean"): delattr(m, "running_mean")
            m.register_buffer("running_mean", torch.zeros(nf, device=device))
        else:
            m.running_mean = m.running_mean.to(device)

        if getattr(m, "running_var", None) is None:
            if hasattr(m, "running_var"): delattr(m, "running_var")
            m.register_buffer("running_var", torch.ones(nf, device=device))
        else:
            m.running_var = m.running_var.to(device)

        if hasattr(m, "num_batches_tracked"):
            if getattr(m, "num_batches_tracked", None) is None:
                if hasattr(m, "num_batches_tracked"): delattr(m, "num_batches_tracked")
                m.register_buffer("num_batches_tracked", torch.tensor(0, dtype=torch.long, device=device))
            else:
                m.num_batches_tracked = m.num_batches_tracked.to(device, dtype=torch.long)

        if reset_running_stats:
            m.running_mean.zero_()
            m.running_var.fill_(1.0)
            if hasattr(m, "num_batches_tracked") and m.num_batches_tracked is not None:
                m.num_batches_tracked.zero_()

        # Use cumulative moving average for fast convergence during calibration
        m.momentum = None
        touched.append((m, orig_momentum))

    # Iterator
    if isinstance(onebatch_or_loader, tuple):
        def _once(): yield onebatch_or_loader
        data_iter = _once()
    else:
        data_iter = iter(onebatch_or_loader)

    used_batches = 0
    used_samples = 0
    while used_batches < int(num_batches):
        try:
            x, _ = next(data_iter)
        except StopIteration:
            break
        if x.size(0) == 0:
            continue
        if max_per_batch and x.size(0) > max_per_batch:
            x = x[:max_per_batch]
        x = x.to(device, non_blocking=True)
        _ = model(x)  # updates trunk BN running stats
        used_batches += 1
        used_samples += int(x.size(0))

    # Restore original momenta
    for m, orig in touched:
        m.momentum = orig

    if log_prefix:
        print(f"{log_prefix}BN-updated(trunk-only): batches={used_batches}, samples={used_samples}, reset={reset_running_stats}")

    model.train() if was_training else model.eval()
    return model


# --- SNIP CIFAR helpers + channel surgery (drop-ins) ---

# import torch
# import torch.nn as nn
from typing import Iterable

# from typing import Dict, List, Union, Optional, Tuple
# import torch
# import torch.nn as nn

# ---------------------------------------------------------------------------
# 1) Scale sizes → per-layer k (targets)
# ---------------------------------------------------------------------------

def build_scale_targets_for_resnet(
    model: nn.Module,
    scale_cfg: Union[float, Dict[str, float]],
) -> Dict[str, int]:
    """
    Compute per-layer OUT-channel target sizes (k) for a ResNet trunk (with optional early exits),
    based on a Scale policy:

      • If `scale_cfg` is a float r in (0,1], every stage width (and stem) scales to round(r * original).
      • If `scale_cfg` is a dict, you can specify finer control:
            {
                "stem": 0.75,                  # conv1
                "stage0": 0.75,
                "stage1": 0.50,
                "stage2": 0.50,
                "stage3": 0.50,
            }
        Any missing keys fall back to 1.0.

    Residual constraints (enforced in the *targets*):
      • Within a stage, the stage width is constant.
      • For a projection block (has downsample): conv2 OUT == downsample.0 OUT == stage_width.
      • For an identity block: conv2 OUT == stage_width.
      • For block.conv1 OUT: also == stage_width.

    Returns:
      targets: Dict[str, int] mapping module qualified names to desired OUT channels.
               Keys include:
                 - "conv1" (stem)
                 - "layers.{s}.{b}.conv1"
                 - "layers.{s}.{b}.conv2"
                 - "layers.{s}.{b}.downsample.0"  (if present)
    """
    # normalize scale policy
    if isinstance(scale_cfg, (float, int)):
        def _r_for(stage_name: Optional[str]) -> float:
            return float(scale_cfg)
    else:
        scale_dict = {str(k).lower(): float(v) for k, v in dict(scale_cfg).items()}
        def _r_for(stage_name: Optional[str]) -> float:
            if stage_name is None:
                return float(scale_dict.get("stem", 1.0))
            return float(scale_dict.get(stage_name.lower(), 1.0))

    targets: Dict[str, int] = {}

    # --- Stem ---
    if hasattr(model, "conv1") and isinstance(model.conv1, nn.Conv2d):
        stem_out = int(model.conv1.out_channels)
        r_stem   = _r_for(None)
        k_stem   = max(1, min(stem_out, int(round(r_stem * stem_out))))
        targets["conv1"] = k_stem

    # --- Stages ---
    if not hasattr(model, "layers"):
        return targets

    # Compute per-stage widths, then assign to blocks
    for s_idx, stage in enumerate(getattr(model, "layers")):
        stage_name = f"stage{s_idx}"
        # infer original stage width from the first block's conv2.out_channels (or conv1)
        if len(stage) == 0:
            continue
        first_block = stage[0]
        stage_out_orig = int(first_block.conv2.out_channels)

        r_stage = _r_for(stage_name)
        stage_width = max(1, min(stage_out_orig, int(round(r_stage * stage_out_orig))))

        for b_idx, block in enumerate(stage):
            conv1_name = f"layers.{s_idx}.{b_idx}.conv1"
            conv2_name = f"layers.{s_idx}.{b_idx}.conv2"
            targets[conv1_name] = stage_width
            targets[conv2_name] = stage_width

            # projection block: tie downsample.0 to same stage width
            if getattr(block, "downsample", None) is not None and isinstance(block.downsample, nn.Sequential):
                if len(block.downsample) >= 1 and isinstance(block.downsample[0], nn.Conv2d):
                    down_name = f"layers.{s_idx}.{b_idx}.downsample.0"
                    targets[down_name] = stage_width

    return targets


# ---------------------------------------------------------------------------
# 2) SNIP scores + targets → index masks
# ---------------------------------------------------------------------------

def masks_from_snip_and_targets(
    scores: Dict[str, torch.Tensor],
    targets: Dict[str, int],
    model: nn.Module,
) -> Dict[str, List[int]]:
    """
    Convert SNIP per-layer scores + Scale per-layer target sizes (k) into concrete keep-index masks.

    Behavior:
      • For every layer present in `targets`, take top-k indices from `scores[name]`.
      • If a layer has no score tensor, keep the lowest indices [0..k-1] as a neutral fallback.
      • For projection blocks, force conv2 and downsample.0 to share the SAME index set
        (chosen from conv2 scores); downsample may have no score — we still copy conv2's indices.
      • For layers whose k == original out_channels, we still return explicit indices (helps the
        downstream surgery ensure downsample/conv2 coupling).

    Returns:
      masks: Dict[str, List[int]] with out-channel indices to keep per layer.
    """
    def _topk_or_range(sc: Optional[torch.Tensor], k: int, outc: int) -> List[int]:
        k = max(1, min(k, outc))
        if sc is None or sc.numel() == 0:
            return list(range(k))
        sc = sc.float().detach().cpu().view(-1)
        k = min(k, sc.numel())
        idx = torch.topk(sc, k=k, largest=True, sorted=True).indices.tolist()
        return [int(i) for i in idx]

    masks: Dict[str, List[int]] = {}

    # Stem
    if "conv1" in targets and hasattr(model, "conv1") and isinstance(model.conv1, nn.Conv2d):
        outc = int(model.conv1.out_channels)
        k    = int(targets["conv1"])
        sc   = scores.get("conv1", None)
        masks["conv1"] = _topk_or_range(sc, k, outc)

    # Stages
    if hasattr(model, "layers"):
        for s_idx, stage in enumerate(getattr(model, "layers")):
            for b_idx, block in enumerate(stage):
                c1_name = f"layers.{s_idx}.{b_idx}.conv1"
                c2_name = f"layers.{s_idx}.{b_idx}.conv2"
                ds_name = f"layers.{s_idx}.{b_idx}.downsample.0" if getattr(block, "downsample", None) is not None else None

                # conv1
                if c1_name in targets:
                    outc = int(block.conv1.out_channels)
                    k    = int(targets[c1_name])
                    sc   = scores.get(c1_name, None)
                    masks[c1_name] = _topk_or_range(sc, k, outc)

                # conv2 (+ optional downsample coupling)
                if c2_name in targets:
                    outc2 = int(block.conv2.out_channels)
                    k2    = int(targets[c2_name])
                    sc2   = scores.get(c2_name, None)
                    keep_c2 = _topk_or_range(sc2, k2, outc2)
                    masks[c2_name] = keep_c2

                    # projection: force downsample.0 to use the same indices
                    if ds_name is not None and ds_name in targets:
                        # even if ds scores exist, we COPY conv2 indices to guarantee addition shape match
                        masks[ds_name] = list(keep_c2)

    # sort & dedup defensively
    for k in list(masks.keys()):
        uniq_sorted = sorted(set(int(i) for i in masks[k]))
        masks[k] = uniq_sorted[: max(1, len(uniq_sorted))]

    return masks
# ---------- data helper (one class-balanced batch) ----------
@torch.no_grad()
def make_balanced_batch_cifar(
    loader, classes: int = 10, per_class: int = 13, device: Optional[torch.device] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Build a single class-balanced batch for SNIP scoring (≈ classes*per_class samples).
    Works with your CIFAR loaders (NCHW tensors).
    """
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    buckets = [[] for _ in range(classes)]
    need = [per_class] * classes

    for x, y in loader:
        b = x.size(0)
        for i in range(b):
            c = int(y[i])
            if need[c] > 0:
                buckets[c].append((x[i:i+1], y[i:i+1]))
                need[c] -= 1
        if all(n == 0 for n in need):
            break

    xs = torch.cat([torch.cat([xi for xi, _ in bkt], dim=0) for bkt in buckets if bkt], dim=0)
    ys = torch.cat([torch.cat([yi for _, yi in bkt], dim=0) for bkt in buckets if bkt], dim=0)
    idx = torch.randperm(xs.size(0))
    xs, ys = xs[idx], ys[idx]
    return xs.to(device, non_blocking=True), ys.to(device, non_blocking=True)

def topk_indices_per_layer(scores: Dict[str, torch.Tensor], keep_ratio: float, min_keep: int = 1) -> Dict[str, torch.Tensor]:
    kept: Dict[str, torch.Tensor] = {}
    kr = max(0.10, min(float(keep_ratio), 1.0))
    for name, sc in scores.items():
        C = int(sc.numel())
        k = max(min_keep, min(C, int(round(kr * C))))
        kept[name] = torch.topk(sc, k=k, largest=True, sorted=True).indices
    return kept

# ---------- channel surgery (no-BN required; BN is optional & handled) ----------
def _to_idx(idx: torch.Tensor) -> torch.Tensor:
    if not torch.is_tensor(idx):
        idx = torch.tensor(idx)
    return idx.detach().cpu().long()

def _sorted(idx: torch.Tensor) -> torch.Tensor:
    return _to_idx(idx).sort()[0]

def take_out_channels_conv(conv: nn.Conv2d, keep_idx: torch.Tensor) -> None:
    keep_idx = _sorted(keep_idx)
    W = conv.weight.data.index_select(0, keep_idx)
    conv.weight = nn.Parameter(W, requires_grad=conv.weight.requires_grad)
    if conv.bias is not None:
        conv.bias = nn.Parameter(conv.bias.data.index_select(0, keep_idx),
                                 requires_grad=conv.bias.requires_grad)
    conv.out_channels = int(W.size(0))

def take_in_channels_conv(conv: nn.Conv2d, keep_idx: torch.Tensor) -> None:
    keep_idx = _sorted(keep_idx)
    W = conv.weight.data.index_select(1, keep_idx)
    conv.weight = nn.Parameter(W, requires_grad=conv.weight.requires_grad)
    conv.in_channels = int(W.size(1))

def take_bn_channels(bn: Optional[nn.Module], keep_idx: torch.Tensor) -> None:
    if bn is None or not isinstance(bn, nn.BatchNorm2d):
        return
    keep_idx = _sorted(keep_idx)
    bn.weight = nn.Parameter(bn.weight.data.index_select(0, keep_idx),
                             requires_grad=bn.weight.requires_grad)
    bn.bias   = nn.Parameter(bn.bias.data.index_select(0, keep_idx),
                             requires_grad=bn.bias.requires_grad)
    bn.running_mean = bn.running_mean.index_select(0, keep_idx).clone()
    bn.running_var  = bn.running_var.index_select(0, keep_idx).clone()
    # Preserve num_batches_tracked so EMA warmup behavior isn’t reset
    if hasattr(bn, "num_batches_tracked"):
        bn.num_batches_tracked = bn.num_batches_tracked.clone()
    bn.num_features = int(keep_idx.numel())

def apply_block_surgery(
    conv1: nn.Conv2d, bn1: Optional[nn.Module],
    conv2: nn.Conv2d, bn2: Optional[nn.Module],
    downsample_conv: Optional[nn.Conv2d],
    keep_idx1: torch.Tensor,   # keep after conv1
    keep_idx2: torch.Tensor,   # keep after conv2
    prev_keep_idx: Optional[torch.Tensor],
) -> torch.Tensor:
    keep_idx1 = _sorted(keep_idx1)
    keep_idx2 = _sorted(keep_idx2)

    if prev_keep_idx is not None:
        take_in_channels_conv(conv1, prev_keep_idx)
    take_out_channels_conv(conv1, keep_idx1)
    take_bn_channels(bn1, keep_idx1)

    take_in_channels_conv(conv2, keep_idx1)
    take_out_channels_conv(conv2, keep_idx2)
    take_bn_channels(bn2, keep_idx2)

    if downsample_conv is not None:
        if prev_keep_idx is not None:
            take_in_channels_conv(downsample_conv, prev_keep_idx)
        take_out_channels_conv(downsample_conv, keep_idx2)

    return keep_idx2

def assert_block_shapes(
    conv_in: nn.Conv2d, bn_in: Optional[nn.Module],
    conv_out: nn.Conv2d, bn_out: Optional[nn.Module],
):
    if isinstance(bn_in, nn.BatchNorm2d):
        assert bn_in.num_features == conv_in.out_channels, \
            f"bn_in {bn_in.num_features} vs conv_in.out {conv_in.out_channels}"
    if isinstance(bn_out, nn.BatchNorm2d):
        assert bn_out.num_features == conv_out.out_channels, \
            f"bn_out {bn_out.num_features} vs conv_out.out {conv_out.out_channels}"

# ---------- simple exit head adapter ----------
def make_exit_head_fn(exit_id: int):
    """Returns head_fn(model,x)->logits for the chosen exit.
    Adjust to your EE model forward; if your model returns list-of-exits, pick the right one.
    """
    def head_fn(model, x: torch.Tensor):
        out = model(x, exit=exit_id)
        if isinstance(out, (list, tuple)):
            out = out[0]
        return out
    return head_fn
# --- REPLACE your compute_snip_channel_scores with this ---
import time

def _bn_freeze_for_grad(model: nn.Module):
    """
    Freeze BN running stats (eval()) but KEEP affine parameters trainable
    and allow gradient to flow through.
    """
    saved = []
    for m in model.modules():
        if isinstance(m, nn.modules.batchnorm._BatchNorm):
            saved.append((m, m.training, m.track_running_stats))
            m.eval()                   # freeze running stats (no updates)
            m.track_running_stats = True  # keep buffers, but won't update in eval()
    return saved

def _bn_restore(saved):
    for m, was_training, was_track in saved:
        if was_training:
            m.train()
        m.track_running_stats = was_track

def compute_snip_channel_scores(
    model: nn.Module,
    onebatch_or_loader,
    loss_fn: nn.Module,
    device: Union[str, torch.device] = "cpu",
    num_batches: int = 1,        # SNIP: single (or very few) batches
    max_per_batch: int = 256,
    log_prefix: str = "[SNIP] ",
) -> Dict[str, torch.Tensor]:
    """
    SNIP channel-saliency: sum_{in,h,w} | W ⊙ dL/dW | per OUT-channel.

    Changes vs. previous:
      • Freeze BN *after* setting model to eval-like (train(False)) and DO NOT call train() again.
      • Remove per-layer L2 normalization (keeps early layers from being over-penalized).
      • Works with a (x,y) tuple or a loader. One balanced batch is recommended upstream.

    Returns dict: {module_name: Tensor[out_channels]}
    """
    model.to(device)
    model.zero_grad(set_to_none=True)

    # Deterministic forward (no dropout), but keep gradients enabled
    model.train(False)

    # Freeze BN running stats (keep affine params trainable)
    bn_saved = _bn_freeze_for_grad(model)

    # Build a small iterator from either a (x,y) tuple or a loader
    if isinstance(onebatch_or_loader, tuple):
        data_iter = [onebatch_or_loader]
    else:
        data_iter = iter(onebatch_or_loader)

    used_batches, used_samples = 0, 0
    running_loss = 0.0

    while used_batches < max(1, num_batches):
        try:
            x, y = next(data_iter)
        except StopIteration:
            break

        if x.size(0) > max_per_batch:
            x, y = x[:max_per_batch], y[:max_per_batch]

        x = x.to(device, non_blocking=True)
        y = y.to(device, non_blocking=True)
        if y.dtype != torch.long:
            y = y.long()

        with torch.enable_grad():
            out = model(x)
            if isinstance(out, (list, tuple)):
                out = out[-1]  # final exit for loss
            loss = loss_fn(out, y)
            loss.backward()

        running_loss += float(loss.detach().cpu().item())
        used_batches += 1
        used_samples += int(x.size(0))

    # Collect per-layer OUT-channel scores (NO layer-wise normalization)
    scores: Dict[str, torch.Tensor] = {}
    for name, layer in model.named_modules():
        if isinstance(layer, nn.Conv2d) and layer.weight.grad is not None:
            W = layer.weight.detach()
            G = layer.weight.grad.detach()
            sc = (W * G).abs().sum(dim=(1, 2, 3))  # per-OUT channel
            scores[name] = sc.cpu()

    if log_prefix:
        print(f"{log_prefix}SNIP: batches={used_batches}, samples={used_samples}, mean_loss={running_loss/max(1,used_batches):.4f}")

    model.zero_grad(set_to_none=True)
    _bn_restore(bn_saved)
    return scores

def debug_dump_snip(scores: Dict[str, torch.Tensor], k: int = 5, tag: str = ""):
    total = 0
    for n, sc in scores.items():
        if sc.numel() == 0: continue
        total += int(sc.numel())
        t = torch.topk(sc, min(k, sc.numel()))
        b = torch.topk(-sc, min(k, sc.numel()))  # bottom-k
        print(f"[SNIP{tag}] {n}: ch={sc.numel()}  top={t.indices.tolist()}  bottom={b.indices.tolist()}  "
              f"range=({float(sc.min()):.3e}, {float(sc.max()):.3e})")
    print(f"[SNIP{tag}] total_channels_scored={total}")
    
def create_resnet_channel_masks(
    scores: Dict[str, torch.Tensor],
    keep_ratio: float,
    model: nn.Module,
    *,
    stage_consistent_conv1: bool = True,
    apply_flops_weighting: bool = True,
    input_size: Tuple[int, int, int] = (3, 32, 32),   # (C,H,W) for a dummy probe if needed
    per_layer_floor: float = 0.30,                     # floor fraction for any "free" layer
) -> Dict[str, List[int]]:
    """
    Build OUT-channel keep masks for a (CIFAR-style) ResNet trunk with residual safety.

    Global budget ("free" layers included in top-K across channels):
        • stem conv1
        • EVERY block.conv1
        • block.conv2 of projection blocks (and their downsample.0 will COPY this)

    Identity blocks:
        • block.conv2 OUT must equal current stage width (top-k by that layer's scores).

    Extras:
        • If `stage_consistent_conv1=True`, all conv1 in the same stage share the SAME mask,
          sized to the average number kept by the global selection in that stage, and chosen
          from the FIRST block's conv1 ranking.
        • If `apply_flops_weighting=True`, per-channel SNIP saliency is multiplied by a
          rough cost proxy H*W*kH*kW captured via forward hooks. Falls back to 1.0.

    Returns:
        masks: Dict[layer_name -> sorted unique channel indices]
               Includes: "conv1", "layers.s.b.conv1", "layers.s.b.conv2",
               and "layers.s.b.downsample.0" (when present).
    """
    keep_ratio = float(max(0.0, min(1.0, keep_ratio)))
    masks: Dict[str, List[int]] = {}

    def _layer_score(name: str, out_ch: int) -> torch.Tensor:
        s = scores.get(name)
        if s is None or s.numel() == 0:
            return torch.ones(out_ch) / max(1, out_ch)  # neutral fallback
        return s.float().detach().cpu()

    # ---------- capture conv output H,W (optional, cheap) ----------
    conv_hw: Dict[str, Tuple[int, int]] = {}
    if apply_flops_weighting:
        hooks = []
        try:
            def _hook_factory(fullname: str):
                def _hook(_m, _in, out):
                    if isinstance(out, torch.Tensor):
                        # out: [N, C, H, W]
                        if out.dim() == 4:
                            conv_hw[fullname] = (int(out.shape[-2]), int(out.shape[-1]))
                return _hook

            for n, m in model.named_modules():
                if isinstance(m, nn.Conv2d):
                    hooks.append(m.register_forward_hook(_hook_factory(n)))

            was_training = model.training
            model.eval()
            dev = next(model.parameters()).device
            x = torch.randn(1, *input_size, device=dev)
            with torch.no_grad():
                _ = model(x)  # cheap single pass
            if was_training:
                model.train()
        except Exception:
            conv_hw = {}
        finally:
            for h in hooks:
                try: h.remove()
                except Exception: pass

    def _cost_for(name: str, module: nn.Conv2d) -> float:
        if not apply_flops_weighting:
            return 1.0
        H, W = conv_hw.get(name, (0, 0))
        kh, kw = module.kernel_size if isinstance(module.kernel_size, tuple) else (module.kernel_size, module.kernel_size)
        if H <= 0 or W <= 0:
            return float(kh * kw)  # fallback on kernel area only
        return float(H * W * kh * kw)

    # ---------- enumerate trunk, collect metadata ----------
    # Stem
    stem_name = "conv1"
    assert hasattr(model, "conv1") and isinstance(model.conv1, nn.Conv2d), "Model must have conv1"
    stem_out = int(model.conv1.out_channels)

    # Stages metadata
    stages = []  # list of dicts: {"s": s_idx, "blocks": [(b_idx, is_proj, names...)]}
    for s_idx, stage in enumerate(getattr(model, "layers")):
        blocks = []
        for b_idx, block in enumerate(stage):
            is_proj = (block.downsample is not None)
            names = dict(
                conv1=f"layers.{s_idx}.{b_idx}.conv1",
                conv2=f"layers.{s_idx}.{b_idx}.conv2",
                down=(f"layers.{s_idx}.{b_idx}.downsample.0" if is_proj else None),
            )
            blocks.append((b_idx, is_proj, names))
        stages.append(dict(s=s_idx, blocks=blocks))

    # ---------- build global-free list with per-layer cost ----------
    global_free: List[Tuple[str, int, torch.Tensor, float]] = []  # (name, outC, scores, cost)

    # stem
    global_free.append((stem_name, stem_out, _layer_score(stem_name, stem_out), _cost_for(stem_name, model.conv1)))

    # conv1 of EVERY block (stage-consistent will be applied after selection)
    for s in stages:
        for b_idx, is_proj, names in s["blocks"]:
            c1_name = names["conv1"]
            # locate module to read outC + kernel for cost (robust but cheap)
            mod = dict(model.named_modules()).get(c1_name, None)
            if isinstance(mod, nn.Conv2d):
                outc = int(mod.out_channels)
            else:
                outc = int(scores.get(c1_name, torch.zeros(1)).numel() or 0)
            if outc <= 0:
                continue
            global_free.append((c1_name, outc, _layer_score(c1_name, outc),
                                _cost_for(c1_name, mod) if isinstance(mod, nn.Conv2d) else 1.0))

    # conv2 of PROJECTION blocks only (identity conv2 handled later by stage width)
    for s in stages:
        for b_idx, is_proj, names in s["blocks"]:
            if not is_proj:
                continue
            c2_name = names["conv2"]
            mod = dict(model.named_modules()).get(c2_name, None)
            if isinstance(mod, nn.Conv2d):
                outc = int(mod.out_channels)
            else:
                outc = int(scores.get(c2_name, torch.zeros(1)).numel() or 0)
            if outc <= 0:
                continue
            global_free.append((c2_name, outc, _layer_score(c2_name, outc),
                                _cost_for(c2_name, mod) if isinstance(mod, nn.Conv2d) else 1.0))

    # ---------- global top-K selection across free channels (FLOPs-weighted) ----------
    total_free = sum(outc for (_n, outc, _sc, _c) in global_free)
    K = max(1, int(round(keep_ratio * max(1, total_free))))

    flat: List[Tuple[float, str, int]] = []
    for name, outc, sc, cost in global_free:
        for idx in range(outc):
            flat.append((float(sc[idx].item()) * float(cost), name, idx))
    flat.sort(reverse=True, key=lambda t: t[0])

    kept_per_layer: Dict[str, List[int]] = {}
    for _score, name, idx in flat[:K]:
        kept_per_layer.setdefault(name, []).append(idx)

    # ---------- per-layer floor to prevent starvation ----------
    FLOOR = float(max(0.0, min(1.0, per_layer_floor)))
    nm2mod = dict(model.named_modules())
    for name, outc, sc, _ in global_free:
        need = max(1, int(round(FLOOR * outc)))
        have = len(kept_per_layer.get(name, []))
        if have < need:
            ranking = torch.topk(sc, k=outc).indices.tolist()
            already = set(kept_per_layer.get(name, []))
            to_add = [i for i in ranking if i not in already][: (need - have)]
            if to_add:
                kept_per_layer.setdefault(name, []).extend(to_add)

    # sort/dedup each
    for k in list(kept_per_layer.keys()):
        kept_per_layer[k] = sorted(set(int(i) for i in kept_per_layer[k]))

    # ---------- apply stage-consistent conv1 (override kept_per_layer for conv1s) ----------
    if stage_consistent_conv1:
        for s in stages:
            # all conv1 names in stage
            conv1_names = [names["conv1"] for (_b, _p, names) in s["blocks"]]
            if not conv1_names:
                continue

            # derive target k for the stage = rounded mean of kept counts across its conv1s (fallback: original outC)
            outc0 = int(getattr(nm2mod.get(conv1_names[0], None), "out_channels", 0))
            if outc0 <= 0:
                continue
            kept_counts = [len(kept_per_layer.get(n, [])) for n in conv1_names if n in kept_per_layer]
            k_stage = int(round(np.mean(kept_counts))) if kept_counts else outc0  # fallback: keep all

            # choose mask from FIRST block's conv1 ranking
            sc_first = _layer_score(conv1_names[0], outc0)
            k_stage = max(1, min(k_stage, outc0))
            mask_stage = torch.topk(sc_first, k=k_stage, largest=True, sorted=True).indices.tolist()

            for n in conv1_names:
                kept_per_layer[n] = sorted(int(i) for i in mask_stage)

    # ---------- finalize masks; enforce residual constraints ----------
    # Stem defines initial stage width
    if stem_name in kept_per_layer and kept_per_layer[stem_name]:
        masks[stem_name] = sorted(int(i) for i in kept_per_layer[stem_name])
        stage_width = len(masks[stem_name])
    else:
        masks[stem_name] = list(range(stem_out))
        stage_width = stem_out

    # Per-block rules
    for s in stages:
        for b_idx, is_proj, names in s["blocks"]:
            c1 = names["conv1"]
            c2 = names["conv2"]
            down = names["down"]

            # conv1 mask
            outc1 = int(getattr(nm2mod.get(c1, None), "out_channels", 0)) or scores.get(c1, torch.zeros(0)).numel()
            if outc1 > 0:
                masks[c1] = sorted(int(i) for i in kept_per_layer.get(c1, list(range(outc1))))

            # conv2 handling
            outc2 = int(getattr(nm2mod.get(c2, None), "out_channels", 0)) or scores.get(c2, torch.zeros(0)).numel()
            if outc2 <= 0:
                continue

            if is_proj:
                # projection: conv2 is global-free — use its selected set; downsample.0 COPIES it
                c2_keep = kept_per_layer.get(c2, list(range(outc2)))
                masks[c2] = sorted(int(i) for i in c2_keep)
                if down is not None:
                    masks[down] = list(masks[c2])
                stage_width = len(masks[c2])  # identity width for the rest of the stage
            else:
                # identity: conv2 OUT must equal current stage_width (select top-k of its own scores)
                sc2 = _layer_score(c2, outc2)
                k = max(1, min(stage_width, outc2))
                masks[c2] = sorted(int(i) for i in torch.topk(sc2, k=k, largest=True, sorted=True).indices.tolist())
                # stage_width unchanged

    # final tidy
    for k in list(masks.keys()):
        masks[k] = sorted(set(int(i) for i in masks[k]))

    return masks

def prune_conv_layer(conv: nn.Conv2d, out_idx: List[int], in_idx: List[int]) -> nn.Conv2d:
    """Create a new Conv2d from selected OUT/IN indices.

    - Preserves groups/stride/padding/dilation/padding_mode.
    - Works for groups>1 if (len(in_idx) % groups == 0) and (len(out_idx) % groups == 0).
      (BasicBlock has groups==1; for ResNeXt, ensure masks respect group boundaries.)
    """
    dev, dt = conv.weight.device, conv.weight.dtype
    g = int(conv.groups)

    # shape guards for grouped convs
    if g > 1:
        assert len(in_idx) % g == 0, f"in_idx({len(in_idx)}) not divisible by groups({g})"
        assert len(out_idx) % g == 0, f"out_idx({len(out_idx)}) not divisible by groups({g})"

    k = conv.kernel_size
    new_conv = nn.Conv2d(
        in_channels=len(in_idx),
        out_channels=len(out_idx),
        kernel_size=k,
        stride=conv.stride,
        padding=conv.padding,
        dilation=conv.dilation,
        groups=g,
        bias=(conv.bias is not None),
        padding_mode=getattr(conv, "padding_mode", "zeros"),
        device=dev,
        dtype=dt,
    )

    with torch.no_grad():
        # W: [outC, inC, kh, kw]
        W = conv.weight.index_select(0, torch.as_tensor(out_idx, device=dev)) \
                       .index_select(1, torch.as_tensor(in_idx,  device=dev))
        new_conv.weight.copy_(W.to(dt))
        if conv.bias is not None:
            new_conv.bias.copy_(conv.bias.index_select(0, torch.as_tensor(out_idx, device=dev)).to(dt))
    return new_conv

def prune_conv_in_channels(conv: nn.Conv2d, in_indices: List[int]) -> nn.Conv2d:
    """Keep all output channels, prune only the IN channels (groups preserved)."""
    dev, dt = conv.weight.device, conv.weight.dtype
    new_conv = nn.Conv2d(
        in_channels=len(in_indices),
        out_channels=conv.out_channels,
        kernel_size=conv.kernel_size,
        stride=conv.stride,
        padding=conv.padding,
        dilation=conv.dilation,
        groups=conv.groups,
        bias=(conv.bias is not None),
        padding_mode=getattr(conv, "padding_mode", "zeros"),
    ).to(device=dev, dtype=dt)

    with torch.no_grad():
        new_conv.weight.copy_(conv.weight[:, in_indices].to(dt).to(dev))
        if conv.bias is not None:
            new_conv.bias.copy_(conv.bias.to(dt).to(dev))
    return new_conv

def prune_bn_layer(bn: nn.BatchNorm2d, keep_indices: List[int]) -> nn.BatchNorm2d:
    dev, dt = bn.weight.device, bn.weight.dtype
    idx = torch.as_tensor(keep_indices, device=dev, dtype=torch.long)

    new_bn = nn.BatchNorm2d(
        num_features=len(keep_indices),
        eps=bn.eps,
        momentum=bn.momentum,
        affine=bn.affine,
        track_running_stats=bn.track_running_stats,
        device=dev, dtype=dt,
    )
    with torch.no_grad():
        if bn.affine:
            new_bn.weight.copy_(bn.weight.index_select(0, idx).to(dt))
            new_bn.bias.copy_(  bn.bias.index_select(0, idx).to(dt))
        if bn.track_running_stats:
            new_bn.running_mean.copy_(bn.running_mean.index_select(0, idx))
            new_bn.running_var.copy_( bn.running_var.index_select(0, idx))
            if hasattr(bn, "num_batches_tracked") and hasattr(new_bn, "num_batches_tracked"):
                new_bn.num_batches_tracked.copy_(bn.num_batches_tracked)
    return new_bn

def _resize_fc_in(fc: nn.Linear, in_features: int) -> nn.Linear:
    """Rebuild a Linear to new in_features, copy overlapping weights."""
    dev, dt = fc.weight.device, fc.weight.dtype
    new_fc = nn.Linear(in_features, fc.out_features, bias=(fc.bias is not None)).to(device=dev, dtype=dt)
    with torch.no_grad():
        k = min(fc.in_features, new_fc.in_features)
        if k > 0:
            new_fc.weight[:, :k].copy_(fc.weight[:, :k].to(dt).to(dev))
        if fc.bias is not None:
            new_fc.bias.copy_(fc.bias.to(dt).to(dev))
    return new_fc


# ---------------- Exit-head patching utilities ----------------

def _first_conv_in_sequential(seq: nn.Sequential) -> Optional[int]:
    for idx, m in enumerate(seq):
        if isinstance(m, nn.Conv2d):
            return idx
    return None

@torch.no_grad()
def _capture_exit_inputs(model: nn.Module, device: torch.device, *, x_sample: Optional[torch.Tensor] = None) -> dict:
    """
    Capture channel count of the tensor fed into each exit head by swapping
    the head with a recorder that returns the input unchanged.

    IMPORTANT: Runs a forward on a REAL sample batch (x_sample) from validation/train.
    """
    model = model.to(device)
    model.eval()

    if not hasattr(model, "exit_heads"):
        return {}

    class _InputRecorder(nn.Module):
        def __init__(self):
            super().__init__()
            self.cin: Optional[int] = None
        def forward(self, x: torch.Tensor):
            if isinstance(x, torch.Tensor):
                self.cin = int(x.shape[1])
            return x  # pass-through

    original_heads: List[nn.Module] = []
    recorders: List[_InputRecorder] = []
    for i, head in enumerate(model.exit_heads):
        original_heads.append(head)
        rec = _InputRecorder().to(device)
        recorders.append(rec)
        model.exit_heads[i] = rec  # type: ignore[index]

    try:
        if x_sample is None:
            # last-resort fallback; prefer to always pass a real batch
            x_sample = torch.randn(1, 3, 32, 32, device=device)
        _ = model(x_sample.to(device, non_blocking=True))
        cin_by_idx = {i: r.cin for i, r in enumerate(recorders) if r.cin is not None}
    finally:
        for i, head in enumerate(original_heads):
            model.exit_heads[i] = head  # type: ignore[index]

    model.train()
    return cin_by_idx


@torch.no_grad()
def _capture_fc_in_features(model: nn.Module, device: torch.device, *, x_sample: Optional[torch.Tensor] = None) -> Dict[int, int]:
    """
    Capture in_features for each head.fc via a forward pre-hook on a REAL sample batch.
    """
    model.eval()
    fin_by_idx: Dict[int, int] = {}

    def make_hook(i: int):
        def _hook(_mod, inputs):
            x = inputs[0]
            if isinstance(x, torch.Tensor):
                fin_by_idx[i] = int(x.shape[-1])
        return _hook

    handles = []
    for i, head in enumerate(model.exit_heads):
        if hasattr(head, "fc") and isinstance(head.fc, nn.Linear):
            handles.append(head.fc.register_forward_pre_hook(make_hook(i)))

    if x_sample is None:
        x_sample = torch.randn(1, 3, 32, 32, device=device)  # fallback; avoid in real runs
    _ = model(x_sample.to(device, non_blocking=True))

    for h in handles:
        h.remove()
    model.train()
    return fin_by_idx


def patch_exit_heads_to_cin_and_fc(
    model: nn.Module,
    device: torch.device,
    *,
    x_sample: Optional[torch.Tensor] = None,
    sample_loader: Optional[Iterable] = None,
    max_per_batch: int = 256,
) -> None:
    """
    Full patch: adjust features' first conv input, then rebuild fc in_features,
    using a REAL mini-batch from validation/train. If both x_sample and loader
    are provided, x_sample wins. If neither is provided, we (reluctantly) fall back.

    Args:
      x_sample: a single batch tensor (N,C,H,W) from your validation set.
      sample_loader: a loader to draw one batch from if x_sample is None.
      max_per_batch: clamp to keep the pass cheap.
    """
    # materialize a real batch if not directly provided
    if x_sample is None and sample_loader is not None:
        try:
            x, _ = next(iter(sample_loader))
            if x.size(0) > max_per_batch:
                x = x[:max_per_batch]
            x_sample = x.to(device, non_blocking=True)
        except StopIteration:
            pass
        except Exception:
            pass

    # ---- first conv(in) of each head
    cin_by_head = _capture_exit_inputs(model, device, x_sample=x_sample)
    for i, head in enumerate(getattr(model, "exit_heads", [])):
        if not hasattr(head, "features") or not isinstance(head.features, nn.Sequential):
            continue
        fcidx = _first_conv_in_sequential(head.features)
        if fcidx is None:
            continue

        target_cin = cin_by_head.get(i, None)
        if target_cin is None:
            continue

        first_conv: nn.Conv2d = head.features[fcidx]
        cur_cin = first_conv.in_channels
        if target_cin == cur_cin:
            continue

        keep = list(range(min(cur_cin, target_cin)))
        head.features[fcidx] = prune_conv_in_channels(first_conv, keep)

    # run once to stabilize actual feature dims
    model.eval()
    with torch.no_grad():
        if x_sample is None:
            x_sample = torch.randn(1, 3, 32, 32, device=device)  # fallback
        _ = model(x_sample.to(device, non_blocking=True))
    model.train()

    # ---- fc(in_features) of each head
    fin_by_head = _capture_fc_in_features(model, device, x_sample=x_sample)
    for i, head in enumerate(getattr(model, "exit_heads", [])):
        if hasattr(head, "fc") and isinstance(head.fc, nn.Linear):
            fin = fin_by_head.get(i, None)
            if fin is not None and fin != head.fc.in_features:
                head.fc = _resize_fc_in(head.fc, fin)


def hard_prune_resnet(
        model: nn.Module,
        masks: Dict[str, IdxLike],
        *,
        patch_sample: Optional[torch.Tensor] = None,
        patch_loader: Optional[Iterable] = None,
        patch_max_per_batch: int = 256,
    ) -> nn.Module:
    """
    Apply structural channel pruning with the given OUT-channel masks.

    Rules:
      - Stem conv1 pruned by mask['conv1'] if present (in=3 stays).
      - For each block:
          conv1:  OUT <- mask, IN <- previous stage width
          conv2:  IN <- conv1 mask
                  • identity block: OUT indices chosen by mask but OUT width == stage width
                  • projection block: OUT <- mask and downsample.0 OUT uses the SAME indices
      - Exit heads: first conv input channels pruned to match tap; FC rebuilt to match features.

    All new modules inherit device/dtype from the layers they replace.
    """
    model = deepcopy(model)
    device = next(model.parameters()).device

    def to_idx_list(x: Optional[IdxLike], total: int) -> List[int]:
        if x is None:
            return list(range(total))
        if isinstance(x, torch.Tensor):
            x = x.detach().cpu().tolist()
        out = [int(i) for i in x if 0 <= int(i) < total]
        return out if out else list(range(total))

    # Stem
    if "conv1" in masks:
        keep_o = to_idx_list(masks["conv1"], model.conv1.out_channels)
        inC = int(model.conv1.in_channels)
        model.conv1 = prune_conv_layer(model.conv1, keep_o, list(range(inC)))  # not hardcoded 3
        model.bn1 = prune_bn_layer(model.bn1, keep_o)
        stage_width = len(keep_o)
    else:
        stage_width = model.conv1.out_channels

    # Blocks
    for s_idx, stage in enumerate(model.layers):
        for b_idx, block in enumerate(stage):
            c1_name = f"layers.{s_idx}.{b_idx}.conv1"
            c2_name = f"layers.{s_idx}.{b_idx}.conv2"
            ds_name = f"layers.{s_idx}.{b_idx}.downsample.0" if block.downsample is not None else None

            # conv1
            c1_keep_o = to_idx_list(masks.get(c1_name), block.conv1.out_channels)
            block.conv1 = prune_conv_layer(block.conv1, c1_keep_o, list(range(stage_width)))
            block.bn1 = prune_bn_layer(block.bn1, c1_keep_o)

            # conv2
            if block.downsample is None:
                # identity: OUT width == stage_width
                pref = to_idx_list(masks.get(c2_name), block.conv2.out_channels)
                desired = max(1, min(stage_width, block.conv2.out_channels))
                if len(pref) >= desired:
                    c2_keep_o = pref[:desired]
                else:
                    pad = [i for i in range(block.conv2.out_channels) if i not in pref]
                    c2_keep_o = pref + pad[: desired - len(pref)]
                block.conv2 = prune_conv_layer(block.conv2, c2_keep_o, c1_keep_o)
                block.bn2 = prune_bn_layer(block.bn2, c2_keep_o)
            else:
                # projection: conv2 & downsample share OUT mask; stage width updates
                c2_keep_o = to_idx_list(masks.get(c2_name), block.conv2.out_channels)
                block.conv2 = prune_conv_layer(block.conv2, c2_keep_o, c1_keep_o)
                block.bn2 = prune_bn_layer(block.bn2, c2_keep_o)

                ds_conv: nn.Conv2d = block.downsample[0]
                ds_bn: nn.BatchNorm2d = block.downsample[1]
                ds_keep_o = to_idx_list(masks.get(ds_name, c2_keep_o), ds_conv.out_channels)
                block.downsample[0] = prune_conv_layer(ds_conv, ds_keep_o, list(range(stage_width)))
                block.downsample[1] = prune_bn_layer(ds_bn, ds_keep_o)
                stage_width = len(ds_keep_o)

    # Patch exit heads (inputs + fc) and sanity-forward on correct device
    model = model.to(device)
    patch_exit_heads_to_cin_and_fc(
        model, device, x_sample=patch_sample, sample_loader=patch_loader, max_per_batch=patch_max_per_batch
    )

    model.eval()
    with torch.no_grad():
        if patch_sample is not None:
            _ = model(patch_sample.to(device, non_blocking=True))
        else:
            # single sanity pass only if we had no real batch
            _ = model(torch.randn(1, 3, 32, 32, device=device))
    model.train()
    # ... tail unchanged ...
    # return model.to(device)
    # Refresh convenience attributes so downstream code sees the current graph
    setattr(model, "all_state_dict_keys", list(model.state_dict().keys()))
    setattr(model, "trainable_state_dict_keys", [n for n, p in model.named_parameters() if p.requires_grad])
    # cheap but effective: recalibrate trunk BN on a small real batch
    try:
        if patch_loader is not None or patch_sample is not None:
            one = (patch_sample, torch.zeros(1, dtype=torch.long, device=device)) if patch_sample is not None else next(iter(patch_loader))
            recalibrate_bn(
                model,
                one,
                device=device,
                num_batches=150,         # 100–300 is enough
                max_per_batch=256,
                reset_running_stats=False,
                log_prefix="[PRUNE] ",
            )
    except Exception as e:
        print(f"[PRUNE] BN recalibration skipped: {e}")
    return model.to(device)


# ---------------- FLOPs/params helper (last-exit only) ----------------
def estimate_gru_lm_macs_params(model, seq_len: int, batch_size: int = 1, time_distributed_head: bool = False):
    """
    Analytic MACs/params for a stacked GRU LM with exits on the last layer.
    Counts multiply-accumulate (MAC), ignores non-linearities.
    """
    V = int(getattr(model, "vocab_size"))
    E = int(getattr(model, "embed_dim"))
    H = int(getattr(model, "hidden_dim"))
    L = int(getattr(model, "num_layers"))

    B = int(batch_size)
    T = int(seq_len)

    # --- Params ---
    # Embedding and head
    params_embed = V * E
    params_head  = H * V + V
    # GRU layer params
    # Layer 1: W_ih (3H×E), W_hh (3H×H), biases (6H)
    params_l1 = 3*H*E + 3*H*H + 6*H
    # Layers 2..L: input dim = H
    params_lk = 6*H*H + 6*H
    params_gru = params_l1 + max(0, L-1) * params_lk
    total_params = params_embed + params_gru + params_head

    # --- MACs ---
    # Per time step per layer (mul-adds only):
    # layer1: 3*(E*H + H*H); next layers: 6*H*H
    macs_l1_per_step = 3*(E*H + H*H)
    macs_lk_per_step = 6*H*H
    macs_gru = B * T * (macs_l1_per_step + max(0, L-1)*macs_lk_per_step)

    # Head cost: last token only vs time-distributed
    head_tokens = (T if time_distributed_head else 1)
    macs_head = B * head_tokens * H * V

    total_macs = macs_gru + macs_head
    return total_macs, total_params


def compare_flops_and_params(
    model: nn.Module,
    input_size=(3, 32, 32),
    device: str = "cpu",
    batch_size: int = 1,
    default_seq_len: int = 80,
    use_analytic_for_rnn: bool = True,
):
    # Detect sequence models by presence of an Embedding
    is_seq = any(isinstance(m, nn.Embedding) for m in model.modules())
    if is_seq and use_analytic_for_rnn:
        T = int(getattr(model, "seq_len", default_seq_len))
        # If you kept a flag on the head, read it; default False now
        time_dist = any(getattr(h, "last_token_only", True) is False
                        for h in getattr(model, "exit_heads", []))
        macs, params = estimate_gru_lm_macs_params(model, seq_len=T,
                                                   batch_size=batch_size,
                                                   time_distributed_head=time_dist)
        return f"{macs/1e6:.2f} MMac", f"{params/1e6:.2f} M"
    # else: CNN/ViT path — you can keep your ptflops code here
    from ptflops import get_model_complexity_info
    model_was_train, last_exit = model.training, getattr(model, "last_exit_only", False)
    model.eval().to(device); setattr(model, "last_exit_only", True)
    try:
        def _ctor(shape):  # image tensors
            import torch
            return torch.zeros((batch_size, *shape), device=device)
        flops, params = get_model_complexity_info(
            model, input_size, print_per_layer_stat=False, verbose=False, input_constructor=_ctor
        )
        return flops, params
    finally:
        setattr(model, "last_exit_only", last_exit)
        model.train(model_was_train)

def ____compare_flops_and_params(model: nn.Module, input_size=(3, 32, 32), device: str = "cpu"):
    """
    Calculates FLOPs and params for an early-exit model by temporarily enabling
    `last_exit_only` (if present) so ptflops measures only the final exit.
    """
    model.eval()
    model.to(device)

    was_last = getattr(model, "last_exit_only", False)
    setattr(model, "last_exit_only", True)

    try:
        with torch.no_grad():
            flops, params = get_model_complexity_info(
                model, input_size, print_per_layer_stat=False, verbose=False
            )
        return flops, params
    except Exception as e:
        print(f"FLOPs calculation failed: {e}")
        return None, None
    finally:
        setattr(model, "last_exit_only", was_last)
        model.train()

        