import logging
from Settings import *
from Util import *
from Optim import VRL as OP1
from Optim import FedProx as OP2
from Optim import FedNova as OP3

import torch
COMPRESSION_VERBOSE = False
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = False

class Client_Sim:
    _compression_logged = False
    def __init__(self, Loader, Model, Lr, wdecay, epoch=2, fixlr=False, optzer="SGD", compression_configs=None, decay_step=10, decay_rate=0.9):
        # Training DataLoader
        self.TrainData = Loader 
        self.DLen = 0
        for batch_id, (inputs, targets) in enumerate(self.TrainData):
            inputs, targets = inputs.to(device), targets.to(device)
            self.DLen += len(inputs)

        self.BeforeParas = None
        self.Model = cp.deepcopy(Model)
        self.GradParas = None
             
        self.device = device 


        # Loss-aware compression configs
        self.comp_weight = compression_configs.get("weight", "cin_diag")  # 'none' | 'cin_diag'
        self.calib_batches = compression_configs.get("calib_batches", 64)  # Number of calibration samples
        self.comp_eps = compression_configs.get("eps", 1e-8)
        self.tac_random_sampling = compression_configs.get("random_sampling", True)  # Whether to randomly select batches

        # Calibration DataLoader
        self.selection_rule = compression_configs.get("rule", "magnitude")
        if self.selection_rule in ["lossaware", "lossaware_svd"]:
            self.CalibData = Loader
            print(f"[Client] Running calibration for lossaware statistics")
        else:
            self.CalibData = None
        

        self.Optzer = optzer
        self.Wdecay = wdecay
        self.Epoch = epoch
        self.Mu = 0.001 
        self.Round = 0
        self.LR = Lr
        self.decay_step = decay_step
        self.decay_rate = decay_rate
        self.optimizer = None
        self.local_steps = 1
        if self.Optzer == "VRL":
            self.optimizer = OP1.VRL(self.Model.parameters(
            ), lr=self.LR, momentum=0.9, weight_decay=self.Wdecay, vrl=True, local=True)
        self.loss_fn = nn.CrossEntropyLoss()
        self.FixLR = fixlr
        self.trainloss = 0
        self.difloss = 0
        
        # Compression configs
        self.comp_granularity = compression_configs.get("granularity", "element")
        if compression_configs and self.comp_granularity == 'element':
            self.comp_ratio = compression_configs.get("ratio", 0.20)
            self.comp_mode = compression_configs.get("mode", "global")
            self.comp_include_bias = compression_configs.get("include_bias", True)
        elif compression_configs and self.comp_granularity == 'low_rank':
            self.comp_ratio = compression_configs.get("ratio", 0.1)
            self.comp_mode = compression_configs.get("mode", "global")
            self.comp_include_bias = compression_configs.get("include_bias", True)
        else:
            raise ValueError(f"Unknown granularity: {self.comp_granularity}")

        self.selection_rule = compression_configs.get("rule", "magnitude") # selection_rule: 'magnitude' | 'lossaware'
        # Rank selection configs for low-rank SVD compression
        self.rank_mode = compression_configs.get("rank_mode", "ratio")  # 'ratio' | 'fixed' | 'mixed'
        self.fixed_rank = compression_configs.get("fixed_rank", 16)
        self.min_rank = compression_configs.get("min_rank", 1)
        self.max_rank = compression_configs.get("max_rank", 10**9)
        
        # Error feedback
        self.comp_ef_enabled = compression_configs.get("ef_enabled", True)
        self.ef_reset_on_global = compression_configs.get("ef_reset_on_global", True)
        self.ef_residuals = {}
        
        # Log frequency
        self.log_layer_stats_every = compression_configs.get("log_layer_stats_every", 50)
        
        # Initialize cin_diag statistics cache
        self.cin_diag = {}        # key: "<layer_name>.weight" -> Tensor[Din or Cin*Kh*Kw]
        # Rank summary cache (for final logging)
        self.last_total_rank = 0
        self.last_avg_rank = 0.0

        # Overlap stats for element-wise and SVD selections
        self.elem_overlap_stats = {
            'total_inconsistent': 0,
            'total_selected': 0,
            'rounds': 0,
        }
        self.svd_overlap_stats = {
            'total_inconsistent_ranks': 0,
            'total_compared_ranks': 0,
        }

        # Estimate overall compression rate (meaningful only for low-rank granularity), logged globally once
        if self.comp_granularity == 'low_rank' and not Client_Sim._compression_logged:
            try:
                self.compute_overall_compression()
            except Exception as e:
                logging.warning(f"[Compression] compute_overall_compression failed: {e}")
            Client_Sim._compression_logged = True

    def _should_exclude_param_key(self, ky: str) -> bool:
        """Parameters not to be compressed"""
        k = ky.lower()
        # For ViT
        if ("pos_embed" in k) or ("cls_token" in k) or ("class_token" in k) or ("encoder.pos_embedding" in k):
            return True
        # For LayerNorm（ViT）and BatchNorm（CNN）
        elif ("layernorm" in k) or (".ln" in k) or ("norm" in k and (k.endswith(".weight") or k.endswith(".bias"))) or ("bn" in k and (k.endswith(".weight") or k.endswith(".bias"))):
            return True
        # For bias: only exclude specific bias parameters like BatchNorm
        # CNN convolution bias should NOT be excluded
        else:
            return False
        # return False

    @torch.no_grad()
    def _accumulate_cin_diag(self, loader=None, max_batches=None):
        """
        Compute column energy diag(X^T X) for each layer's input: Din for Linear, Cin*Kh*Kw for Conv2d.
        No intermediate activation storage; streaming accumulation via forward hooks. Float32 for numerical stability.
        Fallback: simple / cnt averaging (assuming equal batch weights), no weighted total_samples.
        """
        import torch.nn.functional as F
        import logging

        if loader is None:
            loader = self.CalibData  

        was_training = self.Model.training
        self.Model.eval()

        # Clear old statistics
        self.cin_diag = {}

        handles = []
        total_sum_diag = {}  # Global: per-layer sum_diag (tensor)

        def reg_linear_hook(layer_name):
            def hook(mod, inp, out):
                x = inp[0].detach().float()            # [B, ...] get input activations
                
                # Handle ViT multi-dimensional input: [B, seq_len, hidden_dim] -> [B*seq_len, hidden_dim]
                if x.dim() == 3:  # ViT sequence input
                    x = x.view(-1, x.shape[-1])        # [B*seq_len, hidden_dim]
                elif x.dim() > 2:  # Other multi-dimensional input
                    x = x.view(x.shape[0], -1)         # [B, flattened_dim]
                
                # Compute input column energy diagonal
                diag = (x * x).sum(dim=0)              # [hidden_dim] or [flattened_dim], sum over B_eff
                key = f"{layer_name}.weight"
                
                # Check dimension matching
                if hasattr(mod, 'weight') and mod.weight is not None:
                    weight_in_dim = mod.weight.shape[1]
                    if diag.numel() != weight_in_dim:
                        logging.warning(f"[Dimension Mismatch] {layer_name}: activation_cols={diag.numel()}, weight_cols={weight_in_dim}")
                
                if key not in total_sum_diag:
                    total_sum_diag[key] = diag.clone()
                else:
                    total_sum_diag[key].add_(diag)  # In-place accumulation

            return hook

        def reg_conv2d_hook(layer_name, mod: nn.Conv2d):
            k, s, p, d = mod.kernel_size, mod.stride, mod.padding, mod.dilation
            def hook(mod, inp, out):
                x = inp[0].detach().float()            # [B, C, H, W]
                B, C, H, W = x.shape
                # Compute L = H' * W' (unfold output positions, considering s/p/d)
                kh, kw = k if isinstance(k, tuple) else (k, k)
                sh, sw = s if isinstance(s, tuple) else (s, s)
                ph, pw = p if isinstance(p, tuple) else (p, p)
                dh, dw = d if isinstance(d, tuple) else (d, d)
                H_out = (H + 2 * ph - dh * (kh - 1) - 1) // sh + 1
                W_out = (W + 2 * pw - dw * (kw - 1) - 1) // sw + 1
                L = H_out * W_out
                
                cols = F.unfold(x, kernel_size=k, dilation=d, padding=p, stride=s)  # [B, Ckk, L]
                cols = cols.transpose(1, 2).reshape(-1, cols.shape[1])              # [B*L, Ckk]
                diag = (cols * cols).sum(dim=0)                                     # [Ckk]，sum over B*L
                key = f"{layer_name}.weight"
                
                # Check dimension matching (symmetric Linear)
                expected_size = C * kh * kw
                if diag.numel() != expected_size:
                    logging.warning(f"[Dimension Mismatch Conv] {layer_name}: activation={diag.numel()}, expected={expected_size}")
                
                if key not in total_sum_diag:
                    total_sum_diag[key] = diag.clone()
                else:
                    total_sum_diag[key].add_(diag)  # In-place accumulation

            return hook

        #  register hook
        for name, module in self.Model.named_modules():
            if isinstance(module, nn.Linear):
                handles.append(module.register_forward_hook(reg_linear_hook(name)))
            elif isinstance(module, nn.Conv2d):
                handles.append(module.register_forward_hook(reg_conv2d_hook(name, module)))

        #  select batch for statistics
        max_batches = max_batches or self.calib_batches  # assume self.calib_batches is defined
        cnt = 0
        
        if self.tac_random_sampling:
            #  random select batch (using Reservoir Sampling algorithm)
            import random
            reservoir = []
            
            for batch_idx, (xb, _) in enumerate(loader):
                if len(reservoir) < max_batches:
                    reservoir.append((xb, batch_idx))
                else:
                    prob = max_batches / (batch_idx + 1)
                    if random.random() < prob:
                        replace_idx = random.randint(0, max_batches - 1)
                        reservoir[replace_idx] = (xb, batch_idx)
            
            #  process selected batch
            for xb, original_idx in reservoir:
                if len(xb) == 0:  # add: skip empty batch
                    continue
                xb = xb.to(self.device)
                _ = self.Model(xb)
                cnt += 1

            if cnt == 0:  # add fallback: if random fails, switch to sequential
                logging.warning("[CinDiag] Random sampling got 0 batches, fallback to sequential")
                self.tac_random_sampling = False  # temporarily disable random sampling

        if not self.tac_random_sampling or cnt == 0:
            # sequential select first max_batches batches (fallback or original sequential)
            logging.info(f"[CinDiag] Sequential sampling: using first {max_batches} batches for statistics")
            for xb, _ in loader:
                if len(xb) == 0:  # skip empty
                    continue
                xb = xb.to(self.device)
                _ = self.Model(xb)
                cnt += 1
                if cnt >= max_batches:
                    break

        # v = total_sum / cnt
        if cnt > 0 and total_sum_diag:
            for k, v in total_sum_diag.items():
                self.cin_diag[k] = (v / cnt).to(device=self.device, dtype=torch.float32).clamp(min=1e-8)
            if self.cin_diag:
                all_diag = torch.cat([v for v in self.cin_diag.values()])
        else:
            logging.warning("[CinDiag] No samples accumulated, all v=0")

        # fill layers not hooked
        for name, mod in self.Model.named_modules():
            if isinstance(mod, (nn.Linear, nn.Conv2d)):
                key = f"{name}.weight"
                if key not in self.cin_diag:
                    if isinstance(mod, nn.Linear):
                        init_size = mod.weight.shape[1]  # Din for Linear
                    elif isinstance(mod, nn.Conv2d):
                        kh, kw = mod.kernel_size if isinstance(mod.kernel_size, tuple) else (mod.kernel_size, mod.kernel_size)
                        init_size = mod.in_channels * kh * kw  # Cin * Kh * Kw
                    else:
                        init_size = mod.weight.shape[1]  # fallback
                    logging.info(f"[CinDiag] {key} never hooked, init zeros of size {init_size}")
                    self.cin_diag[key] = torch.zeros(init_size, device=mod.weight.device, dtype=torch.float32)

        for h in handles:
            h.remove()
        if was_training:
            self.Model.train()




    def _col_energy(self, name):
        """
        Get the √diag(X^T X) of this layer. Return None when no statistics.
        """
        v = self.cin_diag.get(name, None)
        if v is None:
            return None
        vv = v.clamp_min(self.comp_eps)
        p = getattr(self, 'cin_power', 1.0)
        if p != 1.0:
            vv = vv.pow(p)
        return vv.sqrt().to(device)

    def _score_element(self, name, dW_abs):
        """
        get the loss-aware score for element level compression units (for *.weight).
        For Linear layer, w_ij * column-wise broadcast ( sqrt(diag(X X^T) )
        For Conv layer, 
        """
        if self.selection_rule != "lossaware":
            return dW_abs

        v = self._col_energy(name)  # unified get (Conv/Linear shared)

        # unified Miss layer check: None/non-tensor/empty/all-zero/NaN/negative → fallback magnitude
        if v is None or not isinstance(v, torch.Tensor) or v.numel() == 0 or \
        torch.all(v == 0) or torch.any(torch.isnan(v)) or torch.any(v < 0):
            # logging.warning(f"[Score] Invalid v for {name} (None/empty/neg/NaN/zeros), fallback to magnitude")
            return dW_abs

        # numerical stability: clamp to prevent zero
        v = v.clamp(min=1e-8)

        if dW_abs.dim() == 4:  # Conv2d: [Cout, Cin, Kh, Kw]
            Cout, Cin, Kh, Kw = dW_abs.shape
            expected_size = Cin * Kh * Kw

            # shape check: prioritize per-position (paper formula, more accurate)
            if v.numel() == expected_size:
                try:
                    v_reshaped = v.view(Cin, Kh, Kw)  # [Cin, Kh, Kw]
                    scored = dW_abs * v_reshaped.unsqueeze(0)  # [Cout, Cin, Kh, Kw] * [1, Cin, Kh, Kw]
                    return scored
                except Exception as e:  # view failed (e.g., not divisible)
                    logging.warning(f"[Score] per-position view failed for {name}: {e}, try flat fallback")

            # fallback: flatten multiplication, but check matching first
            if v.numel() != expected_size:
                logging.warning(f"[Score] Shape mismatch for {name}: v={v.numel()} vs expected={expected_size}, fallback to magnitude")
                return dW_abs

            # safe flatten * broadcast
            flat = dW_abs.view(Cout, -1)  # [Cout, expected_size]
            scored_flat = flat * v.unsqueeze(0)  # [Cout, expected_size] * [1, expected_size]
            return scored_flat.view_as(dW_abs)  # restore [Cout, Cin, Kh, Kw]

        elif dW_abs.dim() == 2:  # Linear: [Dout, Din]
            Din = dW_abs.shape[1]
            if v.numel() != Din:
                logging.warning(f"[Score] Shape mismatch for {name}: v={v.numel()} vs Din={Din}, fallback to magnitude")
                return dW_abs
            return dW_abs * v.unsqueeze(0)  # [Dout, Din] * [1, Din]

        else:
            return dW_abs 



    def _compute_budget(self, total_params, layer_shapes):
        """
        Compute the compression budget based on compression ratio and granularity
        """
        if self.comp_granularity == 'element':
            return max(1, int(self.comp_ratio * total_params)) # Budget for element method is a rate, "0.1,0.2"
        elif self.comp_granularity == 'low_rank':
            return max(1, int(self.comp_ratio)) # Budget for low-rank method is a number, "1,2,3"
        else:
            raise ValueError(f"Unknown granularity: {self.comp_granularity}")

    def getParas(self):
        return {k: v.detach().clone() for k, v in self.Model.state_dict().items()}

    def getDeltaParas(self):
        return self.GradParas

    def getCompDeltaParas(self):
        """
        # select compression units (element, singular value) based on magnitude or loss-aware metric
        # Return: Compressed ΔW
        """
        assert self.GradParas is not None and len(self.GradParas) > 0, "call selftrain() first"

        # ΔW, ready in "self.train" function, storge in CPU
        delta = {k: v.to(device) if torch.is_tensor(v) else v for k, v in self.GradParas.items()}

        # Error feedback
        effective_delta = cp.deepcopy(delta)
        if self.comp_ef_enabled:
            for ky, d in effective_delta.items():
                if not torch.is_tensor(d):
                    continue
                r = self.ef_residuals.get(ky, None)
                if r is not None:
                    if r.shape == d.shape:
                        effective_delta[ky] = d + r
                    else:
                        raise ValueError(f"Shape Error: Delta Shape {d.shape} and Residual Shape {r.shape} not match!")

        # dispatch based on granularity and strategy
        if self.comp_granularity == 'element':
            if self.selection_rule in ['magnitude', 'lossaware']:
                return self._compress_delta_element_wise(effective_delta)
            else:
                raise ValueError(f"Selection rule {self.selection_rule} not supported for element granularity.")
        elif self.comp_granularity == 'low_rank':
            if self.selection_rule in ['magnitude', 'magnitude_svd']:
                return self._compress_delta_svd_wise(effective_delta, use_lossaware=False)
            elif self.selection_rule == 'lossaware_svd':
                return self._compress_delta_svd_wise(effective_delta, use_lossaware=True)
            else:
                raise ValueError(f"Selection rule {self.selection_rule} not supported for low_rank granularity.")
        else:
            raise ValueError(f"Unknown granularity: {self.comp_granularity}")

    def _compress_delta_element_wise(self, delta):
        """Element-wise compress delta ΔW"""
        layer_items = []
        total_elems = 0
        for ky, d in delta.items():
            if not torch.is_tensor(d):
                continue

            # whitelist check: skip parameters that should not be compressed
            if self._should_exclude_param_key(ky):
                continue

            if ("weight" in ky):
                scored = self._score_element(ky, d.detach().abs())
                flat = scored.flatten()
                layer_items.append((ky, flat))
                total_elems += flat.numel()
            elif ("bias" in ky) and self.comp_include_bias:
                flat = d.detach().abs().flatten()
                layer_items.append((ky, flat))
                total_elems += flat.numel()

        if len(layer_items) == 0:
            # return zero delta
            return {ky: torch.zeros_like(d) for ky, d in delta.items()}

        # compute thresholds
        thrs = {}
        if self.comp_mode == "global":
            mags_cat = torch.cat([x[1] for x in layer_items])
            N = mags_cat.numel()
            k = max(1, int(self.comp_ratio * N))
            topk_vals, _ = torch.topk(mags_cat, k, largest=True, sorted=False)
            thr = topk_vals.min()
            
            # numerical stability check: prevent global threshold too small causing compression failure
            if thr < 1e-12:
                logging.warning(f"[Compression] Global threshold too small: {thr:.2e}, using fallback")
                thr = torch.tensor(1e-12, device=thr.device, dtype=thr.dtype)
            
            for ky, _ in layer_items:
                thrs[ky] = thr
        else:  # per-layer
            for ky, flat in layer_items:
                k_i = max(1, int(self.comp_ratio * flat.numel()))
                vals, _ = torch.topk(flat, k_i, largest=True, sorted=False)
                thr_i = vals.min()
                
                # numerical stability check: prevent threshold too small causing compression failure
                if thr_i < 1e-12:
                    logging.warning(f"[Compression] Layer {ky} threshold too small: {thr_i:.2e}, using fallback")
                    thr_i = torch.tensor(1e-12, device=thr_i.device, dtype=thr_i.dtype)
                
                thrs[ky] = thr_i

        # pre-calculate: magnitude threshold for difference statistics (layer-wise calculation, compatible with any mode)
        mag_thrs = {}
        if self.selection_rule == 'lossaware':
            for ky, d in delta.items():
                if torch.is_tensor(d) and (("weight" in ky) or ("bias" in ky and self.comp_include_bias)) and (not self._should_exclude_param_key(ky)):
                    flat_mag = d.detach().abs().flatten()
                    k_i = max(1, int(self.comp_ratio * max(1, flat_mag.numel())))
                    mv, _ = torch.topk(flat_mag, k_i, largest=True, sorted=False)
                    mag_thrs[ky] = mv.min()

        # apply mask to delta
        sparse_delta = {}
        kept_deltas = {}  # for EF
        
        for ky, d in delta.items():
            # whitelist check: skip parameters that should not be compressed
            if self._should_exclude_param_key(ky):
                sparse_delta[ky] = d
                if self.comp_ef_enabled:
                    kept_deltas[ky] = d
                continue

            if ("weight" in ky) or ("bias" in ky and self.comp_include_bias):
                if ky in thrs:
                    if ("weight" in ky):
                        scored = self._score_element(ky, d.detach().abs())
                        mask = (scored >= thrs[ky]).to(d.dtype)
                        # difference statistics: loss-aware vs magnitude selection difference
                        if self.selection_rule == 'lossaware' and ky in mag_thrs:
                            mag_mask = (d.detach().abs() >= mag_thrs[ky]).to(d.dtype)
                            la_sel = (mask > 0)
                            mg_sel = (mag_mask > 0)
                            diff = (la_sel ^ mg_sel).sum().item()
                            chosen = (la_sel | mg_sel).sum().item()
                            self.elem_overlap_stats['total_inconsistent'] += int(diff)
                            self.elem_overlap_stats['total_selected'] += int(chosen)
                    else:  # bias
                        mask = (d.detach().abs() >= thrs[ky]).to(d.dtype)

                    sparse_delta[ky] = d * mask  # directly apply mask to delta

                    if self.comp_ef_enabled:
                        kept_deltas[ky] = d * mask
                else:
                    sparse_delta[ky] = torch.zeros_like(d)
            else:
                sparse_delta[ky] = torch.zeros_like(d)

        self._update_ef_residuals(delta, kept_deltas)
        if self.selection_rule == 'lossaware':
            self.elem_overlap_stats['rounds'] += 1
        return sparse_delta

    def _compress_delta_svd_wise(self, delta, use_lossaware=False):
        """SVD compress delta ΔW, optional loss-aware"""
        sparse_delta = {}
        compression_stats = {'total_layers': 0, 'compressed_layers': 0, 'total_rel_err': 0.0}
        total_rank = 0
        ranked_layers = 0
        
        for ky, d in delta.items():
            if self._should_exclude_param_key(ky):
                sparse_delta[ky] = d
                continue

            if ("weight" in ky) and d.dim() == 2:
                compression_stats['total_layers'] += 1
                try:
                    # use general rank selection logic
                    rank = self._select_rank((d.shape[0], d.shape[1]))
                    eff_rank_for_log = max(1, min(rank, min(d.shape[0], d.shape[1])))
                    proj_energy = None
                    if use_lossaware:
                        proj_energy = self.svd_proj_energy.get(ky, None)
                        # compatible with the case that .weight is not added in the statistics stage
                        if proj_energy is None and ky.endswith('.weight'):
                            proj_energy = self.svd_proj_energy.get(ky.replace('.weight', ''), None)
                    compressed_delta = self._apply_svd_compression_lossaware(
                        ky, d, rank, proj_energy, use_lossaware
                    )
                    sparse_delta[ky] = compressed_delta
                    compression_stats['compressed_layers'] += 1
                    total_rank += eff_rank_for_log
                    ranked_layers += 1
                    
                    # compute compression error for summary
                    try:
                        rel_err = (d - compressed_delta).norm() / (d.norm() + 1e-12)
                        compression_stats['total_rel_err'] += rel_err.item()
                    except:
                        pass
                        
                except Exception as e:
                    if self.Round % 50 == 0:  # reduce error log frequency
                        logging.warning(f"[SVD] {ky} failed: {e}; fallback to original delta.")
                    sparse_delta[ky] = d
            else:
                sparse_delta[ky] = torch.zeros_like(d)
        
        # only print overall information once at the beginning through compute_overall_compression; here no longer print by round
            
        # record current round rank summary for training end statistics
        self.last_total_rank = int(total_rank)
        self.last_avg_rank = float((total_rank / max(1, ranked_layers)) if ranked_layers > 0 else 0.0)
        return sparse_delta

    def _apply_svd_compression_lossaware(self, ky, Wdelta, rank, proj_energy=None, use_lossaware=False):
        """SVD reconstruction, support loss-aware scoring"""
        device_orig = Wdelta.device
        Wdelta_cpu = Wdelta.detach().to('cpu')
        U, S, Vh = torch.linalg.svd(Wdelta_cpu, full_matrices=False)
        # re-select rank based on configuration, ensure alignment with current layer
        rank = self._select_rank((U.shape[0], Vh.shape[1]))
        eff_rank = max(1, min(rank, S.shape[0]))  # first compute eff_rank
        if use_lossaware and (proj_energy is not None):
            proj_energy = proj_energy.to(S.device)
            common_len = min(S.shape[0], proj_energy.shape[0])
            if common_len <= 0:
                scores = S.abs()
            else:
                scores = (S[:common_len]**2) * proj_energy[:common_len]
            # statistics loss-aware vs magnitude rank selection difference
            try:
                k_debug = min(eff_rank, S.numel())  # use the actual selected rank number
                idx_mag = torch.topk(S.abs(), k_debug).indices
                L = common_len
                scores_la = (S[:L]**2) * proj_energy[:L]
                idx_la = torch.topk(scores_la, min(k_debug, L)).indices
                mag_set = set(idx_mag.tolist())
                la_set = set(idx_la.tolist())
                inconsistent = len(mag_set.symmetric_difference(la_set))
                self.svd_overlap_stats['total_inconsistent_ranks'] += int(inconsistent)
                self.svd_overlap_stats['total_compared_ranks'] += int(k_debug)
            except Exception:
                pass
        else:
            if use_lossaware and proj_energy is None:
                logging.warning(f"[LossAware-SVD] No proj_energy for {ky}, fallback to magnitude.")
            scores = S.abs()
        top_idx = torch.topk(scores, eff_rank, largest=True).indices
        U_r = U[:, top_idx]
        S_r = S[top_idx]
        Vh_r = Vh[top_idx, :]
        W_approx = (U_r * S_r) @ Vh_r
        # [SVD] compression strength (relative error after reconstruction) - only recorded under specific conditions
        try:
            rel = (Wdelta - W_approx.to(Wdelta.device)).norm() / (Wdelta.norm() + 1e-12)
            # only record non-backbone layers, and record once every 10 rounds, and only record main layers
            should_log = (
                self.Round % 10 == 0 and  # record once every 10 rounds
                not ("backbone" in ky.lower()) and  # exclude backbone detailed information
                ("fc" in ky.lower() or "classifier" in ky.lower() or "head" in ky.lower())  # only record key layers
            )
            if should_log:
                logging.info(f"[SVD] {ky} rank={eff_rank}/{min(Wdelta.shape)} rel_err={rel:.3f}")
        except Exception:
            pass
        return W_approx.to(device_orig)

    def _accumulate_svd_proj_energy(self, loader=None, max_batches=None):
        """statistics projection energy ||v_t^T X||^2 of each layer v_t"""
        self.svd_proj_energy = {}
        if loader is None:
            loader = self.CalibData
        was_training = self.Model.training
        self.Model.eval()
        handles = []

        def reg_svd_energy_hook(layer_name):
            def hook(mod, inp, out):
                x = inp[0].detach().float()
                # flatten any high-dimensional input into [N_all, Din], ensure alignment with Linear weight input dimension
                if x.dim() > 2:
                    x = x.view(-1, x.shape[-1])
                elif x.dim() == 1:
                    x = x.view(1, -1)
                W = mod.weight.detach().float()
                _, _, Vh = torch.linalg.svd(W, full_matrices=False)
                if x.shape[1] != Vh.shape[1]:
                    logging.warning(f"[SVD-Hook] Shape mismatch at {layer_name}: x.shape={tuple(x.shape)}, Vh.shape={tuple(Vh.shape)}; skip this batch")
                    return
                proj = x @ Vh.T
                energy = (proj ** 2).sum(dim=0)
                # fix: save key with .weight, consistent with compression stage
                key = f"{layer_name}.weight"
                if key not in self.svd_proj_energy:
                    self.svd_proj_energy[key] = energy
                else:
                    self.svd_proj_energy[key] += energy
            return hook

        for name, module in self.Model.named_modules():
            if isinstance(module, nn.Linear):
                handles.append(module.register_forward_hook(reg_svd_energy_hook(name)))

        cnt = 0
        max_batches = max_batches or self.calib_batches
        for xb, _ in loader:
            xb = xb.to(device)
            _ = self.Model(xb)
            cnt += 1
            if cnt >= max_batches:
                break

        for k in self.svd_proj_energy:
            self.svd_proj_energy[k] = (self.svd_proj_energy[k] / cnt).to(device)

        for h in handles:
            h.remove()
        if was_training:
            self.Model.train()

    def _rank_for_shape(self, shape, ratio):
        """calculate adaptive rank for specific shape"""
        if len(shape) == 2:
            dout, din = shape
        elif len(shape) == 4:
            dout, cin, kh, kw = shape
            din = cin * kh * kw
        else:
            return 1
        return max(1, int(ratio * min(dout, din)))

    def _select_rank(self, layer_shape):
        """select rank based on configuration: ratio | fixed | mixed"""
        dout, din = layer_shape
        min_dim = min(dout, din)
        mode = getattr(self, 'rank_mode', 'ratio')
        if mode == 'ratio':
            rank = int(min_dim * float(getattr(self, 'comp_ratio', 0.1)))
        elif mode == 'fixed':
            rank = int(min(getattr(self, 'fixed_rank', 16), min_dim))
        elif mode == 'mixed':
            rank = int(min_dim * float(getattr(self, 'comp_ratio', 0.1)))
            rank = max(int(getattr(self, 'min_rank', 1)), min(rank, int(getattr(self, 'max_rank', min_dim))))
        else:
            raise ValueError(f"Unknown rank_mode: {mode}")
        return max(1, min(rank, min_dim))

    def _update_ef_residuals(self, effective_grad, kept_grads):
        """Error Feedback of Residuals"""
        if not self.comp_ef_enabled:
            return
            
        for ky, g in effective_grad.items():
            if torch.is_tensor(g):
                prev_r = self.ef_residuals.get(ky, torch.zeros_like(g))
                u = g + prev_r  # original delta + previous round residual
                kept = kept_grads.get(ky, torch.zeros_like(g))
                self.ef_residuals[ky] = (u - kept).detach()

    def _log_compression_stats(self, GParas, layer_items, granularity):
        if not COMPRESSION_VERBOSE:
            return
        nz, tot = 0, 0
        for ky in GParas:
            if ("weight" in ky) or ("bias" in ky and self.comp_include_bias):
                nz += (GParas[ky] != 0).sum().item()
                tot += GParas[ky].numel()
        sparsity = 1 - nz / max(1, tot)
        logging.info(f"[Unified-Compression] granularity={granularity}, rule={self.selection_rule}, ratio={self.comp_ratio}, sparsity={sparsity:.4f}")
        if self.selection_rule == "lossaware":
            weight_layers = [k for k in GParas.keys() if k.endswith('.weight')]
            miss = sum(1 for k in weight_layers if k not in self.cin_diag)
            logging.info(f"[Loss-aware] cin_diag layers: {len(self.cin_diag)}, missing: {miss}")
        if self.Round % (self.log_layer_stats_every * 2) == 0:
            for ky in GParas:
                if ("weight" in ky):
                    layer_nz = (GParas[ky] != 0).sum().item()
                    layer_tot = GParas[ky].numel()
                    layer_sparsity = 1 - layer_nz / max(1, layer_tot)
                    logging.info(f"[Layer-Stats] {ky}: sparsity={layer_sparsity:.4f}")

    def compute_overall_compression(self):
        """estimate overall parameter compression rate (based on selected rank according to configuration)."""
        try:
            model = getattr(self, 'Model', None)
            if model is None:
                return
            total_orig = 0
            total_after = 0
            total_rank = 0
            for name, p in model.named_parameters():
                if not torch.is_tensor(p):
                    continue
                if p.ndim == 2 and name.endswith('.weight'):
                    dout, din = p.shape
                    total_orig += dout * din
                    r = self._select_rank((dout, din))
                    total_rank += int(max(1, min(r, min(dout, din))))
                    # store as two factors (ignoring r term stored in diagonal S)
                    total_after += dout * r + din * r
                else:
                    # other parameters not compressed, counted as is
                    total_orig += p.numel()
                    total_after += p.numel()
            if total_orig > 0:
                compression_rate = 1.0 - (total_after / total_orig)
                logging.info(f"[Compression] Total orig params: {total_orig}, after: {total_after}, rate: {compression_rate:.4f}, total_rank={total_rank}")
        except Exception as e:
            logging.warning(f"[Compression] compute_overall_compression exception: {e}")

    def total_expected_rank(self):
        """return total expected rank based on current configuration (no logging, used for ending statistics difference)."""
        try:
            model = getattr(self, 'Model', None)
            if model is None:
                return 0
            total_rank = 0
            for name, p in model.named_parameters():
                if torch.is_tensor(p) and p.ndim == 2 and name.endswith('.weight'):
                    dout, din = p.shape
                    r = self._select_rank((dout, din))
                    total_rank += int(max(1, min(r, min(dout, din))))
            return int(total_rank)
        except Exception:
            return 0

    def updateParas(self, global_state_dict):
        self.Model.load_state_dict(global_state_dict)

        # EF residual clear
        if self.comp_ef_enabled and self.ef_reset_on_global:
            self.ef_residuals.clear()

    def updateLR(self, lr):
        self.LR = lr
        self.decay_rate = 1

    def getLR(self):
        return self.LR

    def selftrain(self, control_local=None, control_global=None):
        self.Round += 1
        self.BeforeParas = self.getParas()

        # Local Learning Rate
        if self.Round % self.decay_step == 0:
            self.LR *= self.decay_rate

        # Client Local Optimizer
        optimizer = None
        if self.Optzer == "SGD":
            optimizer = torch.optim.SGD(self.Model.parameters(
            ), lr=self.LR, momentum=0.9, weight_decay=self.Wdecay)
        elif self.Optzer == "AdamW":
            optimizer = torch.optim.AdamW(self.Model.parameters(),
                                         lr=self.LR, weight_decay=self.Wdecay,
                                         betas=(0.9, 0.999), eps=1e-8)
        elif self.Optzer == "FedProx":
            optimizer = OP2.FedProx(self.Model.parameters(
            ), lr=self.LR, momentum=0.9, weight_decay=self.Wdecay, mu=self.Mu)
        elif self.Optzer == "FedNova":
            optimizer = OP3.FedNova(self.Model.parameters(
            ), lr=self.LR, momentum=0.9, weight_decay=self.Wdecay)

        if self.Optzer == "VRL":
            self.optimizer.param_groups[0]['lr'] = self.LR
            optimizer = self.optimizer

        # Loss
        self.trainloss = 0
        self.difloss = 0

        SLoss = []
        GNorm = []
        new_loss_fn = nn.CrossEntropyLoss()
        
        # temporarily disable mixed precision training to debug
        use_amp = False  # hasattr(torch.cuda, 'amp') and torch.device(device).type == 'cuda'
        if use_amp:
            scaler = torch.cuda.amp.GradScaler(enabled=True)  # explicitly enable
        
        self.Model.train()
        Local_Steps = 0
        
        for r in range(self.Epoch):
            sum_loss = 0.0
            C = 0
            for batch_id, (inputs, targets) in enumerate(self.TrainData):
                C = C + 1
                inputs, targets = inputs.to(device, non_blocking=True), targets.to(device, non_blocking=True)
                
                # mixed precision training process
                optimizer.zero_grad()
                if self.Optzer == "VRL":
                    self.optimizer.zero_grad()
                
                if use_amp:
                    # use complete AMP + GradScaler process
                    with torch.cuda.amp.autocast():
                        outputs = self.Model(inputs)
                        loss = self.loss_fn(outputs, targets)
                    
                    # use GradScaler for backward propagation, the numerical value of gradient clipping can be determined by the model
                    scaler.scale(loss).backward()
                    torch.nn.utils.clip_grad_norm_(self.Model.parameters(), 10.0)  # reduce gradient clipping threshold
                    
                    if self.Optzer == "VRL":
                        scaler.step(self.optimizer)
                    else:
                        scaler.step(optimizer)
                    scaler.update()
                else:
                    # regular FP32 training
                    outputs = self.Model(inputs)
                    loss = self.loss_fn(outputs, targets)
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(self.Model.parameters(), 10.0)  # reduce gradient clipping threshold
                    
                    if self.Optzer == "VRL":
                        self.optimizer.step()
                    else:
                        optimizer.step()

                self.difloss = 0  # simplify calculation

                sum_loss += loss.item()

            SLoss.append(sum_loss / max(1, C))
            Local_Steps = C

        self.trainloss = np.mean(SLoss)
        Lrnow = self.getLR()
        self.local_steps = Local_Steps * self.Epoch

        if self.Optzer == "VRL":
            self.optimizer.update_params()

        NVec = 1
        if self.Optzer == "FedNova":
            NVec = optimizer.local_normalizing_vec

        # Compute the delta parameters
        delta_params = {}
        with torch.no_grad():
            current_params = self.Model.state_dict()
            for k in current_params.keys():
                delta_params[k] = (current_params[k] - self.BeforeParas[k])
        self.GradParas = delta_params
        
        # compute cin_diag each round
        if self.selection_rule == "lossaware":
            self._accumulate_cin_diag(loader=self.CalibData, max_batches=self.calib_batches)
        elif self.selection_rule == "lossaware_svd":
            self._accumulate_svd_proj_energy(loader=self.CalibData, max_batches=self.calib_batches)

        # [Check] statistics whether it has been marked (low-frequency check)
        if self.selection_rule == "lossaware_svd" and self.Round % 50 == 0:
            logging.info(f"[Check] svd_proj_energy layers = {len(getattr(self,'svd_proj_energy',{}))}")

        return NVec

    def evaluate(self, loader=None, max_samples=100000):
        self.Model.eval()
        loss, correct, samples, iters = 0, 0, 0, 0
        loss_fn = nn.CrossEntropyLoss()
        with torch.no_grad():
            for i, (x, y) in enumerate(loader):
                x, y = x.to(device), y.to(device)
                y_ = self.Model(x)
                _, preds = torch.max(y_.data, 1)
                correct += (preds == y).sum().item()
                loss += loss_fn(y_, y).item()
                samples += y_.shape[0]
                iters += 1
                if samples >= max_samples:
                    break
        return correct / samples, loss / iters


