"""Trust-Weighted Aggregation for FedSSM.

Implements uncertainty-calibrated fusion with surprise-based trust coefficients.
Key equation:
    tau_i = exp(-S_t * eta * |L_i - L_bar| / std(L))
    theta_{t+1} = sum(n_i * tau_i * theta_i) / sum(n_i * tau_i)
"""

import numpy as np
import torch
from typing import List, Dict, Optional


class TrustAggregator:
    """Trust-weighted model aggregation based on prediction surprise."""

    def __init__(
        self,
        eta: float = 1.0,
        min_trust: float = 0.1,
        max_trust: float = 1.0,
        eta_decay: float = 0.99,
        eta_min: float = 0.1,
        outlier_threshold: float = 3.0,
        enable_outlier_detection: bool = True
    ):
        self.eta = eta
        self.eta_init = eta
        self.min_trust = min_trust
        self.max_trust = max_trust
        self.eta_decay = eta_decay
        self.eta_min = eta_min
        self.outlier_threshold = outlier_threshold
        self.enable_outlier_detection = enable_outlier_detection
        self.history = []

    def _detect_outliers(self, losses: Dict[int, float]) -> List[int]:
        """Detect outlier clients using z-score."""
        if not self.enable_outlier_detection or len(losses) < 3:
            return []

        values = list(losses.values())
        mean_loss = np.mean(values)
        std_loss = np.std(values)

        if std_loss < 1e-10:
            return []

        return [cid for cid, loss in losses.items() if abs(loss - mean_loss) / std_loss > self.outlier_threshold]

    def compute_trust(self, client_losses: Dict[int, float], surprise: float) -> Dict[int, float]:
        """Compute trust coefficients based on surprise and loss deviation."""
        if not client_losses:
            return {}

        losses = list(client_losses.values())
        mean_loss = np.mean(losses)
        std_loss = np.std(losses) + 1e-10

        outliers = self._detect_outliers(client_losses)
        trust = {}

        for cid, loss in client_losses.items():
            deviation = abs(loss - mean_loss) / std_loss
            tau = np.exp(-surprise * self.eta * deviation)
            tau = max(self.min_trust, min(self.max_trust, tau))

            if cid in outliers:
                tau *= 0.1

            trust[cid] = tau

        self.eta = max(self.eta_min, self.eta * self.eta_decay)

        self.history.append({
            "surprise": surprise,
            "mean_trust": np.mean(list(trust.values())),
            "min_trust": min(trust.values()),
            "max_trust": max(trust.values())
        })

        return trust

    def aggregate(
        self,
        state_dicts: List[Dict[str, torch.Tensor]],
        weights: List[float],
        client_ids: List[int],
        client_losses: Dict[int, float],
        surprise: float
    ) -> Dict[str, torch.Tensor]:
        """Aggregate model parameters with trust weighting."""
        if not state_dicts:
            raise ValueError("No state dicts to aggregate")

        trust = self.compute_trust(client_losses, surprise)

        effective_weights = [weights[i] * trust.get(cid, 1.0) for i, cid in enumerate(client_ids)]
        total = sum(effective_weights)

        if total == 0:
            raise ValueError("Total effective weight is zero")

        norm_weights = [w / total for w in effective_weights]

        aggregated = {}
        for key in state_dicts[0].keys():
            stacked = torch.zeros_like(state_dicts[0][key], dtype=torch.float32)
            for w, sd in zip(norm_weights, state_dicts):
                stacked += w * sd[key].float()
            aggregated[key] = stacked.to(state_dicts[0][key].dtype)

        return aggregated

    def get_stats(self) -> Dict:
        if not self.history:
            return {}

        recent = self.history[-10:]
        return {
            "avg_surprise": np.mean([h["surprise"] for h in recent]),
            "avg_mean_trust": np.mean([h["mean_trust"] for h in recent]),
            "total_aggregations": len(self.history)
        }

    def reset(self):
        self.history.clear()
        self.eta = self.eta_init
