# src/apps/clients/scalefl_classification_client.py
import os
import copy
from typing import List, Dict
from collections import OrderedDict, Counter
import numpy as np
import torch
# ADD near the other imports at top of file:
from src.utils import get_func_from_config

from src.models.model_utils import retarget_exit_preconv_bns

from src.apps.clients.reefl_classification_client import ReeFLClassificationClient
from src.models.snip_utils import (
        compute_snip_channel_scores,
        create_resnet_channel_masks,
        hard_prune_resnet,
        recalibrate_bn,
    )

import ast

def _parse_modes(obj):
    """Return a per-exit list of lowercase mode strings (robust to string/list)."""
    if isinstance(obj, (list, tuple)):
        return [str(x).lower() for x in obj]
    if isinstance(obj, str):
        s = obj.strip()
        if s.startswith("["):
            try:
                vals = ast.literal_eval(s)
                return [str(x).lower() for x in vals]
            except Exception:
                return [s.lower()]
        return [s.lower()]
    return []

class ScaleFLClassificationClient(ReeFLClassificationClient):
    """
    Scale/Hetero/SNIP client.

    SNIP path:
      • parent builds a *full-width* model first (width_scale forced to 1.0),
      • we structurally prune with the provided `snip_mask`,
      • we recalibrate BN once, locally, on train data before the first fit.

    Scale path:
      • no mask => parent width_scale is used directly (no structural surgery).
    """

    def __init__(self, cid: str, lid: int, width_scale: float, *args, **kwargs):
        snip_mask = kwargs.pop("snip_mask", None)

        # If we will apply a SNIP mask at init, parent must build full-width.
        initial_width_scale = 1.0 if snip_mask is not None else width_scale

        # ---- TEMPORARILY FORCE CPU FOR PARENT NET CONSTRUCTION (SNIP path only) ----
        ckp = args[0] if len(args) > 0 else None
        orig_device, forced_cpu = None, False
        if snip_mask is not None and ckp is not None:
            try:
                orig_device = getattr(ckp, "device", None)
                if str(orig_device).lower().startswith("cuda"):
                    setattr(ckp, "device", "cpu")
                    forced_cpu = True
            except Exception:
                pass

        # Build the per-exit model (full width if SNIP; else requested width)
        super().__init__(cid, lid, initial_width_scale, *args, **kwargs)

        # Restore the checkpoint device after CPU override
        if forced_cpu and ckp is not None and orig_device is not None:
            try:
                setattr(ckp, "device", orig_device)
            except Exception:
                pass

        # --- Build and cache a TRUE full-width template for future SNIP masks ---
        try:
            arch_fn = getattr(self, "_eval_arch_fn")
            net_args = copy.deepcopy(getattr(self, "_eval_net_args"))
            net_args["width_scale"] = 1.0  # same depth/exits as this client, width=1.0
            self._fullwidth_template = arch_fn(device="cpu", **net_args).cpu()
            print(f"[client {self.cid}] cached full-width template (width_scale=1.0) for future SNIP masks")
        except Exception as e:
            # ultra-safe fallback: rebuild via config
            from src.utils import get_func_from_config
            data_arch = get_func_from_config(self.ckp.config.model)
            alt_args = copy.deepcopy(self.ckp.config.model.args)
            alt_args["width_scale"] = 1.0
            self._fullwidth_template = data_arch(device="cpu", **alt_args).cpu()
            print(f"[client {self.cid}] fallback full-width template; reason: {e}")

        self._need_bn_recal = False
        self._last_mask_sig = None

        # Target device = original CUDA device if we forced CPU above, else self.device
        target_device = orig_device if (forced_cpu and orig_device is not None) else getattr(self, "device", "cpu")

        # ---- Apply SNIP mask immediately if provided ----
        if snip_mask is not None:
            print(f"[client {self.cid}] Built full-width model. Applying structural SNIP mask (init) on CPU...")

            # never mutate the cached template
            base_cpu = copy.deepcopy(self._fullwidth_template).cpu()

            # fetch one real batch for patching (prefer VAL, fallback TRAIN)
            patch_loader = self._build_loader_from_config("val") or self._build_loader_from_config("train")
            patch_batch = None
            if patch_loader is not None:
                try:
                    xb, _ = next(iter(patch_loader))
                    patch_batch = xb[:256]  # keep it light and deterministic
                except Exception:
                    patch_batch = None

            # prune with real-batch context so exit taps & FC get patched correctly
            pruned_cpu = hard_prune_resnet(
                base_cpu,
                snip_mask,
                patch_sample=patch_batch,
                patch_loader=patch_loader,
                patch_max_per_batch=256,
            ).cpu()

            self.net = pruned_cpu.to(target_device, non_blocking=False)
            print(f"[client {self.cid}] SNIP mask applied at init. Model structurally pruned.")

            # ensure head preconv BNs are retargeted (extra safety)
            retarget_exit_preconv_bns(self.net, verbose=True)

            # refresh payload/key lists and optimizer wiring after structural surgery
            self._refresh_keylists_after_surgery()
            self._apply_trainability_after_surgery()

            # track sig & request one-time BN recal later
            self._last_mask_sig = self._mask_sig(snip_mask)
            self._need_bn_recal = True
        else:
            # Width-only clients: ensure net is on target device
            try:
                if target_device and str(target_device).lower().startswith("cuda"):
                    self.net = self.net.to(target_device, non_blocking=False)
            except Exception:
                pass

        # (re)build lists & shapes and apply trainability/optimizer wiring
        self._refresh_keylists_after_surgery()
        self._apply_trainability_after_surgery()
    
    def old__init__(self, cid: str, lid: int, width_scale: float, *args, **kwargs):
        snip_mask = kwargs.pop("snip_mask", None)
        
        # If we will apply a SNIP mask at init, parent must build full-width.
        initial_width_scale = 1.0 if snip_mask is not None else width_scale

        # ---- TEMPORARILY FORCE CPU FOR PARENT NET CONSTRUCTION (SNIP path only) ----
        ckp = args[0] if len(args) > 0 else None
        orig_device = None
        forced_cpu = False
        if snip_mask is not None and ckp is not None:
            try:
                orig_device = getattr(ckp, "device", None)
                if str(orig_device).lower().startswith("cuda"):
                    setattr(ckp, "device", "cpu")
                    forced_cpu = True
            except Exception:
                pass

        # Build the per-exit model (full width if SNIP; else requested width)
        super().__init__(cid, lid, initial_width_scale, *args, **kwargs)
        # --- Build and cache a TRUE full-width template for future SNIP masks ---
        try:
            arch_fn = getattr(self, "_eval_arch_fn")
            net_args = copy.deepcopy(getattr(self, "_eval_net_args"))
            net_args["width_scale"] = 1.0
            # keep same depth/exits as this client
            self._fullwidth_template = arch_fn(device="cpu", **net_args).cpu()
            print(f"[client {self.cid}] cached full-width template (width_scale=1.0) for future SNIP masks")
        except Exception as e:
            # ultra-safe fallback: rebuild via config
            from src.utils import get_func_from_config
            data_arch = get_func_from_config(self.ckp.config.model)
            alt_args = copy.deepcopy(self.ckp.config.model.args)
            alt_args["width_scale"] = 1.0
            self._fullwidth_template = data_arch(device="cpu", **alt_args).cpu()
            print(f"[client {self.cid}] fallback full-width template; reason: {e}")
        # Restore the checkpoint device after CPU override
        if forced_cpu and ckp is not None and orig_device is not None:
            try:
                setattr(ckp, "device", orig_device)
            except Exception:
                pass

        # -------- Cache full-width template (same depth/exits as THIS client) --------
        try:
            # safest: just deepcopy the net we *just built*, which is full-width in SNIP path
            self._fullwidth_template = copy.deepcopy(self.net).to("cpu")
            print(f"[client {self.cid}] cached full-width template for future SNIP masks")
        except Exception as e:
            try:
                arch_fn = getattr(self, "_eval_arch_fn")
                net_args = copy.deepcopy(getattr(self, "_eval_net_args"))
                net_args["width_scale"] = 1.0
                self._fullwidth_template = arch_fn(device="cpu", **net_args).cpu()
                print(f"[client {self.cid}] fallback template via arch_fn (reason: {e})")
            except Exception as e2:
                self._fullwidth_template = copy.deepcopy(self.net).to("cpu")
                print(f"[client {self.cid}] WARNING: second fallback (deepcopy current). Reason: {e2}")

        self._need_bn_recal = False
        self._last_mask_sig = None

        # Target device = original CUDA device if we forced CPU above, else self.device
        target_device = orig_device if (forced_cpu and orig_device is not None) else getattr(self, "device", "cpu")

        # ---- Apply SNIP mask immediately if provided ----
        # ---- Apply SNIP mask immediately if provided ----
        if snip_mask is not None:
            print(f"[client {self.cid}] Built full-width model. Applying structural SNIP mask (init) on CPU...")

            # never mutate the cached template
            base_cpu = copy.deepcopy(self._fullwidth_template).cpu()

            # fetch one real batch for patching (prefer VAL, fallback TRAIN)
            patch_loader = self._build_loader_from_config("val") or self._build_loader_from_config("train")
            patch_batch = None
            if patch_loader is not None:
                try:
                    xb, _ = next(iter(patch_loader))
                    patch_batch = xb[:128]  # keep it light and deterministic
                except Exception:
                    patch_batch = None

            # prune with real-batch context so exit taps & FC get patched correctly
            pruned_cpu = hard_prune_resnet(
                base_cpu,
                snip_mask,
                patch_sample=patch_batch,
                patch_loader=patch_loader,
                patch_max_per_batch=256,
            ).cpu()

            self.net = pruned_cpu.to(target_device, non_blocking=False)
            print(f"[client {self.cid}] SNIP mask applied at init. Model structurally pruned.")

            # ensure head preconv BNs are retargeted (extra safety)
            retarget_exit_preconv_bns(self.net, verbose=True)

            # refresh payload/key lists and optimizer wiring after structural surgery
            self._refresh_keylists_after_surgery()
            self._apply_trainability_after_surgery()

            # track sig & request one-time BN recal later
            self._last_mask_sig = self._mask_sig(snip_mask)
            self._need_bn_recal = True
        else:
            # Width-only clients: ensure net is on target device
            try:
                if target_device and str(target_device).lower().startswith("cuda"):
                    self.net = self.net.to(target_device, non_blocking=False)
            except Exception:
                pass

        # (re)build lists & shapes and apply trainability/optimizer wiring
        self._refresh_keylists_after_surgery()
        self._apply_trainability_after_surgery()
        
    def _inspect_loader(self, loader, tag: str):
        """Print what this loader is (dataset, size, first-batch info)."""
        try:
            ds = getattr(loader, "dataset", None)
            ds_name = type(ds).__name__ if ds is not None else "UnknownDataset"
            n_ds = len(ds) if ds is not None else "?"
            n_batches = len(loader) if hasattr(loader, "__len__") else "?"
            print(f"[client {self.cid}] CALIB[{tag}] dataset={ds_name} size={n_ds} batches={n_batches}")
            it = iter(loader)
            xb, yb = next(it)
            bs = xb.size(0)
            yb_np = yb.detach().cpu().numpy().ravel().tolist()
            from collections import Counter
            hist = Counter(yb_np)
            top = ", ".join([f"{k}:{v}" for k, v in sorted(hist.items())[:10]])
            print(f"[client {self.cid}] CALIB[{tag}] first_batch: shape={tuple(xb.shape)} labels({len(hist)} uniq)={{ {top} }}")
        except Exception as e:
            print(f"[client {self.cid}] CALIB[{tag}] inspect failed: {e}")

    def _build_loader_from_config(self, partition: str):
        """Construct a small, non-augmented loader on demand (VAL ONLY)."""
        try:
            if str(partition).lower() != "val":
                raise RuntimeError(f"[client {self.cid}] Only VAL calibration is allowed (got partition={partition}).")
            data_cfg = self.ckp.config.data
            data_cls = get_func_from_config(data_cfg)
            dataset  = data_cls(self.ckp, **data_cfg.args)
            bs = int(getattr(self.ckp.config.app.eval_fn, "batch_size", 64))
            dl = dataset.get_dataloader(
                data_pool="server",
                partition="val",
                batch_size=bs,
                augment=False,
                num_workers=0,
            )
            return dl
        except Exception as e:
            print(f"[client {self.cid}] CALIB build(val) failed: {e}")
            return None

    def _pick_calibration_loader(self):
        """
        Prefer a VALIDATION loader for calibration (SNIP/BN). Fallback to TRAIN if missing.
        Returns (loader, tag) and logs dataset stats to prove source.
        """
        # Attached validation loader first
        for name in ("valloader", "val_loader", "loader_val", "val_dl", "validloader", "valid_loader"):
            if hasattr(self, name) and getattr(self, name) is not None:
                tag = f"attached:{name}"
                self._inspect_loader(getattr(self, name), tag)
                return getattr(self, name), tag

        # Then attached train loader
        for name in ("trainloader", "train_loader", "loader_train", "train_dl"):
            if hasattr(self, name) and getattr(self, name) is not None:
                tag = f"attached:{name}"
                self._inspect_loader(getattr(self, name), tag)
                return getattr(self, name), tag

        # Build ephemeral VAL, then TRAIN
        val_dl = self._build_loader_from_config("val")
        if val_dl is not None:
            tag = "ephemeral:val"
            self._inspect_loader(val_dl, tag)
            return val_dl, tag

        tr_dl = self._build_loader_from_config("train")
        if tr_dl is not None:
            tag = "ephemeral:train"
            self._inspect_loader(tr_dl, tag)
            return tr_dl, tag

        raise RuntimeError(f"[client {self.cid}] Could not find a val/train DataLoader for calibration.")

    def _mask_sig(self, mask: Dict[str, List[int]]) -> str:
        """Stable signature for a SNIP mask to avoid reapplying identical masks."""
        parts = []
        for k in sorted(mask.keys()):
            v = mask[k]
            if isinstance(v, torch.Tensor):
                v = v.detach().cpu().tolist()
            parts.append(f"{k}:{','.join(str(int(i)) for i in v)}")
        import hashlib
        return hashlib.sha1("|".join(parts).encode("utf-8")).hexdigest()[:12]

    def get_parameters(self, config=None):
        sd = self.net.state_dict()
        keys = getattr(self, "_payload_keys", None) or self.local_sd_keys
        out = []
        for k in keys:
            if k not in sd:
                continue
            arr = sd[k].detach().cpu().numpy()
            if isinstance(getattr(self, "_recv_shapes", None), dict) and (k in self._recv_shapes):
                target = self._recv_shapes[k]
                if tuple(arr.shape) != tuple(target):
                    tmp = np.zeros(target, dtype=arr.dtype)
                    n = min(tmp.size, arr.size)
                    if n > 0:
                        tmp.reshape(-1)[:n] = arr.reshape(-1)[:n]
                    arr = tmp
            out.append(arr)
        return out
    
    def __get_parameters(self, config=None):
        """
        Return tensors in the order decided for THIS round.
        - Pure scale/hetero: if server asked for full local state, include params+buffers
        - SNIP or normal training: trainables only
        """
        sd = self.net.state_dict()
        keys = getattr(self, "_payload_keys", None) or self.local_sd_keys
        return [sd[k].detach().cpu().numpy() for k in keys]

    def set_parameters(self, parameters, keys_override=None):
        """
        Tolerant loader that respects server-provided key order.
        """
        model_sd = self.net.state_dict()
        keys = list(keys_override) if keys_override is not None else list(self.trainable_state_keys)
        exp = len(keys)
        got = len(parameters)

        upto = min(got, exp)
        for k, v in zip(keys[:upto], parameters[:upto]):
            t = torch.from_numpy(np.copy(v))
            ref = model_sd[k]

            if t.dtype != ref.dtype:
                t = t.to(ref.dtype)
            if t.dim() != ref.dim():
                if t.numel() == ref.numel():
                    t = t.view_as(ref)
                else:
                    print(f"[client {self.cid}] WARN set_parameters: rank mismatch for {k}: "
                        f"inc={tuple(t.shape)} ref={tuple(ref.shape)}; skipping key.")
                    continue
            if tuple(t.shape) != tuple(ref.shape):
                slices = tuple(slice(0, min(a, b)) for a, b in zip(t.shape, ref.shape))
                t2 = torch.zeros_like(ref)
                t2[slices] = t[slices]
                t = t2

            model_sd[k] = t

        self.net.load_state_dict(model_sd, strict=False)

    # def _compute_and_apply_snip_locally(self, keep_ratio: float):
    #     """
    #     One-shot SNIP on this client's own loader (VAL preferred).
    #     Then structural prune and BN-recal using the same loader.
    #     """
    #     try:
    #         loader, tag = self._pick_calibration_loader()
    #     except Exception as e:
    #         print(f"[client {self.cid}] Local SNIP skipped (no calib loader: {e}).")
    #         return

    #     device = next(self.net.parameters()).device
    #     self.net.train().to(device)

    #     # How many batches are available/used
    #     try:
    #         len_loader = len(loader)
    #     except TypeError:
    #         len_loader = None
    #     want_batches = 5
    #     used_batches = want_batches if (len_loader is None) else min(want_batches, len_loader)

    #     print(f"[client {self.cid}] SNIP using {tag} | keep_ratio={keep_ratio} | "
    #         f"batches_used={used_batches}/{len_loader if len_loader is not None else '?'}")

    #     # --- SNIP scores from chosen loader ---
    #     scores = compute_snip_channel_scores(
    #         model=self.net,
    #         onebatch_or_loader=loader,
    #         loss_fn=torch.nn.CrossEntropyLoss(),
    #         device=str(device),
    #         num_batches=used_batches,
    #         max_per_batch=256,
    #         log_prefix=f"[client {self.cid}] SNIP[{tag}] "
    #     )
    #     masks = create_resnet_channel_masks(scores, keep_ratio=float(keep_ratio), model=self.net)

    #     # --- Structural prune ---
    #     self.net = hard_prune_resnet(self.net, masks).to(device)

    #     # --- BN recal with the SAME loader ---
    #     print(f"[client {self.cid}] BN recalibration using {tag}")
    #     recalibrate_bn(self.net, loader, device=str(device), num_batches=200, log_prefix=f"[client {self.cid}] BN[{tag}] ")

    #     print(f"[client {self.cid}] Local SNIP+BN completed.")

    # def _pick_train_loader(self):
    #     # Be permissive about attribute naming.
    #     for name in ("trainloader", "train_loader", "loader_train", "train_dl"):
    #         if hasattr(self, name):
    #             return getattr(self, name)
    #     raise RuntimeError(f"[client {self.cid}] Could not find a train DataLoader for BN recalibration.")

    def fit(self, parameters, config):
        import json, ast
        from hashlib import sha1

        cfg = dict(config or {})

        # ----- Decode server order (params only) -----
        raw = cfg.get("keys_prog", None)
        order = None
        if isinstance(raw, (list, tuple)):
            order = list(raw)
        elif isinstance(raw, str):
            try:
                order = json.loads(raw)
            except Exception:
                try:
                    parsed = ast.literal_eval(raw)
                    if isinstance(parsed, (list, tuple)):
                        order = list(parsed)
                except Exception:
                    order = None

        if not order:
            order = list(self.local_sd_keys)
            print(f"[client {self.cid}] ScaleFL.fit(): keys_prog missing → TRAINABLE payload ({len(order)} keys)")

        # Decide outgoing payload for this round (trainable vs full)
        send_full = bool(cfg.get("send_full_local_state", False))
        self._payload_keys = self.local_full_keys if send_full else order

        # --- SNIP bits elided for brevity (your original code stays the same) ---
        snip_mask = cfg.get("snip_mask", None)
        if isinstance(snip_mask, str):
            try:
                snip_mask = ast.literal_eval(snip_mask)
            except Exception:
                snip_mask = None
        have_mask = isinstance(snip_mask, dict) and len(snip_mask) > 0
        pure_scale = (not have_mask) and (getattr(self, "_last_mask_sig", None) is None)
        if have_mask:
            sig = self._mask_sig(snip_mask)
            if sig != getattr(self, "_last_mask_sig", None):
                print(f"[client {self.cid}] Received new SNIP mask (sig={sig}). Rebuilding from full-width template.")
                dev = next(self.net.parameters()).device
                base = copy.deepcopy(self._fullwidth_template).to(dev)
                self.net = hard_prune_resnet(base, snip_mask).to(dev)
                retarget_exit_preconv_bns(self.net, verbose=True)
                self._need_bn_recal = False
                self._last_mask_sig = sig
                self._refresh_keylists_after_surgery()
                self._apply_trainability_after_surgery()
                self._need_bn_recal = True
            else:
                print(f"[client {self.cid}] SNIP mask unchanged (sig={sig}); skipping re-prune).")

        # Always load using TRAINING order
        self.set_parameters(parameters, keys_override=order)

        # BN recal policy (unchanged, shortened)
        if pure_scale:
            cfg.pop("need_bn_recal", None)
            cfg.pop("bn_calib_source", None)
            self._need_bn_recal = False
        else:
            if cfg.get("need_bn_recal", False) and (have_mask or self._last_mask_sig is not None):
                self._need_bn_recal = True

        # Optional head-only block (unchanged)
        train_head_only = bool(cfg.get("train_head_only", False))
        _restore_reqs = None
        if train_head_only:
            print(f"[client {self.cid}] head-only training this round")
            _restore_reqs = {n: p.requires_grad for n, p in self.net.named_parameters()}
            head_set = set(self.head_only_keys)
            for n, p in self.net.named_parameters():
                p.requires_grad = (n in head_set)

            if hasattr(self, "optimizer") and self.optimizer is not None:
                new_params = [p for n, p in self.net.named_parameters() if p.requires_grad]
                if len(self.optimizer.param_groups) == 0:
                    self.optimizer.add_param_group({"params": new_params})
                else:
                    self.optimizer.param_groups[0]["params"] = new_params
                    for g in self.optimizer.param_groups[1:]:
                        g["params"].clear()

        # Force progressive training this round
        prev_depth = getattr(self, "depth_training", False)
        self.depth_training = True
        try:
            self.net.train()
            result = super().fit(parameters, cfg)
        finally:
            self.depth_training = prev_depth

        # Restore trainability after head-only (unchanged)
        if _restore_reqs is not None:
            for n, p in self.net.named_parameters():
                if p.requires_grad != _restore_reqs.get(n, True):
                    p.requires_grad = _restore_reqs[n]
            if hasattr(self, "optimizer") and self.optimizer is not None:
                new_params = [p for n, p in self.net.named_parameters() if p.requires_grad]
                if len(self.optimizer.param_groups) == 0:
                    self.optimizer.add_param_group({"params": new_params})
                else:
                    self.optimizer.param_groups[0]["params"] = new_params
                    for g in self.optimizer.param_groups[1:]:
                        g["params"].clear()

        # Rebuild payload so its length matches exactly self._payload_keys
        if isinstance(result, tuple) and len(result) == 3 and isinstance(result[2], dict):
            sd = self.net.state_dict()
            payload = [sd[k].detach().cpu().numpy() for k in self._payload_keys if k in sd]
            # annotate hints
            ks = self._payload_keys
            key_sig = sha1("|".join(ks).encode()).hexdigest()[:10]
            key_kind = "full" if set(ks) == set(self.local_full_keys) else "train"
            width = float(getattr(self, "width_scale", 1.0))
            result[2].setdefault("width_scale", width)
            result[2].setdefault("key_sig", key_sig)
            result[2].setdefault("keyset_kind", key_kind)
            result = (payload, result[1], result[2])

        return result
    
    def _apply_trainability_after_surgery(self):
        """Freeze/unfreeze according to local_sd_keys and retarget the optimizer."""
        train_set = set(self.local_sd_keys)

        any_train = False
        for n, p in self.net.named_parameters():
            req = (n in train_set)
            if p.requires_grad != req:
                p.requires_grad = req
            any_train = any_train or req

        if not any_train:
            print(f"[client {self.cid}] WARNING: no trainable parameters after surgery.")

        # Retarget optimizer param groups to current live params (best-effort, keeps hyperparams)
        if hasattr(self, "optimizer") and self.optimizer is not None:
            new_params = [p for p in self.net.parameters() if p.requires_grad]
            if len(self.optimizer.param_groups) == 0:
                self.optimizer.add_param_group({"params": new_params})
            else:
                # Put all into the first group to avoid stale references
                self.optimizer.param_groups[0]["params"] = new_params
                # Drop any extra groups that might hold stale params
                for g in self.optimizer.param_groups[1:]:
                    g["params"].clear()

    def _refresh_keylists_after_surgery(self):
        """After structural changes, keep original payload order and drop missing keys."""
        # Live names
        live_sd = self.net.state_dict()
        live_sd_keys = set(live_sd.keys())
        live_param_names = [n for n, _ in self.net.named_parameters()]  # preserves module order
        live_param_set = set(live_param_names)

        # Old lists from ReeFL base (preserve their order)
        old_full = getattr(self, "local_full_keys", []) or []
        old_sd   = getattr(self, "local_sd_keys", []) or []
        old_head = getattr(self, "head_only_keys", []) or []

        # Head prefix from ReeFL (last local exit inside the truncated model)
        head_prefix = getattr(self, "_head_prefix", f"exit_heads.{self.lid}.")

        # 1) Full local state (params + buffers) for this exit: keep old order, drop missing
        if old_full:
            self.local_full_keys = [k for k in old_full if k in live_sd_keys]
        else:
            self.local_full_keys = [
                k for k in live_sd.keys()
                if (not k.startswith("exit_heads.")) or k.startswith(head_prefix)
            ]

        # 2) Trainable payload (progressive subset): keep old order, drop missing
        if old_sd:
            self.local_sd_keys = [k for k in old_sd if k in live_param_set]
        else:
            self.local_sd_keys = [
                k for k in live_param_names
                if (not k.startswith("exit_heads.")) or k.startswith(head_prefix)
            ]

        # 3) Head-only keys: keep old order, drop missing
        if old_head:
            self.head_only_keys = [k for k in old_head if k in live_param_set]
        else:
            self.head_only_keys = [k for k in live_param_names if k.startswith(head_prefix)]

        # 4) Keep trainable_state_keys aligned with payload order
        self.trainable_state_keys = list(self.local_sd_keys)

        # 5) Cache shapes for sanity
        self._expected_shapes = {k: v.shape for k, v in live_sd.items()}

        try:
            n_params = sum(p.numel() for p in self.net.parameters())
            n_train  = sum(p.numel() for n, p in self.net.named_parameters() if n in self.local_sd_keys)
            print(f"[client {self.cid}] after-surgery: params={n_params} payload_params={n_train} "
                f"local_full={len(self.local_full_keys)} local_sd={len(self.local_sd_keys)} head_only={len(self.head_only_keys)}")
        except Exception:
            pass

        print(
            f"[client {self.cid}] lid={self.lid} "
            f"local_keys={len(self.local_sd_keys)} "
            f"(first={self.local_sd_keys[:3] if self.local_sd_keys else []})"
        )
