import torch
import numpy as np 
from io import BytesIO

from functools import reduce
from typing import List, Tuple, Any, Dict, cast

import numpy as np

from flwr.common import FitRes, Parameters
from flwr.server.client_proxy import ClientProxy
import numpy.typing as npt

import pdb

NDArray = npt.NDArray[Any]
NDArrays = List[NDArray]

def softmax(x, T=1.0):
    """Compute softmax values for each sets of scores in x."""
    return np.exp(x/T) / np.sum(np.exp(x/T), axis=0)

def get_layer(model, layer_name):
    assert layer_name.endswith('.weight') or layer_name.endswith('.bias'), 'layer name must be learnable (end with .weight or .bias'
    layer = model
    for attrib in layer_name.split('.'):
        if attrib.isdigit():
            layer = layer[int(attrib)]
        else:
            layer = getattr(layer, attrib)
    return layer

def bytes_to_ndarray(tensor: bytes) -> NDArray:
    """Deserialize NumPy ndarray from bytes."""
    bytes_io = BytesIO(tensor)
    # WARNING: NEVER set allow_pickle to true.
    # Reason: loading pickled data can execute arbitrary code
    # Source: https://numpy.org/doc/stable/reference/generated/numpy.load.html
    ndarray_deserialized = np.load(bytes_io, allow_pickle=False)  # type: ignore
    return cast(NDArray, ndarray_deserialized)

def parameters_to_ndarrays(parameters: Parameters) -> NDArrays:
    """Convert parameters object to NumPy ndarrays."""
    return [bytes_to_ndarray(tensor) for tensor in parameters.tensors]

def aggregate_inplace(results: List[Tuple[ClientProxy, FitRes]]) -> NDArrays:
    """Compute in-place weighted average. Assumes updated parameters are of the same size"""
    # Count total examples
    num_examples_total = sum([fit_res.num_examples for _, fit_res in results])

    # Compute scaling factors for each result
    scaling_factors = [
        fit_res.num_examples / num_examples_total for _, fit_res in results
    ]

    # Let's do in-place aggregation
    # get first result, then add up each other
    params = [
        scaling_factors[0] * x for x in parameters_to_ndarrays(results[0][1].parameters)
    ]
    for i, (_, fit_res) in enumerate(results[1:]):
        res = (
            scaling_factors[i + 1] * x
            for x in parameters_to_ndarrays(fit_res.parameters)
        )
        params = [reduce(np.add, layer_updates) for layer_updates in zip(params, res)]

    return params

def aggregate_inplace_early_exit(
    global_sd: Dict[str, NDArray],
    clients_local_sd_keys: Dict[str, List[str]],
    results: List[Tuple[ClientProxy, FitRes]],
) -> NDArrays:
    """
    Aggregate heterogeneous client updates (different exits/widths).
    - Float tensors (weights, BN running stats) are averaged elementwise over
      the overlapping region, weighted by num_examples.
    - Non-float tensors (e.g., BN num_batches_tracked) keep prior global values.
    - Elements with no client updates keep their previous global value.
    """
    import numpy as np
    try:
        import torch
        _has_torch = True
    except Exception:
        _has_torch = False

    def _to_numpy(x):
        if isinstance(x, np.ndarray):
            return x
        if _has_torch and torch.is_tensor(x):
            return x.detach().cpu().numpy()
        return np.asarray(x)

    # Accumulators and per-element counters (float32 counters)
    agg: Dict[str, np.ndarray] = {k: np.zeros_like(_to_numpy(v)) for k, v in global_sd.items()}
    cnt: Dict[str, np.ndarray] = {k: np.zeros_like(_to_numpy(v), dtype=np.float32) for k, v in global_sd.items()}

    # Accumulate per client (float tensors only)
    for client, fit_res in results:
        local_keys = clients_local_sd_keys[client.cid]
        client_params = parameters_to_ndarrays(fit_res.parameters)
        assert len(local_keys) == len(client_params), (
            f"Key/param length mismatch for client {client.cid}: "
            f"{len(local_keys)} vs {len(client_params)}"
        )
        weight = float(fit_res.num_examples)

        for k, x in zip(local_keys, client_params):
            if k not in agg:
                continue

            g_arr = agg[k]
            c_arr = _to_numpy(x)

            # Only average floating tensors
            if not np.issubdtype(g_arr.dtype, np.floating):
                continue

            # Rank must match; otherwise skip
            if g_arr.ndim != c_arr.ndim:
                continue

            # Intersection slice over all dims (handles width-scaling)
            min_shape = tuple(min(gs, cs) for gs, cs in zip(g_arr.shape, c_arr.shape))
            if any(m == 0 for m in min_shape):
                continue
            region = tuple(slice(0, m) for m in min_shape)

            # Guard against NaNs/Infs from clients
            cres = c_arr[region]
            if not np.all(np.isfinite(cres)):
                continue

            # Type-safe accumulation on the overlapping region
            agg[k][region] += weight * cres.astype(g_arr.dtype, copy=False)
            cnt[k][region] += weight

    # Build output dict: averaged where updated; keep global elsewhere
    out: Dict[str, np.ndarray] = {}
    for k in agg.keys():
        g_full = _to_numpy(global_sd[k])  # previous global
        c_full = cnt[k]

        # Non-float tensors were never touched: keep global
        if not np.issubdtype(g_full.dtype, np.floating):
            continue

        if not np.any(c_full > 0):
            # No client updated this float tensor; keep global
            continue

        avg = agg[k].copy()
        mask = c_full > 0
        avg[mask] = avg[mask] / c_full[mask]
        avg[~mask] = g_full[~mask]  # untouched elements keep global

        out[k] = avg

    # Return in global key order
    return [out.get(k, _to_numpy(global_sd[k])) for k in global_sd.keys()]

