import flwr as fl
import numpy as np
import torch
import torch.nn as nn
import math
import copy
from typing import Dict, Callable, Optional, Tuple 
from flwr.server.server import Server
from flwr.server.history import History
from flwr.common import weights_to_parameters

from collections import OrderedDict, defaultdict
from src.apps import ClassificationApp
from src.apps.app_utils import cosine_decay_with_warmup
from src.models.model_utils import set_weights, set_partial_weights
from src.apps.clients import ree_early_exit_test
from src.utils import get_func_from_config
import string
from src.models.snip_utils import debug_dump_snip

import logging
logger = logging.getLogger(__name__)


class ReeFLClassificationApp(ClassificationApp):    
    def __init__(self, *args, mode='multi_tier', save_log_file=True, width_scaling=None, **kwargs):
        super().__init__(*args, **kwargs)
        assert mode in ['maximum', 'multi_tier']
        self.server_snip_masks = bool(getattr(self.ckp.config.app.args, "server_snip_masks", True))
        self.mode = mode
        self.save_log_file = save_log_file
        self.no_of_exits = self.ckp.config.models.net.args.no_of_exits
        self.no_of_clients = self.ckp.config.simulation.num_clients
        self.width_scaling = width_scaling
        if self.width_scaling:
            assert len(self.width_scaling) == self.no_of_exits, 'no. of width scales must match no. of exits.'

        if self.mode == 'multi_tier':
            self.client_max_exit_layers = {str(i): i % self.no_of_exits for i in range(self.no_of_clients)}
            if self.width_scaling:
                self.client_width_scales = {str(i): self.width_scaling[i % len(self.width_scaling)] for i in range(self.no_of_clients)}
        self.pruning_mode = list(getattr(self.ckp.config.app.args, "pruning_mode", []))
        self._snip_masks = {}

    def _force_width_one(self) -> bool:
        """Return True if ReeFL should run at global width=1.0 (strict)."""
        # Explicit flag wins
        try:
            app_args = getattr(self.ckp.config.app, "args", None)
            if app_args is not None and getattr(app_args, "strict_width_one", None) is not None:
                return bool(app_args.strict_width_one)
        except Exception:
            pass

        # Fallback: infer from strategy class path safely
        try:
            strat_path = str(getattr(self.ckp.config.server.strategy, "class", "")).lower()
            return "reefl" in strat_path
        except Exception:
            return False

    def _pruning_uses_snip(self) -> bool:
        """True if any exit is configured to use SNIP pruning."""
        try:
            return any(str(m).lower() == "snip" for m in (self.pruning_mode or []))
        except Exception:
            return False

    def get_fit_config_fn(self):
        """Return a configuration with static batch size and (local) epochs."""
        def fit_config_fn(rnd: int) -> Dict[str, str]:
            fit_config = self.ckp.config.app.on_fit

            if fit_config.cos_lr_decay:
                current_lr = cosine_decay_with_warmup(rnd,
                                learning_rate_base=fit_config.start_lr,
                                total_steps=self.ckp.config.app.run.num_rounds,
                                minimum_learning_rate=fit_config.min_lr,
                                warmup_learning_rate=0,
                                warmup_steps=0,
                                hold_base_rate_steps=0.)
            else: 
                current_lr = fit_config.start_lr

            client_config = {
                "lr": current_lr,
                "current_round": rnd,
                }
            return client_config

        return fit_config_fn

    def get_evaluate_config_fn(self):
        """"Client evaluate. Evaluate on client's test set"""
        def evaluate_config_fn(rnd: int) -> Dict[str, str]:
            eval_config = self.ckp.config.app.on_evaluate

            client_config = {
                "lr": eval_config.lr,
                "current_round": rnd,
                "finetune_epochs": eval_config.finetune_epochs }
            return client_config

        return evaluate_config_fn

    def get_client_fn(self):
        client_ctor = get_func_from_config(self.app_config.client)
        nof = int(getattr(self, "no_of_exits", 1))
        width_scaling_cfg = getattr(self, "width_scaling", None)

        force_width_one = self._force_width_one()

        def _pick_lid(cid: str) -> int:
            if getattr(self, "mode", "multi_tier") == "maximum":
                return max(0, nof - 1)
            return int(self.client_max_exit_layers.get(cid, int(cid) % max(1, nof))) \
                if hasattr(self, "client_max_exit_layers") else int(cid) % max(1, nof)

        def _pick_width(cid: str, lid: int) -> float:
            if force_width_one:
                return 1.0
            if hasattr(self, "client_width_scales") and self.client_width_scales:
                return float(self.client_width_scales.get(cid, 1.0))
            if width_scaling_cfg and len(width_scaling_cfg) > lid:
                return float(width_scaling_cfg[lid])
            return 1.0

        def client_fn(cid: string):
            lid   = _pick_lid(cid)
            width = _pick_width(cid, lid)

            kwargs = dict(self.app_config.client.args)

            # --- ALWAYS consult exit_shape; SNIP must be single-shot + structural ---
            shape = {}
            if not self.server_snip_masks and hasattr(self, "exit_shape") and callable(self.exit_shape):
                try:
                    shape = self.exit_shape(lid) or {}
                    if "snip_mask" in shape and isinstance(shape["snip_mask"], dict) and shape["snip_mask"]:
                        kwargs["snip_mask"] = shape["snip_mask"]
                    if "width_scale" in shape:
                        width = float(shape["width_scale"])
                except Exception as e:
                    logger.warning(f"[get_client_fn] exit_shape failed for lid={lid}: {e}")

            # Avoid clashing kwarg
            kwargs.pop("width_scale", None)
            
            if force_width_one:
                kwargs.pop("kl_loss", None)
                kwargs.pop("kl_consistency_weight", None)
                kwargs.pop("kl_softmax_temp", None)

            # Shape-mismatch can be allowed if SNIP is in use (structural change)
            use_snip_here = "snip_mask" in kwargs
            kwargs["allow_shape_mismatch"] = (not force_width_one) or use_snip_here

            # Pass width POSITIONALLY
            return client_ctor(cid, lid, width, self.ckp, **kwargs)

        return client_fn

    def get_eval_fn(self):
        from copy import deepcopy

        def evaluate(*args):
            if len(args) < 2:
                raise TypeError("evaluate expects at least (weights, partition)")

            # --- local imports to keep scope tight ---
            import numpy as np
            import torch
            from src.utils import get_func_from_config
            from src.apps.clients import ree_early_exit_test

            weights      = args[0]
            partition    = args[1]
            exit_i       = None
            blk_to_exit  = None
            blks_to_exit = None
            local_keys   = None
            width_scale_override = None

            if len(args) >= 3 and isinstance(args[2], (int, np.integer)):
                exit_i = int(args[2])

            if len(args) >= 4:
                a3 = args[3]
                if isinstance(a3, (list, tuple)):
                    blks_to_exit = list(a3)
                elif a3 is not None:
                    blk_to_exit = int(a3)

            if len(args) >= 5 and isinstance(args[4], (list, tuple)):
                local_keys = list(args[4])

            if len(args) >= 6 and isinstance(args[5], (float, int, np.floating, np.integer)):
                width_scale_override = float(args[5])

            # ---------- Device: prefer CUDA, allow config override ----------
            cfg_dev = getattr(getattr(self.ckp.config, "app", {}), "eval_fn", None)
            cfg_dev = getattr(cfg_dev, "device", None) if cfg_dev is not None else None
            if cfg_dev is not None:
                device = str(cfg_dev)
            else:
                device = f"cuda:{torch.cuda.current_device()}" if torch.cuda.is_available() else "cpu"
            if device.startswith("cuda") and torch.cuda.is_available():
                torch.backends.cudnn.benchmark = True  # perf on fixed shapes

            # ---------- Build submodel for this exit/depth/width ----------
            net_cfg  = self.ckp.config.models.net
            arch_fn  = get_func_from_config(net_cfg)
            net_args = deepcopy(net_cfg.args)

            # probe on CPU to infer defaults
            probe = arch_fn(device="cpu", **deepcopy(net_cfg.args))
            try:
                full_depth = sum(len(s) for s in getattr(probe, "layers", []))
                last_block = full_depth - 1 if full_depth and full_depth > 0 else None
            except Exception:
                full_depth = None
                last_block = None
            if last_block is None:
                try:
                    vals = [int(x) for x in getattr(probe, "blks_to_exit", []) if int(x) >= 0]
                    last_block = max(vals) if vals else 0
                except Exception:
                    last_block = 0

            if exit_i is None:
                try:
                    exit_i = len(getattr(probe, "blks_to_exit")) - 1
                except Exception:
                    exit_i = 0
            if blk_to_exit is None:
                try:
                    blk_to_exit = int(getattr(probe, "blks_to_exit")[exit_i])
                except Exception:
                    blk_to_exit = last_block
            if int(blk_to_exit) < 0:
                blk_to_exit = int(last_block)

            if width_scale_override is not None:
                width_scale = float(width_scale_override)
            else:
                try:
                    width_scale = float(self.ckp.config.app.args.width_scaling[exit_i])
                except Exception:
                    width_scale = 1.0

            if blks_to_exit is not None:
                bte = list(blks_to_exit)[: exit_i + 1]
                bte = [int(last_block) if int(x) < 0 else int(x) for x in bte]
                net_args["blks_to_exit"] = bte
                blk_to_exit = int(bte[-1])
            else:
                if "blks_to_exit" in net_args and net_args["blks_to_exit"] is not None:
                    bte = list(net_args["blks_to_exit"])[: exit_i + 1]
                    bte = [int(last_block) if int(x) < 0 else int(x) for x in bte]
                    net_args["blks_to_exit"] = bte
                    if bte:
                        blk_to_exit = int(bte[-1])

            if "ee_layer_locations" in net_args and isinstance(net_args.get("ee_layer_locations"), (list, tuple)):
                net_args["ee_layer_locations"] = list(net_args["ee_layer_locations"])[: exit_i + 1]

            net_args["depth"]          = int(blk_to_exit) + 1
            net_args["no_of_exits"]    = exit_i + 1
            net_args["width_scale"]    = float(width_scale)
            net_args["last_exit_only"] = False

            # build the evaluation model on the chosen device
            model  = arch_fn(device=device, **net_args).eval()

            # Progressive keys (trunk ≤ exit_i + ONLY head_i)
            if local_keys is None:
                all_keys  = list(model.trainable_state_dict_keys)
                head_pref = f"exit_heads.{exit_i}."
                progressive_keys = [k for k in all_keys if (not k.startswith("exit_heads.")) or k.startswith(head_pref)]
            else:
                progressive_keys = list(local_keys)

            if len(weights) != len(progressive_keys):
                raise AssertionError(
                    f"[ReeFL/eval] len(weights)={len(weights)} != expected={len(progressive_keys)} "
                    f"(exit_i={exit_i}, depth={int(blk_to_exit)+1}, width_scale={width_scale})"
                )

            # ---------- Safe load (slice/pad) with device awareness ----------
            sd = model.state_dict()

            def _to_tensor(x):
                if torch.is_tensor(x): 
                    return x.detach()               # keep device as-is; we'll move in _fit_shape
                return torch.from_numpy(np.asarray(x))

            def _fit_shape(src: torch.Tensor, tgt: torch.Tensor) -> torch.Tensor:
                # move to target param's device/dtype first
                src = src.to(device=tgt.device, dtype=tgt.dtype, non_blocking=True)
                if list(src.shape) == list(tgt.shape):
                    return src
                slices = [slice(0, min(s, t)) for s, t in zip(src.shape, tgt.shape)]
                out = torch.zeros_like(tgt)
                out[tuple(slices)] = src[tuple(slices)]
                return out

            for k, w in zip(progressive_keys, weights):
                if k not in sd:
                    continue
                sd[k] = _fit_shape(_to_tensor(w), sd[k])

            model.load_state_dict(sd, strict=False)

            # ---------- Dataloader (no duplicate pin_memory) ----------
            data_cfg = self.ckp.config.data
            data_cls = get_func_from_config(data_cfg)
            dataset  = data_cls(self.ckp, **data_cfg.args)

            testloader = dataset.get_dataloader(
                data_pool="server",
                partition=partition,
                batch_size=int(self.ckp.config.app.eval_fn.batch_size),
                augment=False,
                num_workers=(4 if str(device).startswith("cuda") else 0),  # safe: CIFAR100 wrapper sets pin_memory itself
            )

            # ---------- Inference ----------
            with torch.inference_mode():
                results = ree_early_exit_test(
                    model,
                    max_early_exit_layer=exit_i,
                    testloader=testloader,
                    device=str(device),
                    ensemble=False,
                )

            # Select exit robustly
            int_keys = sorted(k for k in results.keys() if isinstance(k, int))
            if exit_i in results:
                use_key = exit_i
            else:
                not_greater = [k for k in int_keys if k <= exit_i]
                use_key = (not_greater[-1] if not_greater else int_keys[-1])

            out  = results[use_key]
            loss = float(out["loss"])
            acc  = float(out["accuracy"]) * 100.0

            metrics = {
                f"centralized_{partition}_exit{exit_i}_loss": loss,
                f"centralized_{partition}_exit{exit_i}_acc":  acc,
            }

            print(
                f"[EVAL][{partition}] exit={exit_i} (used={use_key}) "
                f"depth={int(blk_to_exit)+1} width_scale={width_scale:.3f} "
                f"payload_len={len(weights)} prog_keys={len(progressive_keys)} dev={device} -> "
                f"loss={loss:.4f} acc={acc:.2f}"
            )

            # cleanup
            del model, sd
            try:
                if str(device).startswith("cuda"):
                    torch.cuda.synchronize()
                    torch.cuda.empty_cache()
            except Exception:
                pass

            return loss, metrics

        return evaluate
    
    def bind_tolerant_eval(self, server):
        """Force strategy to use the tolerant (slice/pad) centralized evaluator."""
        try:
            server.strategy.eval_fn = self.get_eval_fn()
            logger.info("[App] Bound tolerant eval_fn from %s", self.__class__.__name__)
        except Exception as e:
            logger.warning("[App] Could not bind eval_fn: %s", e)

    def run(self, server: Server, reefl=True):
        """Run federated averaging for a number of rounds."""
        history = History()
        data_config = self.ckp.config.data
        data_class = get_func_from_config(data_config)
        dataset = data_class(self.ckp, **data_config.args)

        # Ensure SNIP masks up front if needed
        # self._warm_snip_masks(force=True)
        if not self.server_snip_masks:
            self._warm_snip_masks(force=True)
        self.bind_tolerant_eval(server)
        
        def _centralized_evaluate(rnd, partition, log=True):
            server_metrics = None
            parameters = server.parameters
            res_cen = server.strategy.evaluate(parameters=parameters, partition=partition)
            if res_cen is not None:
                loss_cen, server_metrics = res_cen
                history.add_loss_centralized(rnd=rnd, loss=loss_cen)
                history.add_metrics_centralized(rnd=rnd, metrics=server_metrics)
                if log:
                    self.ckp.log(server_metrics, step=rnd)
            return server_metrics

        # -------- Initialize parameters in the SAME shape/order the strategy expects
        if self.load or self.start_run > 1:
            server.parameters = self.current_weights
            logger.info('[*] Global Parameters Loaded.')
        else:
            net_config = self.ckp.config.models.net
            arch_fn = get_func_from_config(net_config)
            initial_net_args = copy.deepcopy(net_config.args)
            initial_net_args.pop("depth", None)  # ensure full-depth probe
            net = arch_fn(device='cpu', **initial_net_args)

            # Prefer the strategy-provided initializer (trainables in canonical order)
            try:
                init_params = server.strategy.initialize_parameters(None)
            except Exception:
                init_params = None

            if init_params is not None:
                server.parameters = init_params
            else:
                # Fallback: trainables-only in model order (matches strategy.global_sd_keys)
                server.parameters = weights_to_parameters(
                    [net.state_dict()[k].cpu().numpy() for k in net.trainable_state_dict_keys]
                )

        # -------------------- Training loop --------------------
        logger.info("FL starting")
        app_run_config = self.ckp.config.app.run

        for rnd in range(self.start_run, app_run_config.num_rounds + 1):
            server_metrics = None
            clients_metrics = None

            res_fit = server.fit_round(rnd=rnd)
            if res_fit:
                parameters_prime, _, (results, _) = res_fit
                clients_metrics = [res[1].metrics for res in results]
                if parameters_prime:
                    server.parameters = parameters_prime

            if rnd % app_run_config.test_every_n == 0 or rnd == app_run_config.num_rounds:
                logger.debug(f"[Round {rnd}] Evaluating global model on test set.")
                server_metrics = _centralized_evaluate(rnd, 'test')
                logger.info(f"[Round {rnd}] {server_metrics}")
                if server_metrics is not None:
                    server_metrics["round"] = rnd
                    for k, v in server_metrics.items():
                        self.ckp.log_summary(k, v)
                        if self.save_log_file:
                            alpha = 0 if dataset.pre_partition else '_'.join(list(dataset.test_alpha.keys()))
                            self.ckp.save_results_logfile(
                                self.mode, alpha, k, v, ps_type=f'init_{self.ckp.config.name}',
                                filepath=self.save_log_file, reset=False
                            )

            # end-of-round saving
            self.ckp.save(f'results/round_{rnd}.pkl', 
                {'round': rnd, 'clients_metrics': clients_metrics, 'server_metrics': server_metrics})
            self.ckp.save(f'models/latest_weights.pkl', server.parameters)
            if app_run_config.save_every_n is not None and (rnd == self.start_run or rnd % app_run_config.save_every_n == 0):
                self.ckp.save(f'models/weights_round_{rnd}.pkl', server.parameters)
            self.ckp.save(f'models/last_round_saved.pkl', rnd)

        # # Final evaluation
        # logger.info(f"[Round {rnd}] Training done. Final test evaluation")
        # server_metrics = _centralized_evaluate(rnd, 'test', log=False)
        # logger.info(f"Final Test Result: {server_metrics}")
        # if server_metrics is not None:
        #     server_metrics["round"] = rnd
        #     for k, v in server_metrics.items():
        #         self.ckp.log_summary(k, v)
        #         if self.save_log_file:
        #             alpha = 0 if dataset.pre_partition else '_'.join(list(dataset.test_alpha.keys()))
        #             self.ckp.save_results_logfile(
        #                 self.mode, alpha, k, v, ps_type=f'init_{self.ckp.config.name}',
        #                 filepath=self.save_log_file, reset=False
        #             )

        # self.ckp.save(f'results/round_{rnd}_test.pkl', {'server_metrics': server_metrics})
        # self.ckp.save(f'models/weights_{rnd}_final.pkl', server.parameters)

    def _snip_ratio_from_width(self, exit_i: int) -> float:
        try:
            ws = float(self.width_scaling[exit_i])
        except Exception:
            ws = 1.0
        return max(0.0, min(1.0, ws))

    def _ensure_snip_masks_ready(self, force_refresh: bool = True) -> None:
        snip_exits = []
        for exit_i, mode in enumerate(getattr(self, "pruning_mode", [])):
            if str(mode).lower() != "snip":
                continue
            snip_exits.append(exit_i)

            path = f"snip/exit_{exit_i}.pkl"
            need_build = force_refresh
            if not force_refresh:
                existing = self.ckp.load(path)
                need_build = not (isinstance(existing, dict) and len(existing) > 0)

            if need_build:
                m = self._compute_snip_mask(exit_i)
                self._snip_masks[exit_i] = m
                self.ckp.save(path, m)

        if snip_exits:
            import logging
            logging.getLogger(__name__).info(
                f"[SNIP] masks ready for exits {snip_exits} (keep_ratio == width_scaling per-exit)"
            )

    def exit_shape(self, exit_i: int) -> dict:
        mode = "scale"
        try:
            mode = str(self.pruning_mode[exit_i]).lower()
        except Exception:
            pass

        try:
            width = float(self.width_scaling[exit_i])
        except Exception:
            width = 1.0

        if mode == "snip":
            if exit_i not in self._snip_masks or not self._snip_masks[exit_i]:
                self._snip_masks[exit_i] = self._compute_snip_mask(exit_i, keep_ratio=width)
            return {"snip_mask": self._snip_masks[exit_i]}

        return {"width_scale": width}

    def _compute_snip_mask(self, exit_i: int, keep_ratio: Optional[float] = None) -> dict:
        """
        Build sub-arch for exit_i at width=1.0, compute SNIP scores on *validation*,
        then pick top-k indices per layer where k is dictated by *Scale* (keep_ratio),
        respecting residual constraints. Single-shot; result cached by caller.
        """
        # import torch
        # import torch.nn as nn
        from copy import deepcopy
        # from src.utils import get_func_from_config
        from src.models.snip_utils import (
            compute_snip_channel_scores,
            build_scale_targets_for_resnet,   # <-- add to snip_utils
            masks_from_snip_and_targets,      # <-- add to snip_utils
        )

        # keep ratio defaults to width_scaling[exit_i]
        if keep_ratio is None:
            keep_ratio = self._snip_ratio_from_width(exit_i)
        keep_ratio = float(max(0.0, min(1.0, keep_ratio)))

        # ---- probe full model to resolve exit depth ----
        net_cfg  = self.ckp.config.models.net
        arch_fn  = get_func_from_config(net_cfg)

        # from copy import deepcopy
        from src.models.gru_prune_utils import compute_snip_unit_scores_gru

        probe = arch_fn(device="cpu", **deepcopy(net_cfg.args))
        is_gru = hasattr(probe, "gru")

        if is_gru:
            # Build sub-arch up to exit_i (num_layers = exit_i+1), width=1.0
            net_args = deepcopy(net_cfg.args)
            net_args["num_layers"]     = exit_i + 1
            net_args["no_of_exits"]    = exit_i + 1
            net_args["last_exit_only"] = True
            model = arch_fn(device="cpu", **net_args)

            # validation loader (token batches)
            data_cfg = self.ckp.config.data
            data_cls = get_func_from_config(data_cfg)
            dataset  = data_cls(self.ckp, **data_cfg.args)
            bs = int(getattr(self.ckp.config.app.eval_fn, "batch_size", 64))
            loader = dataset.get_dataloader(
                data_pool="train", partition="val",
                batch_size=bs, augment=False, num_workers=0, shuffle=False
            )

            pad_idx = int(getattr(model, "pad_idx", 0))
            ce = torch.nn.CrossEntropyLoss(ignore_index=pad_idx, reduction="mean")
            scores = compute_snip_unit_scores_gru(
                model, loader, ce, exit_id=(exit_i if exit_i >= 0 else -1),
                device="cpu", num_batches=1, max_per_batch=bs, log_prefix=f"[SNIP-GRU:e{exit_i}] "
            )

            keep_ratio = self._snip_ratio_from_width(exit_i) if keep_ratio is None else float(keep_ratio)
            H = int(scores.numel())
            k = max(1, min(H, int(round(keep_ratio * H))))
            keep_idx = torch.topk(scores, k=k, largest=True).indices.cpu().tolist()

            return {"units": keep_idx}
        probe = arch_fn(device="cpu", **deepcopy(net_cfg.args))
        if exit_i >= len(getattr(probe, "blks_to_exit", [])):
            raise ValueError(f"[SNIP] exit {exit_i} out of range for this architecture")
        blk_to_exit = int(probe.blks_to_exit[exit_i])
        blks_prefix = list(probe.blks_to_exit[: exit_i + 1])
        del probe

        # ---- exact sub-arch for this exit (width=1.0 space for indices) ----
        net_args = deepcopy(net_cfg.args)
        net_args["depth"]          = blk_to_exit + 1
        net_args["blks_to_exit"]   = blks_prefix
        net_args["no_of_exits"]    = exit_i + 1
        net_args["width_scale"]    = 1.0          # indices in global width space
        net_args["last_exit_only"] = True         # logits from the last local head

        device = torch.device("cpu")
        model  = arch_fn(device=device, **net_args)
        model.zero_grad(set_to_none=True)

        # ---- Validation loader (prefer server/test); no dummy unless last resort ----
        data_cfg = self.ckp.config.data
        data_cls = get_func_from_config(data_cfg)
        dataset  = data_cls(self.ckp, **data_cfg.args)

        bs = int(getattr(self.ckp.config.app.eval_fn, "batch_size", 64))
        loader = None
        src    = None

        try:
            loader = dataset.get_dataloader(
                data_pool="train", partition="val",
                batch_size=bs, augment=False, num_workers=0, shuffle=False
            )
            src = "train/train"
        except Exception:
            raise
            # last resort: deterministic dummy (should basically never trigger)
            C, H, W = 3, 32, 32
            num_classes = int(net_args.get("num_classes", 10))
            g = torch.Generator().manual_seed(12345)
            x = torch.randn(bs, C, H, W, generator=g)
            y = torch.randint(0, num_classes, (bs,), generator=g)
            loader = [(x, y)]
            src = "dummy"

        # ---- SNIP scores on validation ----
        scores = compute_snip_channel_scores(
            model, loader, nn.CrossEntropyLoss(), device=device,
            num_batches=1, max_per_batch=bs, log_prefix=f"[SNIP:e{exit_i}] "
        )

        if keep_ratio >= 1.0:
            # no pruning at width=1.0
            return {}

        # ---- Scale targets (k per layer) at this width ratio; then SNIP top-k ----
        # Here we pass a scalar policy = keep_ratio; the helper enforces residual rules.
        targets = build_scale_targets_for_resnet(model, keep_ratio)
        masks   = masks_from_snip_and_targets(scores, targets, model)

        # compact lists for JSONability
        masks = {k: (v if isinstance(v, list) else list(v)) for k, v in masks.items()}

        # Minimal summary
        kept = {k: len(v) for k, v in masks.items()}
        print(f"[SNIP/exit{exit_i}] src={src} keep_ratio={keep_ratio:.3f} layers_pruned={len(kept)} "
            f"sample={list(kept.items())[:4]} ...")

        return masks

    def _warm_snip_masks(self, force: bool = True) -> None:
        exits = []
        for i, mode in enumerate(getattr(self, "pruning_mode", []) or []):
            if str(mode).lower() != "snip":
                continue
            if force or i not in self._snip_masks or not self._snip_masks[i]:
                width = float(self.width_scaling[i]) if self.width_scaling else 1.0
                self._snip_masks[i] = self._compute_snip_mask(i, keep_ratio=width)
            exits.append(i)
    
