import flwr as fl
import torch

# import copy
from typing import Dict, Callable, Optional, Tuple 
from flwr.server.server import Server

from src.apps import ReeFLClassificationApp
from src.models.model_utils import set_partial_weights
from src.apps.clients import test
from src.utils import get_func_from_config
import torch.nn as nn
import logging
from src.apps.clients.client_utils import ree_early_exit_test
logger = logging.getLogger(__name__)

class InclusiveFLClassificationApp(ReeFLClassificationApp):    
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    
    def get_eval_fn(self):
        import numpy as np
        from copy import deepcopy
        from src.utils import get_func_from_config

        def _normalize_blks_to_exit(full_blks, blk_to_exit, exit_i):
            if not full_blks:
                return [blk_to_exit]
            trimmed = [b for b in full_blks if b <= blk_to_exit] or [blk_to_exit]
            if trimmed[-1] != blk_to_exit:
                trimmed.append(blk_to_exit)
            need = (exit_i + 1) - len(trimmed)
            if need > 0:
                trimmed.extend([blk_to_exit] * need)
            elif need < 0:
                trimmed = trimmed[: exit_i] + [blk_to_exit]
            return trimmed

        def _pick_width(exit_i):
            try:
                return float(self.ckp.config.app.args.width_scaling[exit_i])
            except Exception:
                return 1.0

        def evaluate(*args):
            if len(args) < 2:
                raise TypeError("evaluate expects at least (weights, partition)")

            # local imports to avoid polluting module scope
            import torch
            from src.apps.clients import ree_early_exit_test

            weights     = args[0]
            partition   = args[1]
            exit_i      = None
            blk_to_exit = None
            local_keys  = None

            if len(args) >= 3 and isinstance(args[2], (int, np.integer)):
                exit_i = int(args[2])
            if len(args) >= 4 and args[3] is not None:
                blk_to_exit = int(args[3])
            if len(args) >= 5 and isinstance(args[4], (list, tuple)):
                local_keys = list(args[4])

            # ----- device selection (config override -> CUDA if available -> CPU) -----
            cfg_dev = getattr(getattr(self.ckp.config, "app", {}), "eval_fn", None)
            cfg_dev = getattr(cfg_dev, "device", None) if cfg_dev is not None else None
            if cfg_dev is not None:
                device = str(cfg_dev)
            else:
                device = f"cuda:{torch.cuda.current_device()}" if torch.cuda.is_available() else "cpu"
            if device.startswith("cuda") and torch.cuda.is_available():
                torch.backends.cudnn.benchmark = True

            # ----- build net args for the desired exit/depth/width -----
            net_cfg  = self.ckp.config.models.net
            arch_fn  = get_func_from_config(net_cfg)
            net_args = deepcopy(net_cfg.args)

            # probe on CPU to infer defaults (don’t burn GPU mem here)
            ref_net  = arch_fn(device="cpu", **net_cfg.args)
            full_blks = list(getattr(ref_net, "blks_to_exit", []))
            if exit_i is None:
                exit_i = len(full_blks) - 1 if full_blks else 0
            if blk_to_exit is None:
                blk_to_exit = int(full_blks[exit_i]) if full_blks else (getattr(ref_net, "total_blocks", 1) - 1)

            net_args["depth"]       = blk_to_exit + 1
            net_args["no_of_exits"] = exit_i + 1
            net_args["width_scale"] = _pick_width(exit_i)
            chosen = _normalize_blks_to_exit(full_blks, blk_to_exit, exit_i) if full_blks else [blk_to_exit]
            if "blks_to_exit" in net_args:
                net_args["blks_to_exit"] = chosen
            elif "ee_layer_locations" in net_args:
                net_args["ee_layer_locations"] = chosen

            # build the model on the chosen device
            model  = arch_fn(device=device, **net_args).eval()

            # ----- key selection (all progressive vs single-head) -----
            all_keys   = list(model.trainable_state_dict_keys)
            head_pref  = f"exit_heads.{exit_i}."
            single_head_keys = [k for k in all_keys if (not k.startswith("exit_heads.")) or k.startswith(head_pref)]

            if local_keys is not None:
                progressive_keys = local_keys
                if len(weights) != len(progressive_keys):
                    raise AssertionError(
                        f"[InclusiveFL/eval] payload len={len(weights)} != local_keys len={len(progressive_keys)}"
                    )
            else:
                if len(weights) == len(all_keys):
                    progressive_keys = all_keys
                elif len(weights) == len(single_head_keys):
                    progressive_keys = single_head_keys
                else:
                    raise AssertionError(
                        f"[InclusiveFL/eval] payload len={len(weights)} not in "
                        f"{{all={len(all_keys)}, head{exit_i}-only={len(single_head_keys)}}}"
                    )

            # ----- safe load with device-aware shape fit -----
            sd = model.state_dict()

            def _to_tensor(x):
                if torch.is_tensor(x):
                    return x.detach()  # keep whatever device; we'll move to target below
                return torch.from_numpy(np.asarray(x))

            def _fit_to_target(src: torch.Tensor, tgt: torch.Tensor) -> torch.Tensor:
                # move to target param’s device/dtype first
                src = src.to(device=tgt.device, dtype=tgt.dtype, non_blocking=True)
                if list(src.shape) == list(tgt.shape):
                    return src
                # slice/pad into a zeros_like(tgt)
                slices = [slice(0, min(s, t)) for s, t in zip(src.shape, tgt.shape)]
                out = torch.zeros_like(tgt)
                out[tuple(slices)] = src[tuple(slices)]
                return out

            for k, w in zip(progressive_keys, weights):
                if k not in sd:
                    continue
                sd[k] = _fit_to_target(_to_tensor(w), sd[k])

            model.load_state_dict(sd, strict=False)

            # ----- dataloader (no duplicate pin_memory; your wrappers set it) -----
            data_cfg = self.ckp.config.data
            data_cls = get_func_from_config(data_cfg)
            dataset  = data_cls(self.ckp, **data_cfg.args)
            testloader = dataset.get_dataloader(
                data_pool="server",
                partition=partition,
                batch_size=int(self.ckp.config.app.eval_fn.batch_size),
                augment=False,
                num_workers=(4 if str(device).startswith("cuda") else 0),
            )

            # ----- inference -----
            with torch.inference_mode():
                results = ree_early_exit_test(
                    model,
                    max_early_exit_layer=exit_i,
                    testloader=testloader,
                    device=str(device),
                    ensemble=False,
                )

            # robust exit selection
            int_keys = sorted(k for k in results.keys() if isinstance(k, int))
            if exit_i in results:
                use_key = exit_i
            else:
                not_greater = [k for k in int_keys if k <= exit_i]
                use_key = (not_greater[-1] if not_greater else int_keys[-1])

            out  = results[use_key]
            loss = float(out["loss"])
            acc  = float(out["accuracy"]) * 100.0
            metrics = {
                f"centralized_{partition}_exit{exit_i}_loss": loss,
                f"centralized_{partition}_exit{exit_i}_acc":  acc,
            }

            print(
                f"[EVAL][{partition}] exit={exit_i} (used={use_key}) depth={blk_to_exit+1} "
                f"width_scale={net_args['width_scale']:.3f} payload_len={len(weights)} "
                f"prog_keys={len(progressive_keys)} dev={device} -> loss={loss:.4f} acc={acc:.2f}"
            )

            # cleanup
            try:
                if str(device).startswith("cuda"):
                    torch.cuda.synchronize()
                    torch.cuda.empty_cache()
            except Exception:
                pass

            return loss, metrics

        return evaluate

    def run(self, server: Server):
        """Run federated averaging for a number of rounds."""
        super().run(server, reefl=False)

        # saving personalized weights
        for exit_i in server.strategy.exit_personalized_sd_values.keys():
            self.ckp.save(f'models/weights_personalized_exit{exit_i}.pkl',
                server.strategy.exit_personalized_sd_values[exit_i])
