import os
import copy
from typing import List, Dict
from collections import OrderedDict

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import ray

from config import AttrDict
from src.utils import get_func_from_config
from src.apps.clients import ClassificationClient, epochs_to_batches, ree_early_exit_test
from src.data import cycle

from src.models.model_utils import prune, disable_bn_tracking

import hashlib

import ast
from flwr.common import parameters_to_weights

import traceback

def _peek_shapes(arrs, k=4):
    try:
        return [tuple(np.asarray(a).shape) for a in arrs[:k]]
    except Exception:
        return ["<n/a>"]

def _sd_shapes(sd, keys, k=4):
    out = []
    for kk in keys[:k]:
        if kk in sd:
            try:
                out.append((kk, tuple(sd[kk].shape)))
            except Exception:
                out.append((kk, "<n/a>"))
        else:
            out.append((kk, "<missing-in-sd>"))
    return out


def _coerce_list_arg(val):
    if isinstance(val, (list, tuple)):
        return list(val)
    if isinstance(val, str):
        s = val.strip()
        if s.startswith('['):
            try:
                return list(ast.literal_eval(s))
            except Exception:
                return [int(x) for x in s.strip('[]').split(',') if x]
    return val


def _parse_modes(obj):
    """Return a per-exit list of lowercase mode strings (robust to string/list)."""
    if isinstance(obj, (list, tuple)):
        return [str(x).lower() for x in obj]
    if isinstance(obj, str):
        s = obj.strip()
        if s.startswith("["):
            try:
                vals = ast.literal_eval(s)
                return [str(x).lower() for x in vals]
            except Exception:
                return [s.lower()]
        return [s.lower()]
    return []


def _sig(keys):
    return hashlib.sha1("|".join(keys).encode()).hexdigest()[:10]

def _safe_num_workers(default: int = 0) -> int:
    try:
        # Ray actor path
        n = len(ray.worker.get_resource_ids().get("CPU", []))
        return int(n) if n > 0 else default
    except Exception:
        pass
    try:
        # Fallback: local machine cores minus one (capped)
        return max(0, min(4, (os.cpu_count() or 2) - 1))
    except Exception:
        return default

# src/apps/clients/reefl_classification_client.py
def enforce_black_box_constraints(model: torch.nn.Module, trainable_keys, stage: str):
    """
    PFL-safe parameter gating (gentle):
      - Stage 'train': only keys in `trainable_keys` get requires_grad=True.
      - Do NOT override BN train/eval beyond the model-wide .train()/.eval().
      - Optionally set modules with no direct trainable params to eval() to save compute.
    """
    trainable_keys = set(trainable_keys or [])

    # Global mode flag
    model.train() if stage == "train" else model.eval()

    # Parameter gating
    for name, p in model.named_parameters():
        p.requires_grad = (stage == "train" and name in trainable_keys)

    # Modules with no *direct* trainable params → eval (saves a bit of compute)
    if stage == "train":
        for _, m in model.named_modules():
            has_param = False
            any_trainable = False
            for _, p in m.named_parameters(recurse=False):
                has_param = True
                if getattr(p, "requires_grad", False):
                    any_trainable = True
                    break
            if has_param and not any_trainable:
                m.eval()


def _scalefl_loss(outputs, labels, T: float = 1.0, alpha: float = 0.5):
    """
    ScaleFL paper objective (Eq. 5):
      L = (1/(L*(L+1))) * sum_{i=1..L} i * [ CE(z_i, y) + alpha * KL( z_i || z_L ) ]
    where z_i are exit logits in shallow→deep order and z_L is the last exit (teacher).
    KL term is zero for i==L because student==teacher.

    Args:
      outputs: list[Tensor] of logits per exit (len L), last is teacher
      labels:  Tensor of targets
      T:       temperature for KD
      alpha:   KD weight (λ in paper)
    """
    import torch
    import torch.nn.functional as F

    assert isinstance(outputs, (list, tuple)) and len(outputs) >= 1
    L = len(outputs)
    teacher = outputs[-1].detach()

    norm = float(L * (L + 1))
    ce_total = 0.0
    kd_total = 0.0

    for i, out in enumerate(outputs, start=1):
        ce_total = ce_total + i * F.cross_entropy(out, labels)
        if i < L:
            kd = F.kl_div(
                F.log_softmax(out / T, dim=1),
                F.softmax(teacher / T, dim=1),
                reduction="batchmean",
            ) * (T * T)
            kd_total = kd_total + i * kd

    return (ce_total + alpha * kd_total) / norm


