import flwr as fl
import math
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.distributions import Normal
from torch.distributions.kl import kl_divergence
from typing import List
from collections import OrderedDict, defaultdict
from importlib import import_module
import inspect
from ptflops import get_model_complexity_info

import logging
logger = logging.getLogger(__name__)

from typing import Dict, List, Union
ArrayLike = Union[np.ndarray, torch.Tensor]

# ---- debug_helpers.py (you can drop this in model_utils.py or snip_utils.py) ----
# import torch, torch.nn as nn
# from typing import Optional

def dump_trunk_taps(model: nn.Module, prefix: str = "[TRUNK] ") -> None:
    """
    Prints (block_idx, tap_channels) for each early-exit tap point.
    Uses the same logic the template uses (bn2.num_features as the 'tap' width).
    """
    try:
        layers = getattr(model, "layers", None)
        blks_to_exit = list(getattr(model, "blks_to_exit", []))
        if not layers or not blks_to_exit:
            print(prefix + "No layers or blks_to_exit on this model.")
            return

        # Walk blocks in order and record bn2.num_features
        block_dims = []
        blk_idx = -1
        for stage in layers:
            for block in stage:
                blk_idx += 1
                try:
                    c = int(block.bn2.num_features)
                except Exception:
                    c = None
                block_dims.append((blk_idx, c))

        print(prefix + f"depth={blk_idx+1}, taps={blks_to_exit}")
        for b in blks_to_exit:
            if 0 <= int(b) < len(block_dims):
                bi, c = block_dims[int(b)]
                print(prefix + f"tap@blk={bi}: channels={c}")
            else:
                print(prefix + f"tap@blk={b}: out-of-range")
    except Exception as e:
        print(prefix + f"failed ({e})")


def dump_ee_invariants(model: nn.Module, prefix: str = "[EE] ") -> None:
    """
    For each exit head, prints:
      - head.features conv/BN pairs (conv.out_channels vs bn.num_features)
      - the head's avgpool->fc input size
    Helps catch 'conv.out != bn.features' mismatches.
    """
    try:
        heads = getattr(model, "exit_heads", None)
        if not isinstance(heads, nn.ModuleList):
            print(prefix + "No exit_heads present.")
            return

        for i, head in enumerate(heads):
            if not hasattr(head, "features") or not isinstance(head.features, nn.Sequential):
                print(prefix + f"head[{i}] has no sequential features")
                continue

            print(prefix + f"head[{i}]")
            feats = list(head.features)
            # walk conv->(optional scaler)->bn
            j = 0
            while j < len(feats):
                if isinstance(feats[j], nn.Conv2d):
                    conv = feats[j]
                    # find the next BN after this conv (skip Identity/Scaler/ReLU)
                    k = j + 1
                    bn = None
                    while k < len(feats):
                        if isinstance(feats[k], (nn.BatchNorm2d,)):
                            bn = feats[k]
                            break
                        k += 1
                    if bn is not None:
                        ok = (int(conv.out_channels) == int(bn.num_features))
                        print(prefix + f"  conv@{j}: in={conv.in_channels} out={conv.out_channels} | "
                                       f"bn@{k}: num_features={bn.num_features}  -> match={ok}")
                    else:
                        print(prefix + f"  conv@{j}: in={conv.in_channels} out={conv.out_channels} | bn: NONE")
                    j = k if k is not None else j + 1
                else:
                    j += 1

            # fc input
            try:
                fc_in = int(head.fc.in_features)
                fc_out = int(head.fc.out_features)
                print(prefix + f"  fc: in={fc_in} out={fc_out}")
            except Exception:
                pass
    except Exception as e:
        print(prefix + f"failed ({e})")


