from typing import Dict, List, Tuple, Any, Optional
import numpy as np
from flwr.server.client_proxy import ClientProxy
from flwr.common import FitRes, Parameters
import torch
import numpy as np 
from io import BytesIO

from functools import reduce
from typing import List, Tuple, Any, Dict, cast

import numpy as np

from flwr.common import FitRes, Parameters
from flwr.server.client_proxy import ClientProxy
import numpy.typing as npt

NDArray = npt.NDArray[Any]
NDArrays = List[NDArray]

# src/server/strategies/utils/central_snip.py
from collections import Counter
import copy
import torch.nn as nn

from src.utils import get_func_from_config
from src.models.snip_utils import compute_snip_channel_scores, create_resnet_channel_masks

def _inspect_loader(loader, tag="[SERVER][VAL] "):
    try:
        ds = getattr(loader, "dataset", None)
        ds_name = type(ds).__name__ if ds is not None else "UnknownDataset"
        n_ds = len(ds) if ds is not None else "?"
        n_batches = len(loader) if hasattr(loader, "__len__") else "?"
        print(f"{tag}dataset={ds_name} size={n_ds} batches={n_batches}")
        xb, yb = next(iter(loader))
        hist = Counter(yb.detach().cpu().numpy().ravel().tolist())
        top = ", ".join([f"{k}:{v}" for k, v in sorted(hist.items())[:10]])
        print(f"{tag}first_batch: shape={tuple(xb.shape)} labels({len(hist)} uniq)={{ {top} }}")
    except Exception as e:
        print(f"{tag}inspect failed: {e}")

def build_global_val_loader(ckp):
    """Build a global validation DataLoader (no client id)."""
    data_cfg = ckp.config.data
    data_cls = get_func_from_config(data_cfg)
    dataset  = data_cls(ckp, **data_cfg.args)
    bs = getattr(ckp.config.app.eval_fn, "batch_size", 64)
    dl = dataset.get_dataloader(
        data_pool="val",   # global VAL pool
        partition="val",
        cid=None,
        batch_size=bs,
        augment=False,
        num_workers=0,
    )
    _inspect_loader(dl, tag="[SERVER][VAL] ")
    return dl

def compute_masks_for_exits(
    base_model: nn.Module,
    keep_ratios: List[float],
    valloader,
    device: torch.device,
    snip_batches: int = 5,
) -> Dict[int, Dict[str, List[int]]]:
    """
    For each exit `lid`, compute a SNIP mask on central VAL using `keep_ratios[lid]`.
    Returns {lid -> {layer_name -> [kept_idx...]}}.
    """
    masks_by_exit: Dict[int, Dict[str, List[int]]] = {}

    for lid, keep in enumerate(keep_ratios):
        tag = f"[SERVER][SNIP][exit {lid}] "
        # IMPORTANT: use a *fresh* full-width copy (never a pruned model)
        model = copy.deepcopy(base_model).to(device)
        model.train()

        # How many batches will be used?
        try:
            n_batches = len(valloader)
        except TypeError:
            n_batches = None
        used = snip_batches if n_batches is None else min(snip_batches, n_batches)
        print(f"{tag}VAL scoring | keep_ratio={keep:.3f} | batches_used={used}/{n_batches if n_batches is not None else '?'}")

        scores = compute_snip_channel_scores(
            model=model,
            onebatch_or_loader=valloader,
            loss_fn=torch.nn.CrossEntropyLoss(),
            device=str(device),
            num_batches=used,
            max_per_batch=256,
            log_prefix=tag  # <- logs: grad-accumulated: batches=..., samples=...
        )
        
        masks = create_resnet_channel_masks(scores, keep_ratio=float(keep), model=model)

        # short summary for logs
        total_layers = len(masks)
        total_kept = sum(len(v) for v in masks.values())
        print(f"{tag}built mask: layers={total_layers}, kept_channels_total={total_kept}")
        masks_by_exit[lid] = {k: [int(x) for x in v] for k, v in masks.items()}  # JSON-serializable

    return masks_by_exit

def softmax(x, T=1.0):
    """Compute softmax values for each sets of scores in x."""
    return np.exp(x/T) / np.sum(np.exp(x/T), axis=0)

def get_layer(model, layer_name):
    assert layer_name.endswith('.weight') or layer_name.endswith('.bias'), 'layer name must be learnable (end with .weight or .bias'
    layer = model
    for attrib in layer_name.split('.'):
        layer = getattr(layer, attrib)
    return layer

