"""
Stochastic Training with Quantization Error (STE) Simulator

This module implements STE and ODE dynamics for quantized neural networks,
including support for both weight and activation quantization.
"""

import numpy as np
import math
from typing import Optional, Tuple, List, Dict, Any, Union

# Constants
SQRT_2PI = np.sqrt(2.0 * np.pi)
SQRT_2 = np.sqrt(2.0)


def phi(u: Union[float, np.ndarray]) -> Union[float, np.ndarray]:
    """
    Compute standard normal probability density function.
    
    Args:
        u: Input value(s)
        
    Returns:
        PDF value(s) at u
    """
    u = np.asarray(u, dtype=np.float64)
    return np.exp(-0.5 * u * u) / SQRT_2PI


def erfc_np(x: Union[float, np.ndarray]) -> Union[float, np.ndarray]:
    """
    Compute complementary error function (vectorized).
    
    Args:
        x: Input value(s)
        
    Returns:
        erfc(x) for all input values
    """
    x = np.asarray(x, dtype=np.float64)
    vfunc = np.vectorize(math.erfc, otypes=[float])
    return vfunc(x)

def _quantize_step_scalar(x: float, omega: float, b: int) -> float:
    """
    Apply step quantization to a scalar value.
    
    This implements uniform quantization with L = 2^b - 2 internal thresholds,
    resulting in L+1 output levels: {-ω + Δk | k = 0,...,L} where Δ = 2ω/L.
    
    Args:
        x: Input scalar value to quantize
        omega: Quantization range [-omega, omega]
        b: Number of bits for quantization
        
    Returns:
        Quantized value
    """
    L = 2**b - 2
    if L <= 0:
        return float(np.clip(x, -omega, omega))
    
    delta = 2.0 * omega / L
    k_values = np.arange(1, L + 1, dtype=np.float64)
    thresholds = -omega + (2.0 * k_values - 1.0) * (omega / L)
    
    # Count how many thresholds x exceeds
    count = float(np.sum(x - thresholds >= 0.0))
    return float(-omega + delta * count)

def mqr_from_isotropy(m: float, q: float, b: int, omega: float, 
                      eps: float = 1e-12, teacher: str = 'ones', 
                      w0: Optional[np.ndarray] = None, m0: float = 1.0) -> Tuple[float, float, float]:
    """
    Compute quantized macroscopic quantities (m_psi, q_psi, r_psi) under isotropy assumption.
    
    These quantities are used in the ODE dynamics to model the evolution of the
    quantized weights ψ(w) in the mean-field limit.
    
    Args:
        m: Mean overlap <w, w0>/d
        q: Second moment ||w||^2/d
        b: Number of quantization bits
        omega: Quantization range [-omega, omega]
        eps: Tolerance for degeneracy check
        teacher: Teacher type ('ones' or 'gaussian')
        w0: Teacher vector (required if teacher='gaussian')
        m0: Teacher norm ||w0||^2/d (required if teacher='gaussian')
        
    Returns:
        Tuple of (m_psi, q_psi, r_psi) where:
            m_psi = <ψ(w), w0>/d
            q_psi = ||ψ(w)||^2/d
            r_psi = <w, ψ(w)>/d
    """
    L = 2**b - 2
    
    # Handle degenerate case (clip only)
    if L <= 0:
        return _handle_clip_only_case(m, teacher, w0, omega, m0)
    
    if teacher == 'ones':
        return _compute_mqr_ones_teacher(m, q, b, omega, L, eps)
    else:  # gaussian teacher
        return _compute_mqr_gaussian_teacher(m, q, b, omega, L, eps, w0, m0)


def _handle_clip_only_case(m: float, teacher: str, w0: Optional[np.ndarray], 
                          omega: float, m0: float) -> Tuple[float, float, float]:
    """Handle the case where b is too small for quantization (clip only)."""
    if teacher == 'ones':
        psi_m = float(np.clip(m, -omega, omega))
        return psi_m, psi_m**2, m*psi_m
    
    # gaussian teacher
    d = len(w0)
    alpha = m / m0
    psi_w0 = np.clip(alpha * w0, -omega, omega)
    mpsi = float(np.dot(psi_w0, w0) / d)
    qpsi = float(np.dot(psi_w0, psi_w0) / d)
    rpsi = float(alpha * np.dot(psi_w0, w0) / d)
    return mpsi, qpsi, rpsi


def _compute_mqr_ones_teacher(m: float, q: float, b: int, omega: float, 
                             L: int, eps: float) -> Tuple[float, float, float]:
    """Compute mqr for ones teacher case."""
    s2 = max(q - m*m, 0.0)
    
    # Handle degenerate case
    if s2 <= eps:
        psi_m = _quantize_step_scalar(m, omega, b)
        return float(psi_m), float(psi_m*psi_m), float(m*psi_m)
    
    s = np.sqrt(s2)
    k = np.arange(1, L + 1, dtype=np.float64)
    uk = (-m - omega * (1.0 - (2.0 * k - 1.0) / L)) / s
    erfc_vals = erfc_np(uk / SQRT_2)
    
    # Compute m_psi
    mpsi = -omega + (omega / L) * np.sum(erfc_vals)
    
    # Compute q_psi (double sum over max)
    uk_col = uk.reshape(-1, 1)
    uk_row = uk.reshape(1, -1)
    max_mat = np.maximum(uk_col, uk_row)
    qpsi = (omega**2 
            - (2.0 * omega**2 / L) * np.sum(erfc_vals)
            + (2.0 * omega**2 / (L**2)) * np.sum(erfc_np(max_mat / SQRT_2)))
    
    # Compute r_psi
    rpsi = m * mpsi + (2.0 * omega / L) * s * np.sum(phi(uk))
    
    return float(mpsi), float(qpsi), float(rpsi)