def aggregate_scalefl(global_sd, clients_local_sd_keys, results, is_weight=None):
    # global_sd: dict of trainable tensors (np.ndarray)
    import numpy as np
    agg = {k: np.zeros_like(v) for k, v in global_sd.items()}
    cnt = {k: np.zeros_like(v, dtype=np.float32) for k, v in global_sd.items()}

    num_examples_total = sum(fit_res.num_examples for _, fit_res in results)

    for client, fit_res in results:
        weight = fit_res.num_examples / max(1, num_examples_total)
        keys = clients_local_sd_keys[client.cid]
        client_params = parameters_to_ndarrays(fit_res.parameters)
        for k, x in zip(keys, client_params):
            if k not in agg: 
                continue
            g = agg[k]; c = x
            if g.ndim != c.ndim:
                continue
            min_shape = tuple(min(gs, cs) for gs, cs in zip(g.shape, c.shape))
            if any(m == 0 for m in min_shape): 
                continue
            region = tuple(slice(0, m) for m in min_shape)
            agg[k][region] += weight * c[region].astype(g.dtype, copy=False)
            cnt[k][region] += weight

    out = {}
    for k in agg.keys():
        mask = cnt[k] > 0
        if not np.any(mask):
            continue  # keep previous
        avg = agg[k]
        avg[mask] /= cnt[k][mask]
        out[k] = avg
    return list(out.get(k, global_sd[k]) for k in global_sd.keys())

def aggregate_inplace_early_exit_feddyn(global_sd: Dict[str, NDArrays], 
        clients_local_sd_keys: Dict[str, List[str]], 
        results: List[Tuple[ClientProxy, FitRes]],
        h_dict: Dict[str, NDArrays], # weights and biases only
        num_clients: int,
        alpha: float,
        ) -> NDArrays:
    """Compute in-place FedDyn with results of varying sizes.
        https://arxiv.org/pdf/2111.04263.pdf
        Modified from: https://github.com/adap/flower/blob/main/baselines/depthfl/depthfl/strategy.py
    """
    aggregated_sd = {k:np.zeros(v.shape) for k, v in global_sd.items()}

    # Count total examples per sd key
    aggregated_sd_count = {k: 0 for k in global_sd.keys()} 
    # for client, fit_res in results:
    #     local_sd_keys = clients_local_sd_keys[client.cid]
    #     for k in local_sd_keys: 
    #         aggregated_sd_count[k] += 1
    
    # in-place aggregation!
    for client, fit_res in results:
        local_sd_keys = clients_local_sd_keys[client.cid]
        assert len(local_sd_keys) == len(parameters_to_ndarrays(fit_res.parameters))
        for k, x in zip(local_sd_keys, parameters_to_ndarrays(fit_res.parameters)):
            aggregated_sd[k] += x # take sum without weighting
            aggregated_sd_count[k] += 1

    # update h variable and apply it
    for k, v in aggregated_sd.items():
        if aggregated_sd_count[k] > 0:
            aggregated_sd[k] = v / aggregated_sd_count[k]

            h_dict[k] = (
                    h_dict[k] 
                    - alpha 
                    * aggregated_sd_count[k]
                    * (aggregated_sd[k] - global_sd[k])
                    / num_clients
            )
    
            aggregated_sd[k] = aggregated_sd[k] - h_dict[k] / alpha

    # for keys without updates, use global parameters
    keys_without_update = [k for k, count in aggregated_sd_count.items() if count == 0]

    for k in keys_without_update:
        aggregated_sd[k] = global_sd[k]
    
    return list(aggregated_sd.values())

