import numpy as np
import copy
from flwr.server.strategy import FedAvg as FlowerFedAvg
from flwr.server.client_manager import ClientManager
from src.utils import get_func_from_config
from pprint import pformat
from collections import defaultdict
from src.server.strategies.utils import aggregate_inplace
from src.apps.app_utils import cosine_decay_with_warmup
from typing import Dict, Optional, Tuple, List, Any
from flwr.server.client_proxy import ClientProxy
from flwr.server.strategy.aggregate import aggregate
from flwr.common import (
    EvaluateIns,
    EvaluateRes,
    Parameters,
    Weights,
    Scalar,
    FitRes,
    FitIns,
    parameters_to_weights,
    weights_to_parameters,
)

import logging
logger = logging.getLogger(__name__)

import pdb
import numpy.typing as npt
NDArray = npt.NDArray[Any]
NDArrays = List[NDArray]

class InclusiveFL(FlowerFedAvg):
    '''
    Reimplementation of No One Left Behind: Inclusive Federated Learning Over Heterogeneous Devices (KDD'22)
    This version includes robust initialization and corrected logic for personalization and momentum distillation.
    '''
    def __init__(self, ckp, client_valuation, *args, aggregation='fedavg', aggregation_args={}, beta=0.2, **kwargs):
        super().__init__(*args, **kwargs)
        self.ckp = ckp
        self.config = ckp.config
        self.beta = beta
        self.aggregation = aggregation
        self.aggregation_args = aggregation_args

        self.net_config = self.config.models.net
        arch_fn = get_func_from_config(self.net_config)

        # Create a full-depth, full-width reference model. This is the single source of truth.
        global_net_args = copy.deepcopy(self.net_config.args)
        if 'depth' in global_net_args:
            del global_net_args['depth']
        if 'width_scale' in global_net_args:
            del global_net_args['width_scale']
        global_net = arch_fn(device='cpu', **global_net_args)
        
        self.blks_to_exit = global_net.blks_to_exit
        # Keep BOTH views to be robust to different repos/initializations
        self.global_sd_all_keys = list(global_net.state_dict().keys())                 # full: params + buffers
        self.global_sd_keys     = [n for n, p in global_net.named_parameters() if p.requires_grad]  # trainables only

        
        depth = global_net.total_blocks
        self.no_of_exits = len(global_net.exit_heads)
        self.max_exit = self.no_of_exits - 1
        no_of_blocks_per_exit = depth // self.no_of_exits if self.no_of_exits > 0 else depth
        no_of_clients = self.config.simulation.num_clients
        
        self.blocks_per_stage = [len(stage) for stage in global_net.layers]

        self.exit_local_sd_keys = {}
        self.exit_personalized_sd_values = {}
        self.exit_shared_sd_keys = {}
        self.exit_momentum_sd_block_ids = {}
        self.exit_momentum_sd_values = {}

        for exit_i in range(self.no_of_exits):
            net_args = copy.deepcopy(self.net_config.args)
            blk_to_exit = self.blks_to_exit[exit_i]
            net_args['depth'] = blk_to_exit + 1
            if 'ee_layer_locations' in net_args:
                net_args['ee_layer_locations'] = net_args['ee_layer_locations'][:exit_i + 1]
            
            local_net = arch_fn(device='cpu', **net_args)
            self.exit_local_sd_keys[exit_i] = local_net.trainable_state_dict_keys
            self.exit_personalized_sd_values[exit_i] = {}
            shared_sd_keys = []

            if exit_i != self.max_exit:
                personalized_block_id = (exit_i + 1) * no_of_blocks_per_exit - 1
                is_personalized = lambda key: self.get_global_block_id(key) == personalized_block_id or f'exit_heads.{exit_i}' in key
            else:
                personalized_block_id = depth - no_of_blocks_per_exit
                is_personalized = lambda key: self.get_global_block_id(key) >= personalized_block_id or f'exit_heads.{exit_i}' in key

            for sd_key in local_net.trainable_state_dict_keys:
                if is_personalized(sd_key):
                    self.exit_personalized_sd_values[exit_i][sd_key] = None
                else:
                    shared_sd_keys.append(sd_key)
            self.exit_shared_sd_keys[exit_i] = shared_sd_keys
            
            self.exit_momentum_sd_values[exit_i] = {}
            momentum_sd_block_ids = []
            if exit_i != 0:
                # Include the smaller group’s last block: (exit_i * B) - 1
                earliest_momentum_block_id = max(0, exit_i * no_of_blocks_per_exit - 1)
                for sd_key in local_net.trainable_state_dict_keys:
                    block_id = self.get_global_block_id(sd_key)
                    if block_id != -1 and block_id >= earliest_momentum_block_id:
                        momentum_sd_block_ids.append(block_id)
                        if sd_key not in self.exit_momentum_sd_values[exit_i]:
                            w = global_net.state_dict()[sd_key]
                            # Use a NumPy dtype, e.g., float32, and ensure shape is a tuple
                            self.exit_momentum_sd_values[exit_i][sd_key] = np.zeros(tuple(w.shape), dtype=np.float32)
            self.exit_momentum_sd_block_ids[exit_i] = set(momentum_sd_block_ids)
            
            assert not set(self.exit_personalized_sd_values[exit_i].keys()).intersection(set(shared_sd_keys))
            assert set(self.exit_personalized_sd_values[exit_i].keys()) | set(shared_sd_keys) == set(local_net.trainable_state_dict_keys)

        self.clients_exit = {}
        for i in range(no_of_clients):
            if self.config.app.args.mode == 'maximum':
                max_exit = self.no_of_exits - 1
            else:
                max_exit = i % self.no_of_exits
            self.clients_exit[str(i)] = max_exit
        
        self.aggregation = aggregation
        self.aggregation_args = aggregation_args
        if self.aggregation == 'fedadam':
            assert 'beta_1' in self.aggregation_args
            assert 'beta_2' in self.aggregation_args
            assert 'tau' in self.aggregation_args
            assert 'eta' in self.aggregation_args
            self.m_t = {}
            self.v_t = {}

    def get_global_block_id(self, key: str) -> int:
        if not key.startswith('layers.'): 
            return -1
        try:
            parts = key.split('.')
            stage_idx = int(parts[1])
            block_idx_in_stage = int(parts[2])
            base_index = sum(self.blocks_per_stage[:stage_idx])
            return base_index + block_idx_in_stage
        except (IndexError, ValueError): 
            return -1

    def get_personalized_exit_weights(self, exit_i: int, parameters: Parameters) -> List[NDArrays]:
        local_weights: List[np.ndarray] = []
        global_sd = self._aligned_global_sd(parameters)  # robust map

        for sd_key in self.exit_local_sd_keys[exit_i]:
            if sd_key in self.exit_shared_sd_keys[exit_i]:
                local_weights.append(global_sd[sd_key])
            elif sd_key in self.exit_personalized_sd_values[exit_i]:
                if self.exit_personalized_sd_values[exit_i][sd_key] is None:
                    local_weights.append(global_sd[sd_key])
                else:
                    local_weights.append(self.exit_personalized_sd_values[exit_i][sd_key])
            else:
                raise NotImplementedError(f"Key {sd_key} not classified for exit {exit_i}")
        return local_weights
    
    def configure_fit(self, rnd: int, parameters: Parameters, client_manager: ClientManager) -> List[Tuple[ClientProxy, FitIns]]:
        config = {}
        if self.on_fit_config_fn is not None:
            config = self.on_fit_config_fn(rnd)

        sample_size, min_num_clients = self.num_fit_clients(client_manager.num_available())
        clients = client_manager.sample(num_clients=sample_size, min_num_clients=min_num_clients)

        client_instructions: List[Tuple[ClientProxy, FitIns]] = []
        cache: Dict[int, List[np.ndarray]] = {}

        for client in clients:
            exit_i = self.clients_exit[client.cid]
            if exit_i not in cache:
                cache[exit_i] = self.get_personalized_exit_weights(exit_i, parameters)
            local_weights = cache[exit_i]

            cfg = dict(config)
            cfg["keys_prog"]   = self.exit_local_sd_keys[exit_i]
            cfg["exit_i"]      = int(exit_i)
            cfg["blk_to_exit"] = int(self.blks_to_exit[exit_i])

            client_instructions.append((client, FitIns(weights_to_parameters(local_weights), cfg)))

        return client_instructions

    def _align_params_list_to_trainables(self, parameters: Parameters) -> List[np.ndarray]:
        """Return a list of ndarrays aligned to self.global_sd_keys (trainables).
        Works whether `parameters` came from full state_dict or trainables-only."""
        arrs = parameters_to_weights(parameters)

        # Case A: already trainables (good)
        if len(arrs) == len(self.global_sd_keys):
            return arrs

        # Case B: full state_dict (params + buffers) -> filter to our trainable order
        if hasattr(self, "global_sd_all_keys") and len(arrs) == len(self.global_sd_all_keys):
            full_map = dict(zip(self.global_sd_all_keys, arrs))
            try:
                return [full_map[k] for k in self.global_sd_keys]
            except KeyError as e:
                raise AssertionError(f"[InclusiveFL] trainable key missing in full map: {e}")

        raise AssertionError(
            f"[InclusiveFL] Unexpected parameter length: got {len(arrs)}, "
            f"expected {len(self.global_sd_keys)} (trainables) "
            f"or {len(getattr(self, 'global_sd_all_keys', []))} (full)."
        )

    def _aligned_global_sd(self, parameters: Parameters) -> Dict[str, np.ndarray]:
        """Return {trainable_key: ndarray} aligned to trainables."""
        aligned_list = self._align_params_list_to_trainables(parameters)
        return dict(zip(self.global_sd_keys, aligned_list))

    # ADD inside InclusiveFL class
    def _overlap_region(self, a_shape, b_shape):
        """Return a tuple of slices for the elementwise-overlap between two shapes."""
        return tuple(slice(0, min(a, b)) for a, b in zip(a_shape, b_shape))

    def _make_grad_fullshape(self, initial_weight: np.ndarray, group_param: np.ndarray) -> np.ndarray:
        """
        Build a gradient tensor with the SAME shape as `group_param`.
        Fills only the overlapping region with (initial - group), zeros elsewhere.
        Always float32 to avoid dtype issues and 0-D scalar pitfalls.
        """
        S = group_param.shape
        I = initial_weight.shape
        region = self._overlap_region(S, I)

        g = np.zeros(S, dtype=np.float32)
        iw = np.asarray(initial_weight, dtype=np.float32)
        gp = np.asarray(group_param,    dtype=np.float32)
        if len(region) == 0:
            # 0-D scalar case
            g = iw - gp
            # ensure ndarray, not numpy scalar
            g = np.asarray(g, dtype=np.float32)
        else:
            g[region] = iw[region] - gp[region]
        return g
    def _numeric_mean(self, seq) -> Optional[float]:
        """Return mean over numeric values in seq; skip strings/objects. None if no numeric."""
        vals = []
        for x in seq:
            # plain numbers
            if isinstance(x, (int, float, np.integer, np.floating)):
                vals.append(float(x))
            # array-like: take mean if numeric dtype
            elif isinstance(x, (list, tuple, np.ndarray)):
                arr = np.asarray(x)
                if np.issubdtype(arr.dtype, np.number):
                    vals.append(float(arr.mean()))
            # else: skip (strings, dicts, etc.)
        return float(np.mean(vals)) if vals else None

    def aggregate_fit(
        self,
        rnd: int,
        results: List[Tuple[ClientProxy, FitRes]],
        failures: List[BaseException],
        current_parameters: Parameters,
        server=None,
    ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:
        if not results:
            return None, {}
        if not self.accept_failures and failures:
            return None, {}

        # Full global snapshot
        # global_sd = dict(zip(self.global_sd_keys, parameters_to_weights(current_parameters)))
        global_sd = self._aligned_global_sd(current_parameters)

        # Per-element accumulators
        global_sd_updates = {k: np.zeros_like(v) for k, v in global_sd.items()}
        global_sd_counts  = {k: np.zeros_like(v, dtype=np.float32) for k, v in global_sd.items()}

        # Group results by exit
        exit_clients = defaultdict(list)
        for client, fit_res in results:
            exit_clients[self.clients_exit[client.cid]].append((client, fit_res))
        exit_clients = dict(sorted(exit_clients.items()))

        if self.aggregation == 'fedadam':
            beta_1 = self.aggregation_args['beta_1']
            beta_2 = self.aggregation_args['beta_2']
            tau    = self.aggregation_args['tau']
            eta    = self.aggregation_args['eta']

        for exit_i, group_results in exit_clients.items():
            group_keys  = self.exit_local_sd_keys[exit_i]
            float_keys  = [k for k in group_keys if self._is_float_param(global_sd[k])]
            # (int/bool tensors are skipped entirely)

            # Accumulation buffers (float32 to avoid casting issues)
            max_shapes = {k: global_sd[k].shape for k in float_keys}
            agg_params = {k: np.zeros(s, dtype=np.float32) for k, s in max_shapes.items()}
            cnt_params = {k: np.zeros(s, dtype=np.float32) for k, s in max_shapes.items()}

            group_total_examples = float(sum(fr.num_examples for _, fr in group_results)) or 1.0

            # Average client params within the exit (by examples)
            for _, fit_res in group_results:
                weight = float(fit_res.num_examples)
                client_sd = dict(zip(self.exit_local_sd_keys[exit_i], parameters_to_weights(fit_res.parameters)))
                for key in float_keys:
                    arr = np.asarray(client_sd[key], dtype=np.float32)
                    G = agg_params[key].shape
                    C = arr.shape
                    region = self._overlap_region(G, C)
                    if len(region) == 0:
                        agg_params[key] += weight * arr
                        cnt_params[key] += weight
                    else:
                        agg_params[key][region] += weight * arr[region]
                        cnt_params[key][region] += weight

            group_sd = {}
            for key in float_keys:
                mask = cnt_params[key] > 0
                avg_param = np.zeros_like(agg_params[key])
                avg_param[mask] = agg_params[key][mask] / cnt_params[key][mask]
                group_sd[key] = avg_param  # float32

            # Local (personalized) view
            local_weights = self.get_personalized_exit_weights(exit_i, current_parameters)
            local_sd = dict(zip(self.exit_local_sd_keys[exit_i], local_weights))

            # Gradients only for float keys
            group_sd_grad = {}
            for k in float_keys:
                group_sd_grad[k] = self._make_grad_fullshape(local_sd[k], group_sd[k])

            # Momentum guidance from larger exit to this exit's last block
            if exit_i != self.max_exit:
                for k in self.exit_personalized_sd_values[exit_i].keys():
                    if k in group_sd_grad and k.startswith('layers.') and k in self.exit_momentum_sd_values[exit_i + 1]:
                        mom_buf = np.asarray(self.exit_momentum_sd_values[exit_i + 1][k], dtype=np.float32)
                        g       = group_sd_grad[k]
                        region  = self._overlap_region(g.shape, mom_buf.shape)
                        if len(region) == 0:
                            group_sd_grad[k] = self.beta * mom_buf + (1.0 - self.beta) * g
                        else:
                            g = g.copy()
                            g[region] = self.beta * mom_buf[region] + (1.0 - self.beta) * g[region]
                            group_sd_grad[k] = g

            # Update per-exit params (FedAvg or FedAdam) for float keys
            if self.aggregation == 'fedavg':
                for k in float_keys:
                    iw = np.asarray(local_sd[k], dtype=np.float32)
                    region = self._overlap_region(group_sd[k].shape, iw.shape)
                    if len(region) == 0:
                        group_sd[k] = iw - group_sd_grad[k]
                    else:
                        tmp = group_sd[k]
                        tmp[region] = iw[region] - group_sd_grad[k][region]
                        group_sd[k] = tmp
            else:
                if exit_i not in self.m_t:
                    self.m_t[exit_i] = {k: np.zeros_like(v, dtype=np.float32) for k, v in group_sd.items()}
                    self.v_t[exit_i] = {k: np.zeros_like(v, dtype=np.float32) for k, v in group_sd.items()}
                for k in float_keys:
                    g = group_sd_grad[k]
                    self.m_t[exit_i][k] = beta_1 * self.m_t[exit_i][k] + (1.0 - beta_1) * g
                    self.v_t[exit_i][k] = beta_2 * self.v_t[exit_i][k] + (1.0 - beta_2) * (g * g)
                for k in float_keys:
                    iw   = np.asarray(local_sd[k], dtype=np.float32)
                    step = self.m_t[exit_i][k] / (np.sqrt(self.v_t[exit_i][k]) + tau)
                    region = self._overlap_region(group_sd[k].shape, iw.shape)
                    if len(region) == 0:
                        group_sd[k] = iw - eta * step
                    else:
                        tmp = group_sd[k]
                        tmp[region] = iw[region] - eta * step[region]
                        group_sd[k] = tmp

            # Fold SHARED float keys into global
            for sd_key in self.exit_shared_sd_keys[exit_i]:
                if sd_key not in group_sd:
                    continue  # non-float, skip
                G = global_sd[sd_key].shape
                S = group_sd[sd_key].shape
                region = self._overlap_region(G, S)
                if len(region) == 0:
                    global_sd_updates[sd_key] += group_sd[sd_key] * group_total_examples
                    global_sd_counts[sd_key]  += group_total_examples
                else:
                    global_sd_updates[sd_key][region] += group_sd[sd_key][region] * group_total_examples
                    global_sd_counts[sd_key][region]  += group_total_examples

            # Keep personalized tensors (only float keys were updated)
            for sd_key in self.exit_personalized_sd_values[exit_i].keys():
                if sd_key in group_sd:  # float key updated this round
                    self.exit_personalized_sd_values[exit_i][sd_key] = group_sd[sd_key]
                # else: non-float → leave as-is (use previous / global at send time)

            # Build momentum buffers (use only float grads we computed)
            if exit_i != 0:
                agg_momentum = {key: np.zeros_like(val, dtype=np.float32) for key, val in self.exit_momentum_sd_values[exit_i].items()}
                cnt_momentum = {key: np.zeros_like(val, dtype=np.float32) for key, val in self.exit_momentum_sd_values[exit_i].items()}
                for full_key, grad_value in group_sd_grad.items():
                    block_id = self.get_global_block_id(full_key)
                    if block_id in self.exit_momentum_sd_block_ids[exit_i] and full_key in agg_momentum:
                        M = agg_momentum[full_key].shape
                        G = grad_value.shape
                        region = self._overlap_region(M, G)
                        if len(region) == 0:
                            agg_momentum[full_key] += grad_value.astype(np.float32, copy=False)
                            cnt_momentum[full_key] += 1.0
                        else:
                            agg_momentum[full_key][region] += grad_value[region].astype(np.float32, copy=False)
                            cnt_momentum[full_key][region] += 1.0
                for key in self.exit_momentum_sd_values[exit_i].keys():
                    mask = cnt_momentum[key] > 0
                    self.exit_momentum_sd_values[exit_i][key].fill(0)
                    self.exit_momentum_sd_values[exit_i][key][mask] = agg_momentum[key][mask] / cnt_momentum[key][mask]

        # Finalize the global model where we received updates (float keys only)
        for sd_key in global_sd.keys():
            mask = global_sd_counts[sd_key] > 0
            if np.any(mask):
                global_sd[sd_key][mask] = global_sd_updates[sd_key][mask] / global_sd_counts[sd_key][mask]

        train_summary = defaultdict(list)
        for _, fit_res in results:
            for m, v in fit_res.metrics.items():
                train_summary[m].append(v)

        for k, vlist in train_summary.items():
            mean_val = self._numeric_mean(vlist)
            if mean_val is not None:
                self.ckp.log({f'mean_{k}': mean_val}, step=rnd, commit=False)


        return weights_to_parameters(list(global_sd.values())), {}

    def _is_float_param(self, x) -> bool:
        """Return True if array/dtype is floating (safe to average)."""
        dt = x.dtype if hasattr(x, "dtype") else np.dtype(x)
        return np.issubdtype(dt, np.floating)

    def evaluate(self, parameters: Parameters, partition: str = 'test') -> Optional[Tuple[float, Dict[str, Scalar]]]:
        if self.eval_fn is None:
            return None

        mean_loss, mean_acc = 0.0, 0.0
        metrics = {}
        for exit_i in range(self.no_of_exits):
            blk_for_this_exit = self.blks_to_exit[exit_i]

            # use the exact key order the server used to build the payload
            local_weights = self.get_personalized_exit_weights(exit_i, parameters)
            local_keys    = self.exit_local_sd_keys[exit_i]

            _, _metrics = self.eval_fn(local_weights, partition, exit_i, blk_for_this_exit, local_keys)

            loss_key = f'centralized_{partition}_exit{exit_i}_loss'
            acc_key  = f'centralized_{partition}_exit{exit_i}_acc'
            if loss_key in _metrics and acc_key in _metrics:
                mean_loss += _metrics[loss_key]
                mean_acc  += _metrics[acc_key]
                metrics    = {**metrics, **_metrics}

        if self.no_of_exits > 0:
            mean_loss /= self.no_of_exits
            mean_acc  /= self.no_of_exits

        metrics[f"centralized_{partition}_exit_all_loss"] = mean_loss
        metrics[f"centralized_{partition}_exit_all_acc"]  = mean_acc
        return mean_loss, metrics

    def configure_evaluate(self, rnd: int, parameters: Parameters, client_manager: ClientManager) -> List[Tuple[ClientProxy, EvaluateIns]]:
        config = {}
        if self.on_evaluate_config_fn is not None:
            config = self.on_evaluate_config_fn(rnd)
        evaluate_ins = EvaluateIns(parameters, config)

        if rnd >= 0:
            sample_size, min_num_clients = self.num_evaluation_clients(client_manager.num_available())
            clients = client_manager.sample(num_clients=sample_size, min_num_clients=min_num_clients, all_available=True)
        else:
            clients = list(client_manager.all().values())
        return [(client, evaluate_ins) for client in clients]

    def aggregate_evaluate(self, rnd: int, results: List[Tuple[ClientProxy, EvaluateRes]], failures: List[BaseException]) -> Tuple[Optional[float], Dict[str, Scalar]]:
        if not results: return None, {}
        if not self.accept_failures and failures: return None, {}

        client_results = {}
        for client, evaluate_res in results:
            client_results[client.cid] = (
                evaluate_res.num_examples,
                evaluate_res.loss,
                evaluate_res.metrics,
            )
        
        loss_aggregated, accuracy_results = weighted_loss_avg(client_results, getattr(self, "test_alpha", None))
        return loss_aggregated, accuracy_results

def weighted_loss_avg(results: Dict[str, Tuple[int, float, Optional[Dict[str, float]]]], personalized_fl_groups: Optional[Dict[str, int]] = None) -> Tuple[float, Dict[str, float]]:
    if not results:
        return 0.0, {}
        
    accuracy_results = {}
    if personalized_fl_groups is not None and len(personalized_fl_groups) > 1:
        from_id = 0
        for group, to_id in personalized_fl_groups.items():
            group_examples = 0
            group_correct_preds = defaultdict(float)
            group_loss = 0
            for cid in range(from_id, from_id + int(to_id)):
                if str(cid) in results:
                    num_examples, loss, metrics = results[str(cid)]
                    group_examples += num_examples
                    if metrics:
                        for k, acc in metrics.items():
                            if 'test_acc' in k or 'accuracy' in k:
                                group_correct_preds[k] += num_examples * acc
                            else:
                                group_correct_preds[k] += acc
                    group_loss += num_examples * loss
            from_id += to_id
            for k, v in group_correct_preds.items():
                if 'test_acc' in k or 'accuracy' in k:
                    accuracy_results[f'ps_{k}_alpha{group}({to_id} clients)'] = v / group_examples * 100 if group_examples > 0 else 0
                else:
                    accuracy_results[f'mean_{k}_alpha{group}({to_id} clients)'] = v / float(to_id) if to_id > 0 else 0
    
    num_total_evaluation_examples = sum([res[0] for res in results.values()])
    if num_total_evaluation_examples == 0:
        return 0.0, {}

    weighted_losses = [num_examples * loss for num_examples, loss, _ in results.values()]
    num_correct_preds = defaultdict(list)
    for num_examples, _, metrics in results.values():
        if metrics:
            for k, acc in metrics.items():
                if 'test_acc' in k or 'accuracy' in k:
                    num_correct_preds[k].append(num_examples * acc)
                else:
                    num_correct_preds[k].append(acc)

    for k, v in num_correct_preds.items():
        if 'test_acc' in k or 'accuracy' in k:
            accuracy_results[f'ps_{k}'] = sum(v) / num_total_evaluation_examples * 100
        else:
            accuracy_results[f'mean_{k}'] = np.mean(v)

    return sum(weighted_losses) / num_total_evaluation_examples, accuracy_results
