from src.server.strategies import EarlyExitFedAvgSC
import torch
import numpy as np
import copy
import logging
import hashlib
from typing import Dict, Optional, Tuple, List
from flwr.server.client_proxy import ClientProxy
from src.utils import get_func_from_config
from flwr.common import (
    EvaluateRes, Parameters, Scalar, FitRes,
    parameters_to_weights, weights_to_parameters,
)

import torch.nn as nn  # add this
from src.models.snip_utils import (  # add this
    compute_snip_channel_scores,
    create_resnet_channel_masks,
)

from src.models.model_utils import prune
from collections import defaultdict

logger = logging.getLogger(__name__)
NDArray = np.ndarray
NDArrays = List[NDArray]

# --- robust parser for config lists that sometimes come as strings ---
import ast
def _parse_modes(obj):
    if isinstance(obj, (list, tuple)):
        return [str(x).lower() for x in obj]
    if isinstance(obj, str):
        s = obj.strip()
        if s.startswith("["):
            try:
                vals = ast.literal_eval(s)
                return [str(x).lower() for x in vals]
            except Exception:
                return [s.lower()]
        return [s.lower()]
    return []


def _keywise_mask(name: str, shape, scale: float):
    import math

    # Classifier: keep outputs (num_classes), scale inputs only
    if name.endswith(".fc.weight") or name.endswith("linear.weight"):
        O, I = shape
        i = max(1, math.floor(I * scale))
        m = torch.zeros((O, I), dtype=torch.bool)
        m[:, :i] = True
        return m
    if name.endswith(".fc.bias") or name.endswith("linear.bias"):
        return torch.ones(shape, dtype=torch.bool)

    # Stem conv: scale OC, keep all IC (=3) intact
    if len(shape) == 4 and name == "conv1.weight":
        OC, IC, KH, KW = shape
        oc = max(1, math.floor(OC * scale))
        m = torch.zeros((OC, IC, KH, KW), dtype=torch.bool)
        m[:oc, :, :, :] = True
        return m

    # Other convs: scale both OC and IC
    if len(shape) == 4:
        OC, IC, KH, KW = shape
        oc = max(1, math.floor(OC * scale))
        ic = max(1, math.floor(IC * scale))
        m = torch.zeros((OC, IC, KH, KW), dtype=torch.bool)
        m[:oc, :ic, :, :] = True
        return m

    # BN/affine vectors
    if len(shape) == 1:
        C = shape[0]
        c = max(1, math.floor(C * scale))
        m = torch.zeros((C,), dtype=torch.bool)
        m[:c] = True
        return m

    # Generic linear
    if len(shape) == 2:
        O, I = shape
        o = max(1, math.floor(O * scale))
        i = max(1, math.floor(I * scale))
        m = torch.zeros((O, I), dtype=torch.bool)
        m[:o, :i] = True
        return m

    return torch.ones(shape, dtype=torch.bool)


def _sig_for_keys(keys: List[str]) -> str:
    return hashlib.sha1("|".join(keys).encode("utf-8")).hexdigest()[:10]

def _shape(t): 
    import torch, numpy as np
    if isinstance(t, np.ndarray): return tuple(t.shape)
    if isinstance(t, torch.Tensor): return tuple(t.shape)
    return None

def validate_pruned(update_keys, ref_sd, pruned_map, logger):
    bad = []
    miss = []
    for k in update_keys:
        if k not in pruned_map:
            miss.append(k); 
            continue
        if _shape(pruned_map[k]) != _shape(ref_sd[k]):
            bad.append(f"{k}: pruned{_shape(pruned_map[k])} != ref{_shape(ref_sd[k])}")
    if miss or bad:
        logger.error("[SNIP] Index-aware prune mismatch.\n"
                     + (f"Missing keys: {miss[:6]}..." if miss else "")
                     + ("\nMismatched: " + "\n".join(bad[:12]) if bad else ""))
        return False
    return True

def _build_global_val_loader(ckp):
    """
    Build a central **validation** DataLoader for server-side SNIP.
    - Uses ONLY (data_pool='server', partition='val')
    - If that cannot be built, raises RuntimeError (no train fallback).
    - Prints proof logs so it's clear we're using validation data.
    """
    from collections import Counter
    from src.utils import get_func_from_config

    data_cfg = ckp.config.data
    data_cls = get_func_from_config(data_cfg)
    dataset  = data_cls(ckp, **data_cfg.args)
    bs = int(getattr(ckp.config.app.eval_fn, "batch_size", 64))

    try:
        dl = dataset.get_dataloader(
            data_pool="server",
            partition="val",
            batch_size=bs,
            augment=False,
            num_workers=0,
        )
    except Exception as e:
        # Hard fail: you asked for validation only.
        raise RuntimeError(f"[SERVER] No central validation split available (server/val): {e}")

    # Log proof-of-use (dataset, size, first-batch label histogram)
    try:
        ds = getattr(dl, "dataset", None)
        ds_name = type(ds).__name__ if ds is not None else "UnknownDataset"
        n_ds = len(ds) if ds is not None else "?"
        n_batches = len(dl) if hasattr(dl, "__len__") else "?"
        print(f"[SERVER] Using VALIDATION split (server/val): dataset={ds_name} size={n_ds} batches={n_batches}")
        xb, yb = next(iter(dl))
        hist = Counter(yb.detach().cpu().numpy().ravel().tolist())
        top = ", ".join([f"{k}:{v}" for k, v in sorted(hist.items())[:10]])
        print(f"[SERVER] VAL first_batch: shape={tuple(xb.shape)} labels({len(hist)} uniq)={{ {top} }}")
    except Exception as e:
        print(f"[SERVER] Validation loader inspection failed: {e}")

    return dl

