# src/apps/clients/snowfl_classification_client.py

import copy
from typing import Any, Dict, Iterable, Optional

from src.apps.clients.scalefl_classification_client import ScaleFLClassificationClient
from src.apps.clients.reefl_classification_client import ReeFLClassificationClient
from src.models.model_utils import retarget_exit_preconv_bns
from src.models.model_utils import dump_trunk_taps, dump_ee_invariants, assert_ee_conv_bn_match
from src.models.snip_utils import recalibrate_bn

class SNOWFLClassificationClient(ScaleFLClassificationClient):
    """
    SnowFL client with backwards compatibility + optional SNIP.
    - Honors server 'keys_prog' order.
    - Can apply SNIP masks received at fit-time (structural surgery).
    - Retargets exit-head preconv BNs after surgery to avoid channel mismatches.
    - Optional BN recalibration using val/train loader (no structural change).
    - Returns either full local state (params+buffers) or just trainables.
    """
    def _exit_is_snip(self, exit_i: int) -> bool:
        try:
            pm = getattr(self.ckp.config.app, "args", None)
            modes = getattr(pm, "pruning_mode", []) if pm is not None else []
            return (exit_i < len(modes)) and (str(modes[exit_i]).lower() == "snip")
        except Exception:
            return False

    def _mask_sig(self, d: dict) -> str:
        import json, hashlib
        return hashlib.sha1(json.dumps(d, sort_keys=True).encode()).hexdigest()[:10]

    def _prepare_bn_for_recal(self, reset_running_stats: bool = False) -> None:
        import torch, torch.nn as nn
        dev = next(self.net.parameters()).device
        for name, m in self.net.named_modules():
            if not isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
                continue
            if name.startswith("exit_heads."):
                # keep head BN behavior (per-batch, no running stats)
                continue

            # from here unchanged: ensure buffers exist, optional reset…
            m.track_running_stats = True
            nf = int(m.num_features)
            if getattr(m, "running_mean", None) is None:
                if hasattr(m, "running_mean"): delattr(m, "running_mean")
                m.register_buffer("running_mean", torch.zeros(nf, device=dev))
            else:
                m.running_mean = m.running_mean.to(dev)

            if getattr(m, "running_var", None) is None:
                if hasattr(m, "running_var"): delattr(m, "running_var")
                m.register_buffer("running_var", torch.ones(nf, device=dev))
            else:
                m.running_var = m.running_var.to(dev)

            if hasattr(m, "num_batches_tracked"):
                if m.num_batches_tracked is None:
                    if hasattr(m, "num_batches_tracked"): delattr(m, "num_batches_tracked")
                    m.register_buffer("num_batches_tracked", torch.tensor(0, dtype=torch.long, device=dev))
                else:
                    m.num_batches_tracked = m.num_batches_tracked.to(dev, dtype=torch.long)

            if reset_running_stats:
                m.running_mean.zero_()
                m.running_var.fill_(1.0)
                if hasattr(m, "num_batches_tracked") and m.num_batches_tracked is not None:
                    m.num_batches_tracked.zero_()

    def _apply_mask_and_retarget(self, base_model, snip_mask, target_device):
        """
        Apply structural SNIP mask to a fresh full-width model (CPU → prune → move once),
        then retarget exit-head preconv BNs and head FCs using a REAL validation batch.
        """
        base_cpu = copy.deepcopy(base_model).cpu()

        from src.models.snip_utils import hard_prune_resnet

        # fetch one real batch for patching (prefer val, else train)
        patch_batch = None
        patch_loader = self._build_loader_from_config("val")
        if patch_loader is not None:
            try:
                xb, _ = next(iter(patch_loader))
                patch_batch = xb[:256]  # keep it light
            except Exception:
                patch_batch = None

        # --- NEW: GRU path ---
        if hasattr(base_model, "gru"):
            from src.models.gru_prune_utils import hard_prune_gru_uniform
            units = snip_mask.get("units", None)
            if units is None:
                # no mask → nothing to prune
                return copy.deepcopy(base_model).to(target_device)
            pruned = hard_prune_gru_uniform(copy.deepcopy(base_model).cpu(), units).to(target_device)
            return pruned

        pruned_cpu = hard_prune_resnet(
            base_cpu, snip_mask, patch_sample=patch_batch, patch_loader=patch_loader, patch_max_per_batch=64
        ).cpu()
        pruned = pruned_cpu.to(target_device, non_blocking=False)

        retarget_exit_preconv_bns(pruned, verbose=True)
        # self._need_bn_recal = False
        dump_trunk_taps(pruned, prefix=f"[client {self.cid}] ")
        dump_ee_invariants(pruned, prefix=f"[client {self.cid}] ")
        assert_ee_conv_bn_match(pruned, strict=False, prefix=f"[client {self.cid}] ")
        return pruned
    
    # src/apps/clients/snowfl_classification_client.py

    def _maybe_bn_recal(self, cfg):
        need_flag = bool(cfg.get("need_bn_recal", False)) or bool(getattr(self, "_need_bn_recal", False))
        print("need_flag", need_flag)
        if not need_flag:
            return
        try:
            # Prefer attached loaders; do NOT build server VAL here
            loader, tag = self._pick_calibration_loader()
        except Exception as e:
            print(f"[client {self.cid}] BN recalibration skipped (no local calib loader: {e})")
            self._need_bn_recal = False
            return

        device = next(self.net.parameters()).device
        bn_batches = max(1, int(cfg.get("bn_calib_batches", 200)))
        try:
            n_avail = len(loader)
            if isinstance(n_avail, int) and n_avail > 0:
                bn_batches = min(bn_batches, n_avail)
        except Exception:
            pass

        first_after_surgery = bool(getattr(self, "_need_bn_recal", False))
        self._prepare_bn_for_recal(reset_running_stats=first_after_surgery)

        
        recalibrate_bn(
            self.net, loader, device=device,
            num_batches=bn_batches, max_per_batch=500,
            reset_running_stats=first_after_surgery,
            log_prefix=f"[client {self.cid}] "
        )
        print(f"[client {self.cid}] BN recalibration done ({tag}, batches={bn_batches}).")
        self._need_bn_recal = False
    
    def fit(self, parameters, config):
        import json, ast
        from hashlib import sha1
        from typing import Any, Dict, Iterable, Optional

        cfg: Dict[str, Any] = dict(config or {})

        # --- server-provided training order (params only) ---
        raw = cfg.get("keys_prog", None)
        order: Optional[Iterable[str]] = None
        if isinstance(raw, (list, tuple)):
            order = list(raw)
        elif isinstance(raw, str):
            try:
                order = json.loads(raw)
            except Exception:
                try:
                    parsed = ast.literal_eval(raw)
                    if isinstance(parsed, (list, tuple)):
                        order = list(parsed)
                except Exception:
                    order = None

        if not order:
            order = list(self.local_sd_keys)
            print(f"[client {self.cid}] SNOWFL.fit(): keys_prog missing → TRAINABLE payload ({len(order)} keys)")

        # what we RETURN (trainable or full)
        send_full = bool(cfg.get("send_full_local_state", False))
        self._payload_keys = (self.local_full_keys if send_full else order)

        # --- SNIP support (unchanged) ---
        snip_mask = cfg.get("snip_mask", None)
        if isinstance(snip_mask, str):
            try:
                snip_mask = json.loads(snip_mask)
            except Exception:
                snip_mask = None

        have_mask = isinstance(snip_mask, dict) and len(snip_mask) > 0
        if have_mask:
            sig = self._mask_sig(snip_mask)
            if sig != getattr(self, "_last_mask_sig", None):
                print(f"[client {self.cid}] Received new SNIP mask (sig={sig}). Rebuilding from full-width template.")
                dev = next(self.net.parameters()).device
                base = copy.deepcopy(self._fullwidth_template).cpu()
                self.net = self._apply_mask_and_retarget(base, snip_mask, dev)
                self._last_mask_sig = sig
                self._refresh_keylists_after_surgery()
                self._apply_trainability_after_surgery()
                self._need_bn_recal = True
            else:
                print(f"[client {self.cid}] SNIP mask unchanged (sig={sig}); skipping re-prune).")

        # --- load server weights for TRAINING using the decoded order ---
        self.set_parameters(parameters, keys_override=order)

        # --- progressive training switch for this round ---
        prev_depth = getattr(self, "depth_training", False)
        self.depth_training = True
        try:
            self.net.train()
            # Call the base ReeFL fit (not ScaleFL.fit)
            result = ReeFLClassificationClient.fit(self, parameters, cfg)
        finally:
            self.depth_training = prev_depth

        # --- rebuild payload to exactly match self._payload_keys ---
        if isinstance(result, tuple) and len(result) == 3 and isinstance(result[2], dict):
            sd = self.net.state_dict()
            payload = [sd[k].detach().cpu().numpy() for k in self._payload_keys if k in sd]
            ks = self._payload_keys
            key_sig = sha1("|".join(ks).encode()).hexdigest()[:10]
            key_kind = "full" if set(ks) == set(self.local_full_keys) else "train"
            width = float(getattr(self, "width_scale", 1.0))
            result[2].setdefault("width_scale", width)
            result[2].setdefault("key_sig", key_sig)
            result[2].setdefault("keyset_kind", key_kind)
            result = (payload, result[1], result[2])

        return result
    
    
    def _fit(self, parameters, config):
        cfg: Dict[str, Any] = dict(config or {})

        # server-provided training order (params only)
        try:
            order: Optional[Iterable[str]] = list(cfg.get("keys_prog", []))
        except Exception:
            order = None

        # what we RETURN (trainable or full)
        send_full = bool(cfg.get("send_full_local_state", False))
        self._payload_keys = (self.local_full_keys if send_full else (order or self.local_sd_keys))

        snip_mask = cfg.get("snip_mask", None)
        if isinstance(snip_mask, str):
            try:
                import json
                snip_mask = json.loads(snip_mask)
            except Exception:
                snip_mask = None

        have_mask = isinstance(snip_mask, dict) and len(snip_mask) > 0
        # print("HOS ", have_mask)
        if have_mask:
            sig = self._mask_sig(snip_mask)
            if sig != getattr(self, "_last_mask_sig", None):
                print(f"[client {self.cid}] Received new SNIP mask (sig={sig}). Rebuilding from full-width template.")
                dev = next(self.net.parameters()).device
                # base = copy.deepcopy(self._fullwidth_template).to(dev)
                base = copy.deepcopy(self._fullwidth_template).cpu()
                self.net = self._apply_mask_and_retarget(base, snip_mask, dev)
                self._last_mask_sig = sig

                # refresh payload/key lists and optimizer targets after structural surgery
                self._refresh_keylists_after_surgery()
                self._apply_trainability_after_surgery()

                # request one-time BN recal after structural change
                self._need_bn_recal = True
            else:
                print(f"[client {self.cid}] SNIP mask unchanged (sig={sig}); skipping re-prune).")

        # ---- load server weights for TRAINING using the order ----
        self.set_parameters(parameters, keys_override=order or self.local_sd_keys)
        # self._maybe_bn_recal(cfg)
        # ---- optional BN recalibration (no structural change) ----
        # need_flag = bool(cfg.get("need_bn_recal", False)) or bool(getattr(self, "_need_bn_recal", False))
        print("self._need_bn_recal", self._need_bn_recal)
        if self._need_bn_recal:
            try:
                src = str(cfg.get("bn_calib_source", "val")).lower()
                if src in ("val", "server_val"):
                    loader = self._build_loader_from_config("val") or self._build_loader_from_config("train")
                elif src == "train":
                    loader = self._build_loader_from_config("train") or self._build_loader_from_config("val")
                else:
                    loader = self._build_loader_from_config("val") or self._build_loader_from_config("train")

                if loader is not None:
                    device = next(self.net.parameters()).device   # keep as torch.device
                    bn_batches = max(1, int(cfg.get("bn_calib_batches", 400)))
                    try:
                        n_avail = len(loader)
                        if isinstance(n_avail, int) and n_avail > 0:
                            bn_batches = min(bn_batches, n_avail)
                    except Exception:
                        pass
                    first_after_surgery = bool(getattr(self, "_need_bn_recal", False))
                    first_after_surgery = True
                    self._prepare_bn_for_recal(reset_running_stats=first_after_surgery)
                    recalibrate_bn(
                        self.net,
                        loader,
                        device=device,                 # <-- pass torch.device, not str
                        num_batches=bn_batches,
                        max_per_batch=500,
                        reset_running_stats=first_after_surgery,
                        log_prefix=f"[client {self.cid}] "
                    )
                    if first_after_surgery:
                        self._need_bn_recal = False

                    print(f"[client {self.cid}] BN recalibration done (src={src}, batches={bn_batches}).")
                else:
                    print(f"[client {self.cid}] BN recalibration skipped (no calibration loader).")
            except Exception as e:
                print(f"[client {self.cid}] BN recalibration skipped ({e}).")
            finally:
                self._need_bn_recal = False
        # ---- progressive training switch for this round ----
        prev_depth = getattr(self, "depth_training", False)
        self.depth_training = True
        try:
            self.net.train()
            # Call the *grandparent* fit (ReeFL) directly, not ScaleFL.fit
            result = ReeFLClassificationClient.fit(self, parameters, cfg)
        finally:
            self.depth_training = prev_depth

        # ---- annotate payload hints (useful for server-side debugging) ----
        try:
            from hashlib import sha1
            ks = getattr(self, "_payload_keys", self.local_sd_keys)
            key_sig = sha1("|".join(ks).encode()).hexdigest()[:10]
            key_kind = "full" if set(ks) == set(self.local_full_keys) else "train"
            width = float(getattr(self, "width_scale", 1.0))

            if hasattr(result, "metrics") and isinstance(result, dict):
                result.setdefault("width_scale", width)
                result.setdefault("key_sig", key_sig)
                result.setdefault("keyset_kind", key_kind)
            elif isinstance(result, tuple) and len(result) == 3 and isinstance(result[2], dict):
                result[2].setdefault("width_scale", width)
                result[2].setdefault("key_sig", key_sig)
                result[2].setdefault("keyset_kind", key_kind)
                result = (result[0], result[1], result[2])
        except Exception:
            pass

        return result
    # ------------------------------------
    # Main training entrypoint (single round)
    # ------------------------------------
    def __fit(self, parameters, config):
        cfg: Dict[str, Any] = dict(config or {})

        # ---- server-provided order for TRAINING (params only) ----
        # Compatible with "previous" behavior: if keys_prog is present, use it.
        try:
            order: Optional[Iterable[str]] = list(cfg.get("keys_prog", []))
        except Exception:
            order = None

        # what we RETURN this round (trainable or full)
        send_full = bool(cfg.get("send_full_local_state", False))
        self._payload_keys = (self.local_full_keys if send_full else (order or self.local_sd_keys))

        # ---- SNIP-related fields are safely ignored in this non-SNIP client ----
        # (We keep the parsing so servers sending these fields won't crash us.)
        snip_mask = cfg.get("snip_mask", None)
        if isinstance(snip_mask, str):
            try:
                import ast
                _ = ast.literal_eval(snip_mask)  # parse & discard
            except Exception:
                pass
        # No structural rebuilds here on purpose.

        # ---- load server weights using TRAINING order ----
        # If `order` is None we fall back to local trainable keys.
        self.set_parameters(parameters, keys_override=order or self.local_sd_keys)

        # ---- optional BN recalibration (no structural change) ----
        need_flag = bool(cfg.get("need_bn_recal", False)) or bool(getattr(self, "_need_bn_recal", False))
        if need_flag:
            try:
                src = str(cfg.get("bn_calib_source", "val")).lower()
                if src in ("val", "server_val"):
                    loader = self._build_loader_from_config("val") or self._build_loader_from_config("train")
                elif src == "train":
                    loader = self._build_loader_from_config("train") or self._build_loader_from_config("val")
                else:
                    loader = self._build_loader_from_config("val") or self._build_loader_from_config("train")

                if loader is not None:
                    device = next(self.net.parameters()).device
                    bn_batches = max(1, int(cfg.get("bn_calib_batches", 20)))
                    try:
                        n_avail = len(loader)
                        if isinstance(n_avail, int) and n_avail > 0:
                            bn_batches = min(bn_batches, n_avail)
                    except Exception:
                        pass

                    recalibrate_bn(
                        self.net,
                        loader,
                        device=device,
                        num_batches=bn_batches,
                        max_per_batch=256,
                        # Keep defaults liberal (no resetting) to match “previous” behavior
                        reset_running_stats=False,
                        log_prefix=f"[client {self.cid}] "
                    )
                    print(f"[client {self.cid}] BN recalibration done (src={src}, batches={bn_batches}).")
                else:
                    print(f"[client {self.cid}] BN recalibration skipped (no calibration loader).")
            except Exception as e:
                print(f"[client {self.cid}] BN recalibration skipped ({e}).")
            finally:
                # Clear the flag so we don't recal again unless requested
                self._need_bn_recal = False

        # ---- progressive training switch for this round (unchanged) ----
        prev_depth = getattr(self, "depth_training", False)
        self.depth_training = True
        try:
            self.net.train()
            result = super().fit(parameters, cfg)
        finally:
            self.depth_training = prev_depth

        # ---- annotate payload hints (unchanged, useful for debugging) ----
        try:
            from hashlib import sha1
            ks = getattr(self, "_payload_keys", self.local_sd_keys)
            key_sig = sha1("|".join(ks).encode()).hexdigest()[:10]
            key_kind = "full" if set(ks) == set(self.local_full_keys) else "train"
            width = float(getattr(self, "width_scale", 1.0))

            if hasattr(result, "metrics") and isinstance(result, dict):
                result.setdefault("width_scale", width)
                result.setdefault("key_sig", key_sig)
                result.setdefault("keyset_kind", key_kind)
            elif isinstance(result, tuple) and len(result) == 3 and isinstance(result[2], dict):
                result[2].setdefault("width_scale", width)
                result[2].setdefault("key_sig", key_sig)
                result[2].setdefault("keyset_kind", key_kind)
                result = (result[0], result[1], result[2])
        except Exception:
            pass

        return result