import torch
import numpy as np
import logging
from typing import Any, Dict, Iterable, List, Sequence, Optional
from flwr.common import weights_to_parameters

try:
    from sklearn.cluster import KMeans  # noqa: F401
except Exception:  # pragma: no cover
    KMeans = None  # scikit‑learn optional

from src.server.strategies.valuations import ClientValuation

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

def _summarize_delta_reason(cid, selector, local_list, current_weights, mapped_upd):
        """Return a short human-readable reason for a zero delta."""
        if selector is None:
            return "no-selector"
        if len(selector) != len(local_list):
            return f"selector/payload-len-mismatch sel={len(selector)} pay={len(local_list)}"
        sel = list(selector)
        nones = sum(1 for t in mapped_upd if t is None)
        nnz   = 0
        maxabs= 0.0
        for j, gidx in enumerate(sel):
            if 0 <= gidx < len(current_weights) and mapped_upd[gidx] is not None:
                a = mapped_upd[gidx].abs()
                nnz += int(a.sum().item() > 0.0)
                maxabs = max(maxabs, float(a.max().item()))
        if nnz == 0:
            return f"all-zero-on-selected (sel={len(sel)} nones={nones} maxabs={maxabs:.3e})"
        return f"nonzero (sel={len(sel)} nnz={nnz} maxabs={maxabs:.3e})"

###############################################################################
# Owen value with two–stage computation                                        #
#   1) **Inter‑group Shapley (coalitional)** – treat each lid (group) as a     #
#      single player and approximate its Shapley value over Monte‑Carlo        #
#      permutations of the *groups* (``n_perm``).                              #
#   2) **Intra‑group allocation** – split the coalition payoff among members   #
#      using the *normalised intra‑group similarity* (defaults to cosine).     #
#                                                                              #
#   The final individual payoff = group‑level Shapley * share_in_group.        #
#                                                                              #
#   After valuation we balance‑recluster the full fleet into ``num_groups``    #
#   equal‑sized buckets (descending by payoff).                                #
###############################################################################

