# src/apps/scalefl_classification_app.py

import torch
import numpy as np
from copy import deepcopy
from src.apps import ReeFLClassificationApp
from src.utils import get_func_from_config
from src.apps.clients import ree_early_exit_test

class ScaleFLClassificationApp(ReeFLClassificationApp):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def get_eval_fn(self):
        """
        Centralized evaluator for ScaleFL:
        - Rebuilds a neutral submodel for the requested exit/depth/width.
        - Accepts either TRAINABLE (progressive) payloads OR FULL per-exit state (params+BN).
        - Safe slice/pad on shape mismatches (width scaling).
        - Runs ree_early_exit_test on GPU if available (or CPU), reports per-exit metrics.
        """
        def evaluate(*args):
            if len(args) < 2:
                raise TypeError("evaluate expects at least (weights, partition)")

            import numpy as np
            import torch
            from copy import deepcopy
            from src.utils import get_func_from_config
            from src.apps.clients import ree_early_exit_test

            # ---------- Parse args ----------
            weights      = args[0]
            partition    = args[1]
            exit_i       = None
            blk_to_exit  = None
            blks_to_exit = None
            local_keys   = None
            width_scale_override = None

            if len(args) >= 3 and isinstance(args[2], (int, np.integer)):
                exit_i = int(args[2])

            if len(args) >= 4:
                a3 = args[3]
                if isinstance(a3, (list, tuple)):
                    blks_to_exit = list(a3)
                elif a3 is not None:
                    blk_to_exit = int(a3)

            if len(args) >= 5 and isinstance(args[4], (list, tuple)):
                local_keys = list(args[4])

            if len(args) >= 6 and isinstance(args[5], (float, int, np.floating, np.integer)):
                width_scale_override = float(args[5])

            # ---------- Device: CUDA if available (overrideable via config) ----------
            cfg_dev = getattr(getattr(self.ckp.config, "app", {}), "eval_fn", None)
            cfg_dev = getattr(cfg_dev, "device", None) if cfg_dev is not None else None
            if cfg_dev is not None:
                device = str(cfg_dev)
            else:
                device = f"cuda:{torch.cuda.current_device()}" if torch.cuda.is_available() else "cpu"

            if device.startswith("cuda") and torch.cuda.is_available():
                torch.backends.cudnn.benchmark = True  # faster convs on fixed shapes

            # ---------- Build neutral submodel for this exit/depth/width ----------
            net_cfg  = self.ckp.config.models.net
            arch_fn  = get_func_from_config(net_cfg)
            net_args = deepcopy(net_cfg.args)

            # Probe on CPU only to infer structural defaults
            probe = arch_fn(device="cpu", **deepcopy(net_cfg.args))
            try:
                full_depth = sum(len(s) for s in getattr(probe, "layers", []))
                last_block = full_depth - 1 if full_depth and full_depth > 0 else None
            except Exception:
                full_depth = None
                last_block = None
            if last_block is None:
                try:
                    vals = [int(x) for x in getattr(probe, "blks_to_exit", []) if int(x) >= 0]
                    last_block = max(vals) if vals else 0
                except Exception:
                    last_block = 0

            if exit_i is None:
                try:
                    exit_i = len(getattr(probe, "blks_to_exit")) - 1
                except Exception:
                    exit_i = 0
            if blk_to_exit is None:
                try:
                    blk_to_exit = int(getattr(probe, "blks_to_exit")[exit_i])
                except Exception:
                    blk_to_exit = last_block
            if int(blk_to_exit) < 0:
                blk_to_exit = int(last_block)

            # Width scale (override > config > 1.0)
            if width_scale_override is not None:
                width_scale = float(width_scale_override)
            else:
                try:
                    width_scale = float(self.ckp.config.app.args.width_scaling[exit_i])
                except Exception:
                    width_scale = 1.0

            # Depth / exits / locations
            if blks_to_exit is not None:
                bte = list(blks_to_exit)[: exit_i + 1]
                bte = [int(last_block) if int(x) < 0 else int(x) for x in bte]
                net_args["blks_to_exit"] = bte
                blk_to_exit = int(bte[-1])
            else:
                if "blks_to_exit" in net_args and net_args["blks_to_exit"] is not None:
                    bte = list(net_args["blks_to_exit"])[: exit_i + 1]
                    bte = [int(last_block) if int(x) < 0 else int(x) for x in bte]
                    net_args["blks_to_exit"] = bte
                    if bte:
                        blk_to_exit = int(bte[-1])

            if "ee_layer_locations" in net_args and isinstance(net_args.get("ee_layer_locations"), (list, tuple)):
                net_args["ee_layer_locations"] = list(net_args["ee_layer_locations"])[: exit_i + 1]

            net_args["depth"]          = int(blk_to_exit) + 1
            net_args["no_of_exits"]    = exit_i + 1
            net_args["width_scale"]    = float(width_scale)
            net_args["last_exit_only"] = False

            # <<< CHANGED: build the *real* model on GPU if available >>>
            model  = arch_fn(device=device, **net_args).eval()

            # ---------- Key sets (progressive trainables vs full state for this exit) ----------
            if local_keys is None:
                train_keys_all = list(model.trainable_state_dict_keys)
                full_keys_all  = list(model.state_dict().keys())
                head_pref = f"exit_heads.{exit_i}."
                progressive_train_keys = [k for k in train_keys_all if (not k.startswith("exit_heads.")) or k.startswith(head_pref)]
                progressive_full_keys  = [k for k in full_keys_all  if (not k.startswith("exit_heads."))  or k.startswith(head_pref)]

                if len(weights) == len(progressive_train_keys):
                    keys_in = progressive_train_keys
                    payload_kind = "train"
                elif len(weights) == len(progressive_full_keys):
                    keys_in = progressive_full_keys
                    payload_kind = "full"
                else:
                    raise AssertionError(
                        f"[ScaleFL/eval] len(weights)={len(weights)} doesn't match "
                        f"train={len(progressive_train_keys)} or full={len(progressive_full_keys)} "
                        f"(exit_i={exit_i}, depth={int(blk_to_exit)+1}, width_scale={width_scale})"
                    )
            else:
                keys_in = list(local_keys)
                payload_kind = f"provided({len(keys_in)})"

            # ---------- Safe load (slice/pad) with proper DEVICE handling ----------
            sd = model.state_dict()  # tensors live on model's device
            def _to_tensor(x):
                # keep whatever device until _fit_shape moves it
                if torch.is_tensor(x): return x.detach()
                return torch.from_numpy(np.asarray(x))

            def _fit_shape(src: torch.Tensor, tgt: torch.Tensor) -> torch.Tensor:
                # move to target device/dtype first
                src = src.to(device=tgt.device, dtype=tgt.dtype, non_blocking=True)
                if list(src.shape) == list(tgt.shape):
                    return src
                slices = [slice(0, min(s, t)) for s, t in zip(src.shape, tgt.shape)]
                out = torch.zeros_like(tgt)
                out[tuple(slices)] = src[tuple(slices)]
                return out

            for k, w in zip(keys_in, weights):
                if k not in sd:
                    continue
                sd[k] = _fit_shape(_to_tensor(w), sd[k])

            model.load_state_dict(sd, strict=False)

            # ---------- Dataloader (enable pinned mem / workers on CUDA if supported) ----------
            data_cfg = self.ckp.config.data
            data_cls = get_func_from_config(data_cfg)
            dataset  = data_cls(self.ckp, **data_cfg.args)

            dl_kwargs = dict(
                data_pool="server",
                partition=partition,
                batch_size=int(self.ckp.config.app.eval_fn.batch_size),
                augment=False,
                num_workers=(4 if str(device).startswith("cuda") else 0),
            )
            # Only add pin_memory if the underlying get_dataloader forwards kwargs
            # to torch.utils.data.DataLoader (common in this codebase).
            # if str(device).startswith("cuda"):
            #     dl_kwargs["pin_memory"] = True

            testloader = dataset.get_dataloader(**dl_kwargs)

            # ---------- Pure inference ----------
            with torch.inference_mode():
                results = ree_early_exit_test(
                    model,
                    max_early_exit_layer=exit_i,
                    testloader=testloader,
                    device=str(device),    # <<< ensures batches go to the same device >>>
                    ensemble=False,
                )

            # Pick the right exit result robustly
            int_keys = sorted(k for k in results.keys() if isinstance(k, int))
            if exit_i in results:
                use_key = exit_i
            else:
                not_greater = [k for k in int_keys if k <= exit_i]
                use_key = (not_greater[-1] if not_greater else int_keys[-1])

            out  = results[use_key]
            loss = float(out["loss"])
            acc  = float(out["accuracy"]) * 100.0

            metrics = {
                f"centralized_{partition}_exit{exit_i}_loss": loss,
                f"centralized_{partition}_exit{exit_i}_acc":  acc,
            }

            print(
                f"[EVAL][{partition}] exit={exit_i} (used={use_key}) "
                f"depth={int(blk_to_exit)+1} width_scale={width_scale:.3f} "
                f"payload_len={len(weights)} kind={payload_kind} dev={device} -> "
                f"loss={loss:.4f} acc={acc:.2f}"
            )

            # cleanup
            del model, sd
            try:
                if str(device).startswith("cuda"):
                    torch.cuda.synchronize()
                    torch.cuda.empty_cache()
            except Exception:
                pass

            return loss, metrics

        return evaluate
    