import logging
import copy
import inspect
from collections import defaultdict
from typing import Dict, Optional, Tuple, List

import numpy as np
import torch
from flwr.server.strategy import FedAvg as FlowerFedAvg
from flwr.server.client_manager import ClientManager
from flwr.server.client_proxy import ClientProxy
from flwr.common import (
    EvaluateIns,
    EvaluateRes,
    Parameters,
    Scalar,
    FitRes,
    FitIns,
    parameters_to_weights,
    weights_to_parameters,
)
import hashlib
from src.utils import get_func_from_config
from src.server.strategies.utils import (
    aggregate_inplace_early_exit,
    aggregate_inplace_early_exit_fedsparseadam,
)

logger = logging.getLogger(__name__)

BN_BUFFER_SUFFIXES = ("running_mean", "running_var", "num_batches_tracked")


def _is_bn_buffer(name: str) -> bool:
    return any(name.endswith(sfx) for sfx in BN_BUFFER_SUFFIXES)

def _to_numpy(x):
    if torch.is_tensor(x):
        return x.detach().cpu().numpy()
    if isinstance(x, np.ndarray):
        return x
    return np.asarray(x)

class EarlyExitFedAvgSC(FlowerFedAvg):
    """
    FedAvg with early-exit support:
    - Keeps wire payload length stable (all keys) to avoid Flower size flips
    - Aggregates only trainable weights (no BN running stats)
    - Maps global -> local per-exit subsets
    """

    def __init__(
        self,
        ckp,
        client_valuation,
        *args,
        aggregation="fedavg",
        aggregation_args=None,
        **kwargs,
    ):
        super().__init__(*args, **kwargs)
        self.ckp = ckp
        self.config = ckp.config
        self.client_valuation = client_valuation

        # ---- Discover canonical topology and key ordering from a full model ----
        self.net_config = self.config.models.net
        arch_fn = get_func_from_config(self.net_config)
        global_net = arch_fn(device="cpu", **self.net_config.args)

        # Stable wire format (ALL keys: params + buffers) and trainables (no BN buffers)
        self.all_sd_keys: List[str] = list(global_net.all_state_dict_keys)
        self.global_sd_keys: List[str] = list(global_net.trainable_state_dict_keys)

        # Early-exit topology
        self.blks_to_exit: List[int] = list(global_net.blks_to_exit)
        self.no_of_exits: int = len(self.blks_to_exit)
        logger.info(f"[EE-FedAvg] blks_to_exit = {self.blks_to_exit} (no_of_exits={self.no_of_exits})")

        # NEW: detect co-located exits (e.g., [-1,-1,-1,-1] -> all last block)
        try:
            self._exits_collocated: bool = len({int(x) for x in self.blks_to_exit}) == 1
        except Exception:
            self._exits_collocated = False
        logger.info("[EE-FedAvg] exits_collocated=%s", self._exits_collocated)

        # Initial templates as safe fallbacks (exact order/shapes)
        gsd = global_net.state_dict()
        self._initial_full_weights:  List[np.ndarray] = [gsd[k].detach().cpu().numpy() for k in self.all_sd_keys]
        self._initial_trainable_weights: List[np.ndarray] = [gsd[k].detach().cpu().numpy() for k in self.global_sd_keys]

        # ---- Training mode: exclusive (Depth/Scale/ReeFL) vs. inclusive (InclusiveFL) ----
        app_args = getattr(self.config.app, "args", None)
        training_loss_mode = str(getattr(app_args, "training_loss", "exclusive")).lower() if app_args else "exclusive"
        self._inclusive_mode = (training_loss_mode == "inclusive")

        # Per-exit TRAINABLE subsets (used for sending/aggregating) and
        # per-exit FULL subsets (params + BN buffers) for evaluation only.
        self.exit_local_sd_keys: Dict[int, List[str]] = {}
        self.exit_full_sd_keys:  Dict[int, List[str]] = {}

        def _is_bn_buffer(name: str) -> bool:
            return name.endswith("running_mean") or name.endswith("running_var") or name.endswith("num_batches_tracked")

        for exit_i in range(self.no_of_exits):
            # Build the exact submodel shape for exit_i (depth = blk_to_exit+1, exits 0..i)
            args_i = copy.deepcopy(self.net_config.args)
            blk_to_exit = int(self.blks_to_exit[exit_i])
            args_i["depth"]        = blk_to_exit + 1
            args_i["blks_to_exit"] = list(self.blks_to_exit[: exit_i + 1])
            args_i["no_of_exits"]  = exit_i + 1
            args_i["last_exit_only"] = False  # keep all exits up to i during eval/inference

            local_net = arch_fn(device="cpu", **args_i)

            # Trainables intersected with global trainables (server aggregates only these)
            k_all_tr = [k for k in local_net.trainable_state_dict_keys if k in self.global_sd_keys]

            if self._inclusive_mode:
                # InclusiveFL: trunk + heads {0..exit_i}
                head_prefixes = [f"exit_heads.{j}." for j in range(exit_i + 1)]
                def _keep_trainable(k: str) -> bool:
                    return (not k.startswith("exit_heads.")) or any(k.startswith(hp) for hp in head_prefixes)
            else:
                # Exclusive heads: trunk + ONLY head_i
                head_prefix = f"exit_heads.{exit_i}."
                def _keep_trainable(k: str) -> bool:
                    return (not k.startswith("exit_heads.")) or k.startswith(head_prefix)

            train_keys = [k for k in k_all_tr if _keep_trainable(k)]
            self.exit_local_sd_keys[exit_i] = train_keys

            # FULL subset = params + BN buffers, respecting the same head filter
            sd_all_local = list(local_net.state_dict().keys())
            def _keep_full(k: str) -> bool:
                return _keep_trainable(k) or _is_bn_buffer(k)

            full_keys = [k for k in sd_all_local if _keep_full(k) and (k in self.all_sd_keys)]
            self.exit_full_sd_keys[exit_i] = full_keys

            logger.info(
                f"[EE-FedAvg] exit {exit_i}: #trainable={len(train_keys)} #full={len(full_keys)} | "
                f"tail={train_keys[-3:] if len(train_keys)>=3 else train_keys}"
            )

        # Client→exit assignment (must mirror the App’s scheduling)
        self.no_of_clients = self.config.simulation.num_clients
        self.clients_exit: Dict[str, int] = {}
        for i in range(self.no_of_clients):
            if getattr(self.config.app.args, "mode", "multi_tier") == "maximum":
                max_exit = self.no_of_exits - 1
            else:
                max_exit = i % self.no_of_exits
            self.clients_exit[str(i)] = max_exit

        # ---- Aggregation rule over TRAINABLES (FedAvg/FedAdam) ----
        self.aggregation = aggregation
        self.aggregation_args = aggregation_args or {}
        assert self.aggregation in ["fedavg", "fedadam"]
        if self.aggregation == "fedadam":
            for req in ["beta_1", "beta_2", "tau", "eta"]:
                assert req in self.aggregation_args, f"Missing {req} in aggregation_args"
            self.m_t: Dict[str, np.ndarray] = {k: np.zeros_like(gsd[k].detach().cpu().numpy()) for k in self.global_sd_keys}
            self.v_t: Dict[str, np.ndarray] = {k: np.zeros_like(gsd[k].detach().cpu().numpy()) for k in self.global_sd_keys}

        # Hints for downstream code: evaluate should prefer FULL subsets (params+buffers)
        self.eval_prefers_full_exit_subset: bool = True

    def _wire_to_full_sd(self, parameters: Parameters, where: str) -> Dict[str, np.ndarray]:
        wire = parameters_to_weights(parameters)

        if len(wire) == len(self.all_sd_keys):
            return dict(zip(self.all_sd_keys, wire))

        if len(wire) == len(self.global_sd_keys):
            full = {}
            init_full = dict(zip(self.all_sd_keys, self._initial_full_weights))
            # Keep BN buffers from template
            for k in self.all_sd_keys:
                if _is_bn_buffer(k):
                    full[k] = init_full[k]
            # Fill trainables from wire
            full.update(dict(zip(self.global_sd_keys, wire)))
            return full

        logger.warning(
            "Len mismatch in %s: expected %d (full) or %d (trainable) but got %d. Using initial full template.",
            where, len(self.all_sd_keys), len(self.global_sd_keys), len(wire),
        )
        return dict(zip(self.all_sd_keys, self._initial_full_weights))

    def _wire_to_global_sd(self, parameters: Parameters, where: str) -> Dict[str, np.ndarray]:
        wire = parameters_to_weights(parameters)

        if len(wire) == len(self.global_sd_keys):
            return dict(zip(self.global_sd_keys, wire))

        if len(wire) == len(self.all_sd_keys):
            full = dict(zip(self.all_sd_keys, wire))
            return {k: full[k] for k in self.global_sd_keys}

        logger.warning(
            "Len mismatch in %s: expected %d (trainable) or %d (full) but got %d. Using initial trainable template.",
            where, len(self.global_sd_keys), len(self.all_sd_keys), len(wire),
        )
        return dict(zip(self.global_sd_keys, self._initial_trainable_weights))

    def configure_fit(
        self, rnd: int, parameters: Parameters, client_manager: ClientManager
    ) -> List[Tuple[ClientProxy, FitIns]]:
        """
        Sample clients and build per-client payloads.
        IMPORTANT CHANGE: we NO LONGER 'rebalance' exits at the server.
        We always respect each client's own lid (mirrors App logic), which fixes
        HeteroFL/ScaleFL mismatches when exits are co-located or '-1'-based.
        """
        config = {}
        if self.on_fit_config_fn is not None:
            config = self.on_fit_config_fn(rnd)

        # Sample clients
        sample_size, min_num_clients = self.num_fit_clients(client_manager.num_available())
        clients = client_manager.sample(num_clients=sample_size, min_num_clients=min_num_clients)

        # --- Rebuild current full/trainable snapshot from the wire ---
        full_sd = self._wire_to_full_sd(parameters, where="configure_fit")
        global_trainable = {k: full_sd[k] for k in self.global_sd_keys}

        # Deterministic lid that mirrors the App (no server-side rebalance)
        def _client_lid(cid: str) -> int:
            try:
                app_mode = getattr(self.config.app.args, "mode", "multi_tier")
            except Exception:
                app_mode = "multi_tier"
            if str(app_mode).lower() == "maximum":
                return max(0, self.no_of_exits - 1)
            return int(cid) % max(1, self.no_of_exits)

        # Build per-client payloads and attach the **exact key order** we used
        client_instructions: List[Tuple[ClientProxy, FitIns]] = []
        for client in clients:
            exit_i = _client_lid(client.cid)
            # keep an up-to-date mirror map for logging elsewhere
            self.clients_exit[client.cid] = exit_i

            local_keys    = self.exit_local_sd_keys[exit_i]
            local_weights = [global_trainable[k] for k in local_keys]

            cfg = dict(config)
            cfg["keys_prog"] = local_keys   # <— handshake payload order
            cfg["lid_hint"]  = exit_i

            sig = hashlib.sha1("|".join(local_keys).encode()).hexdigest()[:10]
            logger.debug(
                "[server] send cid=%s exit=%d len=%d order_sig=%s",
                client.cid, exit_i, len(local_weights), sig
            )
            client_instructions.append((client, FitIns(weights_to_parameters(local_weights), cfg)))

        # Low-volume summary
        try:
            from collections import Counter
            cnt = Counter(_client_lid(c.cid) for c in clients)
            logger.info("[EE-FedAvg][cfg][rnd %d] sampled per-exit: %s", rnd, sorted(cnt.items()))
        except Exception:
            pass

        return client_instructions

    def aggregate_fit(
        self,
        rnd: int,
        results: List[Tuple[ClientProxy, FitRes]],
        failures: List[BaseException],
        current_parameters: Parameters,
        server=None,
    ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:
        if not results:
            return None, {}
        if not self.accept_failures and failures:
            return None, {}

        # Rebuild full current state and slice trainables
        full_sd_cur = self._wire_to_full_sd(current_parameters, where="aggregate_fit")
        global_trainable_np = {k: full_sd_cur[k] for k in self.global_sd_keys}
        global_trainable = {k: torch.from_numpy(v).cpu() for k, v in global_trainable_np.items()}

        # Robust mapping
        clients_local_sd_keys: Dict[str, List[str]] = {}
        used = 0
        for client, fit_res in results:
            n = len(parameters_to_weights(fit_res.parameters))
            lid_from_client = fit_res.metrics.get("lid", None) if fit_res.metrics else None
            try:
                lid = int(lid_from_client) if lid_from_client is not None else self.clients_exit[client.cid]
            except Exception:
                lid = self.clients_exit.get(client.cid, 0)

            keys = self.exit_local_sd_keys.get(lid, [])
            if n != len(keys):
                cand = [e for e, kk in self.exit_local_sd_keys.items() if len(kk) == n]
                if cand:
                    lid = cand[0]
                    keys = self.exit_local_sd_keys[lid]
                else:
                    logger.error(
                        f"[EE-FedAvg][rnd {rnd}] payload len {n} matches no exit keyset "
                        f"(cid={client.cid}, reported_lid={lid_from_client}). Skipping client."
                    )
                    continue
            clients_local_sd_keys[client.cid] = keys
            used += 1

        if used == 0:
            logger.error(f"[EE-FedAvg][rnd {rnd}] No usable client updates after keyset matching")
            out_list = [full_sd_cur[k] for k in self.all_sd_keys]
            return weights_to_parameters(out_list), {}

        # Aggregate (trainables only)
        if self.aggregation == "fedavg":
            agg_list = aggregate_inplace_early_exit(global_trainable, clients_local_sd_keys, results)
            aggregated = dict(zip(self.global_sd_keys, agg_list))
        else:
            beta_1 = self.aggregation_args["beta_1"]
            beta_2 = self.aggregation_args["beta_2"]
            tau = self.aggregation_args["tau"]
            eta = self.aggregation_args["eta"]
            agg_list = aggregate_inplace_early_exit_fedsparseadam(
                global_trainable, clients_local_sd_keys, results, self.m_t, self.v_t, beta_1, beta_2, tau, eta
            )
            aggregated = dict(zip(self.global_sd_keys, agg_list))

        # Merge back onto full (preserve BN buffers)
        merged_full: Dict[str, np.ndarray] = {k: full_sd_cur[k] for k in self.all_sd_keys}
        for k, v in aggregated.items():
            if k in merged_full:
                merged_full[k] = _to_numpy(v)

        out_list = [merged_full[k] for k in self.all_sd_keys]

        # Log mean training metrics
        train_summary = defaultdict(list)
        for _, fit_res in results:
            if fit_res.metrics:
                for m, v in fit_res.metrics.items():
                    train_summary[m].append(v)
        for k, v in train_summary.items():
            self.ckp.log({f"mean_{k}": np.mean(v)}, step=rnd, commit=False)

        return weights_to_parameters(out_list), {}

    def configure_evaluate(
        self, rnd: int, parameters: Parameters, client_manager: ClientManager
    ) -> List[Tuple[ClientProxy, EvaluateIns]]:
        config = {}
        if self.on_evaluate_config_fn is not None:
            config = self.on_evaluate_config_fn(rnd)
        config = dict(config)
        config["keys_prog"] = self.all_sd_keys
        evaluate_ins = EvaluateIns(parameters, config)

        if rnd >= 0:
            sample_size, min_num_clients = self.num_evaluation_clients(client_manager.num_available())
            clients = client_manager.sample(
                num_clients=sample_size, min_num_clients=min_num_clients, all_available=True
            )
        else:
            clients = list(client_manager.all().values())

        return [(client, evaluate_ins) for client in clients]

    def evaluate(self, parameters: Parameters, partition: str = "test"):
        """Paper-faithful centralized evaluation.

        - Uses the per-exit FULL subset (params + BN buffers) built in __init__
        so BN running stats are respected (no re-estimation, no adaptation).
        - Aggregation remains over trainables only (handled elsewhere).
        - Returns per-exit metrics plus exit_all_* means.
        """
        if self.eval_fn is None:
            return None

        # import numpy as np
        # import copy
        # from src.utils import get_func_from_config

        def _resolve_blk_to_exit(exit_i: int, blk_val: int) -> int:
            if int(blk_val) >= 0:
                return int(blk_val)
            arch_fn  = get_func_from_config(self.net_config)
            probe_args = copy.deepcopy(self.net_config.args)
            probe_args.pop("depth", None)  # full depth
            probe = arch_fn(device="cpu", **probe_args)
            try:
                full_depth = sum(len(s) for s in getattr(probe, "layers", []))
                if full_depth and full_depth > 0:
                    return full_depth - 1
            except Exception:
                pass
            try:
                vals = [int(x) for x in getattr(probe, "blks_to_exit", []) if int(x) >= 0]
                if vals:
                    return max(vals)
            except Exception:
                pass
            return 0

        # Reconstruct the FULL state (params + buffers) from the wire.
        full_sd = self._wire_to_full_sd(parameters, where="evaluate")

        logs = {}
        per_exit_losses, per_exit_accs = [], []

        for exit_i in range(self.no_of_exits):
            # Use FULL (params + BN buffers) subset for this exit (paper-faithful)
            local_keys   = self.exit_full_sd_keys[exit_i]
            blk_to_exit  = _resolve_blk_to_exit(exit_i, int(self.blks_to_exit[exit_i]))
            local_weights = [full_sd[k] for k in local_keys]

            res = self.eval_fn(local_weights, partition, exit_i, blk_to_exit, local_keys)
            if res is None:
                continue

            loss_i, metrics_i = res

            # Merge metrics (and make sure per-exit loss exists even if eval_fn omitted it)
            if isinstance(metrics_i, dict):
                logs.update(metrics_i)

            lk = f"centralized_{partition}_exit{exit_i}_loss"
            ak = f"centralized_{partition}_exit{exit_i}_acc"

            logs.setdefault(lk, float(loss_i))
            per_exit_losses.append(float(logs[lk]))
            if ak in logs:
                per_exit_accs.append(float(logs[ak]))

        if per_exit_losses:
            logs[f"centralized_{partition}_exit_all_loss"] = float(np.mean(per_exit_losses))
        if per_exit_accs:
            logs[f"centralized_{partition}_exit_all_acc"]  = float(np.mean(per_exit_accs))

        return logs.get(f"centralized_{partition}_exit_all_loss", 0.0), logs