def assert_ee_conv_bn_match(model: nn.Module, strict: bool = False, prefix: str = "[EE/CHECK] ") -> bool:
    """
    Returns True if all conv->BN pairs inside exit heads agree on channel count.
    If strict=True, raises AssertionError on first mismatch.
    """
    ok_all = True
    heads = getattr(model, "exit_heads", [])
    for i, head in enumerate(heads):
        feats = getattr(head, "features", None)
        if not isinstance(feats, nn.Sequential):
            continue
        feats = list(feats)
        j = 0
        while j < len(feats):
            if isinstance(feats[j], nn.Conv2d):
                conv = feats[j]
                k = j + 1
                bn = None
                while k < len(feats):
                    if isinstance(feats[k], (nn.BatchNorm2d,)):
                        bn = feats[k]
                        break
                    k += 1
                if bn is not None:
                    ok = (int(conv.out_channels) == int(bn.num_features))
                    if not ok:
                        ok_all = False
                        msg = (f"{prefix}head={i} conv_out={conv.out_channels} "
                               f"!= bn_features={bn.num_features} (conv@{j}, bn@{k})")
                        if strict:
                            raise AssertionError(msg)
                        else:
                            print(msg)
                j = k if k is not None else j + 1
            else:
                j += 1
    return ok_all

def retarget_exit_preconv_bns(model: nn.Module, verbose: bool = False):
    """
    Fix BatchNorm2d layers that immediately follow a Conv2d inside exit-head feature stacks
    to match Conv2d.out_channels (after structural pruning/SNIP).
    """
    changed = 0
    for name, module in model.named_modules():
        # Look for 'features' sequentials inside exit heads
        if isinstance(module, nn.Sequential) and name.endswith('features'):
            prev = None
            for i, m in enumerate(module):
                if isinstance(prev, nn.Conv2d) and isinstance(m, nn.BatchNorm2d):
                    need = prev.out_channels
                    if m.num_features != need:
                        new_bn = nn.BatchNorm2d(
                            need,
                            eps=m.eps, momentum=m.momentum,
                            affine=m.affine, track_running_stats=m.track_running_stats,
                        )
                        # keep device/dtype consistent
                        new_bn = new_bn.to(next(prev.parameters()).device)
                        module[i] = new_bn
                        changed += 1
                prev = m
    if verbose:
        print(f"[SURGERY] retarget_exit_preconv_bns changed={changed}")
    return changed
                
def disable_bn_tracking(model: nn.Module):
    for m in model.modules():
        if isinstance(m, nn.modules.batchnorm._BatchNorm):
            # No running stats anywhere in ScaleFL
            m.track_running_stats = False
            m.running_mean = None
            m.running_var = None
            m.num_batches_tracked = None
    return model

def _as_tensor(arr: ArrayLike) -> torch.Tensor:
    return arr if isinstance(arr, torch.Tensor) else torch.from_numpy(arr)

def _to_same_type(dst_like: ArrayLike, t: torch.Tensor) -> ArrayLike:
    if isinstance(dst_like, torch.Tensor):
        return t.to(dst_like.dtype)
    else:
        return t.detach().cpu().numpy().astype(dst_like.dtype, copy=True)

def _normalize_idx(idx, dim_size: int) -> torch.Tensor:
    if idx is None:
        return torch.arange(dim_size, dtype=torch.long)
    if isinstance(idx, (list, tuple, np.ndarray)):
        idx = torch.as_tensor(idx, dtype=torch.long)
    elif not torch.is_tensor(idx):
        idx = torch.tensor([int(idx)], dtype=torch.long)
    idx = idx.to(torch.long)
    if dim_size > 0:
        idx = idx[(idx >= 0) & (idx < dim_size)]
    if idx.numel() == 0:
        return torch.arange(dim_size, dtype=torch.long)
    return idx

def prune(
    state_dict: Dict[str, ArrayLike],
    param_idx: Dict[str, List[torch.Tensor]],
) -> Dict[str, ArrayLike]:
    """Index-aware pruning: per-dimension index_select for each key."""
    ret: Dict[str, ArrayLike] = {}
    for k, arr in state_dict.items():
        if k not in param_idx:
            continue
        t = _as_tensor(arr)
        src_ndim = t.ndim
        idx_list = list(param_idx[k])
        if len(idx_list) < src_ndim:
            idx_list += [None] * (src_ndim - len(idx_list))
        else:
            idx_list = idx_list[:src_ndim]
        out = t
        for d in range(src_ndim):
            idx = _normalize_idx(idx_list[d], int(out.shape[d]))
            out = out.index_select(d, idx)
        ret[k] = _to_same_type(arr, out.contiguous())
    return ret