def train(
    net,
    max_exit_layer,
    trainloader,
    valloader,
    optimizer,
    finetune_batch,
    device: str,
    round: int,
    mu: float = 0,
    kl_loss: str = '',
    kl_weight=None,
    kl_softmax_temp=1.0,
    aggregation='fedavg',
    prev_grads=None,
    global_params=None,
    feddyn_alpha=0.0,
    clip=1,
    *,
    trainable_keys=None,
    inclusivefl: bool = False,
    inclusive_weights=None,
    inclusive_gamma: float = 0.7,
    inclusive_norm: str = "sum",
):
    """Train the network. Implements paper-true ScaleFL when kl_loss in {'scalefl','scale','forward'}."""
    if finetune_batch == 0:
        return 0.0

    import torch
    torch.set_grad_enabled(True)
    net.train()
    enforce_black_box_constraints(net, trainable_keys, stage="train")

    # Guard: must have something trainable
    n_tr = sum(1 for _, p in net.named_parameters() if p.requires_grad)
    if n_tr == 0:
        raise RuntimeError("No trainable params selected for this step (empty keyset).")

    criterion = torch.nn.CrossEntropyLoss()
    trainable_keys_set = set(trainable_keys or [])

    # FedDyn bookkeeping tensors to device
    if aggregation == 'feddyn':
        assert prev_grads is not None and global_params is not None
        for k in list(prev_grads.keys()):
            prev_grads[k] = prev_grads[k].to(device)
        for k in list(global_params.keys()):
            global_params[k] = global_params[k].to(device)

    last_round_model = copy.deepcopy(net) if mu > 0 else None

    # ---- ScaleFL / KD toggles (backward-compatible) ----
    mode = str(kl_loss).lower().strip()
    use_scalefl = mode in ("scalefl", "scale", "forward")
    T = float(kl_softmax_temp) if kl_softmax_temp is not None else 1.0
    alpha = float(kl_weight) if kl_weight is not None else 0.5

    # One-time banner
    try:
        some_param = next(net.parameters())
        pdev = str(some_param.device)
    except StopIteration:
        pdev = "<no-params>"
    print(f"[TRAIN start] lid={max_exit_layer} steps={finetune_batch} device={device} "
          f"param_device={pdev} trainables={n_tr} mode="
          f"{'ScaleFL' if use_scalefl else 'Inclusive' if inclusivefl else 'CE-last'}")

    avg_loss, total = 0.0, 0
    train_iter = iter(cycle(trainloader))

    try:
        for step in range(int(finetune_batch)):
            # --------- fetch batch ---------
            try:
                images, labels = next(train_iter)
            except Exception as e:
                print(f"[TRAIN ERROR] fetching batch at step={step}: {e}")
                traceback.print_exc()
                raise

            images, labels = images.to(device), labels.to(device)

            if optimizer is not None:
                optimizer.zero_grad(set_to_none=True)
            else:
                net.zero_grad(set_to_none=True)

            # 1ch → 3ch safety
            # if images.size(1) == 1:
            #     images = images.expand(-1, 3, images.shape[2], images.shape[3])
            
            if images.dim() == 4 and images.size(1) == 1:
                images = images.expand(-1, 3, images.shape[2], images.shape[3])

            # --------- forward ---------
            try:
                outputs = net(images)
            except Exception as e:
                print(f"[TRAIN ERROR] forward at step={step}: {e}")
                print(f"  images.shape={tuple(images.shape)} images.device={images.device} dtype={images.dtype}")
                traceback.print_exc()
                raise
            
            
            if torch.is_tensor(outputs):
                outputs = [outputs]

            norm_outs = []
            for o in outputs:
                if isinstance(o, torch.Tensor) and o.dim() == 3:
                    # time-distributed logits → take last step
                    norm_outs.append(o[:, -1, :])
                else:
                    norm_outs.append(o)
            outputs = norm_outs

            # If labels arrive as sequences, take last token (no-op otherwise)
            if isinstance(labels, torch.Tensor) and labels.dim() > 1:
                labels = labels[:, -1]

            if step == 0:
                try:
                    oshapes = [tuple(o.shape) for o in outputs]
                except Exception:
                    oshapes = ["<n/a>"]
                print(f"[TRAIN step0] outs={len(outputs)} shapes={oshapes} "
                      f"labels={tuple(labels.shape)}")

            # --------- loss ---------
            try:
                if use_scalefl and outputs:
                    loss = _scalefl_loss(outputs, labels, T=T, alpha=alpha)
                elif inclusivefl and outputs:
                    L = len(outputs)
                    if inclusive_weights is None:
                        ws = [inclusive_gamma ** (L - 1 - i) for i in range(L)]
                    else:
                        if isinstance(inclusive_weights, (int, float)):
                            ws = [float(inclusive_weights)] * L
                        else:
                            ws = list(inclusive_weights)
                            if len(ws) != L:
                                if len(ws) == 1:
                                    ws = [float(ws[0])] * L
                                else:
                                    raise ValueError(f"Inclusive weights length {len(ws)} != exits {L}")
                    if inclusive_norm == "mean":
                        s = sum(ws) if sum(ws) > 0 else 1.0
                        ws = [w / s * L for w in ws]
                    loss_terms = [criterion(out, labels) * float(ws[i]) for i, out in enumerate(outputs)]
                    loss = torch.stack(loss_terms).sum()
                else:
                    # Baseline: CE on last exit only
                    final_output = outputs[-1] if outputs else None
                    loss = criterion(final_output, labels)
            except Exception as e:
                print(f"[TRAIN ERROR] loss at step={step}: {e}")
                traceback.print_exc()
                raise

            # --------- regularizers ---------
            try:
                if mu > 0:  # FedProx
                    proximal_term = 0.0
                    for w, w_t in zip(net.parameters(), last_round_model.parameters()):
                        proximal_term = proximal_term + (w - w_t).norm(2)
                    loss = loss + (mu / 2) * proximal_term

                if aggregation == 'feddyn' and trainable_keys_set:
                    for k, param in net.named_parameters():
                        if k not in trainable_keys_set:
                            continue
                        if (k not in prev_grads) or (k not in global_params):
                            continue
                        curr_param = param.flatten()
                        assert prev_grads[k].numel() == curr_param.numel()
                        lin_penalty = torch.dot(curr_param, prev_grads[k])
                        quad_penalty = feddyn_alpha / 2.0 * torch.sum((curr_param - global_params[k]) ** 2)
                        loss = loss - lin_penalty + quad_penalty
            except Exception as e:
                print(f"[TRAIN ERROR] regularizers at step={step}: {e}")
                traceback.print_exc()
                raise

            # --------- backward/step ---------
            if (not torch.is_grad_enabled()) or (not getattr(loss, "requires_grad", False)):
                first_param = next((p for p in net.parameters()), None)
                fp_rg = None if first_param is None else first_param.requires_grad
                raise RuntimeError(
                    f"Detached training loss: grad_enabled={torch.is_grad_enabled()} "
                    f"loss.requires_grad={getattr(loss,'requires_grad',None)} "
                    f"first_param.requires_grad={fp_rg}"
                )

            try:
                loss.backward()
            except Exception as e:
                print(f"[TRAIN ERROR] backward at step={step}: {e}")
                traceback.print_exc()
                # Dump a tiny grad summary
                try:
                    nz = sum(int((p.grad is not None) and p.grad.isfinite().all().item()) for p in net.parameters())
                    print(f"[TRAIN DEBUG] non-None finite grads after backward (partial) = {nz}")
                except Exception:
                    pass
                raise

            if clip:
                try:
                    torch.nn.utils.clip_grad_value_(net.parameters(), clip)
                except Exception as e:
                    print(f"[TRAIN ERROR] grad clip at step={step}: {e}")
                    traceback.print_exc()
                    raise

            try:
                optimizer.step()
            except Exception as e:
                print(f"[TRAIN ERROR] optimizer.step at step={step}: {e}")
                traceback.print_exc()
                raise

            avg_loss += float(loss.detach().cpu())
            total += images.size(0)

        return avg_loss / max(total, 1)

    except Exception as e:
        print(f"[TRAIN FATAL] lid={max_exit_layer} round={round} aborted: {e}")
        traceback.print_exc()
        raise