def bytes_to_ndarray(tensor: bytes) -> NDArray:
    """Deserialize NumPy ndarray from bytes."""
    bytes_io = BytesIO(tensor)
    # WARNING: NEVER set allow_pickle to true.
    # Reason: loading pickled data can execute arbitrary code
    # Source: https://numpy.org/doc/stable/reference/generated/numpy.load.html
    ndarray_deserialized = np.load(bytes_io, allow_pickle=False)  # type: ignore
    return cast(NDArray, ndarray_deserialized)

def parameters_to_ndarrays(parameters: Parameters) -> NDArrays:
    """Convert parameters object to NumPy ndarrays."""
    return [bytes_to_ndarray(tensor) for tensor in parameters.tensors]

def aggregate_inplace(results: List[Tuple[ClientProxy, FitRes]]) -> NDArrays:
    """Compute in-place weighted average. Assumes updated parameters are of the same size"""
    # Count total examples
    num_examples_total = sum([fit_res.num_examples for _, fit_res in results])

    # Compute scaling factors for each result
    scaling_factors = [
        fit_res.num_examples / num_examples_total for _, fit_res in results
    ]

    # Let's do in-place aggregation
    # get first result, then add up each other
    params = [
        scaling_factors[0] * x for x in parameters_to_ndarrays(results[0][1].parameters)
    ]
    for i, (_, fit_res) in enumerate(results[1:]):
        res = (
            scaling_factors[i + 1] * x
            for x in parameters_to_ndarrays(fit_res.parameters)
        )
        params = [reduce(np.add, layer_updates) for layer_updates in zip(params, res)]

    return params

def aggregate_inplace_early_exit(
    global_sd: Dict[str, NDArray],
    clients_local_sd_keys: Dict[str, List[str]],
    results: List[Tuple[ClientProxy, FitRes]],
    rnd: int = -1,
    debug_first: int = 20,
) -> List[NDArray]:
    agg: Dict[str, np.ndarray] = {k: np.zeros_like(v, dtype=np.float32) for k, v in global_sd.items()}
    cnt: Dict[str, np.ndarray] = {k: np.zeros_like(v, dtype=np.float32) for k, v in global_sd.items()}

    mismatches = 0
    for client, fit_res in results:
        w = float(fit_res.num_examples)
        local_keys = clients_local_sd_keys.get(client.cid, [])
        arrays = parameters_to_ndarrays(fit_res.parameters)
        upto = min(len(local_keys), len(arrays))

        for k, arr in zip(local_keys[:upto], arrays[:upto]):
            if k not in global_sd:
                continue
            g_arr = agg[k]
            arr = np.asarray(arr, dtype=g_arr.dtype)
            dims = min(arr.ndim, g_arr.ndim)
            if dims == 0:
                continue
            if arr.shape != g_arr.shape:
                mismatches += 1
            sl = tuple(slice(0, min(arr.shape[d], g_arr.shape[d])) for d in range(dims))
            agg[k][sl] += w * arr[sl]
            cnt[k][sl] += w

    out: List[NDArray] = []
    for k in global_sd.keys():
        updated = np.array(global_sd[k], copy=True)
        mask = cnt[k] > 0
        if mask.any():
            updated[mask] = agg[k][mask] / cnt[k][mask]
        out.append(updated)

    if rnd >= 0:
        print(f"[aggregate_inplace_early_exit] per-element mismatches handled: {mismatches}")
        upd_frac = [(cnt[k] > 0).mean() for k in cnt]
        print(f"[RND {rnd}] avg-updated-frac={np.mean(upd_frac):.3f} "
              f"min={np.min(upd_frac):.3f} max={np.max(upd_frac):.3f}")
    return out

def _align_rank_to_target(arr: np.ndarray, tgt: np.ndarray, tile_on_expand: bool = True) -> np.ndarray:
    """
    Make `arr` have the same rank as `tgt`.

    - If arr.ndim > tgt.ndim: iteratively mean over the last axis.
    - If arr.ndim < tgt.ndim: add singleton dims at the end; optionally tile to tgt shape.
    """
    arr = np.asarray(arr, dtype=np.float32)
    tgt = np.asarray(tgt)

    # Reduce (e.g., 4D conv -> 2D linear)
    while arr.ndim > tgt.ndim:
        arr = arr.mean(axis=-1)

    # Expand (e.g., 2D linear -> 4D conv)
    while arr.ndim < tgt.ndim:
        arr = np.expand_dims(arr, axis=-1)

    if tile_on_expand and arr.shape != tgt.shape:
        # Tile only along trailing axes we just added (safe heuristic)
        reps = []
        for d in range(arr.ndim):
            want = tgt.shape[d]
            have = arr.shape[d]
            reps.append(1 if have == want else (want if have == 1 else 1))
        arr = np.tile(arr, reps)

    return arr.astype(np.float32, copy=False)