def torch_clamp(x):
    return torch.clamp(x, min=0, max=1)
  
def set_partial_weights(model, update_keys, weights):
    """
    Load a partial list of tensors (params +/or buffers) into the model.

    - update_keys: list of state_dict keys to update, in the same order as `weights`
    - weights:     list of numpy arrays
    """
    sd = model.state_dict()
    if len(update_keys) != len(weights):
        # Accept (for eval) both full-slice and trainable-only lengths;
        # don't crash here, but raise a clear error with context.
        raise ValueError(
            f"set_partial_weights: length mismatch, keys={len(update_keys)}, weights={len(weights)}"
        )

    new_sd = OrderedDict(sd)  # copy
    for k, w in zip(update_keys, weights):
        ref = sd[k]
        t = torch.from_numpy(np.copy(w))           # never share underlying memory
        if t.dtype != ref.dtype:
            t = t.to(ref.dtype)                    # BN buffers / num_batches_tracked dtypes
        if t.shape != ref.shape:
            try:
                t = t.view_as(ref)                 # reshape if same numel but different shape metadata
            except Exception as e:
                raise ValueError(
                    f"set_partial_weights: shape mismatch for key '{k}': "
                    f"got {tuple(t.shape)}, expected {tuple(ref.shape)}"
                ) from e
        new_sd[k] = t

    # strict=False so any keys not provided remain as-is
    model.load_state_dict(new_sd, strict=False)

def set_weights(model: torch.nn.ModuleList, weights: fl.common.Weights) -> None:
    """Set model weights from a list of NumPy ndarrays."""
    state_dict = OrderedDict(
        {
            k: torch.Tensor(np.atleast_1d(v))
            for k, v in zip(model.state_dict().keys(), weights)
        }
    )
    model.load_state_dict(state_dict, strict=True)

def set_weights_multiple_models(nets: list, parameters: fl.common.Weights) -> None:
    # a list of nets and a list of parameters to load in order
    for net in nets:
        assert len(net.state_dict().keys()) <= len(parameters), f'Insufficient parameters to load {type(net).__name__}.'
        params_dict = zip(net.state_dict().keys(), parameters)
        state_dict = OrderedDict(
            {k: torch.from_numpy(np.copy(v)) for k, v in params_dict}
        )
        net.load_state_dict(state_dict, strict=True)
        parameters = parameters[len(net.state_dict().keys()):]

def get_updates(original_weights: fl.common.Weights, updated_weights: fl.common.Weights) -> List[torch.Tensor]:
    # extract the updates given two weights
    return [torch.from_numpy(np.copy(up)) - torch.from_numpy(np.copy(op)) for up, op in zip(updated_weights, original_weights)]

def apply_updates(original_weights: fl.common.Weights, updates: List[torch.Tensor]) -> fl.common.Weights:
    # apply updates to original weights
    return [np.copy(op) + up.cpu().detach().numpy() for up, op in zip(updates, original_weights)]

class Step(torch.autograd.Function):
    def __init__(self):
        super(Step, self).__init__()

    @staticmethod
    def forward(ctx, input):
        return (input > 0.).long().float()

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output

def clampSTE_max(input, max_limit=1.):
    return ClampSTE.apply(input, max_limit)

class ClampSTE(torch.autograd.Function):
    def __init__(self):
        super(ClampSTE, self).__init__()
        
    @staticmethod
    def forward(ctx, input, max_limit=1.):
        return torch.clamp(input, min=0, max=max_limit)
    
    @staticmethod
    def backward(ctx, grad_output):
        return grad_output, None

class ReLUSTE(torch.autograd.Function):
    def __init__(self):
        super(ReLUSTE, self).__init__()

    @staticmethod
    def forward(ctx, input):
        return F.relu(input, inplace=True)

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output