class ScaleFLFedAvg(EarlyExitFedAvgSC):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        assert self.aggregation == "fedavg"
        self.clients_exit: Dict[str, int] = {}
        self._mask_cache = {}
        self._snip_mask_cache = {}

        self._inclusive_mode = str(
            getattr(self.ckp.config.app.args, "training_loss", "exclusive")
        ).lower() == "inclusive"

        # --- SNIP toggles ---
        self.pruning_mode = _parse_modes(getattr(self.ckp.config.app.args, "pruning_mode", []))
        self._snip_any = any(m == "snip" for m in self.pruning_mode)

        self._central_snip_enabled: bool = True
        self._central_snip_batches: int = 64
        # default to 32; can be overridden via YAML: app.args.snip_batches
        self._central_snip_batches: int = int(getattr(self.ckp.config.app.args, "snip_batches", 540))
        # optional periodic refresh (0 = never after cold start)
        self._snip_refresh_every: int = int(getattr(self.ckp.config.app.args, "snip_refresh_every", 10))
        self._server_val_loader = None
        self._snip_masks_by_exit: Dict[int, Dict[str, List[int]]] = {}

        # Use parent's discovered orders/snapshots
        self._all_full_keys = list(self.all_sd_keys)  # full (params + buffers)
        self._full_template_params = weights_to_parameters(self._initial_full_weights)
        self._last_global_params = self._full_template_params
        self._prefix_param_idxs: Dict[int, Dict[str, List[torch.Tensor]]] = {}

        self.trainable_sd_keys = list(self.global_sd_keys)
        self.width_scaling = self.ckp.config.app.args.width_scaling

        assert len(self.width_scaling) == self.no_of_exits

        arch_fn = get_func_from_config(self.net_config)

        # Per-exit lists and prune indices
        self.exit_local_sd_keys: Dict[int, List[str]] = {}
        self.exit_full_sd_keys: Dict[int, List[str]] = {}
        self.param_idxs: Dict[int, Dict[str, List[torch.Tensor]]] = {}

        # index maps / caches
        self._exit_index_map: Dict[int, Dict[str, Dict[str, List[int]]]] = {}
        self._round_index_map: Dict[str, Dict[str, Dict[str, List[int]]]] = {}

        for exit_i, width_scale in enumerate(self.width_scaling):
            net_args = copy.deepcopy(self.net_config.args)
            blk_to_exit = self.blks_to_exit[exit_i]
            net_args["depth"]        = blk_to_exit + 1
            net_args["blks_to_exit"] = self.blks_to_exit[: exit_i + 1]
            net_args["no_of_exits"]  = exit_i + 1
            net_args["width_scale"]  = width_scale

            local_net = arch_fn(device="cpu", **net_args)
            sd_local = local_net.state_dict()

            model_trainables = [k for k in local_net.trainable_state_dict_keys if k in self.global_sd_keys]

            if self._inclusive_mode:
                prefixes = [f"exit_heads.{j}." for j in range(exit_i + 1)]
                def _ok(k: str) -> bool:
                    if not k.startswith("exit_heads."):
                        return True
                    return any(k.startswith(p) for p in prefixes)
            else:
                head_prefix = f"exit_heads.{exit_i}."
                def _ok(k: str) -> bool:
                    return (not k.startswith("exit_heads.")) or k.startswith(head_prefix)

            local_keys = [k for k in model_trainables if _ok(k)]
            self.exit_local_sd_keys[exit_i] = local_keys

            sd_full_keys = [k for k in sd_local.keys() if _ok(k) or ('.running_' in k) or ('num_batches_tracked' in k)]
            self.exit_full_sd_keys[exit_i] = [k for k in sd_full_keys if k in self._all_full_keys]

            idx = {}
            for k, v in sd_local.items():
                if k in model_trainables:
                    idx[k] = [torch.arange(s) for s in v.shape]
            self.param_idxs[exit_i] = idx
            self._prefix_param_idxs[exit_i] = copy.deepcopy(idx)
            logger.info(f"[ScaleFL] exit {exit_i}: width={width_scale} trainables={len(local_keys)}")

                # ---- NEW: global index caches for per-client selectors (used by Owen) ----
        # Map each global *trainable* key to its index in self.global_sd_keys
        self._global_index_of: Dict[str, int] = {k: i for i, k in enumerate(self.global_sd_keys)}

        # For each exit, precompute the indices (into the GLOBAL trainables order)
        # that this exit’s client actually trains/sends.
        self._exit_to_indices: Dict[int, List[int]] = {}
        for exit_i, keys in self.exit_local_sd_keys.items():
            idxs = [self._global_index_of[k] for k in keys if k in self._global_index_of]
            self._exit_to_indices[exit_i] = idxs

        # Useful for Owen (and for sanity logs)
        self._global_trainables_len: int = len(self.global_sd_keys)
        # Precompute masks per level for full global shapes
        full_params = self._ensure_full_global()
        full_list = parameters_to_weights(full_params)
        full_shapes = {k: np.asarray(w).shape for k, w in zip(self._all_full_keys, full_list)}
        self._full_shapes = full_shapes

        self._level_idx = []
        for s in self.width_scaling:
            masks = {k: _keywise_mask(k, full_shapes[k], s) for k in self._all_full_keys}
            self._level_idx.append(masks)

        # Keyset signatures (sig -> (lid, kind, keys))
        self._sig2keys: Dict[str, Tuple[int, str, List[str]]] = {}
        for lid, keys in self.exit_local_sd_keys.items():
            self._sig2keys[_sig_for_keys(keys)] = (lid, "train", keys)
        for lid, keys in self.exit_full_sd_keys.items():
            self._sig2keys[_sig_for_keys(keys)] = (lid, "full", keys)
        self._sig2keys[_sig_for_keys(self.global_sd_keys)] = (-1, "global_train", self.global_sd_keys)
        self._sig2keys[_sig_for_keys(self._all_full_keys)] = (-1, "full_all", self._all_full_keys)

    def _ensure_server_val(self):
        if self._server_val_loader is None:
            print("[SERVER] Building global validation loader for SNIP...")
            self._server_val_loader = _build_global_val_loader(self.ckp)

        # -------- Public API consumed by Owen --------

    def get_trainable_indices_for_exit(self, exit_i: int) -> Optional[List[int]]:
        """Return indices (into GLOBAL trainables) that belong to this exit."""
        try:
            return list(self._exit_to_indices[int(exit_i)])
        except Exception:
            return None

    def get_trainable_indices_for_client(self, cid: str) -> Optional[List[int]]:
        """
        Return indices (into GLOBAL trainables) that this client’s payload corresponds to
        *this round*, based on its current exit assignment.
        """
        try:
            # Resolve the client's exit (use your existing mapping)
            if hasattr(self, "clients_exit") and isinstance(self.clients_exit, dict) and cid in self.clients_exit:
                exit_i = int(self.clients_exit[cid])
            else:
                # Fallback to your deterministic lid heuristic if needed
                exit_i = int(cid) % max(1, self.no_of_exits)

            return self.get_trainable_indices_for_exit(exit_i)
        except Exception:
            return None

    def get_global_trainables_len(self) -> Optional[int]:
        """Length of the GLOBAL trainables vector/order (self.global_sd_keys)."""
        try:
            return int(self._global_trainables_len)
        except Exception:
            return None

    # (Optional) handy for diagnostics
    def get_clients_exit_map(self) -> Dict[str, int]:
        """Return a shallow copy of the current cid → exit mapping."""
        try:
            return dict(getattr(self, "clients_exit", {}))
        except Exception:
            return {}
        
    def _compute_central_masks_for_exits(self) -> Dict[int, Dict[str, List[int]]]:
        """Compute one full-width SNIP mask per exit on central VAL; log batches/samples."""
        self._ensure_server_val()
        valloader = self._server_val_loader
        device = torch.device("cpu")
        arch_fn = get_func_from_config(self.net_config)
        masks_by_exit: Dict[int, Dict[str, List[int]]] = {}

        for exit_i, keep in enumerate(self.width_scaling):
            if str(self.pruning_mode[exit_i]).lower() != "snip":
                masks_by_exit[exit_i] = {}
                continue
            if float(keep) >= 1:
                # No pruning at full width; don’t score
                logger.info(f"[SERVER][SNIP] exit {exit_i}: keep≈1.0 → skip SNIP scoring")
                masks_by_exit[exit_i] = {}
                continue
            net_args = copy.deepcopy(self.net_config.args)
            blk_to   = self.blks_to_exit[exit_i]
            net_args["depth"]        = blk_to + 1
            net_args["blks_to_exit"] = self.blks_to_exit[: exit_i + 1]
            net_args["no_of_exits"]  = exit_i + 1
            net_args["width_scale"]  = 1.0

            model = arch_fn(device="cpu", **net_args).train()

            try: n_batches = len(valloader)
            except TypeError: n_batches = None
            used = self._central_snip_batches if n_batches is None else min(self._central_snip_batches, n_batches)

            tag = f"[SERVER][SNIP][exit {exit_i}] "
            print(f"{tag}VAL scoring | keep_ratio={float(keep):.3f} | batches_used={used}/{n_batches if n_batches is not None else '?'}")
            scores = compute_snip_channel_scores(
                model=model,
                onebatch_or_loader=valloader,
                loss_fn=nn.CrossEntropyLoss(),
                device=str(device),
                num_batches=max(64, used),
                max_per_batch=540,
                log_prefix=tag,
            )
            mask = create_resnet_channel_masks(scores, keep_ratio=float(keep), model=model)
            masks_by_exit[exit_i] = {k: [int(i) for i in v] for k, v in mask.items()}
            print(f"{tag}built mask: layers={len(mask)}, kept_channels_total={sum(len(v) for v in mask.values())}")
        return masks_by_exit

    def _get_mask_for_exit(self, exit_i: int) -> Dict[str, List[int]]:
        """Return the cached SNIP mask for exit_i.
        If missing, (re)compute ALL per-exit masks on the central VALIDATION loader only.
        Non-SNIP exits or keep_ratio≈1.0 return an empty dict.
        """
        # Respect pruning mode and trivial keep_ratio
        if exit_i >= len(self.pruning_mode) or str(self.pruning_mode[exit_i]).lower() != "snip":
            self._snip_mask_cache.setdefault(exit_i, {})
            return self._snip_mask_cache[exit_i]

        keep_ratio = float(self.width_scaling[exit_i]) if exit_i < len(self.width_scaling) else 1.0
        if keep_ratio >= 0.999:
            self._snip_mask_cache.setdefault(exit_i, {})
            # logger.info(f"[SNIP] exit {exit_i}: keep_ratio≈1.0 → no pruning")
            return self._snip_mask_cache[exit_i]

        # If cached, use it
        if exit_i in self._snip_mask_cache:
            return self._snip_mask_cache[exit_i]

        # Compute all masks on central VALIDATION only (with loud logs), then return requested exit
        self._snip_masks_by_exit = self._compute_central_masks_for_exits()
        self._snip_mask_cache = copy.deepcopy(self._snip_masks_by_exit)
        return self._snip_mask_cache.get(exit_i, {})

    def _get_or_compute_snip_mask(self, exit_i: int, keep_ratio: Optional[float] = None) -> Dict[str, List[int]]:
        """
        Compute a SNIP mask for exit_i (width==1.0 global space) and keep it in memory.
        keep_ratio defaults to self.width_scaling[exit_i].
        Returns {module_name: [kept_out_channel_indices]}.
        """
        if exit_i in self._snip_mask_cache and self._snip_mask_cache[exit_i]:
            return self._snip_mask_cache[exit_i]

        try:
            kr = float(self.width_scaling[exit_i]) if keep_ratio is None else float(keep_ratio)
        except Exception:
            kr = 1.0
        kr = max(0.0, min(1.0, kr))
        if kr >= 1.0:
            self._snip_mask_cache[exit_i] = {}
            # logger.info(f"[SNIP] exit {exit_i}: keep_ratio=1.0 (no pruning); using empty mask")
            return self._snip_mask_cache[exit_i]

        # ----- Build the *global-width* subnetwork up to this exit -----
        arch_fn  = get_func_from_config(self.net_config)
        net_args = copy.deepcopy(self.net_config.args)
        blk_to   = self.blks_to_exit[exit_i]
        net_args["depth"]          = blk_to + 1
        net_args["blks_to_exit"]   = self.blks_to_exit[: exit_i + 1]
        net_args["no_of_exits"]    = exit_i + 1
        net_args["width_scale"]    = 1.0          # indices live in global channel space
        net_args["last_exit_only"] = True

        model = arch_fn(device="cpu", **net_args)

        # ----- One mini-batch (server/train → train/train → dummy) -----
        data_cfg = self.ckp.config.data
        data_cls = get_func_from_config(data_cfg)
        dataset  = data_cls(self.ckp, **data_cfg.args)

        bs = int(getattr(self.ckp.config.app.eval_fn, "batch_size", 32))
        loader = None
        for pool, part in (("server","train"), ("train","train")):
            try:
                loader = dataset.get_dataloader(
                    data_pool=pool, partition=part,
                    batch_size=bs, augment=False, num_workers=0
                )
                break
            except Exception:
                loader = None
        if loader is None:
            raise
            C, H, W = 3, 32, 32
            num_classes = int(net_args.get("num_classes", 10))  # <-- bugfix: dict .get
            class _Dummy:
                def __iter__(self_inner):
                    x = torch.randn(bs, C, H, W)
                    y = torch.randint(0, num_classes, (bs,))
                    yield x, y
            loader = _Dummy()

        scores = compute_snip_channel_scores(model, loader, nn.CrossEntropyLoss(), device="cpu")
        mask   = create_resnet_channel_masks(scores, keep_ratio=kr, model=model)
        mask   = {k: (v if isinstance(v, list) else list(v)) for k, v in mask.items()}  # JSONable

        self._snip_mask_cache[exit_i] = mask
        logger.info(f"[SNIP] exit {exit_i}: computed in-memory mask with keep_ratio={kr:.3f}, layers={len(mask)}")
        return mask

    def _ensure_full_global(self) -> Parameters:
        arrs = parameters_to_weights(self._last_global_params)
        if len(arrs) != len(self._all_full_keys):   # <- was self.global_sd_keys
            logger.warning(
                "last_global_params not full (len=%d, expected %d) – resetting to template",
                len(arrs), len(self._all_full_keys)
            )
            self._last_global_params = self._full_template_params
        return self._last_global_params
    
    def _build_index_map_for_exit(self, exit_i: int) -> Dict[str, Dict[str, List[int]]]:
        """Build per-tensor index map from a VALIDATION-based SNIP mask (global width=1.0)."""
        mask = self._get_mask_for_exit(exit_i)
        if not mask:
            return {}

        arch_fn  = get_func_from_config(self.net_config)
        net_args = copy.deepcopy(self.net_config.args)
        blk_to   = self.blks_to_exit[exit_i]
        net_args["depth"]        = blk_to + 1
        net_args["blks_to_exit"] = self.blks_to_exit[: exit_i + 1]
        net_args["no_of_exits"]  = exit_i + 1
        net_args["width_scale"]  = 1.0  # indices in global channel space

        net = arch_fn(device="cpu", **net_args)
        sd  = net.state_dict()
        nmods = dict(net.named_modules())

        # Collect conv names in forward order AND which conv2s are in projection blocks
        conv_names, proj_conv2 = [], set()
        for s_idx, stage in enumerate(net.layers):
            for b_idx in range(len(stage)):
                blk = stage[b_idx]
                base = f"layers.{s_idx}.{b_idx}"
                c1, c2 = f"{base}.conv1", f"{base}.conv2"
                if c1 in nmods:
                    conv_names.append(c1)
                if c2 in nmods:
                    conv_names.append(c2)
                    if getattr(blk, "downsample", None) is not None:
                        proj_conv2.add(c2)

        def _keep_for(layer_base: str, dim0: int) -> List[int]:
            sel = mask.get(layer_base)
            if sel is None:
                return list(range(dim0))
            out = [int(i) for i in sel if 0 <= int(i) < dim0]
            return out if out else list(range(dim0))

        def _pair_conv_for_tensor_key(k: str) -> str:
            base = k.rsplit(".", 1)[0]
            if base == "bn1":
                return "conv1"
            if ".bn1" in base:
                return base.replace(".bn1", ".conv1")
            if ".bn2" in base:
                return base.replace(".bn2", ".conv2")
            if ".downsample.1" in base:
                return base.replace(".downsample.1", ".downsample.0")
            return base

        prev_out = None
        prev_in_for = {}
        for name in conv_names:
            W = sd.get(f"{name}.weight")
            if W is None:
                continue
            O, I = int(W.shape[0]), int(W.shape[1])
            out_idx = _keep_for(name, O)
            in_idx  = prev_out if prev_out is not None else list(range(I))
            prev_in_for[name] = in_idx
            prev_out = out_idx

        idxmap: Dict[str, Dict[str, List[int]]] = {}

        for k, arr in sd.items():
            if k.endswith(".weight"):
                base = k[:-7]
                if base in nmods and arr.ndim == 4:      # Conv
                    O, I = int(arr.shape[0]), int(arr.shape[1])
                    in_idx  = prev_in_for.get(base, list(range(I)))
                    out_idx = _keep_for(base, O)
                    idxmap[k] = {"out_idx": out_idx, "in_idx": in_idx}
                elif (k.endswith(".fc.weight") or k.endswith("linear.weight")) and arr.ndim == 2:
                    I = int(arr.shape[1])
                    feeder = next((n for n in reversed(conv_names) if n in nmods), None)
                    in_idx = _keep_for(feeder, I) if feeder else list(range(I))
                    idxmap[k] = {"in_idx": in_idx}

            elif arr.ndim == 1:
                if k.endswith(".fc.bias") or k.endswith("linear.bias"):
                    continue
                C = int(arr.shape[0])
                convb = _pair_conv_for_tensor_key(k)
                out_idx = _keep_for(convb, C)
                idxmap[k] = {"out_idx": out_idx}

        safe: Dict[str, Dict[str, List[int]]] = {}
        for k, spec in idxmap.items():
            W = sd[k]
            o = spec.get("out_idx")
            i = spec.get("in_idx")
            if W.ndim in (4, 2):
                OC, IC = int(W.shape[0]), int(W.shape[1])
                if o is not None:
                    o = [x for x in o if 0 <= x < OC]
                if i is not None:
                    i = [x for x in i if 0 <= x < IC]
            elif W.ndim == 1 and o is not None:
                C = int(W.shape[0])
                o = [x for x in o if 0 <= x < C]
            spec2 = {}
            if o is not None:
                spec2["out_idx"] = o
            if i is not None:
                spec2["in_idx"] = i
            if spec2:
                safe[k] = spec2

        return safe

    def _param_idxs_from_idxmap(self, exit_i: int, idxmap: Dict[str, Dict[str, List[int]]]) -> Dict[str, List[torch.Tensor]]:
        """Turn {out_idx,in_idx} per key into the param_idxs format expected by prune()."""
        keys = [k for k in self.exit_local_sd_keys[exit_i] if k in self.global_sd_keys]
        out: Dict[str, List[torch.Tensor]] = {}
        for k in keys:
            shp = self._full_shapes[k]  # slice global → local
            if len(shp) == 4:
                O, I, KH, KW = shp
                out_idx = torch.as_tensor(idxmap.get(k, {}).get("out_idx", list(range(O))), dtype=torch.long)
                in_idx  = torch.as_tensor(idxmap.get(k, {}).get("in_idx",  list(range(I))), dtype=torch.long)
                out[k] = [out_idx, in_idx, torch.arange(KH), torch.arange(KW)]
            elif len(shp) == 2:
                O, I = shp
                in_idx = torch.as_tensor(idxmap.get(k, {}).get("in_idx", list(range(I))), dtype=torch.long)
                out[k] = [torch.arange(O), in_idx]
            elif len(shp) == 1:
                C = shp[0]
                if k.endswith(".fc.bias") or k.endswith("linear.bias"):
                    out[k] = [torch.arange(C)]
                else:
                    out_idx = torch.as_tensor(idxmap.get(k, {}).get("out_idx", list(range(C))), dtype=torch.long)
                    out[k] = [out_idx]
            else:
                out[k] = [torch.arange(s) for s in shp]
        return out

    def _prepare_exit_indices_for_round(self, exit_i: int) -> None:
        use_snip = (
            getattr(self, "_snip_any", False)
            and exit_i < len(self.pruning_mode)
            and str(self.pruning_mode[exit_i]).lower() == "snip"
        )
        if not use_snip: return
        if exit_i not in self._exit_index_map:
            idxmap = self._build_index_map_for_exit(exit_i)
            if not idxmap: return
            self._exit_index_map[exit_i] = idxmap
        self.param_idxs[exit_i] = self._param_idxs_from_idxmap(exit_i, self._exit_index_map[exit_i])

    def _masks_for(self, lid: int, width_scale: float):
        """
        Return a dict[str -> torch.bool tensor] indicating which elements of each
        FULL global tensor should be filled from a width-scaled (prefix) client
        update. Shapes match the full global tensors.

        - If width_scale matches the configured scale for this exit, reuse the
        precomputed masks from __init__ (self._level_idx[lid]).
        - Otherwise, build masks on the fly and cache them.
        """
        s = float(width_scale)
        cache_key = (int(lid), round(s, 6))
        if cache_key in self._mask_cache:
            return self._mask_cache[cache_key]

        # Fast path: exactly the configured width for this exit
        try:
            if abs(s - float(self.width_scaling[lid])) < 1e-8:
                masks = self._level_idx[lid]  # already a dict[key -> bool tensor]
                self._mask_cache[cache_key] = masks
                return masks
        except Exception:
            pass

        # Build masks for this arbitrary scale (all keys, full shapes)
        masks = {}
        for k in self._all_full_keys:
            shp = self._full_shapes[k]
            try:
                masks[k] = _keywise_mask(k, shp, s)  # bool tensor shaped like full_sd[k]
            except Exception:
                # Very defensive fallback: keep-all
                masks[k] = torch.ones(shp, dtype=torch.bool)

        self._mask_cache[cache_key] = masks
        return masks

    def configure_fit(self, rnd, parameters: Parameters, client_manager):
        # no SNIP recompute unless pruning_mode contains "snip"
        # if self._central_snip_enabled and self._snip_any:
        #     print(f"[SERVER][SNIP] Recomputing per-exit masks on central VAL at round {rnd}...")
        #     self._snip_masks_by_exit = self._compute_central_masks_for_exits()
        #     self._snip_mask_cache = copy.deepcopy(self._snip_masks_by_exit)
        #     for exit_i in range(self.no_of_exits):
        #         self._prepare_exit_indices_for_round(exit_i)
        if self._central_snip_enabled and self._snip_any:
            need_cold = not getattr(self, "_snip_mask_cache", None)
            need_refresh = (getattr(self, "_snip_mask_cache", None) is not None
                            and self._snip_refresh_every > 0
                            and rnd % self._snip_refresh_every == 0)
            if need_cold or need_refresh:
                msg = "cold start" if need_cold else f"periodic refresh @ round {rnd}"
                print(f"[SERVER][SNIP] Computing per-exit masks on central VAL ({msg})…")
                self._snip_masks_by_exit = self._compute_central_masks_for_exits()
                self._snip_mask_cache = copy.deepcopy(self._snip_masks_by_exit)
                for exit_i in range(self.no_of_exits):
                    self._prepare_exit_indices_for_round(exit_i)

        base = super().configure_fit(rnd, self._last_global_params, client_manager)
        if not base:
            return base

        self._round_index_map = {}  # (no SNIP this run; keep empty)
        patched = []

        def _client_lid(cid: str) -> int:
            try:
                app_mode = getattr(self.ckp.config.app.args, "mode", "multi_tier")
            except Exception:
                app_mode = "multi_tier"
            if str(app_mode).lower() == "maximum":
                return max(0, self.no_of_exits - 1)
            return int(cid) % max(1, self.no_of_exits)

        if not hasattr(self, "clients_exit") or not isinstance(self.clients_exit, dict):
            self.clients_exit = {}

        for client, fitins in base:
            exit_i = self.clients_exit.get(client.cid, _client_lid(client.cid))

            # pure ScaleFL → trainables only; do NOT request BN/full state
            c = dict(fitins.config or {})
            c.update({
                "keyset_kind": "train",
                "send_full_local_state": False,
            })

            # If SNIP is enabled for this exit, push mask + one-time BN recal on VAL
            snip = self._get_mask_for_exit(exit_i)  # {} if non-SNIP or keep_ratio≈1
            if snip:
                c["snip_mask"] = snip
                c["need_bn_recal"] = True
                c["bn_calib_source"] = "val"
                c["bn_calib_batches"] = 200

            fitins.config = c

            local_keys    = self.exit_local_sd_keys[exit_i]
            local_weights = self.get_personalized_exit_weights(exit_i, self._last_global_params)
            fitins.parameters = weights_to_parameters(local_weights)
            fitins.config = dict(fitins.config or {}, keys_prog=local_keys, lid_hint=exit_i)

            patched.append((client, fitins))

        return patched

    def _score_keyset(self, keys: List[str], local_list: List[NDArray]) -> Tuple[int, int, int]:
        """Return score tuple: (len_match, exact_shape_hits, rank_hits)"""
        len_match = int(len(keys) == len(local_list))
        exact = 0
        rank = 0
        L = min(len(keys), len(local_list), 24)  # check first N only
        for i in range(L):
            exp = self._full_shapes[keys[i]]
            got = np.asarray(local_list[i]).shape
            if got == exp:
                exact += 1
            elif len(got) == len(exp):
                rank += 1
        return (len_match, exact, rank)

    def _choose_keyset(
        self,
        n: int,
        lid_default: int,
        local_list: List[NDArray],
        key_sig: Optional[str],
        key_kind: Optional[str],
    ) -> Tuple[int, List[str], str]:
        """
        Decide which key order to use for a client's payload.

        NEW: If the wire length `n` matches the server-hinted exit `lid_default`
        (either FULL or TRAIN), we *always* honor that hint. This prevents the
        resolver from collapsing to exit 0 when multiple exits have identical
        lengths/shapes (co-located exits, e.g. all -1).

        Otherwise, fall back to the robust heuristic.
        """
        # 0) Prefer the hinted exit when length matches its FULL/TRAIN set
        ef = self.exit_full_sd_keys.get(lid_default, [])
        et = self.exit_local_sd_keys.get(lid_default, [])
        if len(ef) == n:
            return lid_default, ef, "hinted:exit_full(len ok)"
        if len(et) == n:
            return lid_default, et, "hinted:exit_train(len ok)"

        # 1) If client told us the kind ('full'/'train'), use that at hinted lid when len matches
        if key_kind in ("full", "train"):
            src = self.exit_full_sd_keys if key_kind == "full" else self.exit_local_sd_keys
            ks = src.get(lid_default)
            if ks is not None and len(ks) == n:
                return lid_default, ks, f"kind:{key_kind}@lid (len ok)"
            # any exit with that kind and exact length
            for lvl, ks in src.items():
                if len(ks) == n:
                    return lvl, ks, f"kind:{key_kind} (len match)"

        # 2) Any exit (either kind) with exact length — pick best shape score
        def _score(keys: List[str]) -> Tuple[int, int, int]:
            # (len_match, exact_shape_hits, rank_hits) over first N keys
            len_match = int(len(keys) == n)
            exact = 0
            rank = 0
            L = min(len(keys), n, 24)
            for i in range(L):
                exp = self._full_shapes[keys[i]]
                got = np.asarray(local_list[i]).shape
                if got == exp:
                    exact += 1
                elif len(got) == len(exp):
                    rank += 1
            return (len_match, exact, rank)

        candidates: List[Tuple[int, str, List[str], Tuple[int, int, int]]] = []
        for lvl, ks in self.exit_full_sd_keys.items():
            if len(ks) == n:
                candidates.append((lvl, "exit_full", ks, _score(ks)))
        for lvl, ks in self.exit_local_sd_keys.items():
            if len(ks) == n:
                candidates.append((lvl, "exit_train", ks, _score(ks)))
        if candidates:
            # prefer the hinted lid on ties
            best = max(candidates, key=lambda t: (t[3], -int(t[0] == lid_default)))
            lvl, kind, ks, sc = best
            return lvl, ks, f"{kind} (best len+shape score={sc})"

        # 3) Signature match **only** if lengths match
        if key_sig and key_sig in self._sig2keys:
            lid_s, kind_s, keys_s = self._sig2keys[key_sig]
            if len(keys_s) == n:
                lid_final = lid_s if lid_s >= 0 else lid_default
                return lid_final, keys_s, f"sig:{kind_s} (len ok)"

        # 4) Fallback to global sets (pick best score)
        fallback: List[Tuple[int, str, List[str], Tuple[int, int, int]]] = []
        fallback.append((lid_default, "global_train", self.global_sd_keys, _score(self.global_sd_keys)))
        fallback.append((lid_default, "full_all", self._all_full_keys, _score(self._all_full_keys)))
        lvl, kind, ks, sc = max(fallback, key=lambda t: t[3])
        return lvl, ks, f"{kind} (fallback, score={sc})"

    def aggregate_fit(
        self,
        rnd: int,
        results: List[Tuple[ClientProxy, FitRes]],
        failures: List[BaseException],
        current_parameters: Parameters,
        server=None,
    ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:
        if not results:
            return None, {}
        if not self.accept_failures and failures:
            return None, {}

        def _is_bn_buffer(name: str) -> bool:
            return name.endswith("running_mean") or name.endswith("running_var") or name.endswith("num_batches_tracked")

        # pull current full global (we will update only *parameters*, keep BN as-is)
        full_params = self._ensure_full_global()
        full_list  = parameters_to_weights(full_params)
        full_sd    = {k: torch.from_numpy(v).cpu() for k, v in zip(self._all_full_keys, full_list)}

        acc   = {k: torch.zeros_like(v, dtype=(v.dtype if v.is_floating_point() else torch.float32)) for k, v in full_sd.items()}
        count = {k: torch.zeros_like(v, dtype=torch.float32) for k, v in full_sd.items()}

        used_prefix = 0

        for client, fit_res in results:
            local_list = parameters_to_weights(fit_res.parameters)
            n = len(local_list)
            metrics = fit_res.metrics or {}

            lid_default = self.clients_exit.get(client.cid, 0)
            lid, local_keys, _ = self._choose_keyset(
                n, lid_default, local_list, metrics.get("key_sig"), metrics.get("keyset_kind")
            )

            w_i = float(getattr(fit_res, "num_examples", 1.0))
            width_scale = float(metrics.get("width_scale", self.width_scaling[lid]))
            level_mask  = self._masks_for(lid, width_scale)

            local_sd = {k: torch.from_numpy(v).cpu() for k, v in zip(local_keys, local_list)}

            for k, l in local_sd.items():
                # ScaleFL baseline: never aggregate BN buffers
                if _is_bn_buffer(k):
                    continue
                if k not in full_sd:
                    continue

                g_full = full_sd[k]
                m = level_mask.get(k, None)
                if m is None:
                    continue
                if m.numel() != g_full.numel():
                    m = m.reshape(g_full.shape)

                m_flat, l_flat = m.view(-1), l.view(-1)
                if int(m_flat.sum().item()) == int(l_flat.numel()):
                    acc[k].view(-1)[m_flat]   += (l_flat * w_i).to(acc[k].dtype)
                    count[k].view(-1)[m_flat] += w_i
                    used_prefix += 1
        print(f"[AGG] used_prefix={used_prefix} / keys={len(self._all_full_keys)}")
        # reduce (only where we wrote)
        for k in self._all_full_keys:
            if _is_bn_buffer(k):
                continue  # keep BN from previous global
            m = count[k] > 0
            if m.any():
                full_sd[k][m] = (acc[k][m] / count[k][m]).to(full_sd[k].dtype)

        new_full_list = [full_sd[k].cpu().numpy() for k in self._all_full_keys]
        new_params = weights_to_parameters(new_full_list)
        self._last_global_params = new_params

        # simple metric mean (unchanged)
        train_summary = {}
        for _, fit_res in results:
            if not fit_res.metrics:
                continue
            for m, v in fit_res.metrics.items():
                try:
                    train_summary.setdefault(m, []).append(float(v))
                except (ValueError, TypeError):
                    pass
        metrics = {f"mean_{k}": float(np.mean(v)) for k, v in train_summary.items() if v}

        return new_params, metrics

    def evaluate(self, parameters, partition="test"):
        if self.eval_fn is None:
            return None

        # import numpy as np
        # import copy
        # from src.utils import get_func_from_config

        def _resolve_blk_to_exit(exit_i: int, blk_val: int) -> int:
            """If blk_val<0, resolve to concrete last block index from a full-depth probe."""
            if int(blk_val) >= 0:
                return int(blk_val)
            arch_fn  = get_func_from_config(self.net_config)
            probe_args = copy.deepcopy(self.net_config.args)
            probe_args.pop("depth", None)
            probe = arch_fn(device="cpu", **probe_args)
            try:
                full_depth = sum(len(s) for s in getattr(probe, "layers", []))
                if full_depth and full_depth > 0:
                    return full_depth - 1
            except Exception:
                pass
            try:
                vals = [int(x) for x in getattr(probe, "blks_to_exit", []) if int(x) >= 0]
                if vals:
                    return max(vals)
            except Exception:
                pass
            return 0

        logs, losses, accs = {}, [], []

        for exit_i in range(self.no_of_exits):
            keys_full   = self.exit_full_sd_keys[exit_i]    # params + BN buffers for this exit (progressive)
            blk_to_exit = _resolve_blk_to_exit(exit_i, int(self.blks_to_exit[exit_i]))

            # Build FULL payload for this exit from the current *global* wire
            full_sd = self._wire_to_full_sd(parameters, where="evaluate(full)")
            local_w_full = [full_sd[k] for k in keys_full]

            # determine width_scale for this exit/tier (unchanged)
            width_scale = 1.0
            for attr in ("width_scaling", "tier_widths", "tier_width_scaling", "exit_width_scaling"):
                w = getattr(self, attr, None)
                if w is not None:
                    try:
                        width_scale = float(w[exit_i]); break
                    except Exception:
                        pass
            if width_scale == 1.0:
                try:
                    width_scale = float(self.ckp.config.app.args.width_scaling[exit_i])
                except Exception:
                    width_scale = 1.0

            # Evaluate with FULL payload (the App now detects 'full' vs 'train' by length)
            res = self.eval_fn(local_w_full, partition, exit_i, blk_to_exit, keys_full, width_scale)
            if res is None:
                continue

            loss_i, m = res
            if isinstance(m, dict):
                logs.update(m)

            lk = f"centralized_{partition}_exit{exit_i}_loss"
            ak = f"centralized_{partition}_exit{exit_i}_acc"
            logs.setdefault(lk, float(loss_i))
            losses.append(float(logs[lk]))
            if ak in logs:
                accs.append(float(logs[ak]))

        if losses:
            logs[f"centralized_{partition}_exit_all_loss"] = float(np.mean(losses))
        if accs:
            logs[f"centralized_{partition}_exit_all_acc"] = float(np.mean(accs))

        return logs.get(f"centralized_{partition}_exit_all_loss", 0.0), logs

    def get_personalized_exit_weights(self, exit_i: int, parameters: Parameters):
        """
        Build the payload for exit `exit_i` that MATCHES the client model.

        - If the exit uses SNIP and we have an index map, prune with SNIP indices
        (packed/contiguous) so the client's narrow tensors align with SNIP.
        - Otherwise fall back to prefix/width indices.
        """
        arrs = parameters_to_weights(parameters)
        keys_exit = self.exit_local_sd_keys[exit_i]
        n = len(arrs)

        # passthrough if it's already exit-sized
        if n == len(keys_exit):
            return arrs

        # map incoming weights to a dict we can prune from
        if n == len(self._all_full_keys):
            src_map = dict(zip(self._all_full_keys, arrs))
        elif n == len(self.global_sd_keys):
            src_map = dict(zip(self.global_sd_keys, arrs))
        else:
            raise ValueError(
                f"get_personalized_exit_weights: unexpected length {n} "
                f"(exit {exit_i} expects {len(keys_exit)} or full {len(self._all_full_keys)} "
                f"or trainable {len(self.global_sd_keys)})"
            )

        # Decide which index map to use
        use_snip = (
            self._snip_any
            and exit_i < len(self.pruning_mode)
            and str(self.pruning_mode[exit_i]).lower() == "snip"
            and exit_i in self._exit_index_map
        )

        if use_snip:
            # SNIP-packed (indices produced in _prepare_exit_indices_for_round)
            idx_map = {k: v for k, v in self.param_idxs[exit_i].items() if k in keys_exit}
        else:
            # Prefix/width
            idx_map = {k: v for k, v in self._prefix_param_idxs[exit_i].items() if k in keys_exit}

        pruned = prune(src_map, idx_map)  # narrow tensors that fit the client net
        return [pruned.get(k, src_map[k]) for k in keys_exit]