def aggregate_inplace_early_exit_fedsparseadam(
    global_sd: Dict[str, NDArrays],
    clients_local_sd_keys: Dict[str, List[str]],
    results: List[Tuple[ClientProxy, FitRes]],
    m_t: Dict[str, NDArrays],
    v_t: Dict[str, NDArrays],
    beta_1: float,
    beta_2: float,
    tau:   float,   # use as Adam epsilon
    eta:   float,   # server LR
    weight_decay: float = 0.0,                   # AdamW decoupled decay
    exclude_from_wd: Tuple[str, ...] = (".bias", "bn.", "running_mean", "running_var"),
) -> NDArrays:
    """Server-side AdamW with heterogeneous, slice-aware aggregation (float-only)."""
    import numpy as np

    def _np(x):  # ensure numpy array
        return x if isinstance(x, np.ndarray) else np.asarray(x)

    def _is_float_arr(x):
        return isinstance(x, np.ndarray) and np.issubdtype(x.dtype, np.floating)

    # init moments for any missing float key
    for k, g in global_sd.items():
        g = _np(g)
        if _is_float_arr(g):
            if k not in m_t: m_t[k] = np.zeros_like(g)
            if k not in v_t: v_t[k] = np.zeros_like(g)

    # accumulate weighted deltas on overlapping regions only
    delta_sum = {k: (np.zeros_like(_np(v)) if _is_float_arr(_np(v)) else _np(v)) for k, v in global_sd.items()}
    cnt_sum   = {k: (np.zeros_like(_np(v), dtype=np.float32) if _is_float_arr(_np(v)) else _np(v)) for k, v in global_sd.items()}

    tot_examples = float(sum(fr.num_examples for _, fr in results)) or 1.0

    for client, fit_res in results:
        keys = clients_local_sd_keys[client.cid]
        local_params = parameters_to_ndarrays(fit_res.parameters)
        assert len(keys) == len(local_params), f"Key/param mismatch for {client.cid}"

        w = float(fit_res.num_examples) / tot_examples

        for k, lp in zip(keys, local_params):
            if k not in global_sd:
                continue
            g = _np(global_sd[k])
            l = _np(lp)
            if not _is_float_arr(g) or g.ndim != l.ndim:
                continue
            min_shape = tuple(min(ga, la) for ga, la in zip(g.shape, l.shape))
            if any(m == 0 for m in min_shape):
                continue
            region = tuple(slice(0, m) for m in min_shape)
            d = l[region].astype(g.dtype, copy=False) - g[region]
            if not np.all(np.isfinite(d)):
                continue
            delta_sum[k][region] += w * d
            cnt_sum[k][region]   += w

    # AdamW update
    new_sd: Dict[str, np.ndarray] = {}
    for k, g in global_sd.items():
        g = _np(g)
        if not _is_float_arr(g):
            new_sd[k] = g
            continue

        mask = cnt_sum[k] > 0
        if not np.any(mask):
            new_sd[k] = g
            continue

        avg_delta = np.zeros_like(g)
        avg_delta[mask] = delta_sum[k][mask] / cnt_sum[k][mask]

        # moments
        m_t[k] = beta_1 * m_t[k] + (1.0 - beta_1) * avg_delta
        v_t[k] = beta_2 * v_t[k] + (1.0 - beta_2) * (avg_delta * avg_delta)

        # (bias correction omitted unless you track per-key steps; optional)
        upd = eta * (m_t[k] / (np.sqrt(v_t[k]) + float(tau)))

        p = g - upd

        # decoupled weight decay
        if weight_decay > 0.0 and not any(tag in k for tag in exclude_from_wd):
            p = p - eta * weight_decay * g

        new_sd[k] = p

    # keep original for keys with no updates
    return [new_sd.get(k, _np(global_sd[k])) for k in global_sd.keys()]