class KLDBatchNorm1d(nn.BatchNorm1d):
    def __init__(self, num_features, eps=1e-5, momentum=0.1,
                 affine=True, track_running_stats=True):
        super(KLDBatchNorm1d, self).__init__(
            num_features, eps, momentum, affine, track_running_stats)

        self.beta = 1. 
        self.init_mode = False
        self.num_features = num_features

        if self.track_running_stats:
            self.register_buffer('init_running_mean', torch.zeros(num_features))
            self.register_buffer('init_running_var', torch.ones(num_features))
        
    def register_init_parameters(self):
        self.init_running_mean = self.running_mean.detach().clone()
        self.init_running_var = self.running_var.detach().clone()
        self.reset_running_stats()

    def reset_running_stats(self):
        if self.track_running_stats:
            self.running_mean.zero_()
            self.running_var.fill_(1)
            self.num_batches_tracked.zero_()

    def forward(self, input):
        self._check_input_dim(input)
        exponential_average_factor = 0.0

        if self.training and self.track_running_stats:
            if self.num_batches_tracked is not None:
                self.num_batches_tracked += 1
                exponential_average_factor = self.momentum

        # calculate running estimates
        if self.training:
            mean = input.mean([0])
            # use biased var in train
            var = input.var([0], unbiased=False)
            n = input.numel() / input.size(1)
            with torch.no_grad():
                self.running_mean = exponential_average_factor * mean\
                    + (1 - exponential_average_factor) * self.running_mean
                # update running_var with unbiased var
                self.running_var = exponential_average_factor * var * n / (n - 1)\
                    + (1 - exponential_average_factor) * self.running_var

            mean = (1. - self.beta) * self.init_running_mean + self.beta * mean
            var = (1. - self.beta) * self.init_running_var + self.beta * var  

        else:
            if self.init_mode:
                mean = self.init_running_mean
                var = self.init_running_var
            else:
                mean = (1. - self.beta) * self.init_running_mean + self.beta * self.running_mean
                var = (1. - self.beta) * self.init_running_var + self.beta * self.running_var  

        input = (input - mean[None, :]) / (torch.sqrt(var[None, :] + self.eps))
        if self.affine:
            input = input * self.weight[None, :] + self.bias[None, :]

        return input


class KLDBatchNorm2d(nn.BatchNorm2d):
    def __init__(self, num_features, eps=1e-5, momentum=0.1,
                 affine=True, track_running_stats=True):
        super(KLDBatchNorm2d, self).__init__(
            num_features, eps, momentum, affine, track_running_stats)

        self.beta = 1. 
        self.init_mode = False
        self.num_features = num_features

        if self.track_running_stats:
            self.register_buffer('init_running_mean', torch.zeros(num_features))
            self.register_buffer('init_running_var', torch.ones(num_features))
        
    def register_init_parameters(self):
        self.init_running_mean = self.running_mean.detach().clone()
        self.init_running_var = self.running_var.detach().clone()
        self.reset_running_stats()

    def reset_running_stats(self):
        if self.track_running_stats:
            self.running_mean.zero_()
            self.running_var.fill_(1)
            self.num_batches_tracked.zero_()

    def forward(self, input):
        self._check_input_dim(input)
        exponential_average_factor = 0.0

        if self.training and self.track_running_stats:
            if self.num_batches_tracked is not None:
                self.num_batches_tracked += 1
                exponential_average_factor = self.momentum

        # calculate running estimates
        if self.training:
            mean = input.mean([0, 2, 3])
            # use biased var in train
            var = input.var([0, 2, 3], unbiased=False)
            n = input.numel() / input.size(1)
            with torch.no_grad():
                self.running_mean = exponential_average_factor * mean\
                    + (1 - exponential_average_factor) * self.running_mean
                # update running_var with unbiased var
                self.running_var = exponential_average_factor * var * n / (n - 1)\
                    + (1 - exponential_average_factor) * self.running_var

            mean = (1. - self.beta) * self.init_running_mean + self.beta * mean
            var = (1. - self.beta) * self.init_running_var + self.beta * var  

        else:
            if self.init_mode:
                mean = self.init_running_mean
                var = self.init_running_var
            else:
                mean = (1. - self.beta) * self.init_running_mean + self.beta * self.running_mean
                var = (1. - self.beta) * self.init_running_var + self.beta * self.running_var  

        input = (input - mean[None, :, None, None]) / (torch.sqrt(var[None, :, None, None] + self.eps))
        if self.affine:
            input = input * self.weight[None, :, None, None] + self.bias[None, :, None, None]

        return input

