"""FedSSM Server - integrates SSM, selector, and aggregator."""

import torch
import numpy as np
from typing import Dict, List, Tuple, Optional

from ..models import HierarchicalSSM
from .selector import ClientSelector
from .aggregator import TrustAggregator


class FedSSMServer:
    """Federated learning server with SSM-based client selection."""

    def __init__(self, model: torch.nn.Module, config: dict):
        self.config = config
        self.num_clients = config["federated"]["num_clients"]
        self.device = torch.device("cuda" if torch.cuda.is_available() and config["training"]["device"] == "cuda" else "cpu")
        self.global_model = model.to(self.device)

        ssm_cfg = config.get("fedssm", {})

        self.ssm = HierarchicalSSM(
            state_dim=ssm_cfg.get("state_dim", 32),
            obs_dim=4,
            action_dim=ssm_cfg.get("action_dim", 16),
            beta_macro=ssm_cfg.get("beta_macro", 0.9)
        )

        self.selector = ClientSelector(
            num_clients=self.num_clients,
            n_min=ssm_cfg.get("n_min", 2),
            n_max=ssm_cfg.get("n_max", 8),
            kappa=ssm_cfg.get("kappa", 2.0),
            tau=ssm_cfg.get("tau", 2.0),
            beta_variance=ssm_cfg.get("beta_variance", 1.0),
            beta_staleness=ssm_cfg.get("beta_staleness", 0.5),
            beta_loss_align=ssm_cfg.get("beta_loss_align", 1.0),
            beta_grad_align=ssm_cfg.get("beta_grad_align", 0.5)
        )

        self.aggregator = TrustAggregator(
            eta=ssm_cfg.get("trust_eta", 1.0),
            min_trust=ssm_cfg.get("min_trust", 0.1),
            max_trust=ssm_cfg.get("max_trust", 1.0),
            eta_decay=ssm_cfg.get("eta_decay", 0.99),
            outlier_threshold=ssm_cfg.get("outlier_threshold", 3.0)
        )

        self.history = {"train_loss": [], "train_acc": [], "test_loss": [], "test_acc": [], "surprise": [], "n_selected": []}
        self.prev_action = None
        self.prev_obs = None
        self.current_round = 0
        self.total_rounds = config["federated"]["num_rounds"]
        self.best_acc = 0.0

    def _compute_obs(self, losses: Dict[int, float], grad_norms: Dict[int, float]) -> np.ndarray:
        """Compute observation vector [loss, grad_norm, grad_var, progress]."""
        loss_vals = list(losses.values())
        grad_vals = list(grad_norms.values())
        return np.array([
            np.mean(loss_vals) if loss_vals else 1.0,
            np.mean(grad_vals) if grad_vals else 1.0,
            np.var(grad_vals) if len(grad_vals) > 1 else 0.0,
            self.current_round / max(1, self.total_rounds)
        ])

    def select_clients(
        self,
        clients: Dict,
        round_num: int = 0,
        client_losses: Optional[Dict[int, float]] = None,
        client_grad_norms: Optional[Dict[int, float]] = None
    ) -> Tuple[List[int], Dict]:
        """Select clients using SSM prediction and counterfactual selection."""
        self.current_round = round_num
        all_ids = list(clients.keys())

        if client_losses is None:
            client_losses = {cid: 1.0 for cid in all_ids}
        if client_grad_norms is None:
            client_grad_norms = {cid: 1.0 for cid in all_ids}

        obs = self._compute_obs(client_losses, client_grad_norms)

        if self.prev_action is not None and self.prev_obs is not None:
            N_prev, pi_prev = self.prev_action
            action_embed = self.ssm.embed_action(N_prev, pi_prev)
            _, x_pred = self.ssm.predict(self.prev_obs, action_embed)
            surprise = self.ssm.compute_surprise(obs, x_pred)
            self.ssm.update_online(obs, x_pred)
        else:
            surprise = 1.0

        global_loss = np.mean(list(client_losses.values()))
        global_grad = np.mean(list(client_grad_norms.values()))

        selected, info = self.selector.select(all_ids, surprise, global_loss, global_grad, round_num)

        pi = np.array([info["probabilities"].get(cid, 0) for cid in all_ids])
        self.prev_action = (len(selected), pi)
        self.prev_obs = obs

        self.history["surprise"].append(surprise)
        self.history["n_selected"].append(len(selected))

        info["observation"] = obs.tolist()
        info["ssm_state_norm"] = float(np.linalg.norm(self.ssm.h))
        return selected, info

    def aggregate(
        self,
        state_dicts: List[Dict[str, torch.Tensor]],
        weights: List[float],
        client_ids: List[int],
        client_losses: Dict[int, float]
    ) -> Dict[str, torch.Tensor]:
        """Aggregate with trust weighting."""
        surprise = self.history["surprise"][-1] if self.history["surprise"] else 1.0
        return self.aggregator.aggregate(state_dicts, weights, client_ids, client_losses, surprise)

    def update_client_stats(self, client_id: int, loss: float, grad_norm: float, round_num: int):
        self.selector.update_stats(client_id, loss, grad_norm, round_num)

    def get_state_dict(self) -> Dict[str, torch.Tensor]:
        return {k: v.cpu().clone() for k, v in self.global_model.state_dict().items()}

    def set_state_dict(self, state_dict: Dict[str, torch.Tensor]):
        self.global_model.load_state_dict({k: v.to(self.device) for k, v in state_dict.items()})

    def save(self, path: str):
        torch.save({
            "model": self.global_model.state_dict(),
            "history": self.history,
            "ssm": {"h": self.ssm.h, "h_macro": self.ssm.h_macro, "sigma": self.ssm.sigma},
            "best_acc": self.best_acc
        }, path)

    def load(self, path: str):
        ckpt = torch.load(path, map_location=self.device)
        self.global_model.load_state_dict(ckpt["model"])
        self.history = ckpt.get("history", self.history)
        self.best_acc = ckpt.get("best_acc", 0.0)
        if "ssm" in ckpt:
            self.ssm.h = ckpt["ssm"]["h"]
            self.ssm.h_macro = ckpt["ssm"]["h_macro"]
            self.ssm.sigma = ckpt["ssm"]["sigma"]

    def get_stats(self) -> Dict:
        return {
            "round": self.current_round,
            "best_acc": self.best_acc,
            "ssm": {"state_norm": float(np.linalg.norm(self.ssm.h)), "recent_surprise": self.ssm.get_recent_surprise()},
            "selector": self.selector.get_stats(),
            "aggregator": self.aggregator.get_stats()
        }

    def reset(self):
        self.ssm.reset()
        self.selector.reset()
        self.aggregator.reset()
        self.history = {"train_loss": [], "train_acc": [], "test_loss": [], "test_acc": [], "surprise": [], "n_selected": []}
        self.prev_action = None
        self.prev_obs = None
        self.current_round = 0
        self.best_acc = 0.0
