import numpy as np
import copy
from flwr.server.strategy import FedAvg as FlowerFedAvg
from flwr.server.client_manager import ClientManager
from src.utils import get_func_from_config
from typing import Dict, Optional, Tuple, List, Any
from flwr.server.client_proxy import ClientProxy
from collections import defaultdict
from flwr.common import (
    EvaluateIns,
    EvaluateRes,
    Parameters,
    Weights,
    Scalar,
    FitRes,
    FitIns,
    parameters_to_weights,
    weights_to_parameters,
)

import logging
logger = logging.getLogger(__name__)

class EarlyExitFedAvg(FlowerFedAvg):
    """
    A robust FedAvg strategy that handles heterogeneous client models with different depths and widths.
    It uses element-wise aggregation to correctly combine parameters of different shapes.
    """
    def __init__(self, ckp, client_valuation=None, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.ckp = ckp
        self.config = ckp.config
        self.net_config = self.config.models.net
        arch_fn = get_func_from_config(self.net_config)

        # Create a full-depth, full-width reference model. This is the single source of truth.
        global_net_args = copy.deepcopy(self.net_config.args)
        if 'depth' in global_net_args:
            del global_net_args['depth']
        if 'width_scale' in global_net_args:
            del global_net_args['width_scale']
        global_net = arch_fn(device='cpu', **global_net_args)
        
        # Infer architectural properties directly from the instantiated model
        # --- THE FIX IS HERE: Use only trainable parameters, which are communicated by Flower ---
        self.global_sd_keys = global_net.trainable_state_dict_keys
        self.blks_to_exit = global_net.blks_to_exit
        self.no_of_exits = len(global_net.exit_heads)
        no_of_clients = self.config.simulation.num_clients

        # Determine the set of parameter keys for each possible exit model
        self.exit_local_sd_keys = {}
        for exit_i in range(self.no_of_exits):
            net_args = copy.deepcopy(self.net_config.args)
            blk_to_exit = self.blks_to_exit[exit_i]
            net_args['depth'] = blk_to_exit + 1
            if 'ee_layer_locations' in net_args:
                net_args['ee_layer_locations'] = net_args['ee_layer_locations'][:exit_i + 1]
            
            local_net = arch_fn(device='cpu', **net_args)
            # --- THE FIX IS HERE: Use only trainable parameters for local models too ---
            self.exit_local_sd_keys[exit_i] = local_net.trainable_state_dict_keys

        # Map each client to an exit
        self.clients_exit = {}
        for i in range(no_of_clients):
            if self.config.app.args.mode == 'maximum':
                max_exit = self.no_of_exits - 1
            else:
                max_exit = i % self.no_of_exits
            self.clients_exit[str(i)] = max_exit

    def configure_fit(self, rnd: int, parameters: Parameters, client_manager: ClientManager) -> List[Tuple[ClientProxy, FitIns]]:
        config = {}
        if self.on_fit_config_fn is not None:
            config = self.on_fit_config_fn(rnd)

        sample_size, min_num_clients = self.num_fit_clients(client_manager.num_available())
        clients = client_manager.sample(num_clients=sample_size, min_num_clients=min_num_clients)

        client_instructions = []
        global_sd = dict(zip(self.global_sd_keys, parameters_to_weights(parameters)))

        for client in clients:
            exit_i = self.clients_exit[client.cid]
            local_sd_keys = self.exit_local_sd_keys[exit_i]
            
            # Extract the relevant subset of weights for the client's model
            local_weights = [global_sd[k] for k in local_sd_keys]
            
            # This assertion should now pass because the keys are consistent
            assert len(local_weights) == len(local_sd_keys)

            client_instructions.append((client, FitIns(weights_to_parameters(local_weights), config)))
        return client_instructions

    def aggregate_fit(self, rnd: int, results: List[Tuple[ClientProxy, FitRes]], failures: List[BaseException], current_parameters: Parameters, server=None) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:
        if not results: return None, {}
        if not self.accept_failures and failures: return None, {}
        
        # Create buffers for element-wise aggregation
        global_sd = dict(zip(self.global_sd_keys, parameters_to_weights(current_parameters)))
        agg_updates = {key: np.zeros_like(val, dtype=np.float32) for key, val in global_sd.items()}
        agg_counts = {key: np.zeros_like(val, dtype=np.float32) for key, val in global_sd.items()}

        # Perform element-wise aggregation
        for client, fit_res in results:
            exit_i = self.clients_exit[client.cid]
            local_keys = self.exit_local_sd_keys[exit_i]
            client_weights = parameters_to_weights(fit_res.parameters)
            weight = fit_res.num_examples

            for i, key in enumerate(local_keys):
                # Add the client's update to the corresponding slice of the buffer
                sl = tuple(slice(0, s) for s in client_weights[i].shape)
                agg_updates[key][sl] += weight * client_weights[i]
                agg_counts[key][sl] += weight
        
        # Finalize the average for the new global model
        new_global_sd = {}
        for key, old_val in global_sd.items():
            mask = agg_counts[key] > 0
            new_val = np.copy(old_val)
            new_val[mask] = agg_updates[key][mask] / agg_counts[key][mask]
            new_global_sd[key] = new_val

        return weights_to_parameters(list(new_global_sd.values())), {}

    def evaluate(self, parameters: Parameters, partition: str = 'test') -> Optional[Tuple[float, Dict[str, Scalar]]]:
        if self.eval_fn is None: return None

        global_sd = dict(zip(self.global_sd_keys, parameters_to_weights(parameters)))
        
        # Evaluate each exit model and average the results
        all_metrics = defaultdict(list)
        total_loss = 0.0
        
        for exit_i in range(self.no_of_exits):
            local_keys = self.exit_local_sd_keys[exit_i]
            local_weights = [global_sd[k] for k in local_keys if k in global_sd]
            
            loss, metrics = self.eval_fn(local_weights, partition, exit_i, self.blks_to_exit[exit_i])
            total_loss += loss
            for k, v in metrics.items():
                all_metrics[k].append(v)
        
        # Average metrics across all exits
        final_metrics = {k: np.mean(v) for k, v in all_metrics.items()}
        avg_loss = total_loss / self.no_of_exits if self.no_of_exits > 0 else 0

        return avg_loss, final_metrics