def set_kld_beta(model, betas):
    i = 0 
    for m in model.modules():
        if isinstance(m,(KLDBatchNorm1d,KLDBatchNorm2d)):
            if type(betas) in [int, float]:
                m.beta = betas
            elif torch.is_tensor(betas) and betas.size().numel() == 1:
                m.beta = betas.squeeze()
            else:
                m.beta = betas.squeeze()[i]
                i += 1

def update_mean_and_var(m_x, v_x, N, m_y, v_y, M):
    if M == 1:
        var = v_x
    else:
        var1 = ((N - 1) * v_x + (M - 1) * v_y) / (N + M - 1)
        var2 = (N * M * ((m_x - m_y) ** 2)) / ((N+M)*(N+M-1))
        var = var1 + var2
    mean = (N*m_x + M*m_y) / (N+M)

    return mean, var, N+M

def precompute_kld(model, dataloader, jit_augment, device, eps=1e-5):
    set_kld_bn_mode(model, True) # Use initial model for KLDBatchNorm
    model.eval()

    feat_stats = {}
    def set_hook(name):
        if name not in feat_stats:
            feat_stats[name] = defaultdict(float)

        def hook(m, inp, outp):
            inp = inp[0]
            if len(inp.size()) == 2:
                mean = inp.mean([0])
                var = inp.var([0], unbiased=True)
            else:
                mean = inp.mean([0, 2, 3])
                var = inp.var([0, 2, 3], unbiased=True)
            n = inp.numel() / inp.size(1) 

            with torch.no_grad():
                feat_stats[name]['running_mean'], \
                feat_stats[name]['running_var'], \
                feat_stats[name]['total_size'] = update_mean_and_var(feat_stats[name]['running_mean'],
                                        feat_stats[name]['running_var'],
                                        feat_stats[name]['total_size'],
                                        mean,
                                        var,
                                        n)

        return hook

    hooks = {}
    i = 0
    for m in model.modules():
        if isinstance(m, (KLDBatchNorm1d, KLDBatchNorm2d, nn.BatchNorm2d)):
            hooks[i] = m.register_forward_hook(set_hook(i))
            i += 1

    with torch.no_grad():
        for img, _ in dataloader:
            img = img.to(device)
            if jit_augment is not None:
                img = jit_augment(img)
            model(img)
    
    for h in hooks.values():
        h.remove()

    set_kld_bn_mode(model, False)

    klds = None

    i = 0
    for m in model.modules():
        if isinstance(m, (KLDBatchNorm1d, KLDBatchNorm2d, nn.BatchNorm2d)):
            if isinstance(m, (KLDBatchNorm1d, KLDBatchNorm2d)):
                init_dist = Normal(m.init_running_mean, torch.sqrt(m.init_running_var + eps))
            else:
                init_dist = Normal(m.running_mean, torch.sqrt(m.running_var + eps))
            client_dist = Normal(feat_stats[i]['running_mean'], torch.sqrt(feat_stats[i]['running_var'] + eps))
            # client_dist = Normal(m.running_mean, m.running_var)

            kl = 0.5 * kl_divergence(init_dist, client_dist) + 0.5 * kl_divergence(client_dist, init_dist)
            m_kl = torch.mean(kl).view(1,1)
            if klds == None:
                klds = m_kl
            else:
                klds = torch.cat([m_kl, klds], dim=1) 
            i += 1

    return klds

