from __future__ import annotations

import logging
from typing import Dict, List, Tuple, Optional
import numpy as np
from flwr.common import Parameters, FitRes, Scalar, weights_to_parameters
from flwr.server.client_proxy import ClientProxy

from src.server.strategies import EarlyExitFedAvgSC
from .scalefl_fedavg import ScaleFLFedAvg  # reuse BN-aware prefix aggregation etc.

logger = logging.getLogger(__name__)

class HeteroFLFedAvg(ScaleFLFedAvg):
    """
    HeteroFL = ScaleFL without SNIP.
    - Pure width/prefix aggregation (params + BN buffers).
    - No SNIP masks / no BN recal triggers from server.
    """

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        # Force pure width mode everywhere
        self.pruning_mode = ["scale"] * self.no_of_exits
        self._snip_any = False
        self._central_snip_enabled = False
        self._snip_masks_by_exit = {i: {} for i in range(self.no_of_exits)}

    def configure_fit(self, rnd: int, parameters: Parameters, client_manager):
        # Call the SC base directly so we don't do any SNIP work
        base = EarlyExitFedAvgSC.configure_fit(self, rnd, self._last_global_params, client_manager)
        if not base:
            return base

        patched = []
        for client, fitins in base:
            exit_i = self.clients_exit[client.cid]

            # strip any stray SNIP/BN flags
            c = dict(fitins.config or {})
            c.pop("snip_mask", None)
            c.pop("need_bn_recal", None)
            c.pop("bn_calib_source", None)
            fitins.config = c

            # personalize payload to this exit (prefix/width indexing)
            local_weights = self.get_personalized_exit_weights(exit_i, self._last_global_params)
            fitins.parameters = weights_to_parameters(local_weights)
            patched.append((client, fitins))

            logger.info(
                f"[HeteroFL][cfg][rnd {rnd}] cid={client.cid} exit={exit_i} "
                f"width_scale={float(self.width_scaling[exit_i]):.3f}"
            )

        logger.info(f"[HeteroFL][rnd {rnd}] SNIP disabled; BN-aware prefix aggregation active.")
        return patched