class OwenMC(ClientValuation):
    def __init__(
        self,
        similarity: str = "cosine",
        n_perm: int = 20,
        num_groups: int = 4,
        jitter: float = 1e-6,
        # intra-group knobs
        alpha: float = 0.30,
        ema_decay: float = 0.55,
        regroup_every: int = 3,
        beta0: float = 0.90,
        beta1: float = 1.40,
        temp0: float = 0.55,
        temp1: float = 0.35,
        gamma: float = 0.40,
        clip_min: float = 1e-3,
        lam: float = 0.10,
        sim_whiten: bool = True,
        log_hparams: bool = True,
        # stability/capacity knobs
        warmup_rounds: int = 0,
        min_stay_rounds: int = 5,
        lam_schedule: str = "linear",
        gamma_schedule: str = "const",
        # NEW: persistence/normalization knobs
        neutral_value: float = 1.0,     # baseline value for never-seen clients
        normalize_round: bool = True,   # rescale this round’s slice so mean≈1
        idle_ema: float = 0.0,          # 0=off; e.g., 0.95 gently drifts idle clients to neutral
        strict_equal_groups: bool = True,   # keep groups equal-sized each regroup
        map_best_to_deepest: bool = True,   # lid=0 (best) → deepest exit
        aggressive_mode: bool = False,
        intra_mix_eps: float = 0.05,
        use_target_in_shapley: bool = True,
        *args, **kwargs
    ) -> None:
        super().__init__(*args, **kwargs)
        self.similarity = similarity.lower()
        if self.similarity not in {"cosine", "euclidean", "pearson"}:
            raise ValueError(f"Unknown similarity metric: {similarity}")
        self.aggressive_mode = bool(aggressive_mode)
        self.n_perm       = int(max(1, n_perm))
        self.num_groups   = int(max(1, num_groups))
        self.jitter       = float(jitter)

        self.alpha         = float(alpha)
        self.ema_decay     = float(ema_decay)
        self.regroup_every = int(max(1, regroup_every))
        self.beta0         = float(beta0)
        self.beta1         = float(beta1)
        self.temp0         = float(temp0)
        self.temp1         = float(temp1)
        self.gamma         = float(gamma)
        self.clip_min      = float(clip_min)
        self.lam           = float(lam)
        self.sim_whiten    = bool(sim_whiten)

        self.warmup_rounds   = int(max(0, warmup_rounds))
        self.min_stay_rounds = int(max(0, min_stay_rounds))
        self.lam_schedule    = str(lam_schedule).lower()
        self.gamma_schedule  = str(gamma_schedule).lower()

        # NEW
        self.neutral_value   = float(neutral_value)
        self.normalize_round = bool(normalize_round)
        self.idle_ema        = float(idle_ema)

        self._printed_hparams = False
        self._log_hparams     = bool(log_hparams)
        self.strict_equal_groups = bool(strict_equal_groups)
        self.map_best_to_deepest = bool(map_best_to_deepest)
        self.intra_mix_eps = float(intra_mix_eps)
        self.use_target_in_shapley = bool(use_target_in_shapley)

    @staticmethod
    def _cos(a: torch.Tensor, b: torch.Tensor) -> float:
        return torch.nn.functional.cosine_similarity(a.flatten(), b.flatten(), dim=0).item()

    @staticmethod
    def _euc(a: torch.Tensor, b: torch.Tensor) -> float:  # higher == closer
        return -torch.norm(a.flatten() - b.flatten()).item()

    @staticmethod
    def _pear(a: torch.Tensor, b: torch.Tensor) -> float:
        a, b = a.flatten(), b.flatten()
        if a.std() == 0 or b.std() == 0:
            return 0.0
        return float(np.corrcoef(a.cpu().numpy(), b.cpu().numpy())[0, 1])

    _SIM_FN = {"cosine": _cos.__func__, "euclidean": _euc.__func__, "pearson": _pear.__func__}

    @staticmethod
    def _as_proxies(it: Iterable[Any]) -> List[Any]:
        it = list(it)
        return it if it and hasattr(it[0], "cid") else []

    # --- Memory / defaults -------------------------------------------------
    def _ensure_memory_defaults(self):
        # Lazy init so you don't have to modify __init__
        if not hasattr(self, "mem"):
            self.mem = {}  # cid -> {"flat": Tensor(cpu), "n": int, "lid": int, "r": int}
        # Defaults; override via class attrs if you want (set before first evaluate)
        for name, val in [
            ("stale_half_life", 20),   # rounds: 50% decay every 20 rounds
            ("mem_max_age",     100),  # ignore older than this
            ("neutral_value",   1.0),  # idle drift target
            ("idle_ema",        0.10), # how fast inactive clients drift to neutral
            ("normalize_round", True), # normalize slice contribs each round
            ("warmup_rounds",   20),   # regroup warmup (if attr existed already, it's kept)
            ("min_stay_rounds", 5),
        ]:
            if not hasattr(self, name):
                setattr(self, name, val)

    def _stale_w(self, age: int) -> float:
        if age < 0:
            return 0.0
        hl = max(1e-6, float(getattr(self, "stale_half_life", 20)))
        return float(2.0 ** (-(float(age) / hl)))

    def _update_memory(self, slice_proxies, flat_deltas_raw, client_samples, client_ids, rnd: int) -> None:
        self._ensure_memory_defaults()
        cur_sig = int(flat_deltas_raw[0].numel()) if flat_deltas_raw else 0
        cid2idx = {c: i for i, c in enumerate(client_ids)}
        for p in slice_proxies:
            i = cid2idx[p.cid]
            raw = self._nan_guard(flat_deltas_raw[i].detach().cpu(), tag=f"mem[{p.cid}]")
            n_i = int(max(1, client_samples[i]))
            vnorm = self._vec_norm(raw)
            if vnorm == 0.0:
                logger.info(f"[OwenMC] storing ZERO vec for cid={p.cid} (round={rnd}, n={n_i})")
            self.mem[p.cid] = {
                "flat": raw,
                "n": n_i,
                "lid": int(getattr(p, "lid", 0)),
                "r": int(rnd),
                "sig": cur_sig,
            }

    def _count_mem_cov_by_lid(self, proxies_all: List[Any]) -> Dict[int, int]:
        """How many clients per current lid have a memory vector."""
        cov: Dict[int, int] = {}
        for p in proxies_all:
            lid = int(getattr(p, "lid", 0))
            if p.cid in self.mem:
                cov[lid] = cov.get(lid, 0) + 1
            else:
                cov.setdefault(lid, 0)
        return cov

    def _mem_stats_by_lid(self, proxies_all: List[Any], D: int) -> Dict[int, dict]:
        stats: Dict[int, dict] = {}
        for p in proxies_all:
            lid = int(getattr(p, "lid", 0))
            rec = self.mem.get(p.cid, None)
            if lid not in stats:
                stats[lid] = {"cnt":0, "zero":0, "naninf":0, "age_sum":0, "mass_abs":0.0, "mass_signed":0.0}
            if rec is None:
                continue
            v = rec.get("flat", None)
            if not isinstance(v, torch.Tensor):
                continue
            vec = v.float().view(-1)
            if vec.numel() != D and D > 0:
                # still usable (we pad/trunc), but flag it
                pass
            nrm = float(vec.norm(p=2).item())
            naninf = int((torch.isnan(vec) | torch.isinf(vec)).any().item())
            zero = 1 if nrm == 0.0 else 0
            age = 0  # exact age filled in evaluate where we know 'rnd'
            stats[lid]["cnt"] += 1
            stats[lid]["zero"] += zero
            stats[lid]["naninf"] += naninf
            # "mass" proxies: absolute and signed projection onto unit-1 vector (very rough)
            stats[lid]["mass_abs"] += nrm
            stats[lid]["mass_signed"] += nrm  # signed needs a direction; we just track same for now
        return stats
    
    def _rep_audit(self, rnd: int, rep_by_lid: Dict[int, torch.Tensor], proxies_all: List[Any],
               D: int, stale_w_fn, mem: dict, lam_eff: float, stale_half_life: float) -> None:
        # per-lid sums of (w_stale * n_i * ||v||)
        accum = {}
        for p in proxies_all:
            rec = mem.get(p.cid, None)
            if rec is None:
                continue
            age = int(rnd) - int(rec.get("r", rnd))
            w = stale_w_fn(age)
            n_i = float(max(1, int(rec.get("n", 1))))
            lid = int(getattr(p, "lid", rec.get("lid", 0)))
            v = rec.get("flat", None)
            if not isinstance(v, torch.Tensor):
                continue
            nrm = float(v.float().view(-1).norm().item())
            accum[lid] = accum.get(lid, 0.0) + (w * n_i * nrm)
        try:
            rep_norms = {lid: float(rep.norm().item()) for lid, rep in rep_by_lid.items()}
            logger.info(f"[OwenMC][rnd {rnd}] rep_audit mass_abs≈{accum} ; rep_norms={rep_norms} "
                        f"(lam_eff={lam_eff:.3f}, hl={stale_half_life})")
            # flag lids with mem mass but zero rep
            for lid, mass in accum.items():
                if mass > 0 and rep_norms.get(lid, 0.0) == 0.0:
                    logger.warning(f"[OwenMC][rnd {rnd}] ALERT: lid={lid} has mem mass {mass:.1f} but rep norm=0")
        except Exception:
            pass

    def _inject_group_prior(
        self,
        group_rep: Dict[int, torch.Tensor],
        t_hat: Optional[torch.Tensor],
        proxies_all: List[Any],
        eps_scale: float = 1e-4,
    ) -> None:
        """
        If a group has members but its rep is exactly zero (no memory mass yet),
        add a tiny vector along t_hat so its Shapley isn’t forced to 0.0.
        Operates in-place on group_rep.
        """
        if t_hat is None:
            return
        if not group_rep:
            return

        # members count per lid (current lids BEFORE regroup)
        members_by_lid: Dict[int, int] = {}
        for p in proxies_all:
            lid = int(getattr(p, "lid", 0))
            members_by_lid[lid] = members_by_lid.get(lid, 0) + 1

        # scale prior to be tiny vs existing reps
        nz_norms = [v.norm().item() for v in group_rep.values() if v.numel() > 0 and v.norm().item() > 0]
        base = float(np.median(nz_norms)) if nz_norms else 1.0
        eps = float(eps_scale) * base

        D = next(iter(group_rep.values())).numel()
        if t_hat.numel() != D:
            # shape-mismatch safety: skip
            return

        for lid, rep in list(group_rep.items()):
            if members_by_lid.get(lid, 0) <= 0:
                continue  # empty group (shouldn't happen with equal sizes)
            if rep.numel() == 0 or rep.norm().item() == 0.0:
                group_rep[lid] = rep + eps * t_hat  # in-place replace

    def old_fleet_group_reps_from_memory(
        self,
        server,
        flat_deltas_raw,        # list[Tensor] RAW (unwhitened) for current slice
        client_samples,         # list[int] for current slice
        client_ids,             # list[str] for current slice
        results,                # [(ClientProxy, FitRes), ...] current slice
        rnd: int,
    ) -> dict[int, torch.Tensor]:
        """
        Build group representatives (one RAW vector per lid) from:
        - current slice (fresh, weight=1.0)
        - last-seen memory for others, with staleness decay
        IMPORTANT: returns **S U M** of (stale_weight * n_i * Δ_i) per group (NO normalization).
        """
        self._ensure_memory_defaults()
        device = flat_deltas_raw[0].device if flat_deltas_raw else torch.device("cpu")
        D = int(flat_deltas_raw[0].numel()) if flat_deltas_raw else 1
        zero = torch.zeros(D, dtype=torch.float32, device=device)

        def _fix_dim(v: torch.Tensor, D: int) -> torch.Tensor:
            v = v.float().view(-1)
            n = int(v.numel())
            if n == D:
                return v
            if n < D:
                return torch.nn.functional.pad(v, (0, D - n))
            return v[:D]

        cid2idx = {cid: i for i, cid in enumerate(client_ids)}
        slice_proxies = [p for p, _ in results]
        proxies_all = self._collect_all_proxies(server)

        rep_by_lid: dict[int, torch.Tensor] = {}

        for p in proxies_all:
            cid = p.cid
            if cid in cid2idx:
                # fresh (RAW) from this round
                i = cid2idx[cid]
                v = _fix_dim(flat_deltas_raw[i].to(device), D)
                n_i = float(max(1, int(client_samples[i])))
                w_stale = 1.0
                lid = int(getattr(p, "lid", 0))
            else:
                # memory (RAW) from previous round(s)
                rec = self.mem.get(cid, None)
                if rec is None:
                    continue
                age = int(rnd) - int(rec.get("r", rnd))
                if age > int(getattr(self, "mem_max_age", 100)):
                    continue
                v = _fix_dim(rec["flat"].to(device), D)
                n_i = float(max(1, int(rec.get("n", 1))))
                w_stale = self._stale_w(age)
                if w_stale <= 0.0:
                    continue
                lid = int(rec.get("lid", 0))

            # SUM (no normalization): ∑ (w_stale * n_i) * v
            rep_by_lid[lid] = rep_by_lid.get(lid, zero.clone()) + (w_stale * n_i) * v

        # ensure all lids in fleet exist in the dict
        if not rep_by_lid:
            return {0: zero.clone()}
        return rep_by_lid

    def _whiten_flat(self, flat: torch.Tensor, ref_numels: List[int], scales: List[float]) -> torch.Tensor:
        """Divide each layer segment by its RMS scale; keeps length identical to flat."""
        flat = flat.view(-1).float()
        D = int(sum(int(m) for m in ref_numels))
        if flat.numel() != D or not scales or len(scales) != len(ref_numels):
            return flat  # safety: no-op
        out = []
        off = 0
        for m, s in zip(ref_numels, scales):
            m = int(m)
            seg = flat[off:off+m]
            seg = seg / float(s if s and s > 1e-12 else 1.0)
            out.append(seg)
            off += m
        return torch.cat(out, dim=0)
    
    def _fleet_group_reps_from_memory(
        self,
        server,
        flat_deltas_raw,        # list[Tensor] RAW (unwhitened) for current slice
        client_samples,         # list[int] for current slice
        client_ids,             # list[str] for current slice
        results,                # [(ClientProxy, FitRes), ...] current slice
        rnd: int,
    ) -> dict[int, torch.Tensor]:
        """
        Build group representatives (one RAW vector per CURRENT lid) from:
        - current slice (fresh, weight=1.0)
        - last-seen memory for others, with staleness decay

        IMPORTANT: assign every contribution to the client's *current* lid.
        """
        self._ensure_memory_defaults()
        device = flat_deltas_raw[0].device if flat_deltas_raw else torch.device("cpu")
        D = int(flat_deltas_raw[0].numel()) if flat_deltas_raw else 1
        zero = torch.zeros(D, dtype=torch.float32, device=device)

        def _fix_dim(v: torch.Tensor, D: int) -> torch.Tensor:
            v = v.float().view(-1)
            n = int(v.numel())
            if n == D:
                return v
            if n < D:
                return torch.nn.functional.pad(v, (0, D - n))
            return v[:D]

        cid2idx = {cid: i for i, cid in enumerate(client_ids)}
        proxies_all = self._collect_all_proxies(server)

        rep_by_lid: dict[int, torch.Tensor] = {}

        for p in proxies_all:
            cid = p.cid
            lid_now = int(getattr(p, "lid", 0))  # use CURRENT lid for assignment

            if cid in cid2idx:
                # fresh (RAW) from this round
                i = cid2idx[cid]
                v = _fix_dim(flat_deltas_raw[i].to(device), D)
                n_i = float(max(1, int(client_samples[i])))
                w_stale = 1.0
            else:
                # memory (RAW) from previous rounds
                rec = self.mem.get(cid, None)
                if rec is None:
                    continue
                age = int(rnd) - int(rec.get("r", rnd))
                if age > int(getattr(self, "mem_max_age", 100)):
                    continue
                v = _fix_dim(rec["flat"].to(device), D)
                n_i = float(max(1, int(rec.get("n", 1))))
                w_stale = self._stale_w(age)
                if w_stale <= 0.0:
                    continue

            rep_by_lid[lid_now] = rep_by_lid.get(lid_now, zero.clone()) + (w_stale * n_i) * v

        # Ensure every CURRENT lid has an entry (even if zero)
        lids_present = {int(getattr(p, "lid", 0)) for p in proxies_all}
        for lid in lids_present:
            if lid not in rep_by_lid:
                rep_by_lid[lid] = zero.clone()

        # Optional sanity log
        try:
            rep_norms = {lid: float(rep.norm(p=2).item()) for lid, rep in rep_by_lid.items()}
            logger.info(f"[OwenMC][rnd {rnd}] rep_stats_fix [‖rep‖]: {rep_norms}")
        except Exception:
            pass
        
        mu = 0.70  # inertia; try 0.6–0.8
        prev = getattr(self, "_rep_ema", {})
        smoothed = {}
        for lid, v in rep_by_lid.items():
            v_prev = prev.get(lid, None)
            if isinstance(v_prev, torch.Tensor) and v_prev.numel() == v.numel():
                smoothed[lid] = mu * v_prev + (1.0 - mu) * v
            else:
                smoothed[lid] = v
        self._rep_ema = {lid: smoothed[lid].detach() for lid in smoothed}  # store for next round
        rep_by_lid = smoothed

        return rep_by_lid
    
    def _group_shapley_from_rep(
        self,
        group_rep: dict[int, torch.Tensor],
        target_vec: Optional[torch.Tensor],
        lam: float = 0.0,
        seed: Optional[int] = None,
    ) -> dict[int, float]:
        """
        Group-level Shapley on SUM reps with a strictly monotone utility:
        U(S) = sum_{g in S, p_g>0} p_g + lam * sum_{g in S} ||v_g||
        where p_g = <v_g, t_hat>. We (1) use antithetic permutations to cut variance,
        (2) clamp negative marginals to 0 (numeric safety), and (3) enforce Shapley
        efficiency: sum_g φ_g == U(all).
        """
        groups = list(group_rep.keys())
        if not groups:
            return {}

        # unit t̂ if provided (RAW space)
        t = None
        if isinstance(target_vec, torch.Tensor) and target_vec.numel() > 0:
            t = target_vec.view(-1).float()
            t = t / (t.norm(p=2) + 1e-12)

        def U(coal: List[int]) -> float:
            if not coal:
                return 0.0
            proj_pos = 0.0
            if t is not None:
                for g in coal:
                    v = group_rep[g].view(-1).float()
                    p = float(torch.dot(v, t).item())
                    if p > 0.0:
                        # squash extremes: sqrt keeps ordering but reduces spikes
                        proj_pos += p ** 0.5
            mag_pos = 0.0
            for g in coal:
                mag_pos += float(group_rep[g].view(-1).float().norm(p=2).item())
            return proj_pos + float(lam) * mag_pos

        rng = np.random.default_rng(seed)
        n_perm = int(max(1, getattr(self, "n_perm", 20)))
        # antithetic: do pairs (π, reverse(π)) to reduce variance
        k = max(1, n_perm // 2)
        phi = {g: 0.0 for g in groups}

        for _ in range(k):
            order = rng.permutation(groups).tolist()
            for ord_ in (order, list(reversed(order))):
                prefix, u_prev = [], 0.0
                for g in ord_:
                    u_next = U(prefix + [g])
                    delta = u_next - u_prev
                    if delta < 0.0:
                        delta = 0.0
                    phi[g] += delta
                    prefix.append(g)
                    u_prev = u_next

        used = 2 * k
        if used < n_perm:
            order = rng.permutation(groups).tolist()
            prefix, u_prev = [], 0.0
            for g in order:
                u_next = U(prefix + [g])
                delta = u_next - u_prev
                if delta < 0.0:
                    delta = 0.0
                phi[g] += delta
                prefix.append(g)
                u_prev = u_next
            used += 1

        inv = 1.0 / float(used)
        for g in groups:
            phi[g] = max(0.0, phi[g] * inv)

        # Enforce efficiency: sum φ == U(all)
        total_phi = sum(phi.values())
        total_u = U(groups)
        if total_phi > 0 and total_u > 0:
            scale = float(total_u / total_phi)
            for g in groups:
                phi[g] *= scale

        return phi

    def old_group_shapley_from_rep(
        self,
        group_rep: dict[int, torch.Tensor],
        target_vec: Optional[torch.Tensor],
        lam: float = 0.0,
        seed: Optional[int] = None,
    ) -> dict[int, float]:
        """Group-level Shapley using precomputed SUM reps."""
        groups = list(group_rep.keys())
        if not groups:
            return {}
        phi = {g: 0.0 for g in groups}
        rng = np.random.default_rng(seed)
        lam_local = float(lam)

        def _coal_u(coal):
            if not coal:
                return 0.0
            S = torch.stack([group_rep[g] for g in coal], dim=0).sum(dim=0)

            if target_vec is not None and target_vec.numel() > 0:
                t = target_vec
                tnorm = t.norm(p=2)
                if tnorm.item() > 0:
                    t_hat = t / tnorm
                    if getattr(self, "aggressive_mode", False):
                        # AGGRESSIVE: signed alignment drives utility; anti-aligned coalitions get penalized
                        signed_proj = float(torch.dot(S, t_hat).item())               # can be negative
                        mag = float(S.norm(p=2).item())
                        return signed_proj + lam_local * mag
                    else:
                        # Conservative: keep only positive contributions to avoid harsh penalties
                        proj_pos = 0.0
                        for g in coal:
                            v = group_rep[g]
                            if v.numel() == 0:
                                continue
                            proj = float(torch.dot(v, t_hat).item())
                            if proj > 0.0:
                                proj_pos += proj
                        mag = float(S.norm(p=2).item())
                        return proj_pos + lam_local * mag

            # No target: fall back to magnitude
            return float(S.norm(p=2).item())

        n_perm = int(max(1, getattr(self, "n_perm", 20)))
        for _ in range(n_perm):
            order = rng.permutation(groups).tolist()
            prefix, u_prev = [], 0.0
            for g in order:
                u_next = _coal_u(prefix + [g])
                phi[g] += (u_next - u_prev)
                prefix.append(g); u_prev = u_next

        scale = float(n_perm)
        for g in groups:
            phi[g] /= scale
        return phi

    # --- helper: Spearman rank correlation (no SciPy) ---
    def _spearman_rank_corr(self, a: List[float], b: List[float]) -> float:
        if len(a) != len(b) or len(a) == 0:
            return float('nan')
        x = np.asarray(a, dtype=float)
        y = np.asarray(b, dtype=float)
        # rank with average ties
        rx = x.argsort().argsort().astype(float)
        ry = y.argsort().argsort().astype(float)
        # average ranks for ties
        for ranks, vals in ((rx, x), (ry, y)):
            _, inv, counts = np.unique(vals, return_inverse=True, return_counts=True)
            avg = np.bincount(inv, ranks) / counts
            ranks[:] = avg[inv]
        # Pearson on ranks
        sx = (rx - rx.mean()) / (rx.std() + 1e-12)
        sy = (ry - ry.mean()) / (ry.std() + 1e-12)
        return float(np.mean(sx * sy))

    def _flatten_pad(self, upd_list: List[torch.Tensor], ref_numels: List[int]) -> torch.Tensor:
        """
        Concatenate all layer deltas into one vector with a fixed length
        (sum(ref_numels)). For any missing layer (None) or length mismatch,
        pad/truncate so total length is identical across clients.
        """
        chunks: List[torch.Tensor] = []
        L = len(ref_numels)
        for i in range(L):
            m = int(ref_numels[i])
            if i >= len(upd_list) or upd_list[i] is None:
                chunks.append(torch.zeros(m, dtype=torch.float32))
                continue
            v = upd_list[i].flatten().float()
            n = int(v.numel())
            if n == m:
                chunks.append(v)
            elif n < m:
                pad = m - n
                chunks.append(torch.nn.functional.pad(v, (0, pad)))
            else:  # n > m (shouldn't happen if shapes checked, but be defensive)
                chunks.append(v[:m])
        return torch.cat(chunks, dim=0)

    def _collect_all_proxies(self, server) -> List[Any]:
        mgr = server.client_manager()
        proxies = self._as_proxies(mgr.all())
        if proxies:
            return proxies
        # fallback: look inside manager dicts
        for attr in mgr.__dict__.values():
            if isinstance(attr, (list, dict)):
                vals = attr.values() if isinstance(attr, dict) else attr
                proxies = self._as_proxies(vals)
                if proxies:
                    return proxies
        return []

    def _dump(self, proxies: Sequence[Any], tag: str) -> None:
        grp: Dict[int, List[str]] = {}
        for p in proxies:
            grp.setdefault(getattr(p, "lid", -1), []).append(str(p.cid))
        logger.info(f"[OwenMC] FULL groups {tag}: {grp}")

    def _ensure_value_defaults(self, server) -> List[Any]:
        """Ensure every proxy has baseline value/ema so ranking is fair with partial participation."""
        proxies = self._collect_all_proxies(server)
        for p in proxies:
            if not hasattr(p, "value"):
                p.value = float(self.neutral_value)
            if not hasattr(p, "value_ema"):
                p.value_ema = float(self.neutral_value)
        return proxies

    def _pairwise_sim(self, updates: List[List[torch.Tensor]]) -> torch.Tensor:
        fn = self._SIM_FN[self.similarity]
        n = len(updates)
        M = torch.zeros((n, n))
        for i in range(n):
            M[i, i] = 1.0
            for j in range(i + 1, n):
                shared = [fn(a, b) for a, b in zip(updates[i], updates[j]) if a is not None and b is not None]
                val = float(np.mean(shared)) if shared else 0.0
                M[i, j] = M[j, i] = val
        return M

    def _flatten(self, upd_list: List[torch.Tensor]) -> torch.Tensor:
        vecs = [t.flatten() for t in upd_list if t is not None]
        if not vecs:
            return torch.zeros(1)
        return torch.cat(vecs).float()

    def _build_group_reps(
        self,
        lid_to_idx: Dict[int, List[int]],
        flat_deltas: List[torch.Tensor],
        client_samples: List[int],
    ) -> Dict[int, torch.Tensor]:
        """
        Build one vector per group as the **sample-weighted SUM** of member deltas.
        This matches FedAvg's additive contribution: sum_i (n_i * Δ_i).
        """
        reps: Dict[int, torch.Tensor] = {}
        D = int(flat_deltas[0].numel()) if flat_deltas else 1
        zero = torch.zeros(D, dtype=torch.float32)
        for lid, idxs in lid_to_idx.items():
            if not idxs:
                reps[lid] = zero.clone()
                continue
            s = torch.tensor([max(1, int(client_samples[i])) for i in idxs], dtype=torch.float32)
            M = torch.stack([flat_deltas[i] for i in idxs], dim=0)   # [m, D]
            g = (s[:, None] * M).sum(dim=0)                          # <-- SUM, not mean
            reps[lid] = g
        return reps

    def _coalition_utility(
        self,
        coalition: List[int],
        group_rep: Dict[int, torch.Tensor],
        target: Optional[torch.Tensor],
    ) -> float:
        if not coalition:
            return 0.0

        if target is not None:
            tnorm = target.norm(p=2)
            if tnorm.item() > 0:
                t_hat = target / tnorm
                proj_pos = 0.0
                for g in coalition:
                    v = group_rep[g]
                    if v.numel() == 0:
                        continue
                    proj = float(torch.dot(v, t_hat).item())
                    if proj > 0.0:
                        proj_pos += proj
                S = torch.stack([group_rep[g] for g in coalition], dim=0).sum(dim=0)
                mag = float(S.norm(p=2).item())
                return proj_pos + self.lam * mag

        S = torch.stack([group_rep[g] for g in coalition], dim=0).sum(dim=0)
        return float(S.norm(p=2).item())

    def _group_shapley(
        self,
        lid_to_idx: Dict[int, List[int]],
        flat_deltas: List[torch.Tensor],
        client_samples: List[int],
        target_vec: Optional[torch.Tensor],
        seed: Optional[int] = None,
    ) -> Dict[int, float]:
        groups = list(lid_to_idx.keys())
        if not groups:
            return {}
        rep = self._build_group_reps(lid_to_idx, flat_deltas, client_samples)
        phi = {g: 0.0 for g in groups}
        rng = np.random.default_rng(seed)

        for _ in range(self.n_perm):
            order = rng.permutation(groups).tolist()
            prefix, u_prev = [], 0.0
            for g in order:
                u_next = self._coalition_utility(prefix + [g], rep, target_vec)
                phi[g] += (u_next - u_prev)
                prefix.append(g); u_prev = u_next

        scale = float(self.n_perm)
        for g in groups:
            phi[g] /= scale
        return phi

    def _layer_rms(self, updates: List[List[torch.Tensor]]) -> List[float]:
        """Per-layer RMS over clients (ignores None); used to whiten deltas."""
        L = len(updates[0]) if updates else 0
        scales: List[float] = []
        eps = 1e-8
        for i in range(L):
            vals = []
            for u in updates:
                t = u[i]
                if t is not None:
                    v = t.float().flatten()
                    if v.numel() > 0:
                        vals.append(v)
            if vals:
                cat = torch.cat(vals, dim=0)
                rms = torch.sqrt((cat * cat).mean()).item()
                scales.append(rms if rms > eps else 1.0)
            else:
                scales.append(1.0)
        return scales

    def _flatten_pad_scaled(
        self,
        upd_list: List[torch.Tensor],
        ref_numels: List[int],
        scales: List[float],
    ) -> torch.Tensor:
        """
        Concatenate all layer deltas into one vector with fixed length sum(ref_numels),
        after per-layer whitening by 'scales'. Pads/truncates to keep identical length.
        """
        chunks: List[torch.Tensor] = []
        L = len(ref_numels)
        for i in range(L):
            m = int(ref_numels[i])
            s = float(scales[i]) if i < len(scales) else 1.0
            if i >= len(upd_list) or upd_list[i] is None:
                chunks.append(torch.zeros(m, dtype=torch.float32))
                continue
            v = upd_list[i].float().flatten()
            if s != 1.0:
                v = v / s
            n = int(v.numel())
            if n == m:
                chunks.append(v)
            elif n < m:
                chunks.append(torch.nn.functional.pad(v, (0, m - n)))
            else:
                chunks.append(v[:m])
        return torch.cat(chunks, dim=0)

    def _regroup(self, server, default_val: float):
        """Balanced contiguous buckets by value (desc), rename so lid=0 is best."""
        proxies = self._collect_all_proxies(server)
        if not proxies:
            logger.info("[OwenMC] No proxies found to regroup – skipping.")
            return

        try:
            exits = int(getattr(server, "no_of_exits", self.num_groups))
        except Exception:
            exits = self.num_groups
        G_cfg = max(1, int(self.num_groups))
        G = max(1, min(G_cfg, exits, len(proxies)))
        if G != G_cfg:
            logger.info(f"[OwenMC] Adjusting num_groups from {G_cfg} to {G} "
                        f"(exits={exits}, clients={len(proxies)})")

        self._dump(proxies, "BEFORE")

        # Use value -> value_ema -> neutral fallback (not 0.0)
        vals = np.array([
            float(getattr(p, "value", getattr(p, "value_ema", self.neutral_value)))
            for p in proxies
        ], dtype=float)

        if self.jitter and self.jitter > 0:
            vals = vals + np.random.uniform(-self.jitter, self.jitter, size=vals.shape)

        N = len(proxies)
        if N <= 1 or G == 1:
            for p in proxies:
                p.lid = 0
            self._dump(proxies, "AFTER")
            return

        order = np.argsort(-vals)
        labels = np.full(N, -1, dtype=int)
        base = N // G
        extra = N % G
        offset = 0
        for g in range(G):
            cnt = base + (1 if g < extra else 0)
            if cnt > 0:
                sl = order[offset:offset + cnt]
                labels[sl] = g
                offset += cnt
        labels[labels < 0] = G - 1

        means = {g: float(vals[labels == g].mean()) if np.any(labels == g) else -np.inf for g in range(G)}
        rename = {old: new for new, old in enumerate(sorted(means, key=means.get, reverse=True))}
        for i, p in enumerate(proxies):
            p.lid = int(rename[int(labels[i])])

        self._dump(proxies, "AFTER")
        # After regroup + apply mapping, refresh mem.lid so future reps are consistent
        try:
            for p in self._collect_all_proxies(server):
                rec = self.mem.get(p.cid, None)
                if rec is not None:
                    rec["lid"] = int(getattr(p, "lid", rec.get("lid", 0)))
        except Exception:
            pass

    def _nan_guard(self, t: torch.Tensor, tag: str) -> torch.Tensor:
        if t is None:
            return t
        bad = torch.isnan(t) | torch.isinf(t)
        if bad.any():
            cnt = int(bad.sum().item())
            logger.warning(f"[OwenMC] {tag}: found {cnt} NaN/Inf entries; zeroing them")
            t = t.clone()
            t[bad] = 0.0
        return t

    def _vec_norm(self, v: Optional[torch.Tensor]) -> float:
        try:
            return float(v.float().view(-1).norm(p=2).item()) if v is not None else 0.0
        except Exception:
            return 0.0

    def _estimate_target(
        self,
        flat_deltas: List[torch.Tensor],
        client_samples: List[int],
        ema_decay: float = 0.9,
    ) -> Optional[torch.Tensor]:
        if not flat_deltas:
            return None
        s = torch.tensor([max(1, int(x)) for x in client_samples], dtype=torch.float32)
        s = s / s.sum()
        T = torch.stack(flat_deltas, dim=0)      # [N,D]
        t = (s[:, None] * T).sum(dim=0)          # [D]
        n = t.norm(p=2)
        if n.item() == 0.0:
            return None
        t = t / n

        # orientation locking: keep sign consistent across rounds
        prev = getattr(self, "_tgt_ema", None)
        if isinstance(prev, torch.Tensor) and prev.numel() == t.numel():
            if float(torch.dot(prev, t).item()) < 0.0:
                t = -t

        # EMA smoothing
        if isinstance(prev, torch.Tensor) and prev.numel() == t.numel():
            t = ema_decay * prev + (1.0 - ema_decay) * t
            n = t.norm(p=2)
            if n.item() > 0:
                t = t / n
        self._tgt_ema = t
        return t
    
    def old_estimate_target(
        self,
        flat_deltas: List[torch.Tensor],
        client_samples: List[int],
        ema_decay: float = 0.9,
    ) -> Optional[torch.Tensor]:
        """
        Build a stable target direction using the **sample-weighted mean client delta**
        (layer-whitened + padded space), then smooth with EMA across rounds.
        Works even when weights_2 dimensionality doesn't match.
        """
        if not flat_deltas:
            return None
        s = torch.tensor([max(1, int(x)) for x in client_samples], dtype=torch.float32)
        s = s / s.sum()
        # weighted mean update
        T = torch.stack(flat_deltas, dim=0)            # [N, D]
        t = (s[:, None] * T).sum(dim=0)                # [D]
        n = t.norm(p=2)
        if n.item() == 0.0:
            return None
        t = t / n

        # EMA for stability (no need to add attribute in __init__)
        prev = getattr(self, "_tgt_ema", None)
        if isinstance(prev, torch.Tensor) and prev.numel() == t.numel():
            t = (ema_decay * prev + (1.0 - ema_decay) * t)
            n = t.norm(p=2)
            if n.item() > 0:
                t = t / n
        self._tgt_ema = t
        return t

    def _desired_exit_proportions(self, strategy, no_of_exits: int) -> List[float]:
        """
        Heuristic capacity for each exit.
        Prefer exact trainable counts if available; fallback to width^2; else uniform.
        Deepest exit is last index (no_of_exits-1).
        """
        try:
            # If strategy exposes per-exit trainable counts/list, use it
            counts = getattr(strategy, "_trainables_per_exit", None)
            if counts and len(counts) == no_of_exits:
                v = np.array([max(1.0, float(x)) for x in counts], dtype=float)
                return (v / v.sum()).tolist()
        except Exception:
            pass

        try:
            widths = list(getattr(strategy, "width_scaling", []))
            if widths and len(widths) == no_of_exits:
                w = np.array([max(1e-3, float(x)) for x in widths], dtype=float)
                v = (w * w)  # conv channels roughly scale with width^2
                return (v / v.sum()).tolist()
        except Exception:
            pass

        # uniform
        return [1.0 / float(no_of_exits)] * max(1, int(no_of_exits))

    def _get_client_selector(self, strategy, cid: str, L_full: int):
        """
        Return a list of global-trainable indices for this client (exit-specific),
        or None if the strategy doesn't expose it. Defensive and logs once.
        """
        try:
            # Preferred explicit API (you'll add this in the strategy; see section B)
            if hasattr(strategy, "get_trainable_indices_for_client"):
                idxs = strategy.get_trainable_indices_for_client(cid)
                if idxs is None:
                    return None
                idxs = [int(i) for i in idxs if 0 <= int(i) < L_full]
                return idxs if idxs else None

            # Legacy/heuristic fallbacks (if you already store per-exit maps)
            # Common names to try:
            #   - strategy._indices_per_exit : Dict[int, List[int]]
            #   - strategy.indices_per_exit  : Dict[int, List[int]]
            #   - strategy._exit_to_indices  : Dict[int, List[int]]
            #   - strategy.clients_exit      : Dict[str, int]  (we need both map + per-exit indices)
            exit_map = getattr(strategy, "clients_exit", None)
            if isinstance(exit_map, dict) and cid in exit_map:
                exit_id = int(exit_map[cid])
                for name in ("_indices_per_exit", "indices_per_exit", "_exit_to_indices"):
                    per_exit = getattr(strategy, name, None)
                    if isinstance(per_exit, dict) and exit_id in per_exit:
                        idxs = [int(i) for i in per_exit[exit_id] if 0 <= int(i) < L_full]
                        return idxs if idxs else None
        except Exception:
            pass
        # If we get here, we couldn't find it
        return None

    def _apply_regroup_to_strategy(self, strategy, proxies, round_idx: int):
        """
        Map Owen group ranks (p.lid) → exits.
        - If strict_equal_groups=True: enforce exact, equal-sized exits by directly
        mapping each client's current group to its exit (no hysteresis).
        This guarantees counts like 25/25/25/25 when groups are equal.
        - Else: fall back to hysteretic, capacity-aware mapping (previous behavior).
        """
        if not proxies or strategy is None:
            return

        # how many exits?
        try:
            no_of_exits = int(getattr(strategy, "no_of_exits", 0))
            if no_of_exits <= 0:
                no_of_exits = int(len(getattr(strategy, "blks_to_exit", []))) or 1
        except Exception:
            no_of_exits = 1

        # build group→exit mapping
        ranks_present = sorted({int(getattr(p, "lid", 0)) for p in proxies})
        if self.map_best_to_deepest:
            # lid 0 (best) → deepest exit index
            rank_to_exit = {r: (no_of_exits - 1 - (r % no_of_exits)) for r in ranks_present}
        else:
            # lid r → exit r (round-robin if more groups than exits)
            rank_to_exit = {r: (r % no_of_exits) for r in ranks_present}

        # STRICT mode: guarantee exact counts by overriding any hysteresis
        if bool(getattr(self, "strict_equal_groups", True)):
            new_map = {}
            for p in proxies:
                lid = int(getattr(p, "lid", 0))
                new_map[p.cid] = int(rank_to_exit.get(lid, 0))

            # commit and (optionally) record change times
            strategy.clients_exit = new_map
            # if you want to *reset* change-timestamps (prevents stale effects):
            strategy._clients_exit_last_change = {p.cid: int(round_idx) for p in proxies}

            # log to verify equal buckets
            try:
                from collections import Counter
                cnts = Counter(new_map.values())
                logger.info(f"[OwenMC] exit assignment (strict) (exit:count) = {sorted(cnts.items())}")
            except Exception:
                pass
            return

        # ----------------------------
        # NON-STRICT (previous behavior)
        # ----------------------------
        # capacity proportions → integer quotas
        props = self._desired_exit_proportions(strategy, no_of_exits)[:no_of_exits]
        props = np.array(props, dtype=float)
        props = props / props.sum()
        N = len(proxies)
        quotas = np.floor(props * N).astype(int)
        remainder = N - int(quotas.sum())
        if remainder > 0:
            frac = (props * N) - np.floor(props * N)
            order_frac = np.argsort(-frac)
            for i in order_frac[:remainder]:
                quotas[i] += 1

        # order by value (EMA fallback), high→low
        order = sorted(
            proxies,
            key=lambda p: float(getattr(p, "value", getattr(p, "value_ema", self.neutral_value))),
            reverse=True,
        )

        # desired exit for each proxy based on group
        desired = {p.cid: int(rank_to_exit.get(int(getattr(p, "lid", 0)), 0)) for p in proxies}

        cur_map = getattr(strategy, "clients_exit", {}) or {}
        last_change = getattr(strategy, "_clients_exit_last_change", {}) or {}
        new_map = dict(cur_map)

        # First pass: try to assign respecting hysteresis
        buckets = {e: [] for e in range(no_of_exits)}
        for p in order:
            cid = p.cid
            tgt = desired[cid]
            prev = int(cur_map.get(cid, tgt))
            if prev != tgt:
                last_sw = int(last_change.get(cid, -10**9))
                if (round_idx - last_sw) < int(getattr(self, "min_stay_rounds", 5)):
                    tgt = prev
                else:
                    last_change[cid] = int(round_idx)
            new_map[cid] = tgt
            buckets[tgt].append(cid)

        # Second pass: rebalance to quotas (may override hysteresis if needed)
        for exit_i in range(no_of_exits):
            # if over quota, move the lowest-ranked (by current order) out
            while len(buckets[exit_i]) > int(quotas[exit_i]):
                cid = buckets[exit_i].pop()  # lowest in 'order'
                # find some underfull exit to move into
                moved = False
                for e2 in range(no_of_exits):
                    if len(buckets[e2]) < int(quotas[e2]):
                        buckets[e2].append(cid)
                        new_map[cid] = e2
                        moved = True
                        break
                if not moved:
                    # should not happen, but safety
                    buckets[exit_i].append(cid)
                    break

        strategy.clients_exit = new_map
        strategy._clients_exit_last_change = last_change
        
        # log counts
        try:
            from collections import Counter
            cnts = Counter(new_map.values())
            logger.info(f"[OwenMC] exit assignment (balanced) (exit:count) = {sorted(cnts.items())}")
        except Exception:
            pass

    def _pairwise_sim_flat(self, flats: List[torch.Tensor]) -> torch.Tensor:
        """Cosine similarity on fixed-length, layer-whitened flattened deltas."""
        if not flats:
            return torch.zeros(0, 0)
        X = torch.stack([f / (f.norm() + 1e-12) for f in flats], dim=0)  # [N,D]
        return (X @ X.T).clamp(min=-1.0, max=1.0)

    def _get_trainable_selector(self, strategy, L: int):
        """
        Try to recover the exact trainables selection used by ScaleFLFedAvg.
        Returns either:
        - indices: List[int] of selected layers, or
        - mask:    np.ndarray[bool] of length L,
        - or None if we can't find it (means 'use all L').
        """
        sel = None
        try:
            # common patterns
            if hasattr(strategy, "_trainable_idx") and strategy._trainable_idx:
                sel = list(map(int, strategy._trainable_idx))
            elif hasattr(strategy, "trainable_idx") and strategy.trainable_idx:
                sel = list(map(int, strategy.trainable_idx))
            elif hasattr(strategy, "_trainable_mask") and strategy._trainable_mask is not None:
                m = np.asarray(strategy._trainable_mask).astype(bool)
                if m.size == L:
                    sel = m
            elif hasattr(strategy, "trainable_mask") and strategy.trainable_mask is not None:
                m = np.asarray(strategy.trainable_mask).astype(bool)
                if m.size == L:
                    sel = m
        except Exception:
            sel = None
        # light sanity
        if isinstance(sel, (list, tuple)):
            sel = [i for i in sel if 0 <= int(i) < L]
            if not sel:
                sel = None
        elif isinstance(sel, np.ndarray):
            if sel.dtype != bool or sel.size != L:
                sel = None
        return sel

    def _slice_by_selector(self, weights_list: List[np.ndarray], sel, expect_len: Optional[int] = None) -> List[np.ndarray]:
        """Apply indices/mask selector to a list of layer arrays."""
        if sel is None:
            return weights_list
        if isinstance(sel, (list, tuple)):
            out = [weights_list[i] for i in sel]
        else:  # boolean mask
            out = [w for w, keep in zip(weights_list, sel) if bool(keep)]
        if expect_len is not None and len(out) != expect_len:
            logger.warning(f"[OwenMC] selector produced {len(out)} layers, expected {expect_len}")
        return out

    def _safe_flatten_pad_list(self, base_list: List[np.ndarray], other_list: List[np.ndarray]) -> List[torch.Tensor]:
        """
        Compute per-layer diffs *positionally* but on already-aligned lists.
        If a shape mismatches at position i, we return None for that layer.
        """
        diffs: List[torch.Tensor] = []
        L = len(base_list)
        for i in range(L):
            b = base_list[i]
            n = other_list[i] if i < len(other_list) else None
            if n is not None and b.shape == n.shape:
                tb = torch.from_numpy(np.copy(b)).float()
                tn = torch.from_numpy(np.copy(n)).float()
                diffs.append(tn - tb)
            else:
                diffs.append(None)
        return diffs

    def update_from_server(self, **kwargs):
        if not kwargs:
            return
        def _c01(x): return max(0.0, min(1.0, float(x)))
        def _pos(x): return max(0.0, float(x))
        def _pos_eps(x): return max(1e-6, float(x))
        spec = {
            "n_perm": (int, lambda x: max(1, int(x))),
            "alpha": (float, _c01),
            "ema_decay": (float, _c01),
            "regroup_every": (int, lambda x: max(1, int(x))),
            "beta0": (float, _pos),
            "beta1": (float, _pos),
            "temp0": (float, _pos_eps),
            "temp1": (float, _pos_eps),
            "gamma": (float, _c01),
            "clip_min": (float, _pos),
            "lam": (float, _pos),
            "lam_schedule": (str, str),
            "gamma_schedule": (str, str),
            "normalize_round": (bool, bool),
            "idle_ema": (float, _c01),
            "neutral_value": (float, _pos),
            "strict_equal_groups": (bool, bool),
            "map_best_to_deepest": (bool, bool),
            "aggressive_mode": (bool, bool),
            "intra_mix_eps": (float, _c01),
            "warmup_rounds": (int, lambda x: max(0, int(x))),
            "min_stay_rounds": (int, lambda x: max(0, int(x))),
            "use_target_in_shapley": (bool, bool),
        }
        for k, v in kwargs.items():
            if k in spec:
                caster, clamp = spec[k]
                try:
                    setattr(self, k, clamp(caster(v)))
                except Exception:
                    pass

    def freeze_values(self, server, value=None):
        v = float(self.neutral_value if value is None else value)
        try:
            proxies = self._collect_all_proxies(server)
            for p in proxies:
                p.value = v
                p.value_ema = v
        except Exception:
            pass
        
    def evaluate(
        self,
        *,
        current_weights: List[np.ndarray],
        weights_1: List[List[np.ndarray]],
        client_samples: List[int],
        client_ids: List[str],
        server,
        strategy,
        results,
        **kwargs,
    ) -> Dict[str, List[float]]:
        """
        Owen with correct per-client layer mapping:

        - For each client, compute delta = (returned_tensor - personalized_baseline_tensor)
        ONLY on the tensors that the client actually trained (exit-local payload).
        - Map that delta back into the GLOBAL trainable index space using the per-exit selector.
        - Continue with memory → group reps → group Shapley → intra-group split → EMA → regroup.
        """
        # ---------- setup / schedules ----------
        self._ensure_memory_defaults()
        rnd = int(kwargs.get("round_idx", -1))

        try:
            total_rounds = int(getattr(server.ckp.config.app.args, "global_rounds", 200))
        except Exception:
            total_rounds = 200
        prog = 0.0 if rnd < 0 else max(0.0, min(1.0, rnd / max(1, total_rounds)))

        BETA = float(self.beta0) + (float(self.beta1) - float(self.beta0)) * prog
        TEMP = float(self.temp0) + (float(self.temp1) - float(self.temp0)) * prog

        if self.lam_schedule == "off":
            LAM = 0.0
        elif self.lam_schedule == "linear":
            LAM = float(self.lam) * (1.0 - prog)
        else:
            LAM = float(self.lam)
        LAM = max(float(getattr(self, "lam_min", 0.02)), LAM)

        if self.gamma_schedule == "linear_up":
            GAM = float(self.gamma) * min(1.0, prog)
        elif self.gamma_schedule == "linear_down":
            GAM = float(self.gamma) * (1.0 - prog)
        else:
            GAM = float(self.gamma)

        alpha     = float(self.alpha)
        clip_min  = float(self.clip_min)
        ema_decay = float(self.ema_decay)
        neutral   = float(self.neutral_value)

        # ---------- global shapes / helpers ----------
        L = len(current_weights)  # GLOBAL trainables length
        ref_numels = [int(np.prod(np.asarray(w.shape))) for w in current_weights]
        current_params_obj = weights_to_parameters(current_weights)  # to rebuild per-client baselines

        # Per-exit selector (GLOBAL trainable indices that belong to this exit)
        sel_by_exit: Dict[int, List[int]] = {}
        try:
            for e in range(int(getattr(strategy, "no_of_exits", 1))):
                s = strategy.get_trainable_indices_for_exit(e)
                if s:
                    sel_by_exit[e] = list(map(int, s))
        except Exception:
            pass

        def _client_exit(cid: str) -> int:
            try:
                if hasattr(strategy, "clients_exit") and cid in strategy.clients_exit:
                    return int(strategy.clients_exit[cid])
            except Exception:
                pass
            # fallback heuristic (should rarely be used)
            try:
                return int(cid) % max(1, int(getattr(strategy, "no_of_exits", 1)))
            except Exception:
                return 0

        # ---------- build RAW deltas for the sampled slice (respect personalized mapping) ----------
        updates: List[List[Optional[torch.Tensor]]] = []
        zero_clients = 0
        zero_details: List[str] = []

        for k, local_list in enumerate(weights_1):
            cid = client_ids[k]
            exit_i = _client_exit(cid)

            # personalized baseline that was sent to this client this round
            try:
                base_list = strategy.get_personalized_exit_weights(exit_i, current_params_obj)
            except Exception:
                base_list = None

            # selector for this exit
            sel = sel_by_exit.get(exit_i)
            sel_ok = sel is not None and isinstance(sel, list)

            upd: List[Optional[torch.Tensor]] = [None] * L
            reason = None

            if base_list is None or not sel_ok:
                # no way to align → treat as zero this round (don’t pollute memory with wrong shapes)
                reason = "no-personalized-baseline-or-selector"
            elif len(local_list) != len(base_list):
                reason = f"selector/payload-len-mismatch(sel={len(sel)}, base={len(base_list)}, payload={len(local_list)})"
            elif len(sel) != len(local_list):
                reason = f"selector/payload-len-mismatch(sel={len(sel)}, payload={len(local_list)})"
            else:
                # map each returned tensor to its GLOBAL trainable slot and diff vs personalized baseline
                all_zero = True
                for j, gidx in enumerate(sel):
                    if gidx is None or gidx < 0 or gidx >= L:
                        continue
                    base_np = base_list[j]
                    new_np  = local_list[j]
                    if base_np is None or new_np is None:
                        continue
                    b_shape = np.asarray(base_np).shape
                    n_shape = np.asarray(new_np).shape
                    if b_shape != n_shape:
                        if reason is None:
                            reason = f"shape-mismatch@{j} base={b_shape} new={n_shape}"
                        continue
                    b = torch.from_numpy(np.copy(base_np)).float()
                    n = torch.from_numpy(np.copy(new_np)).float()
                    d = n - b
                    if float(d.abs().sum().item()) > 0.0:
                        all_zero = False
                    upd[gidx] = d
                if all_zero and reason is None:
                    reason = "all-zero-on-selected"

            # zero-delta client?
            is_zero = True
            for t in upd:
                if t is None:
                    continue
                if float(t.abs().sum().item()) > 0.0:
                    is_zero = False
                    break
            if is_zero:
                zero_clients += 1
                zero_details.append(f"cid={cid}:{reason or 'all-none-after-mapping'}")

            updates.append(upd)

        if zero_clients > 0:
            logger.info(
                f"[OwenMC][rnd {rnd}] zero-delta slice clients={zero_clients}/{len(updates)} "
                f"(trainables L={L}, selector={'ok' if sel_by_exit else 'none'})"
            )
            for msg in zero_details[:8]:
                logger.info(f"[OwenMC][rnd {rnd}] zero-delta reason {msg}")

        # Flatten to fixed-length RAW vectors (pad missing tensors with zeros)
        flat_raw_slice = [self._flatten_pad(u, ref_numels) for u in updates]
        D = int(flat_raw_slice[0].numel()) if flat_raw_slice else 0

        # --- layer RMS scales and whitened slice (only once per round) ---
        scales = self._layer_rms(updates)  # per-layer RMS for whitening (length == L)
        flat_white_slice = [
            self._whiten_flat(v, ref_numels, scales) for v in flat_raw_slice
        ] if (flat_raw_slice and getattr(self, "sim_whiten", True)) else flat_raw_slice

        # Log which ones are truly zero before storing
        try:
            for cid, v in zip(client_ids, flat_raw_slice):
                if float(v.abs().sum().item()) == 0.0:
                    n_i = int(max(1, int(client_samples[client_ids.index(cid)])))
                    logger.info(f"[OwenMC] storing ZERO vec for cid={cid} (round={rnd}, n={n_i})")
        except Exception:
            pass

        # Write current slice into memory (RAW, fixed-length)
        slice_proxies = [p for p, _ in results]
        self._update_memory(slice_proxies, flat_raw_slice, client_samples, client_ids, rnd)

        # ---------- target directions (RAW and whitened) ----------
        t_hat_raw = None
        t_hat_white = None

        # try model-level delta (current_weights -> weights_2)
        target_raw = None
        try:
            w2 = kwargs.get("weights_2", None)
            if w2 is not None:
                cur = torch.cat([torch.from_numpy(np.copy(w)).float().flatten() for w in current_weights], dim=0)
                total = int(cur.numel())
                if total == int(sum(int(np.asarray(x).size) for x in w2)):
                    new = torch.cat([torch.from_numpy(np.copy(w)).float().flatten() for w in w2], dim=0)
                    target_raw = (new - cur).view(-1).float()
        except Exception:
            target_raw = None

        def _unit(v: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
            if not isinstance(v, torch.Tensor) or v.numel() == 0:
                return None
            n = v.norm(p=2)
            return (v / (n + 1e-12)) if n.item() > 0 else None

        if target_raw is None:
            # fallback: sample-weighted mean of this slice (RAW)
            s = torch.tensor([max(1, int(x)) for x in client_samples], dtype=torch.float32)
            s = s / (s.sum() + 1e-12)
            T_raw = torch.stack(flat_raw_slice, dim=0) if flat_raw_slice else torch.zeros(0)
            t_raw = (s[:, None] * T_raw).sum(dim=0) if T_raw.numel() > 0 else None
            t_hat_raw = _unit(t_raw)
        else:
            t_hat_raw = _unit(target_raw)

        # whitened target (for alignment/similarity)
        if getattr(self, "sim_whiten", True) and D > 0:
            if target_raw is not None and target_raw.numel() == D:
                t_white = self._whiten_flat(target_raw, ref_numels, scales)
                t_hat_white = _unit(t_white)
            else:
                s = torch.tensor([max(1, int(x)) for x in client_samples], dtype=torch.float32)
                s = s / (s.sum() + 1e-12)
                T_w = torch.stack(flat_white_slice, dim=0) if flat_white_slice else torch.zeros(0)
                t_white = (s[:, None] * T_w).sum(dim=0) if T_w.numel() > 0 else None
                t_hat_white = _unit(t_white)

        # orientation locking (prefer whitened space)
        prev = getattr(self, "_tgt_ema", None)
        base = t_hat_white if t_hat_white is not None else t_hat_raw
        if isinstance(prev, torch.Tensor) and base is not None and prev.numel() == base.numel():
            if float(torch.dot(prev, base).item()) < 0.0:
                base = -base
        if base is not None:
            ema_decay_eff = max(0.5, float(ema_decay))
            if isinstance(prev, torch.Tensor) and prev.numel() == base.numel():
                base = ema_decay_eff * prev + (1.0 - ema_decay_eff) * base
                base = _unit(base)
            self._tgt_ema = base
            if t_hat_white is None and base is not None:
                t_hat_white = base

        # regroup cadence for logging/regroup
        try:
            warmup = int(getattr(self, "warmup_rounds", 0))
            regroup_every = int(getattr(self, "regroup_every", 1))
            do_regroup = (rnd >= warmup) and ((regroup_every <= 1) or (rnd % regroup_every == 0))
        except Exception:
            do_regroup = True
        LOG_THIS_ROUND = bool(getattr(self, "_log_hparams", False)) and bool(do_regroup)

        if LOG_THIS_ROUND:
            mem_cov = sum(1 for _cid, rec in getattr(self, "mem", {}).items() if rec is not None)
            logger.info(
                f"[OwenMC][rnd {rnd}] slice.N={len(client_ids)} fleet.N={len(self._collect_all_proxies(server))} "
                f"mem.cov={mem_cov} D={D} t_hat_raw={'ok' if t_hat_raw is not None else 'none'} "
                f"t_hat_white={'ok' if t_hat_white is not None else 'none'} "
                f"beta={BETA:.3f} temp={TEMP:.3f} gamma={GAM:.3f} lam_eff={LAM:.3f}"
            )

        # ---------- group reps from memory; Shapley ----------
        old_lam = float(self.lam)
        self.lam = float(LAM)

        proxies_all = self._collect_all_proxies(server)

        # memory coverage / stats by lid
        mem_cov_by_lid = self._count_mem_cov_by_lid(proxies_all)
        try:
            ms, zc = {}, {}
            for p in proxies_all:
                lid = int(getattr(p, "lid", 0))
                rec = self.mem.get(p.cid, None)
                if rec is None or "flat" not in rec:
                    continue
                v = rec["flat"].view(-1).float()
                ms[lid] = ms.get(lid, 0.0) + float(v.abs().sum().item())
                zc[lid] = zc.get(lid, 0) + int(v.abs().sum().item() == 0.0)
            memo = {lid: f"{{cnt:{mem_cov_by_lid.get(lid,0)},zero:{zc.get(lid,0)},naninf:0,mass_abs:{ms.get(lid,0.0):.1f}}}"
                    for lid in sorted(set(list(ms.keys()) + list(mem_cov_by_lid.keys())))}
            logger.info(f"[OwenMC][rnd {rnd}] mem_stats_by_lid={{{', '.join(f'{k}:{v}' for k,v in memo.items())}}}")
        except Exception:
            pass

        group_rep = self._fleet_group_reps_from_memory(
            server,
            flat_deltas_raw=flat_raw_slice,
            client_samples=client_samples,
            client_ids=client_ids,
            results=results,
            rnd=rnd,
        )

        try:
            rep_stats_fix = {lid: group_rep[lid].norm().item() for lid in group_rep}
            rep_stats = {
                lid: (
                    group_rep[lid].norm().item(),
                    (float(torch.dot(group_rep[lid], t_hat_raw).item())
                    if (t_hat_raw is not None and group_rep[lid].numel() == (t_hat_raw.numel() if isinstance(t_hat_raw, torch.Tensor) else 0)) else 0.0),
                )
                for lid in group_rep
            }
            logger.info(f"[OwenMC][rnd {rnd}] rep_stats_fix [‖rep‖]: {rep_stats_fix}")
            logger.info(f"[OwenMC][rnd {rnd}] mem_cov_by_lid={mem_cov_by_lid}")
            logger.info(f"[OwenMC][rnd {rnd}] rep_stats [‖rep‖, align_raw]: {rep_stats}")
        except Exception:
            pass

        self._inject_group_prior(group_rep, t_hat_raw, proxies_all, eps_scale=1e-4)

        try:
            abs_mass = {lid: float(group_rep[lid].abs().sum().item()) for lid in group_rep}
            norms    = {lid: float(group_rep[lid].norm().item()) for lid in group_rep}
            logger.info(f"[OwenMC][rnd {rnd}] rep_audit mass_abs≈{abs_mass} ; rep_norms={norms} "
                        f"(lam_eff={LAM:.3f}, hl={float(getattr(self,'stale_half_life',20)):.1f})")
        except Exception:
            pass

        t_for_shap = t_hat_raw if getattr(self, "use_target_in_shapley", True) else None
        g_shap = self._group_shapley_from_rep(group_rep, t_for_shap, lam=LAM, seed=(None if rnd < 0 else rnd))
        try:
            phi_vals = np.array([g_shap[g] for g in sorted(g_shap)])
            neg_cnt = int((phi_vals < 0).sum())
            logger.info(f"[OwenMC][rnd {rnd}] Shapley per-group: {g_shap} "
                        f"(mean={phi_vals.mean():.4f}, std={phi_vals.std():.4f}, neg={neg_cnt})")
        except Exception:
            pass

        self.lam = old_lam  # restore

        # ---------- intra-group allocation ----------
        proxies_all = self._collect_all_proxies(server)
        self._ensure_value_defaults(server)
        prev_vals_all = [float(getattr(p, "value", self.neutral_value)) for p in proxies_all]

        def _mem_vec(cid: str) -> torch.Tensor:
            rec = self.mem.get(cid, None)
            if rec is None or D <= 0:
                return torch.zeros(max(1, D), dtype=torch.float32)
            v = rec["flat"].float().view(-1)
            n = int(v.numel())
            if n == D:
                return v
            if n < D:
                return torch.nn.functional.pad(v, (0, D - n))
            return v[:D]

        def _mem_n(cid: str) -> int:
            rec = self.mem.get(cid, None)
            return int(max(1, int(rec.get("n", 1)))) if rec is not None else 1

        members_by_lid: Dict[int, List[Any]] = {}
        for p in proxies_all:
            members_by_lid.setdefault(int(getattr(p, "lid", 0)), []).append(p)

        # threshold for “flat” similarity; configurable
        flat_thresh = float(getattr(self, "intra_flat_thresh", 1e-4))

        new_values: Dict[str, float] = {}
        for lid, plist in members_by_lid.items():
            if not plist:
                continue

            # whitened, L2-normalized member vectors for similarity/target alignment
            Vs_w, Ns = [], []
            for p in plist:
                v_raw = _mem_vec(p.cid)
                v_w = self._whiten_flat(v_raw, ref_numels, scales) if getattr(self, "sim_whiten", True) else v_raw
                v_w = v_w / (v_w.norm(p=2) + 1e-12)
                Vs_w.append(v_w)
                Ns.append(_mem_n(p.cid))
            X = torch.stack(Vs_w, dim=0) if len(Vs_w) > 0 else torch.zeros(0, max(1, D))
            m_rows = X.shape[0]

            S = (X @ X.T).clamp(min=-1.0, max=1.0) if m_rows > 0 else torch.zeros(0, 0)
            if m_rows > 1:
                off_sum = S.sum(dim=1) - torch.diag(S)
                c_pair = off_sum / float(m_rows - 1)
            else:
                c_pair = S.mean(dim=1) if m_rows > 0 else torch.zeros(0)

            # target alignment in whitened space (nonnegative)
            if t_hat_white is not None and m_rows > 0:
                c_targ = (X @ t_hat_white.view(-1)).clamp(min=0.0)
            else:
                c_targ = torch.zeros_like(c_pair)

            if getattr(self, "aggressive_mode", False):
                # AGGRESSIVE: ignore peer smoothing; pure target alignment (nonnegative)
                c = torch.relu(c_targ)
            else:
                c = torch.relu(c_pair) * (1.0 - float(GAM)) + c_targ * float(GAM)

            # ---- NEW: fallback when the similarity signal is too flat (prevents uniform weights) ----
            c_std = float(c.std().item()) if c.numel() > 0 else 0.0
            if c.numel() > 0 and c_std < flat_thresh:
                mag = torch.tensor([_mem_vec(p.cid).norm(p=2).item() for p in plist], dtype=torch.float32)
                # z-score → keep only >0 to “reward” above-mean mass
                mag = (mag - mag.mean()) / (mag.std() + 1e-12)
                c = torch.clamp(mag, min=0.0)

            s = torch.tensor([max(1, int(n)) for n in Ns], dtype=torch.float32)
            s = s / (s.sum() + 1e-12)

            z = float(BETA) * c + alpha * torch.log(s + 1e-8)
            w = torch.softmax(z / max(1e-6, float(TEMP)), dim=0)
            w = torch.clamp(w, min=clip_min)
            w = w / w.sum()

            # epsilon mixing with uniform to avoid starvation
            eps = getattr(self, "intra_mix_eps", 0.05)
            m_w = w.numel()
            if m_w > 0:
                w = (1.0 - eps) * w + eps * (torch.ones_like(w) / m_w)
                w = w / w.sum()
            gv = float(g_shap.get(lid, 0.0))
            for idx, p in enumerate(plist):
                new_values[p.cid] = gv * float(w[idx])

            if LOG_THIS_ROUND:
                eff_k = 1.0 / float((w * w).sum().item()) if w.numel() > 0 else 0.0
                sim_mean = float(torch.relu(c_pair).mean().item()) if c_pair.numel() > 0 else 0.0
                targ_mean = float(c_targ.mean().item()) if c_targ.numel() > 0 else 0.0
                logger.info(
                    f"[OwenMC][rnd {rnd}] lid={lid} m={len(plist)} gv={g_shap.get(lid,0.0):.4f} "
                    f"sim_mean={sim_mean:.4f} targ_mean={targ_mean:.4f} c_std={c_std:.4e} "
                    f"effK={eff_k:.2f} w[min,max]=({float(w.min().item()):.4f},{float(w.max().item()):.4f})"
                )

        # ---------- normalize & persist ----------
        vals_tensor = torch.tensor(list(new_values.values()), dtype=torch.float32) if new_values else torch.tensor([])
        scale_used = 1.0
        if getattr(self, "aggressive_mode", False):
            self.normalize_round = False
        if bool(self.normalize_round) and vals_tensor.numel() > 0:
            pos = vals_tensor[vals_tensor > 0]
            scale_used = float(pos.mean().item()) if pos.numel() > 0 else 1.0
            if scale_used > 0:
                for k in new_values:
                    new_values[k] /= scale_used

        if LOG_THIS_ROUND and new_values:
            arr = np.array(list(new_values.values()), dtype=float)
            p5, p50, p95 = np.percentile(arr, [5, 50, 95]).tolist()
            logger.info(
                f"[OwenMC][rnd {rnd}] values: mean={arr.mean():.4f} std={arr.std():.4f} "
                f"min={arr.min():.4f} p5={p5:.4f} median={p50:.4f} p95={p95:.4f} max={arr.max():.4f} "
                f"norm_scale={scale_used:.4f}"
            )

        for p in proxies_all:
            nv = float(new_values.get(p.cid, getattr(p, "value", neutral)))
            if ema_decay > 0.0:
                prev = float(getattr(p, "value", neutral))
                ema  = ema_decay * prev + (1.0 - ema_decay) * nv
                p.value = ema
                p.value_ema = ema
            else:
                p.value = nv
                p.value_ema = nv

        if float(getattr(self, "idle_ema", 0.0)) > 0.0:
            seen = set(new_values.keys())
            idle_k = float(self.idle_ema)
            for p in proxies_all:
                if p.cid in seen:
                    continue
                prev = float(getattr(p, "value", neutral))
                ema  = idle_k * prev + (1.0 - idle_k) * neutral
                p.value = ema
                p.value_ema = ema

        if LOG_THIS_ROUND:
            new_vals_all = [float(getattr(p, "value", self.neutral_value)) for p in proxies_all]
            a = np.asarray(prev_vals_all, dtype=float)
            b = np.asarray(new_vals_all, dtype=float)
            if len(a) == len(b) and len(a) > 0 and np.std(a) > 0 and np.std(b) > 0:
                def _ranks(x):
                    order = x.argsort()
                    ranks = np.empty_like(order, dtype=float)
                    ranks[order] = np.arange(len(x), dtype=float)
                    vals, inv, cnt = np.unique(x, return_inverse=True, return_counts=True)
                    avg = np.bincount(inv, ranks) / cnt
                    return avg[inv]
                rx, ry = _ranks(a), _ranks(b)
                sx = (rx - rx.mean()) / (rx.std() + 1e-12)
                sy = (ry - ry.mean()) / (ry.std() + 1e-12)
                rho = float(np.mean(sx * sy))
            else:
                rho = float('nan')
            logger.info(f"[OwenMC][rnd {rnd}] Spearman(rank(value_prev), rank(value_now)) = {rho:.3f}")

        if LOG_THIS_ROUND:
            gm = {}
            for lid, plist in members_by_lid.items():
                if not plist:
                    continue
                gm[lid] = float(np.mean([getattr(p, "value", self.neutral_value) for p in plist]))
            logger.info(f"[OwenMC][rnd {rnd}] pre-regroup group-mean(values): {gm}")

        # ---------- regroup & map to exits ----------
        if do_regroup:
            self._regroup(server, default_val=neutral)
            self._apply_regroup_to_strategy(strategy, proxies_all, rnd)

        # ---------- export full-fleet Owen values ----------
        proxies_all = self._collect_all_proxies(server)
        full_map = {p.cid: float(getattr(p, "value", neutral)) for p in proxies_all}

        # Keep a stable, server-visible cache for downstream aggregation
        self.last_values_map = dict(full_map)
        try:
            # expose on strategy so the aggregator can read it
            setattr(strategy, "_owen_values", dict(full_map))
            setattr(strategy, "_owen_round", int(rnd))
        except Exception:
            pass

        # Slice-aligned output (compat)
        out_vals = [float(full_map.get(cid, neutral)) for cid in client_ids]

        if LOG_THIS_ROUND:
            try:
                import numpy as _np
                arr = _np.array(list(full_map.values()), dtype=float)
                logger.info(
                    f"[OwenMC][rnd {rnd}] fleet Owen stats: "
                    f"mean={arr.mean():.4f} std={arr.std():.4f} "
                    f"min={arr.min():.4f} p50={_np.median(arr):.4f} max={arr.max():.4f}"
                )
            except Exception:
                pass

        return {
            "Owen_MC": out_vals,       # slice-aligned
            "Owen_MC_full": full_map,  # dict across fleet
        }