def _compute_mqr_gaussian_teacher(m: float, q: float, b: int, omega: float, L: int, 
                                 eps: float, w0: np.ndarray, m0: float) -> Tuple[float, float, float]:
    """Compute mqr for Gaussian teacher case."""
    if w0 is None:
        raise ValueError("w0 must be provided for gaussian teacher")
    
    d = len(w0)
    alpha = m / m0
    s2 = max(q - m*m/m0, 0.0)
    
    # Handle degenerate case
    if s2 <= eps:
        psi_w = np.array([_quantize_step_scalar(alpha * w0i, omega, b) for w0i in w0])
        mpsi = float(np.dot(psi_w, w0) / d)
        qpsi = float(np.dot(psi_w, psi_w) / d)
        rpsi = float(alpha * np.dot(psi_w, w0) / d)
        return mpsi, qpsi, rpsi
    
    s = np.sqrt(s2)
    delta = 2.0 * omega / L
    
    # Thresholds and levels
    k = np.arange(1, L + 1, dtype=np.float64)
    theta = -omega + (k - 0.5) * delta
    v = -omega + delta * np.arange(0, L + 1, dtype=np.float64)
    
    # Component-wise calculation
    mpsi_sum = 0.0
    qpsi_sum = 0.0
    rpsi_sum = 0.0
    
    for w0i in w0:
        a = alpha * w0i
        
        # Compute probabilities
        phi_vals = std_norm_cdf((a - theta) / s)
        
        # M(a), Q(a), R(a) for this component
        M_a = -omega + delta * np.sum(phi_vals)
        
        Q_a = v[0]**2
        for j in range(L):
            Q_a += (v[j+1]**2 - v[j]**2) * phi_vals[j]
        
        phi_vals_density = phi((a - theta) / s)
        R_a = a * M_a + delta * s * np.sum(phi_vals_density)
        
        mpsi_sum += w0i * M_a
        qpsi_sum += Q_a
        rpsi_sum += R_a
    
    return float(mpsi_sum / d), float(qpsi_sum / d), float(rpsi_sum / d)

def std_norm_cdf(x: Union[float, np.ndarray]) -> Union[float, np.ndarray]:
    """
    Compute standard normal cumulative distribution function.
    
    Args:
        x: Input value(s)
        
    Returns:
        CDF value(s) Φ(x)
    """
    x = np.asarray(x, dtype=np.float64)
    return 0.5 * erfc_np(-x / SQRT_2)


def activation_moments_kappa_sigma(omega: float, b: int) -> Tuple[float, float]:
    """
    Compute activation quantization moments kappa and sigma_psi^2.
    
    For input quantization ψ applied to x ~ N(0,1), compute:
    - kappa = E[x ψ(x)]
    - sigma_psi^2 = E[ψ(x)^2]
    
    Args:
        omega: Quantization range [-omega, omega]
        b: Number of quantization bits
        
    Returns:
        Tuple of (kappa, sigma_psi^2)
        
    Raises:
        ValueError: If b < 2 (insufficient bits for quantization)
    """
    L = 2**b - 2
    if L <= 0:
        raise ValueError("Activation quantization requires b >= 2")
    
    delta = 2.0 * omega / L
    
    # Quantization levels: v_k = -ω + Δk for k=0,...,L
    levels = -omega + delta * np.arange(0, L + 1, dtype=np.float64)
    
    # Quantization thresholds: θ_k = -ω + (k-0.5)Δ for k=1,...,L
    k_indices = np.arange(1, L + 1, dtype=np.float64)
    thresholds = -omega + (k_indices - 0.5) * delta
    
    # Compute probabilities p_k for each quantization cell
    probabilities = np.empty(L + 1, dtype=np.float64)
    probabilities[0] = std_norm_cdf(thresholds[0])
    if L > 1:
        probabilities[1:L] = std_norm_cdf(thresholds[1:]) - std_norm_cdf(thresholds[:-1])
    probabilities[L] = 1.0 - std_norm_cdf(thresholds[-1])
    
    # Compute moments
    sigma_psi2 = float(np.sum(levels**2 * probabilities))
    kappa = float(delta * np.sum(phi(thresholds)))
    
    return kappa, sigma_psi2