def aggregate_scalefl(global_sd, clients_local_sd_keys, results, is_weight=None):
    """
    FedAvg over heterogeneous (depth/width) trainables.

    * Overlap-slice per tensor (prefix on every dim).
    * Sample-weighted average.
    * Keep previous global values where no client updated.
    * Ignores non-float tensors (trainables are floats anyway).
    """
    import numpy as np
    from src.server.strategies.utils_snow import parameters_to_ndarrays  # or your local import

    # Accumulators (float32 for safety)
    agg = {k: np.zeros_like(v, dtype=np.float32) for k, v in global_sd.items()}
    cnt = {k: np.zeros_like(v, dtype=np.float32) for k, v in global_sd.items()}

    num_examples_total = float(sum(fr.num_examples for _, fr in results)) or 1.0

    for client, fit_res in results:
        w = float(fit_res.num_examples) / num_examples_total
        keys = clients_local_sd_keys[client.cid]
        arrs = parameters_to_ndarrays(fit_res.parameters)

        # Optional sanity
        if len(keys) != len(arrs):
            # Mismatched payload; skip safely
            continue

        for k, arr in zip(keys, arrs):
            if k not in agg:
                continue
            if not np.issubdtype(arr.dtype, np.floating):
                # trainables are floats; buffers (ints) should not be here
                continue

            g = agg[k]
            c = arr.astype(np.float32, copy=False)

            # Only average overlapping prefix; require same rank for a valid match
            if c.ndim != g.ndim:
                # If this happens, it indicates a key mapping bug; skip to be safe.
                continue

            region = tuple(slice(0, min(gs, cs)) for gs, cs in zip(g.shape, c.shape))
            if any(s.stop == 0 for s in region):  # empty overlap
                continue

            g[region] += w * c[region]
            cnt[k][region] += w

    # Finalize: start from previous global, then overwrite only updated positions
    out_list = []
    for k in global_sd.keys():
        out = np.array(global_sd[k], copy=True)
        m = cnt[k] > 0
        if m.any():
            out[m] = agg[k][m] / cnt[k][m]
        out_list.append(out)

    return out_list

def aggregate_inplace_early_exit_feddyn(global_sd: Dict[str, NDArrays], 
        clients_local_sd_keys: Dict[str, List[str]], 
        results: List[Tuple[ClientProxy, FitRes]],
        h_dict: Dict[str, NDArrays], # weights and biases only
        num_clients: int,
        alpha: float,
        ) -> NDArrays:
    """Compute in-place FedDyn with results of varying sizes.
        https://arxiv.org/pdf/2111.04263.pdf
        Modified from: https://github.com/adap/flower/blob/main/baselines/depthfl/depthfl/strategy.py
    """
    aggregated_sd = {k:np.zeros(v.shape) for k, v in global_sd.items()}

    # Count total examples per sd key
    aggregated_sd_count = {k: 0 for k in global_sd.keys()} 
    # for client, fit_res in results:
    #     local_sd_keys = clients_local_sd_keys[client.cid]
    #     for k in local_sd_keys: 
    #         aggregated_sd_count[k] += 1
    
    # in-place aggregation!
    for client, fit_res in results:
        local_sd_keys = clients_local_sd_keys[client.cid]
        assert len(local_sd_keys) == len(parameters_to_ndarrays(fit_res.parameters))
        for k, x in zip(local_sd_keys, parameters_to_ndarrays(fit_res.parameters)):
            aggregated_sd[k] += x # take sum without weighting
            aggregated_sd_count[k] += 1

    # update h variable and apply it
    for k, v in aggregated_sd.items():
        if aggregated_sd_count[k] > 0:
            aggregated_sd[k] = v / aggregated_sd_count[k]

            h_dict[k] = (
                    h_dict[k] 
                    - alpha 
                    * aggregated_sd_count[k]
                    * (aggregated_sd[k] - global_sd[k])
                    / num_clients
            )
    
            aggregated_sd[k] = aggregated_sd[k] - h_dict[k] / alpha

    # for keys without updates, use global parameters
    keys_without_update = [k for k, count in aggregated_sd_count.items() if count == 0]

    for k in keys_without_update:
        aggregated_sd[k] = global_sd[k]
    
    return list(aggregated_sd.values())

