"""
sdm.py - Sparse Distributed Memory with Adaptive Control Theory
For ICLR 2026 Paper: "Adaptive Control Theory for High-Dimensional Associative Memory"

Updated with: binomial lookup tables, reproducible RNGs, tie-break reconstruction,
packed popcount, and stability guardrails
"""

import numpy as np
from scipy import stats
from typing import Optional, Tuple, List
import warnings
warnings.filterwarnings('ignore')


# ============================================================================
# Binomial Lookup Table (Major Speed Win)
# ============================================================================

class BinomTable:
    """Precompute Binomial(n, 0.5) PMF/CDF for all integer thresholds T."""
    
    def __init__(self, n: int):
        self.n = int(n)
        self._pmf = stats.binom.pmf(np.arange(self.n + 1), self.n, 0.5)
        self._cdf = np.cumsum(self._pmf)

    @property
    def pmf(self) -> np.ndarray:
        return self._pmf

    @property
    def cdf(self) -> np.ndarray:
        return self._cdf

    def k_from_T(self, L: int, T: int) -> float:
        """Expected activations: E[k|T] = L * CDF(T)"""
        T = int(np.clip(T, 0, self.n))
        return float(L * self._cdf[T])

    def slope_at_T(self, L: int, T: int, s_min: float = 1e-3) -> float:
        """Discrete slope: dE[k]/dT = L * PMF(T)"""
        T = int(np.clip(T, 0, self.n))
        return max(float(L * self._pmf[T]), s_min)


# ============================================================================
# Utility Functions
# ============================================================================

def make_clustered_addresses(n_locations: int, n_dims: int,
                             n_prototypes: int = 8,
                             flip_prob: float = 0.1,
                             seed: int = 0) -> np.ndarray:
    """
    Produce structured addresses by sampling around a small set of prototypes.
    Used for non-binomial stress testing.
    
    Args:
        n_locations: Number of addresses to generate
        n_dims: Dimensionality of each address
        n_prototypes: Number of prototype addresses
        flip_prob: Probability of flipping each bit from prototype
        seed: Random seed for reproducibility
        
    Returns:
        Array of clustered addresses (n_locations, n_dims)
    """
    rng = np.random.default_rng(seed)
    protos = rng.integers(0, 2, size=(n_prototypes, n_dims), dtype=np.uint8)
    addrs = np.empty((n_locations, n_dims), dtype=np.uint8)
    for i in range(n_locations):
        p = protos[rng.integers(0, n_prototypes)].copy()
        mask = rng.random(n_dims) < flip_prob
        p[mask] = 1 - p[mask]
        addrs[i] = p
    return addrs


# ============================================================================
# Core SDM Implementation
# ============================================================================

