from __future__ import annotations

import copy
import logging
from typing import Dict, List, Tuple, Optional

import numpy as np
import torch
from flwr.common import (
    Parameters,
    FitIns,
    FitRes,
    Scalar,
    parameters_to_weights,
    weights_to_parameters,
)
from flwr.server.client_manager import ClientManager
from flwr.server.client_proxy import ClientProxy

from src.utils import get_func_from_config
from src.server.strategies import EarlyExitFedAvgSC

logger = logging.getLogger(__name__)


class ReeFLFedAvg(EarlyExitFedAvgSC):
    """
    FedAvg for ReeFL (exclusive heads):
      - Each client trains a truncated subnet up to its exit_i:
          trunk (blocks <= exit_i) + ONLY head_i (exclusive).
      - Server keeps the global *trainable* snapshot (Flower `parameters`).
      - On send: slice that snapshot to the *exclusive* trainables for the client's exit.
      - On aggregate: weighted element-wise average into the global trainables.
    """

    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)

        arch_fn = get_func_from_config(self.net_config)
        self.exit_local_sd_keys: Dict[int, List[str]] = {}

        # Use the server's canonical GLOBAL trainable order to avoid any ordering drift
        gorder = list(self.global_sd_keys)
        try:
            self.drop_nonfinite: bool = bool(getattr(self.ckp.config.app.args, "drop_nonfinite", True))
        except Exception:
            self.drop_nonfinite = True
            
        for exit_i in range(self.no_of_exits):
            args_i = copy.deepcopy(self.net_config.args)
            blk_to_exit = int(self.blks_to_exit[exit_i])

            args_i["depth"]          = blk_to_exit + 1
            args_i["blks_to_exit"]   = list(self.blks_to_exit[: exit_i + 1])
            args_i["no_of_exits"]    = exit_i + 1
            args_i["last_exit_only"] = False

            local_net = arch_fn(device="cpu", **args_i)

            # Local trainables for this truncated model
            local_trainable_set = set(local_net.trainable_state_dict_keys)

            head_prefix = f"exit_heads.{exit_i}."
            def _keep(k: str) -> bool:
                return (not k.startswith("exit_heads.")) or k.startswith(head_prefix)

            # Build in GLOBAL order, filtered by (local presence ∧ exclusive head)
            keys_exclusive = [k for k in gorder if (k in local_trainable_set) and _keep(k)]
            self.exit_local_sd_keys[exit_i] = keys_exclusive

            logger.info(
                f"[ReeFLFedAvg] exit {exit_i}: exclusive trainables={len(keys_exclusive)} "
                f"tail={keys_exclusive[-3:] if len(keys_exclusive)>=3 else keys_exclusive}"
            )

    @staticmethod
    def _all_finite_list(lst) -> bool:
        """Return False if any array/tensor in lst has NaN/Inf (empty arrays are ok)."""
        import numpy as np
        for arr in lst:
            a = np.asarray(arr)
            if a.size and not np.isfinite(a).all():
                return False
        return True

    def initialize_parameters(self, client_manager):
        # Keep wire length stable (full: params + BN buffers), same as parent.
        return weights_to_parameters(self._initial_full_weights)

    def _overlap_region(self, a_shape, b_shape):
        if len(a_shape) == 0 or len(b_shape) == 0:
            return ()
        return tuple(slice(0, min(a, b)) for a, b in zip(a_shape, b_shape))

    def configure_fit(
        self,
        rnd: int,
        parameters: Parameters,
        client_manager: ClientManager,
    ):
        # Use the parent to SAMPLE clients; then patch payload/config for ReeFL (exclusive heads)
        scheduled = super().configure_fit(rnd, parameters, client_manager)
        if not scheduled:
            return scheduled

        import copy
        from src.utils import get_func_from_config

        def _resolve_blk_to_exit(exit_i: int, blk_val: int) -> int:
            if int(blk_val) >= 0:
                return int(blk_val)
            arch_fn  = get_func_from_config(self.net_config)
            probe_args = copy.deepcopy(self.net_config.args)
            probe_args.pop("depth", None)  # full depth
            probe = arch_fn(device="cpu", **probe_args)
            try:
                full_depth = sum(len(s) for s in getattr(probe, "layers", []))
                if full_depth and full_depth > 0:
                    return full_depth - 1
            except Exception:
                pass
            try:
                vals = [int(x) for x in getattr(probe, "blks_to_exit", []) if int(x) >= 0]
                if vals:
                    return max(vals)
            except Exception:
                pass
            return 0

        # Current global TRAINABLE snapshot (ordered by server's canonical keys)
        global_train_np = self._wire_to_global_sd(parameters, where="configure_fit(reefl)")
        global_sd = {k: global_train_np[k] for k in self.global_sd_keys}

        patched = []
        for client, fitins in scheduled:
            exit_i = self.clients_exit[client.cid]  # set by parent to mirror App logic
            local_keys = self.exit_local_sd_keys[exit_i]  # trunk + ONLY head_i (exclusive)
            local_w = [global_sd[k] for k in local_keys]

            blk_resolved = _resolve_blk_to_exit(exit_i, int(self.blks_to_exit[exit_i]))

            cfg = dict(getattr(fitins, "config", {}) or {})
            cfg["keys_prog"]   = local_keys   # handshake: exact order of trainables
            cfg["lid_hint"]    = int(exit_i)
            cfg["exit_i"]      = int(exit_i)
            cfg["blk_to_exit"] = int(blk_resolved)

            fitins.parameters = weights_to_parameters(local_w)
            fitins.config     = cfg
            patched.append((client, fitins))

        return patched

    
    def aggregate_fit(
        self,
        rnd: int,
        results: List[Tuple[ClientProxy, FitRes]],
        failures: List[BaseException],
        current_parameters: Parameters,
        server=None,
    ):
        # Short-circuit on no results/failures policy
        if not results:
            return None, {}
        if not self.accept_failures and failures:
            return None, {}

        import numpy as np

        dropped_nf = 0
        filtered: List[Tuple[ClientProxy, FitRes]] = []

        for client, fit_res in results:
            # 1) Filter out non-finite payloads
            try:
                local_list = parameters_to_weights(fit_res.parameters)
            except Exception:
                local_list = []
            if self.drop_nonfinite and not self._all_finite_list(local_list):
                logger.warning(f"[ReeFLFedAvg][rnd {rnd}] cid={client.cid} non-finite payload → skipped")
                dropped_nf += 1
                continue

            # 2) Keep ONLY numeric metrics (prevent np.mean on strings in parent)
            if getattr(fit_res, "metrics", None):
                numeric_metrics = {}
                for k, v in list(fit_res.metrics.items()):
                    try:
                        fv = float(v)
                        if np.isfinite(fv):
                            numeric_metrics[k] = fv
                    except Exception:
                        # drop non-numeric (e.g., 'order_sig', 'mode', GUIDs)
                        continue
                fit_res.metrics = numeric_metrics

            filtered.append((client, fit_res))

        # 3) Delegate to parent for actual aggregation (FedAvg/FedAdam + BN-preserving merge)
        new_params, metrics = super().aggregate_fit(
            rnd=rnd,
            results=filtered,
            failures=failures,
            current_parameters=current_parameters,
            server=server,
        )

        metrics = metrics or {}
        metrics["clients_dropped_nonfinite"] = float(dropped_nf)
        return new_params, metrics