def precompute_feat_stats(model, learnable_modules, dataloader, jit_augment, device, mode, no_of_samples=None, eps=1e-5):
    set_kld_bn_mode(model, True)
    model.eval()

    global_feats = mode in ['layerwise', 'layerwise_samples']
    hook_all_layers = mode in ['layerwise', 'layerwise_samples', 'layerwise_local']
    global_model = mode in ['layerwise', 'layerwise_samples', 'layerwise_last']


    feat_stats = {}
    def set_hook(name):
        if name not in feat_stats:
            feat_stats[name] = defaultdict(float)

        def hook(m, inp, outp):
            inp = inp[0]

            if global_feats:
                mean = inp.mean()
                var = inp.var(unbiased=True)
                n = inp.numel()

            else:
                if len(inp.size()) == 2:
                    mean = inp.mean([0])
                    var = inp.var([0], unbiased=True)
                else:
                    mean = inp.mean([0, 2, 3])
                    var = inp.var([0, 2, 3], unbiased=True)
                n = inp.numel() / inp.size(1)

            with torch.no_grad():
                feat_stats[name]['running_mean'], \
                feat_stats[name]['running_var'], \
                feat_stats[name]['total_size'] = update_mean_and_var(feat_stats[name]['running_mean'],
                                        feat_stats[name]['running_var'],
                                        feat_stats[name]['total_size'],
                                        mean,
                                        var,
                                        n)

        return hook

    hooks = {}
    if hook_all_layers:
        i = 0
        for m in model.modules():
            if isinstance(m, learnable_modules):
                hooks[i] = m.register_forward_hook(set_hook(i))
                i += 1
    else: # 'layerwise_last'
        hooks[0] = model.net.fc.register_forward_hook(set_hook(0))
       
    with torch.no_grad():
        for img, _ in dataloader:
            img = img.to(device)
            if jit_augment is not None:
                img = jit_augment(img)
            model(img)
    
    for h in hooks.values():
        h.remove()

    set_kld_bn_mode(model, False)
    # process feat_stats

    if global_model: 
        mask_net_input = None
        for stats in feat_stats.values():
            # if mode in ['layerwise', 'layerwise_samples']:
                # l_stats = torch.cat([torch.mean(stats['running_mean']).view(1), torch.mean(stats['running_var']).view(1)])
            # elif mode == 'layerwise_last':
            #     l_stats = torch.cat([stats['running_mean'], stats['running_var']])

            if global_feats:
                l_stats = torch.cat([stats['running_mean'].view(1), torch.sqrt(stats['running_var'].view(1) + eps)])
            else:
                l_stats = torch.cat([stats['running_mean'], stats['running_var']])

            if mask_net_input is None:
                mask_net_input = l_stats
            else:
                mask_net_input = torch.cat([mask_net_input, l_stats])
        
        if mode == 'layerwise_samples':
            assert no_of_samples is not None
            mask_net_input = torch.cat([torch.log(torch.tensor([no_of_samples])).to(device), mask_net_input])
                
        return mask_net_input.view(1,-1)
    
    else: # layerwise_local for local models
        mask_net_input = {}
        for idx, stats in feat_stats.items():
            mask_net_input[idx] = torch.cat([stats['running_mean'], torch.sqrt(stats['running_var'] + eps)]).view(1,-1)
        return mask_net_input

def set_kld_bn_mode(m, v):
    for n, ch in m.named_children():
        if type(ch) in [KLDBatchNorm1d, KLDBatchNorm2d]:
            ch.init_mode = v
        set_kld_bn_mode(ch, v)

def copy_pretrain_to_kld(pretrain_weights, model):
    load_net_sd = OrderedDict({})
    model_state_dict = model.state_dict()
    if dict in inspect.getmro(type(pretrain_weights)):
        if 'classifier.bias' in pretrain_weights:
            # baseline model
            for k, v in pretrain_weights.items():
                if 'classifier' in k:
                    assert v.shape == model_state_dict[k.replace('classifier','net.fc')].shape
                    load_net_sd[k.replace('classifier','net.fc')] = v
                else:
                    load_net_sd[k.replace('base','net')] = v
        else:
            # pytorch's pretrained imagenet model
            for k, v in pretrain_weights.items():
                if f'net.{k}' in model_state_dict and model_state_dict[f'net.{k}'].size() == v.size():
                    load_net_sd[f'net.{k}'] = v
    else:
        net_keys_wo_ada = [k for k in model_state_dict.keys() if 'init_running' not in k]
        assert len(net_keys_wo_ada) == len(pretrain_weights)

        for k, w in zip(net_keys_wo_ada, pretrain_weights):
            load_net_sd[k] = torch.Tensor(np.atleast_1d(w))

    model.load_state_dict(load_net_sd, strict=False)
    reinit_bn(model)