class STESimulator:
    """
    Stochastic Training with Error (STE) simulator for quantized neural networks.
    
    This class implements both stochastic gradient descent with quantization
    and the corresponding ODE dynamics in the mean-field limit.
    
    Attributes:
        d: Dimension of the weight vector
        T: Total number of training steps
        eta: Learning rate
        lam: L2 regularization parameter
        sigma: Noise standard deviation
        omega: Quantization range [-omega, omega]
        b: Number of quantization bits (None for clip only)
        seed: Random seed for reproducibility
        log_every: Frequency of logging
        quant_method: Quantization method ('step' or 'round')
        snapshot_every: Frequency of saving snapshots
        teacher: Teacher type ('ones' or 'gaussian')
    """
    
    def __init__(self,
                 d: int = 800,
                 T: int = 8000,
                 eta: float = 0.08,
                 lam: float = 0.0,
                 sigma: float = 0.1,
                 omega: float = 1.0,
                 b: Optional[int] = 3,
                 seed: int = 0,
                 log_every: int = 100,
                 quant_method: str = "step",
                 snapshot_every: int = 2000,
                 teacher: str = "ones"):
        """
        Initialize the STE simulator.
        
        Args:
            d: Dimension of the weight vector
            T: Total number of training steps
            eta: Learning rate
            lam: L2 regularization parameter
            sigma: Noise standard deviation
            omega: Quantization range [-omega, omega]
            b: Number of quantization bits (None for clip only)
            seed: Random seed for reproducibility
            log_every: Frequency of logging
            quant_method: Quantization method ('step' or 'round')
            snapshot_every: Frequency of saving snapshots
            teacher: Teacher type ('ones' or 'gaussian')
            
        Raises:
            ValueError: If parameters are invalid
        """
        # Store parameters
        self.d = int(d)
        self.T = int(T)
        self.eta = float(eta)
        self.lam = float(lam)
        self.sigma = float(sigma)
        self.omega = float(omega)
        self.b = None if b is None else int(b)
        self.seed = int(seed)
        self.log_every = int(log_every)
        self.quant_method = quant_method.lower()
        self.snapshot_every = int(snapshot_every)
        self.teacher = teacher.lower()
        
        # Validate parameters
        self._validate_parameters()
        
        # Initialize random number generator
        self.rng = np.random.default_rng(self.seed)
        
        # Precompute constants
        self._sqrt_d_inv = 1.0 / np.sqrt(self.d)
        self._lam_over_d = self.lam / self.d
        
        # Teacher-related attributes (set during simulation)
        self._w0: Optional[np.ndarray] = None
        self._m0: float = 1.0
        
        # Initialize quantization parameters
        self._init_quantization_params()
    
    def _validate_parameters(self):
        """Validate input parameters."""
        if self.b is not None and self.b < 2:
            raise ValueError("b must be at least 2 (uniform quantization). Use b=None for clip only")
        if self.quant_method not in ("step", "round"):
            raise ValueError("quant_method must be 'step' or 'round'")
        if self.teacher not in ("ones", "gaussian"):
            raise ValueError("teacher must be 'ones' or 'gaussian'")
    
    def _init_quantization_params(self):
        """Initialize quantization-related parameters."""
        if self.b is None:
            self._L = None
            self._Delta = None
            self._thresholds = None
        else:
            self._L = 2 ** self.b - 2
            if self._L <= 0:
                raise ValueError("L is less than 0. b must be at least 2")
            self._Delta = 2.0 * self.omega / self._L
            k_values = np.arange(1, self._L + 1, dtype=np.float64)
            self._thresholds = -self.omega + (2.0 * k_values - 1.0) * (self.omega / self._L)


    def quantize_round(self, x: np.ndarray) -> np.ndarray:
        """
        Apply round-based uniform quantization.
        
        Args:
            x: Input array to quantize
            
        Returns:
            Quantized array with values in [-omega, omega]
        """
        if self.b is None:
            return np.clip(x, -self.omega, self.omega)
        
        # Clip input to valid range
        clipped = np.clip(x, -self.omega, self.omega)
        
        # Map to quantization indices
        k = np.rint((clipped + self.omega) / self._Delta + 1e-12)
        k = np.clip(k, 0, self._L)
        
        # Map back to quantized values
        return -self.omega + k * self._Delta

    def quantize_step(self, x: np.ndarray) -> np.ndarray:
        """
        Apply step function-based uniform quantization.
        
        Args:
            x: Input array to quantize
            
        Returns:
            Quantized array with values in [-omega, omega]
        """
        if self.b is None:
            return np.clip(x, -self.omega, self.omega)
        
        # Count how many thresholds each input exceeds
        count = (x[..., None] >= self._thresholds[None, ...]).sum(axis=-1, dtype=np.float64)
        
        # Map to quantized values
        return -self.omega + self._Delta * count

    def psi(self, v: np.ndarray) -> np.ndarray:
        """
        Apply quantization function ψ to input.
        
        Args:
            v: Input array
            
        Returns:
            Quantized array
        """
        if self.quant_method == "step":
            return self.quantize_step(v)
        else:
            return self.quantize_round(v)


    def epsilon_g_from_stats(self, m_psi: np.ndarray, q_psi: np.ndarray) -> np.ndarray:
        """
        Compute generalization error from macroscopic statistics.
        
        Args:
            m_psi: Mean overlap <ψ(w), w0>/d
            q_psi: Second moment ||ψ(w)||^2/d
            
        Returns:
            Generalization error ε_g
        """
        return self._m0 - 2.0 * m_psi + q_psi + (self.sigma ** 2)

    def compute_metrics(self, w: np.ndarray, w0: np.ndarray) -> Tuple[float, float, float, float, float, float, np.ndarray]:
        """
        Compute all macroscopic quantities for current weights.
        
        Args:
            w: Current weight vector
            w0: Teacher weight vector
            
        Returns:
            Tuple of (m, q, m_psi, q_psi, r_psi, epsilon_g, psi_w) where:
                m: <w, w0>/d
                q: ||w||^2/d
                m_psi: <ψ(w), w0>/d
                q_psi: ||ψ(w)||^2/d
                r_psi: <w, ψ(w)>/d
                epsilon_g: Generalization error
                psi_w: Quantized weights ψ(w)
        """
        psi_w = self.psi(w)
        
        # Compute overlaps and norms
        m = float(np.dot(w, w0) / self.d)
        q = float(np.dot(w, w) / self.d)
        m_psi = float(np.dot(psi_w, w0) / self.d)
        q_psi = float(np.dot(psi_w, psi_w) / self.d)
        r_psi = float(np.dot(psi_w, w) / self.d)
        
        # Compute generalization error
        epsilon_g = float(self.epsilon_g_from_stats(m_psi, q_psi))
        
        return m, q, m_psi, q_psi, r_psi, epsilon_g, psi_w

    def should_snapshot(self, step: int) -> bool:
        """
        Determine if a snapshot should be saved at this step.
        
        Args:
            step: Current training step
            
        Returns:
            True if snapshot should be saved
        """
        return (step == 0) or (step == self.T) or ((step % self.snapshot_every) == 0)

    def ste_update(self, w: np.ndarray, w0: np.ndarray) -> np.ndarray:
        """
        Perform one step of STE (Straight-Through Estimator) update.
        
        The update follows: w ← w + η[(1/√d)(w0 - ψ(w))^T x + ξ](x/√d) - η(λ/d)ψ(w)
        
        Args:
            w: Current weight vector
            w0: Teacher weight vector
            
        Returns:
            Updated weight vector
        """
        # Sample input and noise
        x = self.rng.normal(0.0, 1.0, size=self.d)
        xi = self.rng.normal(0.0, self.sigma)
        
        # Compute quantized weights
        psi_w = self.psi(w)
        
        # Compute gradient factor
        gradient_factor = self._sqrt_d_inv * np.dot((w0 - psi_w), x) + xi
        
        # Apply STE update
        w_new = w + self.eta * (gradient_factor * self._sqrt_d_inv) * x - self.eta * self._lam_over_d * psi_w
        
        return w_new


    def simulate(self, init_w: Optional[np.ndarray] = None, 
                 w0_fixed: Optional[np.ndarray] = None) -> Dict[str, Any]:
        """
        Run the STE simulation.
        
        Args:
            init_w: Initial weight vector (if None, randomly initialized)
            w0_fixed: Fixed teacher vector (if None, generated based on teacher type)
            
        Returns:
            Dictionary containing training history with keys:
                - steps: Training steps where metrics were logged
                - m, q, m_psi, q_psi, r_psi, epsilon_g: Macroscopic quantities
                - snapshot_steps: Steps where snapshots were saved
                - w_snapshots, psi_snapshots: Weight snapshots
                - w0: Teacher vector
                - params: Simulation parameters
        """
        # Initialize weights and teacher
        w = self._initialize_weights(init_w)
        w0 = self._initialize_teacher(w0_fixed)
        
        # Create history dictionary
        history = self._create_history_dict(w0)
        
        # Record initial state
        self._record_metrics(history, w, w0, step=0)
        
        # Main training loop
        for t in range(1, self.T + 1):
            w = self.ste_update(w, w0)
            
            # Log metrics at specified intervals
            if (t % self.log_every == 0) or (t == self.T):
                self._record_metrics(history, w, w0, step=t, compute_snapshot=False)
            
            # Save snapshots at specified intervals
            if self.should_snapshot(t):
                self._save_snapshot(history, w, step=t)
        
        # Convert lists to arrays
        self._finalize_history(history)
        
        return history
    
    def _initialize_weights(self, init_w: Optional[np.ndarray]) -> np.ndarray:
        """Initialize weight vector."""
        if init_w is None:
            return self.rng.normal(0.0, 1.0, size=self.d)
        else:
            w = np.asarray(init_w, dtype=np.float64).copy()
            if w.shape != (self.d,):
                raise ValueError(f"init_w must be of shape ({self.d},) (actual shape: {w.shape})")
            return w
    
    def _initialize_teacher(self, w0_fixed: Optional[np.ndarray]) -> np.ndarray:
        """Initialize teacher vector."""
        if w0_fixed is not None:
            w0 = w0_fixed
        else:
            if self.teacher == "ones":
                w0 = np.ones(self.d, dtype=np.float64)
            else:  # gaussian
                rng_w0 = np.random.default_rng(self.seed + 99991)
                w0 = rng_w0.normal(0.0, 1.0, size=self.d)
        
        # Store teacher properties
        self._m0 = float(np.dot(w0, w0) / self.d)
        self._w0 = w0
        
        return w0
    
    def _create_history_dict(self, w0: np.ndarray) -> Dict[str, Any]:
        """Create empty history dictionary."""
        return {
            "steps": [], "m": [], "q": [], "m_psi": [], "q_psi": [], "r_psi": [], "epsilon_g": [],
            "snapshot_steps": [], "w_snapshots": {}, "psi_snapshots": {},
            "w0": w0.copy(),
            "params": {
                "d": self.d, "T": self.T, "eta": self.eta, "lam": self.lam, "sigma": self.sigma,
                "omega": self.omega, "b": self.b, "seed": self.seed, "log_every": self.log_every,
                "quant_method": self.quant_method, "snapshot_every": self.snapshot_every,
                "teacher": self.teacher, "m0": self._m0
            }
        }
    
    def _record_metrics(self, history: Dict[str, Any], w: np.ndarray, w0: np.ndarray, 
                       step: int, compute_snapshot: bool = True):
        """Record metrics at current step."""
        m, q, m_psi, q_psi, r_psi, eps, psi_w = self.compute_metrics(w, w0)
        
        history["steps"].append(step)
        history["m"].append(m)
        history["q"].append(q)
        history["m_psi"].append(m_psi)
        history["q_psi"].append(q_psi)
        history["r_psi"].append(r_psi)
        history["epsilon_g"].append(eps)
        
        if compute_snapshot and self.should_snapshot(step):
            history["snapshot_steps"].append(step)
            history["w_snapshots"][step] = w.copy()
            history["psi_snapshots"][step] = psi_w.copy()
    
    def _save_snapshot(self, history: Dict[str, Any], w: np.ndarray, step: int):
        """Save weight snapshot."""
        psi_w = self.psi(w)
        history["snapshot_steps"].append(step)
        history["w_snapshots"][step] = w.copy()
        history["psi_snapshots"][step] = psi_w.copy()
    
    def _finalize_history(self, history: Dict[str, Any]):
        """Convert history lists to numpy arrays."""
        for k in ("steps", "snapshot_steps"):
            history[k] = np.array(history[k], dtype=int)
        for k in ("m", "q", "m_psi", "q_psi", "r_psi", "epsilon_g"):
            history[k] = np.array(history[k], dtype=np.float64)


    def ode_rhs(self, tau: float, y: np.ndarray) -> Tuple[np.ndarray, Tuple[float, float, float]]:
        """
        Compute right-hand side of the ODE system.
        
        The ODE describes the evolution of macroscopic quantities:
        dm/dτ = -η((1+λ)m_ψ★ - m0)
        dq/dτ = -2η((1+λ)r_ψ★ - m) + η²(m0 - 2m_ψ★ + q_ψ★ + σ²)
        
        Args:
            tau: Time variable (τ = t/d)
            y: State vector [m, q]
            
        Returns:
            Tuple of (dy/dτ, (m_psi, q_psi, r_psi))
            
        Raises:
            ValueError: If b is None (clip only not supported for ODE)
        """
        if self.b is None:
            raise ValueError("For b=None (clip only), star statistics mqr_from_isotropy is undefined. Specify b>=2.")
        
        # Extract current state
        m, q = float(y[0]), float(y[1])
        
        # Compute quantized statistics under isotropy
        mpsi, qpsi, rpsi = mqr_from_isotropy(
            m, q, self.b, self.omega, 
            teacher=self.teacher, w0=self._w0, m0=self._m0
        )
        
        # Compute derivatives
        dm_dt = -self.eta * ((1.0 + self.lam) * mpsi - self._m0)
        dq_dt = (-2.0 * self.eta * ((1.0 + self.lam) * rpsi - m) + 
                 (self.eta ** 2) * (self._m0 - 2.0 * mpsi + qpsi + (self.sigma ** 2)))
        
        return np.array([dm_dt, dq_dt], dtype=np.float64), (mpsi, qpsi, rpsi)


    def solve_ode(self,
                  history: Optional[Dict[str, Any]] = None,
                  steps: Optional[np.ndarray] = None,
                  y0: Optional[Tuple[float, float]] = None,
                  substeps: int = 12) -> Dict[str, np.ndarray]:
        """
        Solve the ODE system using 4th-order Runge-Kutta method.
        
        Args:
            history: Training history (if provided, extracts steps and initial condition)
            steps: Time steps to evaluate at (if history not provided)
            y0: Initial condition [m(0), q(0)] (if history not provided)
            substeps: Number of RK4 substeps between output points
            
        Returns:
            Dictionary with ODE solution containing:
                - tau: Time points (τ = t/d)
                - m, q: Macroscopic quantities
                - m_psi, q_psi, r_psi: Quantized statistics
                - epsilon_g: Generalization error
                
        Raises:
            ValueError: If neither history nor (steps, y0) is provided
        """
        # Extract initial conditions and time points
        steps, y0 = self._extract_ode_params(history, steps, y0)
        
        # Convert steps to tau
        tau = steps / float(self.d)
        N = len(tau)
        
        # Initialize solution arrays
        solution = self._initialize_ode_solution(N)
        
        # Set initial state
        y = np.array(y0, dtype=np.float64)
        solution["m"][0], solution["q"][0] = y
        _, (solution["m_psi"][0], solution["q_psi"][0], solution["r_psi"][0]) = self.ode_rhs(tau[0], y)
        
        # Integrate using RK4
        for i in range(1, N):
            y = self._rk4_step(y, tau[i-1], tau[i], substeps)
            solution["m"][i], solution["q"][i] = y
            _, (solution["m_psi"][i], solution["q_psi"][i], solution["r_psi"][i]) = self.ode_rhs(tau[i], y)
        
        # Compute generalization error
        solution["tau"] = tau
        solution["epsilon_g"] = self.epsilon_g_from_stats(solution["m_psi"], solution["q_psi"])
        
        return solution
    
    def _extract_ode_params(self, history: Optional[Dict[str, Any]], 
                           steps: Optional[np.ndarray], 
                           y0: Optional[Tuple[float, float]]) -> Tuple[np.ndarray, Tuple[float, float]]:
        """Extract ODE parameters from inputs."""
        if history is not None:
            steps = np.asarray(history["steps"], dtype=int)
            y0 = (float(history["m"][0]), float(history["q"][0]))
        elif (steps is not None) and (y0 is not None):
            steps = np.asarray(steps, dtype=int)
            y0 = (float(y0[0]), float(y0[1]))
        else:
            raise ValueError("Either history or (steps, y0) must be specified")
        return steps, y0
    
    def _initialize_ode_solution(self, N: int) -> Dict[str, np.ndarray]:
        """Initialize arrays for ODE solution."""
        return {
            "m": np.zeros(N, dtype=np.float64),
            "q": np.zeros(N, dtype=np.float64),
            "m_psi": np.zeros(N, dtype=np.float64),
            "q_psi": np.zeros(N, dtype=np.float64),
            "r_psi": np.zeros(N, dtype=np.float64)
        }
    
    def _rk4_step(self, y: np.ndarray, t0: float, t1: float, substeps: int) -> np.ndarray:
        """Perform one RK4 integration step."""
        h = (t1 - t0) / float(substeps)
        t = t0
        
        for _ in range(substeps):
            k1, _ = self.ode_rhs(t, y)
            k2, _ = self.ode_rhs(t + 0.5 * h, y + 0.5 * h * k1)
            k3, _ = self.ode_rhs(t + 0.5 * h, y + 0.5 * h * k2)
            k4, _ = self.ode_rhs(t + h, y + h * k3)
            y = y + (h / 6.0) * (k1 + 2.0 * k2 + 2.0 * k3 + k4)
            t += h
            
        return y


    def run_histories(self,
                      num_runs: int = 5,
                      substeps: int = 12,
                      shared_init: bool = True,
                      ode_once: bool = True,
                      init_w: Optional[np.ndarray] = None
                      ) -> Tuple[List[Dict[str, Any]], Union[Dict[str, np.ndarray], List[Dict[str, np.ndarray]]]]:
        """
        Run multiple STE simulations and optionally solve corresponding ODEs.
        
        Args:
            num_runs: Number of simulation runs
            substeps: Number of RK4 substeps for ODE solver
            shared_init: If True, all runs use same initial weights
            ode_once: If True, solve ODE only once with shared initial condition
            init_w: Initial weight vector (if None, randomly initialized)
            
        Returns:
            Tuple of (histories, odes) where:
                - histories: List of training histories
                - odes: ODE solution(s) - single dict if ode_once=True, list otherwise
        """
        # Prepare shared initialization if requested
        init_w0 = self._prepare_shared_init(init_w) if shared_init else None
        
        # Create shared teacher
        w0_shared = self._create_teacher()
        m0_shared = float(np.dot(w0_shared, w0_shared) / self.d)
        
        # Run simulations
        histories = self._run_simulations(num_runs, init_w0, w0_shared, shared_init)
        
        # Solve ODEs
        if ode_once and shared_init:
            self._w0 = w0_shared
            self._m0 = m0_shared
            ode = self.solve_ode(history=histories[0], substeps=substeps)
            return histories, ode
        else:
            odes = []
            for h in histories:
                self._w0 = h["w0"]
                self._m0 = h["params"]["m0"]
                odes.append(self.solve_ode(history=h, substeps=substeps))
            return histories, odes
    
    def _prepare_shared_init(self, init_w: Optional[np.ndarray]) -> np.ndarray:
        """Prepare shared initial weights."""
        if init_w is not None:
            init_w0 = np.asarray(init_w, dtype=np.float64)
            if init_w0.shape != (self.d,):
                raise ValueError(f"init_w must be of shape ({self.d},) (actual shape: {init_w0.shape})")
            return init_w0
        else:
            rng_init = np.random.default_rng(self.seed)
            return rng_init.normal(0.0, 1.0, size=self.d)
    
    def _create_teacher(self) -> np.ndarray:
        """Create teacher vector based on teacher type."""
        if self.teacher == "ones":
            return np.ones(self.d, dtype=np.float64)
        else:  # gaussian
            rng_w0 = np.random.default_rng(self.seed + 99991)
            return rng_w0.normal(0.0, 1.0, size=self.d)
    
    def _run_simulations(self, num_runs: int, init_w0: Optional[np.ndarray], 
                        w0_shared: np.ndarray, shared_init: bool) -> List[Dict[str, Any]]:
        """Run multiple simulations."""
        histories = []
        for r in range(num_runs):
            sim = STESimulator(
                d=self.d, T=self.T, eta=self.eta, lam=self.lam,
                sigma=self.sigma, omega=self.omega, b=self.b,
                seed=self.seed + r, log_every=self.log_every,
                quant_method=self.quant_method, snapshot_every=self.snapshot_every,
                teacher=self.teacher
            )
            h = sim.simulate(init_w=init_w0 if shared_init else None, w0_fixed=w0_shared)
            histories.append(h)
        return histories


    @staticmethod
    def aggregate_histories(histories: List[Dict[str, Any]],
                            keys: Tuple[str, ...]) -> Dict[str, np.ndarray]:
        """
        Aggregate multiple training histories into arrays.
        
        Args:
            histories: List of history dictionaries
            keys: Keys to aggregate
            
        Returns:
            Dictionary with aggregated data of shape [T_points, num_runs]
        """
        T_points = len(histories[0]["steps"])
        num_runs = len(histories)
        
        aggregated = {}
        for key in keys:
            aggregated[key] = np.zeros((T_points, num_runs), dtype=np.float64)
            for r, h in enumerate(histories):
                aggregated[key][:, r] = h[key]
                
        return aggregated

    @staticmethod
    def aggregate_odes(odes: Union[List[Dict[str, np.ndarray]], Dict[str, np.ndarray]],
                       keys: Tuple[str, ...]) -> Dict[str, np.ndarray]:
        """
        Aggregate ODE solutions into arrays.
        
        Args:
            odes: Single ODE solution or list of ODE solutions
            keys: Keys to aggregate
            
        Returns:
            Dictionary with aggregated data of shape [T_points, num_runs]
        """
        # Convert single dict to list
        if isinstance(odes, dict):
            odes = [odes]
            
        T_points = len(odes[0]["tau"])
        num_runs = len(odes)
        
        aggregated = {}
        for key in keys:
            aggregated[key] = np.zeros((T_points, num_runs), dtype=np.float64)
            for r, o in enumerate(odes):
                aggregated[key][:, r] = o[key]
                
        return aggregated

    @staticmethod
    def select_every_k_steps(steps: np.ndarray, k: int = 100) -> np.ndarray:
        """
        Select indices for every k-th step (always including 0).
        
        Args:
            steps: Array of step numbers
            k: Step interval
            
        Returns:
            Array of indices
        """
        steps = np.asarray(steps, dtype=int)
        idx = np.nonzero((steps % k) == 0)[0]
        
        # Ensure index 0 is included
        if idx.size == 0 or idx[0] != 0:
            idx = np.insert(idx, 0, 0)
            
        return idx

    def check_quantizers(self, n: int = 200_000) -> Tuple[bool, float]:
        """
        Verify that round and step quantization methods produce identical results.
        
        Args:
            n: Number of test samples
            
        Returns:
            Tuple of (exact_match, max_difference)
        """
        # Generate test data
        rng = np.random.default_rng(self.seed + 1234)
        x = np.concatenate([
            rng.normal(0, 1, n // 2),
            rng.uniform(-3 * self.omega, 3 * self.omega, n - n // 2)
        ])
        
        # Compare quantizers
        q_round = self.quantize_round(x)
        q_step = self.quantize_step(x)
        
        # Check for differences
        diff = np.max(np.abs(q_round - q_step))
        exact_match = bool(np.array_equal(q_round, q_step))
        
        return exact_match, float(diff)


class STESimulatorActQuant(STESimulator):
    """
    Extended STE simulator with activation quantization.
    
    This class adds input quantization to the base STE simulator, modeling
    scenarios where both weights and activations are quantized. The theory
    assumes m0=1 (teacher='ones' case).
    
    Additional features:
        - Quantizes both weights ψ(w) and inputs ψ(x)
        - Modified generalization error formula
        - Modified ODE dynamics accounting for input quantization
    """
    
    def __init__(self, *args, **kwargs):
        """
        Initialize activation-quantized STE simulator.
        
        Inherits all parameters from STESimulator with additional constraints:
        - Requires b >= 2 (uniform quantization)
        - Theory assumes teacher='ones' (m0=1)
        """
        super().__init__(*args, **kwargs)
        
        # Validate quantization requirement
        if self.b is None:
            raise ValueError("STESimulatorActQuant requires b>=2 (uniform quantization)")
        
        # Warn if using non-standard teacher
        if self.teacher != "ones":
            print("Warning: STESimulatorActQuant's theory assumes teacher='ones' (m0=1)")
        
        # Precompute activation quantization moments
        self.kappa, self.sigma_psi2 = activation_moments_kappa_sigma(self.omega, self.b)

    def epsilon_g_from_stats(self, m_psi: np.ndarray, q_psi: np.ndarray) -> np.ndarray:
        """
        Compute generalization error for activation-quantized case.
        
        Formula: ε_g = 1 + σ_ψ²q_ψ - 2κm_ψ + σ²
        
        Args:
            m_psi: Mean overlap <ψ(w), w0>/d
            q_psi: Second moment ||ψ(w)||^2/d
            
        Returns:
            Generalization error
        """
        return (1.0 + self.sigma_psi2 * q_psi) - (2.0 * self.kappa * m_psi) + (self.sigma ** 2)

    def ste_update(self, w: np.ndarray, w0: np.ndarray) -> np.ndarray:
        """
        Perform STE update with activation quantization.
        
        The update uses quantized inputs: w ← w + η[(1/√d)w0^Tx + ξ - (1/√d)ψ(w)^Tψ(x)](ψ(x)/√d) - η(λ/d)ψ(w)
        
        Args:
            w: Current weight vector
            w0: Teacher weight vector
            
        Returns:
            Updated weight vector
        """
        # Sample input and noise
        x = self.rng.normal(0.0, 1.0, size=self.d)
        xi = self.rng.normal(0.0, self.sigma)
        
        # Quantize weights and inputs
        psi_w = self.psi(w)
        psi_x = self.psi(x)
        
        # Compute prediction error
        error = self._sqrt_d_inv * (np.dot(w0, x) - np.dot(psi_w, psi_x)) + xi
        
        # Apply STE update with quantized gradient
        w_new = w + self.eta * (error * self._sqrt_d_inv) * psi_x - self.eta * self._lam_over_d * psi_w
        
        return w_new

    def ode_rhs(self, tau: float, y: np.ndarray) -> Tuple[np.ndarray, Tuple[float, float, float]]:
        """
        Compute ODE right-hand side for activation-quantized case.
        
        Modified dynamics:
        dm/dτ = η(κ - (σ_ψ² + λ)m_ψ★)
        dq/dτ = 2η(κm - (σ_ψ² + λ)r_ψ★) + η²σ_ψ²(1 + σ_ψ²q_ψ★ - 2κm_ψ★ + σ²)
        
        Args:
            tau: Time variable (τ = t/d)
            y: State vector [m, q]
            
        Returns:
            Tuple of (dy/dτ, (m_psi, q_psi, r_psi))
        """
        if self.b is None:
            raise ValueError("For b=None (clip only), star statistics mqr_from_isotropy is undefined. Specify b>=2.")
        
        # Extract state
        m, q = float(y[0]), float(y[1])
        
        # Compute quantized statistics
        mpsi, qpsi, rpsi = mqr_from_isotropy(m, q, self.b, self.omega)
        
        # Compute derivatives with activation quantization effects
        dm_dt = self.eta * (self.kappa - (self.sigma_psi2 + self.lam) * mpsi)
        dq_dt = (2.0 * self.eta * (self.kappa * m - (self.sigma_psi2 + self.lam) * rpsi) +
                 (self.eta ** 2) * self.sigma_psi2 * 
                 (1.0 + self.sigma_psi2 * qpsi - 2.0 * self.kappa * mpsi + (self.sigma ** 2)))
        
        return np.array([dm_dt, dq_dt], dtype=np.float64), (mpsi, qpsi, rpsi)

    def run_histories(self,
                      num_runs: int = 5,
                      substeps: int = 12,
                      shared_init: bool = True,
                      ode_once: bool = True,
                      init_w: Optional[np.ndarray] = None
                      ) -> Tuple[List[Dict[str, Any]], Union[Dict[str, np.ndarray], List[Dict[str, np.ndarray]]]]:
        """
        Run multiple simulations with activation quantization.
        
        Overrides parent method to ensure proper instantiation of derived class.
        All parameters and behavior are identical to parent class.
        """
        # Prepare shared initialization if requested
        init_w0 = self._prepare_shared_init(init_w) if shared_init else None
        
        # Create shared teacher
        w0_shared = self._create_teacher()
        m0_shared = float(np.dot(w0_shared, w0_shared) / self.d)
        
        # Run simulations using derived class
        histories = []
        for r in range(num_runs):
            # Use self.__class__ to ensure correct derived class instantiation
            sim = self.__class__(
                d=self.d, T=self.T, eta=self.eta, lam=self.lam,
                sigma=self.sigma, omega=self.omega, b=self.b,
                seed=self.seed + r, log_every=self.log_every,
                quant_method=self.quant_method, snapshot_every=self.snapshot_every,
                teacher=self.teacher
            )
            h = sim.simulate(init_w=init_w0 if shared_init else None, w0_fixed=w0_shared)
            histories.append(h)
        
        # Solve ODEs
        if ode_once and shared_init:
            self._w0 = w0_shared
            self._m0 = m0_shared
            ode = self.solve_ode(history=histories[0], substeps=substeps)
            return histories, ode
        else:
            odes = []
            for h in histories:
                self._w0 = h["w0"]
                self._m0 = h["params"]["m0"]
                odes.append(self.solve_ode(history=h, substeps=substeps))
            return histories, odes