class SimpleSdm:
    """
    Core SDM implementation with adaptive threshold control.
    Implements the fundamental control law: k = L * F_binomial(T)
    """
    
    def __init__(self, n_dims: int, n_locations: int, 
                 threshold: Optional[int] = None, 
                 binom_table: Optional[BinomTable] = None):
        """
        Initialize SDM with random hard locations.
        
        Args:
            n_dims: Dimensionality of bit patterns (n)
            n_locations: Number of hard locations (L)
            threshold: Initial Hamming distance threshold (T), defaults to n/2
            binom_table: Precomputed binomial table (created if None)
        """
        self.n_dims = n_dims
        self.n_locations = n_locations
        
        # Generate random hard location addresses uniformly in {0,1}^n
        self.addresses = np.random.randint(0, 2, (n_locations, n_dims), dtype=np.uint8)
        
        # Initialize counters to zero
        self.counters = np.zeros((n_locations, n_dims), dtype=np.int32)
        
        # Set initial threshold
        self.threshold = threshold if threshold is not None else n_dims // 2
        
        # Binomial lookup table for fast queries
        self.binom = binom_table if binom_table is not None else BinomTable(n_dims)
        
        # Optional: packed representation for faster Hamming distances
        self.use_packed = (n_dims % 8 == 0)
        if self.use_packed:
            self._addresses_packed = np.packbits(self.addresses, axis=1)
        
        # Track statistics
        self.patterns_stored = 0
        self.last_activation_count = 0
    
    @staticmethod
    def _reconstruct(counter_sum: np.ndarray, fallback_bits: np.ndarray) -> np.ndarray:
        """
        Reconstruct pattern from counter sum, using fallback for ties.
        This reduces bias when some counters are exactly zero.
        
        Args:
            counter_sum: Sum of activated counters
            fallback_bits: Tie-breaking bits (typically the query)
            
        Returns:
            Reconstructed binary pattern
        """
        rec = (counter_sum > 0).astype(np.uint8)
        zeros = (counter_sum == 0)
        if np.any(zeros):
            rec[zeros] = fallback_bits[zeros]  # tie-break toward query bits
        return rec
    
    def hamming_distance(self, x: np.ndarray, y: np.ndarray) -> int:
        """Compute Hamming distance between two bit vectors."""
        return np.sum(x != y)
    
    def _hamming_distances_packed(self, pattern: np.ndarray) -> np.ndarray:
        """Fast Hamming distances using packed bytes (8-16x speedup on large n)."""
        p = np.packbits(pattern, axis=0)
        xor = np.bitwise_xor(self._addresses_packed, p)
        return np.unpackbits(xor, axis=1).sum(axis=1)
    
    def hamming_distances_batch(self, pattern: np.ndarray) -> np.ndarray:
        """Compute Hamming distances from pattern to all addresses."""
        if self.use_packed:
            return self._hamming_distances_packed(pattern)
        return np.sum(self.addresses != pattern, axis=1)
    
    def store_pattern(self, pattern: np.ndarray) -> int:
        """
        Store a pattern by updating counters at activated locations.
        
        Args:
            pattern: Binary pattern to store {0,1}^n
            
        Returns:
            Number of activated locations (k)
        """
        distances = self.hamming_distances_batch(pattern)
        active_indices = np.where(distances <= self.threshold)[0]
        if active_indices.size:
            signs = 2 * pattern - 1  # shape (n_dims,)
            self.counters[active_indices] += signs  # broadcasts over rows
        self.patterns_stored += 1
        self.last_activation_count = int(active_indices.size)
        return self.last_activation_count
    
    def retrieve_pattern(self, query: np.ndarray) -> Tuple[np.ndarray, int]:
        """
        Retrieve pattern similar to query with improved tie-breaking.
        
        Args:
            query: Query pattern {0,1}^n
            
        Returns:
            (reconstructed_pattern, activation_count)
        """
        # Find activated locations
        distances = self.hamming_distances_batch(query)
        active_mask = distances <= self.threshold
        active_indices = np.where(active_mask)[0]
        
        if len(active_indices) == 0:
            # No activations - return random pattern
            return np.random.randint(0, 2, self.n_dims, dtype=np.uint8), 0
        
        # Sum counters from activated locations
        counter_sum = np.sum(self.counters[active_indices], axis=0)
        
        # Reconstruct with tie-breaking (improves BER)
        reconstruction = self._reconstruct(counter_sum, query)
        
        self.last_activation_count = len(active_indices)
        return reconstruction, self.last_activation_count
    
    def count_activations(self, query: np.ndarray) -> int:
        """Count how many locations activate for a query."""
        distances = self.hamming_distances_batch(query)
        return int(np.sum(distances <= self.threshold))
    
    def count_activations_expected(self, _ = None) -> float:
        """Return theoretical expected activations E[k|T] = L * F_binomial(T)."""
        return self.binom.k_from_T(self.n_locations, int(self.threshold))
    
    def compute_snr(self, k: int, P: int) -> float:
        """
        Compute theoretical signal-to-noise ratio.
        SNR = sqrt(k / (P-1))
        """
        if P <= 1:
            return float('inf')
        return np.sqrt(k / (P - 1))
    
    def clear_counters(self):
        """Reset all counters to zero."""
        self.counters.fill(0)
        self.patterns_stored = 0