def last_good_train(
    net,
    max_exit_layer,
    trainloader,
    valloader,
    optimizer,
    finetune_batch,
    device: str,
    round: int,
    mu: float = 0,
    kl_loss: str = '',
    kl_weight=None,
    kl_softmax_temp=1.0,
    aggregation='fedavg',
    prev_grads=None,
    global_params=None,
    feddyn_alpha=0.0,
    clip=1,
    *,
    trainable_keys=None,
    inclusivefl: bool = False,
    inclusive_weights=None,
    inclusive_gamma: float = 0.7,
    inclusive_norm: str = "sum",
):
    """Train the network. Implements paper-true ScaleFL when kl_loss in {'scalefl','scale'}."""
    if finetune_batch == 0:
        return 0.0

    import torch
    torch.set_grad_enabled(True)
    net.train()
    enforce_black_box_constraints(net, trainable_keys, stage="train")

    # Guard: must have something trainable
    if sum(1 for _, p in net.named_parameters() if p.requires_grad) == 0:
        raise RuntimeError("No trainable params selected for this step (empty keyset).")

    criterion = torch.nn.CrossEntropyLoss()
    trainable_keys_set = set(trainable_keys or [])

    # FedDyn bookkeeping tensors to device
    if aggregation == 'feddyn':
        assert prev_grads is not None and global_params is not None
        for k in list(prev_grads.keys()):
            prev_grads[k] = prev_grads[k].to(device)
        for k in list(global_params.keys()):
            global_params[k] = global_params[k].to(device)

    last_round_model = copy.deepcopy(net) if mu > 0 else None

    # ScaleFL toggles
    use_scalefl = str(kl_loss).lower() in ("scalefl", "scale")
    T = float(kl_softmax_temp) if kl_softmax_temp is not None else 1.0
    alpha = float(kl_weight) if kl_weight is not None else 0.5

    avg_loss, total = 0.0, 0
    trainloader = iter(cycle(trainloader))

    for _ in range(finetune_batch):
        images, labels = next(trainloader)
        images, labels = images.to(device), labels.to(device)

        if optimizer is not None:
            optimizer.zero_grad(set_to_none=True)
        else:
            net.zero_grad(set_to_none=True)

        # 1ch → 3ch safety
        if images.size(1) == 1:
            images = images.expand(-1, 3, images.shape[2], images.shape[3])

        outputs = net(images)
        if torch.is_tensor(outputs):
            outputs = [outputs]

        # ------------------------ Loss ------------------------
        if use_scalefl and outputs:
            loss = _scalefl_loss(outputs, labels, T=T, alpha=alpha)

        elif inclusivefl and outputs:
            L = len(outputs)
            if inclusive_weights is None:
                ws = [inclusive_gamma ** (L - 1 - i) for i in range(L)]
            else:
                if isinstance(inclusive_weights, (int, float)):
                    ws = [float(inclusive_weights)] * L
                else:
                    ws = list(inclusive_weights)
                    if len(ws) != L:
                        if len(ws) == 1:
                            ws = [float(ws[0])] * L
                        else:
                            raise ValueError(f"Inclusive weights length {len(ws)} != exits {L}")
            if inclusive_norm == "mean":
                s = sum(ws) if sum(ws) > 0 else 1.0
                ws = [w / s * L for w in ws]
            loss_terms = [criterion(out, labels) * float(ws[i]) for i, out in enumerate(outputs)]
            loss = torch.stack(loss_terms).sum()
        else:
            # Baseline: CE on last exit only
            final_output = outputs[-1] if outputs else None
            loss = criterion(final_output, labels)

        # --------------- Optional regularizers ----------------
        if mu > 0:  # FedProx
            proximal_term = 0.0
            for w, w_t in zip(net.parameters(), last_round_model.parameters()):
                proximal_term = proximal_term + (w - w_t).norm(2)
            loss = loss + (mu / 2) * proximal_term

        if aggregation == 'feddyn' and trainable_keys_set:
            for k, param in net.named_parameters():
                if k not in trainable_keys_set:
                    continue
                if (k not in prev_grads) or (k not in global_params):
                    continue
                curr_param = param.flatten()
                assert prev_grads[k].numel() == curr_param.numel()
                lin_penalty = torch.dot(curr_param, prev_grads[k])
                quad_penalty = feddyn_alpha / 2.0 * torch.sum((curr_param - global_params[k]) ** 2)
                loss = loss - lin_penalty + quad_penalty

        # --------------- Backward / step ----------------------
        if (not torch.is_grad_enabled()) or (not getattr(loss, "requires_grad", False)):
            first_param = next((p for p in net.parameters()), None)
            fp_rg = None if first_param is None else first_param.requires_grad
            raise RuntimeError(
                f"Detached training loss: grad_enabled={torch.is_grad_enabled()} "
                f"loss.requires_grad={getattr(loss,'requires_grad',None)} "
                f"first_param.requires_grad={fp_rg}"
            )

        loss.backward()
        if clip:
            torch.nn.utils.clip_grad_value_(net.parameters(), clip)
        optimizer.step()

        avg_loss += float(loss.detach().cpu())
        total += images.size(0)

    return avg_loss / max(total, 1)


