# src/server/strategies/snowfl_fedavg.py
import logging, copy
from typing import Dict, List, Tuple, Optional
import numpy as np
from flwr.common import Parameters, Scalar, FitRes, parameters_to_weights
from flwr.server.client_proxy import ClientProxy
from flwr.common import FitIns, weights_to_parameters
import hashlib
from src.server.strategies.scalefl_fedavg import ScaleFLFedAvg
# --- add near top of file ---
from typing import DefaultDict
from collections import defaultdict
import random
import torch
import json

logger = logging.getLogger(__name__)

class SNOWFLFedAvg(ScaleFLFedAvg):
    """
    SNOWFL = ScaleFL FedAvg + (optional) BN-recalibration on clients + Owen valuation.
    IMPORTANT:
      • Aggregation path is EXACTLY the same as ScaleFL (trainables only; no BN averaging).
      • Owen is computed AFTER aggregation for analysis/regrouping, NOT mixed into weights.
      • Optional client BN-recal does not change payload size or server aggregation.
    """

    def __init__(
        self,
        *args,
        bn_sync_every: int = 0,          # 0 disables BN-recal rounds
        bn_bootstrap: bool = False,      # if True, also recal at round 1
        bn_calib_source: str = "val",    # "val" or "train"
        bn_calib_batches: int = 400,      # small, cheap sweep
        bn_on_regroup: bool = False,
        **kwargs
    ):
        super().__init__(*args, **kwargs)
        self._central_snip_enabled = False
        self._snip_any = False
        self._snip_mask_by_width: dict[float, dict] = {}  # width_scale -> mask
        
        self.bn_sync_every    = int(max(0, bn_sync_every))
        self.bn_bootstrap     = bool(bn_bootstrap)
        self.bn_calib_source  = str(bn_calib_source).lower()
        self.bn_calib_batches = int(max(1, bn_calib_batches))
        self.bn_on_regroup    = bool(bn_on_regroup)

        self._bn_recal_next: set[str] = set()  # clients to ask on next round

        self._client_mask_sig: dict[str, str] = {}

        # Respect pruning_mode from parent; central recompute OFF by default (compat)
        try:
            app_args = getattr(self.ckp.config.app, "args", None)
            self._central_snip_enabled = bool(getattr(app_args, "central_snip_on_server", False))
        except Exception:
            self._central_snip_enabled = False

    # --- Helpers: build ephemeral model, load trainables, and recal BN on server ---

    def _bn_force_tracking(self, model):
        """Ensure BN layers have running buffers and track stats; skip exit-head BNs."""
        import torch, torch.nn as nn
        dev = next(model.parameters()).device
        for name, m in model.named_modules():
            if not isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
                continue
            # exit-head BNs are usually per-batch; keep them untouched
            if name.startswith("exit_heads."):
                continue
            m.track_running_stats = True
            nf = int(m.num_features)

            # (re)allocate buffers if missing/None and place on correct device
            if getattr(m, "running_mean", None) is None:
                if hasattr(m, "running_mean"): delattr(m, "running_mean")
                m.register_buffer("running_mean", torch.zeros(nf, device=dev))
            else:
                m.running_mean = m.running_mean.to(dev)

            if getattr(m, "running_var", None) is None:
                if hasattr(m, "running_var"): delattr(m, "running_var")
                m.register_buffer("running_var", torch.ones(nf, device=dev))
            else:
                m.running_var = m.running_var.to(dev)

            if hasattr(m, "num_batches_tracked"):
                if getattr(m, "num_batches_tracked", None) is None:
                    if hasattr(m, "num_batches_tracked"): delattr(m, "num_batches_tracked")
                    m.register_buffer("num_batches_tracked", torch.tensor(0, dtype=torch.long, device=dev))
                else:
                    m.num_batches_tracked = m.num_batches_tracked.to(device=dev, dtype=torch.long)
                    
    def _build_dataset_for_server(self):
        """Fresh (or cached) dataset instance; same ctor as eval path."""
        if hasattr(self, "_server_dataset") and self._server_dataset is not None:
            return self._server_dataset
        from src.utils import get_func_from_config
        data_cfg = self.ckp.config.data
        data_cls = get_func_from_config(data_cfg)
        self._server_dataset = data_cls(self.ckp, **data_cfg.args)
        return self._server_dataset

    def _load_trainables_by_key(self, model, weights_np_list):
        """Load aggregated trainables into model by the strategy's global key order."""
        try:
            keys = list(self.global_sd_keys)  # from ScaleFLFedAvg
        except Exception:
            keys = list(getattr(model, "trainable_state_dict_keys", []))
        from src.models.model_utils import set_partial_weights
        set_partial_weights(model, keys, weights_np_list)

    def _server_bn_recal_apply(self, new_params):
        """
        Server-side BN calibration that *can* change the outgoing parameters.

        If outgoing payload == FULL STATE → write back recalibrated BN buffers now.
        If payload == TRAINABLES ONLY or no loader → keep weights neutral and
        schedule client-side BN recal next round.
        """
        import torch, numpy as np
        import torch.nn as nn
        from flwr.common import parameters_to_weights, weights_to_parameters
        from src.utils import get_func_from_config

        # Optional global switch
        args = getattr(self.ckp.config.app, "args", None)
        do_bn = True if args is None else bool(getattr(args, "server_bn_during_agg", True))
        if not do_bn:
            return new_params

        # knobs
        bn_src = str(getattr(self, "bn_calib_source", "val")).lower()
        try:
            bn_batches = int(getattr(self, "bn_calib_batches", 400))
        except Exception:
            bn_batches = 400
        try:
            bs_eval = int(getattr(self.ckp.config.app.eval_fn, "batch_size", 128))
        except Exception:
            bs_eval = 128

        # 1) build shadow model
        net_cfg = self.ckp.config.models.net
        arch_fn = get_func_from_config(net_cfg)
        dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        shadow = arch_fn(device=dev, **net_cfg.args)

        # 2) load aggregated weights
        agg_list = parameters_to_weights(new_params)
        self._load_agg_weights_len_aware(shadow, agg_list)

        # detect payload kind
        full_keys = list(shadow.state_dict().keys())
        try:
            train_keys = list(self.global_sd_keys)
        except Exception:
            train_keys = [k for k, _ in shadow.named_parameters()]
        payload_len = len(agg_list)
        is_full_payload = (payload_len == len(full_keys))
        is_trainables_payload = (payload_len == len(train_keys))
        kind = "full" if is_full_payload else ("trainables" if is_trainables_payload else f"unknown(len={payload_len})")
        logger.info("[SERVER][BN] calib_source=%s payload_kind=%s", bn_src, kind)

        # 3) BN-only sweep (CMA momentum=None), skip exit-head BNs
        self._bn_force_tracking(shadow)
        shadow.eval()
        reset_stats = bool(getattr(self, "bn_reset_stats", True))  # opt-in attribute; defaults True
        for name, m in shadow.named_modules():
            if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
                if name.startswith("exit_heads."):
                    continue
                m.train()
                if reset_stats:
                    m.running_mean.zero_()
                    m.running_var.fill_(1.0)
                    if hasattr(m, "num_batches_tracked"):
                        m.num_batches_tracked.zero_()
                m.momentum = None  # cumulative moving average

        # dataloader
        dataset = self._build_dataset_for_server()
        part = "val" if bn_src == "val" else "train"
        try:
            loader = dataset.get_dataloader(
                data_pool="server", partition=part,
                batch_size=bs_eval, augment=False, num_workers=0, shuffle=False
            )
        except Exception:
            loader = None

        if loader is None:
            logger.info("[SERVER][BN] No server/%s loader; scheduling client BN recal next round.", part)
            self._bn_recal_next = {"*"}
            del shadow
            try: torch.cuda.empty_cache()
            except Exception: pass
            return new_params

        seen = 0
        with torch.no_grad():
            for xb, _ in loader:
                if xb.dim() >= 4 and xb.size(1) == 1:
                    xb = xb.expand(-1, 3, xb.shape[2], xb.shape[3])
                if xb.size(0) > 512:
                    xb = xb[:512]
                _ = shadow(xb.to(dev, non_blocking=True))
                seen += 1
                if seen >= bn_batches:
                    break

        shadow.eval()
        logger.info("[SERVER][BN] Applied BN sweep on %d batch(es) from server/%s.", seen, part)

        # 4) write-back / scheduling
        if seen > 0 and is_full_payload:
            sd = shadow.state_dict()
            full_out = [sd[k].detach().cpu().numpy() for k in full_keys]
            new_params = weights_to_parameters(full_out)
            logger.info("[SERVER][BN] Broadcast will include recalibrated BN buffers (full-state payload).")
        else:
            self._bn_recal_next = {"*"}
            if is_trainables_payload:
                logger.info("[SERVER][BN] Trainables-only payload: scheduling client BN recal next round.")
            else:
                logger.info("[SERVER][BN] Unknown payload shape: keeping weights as-is; scheduling client BN.")

        # cleanup
        del shadow
        try: torch.cuda.empty_cache()
        except Exception: pass
        return new_params

    def _mask_sig(self, d: dict) -> str:
        import json, hashlib
        return hashlib.sha1(json.dumps(d, sort_keys=True).encode()).hexdigest()[:10]

    def _exit_is_snip(self, exit_i: int) -> bool:
        try:
            pm = getattr(self.ckp.config.app.args, "pruning_mode", [])
            return (exit_i < len(pm)) and (str(pm[exit_i]).lower() == "snip")
        except Exception:
            return False

    def _mask_for_exit(self, exit_i: int):
        """Width-tied SNIP: all exits sharing the same width use ONE mask."""
        try:
            width = float(self.width_scaling[exit_i])
        except Exception:
            width = 1.0
        return self._mask_for_width(width)

    def _mask_for_width(self, width: float):
        """Return the cached SNIP mask for a given width if available; else try to fetch once."""
        # 1) if cached, return
        m = self._snip_mask_by_width.get(float(width), None)
        if isinstance(m, dict) and m:
            return m

        # 2) try to pull from app metadata once (expected to be precomputed over VALIDATION)
        try:
            if hasattr(self.app, "width_shape"):
                shape = self.app.width_shape(width) or {}
                m = shape.get("snip_mask", None)
                if isinstance(m, dict) and m:
                    self._snip_mask_by_width[float(width)] = m
                    return m
        except Exception:
            pass

        # 3) fallback: if the app only exposes per-exit shapes, pick any exit with same width
        try:
            if hasattr(self.app, "exit_shape") and hasattr(self, "width_scaling"):
                for e in range(self.no_of_exits):
                    if float(self.width_scaling[e]) == float(width):
                        shape = self.app.exit_shape(e) or {}
                        m = shape.get("snip_mask", None)
                        if isinstance(m, dict) and m:
                            self._snip_mask_by_width[float(width)] = m
                            return m
        except Exception:
            pass

        return None  # no mask available
    
    def _mark_bn_recal_clients(self, before_map: dict, after_map: dict, rnd: int) -> None:
        # cadence scheduling is handled by _server_bn_recal_shadow after a successful sweep
        changed = set()
        if self.bn_on_regroup and before_map and after_map:
            for cid, new_exit in after_map.items():
                old_exit = before_map.get(cid, None)
                if old_exit is not None and int(old_exit) != int(new_exit):
                    changed.add(cid)

        if changed:
            self._bn_recal_next.update(changed)
            logger.info("[SERVER][BN] round %d → exit changes for %d clients; scheduling BN recal next round", rnd, len(changed))

    def _compute_owen_slice_weights(
        self,
        results,
        *,
        per_exit_norm: bool = True,
        temp: float = 0.35,
        floor: float = 0.50,
        ceil: float = 2.00,
    ) -> dict[str, float]:
        """
        Turn per-client Owen values into aggregation multipliers.
        When per_exit_norm=True, preserve *each exit's* mean weight == 1 after clipping.
        """
        import numpy as np

        cids  = [c.cid for (c, _) in results]
        exits = np.array([int(getattr(self, "clients_exit", {}).get(cid, 0)) for cid in cids], dtype=int)
        vals  = np.array([float(max(1e-8, getattr(c, "value", 1.0))) for (c, _) in results], dtype=float)

        z = np.log(vals + 1e-12)

        def _softmax(x, t):
            x = x - x.max()
            y = np.exp(x / max(1e-6, t))
            s = y.sum()
            return y / s if s > 0 else np.ones_like(y) / len(y)

        w = np.zeros_like(z)

        if per_exit_norm:
            for e in np.unique(exits):
                idx = np.where(exits == e)[0]
                if idx.size == 0:
                    continue
                w_e = _softmax(z[idx], temp) * float(idx.size)   # mean=1 on that exit
                w_e = np.clip(w_e, floor, ceil)                  # clip within the slice
                w_e *= (idx.size / w_e.sum())                    # restore mean=1 per exit
                w[idx] = w_e
        else:
            w = _softmax(z, temp) * float(len(z))
            w = np.clip(w, floor, ceil)
            w *= (len(w) / w.sum())                              # restore mean=1 globally

        return {cid: float(wi) for cid, wi in zip(cids, w)}

    def _owen_mix_factor(self, rnd: int, total_rounds: Optional[int] = None, mode: str = "linear_up") -> float:
        # returns a scalar in [0,1] controlling how much we trust Owen in aggregation
        if total_rounds is None:
            try:
                total_rounds = int(getattr(self, "global_rounds", 200))
            except Exception:
                total_rounds = 200
        r = max(0, min(rnd, max(1, total_rounds)))
        if mode == "const":
            return float(getattr(self, "owen_mix_const", 1.0))
        if mode == "linear_up":
            return float(r / float(max(1, total_rounds)))
        if mode == "linear_down":
            return float(1.0 - r / float(max(1, total_rounds)))
        return 1.0

    def _owen_value_for(self, cid: str, neutral: float = 1.0) -> float:
        try:
            m = getattr(self, "_owen_values", None)
            if isinstance(m, dict) and cid in m:
                v = float(m[cid])
                if v > 0 and np.isfinite(v):
                    return v
        except Exception:
            pass
        # fallbacks: client proxy's EMA value; else neutral
        try:
            mgr = self.client_manager()
            for p in getattr(mgr, "all")() if callable(getattr(mgr, "all", None)) else []:
                if getattr(p, "cid", None) == cid:
                    return float(getattr(p, "value", neutral))
        except Exception:
            pass
        return float(neutral)

    def owen_runtime_update(self, **kwargs):
        try:
            cv = getattr(self, "client_valuation", None)
            if cv is not None and hasattr(cv, "update_from_server"):
                cv.update_from_server(**kwargs)
        except Exception:
            pass

    def owen_freeze(self, value=None):
        try:
            cv = getattr(self, "client_valuation", None)
            if cv is not None and hasattr(cv, "freeze_values"):
                cv.freeze_values(server=self._server if hasattr(self, "_server") else self, value=value)
        except Exception:
            pass
    
    def _owen_update(self, rnd: int, server=None):
        if not hasattr(self, "client_valuation") or self.client_valuation is None:
            return
        cv = self.client_valuation
        try:
            total = int(getattr(server.ckp.config.app.args, "global_rounds", 200))
        except Exception:
            total = 200
        prog = max(0.0, min(1.0, rnd / max(1, total)))

        cv.lam = max(0.02, 0.35 * (1.0 - prog))
        cv.lam_schedule = "off"

        cv.gamma = min(1.0, 0.90 * (0.25 + 0.75 * prog))
        cv.gamma_schedule = "off"

        cv.temp0, cv.temp1 = 0.50, 0.08
        cv.beta0, cv.beta1 = 1.00, 1.15
        cv.intra_mix_eps = 0.03

        cv.regroup_every = 1
        cv.min_stay_rounds = 1
        cv.strict_equal_groups = True
        cv.map_best_to_deepest = True

        cv.aggressive_mode = (prog <= 0.20)
        cv.normalize_round = False
        cv.use_target_in_shapley = True

        cv.stale_half_life = 6
        cv.ema_decay = 0.0
        cv.idle_ema = 0.0
        
    def aggregate_fit(
        self,
        rnd: int,
        results: list[tuple[ClientProxy, FitRes]],
        failures: list[BaseException],
        current_parameters: Parameters,
        server=None,
    ):
        new_params, metrics = super().aggregate_fit(rnd, results, failures, current_parameters, server)
        if new_params is None:
            return None, metrics

        before_map = dict(getattr(self, "clients_exit", {}))

        if hasattr(self, "_owen_update"):
            self._owen_update(rnd, server)

        if hasattr(self, "client_valuation") and self.client_valuation is not None:
            try:
                current_w  = parameters_to_weights(current_parameters)
                clients_w  = [parameters_to_weights(fr.parameters) for _, fr in results]
                client_n   = [float(getattr(fr, "num_examples", 1.0)) for _, fr in results]
                client_ids = [c.cid for (c, _) in results]
                lr = self.on_fit_config_fn(rnd).get("lr", 0.01) if self.on_fit_config_fn else 0.01
                _ = self.client_valuation.evaluate(
                    current_weights=current_w, lr=lr,
                    weights_1=clients_w, weights_2=parameters_to_weights(new_params),
                    use_val=False, client_samples=client_n, client_ids=client_ids,
                    server=server, strategy=self, results=results, round_idx=rnd,
                )
            except Exception as e:
                logger.warning(f"[SNOWFL] Owen valuation skipped: {e}")

        try:
            after_map = dict(getattr(self, "clients_exit", {}))
            self._mark_bn_recal_clients(before_map, after_map, rnd)
        except Exception:
            pass

        try:
            if (self.bn_sync_every > 0 and rnd % self.bn_sync_every == 0):
                new_params = self._server_bn_recal_apply(new_params)
        except Exception as e:
            logger.warning("[SERVER][BN] Shadow BN recal skipped: %s", e)

        return new_params, metrics

    # inside SNOWFLFedAvg

    def _load_agg_weights_len_aware(self, model, weights_np_list):
        """
        Load aggregated weights into `model`, picking the matching key list
        (trainables-only or full state) based on payload length.
        """
        from src.models.model_utils import set_partial_weights

        # Full state keys (params + buffers), deterministic order
        full_keys = list(model.state_dict().keys())

        # Trainable keys in canonical strategy order
        try:
            train_keys = list(self.global_sd_keys)   # from ScaleFLFedAvg
        except Exception:
            # Fallback: named_parameters order (won’t include buffers)
            train_keys = [k for k, _ in model.named_parameters()]

        n_w = len(weights_np_list)
        if n_w == len(full_keys):
            update_keys = full_keys
            kind = "full"
        elif n_w == len(train_keys):
            update_keys = train_keys
            kind = "trainables"
        else:
            # Last-chance: allow a custom global full list if you have one
            try:
                gf = list(self.global_full_sd_keys)
                if n_w == len(gf):
                    update_keys = gf
                    kind = "global_full"
                else:
                    raise ValueError
            except Exception:
                raise ValueError(
                    f"_load_agg_weights_len_aware: no key list matches payload length "
                    f"(weights={n_w}, train_keys={len(train_keys)}, full_keys={len(full_keys)})"
                )

        logger.debug("[SERVER][BN] Loading %s (%d tensors) into shadow model.", kind, n_w)
        set_partial_weights(model, update_keys, weights_np_list)
        
    def configure_fit(self, rnd, parameters: Parameters, client_manager):
        # cadence/regroup state from scheduler
        do_all = "*" in getattr(self, "_bn_recal_next", set())
        bn_targets = set() if do_all else set(getattr(self, "_bn_recal_next", set()))
        consumed = set()

        base = super().configure_fit(rnd, parameters, client_manager)
        if not base:
            return base

        patched = []

        def _fallback_exit(cid: str) -> int:
            try:
                app_mode = getattr(self.ckp.config.app.args, "mode", "multi_tier")
            except Exception:
                app_mode = "multi_tier"
            if str(app_mode).lower() == "maximum":
                return max(0, self.no_of_exits - 1)
            return int(cid) % max(1, self.no_of_exits)

        for client, fitins in base:
            cfg = dict(fitins.config or {})

            # prefer current strategy mapping; fallback deterministic
            mapped = None
            try:
                mapped = int(getattr(self, "clients_exit", {}).get(client.cid, None))
            except Exception:
                mapped = None
            exit_i = mapped if mapped is not None else cfg.get("lid_hint", _fallback_exit(client.cid))
            cfg["lid_hint"] = int(exit_i)

            # BN due if cadence/regroup marked us
            bn_due = (do_all or (client.cid in bn_targets))

            # SNIP: per-exit structural mask → deep copy + stable signature
            if self._exit_is_snip(exit_i):
                mask = self._mask_for_exit(exit_i)
                if isinstance(mask, dict) and mask:
                    import copy as _copy, json as _json
                    cfg["snip_mask"] = _json.dumps(_copy.deepcopy(mask), sort_keys=True)
                    sig_new = self._mask_sig(mask)
                    sig_old = self._client_mask_sig.get(client.cid)
                    if sig_new != sig_old:
                        bn_due = True  # structure changed → force one BN recal
                        cfg.setdefault("bn_calib_source", self.bn_calib_source)
                        cfg.setdefault("bn_calib_batches", self.bn_calib_batches)
                        self._client_mask_sig[client.cid] = sig_new
                else:
                    cfg.pop("snip_mask", None)

            # inject or clear BN flags
            if bn_due:
                cfg["need_bn_recal"] = True
                cfg.setdefault("bn_calib_source", self.bn_calib_source)
                cfg.setdefault("bn_calib_batches", self.bn_calib_batches)
                consumed.add(client.cid)
            else:
                cfg.pop("need_bn_recal", None)
                cfg.pop("bn_calib_source", None)
                cfg.pop("bn_calib_batches", None)

            fitins.config = cfg
            patched.append((client, fitins))

        # clear cadence state we consumed
        if do_all:
            self._bn_recal_next.clear()
        else:
            self._bn_recal_next.difference_update(consumed)

        return patched

    def __configure_fit(self, rnd, parameters: Parameters, client_manager):
        """
        Build per-client FitIns with:
        • per-exit personalized weights / keys (ScaleFL behavior),
        • optional central SNIP mask recompute on server/VAL,
        • SNIP mask shipping (width/exit-tied),
        • BN recal scheduling:
            - cadence/regroup via self._bn_recal_next (supports "*" = all),
            - OR one-shot when the shipped SNIP mask signature changes for that client.
        Never alters aggregation rules (still trainables only).
        """
        # --- defensive init of internal fields ---
        if not hasattr(self, "_central_snip_enabled"): self._central_snip_enabled = False
        if not hasattr(self, "_snip_any"):             self._snip_any = False
        if not hasattr(self, "_snip_masks_by_exit"):   self._snip_masks_by_exit = {}
        if not hasattr(self, "_snip_mask_cache"):      self._snip_mask_cache = None
        if not hasattr(self, "_client_mask_sig"):      self._client_mask_sig = {}  # cid -> sig
        if not hasattr(self, "_bn_recal_next"):        self._bn_recal_next = set()
        if not hasattr(self, "clients_exit") or not isinstance(self.clients_exit, dict):
            self.clients_exit = {}

        # --- optional central SNIP recompute on server VALIDATION split ---
        if self._central_snip_enabled and self._snip_any:
            need_cold = not self._snip_mask_cache
            need_refresh = (self._snip_mask_cache is not None
                            and getattr(self, "_snip_refresh_every", 0) > 0
                            and rnd % self._snip_refresh_every == 0)
            if need_cold or need_refresh:
                msg = "cold start" if need_cold else f"periodic refresh @ round {rnd}"
                print(f"[SERVER][SNIP] Computing per-exit masks on central VAL ({msg})…")
                self._snip_masks_by_exit = self._compute_central_masks_for_exits()
                self._snip_mask_cache = copy.deepcopy(self._snip_masks_by_exit)
                for exit_i in range(self.no_of_exits):
                    self._prepare_exit_indices_for_round(exit_i)

        # --- get base sampling from ScaleFL (it picks clients etc.) ---
        base = super().configure_fit(rnd, getattr(self, "_last_global_params", parameters), client_manager)
        if not base:
            return base

        # --- BN cadence state for this round ---
        do_all = "*" in self._bn_recal_next
        bn_targets = set() if do_all else set(self._bn_recal_next)  # copy
        consumed = set()

        def _client_lid(cid: str) -> int:
            try:
                app_mode = getattr(self.ckp.config.app.args, "mode", "multi_tier")
            except Exception:
                app_mode = "multi_tier"
            if str(app_mode).lower() == "maximum":
                return max(0, self.no_of_exits - 1)
            return int(cid) % max(1, self.no_of_exits)

        patched = []

        for client, fitins in base:
            # --- decide exit for this client ---
            exit_i = self.clients_exit.get(client.cid, _client_lid(client.cid))

            # --- start from parent config and enforce trainables-only payload ---
            c = dict(fitins.config or {})
            c.update({
                "keyset_kind": "train",
                "send_full_local_state": False,
            })

            # --- BN recal need from cadence/regroup set ---
            bn_due = do_all or (client.cid in bn_targets)

            # --- SNIP mask for this exit ({} if not SNIP or keep≈1.0) ---
            # Prefer your existing helper if present; else try width/exit variants.
            try:
                snip = self._get_mask_for_exit(exit_i)
            except Exception:
                try:
                    snip = self._mask_for_exit(exit_i)
                except Exception:
                    snip = {}

            if isinstance(snip, dict) and snip:
                # Ship mask (client accepts dict OR str; dict is fine)
                c["snip_mask"] = snip
                # One-time BN when the structural mask changes for THIS client
                sig_new = self._mask_sig(snip)
                sig_old = self._client_mask_sig.get(client.cid)
                if sig_new != sig_old:
                    bn_due = True
                    self._client_mask_sig[client.cid] = sig_new
                    logger.info(
                        "[SERVER][SNIP] rnd=%s cid=%s exit=%d -> NEW mask sig=%s (bn_recal=on)",
                        rnd, client.cid, exit_i, sig_new[:8],
                    )
            else:
                # keep payload clean if no effective mask
                c.pop("snip_mask", None)

            # --- inject BN flags if due; else scrub keys to avoid stale state ---
            if bn_due:
                c["need_bn_recal"]    = True
                c["bn_calib_source"]  = getattr(self, "bn_calib_source", "val")
                c["bn_calib_batches"] = int(getattr(self, "bn_calib_batches", 540))
                consumed.add(client.cid)
            else:
                c.pop("need_bn_recal", None)
                c.pop("bn_calib_source", None)
                c.pop("bn_calib_batches", None)

            # --- personalize payload: keys & weights for this exit (ScaleFL behavior) ---
            local_keys    = self.exit_local_sd_keys[exit_i]
            local_weights = self.get_personalized_exit_weights(exit_i, getattr(self, "_last_global_params", parameters))
            fitins.parameters = weights_to_parameters(local_weights)
            # c["keys_prog"] = local_keys
            c["keys_prog"] = json.dumps(local_keys)
            c["lid_hint"]  = exit_i
            fitins.config  = c

            patched.append((client, fitins))

        # --- clear cadence state we consumed ---
        if do_all:
            self._bn_recal_next.clear()
        else:
            self._bn_recal_next.difference_update(consumed)

        return patched
    

    