def reinit_bn(m):
    for n, ch in m.named_children():
        if type(ch) in [KLDBatchNorm1d, KLDBatchNorm2d]:
            ch.register_init_parameters()
        reinit_bn(ch)

def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
    # Cut & paste from PyTorch official master until it's in a few official releases - RW
    # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
    def norm_cdf(x):
        # Computes standard normal cumulative distribution function
        return (1. + math.erf(x / math.sqrt(2.))) / 2.

    if (mean < a - 2 * std) or (mean > b + 2 * std):
        print("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
                      "The distribution of values may be incorrect.",
                      stacklevel=2)

    with torch.no_grad():
        # Values are generated by using a truncated uniform distribution and
        # then using the inverse CDF for the normal distribution.
        # Get upper and lower cdf values
        l = norm_cdf((a - mean) / std)
        u = norm_cdf((b - mean) / std)

        # Uniformly fill tensor with values from [l, u], then translate to
        # [2l-1, 2u-1].
        tensor.uniform_(2 * l - 1, 2 * u - 1)

        # Use inverse cdf transform for normal distribution to get truncated
        # standard normal
        tensor.erfinv_()

        # Transform to proper mean, std
        tensor.mul_(std * math.sqrt(2.))
        tensor.add_(mean)

        # Clamp to ensure it's in the proper range
        tensor.clamp_(min=a, max=b)
        return tensor

class Scaler(nn.Module):
    def __init__(self, rate):
        super().__init__()
        self.rate = rate

    def forward(self, input):
        output = input / self.rate
        return output

class KDLoss(nn.Module):
    def __init__(self, args):
        super(KDLoss, self).__init__()

        self.kld_loss = nn.KLDivLoss()
        self.ce_loss = nn.CrossEntropyLoss()
        self.log_softmax = nn.LogSoftmax(dim=1)
        self.softmax = nn.Softmax(dim=1)

        self.T = args.KD_T
        self.gamma = args.KD_gamma

    def loss_fn_kd(self, pred, target, soft_target, gamma_active=True):
        _ce = self.ce_loss(pred, target)
        T = self.T
        if self.gamma and gamma_active:
            # _ce = (1. - self.gamma) * _ce
            _kld = self.kld_loss(self.log_softmax(pred / T), self.softmax(soft_target / T)) * self.gamma * T * T
        else:
            _kld = 0
        loss = _ce + _kld
        return loss

def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=dilation, groups=groups, bias=False,
                     dilation=dilation)

def conv1x1(in_planes, out_planes, stride=1):
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
   
class Classifier(nn.Module):
    def __init__(
        self,
        in_planes: int,
        num_classes: int,
        num_conv_layers: int = 3,
        reduction: int = 1,
        scale: float = 1.0,
        track_running_stats: bool = False,   # ScaleFL often sets this False
    ):
        super().__init__()

        # 1) optional channel-reduction conv
        new_planes = in_planes // reduction if reduction > 1 else in_planes

        # helper – keeps FLOPs consistent when width_scale < 1
        scaler = Scaler(scale) if scale < 1.0 else nn.Identity()

        convs = []
        for i in range(num_conv_layers):
            inp  = in_planes if i == 0 else new_planes
            convs.append(conv3x3(inp, new_planes))
            convs.append(scaler)
            convs.append(nn.BatchNorm2d(new_planes,
                                        track_running_stats=track_running_stats))
            convs.append(nn.ReLU(inplace=True))

        self.features = nn.Sequential(*convs)
        self.avgpool  = nn.AdaptiveAvgPool2d(1)
        self.flatten  = nn.Flatten()
        self.fc       = nn.Linear(new_planes, num_classes)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.features(x)
        x = self.avgpool(x)
        x = self.flatten(x)   # shape → (B, C)
        return self.fc(x)