class Server_Sim:
    def __init__(self, Loader, Global_Model, Lr, wdecay=0, Fixlr=False, Dname="cifar10"):
        self.TrainData = Loader  # direct reference, save memory
        self.Global_Model = Global_Model
        if not hasattr(self.Global_Model, 'device') or self.Global_Model.parameters().__next__().device != device:
            self.Global_Model = self.Global_Model.to(device)  # ensure model on the correct device
        self.optimizer = torch.optim.SGD(
            self.Global_Model.parameters(), lr=Lr, momentum=0.9, weight_decay=wdecay)
        # Note: Server doesn't need learning rate scheduling in standard FL
        # self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=1, gamma=0.9)
        self.loss_fn = nn.CrossEntropyLoss()
        self.FixLr = Fixlr
        self.RecvParas = []
        self.RecvLens = []
        self.RecvScale = []
        self.RecvAs = []
        self.LStep = 0
        self.CStep = 0
        self.Eta = 0.01
        self.Beta1 = 0.5
        self.Beta2 = 0.9
        self.Tau = 0.001
        self.Vt = None
        self.Mt = None
        self.Round = 0

    def getParas(self):
        return {k: v.detach().clone() for k, v in self.Global_Model.state_dict().items()}

    def getLR(self):
        LR = self.optimizer.state_dict()['param_groups'][0]['lr']
        return LR

    def updateParas(self, Paras):
        self.Global_Model.load_state_dict(Paras)

    def avgDeltas(self, DeltaList, Lens):
        """
        FedAvg Aggregation, Weighted Aggregation according to DataLens
        DeltaList: [ΔW1, ΔW2, ...] client delta list
        Lens: [n1, n2, ...] sample count for each client
        Return: ΔW_agg = Σ(n_i/N * ΔW_i)
        """            
        if len(DeltaList) == 0:
            return {}
        total_n = float(sum(Lens))
        if total_n <= 0:
            return {}

        # only aggregate floating point parameters; skip integer/long integer (e.g. BN counter)
        server_state = self.Global_Model.state_dict()
        float_keys = []
        for ky, v in DeltaList[0].items():
            if torch.is_tensor(v) and v.is_floating_point() and (ky in server_state):
                float_keys.append(ky)

        Res = {}
        for ky in float_keys:
            target = server_state[ky]
            Res[ky] = torch.zeros_like(target, device=target.device, dtype=target.dtype)

        for i, delta in enumerate(DeltaList):
            weight = Lens[i] / total_n
            for ky in float_keys:
                if ky in delta and torch.is_tensor(delta[ky]) and delta[ky].is_floating_point():
                    # align to target device and dtype
                    addend = delta[ky].to(device=Res[ky].device, dtype=Res[ky].dtype)
                    Res[ky] += addend * addend.new_tensor(weight)

        return Res

    def Adagrad(self, Grad):
        for ky in Grad.keys():
            self.Vt[ky] = self.Vt[ky] + 0.25 * Grad[ky] ** 2

    def Yogi(self, Grad):
        for ky in Grad.keys():
            Vt = self.Vt[ky]
            self.Vt[ky] = Vt - (1 - self.Beta2) * \
                Grad[ky] ** 2 * torch.sign(Vt - Grad[ky] ** 2)

    def Adam(self, Grad):
        for ky in Grad.keys():
            Vt = self.Vt[ky]
            self.Vt[ky] = self.Beta2 * Vt + (1 - self.Beta2) * Grad[ky] ** 2
    def aggParas(self, Server_Optim="Yogi"):
        self.Round += 1
        Disc = 0.9

        # add: recursively convert to float to prevent Long (dict/tensor)
        def _to_float_recursive(params):
            long_count = 0
            for k, v in params.items():
                if torch.is_tensor(v):
                    if v.dtype in (torch.int64, torch.long, torch.bool):
                        params[k] = v.float()
                        long_count += 1
                        # logging.warning(f"[AggParas] Converted {k}: shape={v.shape}, dtype={v.dtype} → float")
                elif isinstance(v, dict):
                    _, child_count = _to_float_recursive(v)
                    long_count += child_count
            return params, long_count

        # Aggregation: θ^{t+1} = θ^t + ΔW_agg
        agg_delta = self.avgDeltas(self.RecvParas, self.RecvLens)  # ΔW_agg
        
        # debug information: check aggregation result - display every round
        if True:
            total_norm = 0
            non_zero_params = 0
            total_params = 0
            for ky, v in agg_delta.items():
                if torch.is_tensor(v):
                    total_norm += v.norm().item() ** 2
                    non_zero_params += (v != 0).sum().item()
                    total_params += v.numel()
            
            sparsity = 1 - (non_zero_params / max(1, total_params))
            num_clients = len(self.RecvLens)
            total_samples = sum(self.RecvLens)
            avg_samples = total_samples / max(1, num_clients)
            # get current learning rate
            current_lr = self.optimizer.param_groups[0]['lr'] if hasattr(self, 'optimizer') and self.optimizer else "N/A"
            # logging.info(f"[Aggregation Debug] Round {self.Round}: clients={num_clients}, total_samples={total_samples}, avg_samples={avg_samples:.1f}, agg_norm={total_norm**0.5:.6f}, sparsity={sparsity:.4f}, server_lr={current_lr}")

        current_params = self.getParas()
        current_params, long_count_cp = _to_float_recursive(current_params.copy())  # copy to prevent original modification
        if long_count_cp > 0:
            # logging.info(f"[AggParas] Fixed {long_count_cp} Long in current_params")
            pass

        GParas = {}
        for ky in current_params.keys():
            if ky in agg_delta:
                # ensure agg_delta tensor and current_params on the same device
                agg_delta_tensor = agg_delta[ky].to(device=current_params[ky].device) if torch.is_tensor(agg_delta[ky]) else agg_delta[ky]
                GParas[ky] = current_params[ky] + agg_delta_tensor
            else:
                GParas[ky] = current_params[ky]

        GParas, long_count_gp = _to_float_recursive(GParas)
        if long_count_gp > 0:
            # logging.info(f"[AggParas] Fixed {long_count_gp} Long in GParas")
            pass

        if Server_Optim != None:  # remove round limit
            if self.Vt == None:
                self.Vt = cp.deepcopy(GParas)
                for ky in GParas.keys():
                    G = GParas[ky]
                    Gen = torch.zeros_like(G) + self.Tau**2
                    self.Vt[ky] = Gen.float()  # add: force float

            GetGrad = cp.deepcopy(GParas)
            BParas = self.getParas()
            for ky in BParas.keys():
                grad = GParas[ky] - BParas[ky]
                GetGrad[ky] = grad.float()  # add: grad to float

            if Server_Optim == "Adag":
                self.Adagrad(GetGrad)
            if Server_Optim == "Adam":
                self.Adam(GetGrad)
            if Server_Optim == "Yogi":
                self.Yogi(GetGrad)

            if self.Mt == None:
                self.Mt = cp.deepcopy(GetGrad)
                for ky in self.Mt.keys():  # add: Mt to float
                    self.Mt[ky] = self.Mt[ky].float()

            for ky in self.Mt.keys():
                self.Mt[ky] = self.Mt[ky] * self.Beta1 + \
                    GetGrad[ky] * (1 - self.Beta1)

            for ky in GetGrad.keys():
                NewGrad = self.Mt[ky] / (torch.sqrt(self.Vt[ky]) + self.Tau)
                In = 0
                if "weight" in ky:
                    In = 1
                if "bias" in ky:
                    In = 1
                if In == 1:
                    Eta = torch.median(torch.sqrt(self.Vt[ky]) + self.Tau)
                    GParas[ky] = BParas[ky] + Eta * NewGrad

            self.Eta *= Disc
            self.Eta = max(self.Eta, self.Tau)

        # update
        if True:  # display every round
            old_norm = sum(p.float().norm().item()**2 for p in current_params.values() if torch.is_tensor(p))**0.5 
            new_norm = sum(p.float().norm().item()**2 for p in GParas.values() if torch.is_tensor(p))**0.5  
            param_change = sum((GParas[k] - current_params[k]).float().norm().item()**2 
                             for k in current_params.keys() 
                             if k in GParas and torch.is_tensor(GParas[k]))**0.5  
            # logging.info(f"[Parameter Update] Round {self.Round}: old_norm={old_norm:.6f}, new_norm={new_norm:.6f}, change_norm={param_change:.6f}")
        
        self.updateParas(GParas)
        self.RecvParas = []
        self.RecvLens = []
        self.RecvScale = []

    def recvInfo(self, Para, Len, Scale):
        self.RecvParas.append(Para)
        self.RecvLens.append(Len)
        self.RecvScale.append(Scale)

    def evaluate(self, loader=None, max_samples=100000):
        self.Global_Model.eval()

        loss, correct, samples, iters = 0, 0, 0, 0
        with torch.no_grad():
            for i, (x, y) in enumerate(loader):
                x, y = x.to(device), y.to(device)
                y_ = self.Global_Model(x)
                _, preds = torch.max(y_.data, 1)
                loss += self.loss_fn(y_, y).item()

                correct += (preds == y).sum().item()
                samples += y_.shape[0]
                iters += 1

                if samples >= max_samples:
                    break

        return loss / iters, correct / samples

    def saveModel(self, Path):
        torch.save(self.Global_Model, Path)