def _shape_fingerprint(model: torch.nn.Module, limit=6):
    fp = []
    for name, m in model.named_modules():
        if isinstance(m, torch.nn.Conv2d):
            w = getattr(m, "weight", None)
            if w is not None:
                fp.append((name, tuple(w.shape)))
        if isinstance(m, torch.nn.BatchNorm2d):
            w = getattr(m, "weight", None)
            if w is not None:
                fp.append((name, ("bn", w.numel())))
        if len(fp) >= limit:
            break
    n_params = sum(p.numel() for p in model.parameters())
    return fp, n_params


class ReeFLClassificationClient(ClassificationClient):
    
    def __init__(
        self,
        cid: str,
        lid: int,
        width_scale: float,
        ckp: AttrDict,
        *args,
        kl_loss: str = '',
        kl_consistency_weight: int = 300,
        kl_weight=None,
        kl_softmax_temp=1.0,
        aggregation: str = 'fedavg',
        clip=1.0,
        **kwargs
    ):
        super().__init__(cid, ckp, *args, **kwargs)

        self.lid = lid
        self.width_scale = width_scale
        self.kl_loss = kl_loss
        self.kl_consistency_weight = kl_consistency_weight
        self.kl_weight = kl_weight
        self.kl_softmax_temp = kl_softmax_temp
        self.aggregation = aggregation
        self.clip = clip

        assert self.aggregation in ['fedavg', 'feddyn']
        assert self.kl_loss in ['', 'forward', 'all', 'dynamic', 'dynamic_ma']
        self.val_set = self.kl_loss == 'dynamic'

        # --- HeteroFL / payload option ---
        app_args = getattr(getattr(self.ckp, "config", AttrDict()), "app", AttrDict())
        app_args = getattr(app_args, "args", AttrDict())
        self.send_full_local_state = bool(getattr(app_args, "send_full_local_state", False))

        # --- InclusiveFL knobs (off unless enabled) ---
        self.inclusivefl = False
        self.inclusive_gamma = 0.7
        self.inclusive_norm  = "sum"
        self.inclusive_weights_cfg = None
        try:
            mode = str(getattr(app_args, "training_loss", "exclusive")).lower()
            self.inclusivefl = bool(getattr(app_args, "inclusivefl", False)) or (mode == "inclusive")
            self.inclusive_gamma = float(getattr(app_args, "inclusive_gamma", 0.7))
            self.inclusive_norm  = str(getattr(app_args, "inclusive_norm", "sum")).lower()
            w = getattr(app_args, "inclusive_weights", None)
            if w is not None:
                if isinstance(w, (list, tuple)):
                    self.inclusive_weights_cfg = [float(x) for x in w]
                else:
                    self.inclusive_weights_cfg = float(w)
        except Exception:
            pass

        # Per-exit mode
        self._modes_per_exit = _parse_modes(getattr(app_args, "pruning_mode", None))
        self.depth_training = False
        try:
            if self._modes_per_exit:
                mode_here = self._modes_per_exit[self.lid] if len(self._modes_per_exit) > self.lid else self._modes_per_exit[0]
                self.depth_training = (str(mode_here).lower() == "depth")
        except Exception:
            self.depth_training = False

        # --- Build sub-architecture for this lid ---
        arch_fn = get_func_from_config(self.net_config)
        net_args = copy.deepcopy(self.net_config.args)
        net_args['last_exit_only'] = True  # we want only this exit’s logits during training
        if "ee_layer_locations" in net_args and net_args["ee_layer_locations"] is not None:
            net_args["ee_layer_locations"] = _coerce_list_arg(net_args["ee_layer_locations"])[: self.lid + 1]

        # read global exit layout from the full model constructed by the base class
        all_blks_to_exit = self.net.blks_to_exit
        self._total_exits_global = len(all_blks_to_exit)
        if self.lid >= self._total_exits_global:
            raise ValueError(f"Client {self.cid}: lid {self.lid} out of bounds for {self._total_exits_global} exits.")

        blk_to_exit = all_blks_to_exit[self.lid]
        net_args['depth'] = blk_to_exit + 1
        net_args['blks_to_exit'] = all_blks_to_exit[: self.lid + 1]
        net_args['no_of_exits'] = self.lid + 1
        net_args['width_scale'] = self.width_scale

        # Overwrite self.net with the truncated sub-model
        self.net = arch_fn(device=self.device, **net_args)
        # Ensure the attribute exists even if the model ignores kwarg
        setattr(self.net, "last_exit_only", bool(net_args.get("last_exit_only", True)))

        # Sanity on exits/heads
        assert len(getattr(self.net, "blks_to_exit")) == self.lid + 1, f"blks_to_exit mismatch for lid={self.lid}"
        if hasattr(self.net, "exit_heads"):
            assert len(self.net.exit_heads) == self.lid + 1, f"exit_heads mismatch for lid={self.lid}"

        # Effective head index in this truncated model
        local_exit_idx = len(self.net.blks_to_exit) - 1
        if local_exit_idx < 0:
            raise RuntimeError(f"[client {self.cid}] sub-model has no exits (blks_to_exit={self.net.blks_to_exit})")
        self._head_prefix = f"exit_heads.{local_exit_idx}."

        # ---------- Payload/key selection (BEFORE printing) ----------
        keys_all = list(self.net.trainable_state_dict_keys)          # params only (trainables)
        sd_all_keys = list(self.net.state_dict().keys())             # params + buffers

        head_prefixes = [f"exit_heads.{j}." for j in range(local_exit_idx + 1)]

        def _is_allowed_head(k: str) -> bool:
            if not k.startswith("exit_heads."):
                return True
            return any(k.startswith(p) for p in head_prefixes)

        def _allowed_prog(k: str) -> bool:
            # payload filtering for this exit (trunk ≤ exit + this head)
            return (not k.startswith("exit_heads.")) or k.startswith(self._head_prefix)

        # Build both: local_sd_keys (params only) and local_full_keys (params + BN buffers)
        if self.inclusivefl:
            # trunk + all heads up to lid
            self.local_sd_keys   = [k for k in keys_all   if _is_allowed_head(k)]
            # include buffers except the integer counters (num_batches_tracked)
            self.local_full_keys = [k for k in sd_all_keys if _is_allowed_head(k) and (not k.endswith("num_batches_tracked"))]
            self.head_only_keys  = [k for k in keys_all if k.startswith(head_prefixes[-1])]
        else:
            # trunk ≤ exit + ONLY this head in payload
            self.local_sd_keys   = [k for k in keys_all    if _allowed_prog(k)]
            self.local_full_keys = [k for k in sd_all_keys if _allowed_prog(k) and (not k.endswith("num_batches_tracked"))]
            self.head_only_keys  = [k for k in keys_all if k.startswith(self._head_prefix)]

        # The set used to gate requires_grad during training
        self.trainable_state_keys = self.local_sd_keys

        # Flags
        self.is_deepest = (self.lid == self._total_exits_global - 1)

        # ----------------- SAFE TO PRINT NOW -----------------
        fp, n_params = _shape_fingerprint(self.net, limit=8)
        print(
            f"[client {self.cid}] build exit={self.lid} depth={net_args['depth']} "
            f"width={net_args['width_scale']:.3f} params={n_params} "
            f"keycount={len(self.local_sd_keys)} fp={fp}"
        )
        print(
            f"[client {self.cid}] lid={self.lid} local_exit_idx={local_exit_idx} "
            f"local_keys={len(self.local_sd_keys)} (first={self.local_sd_keys[:3]})"
        )
        print(f"[client {self.cid}] last_exit_only={getattr(self.net,'last_exit_only',None)} "
            f"(inclusivefl={self.inclusivefl}, kl_loss={self.kl_loss})")

        # For isolated eval reconstruction
        self._eval_arch_fn = arch_fn
        self._eval_net_args = copy.deepcopy(net_args)

        # FedDyn setup
        self.feddyn_alpha = 0.0
        self.prev_grads = None
        if self.aggregation == 'feddyn':
            self.feddyn_alpha = self.ckp.config.server.strategy.args.alpha
            self.prev_grads_filepath = os.path.join(self.ckp.run_dir, f'prev_grads/{self.cid}')
            self.prev_grads = self.ckp.offline_load(self.prev_grads_filepath)
            if self.prev_grads is None:
                self.prev_grads = {
                    k: torch.zeros(v.numel()) for (k, v) in self.net.named_parameters()
                    if k in self.trainable_state_keys
                }
        self.allow_shape_mismatch = True  # hetero safe default
        self._recv_shapes = {}            # per-round shape echo map

    def __get_parameters(self, config=None):
        """
        Always return weights in the exact order negotiated with the server (keys_prog),
        falling back to the minimal local payload if needed.
        """
        sd = self.net.state_dict()
        # prefer the server-provided order from this fit/eval call
        if hasattr(self, "_payload_keys") and self._payload_keys:
            keys = self._payload_keys
        else:
            # fallback: use minimal (params-only) local payload
            keys = getattr(self, "local_sd_keys", None) or list(sd.keys())

        arrs = []
        for k in keys:
            if k not in sd:
                # tolerate missing (strict False elsewhere), just skip
                continue
            arrs.append(sd[k].detach().cpu().numpy())
        return arrs

    def get_parameters(self, config=None):
        sd = self.net.state_dict()
        keys = getattr(self, "_payload_keys", None) or self.local_sd_keys
        out = []
        for k in keys:
            if k not in sd:
                continue
            arr = sd[k].detach().cpu().numpy()
            shp = getattr(self, "_recv_shapes", {}).get(k, None)
            if shp and tuple(arr.shape) != tuple(shp):
                tmp = np.zeros(shp, dtype=arr.dtype)
                n = min(tmp.size, arr.size)
                if n > 0:
                    tmp.reshape(-1)[:n] = arr.reshape(-1)[:n]
                arr = tmp
            out.append(arr)
        print(f"[client {self.cid}] GET params: sending {len(out)} arrays (keys_prog_len={len(keys)})")
        return out

    def set_parameters(self, parameters, *, keys_override=None):
        """
        Robustly load incoming weights:
        - Accepts list/tuple of ndarrays OR a Flower Parameters-like object (duck-typed).
        - Tolerant to length mismatch: we load up to min(len(keys), len(arrs)).
        - Tolerant to shape mismatch: copy min(numel) into a zero-like target tensor.
        - Records the server-provided shape per key so get_parameters() can echo.
        """
        sd = self.net.state_dict()
        keys_train = list(keys_override) if keys_override is not None else self.local_sd_keys

        # Convert "parameters" → list of np arrays (without importing flwr)
        arrs = None
        if isinstance(parameters, (list, tuple)):
            arrs = list(parameters)
        else:
            # Best-effort duck-typing for flwr.common.Parameters
            # We try a few common layouts, but DO NOT crash if unknown.
            try:
                # flwr>=1.1 often has .tensors: List[bytes]; we can’t decode safely w/o dtype.
                # So, if it's not a plain list/tuple, just bail out early with a clear message.
                print(f"[client {self.cid}] set_parameters: unsupported params type {type(parameters).__name__}")
                arrs = []
            except Exception:
                arrs = []

        # Record inbound shapes to echo back later (only for what we actually load)
        self._recv_shapes = {}

        def _to_tensor(x):
            if torch.is_tensor(x):
                return x.detach().cpu()
            return torch.from_numpy(np.asarray(x))

        def _adapt(src: torch.Tensor, tgt: torch.Tensor) -> torch.Tensor:
            # exact match → fast path
            if list(src.shape) == list(tgt.shape):
                return src.to(dtype=tgt.dtype)
            # tolerant: flatten copy min(numel)
            out = torch.zeros_like(tgt, dtype=tgt.dtype)
            n = min(out.numel(), src.numel())
            if n > 0:
                out.view(-1)[:n] = src.to(dtype=tgt.dtype).view(-1)[:n]
            return out

        upto = min(len(keys_train), len(arrs))
        if len(keys_train) != len(arrs):
            print(
                f"[client {self.cid}] set_parameters: length mismatch "
                f"(arrays={len(arrs)} vs keys={len(keys_train)}) → loading first {upto}"
            )
            # Also shorten self._payload_keys for this round so we send the same length back
            if hasattr(self, "_payload_keys") and isinstance(self._payload_keys, list):
                self._payload_keys = self._payload_keys[:upto]

        for k, v in zip(keys_train[:upto], arrs[:upto]):
            if k not in sd:
                continue
            t = _to_tensor(v)
            self._recv_shapes[k] = tuple(t.shape)  # remember server’s shape for echo
            sd[k] = _adapt(t, sd[k])

        self.net.load_state_dict(sd, strict=False)
        
    def _____set_parameters(self, parameters, *, keys_override=None):
        """
        Load incoming weights using exact key order provided by the server (keys_override).
        Tolerant to missing keys; strict shape check unless self.allow_shape_mismatch=True.
        """
        sd = self.net.state_dict()
        arrs = parameters if isinstance(parameters, (list, tuple)) else parameters_to_weights(parameters)
        keys_train = list(keys_override) if keys_override is not None else self.local_sd_keys

        # Debug: early sanity
        try:
            print(f"[FIT recv] cid={self.cid} lid={self.lid} n_arrays={len(arrs)} "
                f"keys_prog_len={len(keys_train)} "
                f"local_expected={len(getattr(self, 'local_full_keys', [])) or len(self.local_sd_keys)}")
        except Exception:
            pass

        allow_mismatch = bool(getattr(self, "allow_shape_mismatch", False))

        def _to_tensor(x):
            if torch.is_tensor(x):
                return x.detach().cpu()
            return torch.from_numpy(np.asarray(x))

        def _fit_shape(src: torch.Tensor, tgt: torch.Tensor) -> torch.Tensor:
            if list(src.shape) == list(tgt.shape):
                return src.to(dtype=tgt.dtype)
            if not allow_mismatch:
                raise ValueError(
                    f"[cid={self.cid}] shape mismatch disallowed: recv={tuple(src.shape)} vs model={tuple(tgt.shape)}"
                )
            # tolerant slice/pad
            slices = [slice(0, min(s, t)) for s, t in zip(src.shape, tgt.shape)]
            out = torch.zeros_like(tgt)
            out[tuple(slices)] = src[tuple(slices)].to(dtype=tgt.dtype)
            return out

        # load only what the server sent, in the server's order
        for k, v in zip(keys_train, arrs):
            if k not in sd:
                continue
            sd[k] = _fit_shape(_to_tensor(v), sd[k])

        self.net.load_state_dict(sd, strict=False)

    def fit(self, parameters, config):
        """
        Train on the server-specified subset. We *always* respect the server 'keys_prog'
        both when loading and when returning the payload (prevents payload length mismatches).

        New default: train trunk≤lid + head_lid (progressive), matching the server's payload.
        To revert to legacy head-only (non-deepest) behavior, set app.args.reefl_train_prog=False.
        """
        # --- Determine the key order for THIS round ---
        try:
            raw_order = (config or {}).get("keys_prog", None)
            order = list(raw_order) if isinstance(raw_order, (list, tuple)) else None
        except Exception:
            order = None
        if not order:
            # Previous-compatible: use trainable payload (≈19 for your GRU)
            order = list(self.local_sd_keys)
            print(f"[client {self.cid}] fit(): keys_prog missing → TRAINABLE payload ({len(order)} keys)")

        # Visibility on what we received
        try:
            n_in = len(parameters) if isinstance(parameters, (list, tuple)) else f"type={type(parameters).__name__}"
        except Exception:
            n_in = "<unknown>"
        print(f"[client {self.cid}] FIT begin: recv_arrays={n_in} keys_prog_len={len(order)}")

        # Adopt order for this round
        self._payload_keys = order
        key_sig_now = _sig(self._payload_keys)

        # 1) LOAD INCOMING WEIGHTS (robust)
        try:
            self.set_parameters(parameters, keys_override=self._payload_keys)
            print(f"[client {self.cid}] FIT after load: mapped {len(self._payload_keys)} keys into local state")
        except Exception as e:
            import traceback
            print(f"[client {self.cid}] ERROR in set_parameters: {e}")
            traceback.print_exc()
            # Return current weights so server can proceed; mark failure in metrics
            return self.get_parameters(), 0, {
                "fed_train_loss": float("nan"),
                "lid": int(self.lid),
                "train_mode": "preload_failed",
                "key_sig": key_sig_now,
                "width_scale": float(self.width_scale),
                "error": "set_parameters_failed",
            }

        round_config = AttrDict(config)
        rnd = int(getattr(round_config, "current_round", 0))

        # --- Decide trainable keys (consistent with ReeFL server) ---
        app_args = getattr(getattr(self.ckp, "config", AttrDict()), "app", AttrDict())
        app_args = getattr(app_args, "args", AttrDict())

        # default True → progressive training (trunk≤lid + this head) for all exits
        force_prog = bool(getattr(app_args, "reefl_train_prog", True))

        if self.inclusivefl:
            keys_for_training = self.local_sd_keys
            train_mode = "train_inclusive"
        elif force_prog or self.depth_training or self.is_deepest:
            # NEW default: always train exactly what we send/receive
            keys_for_training = self.local_sd_keys
            train_mode = "train_prog_exclusive"
        else:
            # Legacy fallback (previous behavior): non-deepest train head-only
            keys_for_training = self.head_only_keys if not self.is_deepest else self.local_sd_keys
            train_mode = "train_head_only(legacy)" if not self.is_deepest else "train_exclusive_deepest"

        # Fallbacks if something ended up empty
        if not keys_for_training:
            live_trainables = [n for n, p in self.net.named_parameters() if p.requires_grad]
            head_prefix = getattr(self, "_head_prefix", f"exit_heads.{(self.lid if hasattr(self, 'lid') else 0)}.")
            cand = [k for k in live_trainables if (not k.startswith("exit_heads.")) or k.startswith(head_prefix)]
            keys_for_training = cand if cand else live_trainables
            train_mode = (train_mode or "") + "+fallback"

        # Enforce black-box gating and set up optimizer/loader
        enforce_black_box_constraints(self.net, keys_for_training, stage="train")
        torch.set_grad_enabled(True)
        self.net.train()

        # 2) OPTIMIZER (robust)
        try:
            optim_func = get_func_from_config(self.net_config.optimizer)

            # Build param groups (keep head_lr_scale behavior; trunk at base LR, head scaled)
            base_lr = float(getattr(round_config, "lr", 0.001))  # safe fallback
            head_lr_scale = float(getattr(app_args, "head_lr_scale", 1.0))

            # current trainables after gating
            name_to_param = {n: p for n, p in self.net.named_parameters() if p.requires_grad}
            if len(name_to_param) == 0:
                payload = self.get_parameters()
                return payload, 0, {
                    "fed_train_loss": 0.0,
                    "lid": int(self.lid),
                    "train_mode": train_mode,
                    "key_sig": key_sig_now,
                    "width_scale": float(self.width_scale),
                    "note": "no_trainable_layers_for_this_exit",
                    "trained_trunk": 0.0,
                    "payload_len": float(len(self._payload_keys)),
                    "trained_len": 0.0,
                }

            head_names = set(self.head_only_keys)
            trunk_params = [p for n, p in name_to_param.items() if n not in head_names]
            head_params  = [p for n, p in name_to_param.items() if n in head_names]

            if head_lr_scale != 1.0 and head_params and trunk_params:
                param_groups = [
                    {"params": trunk_params, "lr": base_lr},
                    {"params": head_params,  "lr": base_lr * head_lr_scale},
                ]
                lr_used = base_lr  # base for logging
            else:
                param_groups = [{"params": list(name_to_param.values())}]
                lr_used = base_lr * (head_lr_scale if (head_params and not trunk_params) else 1.0)

            # optimizer = optim_func(param_groups, lr=lr_used, **self.net_config.optimizer.args)
            opt_args = dict(getattr(self.net_config.optimizer, "args", {}))  # copy to avoid mutating config
            opt_args.pop("lr", None)             # remove lr from config if present
            opt_args["lr"] = lr_used             # enforce the computed lr for this round
            optimizer = optim_func(param_groups, **opt_args)
        except Exception as e:
            import traceback
            print(f"[client {self.cid}] ERROR creating optimizer: {e}")
            traceback.print_exc()
            return self.get_parameters(), 0, {
                "fed_train_loss": float("nan"),
                "lid": int(self.lid),
                "train_mode": "optimizer_init_failed",
                "key_sig": key_sig_now,
                "width_scale": float(self.width_scale),
                "error": "optimizer_init_failed",
            }

        # 3) DATALOADER (robust; uses _safe_num_workers if you added it)
        try:
            try:
                num_workers = _safe_num_workers(0)  # if you added helper
            except NameError:
                # previous default: Ray CPU count (kept as a safe fallback)
                num_workers = len(ray.worker.get_resource_ids().get("CPU", []))

            trainloader = self.dataloader(
                data_pool='train', cid=self.cid, partition='train',
                batch_size=int(self.batch_size), num_workers=num_workers,
                shuffle=True, augment=True
            )
            self.net.to(self.device)
            ds_len = int(len(trainloader.dataset))
            print(f"[client {self.cid}] loader ok: workers={num_workers} dataset_len={ds_len} device={self.device}")
            if ds_len == 0:
                payload = self.get_parameters()
                return payload, 0, {
                    "fed_train_loss": 0.0,
                    "lid": int(self.lid),
                    "train_mode": train_mode,
                    "key_sig": key_sig_now,
                    "width_scale": float(self.width_scale),
                    "note": "empty_client_partition",
                    "trained_trunk": float(int(bool(trunk_params))),
                    "payload_len": float(len(self._payload_keys)),
                    "trained_len": float(len(name_to_param)),
                }
        except Exception as e:
            import traceback
            print(f"[client {self.cid}] ERROR building dataloader: {e}")
            traceback.print_exc()
            return self.get_parameters(), 0, {
                "fed_train_loss": float("nan"),
                "lid": int(self.lid),
                "train_mode": "dataloader_build_failed",
                "key_sig": key_sig_now,
                "width_scale": float(self.width_scale),
                "error": "dataloader_build_failed",
            }

        # epochs → steps (optionally boost deepest)
        deepest_epoch_scale = float(getattr(app_args, "deepest_epoch_scale", 1.0))
        effective_epochs = self.local_epochs * (deepest_epoch_scale if self.is_deepest else 1.0)
        total_fb = epochs_to_batches(effective_epochs, ds_len, self.batch_size)
        if total_fb == 0:
            payload = self.get_parameters()
            return payload, ds_len, {
                "fed_train_loss": 0.0,
                "lid": int(self.lid),
                "train_mode": train_mode,
                "key_sig": key_sig_now,
                "width_scale": float(self.width_scale),
                "note": "zero_steps",
                "trained_trunk": float(int(bool(trunk_params))),
                "payload_len": float(len(self._payload_keys)),
                "trained_len": float(len(name_to_param)),
            }

        # ---- train ----
        print(f"[client {self.cid}] about to TRAIN: steps={int(total_fb)} lr={lr_used} mode={train_mode}")
        loss = train(
            self.net, self.lid, trainloader, None,
            optimizer=optimizer,
            finetune_batch=int(total_fb),
            device=self.device,
            round=rnd,
            mu=self.fedprox_mu,
            kl_loss=self.kl_loss,
            kl_weight=(self.kl_weight if self.kl_weight is not None else self.kl_consistency_weight),
            kl_softmax_temp=self.kl_softmax_temp,
            trainable_keys=list(name_to_param.keys()),
            aggregation=self.aggregation,
            prev_grads=self.prev_grads,
            global_params=None,
            feddyn_alpha=getattr(self, "feddyn_alpha", 0.0),
            clip=self.clip,
            inclusivefl=self.inclusivefl,
            inclusive_weights=self.inclusive_weights_cfg,
            inclusive_gamma=self.inclusive_gamma,
            inclusive_norm=self.inclusive_norm,
        )

        # ---- return EXACTLY what the server asked for ----
        sd = self.net.state_dict()
        payload = [sd[k].detach().cpu().numpy() for k in self._payload_keys if k in sd]
        if len(payload) != len(self._payload_keys):
            payload = [sd[k].detach().cpu().numpy() for k in self._payload_keys if k in sd]

        return payload, ds_len, {
            "fed_train_loss": float(loss),
            "lid": int(self.lid),
            "train_mode": train_mode,
            "key_sig": key_sig_now,
            "width_scale": float(self.width_scale),
            "trained_trunk": float(int(bool(trunk_params))),
            "payload_len": float(len(self._payload_keys)),
            "trained_len": float(len(name_to_param)),
        }

    def evaluate(self, parameters, config):
        # Load using server order if provided (no BN recalibration here)
        order = None
        # try:
        #     order = list((config or {}).get("keys_prog", []))
        # except Exception:
        #     order = None

        try:
            raw_order = (config or {}).get("keys_prog", None)
            order = list(raw_order) if isinstance(raw_order, (list, tuple)) else None
        except Exception:
            order = None
        if not order:
            # order = list(self.net.state_dict().keys())
            order = list(self.local_sd_keys)
            print(f"[client {self.cid}] evaluate(): keys_prog missing → using FULL local state ({len(order)} keys)")

        # When sending eval payloads, many servers still expect the trainable-length back.
        # We keep using FULL here to match what we just loaded with.
        self._payload_keys = order if order else (
            self.local_full_keys if getattr(self, "send_full_local_state", False) else self.local_sd_keys
        )
        # self._payload_keys = order or (self.local_full_keys if getattr(self, "send_full_local_state", False) else self.local_sd_keys)
        self.set_parameters(parameters, keys_override=self._payload_keys)

        round_config = AttrDict(config)
        num_workers = len(ray.worker.get_resource_ids().get("CPU", []))

        eval_net = self._eval_arch_fn(device="cpu", **self._eval_net_args)
        with torch.no_grad():
            eval_net.load_state_dict(copy.deepcopy(self.net.state_dict()), strict=True)
        eval_net.to(self.device)
        for p in eval_net.parameters():
            p.requires_grad = False
        enforce_black_box_constraints(eval_net, [], stage="eval")

        testloader = self.dataloader(
            data_pool='test', cid=self.cid, partition='test',
            batch_size=50, augment=False, num_workers=num_workers
        )

        with torch.inference_mode():
            results = ree_early_exit_test(eval_net, self.lid, testloader, device=self.device)

        out = results[self.lid]
        loss = float(out['loss'])
        acc  = float(out['accuracy'] * 100)

        metrics = {
            f'ps_test_acc_exit{self.lid}': acc,
            f'ps_test_loss_exit{self.lid}': loss,
            "accuracy": acc,
        }
        return loss, len(testloader.dataset), metrics