# ============================================================================
# EMA Controller with Stability Guardrails
# ============================================================================

class EmaController:
    """
    EMA-based threshold controller with slope-normalized updates.
    Implements noise-aware convergence and optimal batching.

    Control law:
        z_t = (1-α) z_{t-1} + α k_t
        T_{t+1} = T_t - c * (z_t - k*) / s(T_t), capped by ΔT_max
    
    Stability: 0 < c < 2 for linear analysis (tighter in practice with noise/quantization)
    """

    def __init__(self, alpha: float, target_k: int,
                 c: float = 0.8, delta_t_max: int = 5,
                 tol_sigma: float = 2.0, consec: int = 5, batch_size: int = 1,
                 slope_mode: str = "theory", slope_m: int = 16, slope_h: int = 1):
        """
        Initialize EMA controller with noise-aware convergence.
        
        Args:
            alpha: EMA smoothing parameter (0 < α < 1), recommended: 0.05-0.2
            target_k: Target activation count k*
            c: Control gain (0 < c < 2), recommended: 0.4-1.1 (practical envelope)
            delta_t_max: Maximum threshold update per step
            tol_sigma: Convergence tolerance in units of EMA noise standard deviation
            consec: Required consecutive steps within tolerance for convergence
            batch_size: Number of queries to average per measurement (reduces noise)
            slope_mode: "theory" for binomial PMF, "empirical" for data-driven estimate
            slope_m: Number of queries for empirical slope estimation
            slope_h: Finite-difference step size for empirical slope
        """
        self.alpha = float(alpha)
        self.target_k = int(target_k)
        self.c = float(c)
        self.delta_t_max = int(delta_t_max)
        self.tol_sigma = float(tol_sigma)
        self.consec = int(consec)
        self.batch_size = int(batch_size)
        self.slope_mode = str(slope_mode)
        self.slope_m = int(slope_m)
        self.slope_h = int(slope_h)

        # State variables
        self.k_smooth = None
        self.k_history: List[float] = []
        self.threshold_history: List[int] = []
        self.converged = False
        self.convergence_step = None
        self._in_tol_run = 0
        
        # Stability guardrail
        self.c_practical_max = 1.1  # recommended envelope from noisy experiments
        if self.c > self.c_practical_max and len(self.threshold_history) == 0:
            warnings.warn(
                f"c={self.c} exceeds practical envelope (~{self.c_practical_max}); "
                f"expect oscillations with noise."
            )

    @staticmethod
    def _slope_at_T(n_dims: int, n_locations: int, T: int, 
                    s_min: float = 1e-3, table: Optional[BinomTable] = None) -> float:
        """
        Discrete slope: dE[k]/dT ≈ L * BinomPMF(T; n, 0.5).
        This matches 1-step threshold increments.
        """
        if table is not None:
            return table.slope_at_T(n_locations, int(T), s_min)
        pmf = stats.binom.pmf(T, n_dims, 0.5)
        s = n_locations * pmf
        return max(float(s), s_min)

    def _slope_empirical(self, sdm: "SimpleSdm", T: int, m: Optional[int] = None, 
                        h: Optional[int] = None, rng: Optional[np.random.Generator] = None) -> float:
        """
        Empirical slope via symmetric finite difference around T using m queries.
        Returns max(value, 1e-3) to avoid division blow-ups.
        
        Args:
            sdm: SDM instance
            T: Current threshold value
            m: Number of queries per threshold (uses self.slope_m if None)
            h: Step size for finite difference (uses self.slope_h if None)
            rng: Random generator for reproducibility
            
        Returns:
            Estimated slope dE[k]/dT
        """
        rng = rng or np.random.default_rng()
        m = self.slope_m if m is None else m
        h = self.slope_h if h is None else h
        Tl = max(1, T - h)
        Th = min(sdm.n_dims - 1, T + h)

        # Preserve current threshold
        T_save = int(sdm.threshold)

        # Average k at Tl
        sdm.threshold = Tl
        ks = []
        for _ in range(m):
            q = rng.integers(0, 2, sdm.n_dims, dtype=np.uint8)
            ks.append(sdm.count_activations(q))
        k_l = float(np.mean(ks))

        # Average k at Th
        sdm.threshold = Th
        ks = []
        for _ in range(m):
            q = rng.integers(0, 2, sdm.n_dims, dtype=np.uint8)
            ks.append(sdm.count_activations(q))
        k_h = float(np.mean(ks))

        # Restore threshold
        sdm.threshold = T_save

        slope = (k_h - k_l) / float(max(2*h, 1))
        return float(max(slope, 1e-3))

    def update_threshold(self, sdm: SimpleSdm, query: np.ndarray, 
                        fast_synthetic: bool = False, 
                        rng: Optional[np.random.Generator] = None) -> int:
        """
        Update threshold using EMA control with noise-aware convergence.
        
        Args:
            sdm: SDM instance to control
            query: Query pattern (can be ignored if batch_size > 1 or fast_synthetic)
            fast_synthetic: Use fast binomial draws instead of real Hamming distances
            rng: Random generator for reproducibility
            
        Returns:
            Current activation count
        """
        rng = rng or np.random.default_rng()
        
        # 1) Measure activations (optionally batched to reduce noise)
        if self.batch_size <= 1:
            k_obs = sdm.count_activations(query)
        elif fast_synthetic:
            # Fast synthetic measurements using binomial distribution
            p = stats.binom.cdf(sdm.threshold, sdm.n_dims, 0.5)
            k_obs = float(np.mean([rng.binomial(sdm.n_locations, p) 
                                   for _ in range(self.batch_size)]))
        else:
            # Real Hamming distance computations
            ks = []
            for _ in range(self.batch_size):
                q = rng.integers(0, 2, sdm.n_dims, dtype=np.uint8)
                ks.append(sdm.count_activations(q))
            k_obs = float(np.mean(ks))

        # 2) EMA update
        if self.k_smooth is None:
            self.k_smooth = sdm.count_activations_expected()
        else:
            self.k_smooth = (1 - self.alpha) * self.k_smooth + self.alpha * float(k_obs)

        # 3) Error and slope
        err = self.k_smooth - self.target_k
        
        # Select slope computation method
        if self.slope_mode == "theory":
            s = self._slope_at_T(sdm.n_dims, sdm.n_locations, 
                               int(sdm.threshold), table=sdm.binom)
        else:
            s = self._slope_empirical(sdm, int(sdm.threshold), rng=rng)

        # 4) Slope-normalized control (bounded)
        u = - self.c * err / s
        u = float(np.clip(u, -self.delta_t_max, self.delta_t_max))

        # 5) Apply and clip threshold
        sdm.threshold = int(np.clip(int(round(sdm.threshold + u)), 1, sdm.n_dims - 1))

        # 6) Noise-aware convergence band (consistent σ formula)
        p_target = max(1.0 / sdm.n_locations, self.target_k / sdm.n_locations)
        sigma_k = np.sqrt(self.target_k * (1.0 - p_target))
        sigma_ema = np.sqrt(self.alpha / (2.0 - self.alpha)) * sigma_k / np.sqrt(self.batch_size)
        in_band = (abs(err) <= self.tol_sigma * sigma_ema)

        # 7) Bookkeeping
        self.k_history.append(self.k_smooth)
        self.threshold_history.append(sdm.threshold)

        # 8) Convergence test: within sigma band for N consecutive steps
        if in_band:
            self._in_tol_run += 1
            if (not self.converged) and (self._in_tol_run >= self.consec):
                self.converged = True
                self.convergence_step = len(self.k_history)
        else:
            self._in_tol_run = 0

        return int(round(k_obs))

    def update_threshold_noiseless(self, sdm: SimpleSdm, _ = None) -> float:
        """
        Update threshold using expected activations instead of noisy measurements.
        Useful for testing pure stability without observation noise.
        """
        # Use theoretical expectation E[k|T] instead of noisy measurement
        k_expected = sdm.count_activations_expected()

        # EMA update with expected value
        if self.k_smooth is None:
            self.k_smooth = float(k_expected)
        else:
            self.k_smooth = (1 - self.alpha) * self.k_smooth + self.alpha * float(k_expected)

        # Error and slope
        err = self.k_smooth - self.target_k
        
        # Select slope computation method
        if self.slope_mode == "theory":
            s = self._slope_at_T(sdm.n_dims, sdm.n_locations, 
                               int(sdm.threshold), table=sdm.binom)
        else:
            s = self._slope_empirical(sdm, int(sdm.threshold))

        # Control law (consistent σ formula)
        u = - self.c * err / s
        u = float(np.clip(u, -self.delta_t_max, self.delta_t_max))
        sdm.threshold = int(np.clip(int(round(sdm.threshold + u)), 1, sdm.n_dims - 1))

        # For noiseless case, convergence is when error is very small
        self.k_history.append(self.k_smooth)
        self.threshold_history.append(sdm.threshold)
        
        if abs(err) < 0.01:
            self._in_tol_run += 1
            if (not self.converged) and (self._in_tol_run >= self.consec):
                self.converged = True
                self.convergence_step = len(self.k_history)
        else:
            self._in_tol_run = 0

        return k_expected

    def get_convergence_time(self) -> Optional[int]:
        """Return step at which convergence was achieved."""
        return self.convergence_step


