"""Counterfactual Client Selection for FedSSM.

Implements adaptive participation budget and exploration-exploitation tradeoff.
Key equations:
    N_t = N_min + sigmoid(kappa * (S_{t-1} - theta_S)) * (N_max - N_min)
    p_i ~ (1 - rho) * w_explore + rho * w_exploit
    rho = sigmoid(tau * (theta_S - S_t))
"""

import numpy as np
from typing import List, Dict, Tuple
from collections import deque
from dataclasses import dataclass, field


@dataclass
class ClientProfile:
    """Tracks client statistics for selection decisions."""
    client_id: int
    loss_history: deque = field(default_factory=lambda: deque(maxlen=50))
    grad_norm_history: deque = field(default_factory=lambda: deque(maxlen=50))
    last_participated: int = -1
    total_participations: int = 0

    def update(self, loss: float, grad_norm: float, round_num: int):
        self.loss_history.append(loss)
        self.grad_norm_history.append(grad_norm)
        self.last_participated = round_num
        self.total_participations += 1

    @property
    def loss_variance(self) -> float:
        return float(np.var(self.loss_history)) if len(self.loss_history) >= 2 else 1.0

    @property
    def mean_loss(self) -> float:
        return float(np.mean(self.loss_history)) if self.loss_history else float('inf')

    @property
    def mean_grad_norm(self) -> float:
        return float(np.mean(self.grad_norm_history)) if self.grad_norm_history else 1.0

    def staleness(self, current_round: int) -> int:
        return current_round - self.last_participated if self.last_participated >= 0 else current_round + 1


class ClientSelector:
    """Counterfactual client selector with adaptive budget and explore-exploit balance."""

    def __init__(
        self,
        num_clients: int,
        n_min: int = 2,
        n_max: int = 8,
        kappa: float = 2.0,
        tau: float = 2.0,
        beta_variance: float = 1.0,
        beta_staleness: float = 0.5,
        beta_loss_align: float = 1.0,
        beta_grad_align: float = 0.5,
        threshold_quantile: float = 0.5
    ):
        self.num_clients = num_clients
        self.n_min = n_min
        self.n_max = n_max
        self.kappa = kappa
        self.tau = tau
        self.beta_variance = beta_variance
        self.beta_staleness = beta_staleness
        self.beta_loss_align = beta_loss_align
        self.beta_grad_align = beta_grad_align
        self.threshold_quantile = threshold_quantile

        self.profiles: Dict[int, ClientProfile] = {}
        self.surprise_history = deque(maxlen=100)
        self.surprise_threshold = 1.0
        self.current_round = 0

    def _ensure_profile(self, client_id: int):
        if client_id not in self.profiles:
            self.profiles[client_id] = ClientProfile(client_id=client_id)

    def _compute_budget(self, surprise: float) -> int:
        """Compute adaptive participation budget based on surprise."""
        self.surprise_history.append(surprise)
        if len(self.surprise_history) >= 5:
            self.surprise_threshold = np.quantile(list(self.surprise_history), self.threshold_quantile)

        x = self.kappa * (surprise - self.surprise_threshold)
        sigmoid_val = 1.0 / (1.0 + np.exp(-x))
        N = self.n_min + int(sigmoid_val * (self.n_max - self.n_min))
        return max(self.n_min, min(self.n_max, N))

    def _exploration_weight(self, client_id: int) -> float:
        """Compute exploration weight based on variance and staleness."""
        profile = self.profiles[client_id]
        all_vars = [p.loss_variance for p in self.profiles.values()]
        all_stale = [p.staleness(self.current_round) for p in self.profiles.values()]

        norm_var = profile.loss_variance / (np.mean(all_vars) + 1e-10)
        norm_stale = profile.staleness(self.current_round) / (np.mean(all_stale) + 1e-10)

        return np.exp(self.beta_variance * norm_var + self.beta_staleness * norm_stale)

    def _exploitation_weight(self, client_id: int, global_loss: float, global_grad: float) -> float:
        """Compute exploitation weight based on alignment with global statistics."""
        profile = self.profiles[client_id]

        if np.isinf(profile.mean_loss) or global_loss < 1e-10:
            loss_dev = 1.0
        else:
            loss_dev = abs(profile.mean_loss - global_loss) / (global_loss + 1e-10)

        grad_dev = abs(profile.mean_grad_norm - global_grad) / (global_grad + 1e-10) if global_grad > 1e-10 else 1.0

        return np.exp(-self.beta_loss_align * loss_dev - self.beta_grad_align * grad_dev)

    def _compute_probabilities(
        self,
        client_ids: List[int],
        surprise: float,
        global_loss: float,
        global_grad: float
    ) -> Dict[int, float]:
        """Compute sampling probabilities with explore-exploit tradeoff."""
        rho = 1.0 / (1.0 + np.exp(-self.tau * (self.surprise_threshold - surprise)))

        probs = {}
        for cid in client_ids:
            self._ensure_profile(cid)
            w_explore = self._exploration_weight(cid)
            w_exploit = self._exploitation_weight(cid, global_loss, global_grad)
            probs[cid] = (1 - rho) * w_explore + rho * w_exploit

        total = sum(probs.values())
        return {cid: w / total for cid, w in probs.items()} if total > 0 else {cid: 1.0 / len(client_ids) for cid in client_ids}

    def select(
        self,
        client_ids: List[int],
        surprise: float,
        global_loss: float = 1.0,
        global_grad: float = 1.0,
        round_num: int = 0
    ) -> Tuple[List[int], Dict]:
        """Select clients based on surprise-driven adaptive strategy."""
        self.current_round = round_num
        for cid in client_ids:
            self._ensure_profile(cid)

        N = min(self._compute_budget(surprise), len(client_ids))
        probs = self._compute_probabilities(client_ids, surprise, global_loss, global_grad)

        prob_array = np.array([probs[cid] for cid in client_ids])
        prob_array = prob_array / prob_array.sum()

        selected_idx = np.random.choice(len(client_ids), size=N, replace=False, p=prob_array)
        selected = [client_ids[i] for i in selected_idx]

        rho = 1.0 / (1.0 + np.exp(-self.tau * (self.surprise_threshold - surprise)))

        info = {
            "n_selected": N,
            "surprise": surprise,
            "surprise_threshold": self.surprise_threshold,
            "exploration_ratio": 1 - rho,
            "probabilities": probs
        }
        return selected, info

    def update_stats(self, client_id: int, loss: float, grad_norm: float, round_num: int):
        self._ensure_profile(client_id)
        self.profiles[client_id].update(loss, grad_norm, round_num)

    def get_stats(self) -> Dict:
        return {
            "current_round": self.current_round,
            "surprise_threshold": self.surprise_threshold,
            "num_profiles": len(self.profiles),
            "participation_counts": {cid: p.total_participations for cid, p in self.profiles.items()}
        }

    def reset(self):
        self.profiles.clear()
        self.surprise_history.clear()
        self.surprise_threshold = 1.0
        self.current_round = 0