def aggregate_inplace_early_exit_fedsparseadam(global_sd: Dict[str, NDArrays], 
                                clients_local_sd_keys: Dict[str, List[str]], 
                                results: List[Tuple[ClientProxy, FitRes]],
                                m_t: Dict[str, NDArrays],
                                v_t: Dict[str, NDArrays],
                                beta_1: float,
                                beta_2: float,
                                tau: float,
                                eta: float) -> NDArrays:
    """Compute in-place weighted average with results of varying sizes."""
    aggregated_sd = {k:np.zeros(v.shape) for k, v in global_sd.items()}

    # Count total examples per sd key
    aggregated_sd_num_examples_total = {k: 0 for k in global_sd.keys()} 
    for client, fit_res in results:
        local_sd_keys = clients_local_sd_keys[client.cid]
        for k in local_sd_keys: 
            aggregated_sd_num_examples_total[k] += fit_res.num_examples
    
    # in-place aggregation!
    for client, fit_res in results:
        local_sd_keys = clients_local_sd_keys[client.cid]
        for k, x in zip(local_sd_keys, parameters_to_ndarrays(fit_res.parameters)):
            weight = fit_res.num_examples / aggregated_sd_num_examples_total[k]
            aggregated_sd[k] += weight * x

    # sparse fedadam
    for k, count in aggregated_sd_num_examples_total.items():
        if count > 0:
            # updated model = initial model - grad
            g = global_sd[k] - aggregated_sd[k]
            m_t[k] = np.multiply(beta_1, m_t[k]) + (1 - beta_1) * g            
            v_t[k] = np.multiply(beta_2, v_t[k]) + (1 - beta_2) * np.multiply(g, g)

            aggregated_sd[k] = global_sd[k] - eta * m_t[k] / (np.sqrt(v_t[k]) + tau)

    # for keys without updates, use global parameters
    keys_without_update = [k for k, count in aggregated_sd_num_examples_total.items() if count == 0]

    for k in keys_without_update:
        aggregated_sd[k] = global_sd[k]
    
    return list(aggregated_sd.values())

def aggregate_scalefl_with_weights(
    global_sd: Dict[str, NDArray],
    clients_local_sd_keys: Dict[str, List[str]],
    results: List[Tuple[ClientProxy, FitRes]],
    is_weight: Dict[str, bool],
    client_weights: Optional[Dict[str, float]] = None,
) -> NDArrays:
    agg: Dict[str, np.ndarray] = {k: np.zeros_like(v, dtype=np.float32) for k, v in global_sd.items()}
    cnt: Dict[str, np.ndarray] = {k: np.zeros_like(v, dtype=np.float32) for k, v in global_sd.items()}

    def _to_array(x):
        return x if isinstance(x, np.ndarray) else np.asarray(x)

    for client, fit_res in results:
        base = float(fit_res.num_examples)
        w_client = client_weights.get(client.cid, base) if client_weights else base

        local_keys = clients_local_sd_keys.get(client.cid, [])
        arrays = parameters_to_ndarrays(fit_res.parameters)
        upto = min(len(local_keys), len(arrays))

        for k, arr in zip(local_keys[:upto], arrays[:upto]):
            if k not in global_sd:
                continue
            arr = _to_array(arr)
            if not np.issubdtype(arr.dtype, np.floating):
                continue

            tgt = agg[k]
            dims = min(arr.ndim, tgt.ndim)
            if dims == 0:
                continue

            slices = tuple(slice(0, min(arr.shape[d], tgt.shape[d])) for d in range(dims))
            src_view = arr[slices].astype(tgt.dtype, copy=False)
            agg[k][slices] += w_client * src_view
            cnt[k][slices] += w_client

    out: NDArrays = []
    for k in global_sd.keys():
        updated = np.array(global_sd[k], copy=True)
        mask = cnt[k] > 0
        if mask.any():
            updated[mask] = agg[k][mask] / cnt[k][mask]
        out.append(updated)
    return out