# ============================================================================
# Theoretical Calculations
# ============================================================================

def compute_theoretical_k(n_dims: int, n_locations: int, threshold: int) -> float:
    """
    Compute theoretical expected activations using binomial CDF.
    k = L * F_binomial(T, n, 0.5)
    """
    binom = stats.binom(n_dims, 0.5)
    p_activation = binom.cdf(threshold)
    return n_locations * p_activation


def compute_theoretical_threshold(n_dims: int, n_locations: int, target_k: int) -> int:
    """
    Compute optimal threshold for target activation count.
    T* = F_binomial^(-1)(k*/L)
    This is the single source of truth for computing T*.
    """
    p_target = target_k / n_locations
    binom = stats.binom(n_dims, 0.5)
    return int(binom.ppf(p_target))


def predict_steps_linear(alpha: float, c: float, s_star: float,
                         T0: int, T_star: int, delta_t_max: int,
                         eps_k: float, consec: int = 5, L: Optional[int] = None,
                         n: Optional[int] = None) -> float:
    """
    2×2 linear convergence model: A = [[1, -c/s*], [α s*, 1-α(1+c)]]
    N ≈ ceil(|T0-T*|/ΔTmax) + ceil(log(ε0/εk)/(-log ρ)) + (consec-1)

    Uses s0 at T0 for ε0, s* for eigenvalues (linear stage near T*).
    """
    # Local slope at start (better ε0)
    if (L is not None) and (n is not None):
        s0 = L * stats.binom.pmf(T0, n, 0.5)
    else:
        s0 = s_star  # fallback

    A = np.array([[1.0, -c / s_star],
                  [alpha * s_star, 1.0 - alpha * (1.0 + c)]], dtype=float)
    rho = max(abs(np.linalg.eigvals(A)))
    if rho >= 1:
        return np.inf

    N_slew = np.ceil(abs(T0 - T_star) / float(delta_t_max))
    eps0 = max(s0 * abs(T0 - T_star), 1e-9)
    N_linear = np.ceil(np.log(max(eps0 / eps_k, 1.0001)) / (-np.log(rho)))
    return float(N_slew + N_linear + (consec - 1))


