"""Decision-Aware State Space Model for Federated Learning.

Implements action-conditioned state transitions with HiPPO-style initialization.
Reference: Gu et al., "Efficiently Modeling Long Sequences with Structured State Spaces", ICLR 2022.
"""

import numpy as np
from typing import Tuple, Optional, Dict
from collections import deque


class StateSpaceModel:
    """Action-conditioned SSM for modeling FL training dynamics.

    State transition:
        h_t = A(a) @ h_{t-1} + B(a) @ x_{t-1}
        x_hat_t = C @ h_t

    where A, B are modulated by action a = (N, pi).
    """

    def __init__(
        self,
        state_dim: int = 32,
        obs_dim: int = 4,
        action_dim: int = 16,
        beta_macro: float = 0.9,
        forgetting_factor: float = 0.95
    ):
        self.state_dim = state_dim
        self.obs_dim = obs_dim
        self.action_dim = action_dim
        self.beta_macro = beta_macro
        self.forgetting_factor = forgetting_factor

        self._init_params()
        self.h = np.zeros(state_dim)
        self.h_macro = np.zeros(state_dim)
        self.sigma = np.eye(obs_dim) * 0.1
        self.sigma_ema_alpha = 0.1

        self.obs_history = deque(maxlen=100)
        self.pred_history = deque(maxlen=100)
        self.surprise_history = deque(maxlen=100)

    def _init_params(self):
        """Initialize SSM parameters with HiPPO-style diagonal structure."""
        diag = -np.arange(1, self.state_dim + 1) / self.state_dim
        self.A_base = np.diag(diag)
        for i in range(self.state_dim - 1):
            self.A_base[i, i + 1] = 0.1

        self.B_base = np.random.randn(self.state_dim, self.obs_dim) * 0.1
        self.C = np.random.randn(self.obs_dim, self.state_dim) * 0.1
        self.W_A = np.random.randn(self.state_dim, self.action_dim) * 0.01
        self.W_B = np.random.randn(self.state_dim, self.action_dim) * 0.01

    def embed_action(self, N: int, pi: np.ndarray, N_max: int = 10) -> np.ndarray:
        """Embed action (N, pi) into continuous space."""
        n_embed = np.zeros(self.action_dim // 2)
        for i in range(len(n_embed)):
            freq = N / (10000 ** (i / len(n_embed)))
            n_embed[i] = np.sin(freq) if i % 2 == 0 else np.cos(freq)

        pi_embed = np.zeros(self.action_dim // 2)
        if len(pi) > 0:
            pi_embed[0] = np.mean(pi)
            pi_embed[1] = np.std(pi) if len(pi) > 1 else 0
            pi_embed[2] = np.max(pi)
            pi_embed[3] = -np.sum(pi * np.log(pi + 1e-10))
            remaining = min(len(pi), len(pi_embed) - 4)
            pi_embed[4:4 + remaining] = pi[:remaining]

        return np.concatenate([n_embed, pi_embed])

    def modulate_matrices(self, action_embed: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
        """Modulate transition matrices based on action embedding."""
        modulation = 1.0 / (1.0 + np.exp(-self.W_A @ action_embed))
        A_mod = self.A_base * modulation.reshape(-1, 1)
        b_adjust = self.W_B @ action_embed
        B_mod = self.B_base + b_adjust.reshape(-1, 1) * 0.1
        return A_mod, B_mod

    def predict(self, x_prev: np.ndarray, action_embed: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
        """Predict next observation given previous observation and action."""
        A, B = self.modulate_matrices(action_embed)
        self.h = A @ self.h + B @ x_prev
        self.h_macro = self.beta_macro * self.h_macro + (1 - self.beta_macro) * self.h
        x_pred = self.C @ self.h
        return self.h.copy(), x_pred

    def compute_surprise(self, x_actual: np.ndarray, x_pred: np.ndarray) -> float:
        """Compute Mahalanobis distance as prediction surprise."""
        error = x_actual - x_pred
        outer = np.outer(error, error)
        self.sigma = (1 - self.sigma_ema_alpha) * self.sigma + self.sigma_ema_alpha * outer

        try:
            sigma_inv = np.linalg.inv(self.sigma + np.eye(self.obs_dim) * 1e-6)
            surprise = float(error @ sigma_inv @ error)
        except np.linalg.LinAlgError:
            surprise = float(np.sum(error ** 2))

        self.obs_history.append(x_actual)
        self.pred_history.append(x_pred)
        self.surprise_history.append(surprise)
        return surprise

    def counterfactual_rollout(
        self,
        candidate_actions: list,
        x_current: np.ndarray,
        goal_state: Optional[np.ndarray] = None
    ) -> Dict[int, float]:
        """Simulate future states under different candidate actions."""
        if goal_state is None:
            goal_state = np.array([0.0, 0.0, 0.0, 1.0])

        expected_surprises = {}
        h_backup = self.h.copy()

        for idx, (N, pi) in enumerate(candidate_actions):
            action_embed = self.embed_action(N, pi)
            A, B = self.modulate_matrices(action_embed)
            h_sim = A @ h_backup + B @ x_current
            x_sim = self.C @ h_sim
            error = x_sim - goal_state
            expected_surprises[idx] = float(np.sum(error ** 2))

        return expected_surprises

    def get_recent_surprise(self, window: int = 5) -> float:
        if not self.surprise_history:
            return 1.0
        return np.mean(list(self.surprise_history)[-window:])

    def get_surprise_threshold(self, quantile: float = 0.5) -> float:
        if len(self.surprise_history) < 5:
            return 1.0
        return np.quantile(list(self.surprise_history), quantile)

    def update_online(self, x_actual: np.ndarray, x_pred: np.ndarray, lr: float = 0.01):
        """Online update of observation matrix C."""
        error = x_actual - x_pred
        grad_C = -np.outer(error, self.h)
        self.C -= lr * grad_C

    def reset(self):
        self.h = np.zeros(self.state_dim)
        self.h_macro = np.zeros(self.state_dim)
        self.sigma = np.eye(self.obs_dim) * 0.1
        self.obs_history.clear()
        self.pred_history.clear()
        self.surprise_history.clear()


class HierarchicalSSM(StateSpaceModel):
    """Multi-scale SSM with hierarchical state representations."""

    def __init__(self, *args, num_scales: int = 3, **kwargs):
        super().__init__(*args, **kwargs)
        self.num_scales = num_scales
        self.h_scales = [np.zeros(self.state_dim) for _ in range(num_scales)]
        self.betas = [0.5 ** (i + 1) for i in range(num_scales)]

    def predict(self, x_prev: np.ndarray, action_embed: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
        h_new, x_pred = super().predict(x_prev, action_embed)
        for i in range(self.num_scales):
            self.h_scales[i] = self.betas[i] * self.h_scales[i] + (1 - self.betas[i]) * h_new
        return h_new, x_pred

    def get_hierarchical_state(self) -> np.ndarray:
        return np.concatenate([self.h] + self.h_scales)