def c_max_for_alpha(alpha: float, s_star: float, c_hi: float = 1.0) -> float:
    """
    Find max c with ρ(A) < 1 for A = [[1, -c/s*], [α s*, 1-α(1+c)]].
    Expanding-bracket + bisection so we never just 'hit the ceiling'.
    
    Returns theoretical upper bound (tighter in practice due to noise + quantization).
    """
    def rho(c):
        A = np.array([[1.0, -c / s_star],
                      [alpha * s_star, 1.0 - alpha * (1.0 + c)]], dtype=float)
        return max(abs(np.linalg.eigvals(A)))

    # Expand hi until unstable or max tries
    hi = float(c_hi)
    for _ in range(30):
        if rho(hi) >= 1.0:
            break
        hi *= 2.0
    if rho(hi) < 1.0:
        # Could not find instability; return current hi as conservative floor
        return hi

    lo = 0.0
    for _ in range(60):
        mid = 0.5 * (lo + hi)
        if rho(mid) < 1.0:
            lo = mid
        else:
            hi = mid
    return lo


# ============================================================================
# Utility Functions
# ============================================================================

def measure_recall_accuracy(original: np.ndarray, reconstructed: np.ndarray) -> float:
    """Measure fraction of correctly recalled bits."""
    return float(np.mean(original == reconstructed))


def add_noise_to_pattern(pattern: np.ndarray, noise_fraction: float, 
                         rng: Optional[np.random.Generator] = None) -> np.ndarray:
    """Flip a fraction of bits to add noise."""
    rng = rng or np.random.default_rng()
    noisy = pattern.copy()
    n_flip = int(len(pattern) * noise_fraction)
    flip_indices = rng.choice(len(pattern), n_flip, replace=False)
    noisy[flip_indices] = 1 - noisy[flip_indices]
    return noisy


# ============================================================================
# Baseline Controllers (for comparison)
# ============================================================================

class BaselineController:
    """Base class for baseline comparison methods."""
    
    def __init__(self, target_k: int):
        self.target_k = target_k
        self.convergence_time = None
        self.performance_history = []
        
    def update(self, sdm: SimpleSdm, step: int) -> float:
        """Update controller and return performance metric."""
        raise NotImplementedError


class FixedOracleController(BaselineController):
    """Oracle with perfect knowledge of optimal threshold."""
    
    def __init__(self, target_k: int, n_dims: int, n_locations: int):
        super().__init__(target_k)
        self.optimal_threshold = compute_theoretical_threshold(n_dims, n_locations, target_k)
        self.convergence_time = 0  # Instant "convergence"
        
    def update(self, sdm: SimpleSdm, step: int) -> float:
        sdm.threshold = self.optimal_threshold
        query = np.random.randint(0, 2, sdm.n_dims, dtype=np.uint8)
        k = sdm.count_activations(query)
        performance = 1.0 - abs(k - self.target_k) / self.target_k
        self.performance_history.append(performance)
        return performance


class GridSearchController(BaselineController):
    """Periodic grid search optimization."""
    
    def __init__(self, target_k: int, search_interval: int = 500):
        super().__init__(target_k)
        self.search_interval = search_interval
        self.best_threshold = None
        
    def update(self, sdm: SimpleSdm, step: int) -> float:
        if step % self.search_interval == 0:
            # Perform grid search
            best_error = float('inf')
            threshold_range = range(
                max(1, sdm.threshold - 50),
                min(sdm.n_dims, sdm.threshold + 51),
                2
            )
            
            for T in threshold_range:
                sdm.threshold = T
                k_samples = []
                for _ in range(10):
                    query = np.random.randint(0, 2, sdm.n_dims, dtype=np.uint8)
                    k = sdm.count_activations(query)
                    k_samples.append(k)
                error = abs(np.mean(k_samples) - self.target_k)
                if error < best_error:
                    best_error = error
                    self.best_threshold = T
                    
            sdm.threshold = self.best_threshold
            
            if not self.convergence_time and best_error / self.target_k < 0.05:
                self.convergence_time = step
                
        query = np.random.randint(0, 2, sdm.n_dims, dtype=np.uint8)
        k = sdm.count_activations(query)
        performance = 1.0 - abs(k - self.target_k) / self.target_k
        self.performance_history.append(performance)
        return performance