"""
Baseline Algorithms for Comparison

This module implements baseline algorithms for comparison with SP-UCB-OLP:

1. SPGreedyOLP: Greedy selection without exploration bonus
2. OneHotSPUCBOLP: Per-config UCB (wrong abstraction, chooses best fixed config)
3. OraclePolicy: Uses oracle mixture and price (upper bound)
4. RandomPolicy: Uniform random selection
5. FixedConfigPolicy: Sticks to a single config
"""

import numpy as np
from scipy.optimize import minimize
from typing import Dict, List, Tuple, Optional, Any
from .base import BaseAlgorithm


class SPGreedyOLP(BaseAlgorithm):
    """
    Greedy Saddle-Point OLP without exploration.

    Same as SP-UCB-OLP but with alpha=0 (no confidence bonus).
    This is a pure exploitation baseline that doesn't explore.

    Expected behavior:
    - May get stuck on suboptimal config early
    - No regret guarantees
    """

    def __init__(
        self,
        K: int,
        d: int,
        T: int,
        B: np.ndarray,
        config: Dict[str, Any] = None
    ):
        config = config or {}
        config['alpha'] = 0.0  # No exploration
        super().__init__(K, d, T, B, config)

        self.n_restarts = self.config.get('n_restarts', 2)  # Reduced for speed
        self.warm_start = self.config.get('warm_start', True)

        # Caching for efficiency - recompute every sqrt(T) steps
        self.last_solve_time = -1
        self.solve_frequency = max(1, int(np.sqrt(T)))
        self._last_optimal_p = None

        # scipy with caching is faster than Gurobi for this problem structure
        self.use_gurobi = self.config.get('use_gurobi', False)
        self._gurobi_available = None

    def _check_gurobi(self) -> bool:
        """Check if Gurobi is available (cached)."""
        if self._gurobi_available is None:
            try:
                import gurobipy
                self._gurobi_available = True
            except ImportError:
                self._gurobi_available = False
        return self._gurobi_available

    def _solve_greedy(self, t: int) -> Tuple[np.ndarray, np.ndarray]:
        """
        Solve greedy problem (no UCB bonus) with warm starting.

        min_p { <p, b_safe> + max_theta g_hat_theta(p) }

        Uses Gurobi LP if available, falls back to scipy.
        """
        # Zero confidence radii for greedy (no exploration)
        beta = np.zeros(self.K)

        # Try Gurobi first
        if self.use_gurobi and self._check_gurobi():
            try:
                w, p, _ = self.solve_saddle_point_gurobi(self.b_safe, beta, self.P_max)
                self._last_optimal_p = p.copy()
                return w, p
            except Exception:
                pass

        # Fallback: scipy
        return self._solve_greedy_scipy()

    def _solve_greedy_scipy(self) -> Tuple[np.ndarray, np.ndarray]:
        """Solve greedy problem using scipy (fallback)."""
        def objective(p):
            linear_term = np.dot(p, self.b_safe)
            surpluses = np.zeros(self.K)
            for theta in range(self.K):
                surpluses[theta] = self.compute_empirical_surplus(theta, p)
            envelope = np.max(surpluses)
            return linear_term + envelope

        best_val = np.inf
        best_p = np.zeros(self.d)

        # Use previous solution as warm start if available
        initial_points = []
        if self._last_optimal_p is not None:
            initial_points.append(self._last_optimal_p)
        initial_points.append(np.zeros(self.d))
        if self._last_optimal_p is None:
            for _ in range(self.n_restarts - 1):
                initial_points.append(np.random.uniform(0, self.P_max / 2, self.d))

        for p_init in initial_points:
            result = minimize(
                objective,
                p_init,
                method='L-BFGS-B',
                bounds=[(0, self.P_max)] * self.d,
                options={'maxiter': 100}
            )

            if result.fun < best_val:
                best_val = result.fun
                best_p = result.x.copy()

        # Store for warm start next time
        self._last_optimal_p = best_p.copy()

        # Compute mixture at optimal price
        surpluses = np.zeros(self.K)
        for theta in range(self.K):
            surpluses[theta] = self.compute_empirical_surplus(theta, best_p)

        max_surplus = np.max(surpluses)
        achieving_max = (surpluses >= max_surplus - 1e-8)
        w = achieving_max.astype(float)
        w /= w.sum()

        return w, best_p

    def select_config(self, t: int) -> Tuple[int, np.ndarray, np.ndarray]:
        """Select config greedily with caching."""
        self.t = t

        if self.warm_start and t < self.K:
            theta = t % self.K
            w = np.zeros(self.K)
            w[theta] = 1.0
            p = np.zeros(self.d)
            self.current_theta = theta
            self.current_w = w
            self.current_p = p
            return theta, w, p

        # Only recompute every solve_frequency steps
        if t - self.last_solve_time >= self.solve_frequency or self.current_w is None:
            w, p = self._solve_greedy(t)
            self.current_w = w
            self.current_p = p
            self.last_solve_time = t
        else:
            w = self.current_w
            p = self.current_p

        theta = np.random.choice(self.K, p=w)

        self.current_theta = theta
        return theta, w, p

    def decide_admission(
        self,
        t: int,
        theta: int,
        r: float,
        a: np.ndarray,
        p: np.ndarray
    ) -> bool:
        """Accept if surplus non-negative and budget allows."""
        if np.any(a > self.B_remaining + 1e-9):
            return False
        return r >= np.dot(p, a) - 1e-6  # Small tolerance for numerical stability

    def __repr__(self) -> str:
        return f"SPGreedyOLP(K={self.K}, d={self.d}, T={self.T})"


class OneHotSPUCBOLP(BaseAlgorithm):
    """
    One-Hot (Per-Config) UCB Selection.

    This is the WRONG abstraction that treats configs as independent arms.
    It selects the config with highest UCB value:

        theta_t = argmax_theta (V_hat_theta + beta_theta)

    where V_hat_theta = min_p {<p, b> + g_hat_theta(p)} is the per-config value.

    This baseline demonstrates the failure mode when we don't account for
    the switching-aware oracle structure.

    Expected behavior:
    - May exceed V* (fixed-config oracle) but not V^mix
    - Ignores complementarity between configs
    """

    def __init__(
        self,
        K: int,
        d: int,
        T: int,
        B: np.ndarray,
        config: Dict[str, Any] = None
    ):
        super().__init__(K, d, T, B, config)

        self.alpha = self.config.get('alpha', 1.0)
        self.n_restarts = self.config.get('n_restarts', 3)
        self.warm_start = self.config.get('warm_start', True)

        # Per-config prices
        self.p_per_config = np.zeros((self.K, self.d))

        # scipy with caching is faster than Gurobi for this problem structure
        self.use_gurobi = self.config.get('use_gurobi', False)
        self._gurobi_available = None

    def _check_gurobi(self) -> bool:
        """Check if Gurobi is available (cached)."""
        if self._gurobi_available is None:
            try:
                import gurobipy
                self._gurobi_available = True
            except ImportError:
                self._gurobi_available = False
        return self._gurobi_available

    def _compute_per_config_value(self, theta: int) -> Tuple[float, np.ndarray]:
        """
        Compute V_hat_theta = min_p {<p, b> + g_hat_theta(p)}.

        Returns value and optimal price for this config.
        Uses Gurobi if available, falls back to scipy.
        """
        # Try Gurobi first - use saddle-point solver with single config
        if self.use_gurobi and self._check_gurobi():
            try:
                # Create beta with only this config having 0, others having -inf
                # Or simpler: use saddle-point with beta=0 for single config
                beta = np.full(self.K, -1e10)  # Large negative = never selected
                beta[theta] = 0.0
                _, p_opt, obj_val = self.solve_saddle_point_gurobi(self.b, beta, self.P_max)
                return obj_val, p_opt
            except Exception:
                pass

        # Fallback: scipy
        return self._compute_per_config_value_scipy(theta)

    def _compute_per_config_value_scipy(self, theta: int) -> Tuple[float, np.ndarray]:
        """Compute per-config value using scipy (fallback)."""
        def objective(p):
            linear_term = np.dot(p, self.b)
            g_hat = self.compute_empirical_surplus(theta, p)
            return linear_term + g_hat

        best_val = np.inf
        best_p = np.zeros(self.d)

        for restart in range(self.n_restarts):
            if restart == 0:
                p_init = np.zeros(self.d)
            else:
                p_init = np.random.uniform(0, self.P_max / 2, self.d)

            result = minimize(
                objective,
                p_init,
                method='L-BFGS-B',
                bounds=[(0, self.P_max)] * self.d,
                options={'maxiter': 100}
            )

            if result.fun < best_val:
                best_val = result.fun
                best_p = result.x.copy()

        return best_val, best_p

    def select_config(self, t: int) -> Tuple[int, np.ndarray, np.ndarray]:
        """Select config with highest per-config UCB value."""
        self.t = t

        if self.warm_start and t < self.K:
            theta = t % self.K
            w = np.zeros(self.K)
            w[theta] = 1.0
            p = np.zeros(self.d)
            self.current_theta = theta
            self.current_w = w
            self.current_p = p
            return theta, w, p

        # Compute UCB value for each config
        ucb_values = np.zeros(self.K)
        for theta in range(self.K):
            V_hat, p_theta = self._compute_per_config_value(theta)
            beta = self.compute_confidence_radius(theta, t)
            ucb_values[theta] = V_hat + beta
            self.p_per_config[theta] = p_theta

        # Select config with highest UCB
        theta = int(np.argmax(ucb_values))

        # One-hot mixture
        w = np.zeros(self.K)
        w[theta] = 1.0

        # Use per-config price
        p = self.p_per_config[theta].copy()

        self.current_theta = theta
        self.current_w = w
        self.current_p = p
        return theta, w, p

    def decide_admission(
        self,
        t: int,
        theta: int,
        r: float,
        a: np.ndarray,
        p: np.ndarray
    ) -> bool:
        """Accept using per-config price."""
        if np.any(a > self.B_remaining + 1e-9):
            return False
        return r >= np.dot(p, a) - 1e-6  # Small tolerance for numerical stability

    def __repr__(self) -> str:
        return f"OneHotSPUCBOLP(K={self.K}, d={self.d}, T={self.T}, alpha={self.alpha})"


class OraclePolicy(BaseAlgorithm):
    """
    Oracle Policy with perfect information.

    Uses the true optimal mixture w* and price p* computed from
    the population distribution (not samples).

    This is an UPPER BOUND on achievable performance, not a
    practical algorithm.

    Parameters
    ----------
    w_star : np.ndarray
        Optimal mixture from oracle computation (K,)
    p_star : np.ndarray
        Optimal price from oracle computation (d,)
    """

    def __init__(
        self,
        K: int,
        d: int,
        T: int,
        B: np.ndarray,
        config: Dict[str, Any] = None
    ):
        super().__init__(K, d, T, B, config)

        # Oracle values must be provided in config (use self.config which is never None)
        self.w_star = self.config.get('w_star', np.ones(K) / K)
        self.p_star = self.config.get('p_star', np.zeros(d))

        if isinstance(self.w_star, list):
            self.w_star = np.array(self.w_star)
        if isinstance(self.p_star, list):
            self.p_star = np.array(self.p_star)

    def select_config(self, t: int) -> Tuple[int, np.ndarray, np.ndarray]:
        """Sample from oracle mixture."""
        self.t = t

        theta = np.random.choice(self.K, p=self.w_star)
        w = self.w_star.copy()
        p = self.p_star.copy()

        self.current_theta = theta
        self.current_w = w
        self.current_p = p
        return theta, w, p

    def decide_admission(
        self,
        t: int,
        theta: int,
        r: float,
        a: np.ndarray,
        p: np.ndarray
    ) -> bool:
        """Accept using oracle price."""
        if np.any(a > self.B_remaining + 1e-9):
            return False
        return r >= np.dot(p, a) - 1e-6  # Small tolerance for numerical stability

    def __repr__(self) -> str:
        return f"OraclePolicy(K={self.K}, d={self.d}, T={self.T})"


class RandomPolicy(BaseAlgorithm):
    """
    Uniform random policy.

    Selects configs uniformly at random and accepts all
    arrivals that fit in budget.

    This is a LOWER BOUND baseline with no learning.
    """

    def __init__(
        self,
        K: int,
        d: int,
        T: int,
        B: np.ndarray,
        config: Dict[str, Any] = None
    ):
        super().__init__(K, d, T, B, config)

    def select_config(self, t: int) -> Tuple[int, np.ndarray, np.ndarray]:
        """Select config uniformly at random."""
        self.t = t

        theta = np.random.randint(0, self.K)
        w = np.ones(self.K) / self.K
        p = np.zeros(self.d)  # Accept everything that fits

        self.current_theta = theta
        self.current_w = w
        self.current_p = p
        return theta, w, p

    def decide_admission(
        self,
        t: int,
        theta: int,
        r: float,
        a: np.ndarray,
        p: np.ndarray
    ) -> bool:
        """Accept if budget allows (greedy)."""
        return np.all(a <= self.B_remaining + 1e-9)

    def __repr__(self) -> str:
        return f"RandomPolicy(K={self.K}, d={self.d}, T={self.T})"


class FixedConfigPolicy(BaseAlgorithm):
    """
    Fixed configuration policy.

    Always uses a single fixed configuration. Useful for
    measuring V* empirically.

    Parameters
    ----------
    fixed_theta : int
        The configuration to always use (default: 0)
    """

    def __init__(
        self,
        K: int,
        d: int,
        T: int,
        B: np.ndarray,
        config: Dict[str, Any] = None
    ):
        super().__init__(K, d, T, B, config)

        self.fixed_theta = config.get('fixed_theta', 0) if config else 0
        self.n_restarts = self.config.get('n_restarts', 3)

        # Optimal price for this config
        self.optimal_p = np.zeros(d)

    def _compute_optimal_price(self) -> np.ndarray:
        """Compute optimal price for fixed config."""
        theta = self.fixed_theta

        def objective(p):
            linear_term = np.dot(p, self.b)
            g_hat = self.compute_empirical_surplus(theta, p)
            return linear_term + g_hat

        best_val = np.inf
        best_p = np.zeros(self.d)

        for restart in range(self.n_restarts):
            if restart == 0:
                p_init = np.zeros(self.d)
            else:
                p_init = np.random.uniform(0, self.P_max / 2, self.d)

            result = minimize(
                objective,
                p_init,
                method='L-BFGS-B',
                bounds=[(0, self.P_max)] * self.d,
                options={'maxiter': 100}
            )

            if result.fun < best_val:
                best_val = result.fun
                best_p = result.x.copy()

        return best_p

    def select_config(self, t: int) -> Tuple[int, np.ndarray, np.ndarray]:
        """Always return the fixed config."""
        self.t = t

        theta = self.fixed_theta
        w = np.zeros(self.K)
        w[theta] = 1.0

        # Periodically update price estimate
        if t > 0 and t % max(1, int(np.sqrt(self.T))) == 0:
            self.optimal_p = self._compute_optimal_price()

        self.current_theta = theta
        self.current_w = w
        self.current_p = self.optimal_p.copy()
        return theta, w, self.optimal_p

    def decide_admission(
        self,
        t: int,
        theta: int,
        r: float,
        a: np.ndarray,
        p: np.ndarray
    ) -> bool:
        """Accept using fixed config's optimal price."""
        if np.any(a > self.B_remaining + 1e-9):
            return False
        return r >= np.dot(p, a) - 1e-6  # Small tolerance for numerical stability

    def __repr__(self) -> str:
        return f"FixedConfigPolicy(K={self.K}, d={self.d}, T={self.T}, theta={self.fixed_theta})"


# =============================================================================
# ABLATION VARIANTS - For isolating algorithm components
# =============================================================================

class EnvelopeGreedyOLP(BaseAlgorithm):
    """
    Envelope-Greedy Ablation: Compute (p_t, w_t) but SELECT greedily.

    This ablation removes the MIXTURE SAMPLING component of SP-UCB-OLP.
    Instead of sampling θ ~ w_t, it picks:

        θ_t = argmax_θ { ĝ_θ(p_t) + β_θ(t) }

    This tests whether mixture sampling is essential for complementarity.

    Expected failure mode:
    - Cannot exploit complementary resources (always picks one config)
    - Should match or beat V* but may not reach V^mix
    """

    def __init__(
        self,
        K: int,
        d: int,
        T: int,
        B: np.ndarray,
        config: Dict[str, Any] = None
    ):
        super().__init__(K, d, T, B, config)

        self.alpha = self.config.get('alpha', 0.1)
        self.beta_max = self.config.get('beta_max', 10.0 * self.R_max)
        self.n_restarts = self.config.get('n_restarts', 2)
        self.warm_start = self.config.get('warm_start', True)

        self.last_solve_time = -1
        self.solve_frequency = max(1, int(np.sqrt(T)))
        self._last_optimal_p = None
        self._ucb_values = np.zeros(self.K)  # Initialize UCB values

    def _compute_confidence_radii(self, t: int) -> np.ndarray:
        """Compute confidence radii for all configs."""
        beta = np.zeros(self.K)
        for theta in range(self.K):
            raw_beta = self.compute_confidence_radius(theta, t)
            beta[theta] = min(raw_beta, self.beta_max)
        return beta

    def _solve_envelope(self, t: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
        """
        Solve min_p { <p, b_safe> + max_θ (g_hat_θ(p) + β_θ) }.

        Returns w, p, and UCB values.
        """
        beta = self._compute_confidence_radii(t)

        def objective(p):
            linear_term = np.dot(p, self.b_safe)
            ucb_values = np.zeros(self.K)
            for theta in range(self.K):
                g_hat = self.compute_empirical_surplus(theta, p)
                ucb_values[theta] = g_hat + beta[theta]
            envelope = np.max(ucb_values)
            return linear_term + envelope

        best_val = np.inf
        best_p = np.zeros(self.d)

        initial_points = []
        if self._last_optimal_p is not None:
            initial_points.append(self._last_optimal_p)
        initial_points.append(np.zeros(self.d))

        for p_init in initial_points:
            result = minimize(
                objective,
                p_init,
                method='L-BFGS-B',
                bounds=[(0, self.P_max)] * self.d,
                options={'maxiter': 100, 'ftol': 1e-6}
            )
            if result.fun < best_val:
                best_val = result.fun
                best_p = result.x.copy()

        self._last_optimal_p = best_p.copy()

        # Compute UCB values at optimal price
        ucb_values = np.zeros(self.K)
        for theta in range(self.K):
            g_hat = self.compute_empirical_surplus(theta, best_p)
            ucb_values[theta] = g_hat + beta[theta]

        # Compute mixture (for logging) - uniform over max achievers
        max_ucb = np.max(ucb_values)
        w = (ucb_values >= max_ucb - 1e-8).astype(float)
        w /= w.sum()

        return w, best_p, ucb_values

    def select_config(self, t: int) -> Tuple[int, np.ndarray, np.ndarray]:
        """Select config GREEDILY (no mixture sampling)."""
        self.t = t

        if self.warm_start and t < self.K:
            theta = t % self.K
            w = np.zeros(self.K)
            w[theta] = 1.0
            p = np.zeros(self.d)
            self.current_theta = theta
            self.current_w = w
            self.current_p = p
            return theta, w, p

        if t - self.last_solve_time >= self.solve_frequency or self.current_w is None:
            w, p, ucb_values = self._solve_envelope(t)
            self.current_w = w
            self.current_p = p
            self._ucb_values = ucb_values
            self.last_solve_time = t
        else:
            w = self.current_w
            p = self.current_p
            ucb_values = self._ucb_values

        # KEY DIFFERENCE: Select argmax instead of sampling
        theta = int(np.argmax(ucb_values))

        self.current_theta = theta
        return theta, w, p

    def decide_admission(
        self,
        t: int,
        theta: int,
        r: float,
        a: np.ndarray,
        p: np.ndarray
    ) -> bool:
        """Accept using global price."""
        if np.any(a > self.B_remaining + 1e-9):
            return False
        return r >= np.dot(p, a) - 1e-6

    def __repr__(self) -> str:
        return f"EnvelopeGreedyOLP(K={self.K}, d={self.d}, T={self.T}, alpha={self.alpha})"


class MixtureLocalPriceOLP(BaseAlgorithm):
    """
    Mixture+LocalPrice Ablation: Use mixture w_t but per-config prices.

    This ablation removes the GLOBAL PRICE component of SP-UCB-OLP.
    It samples θ ~ w_t but uses config-specific price p_θ instead of
    the global saddle-point price p_t.

    Expected failure mode:
    - Price mismatch across configs
    - Cannot properly coordinate admission decisions
    - May have inconsistent budget utilization
    """

    def __init__(
        self,
        K: int,
        d: int,
        T: int,
        B: np.ndarray,
        config: Dict[str, Any] = None
    ):
        super().__init__(K, d, T, B, config)

        self.alpha = self.config.get('alpha', 0.1)
        self.beta_max = self.config.get('beta_max', 10.0 * self.R_max)
        self.n_restarts = self.config.get('n_restarts', 2)
        self.warm_start = self.config.get('warm_start', True)

        self.last_solve_time = -1
        self.solve_frequency = max(1, int(np.sqrt(T)))

        # Per-config prices
        self.p_per_config = np.zeros((self.K, self.d))

    def _compute_confidence_radii(self, t: int) -> np.ndarray:
        """Compute confidence radii for all configs."""
        beta = np.zeros(self.K)
        for theta in range(self.K):
            raw_beta = self.compute_confidence_radius(theta, t)
            beta[theta] = min(raw_beta, self.beta_max)
        return beta

    def _solve_per_config(self, theta: int) -> np.ndarray:
        """Solve for per-config optimal price."""
        def objective(p):
            linear_term = np.dot(p, self.b_safe)
            g_hat = self.compute_empirical_surplus(theta, p)
            return linear_term + g_hat

        best_val = np.inf
        best_p = np.zeros(self.d)

        for restart in range(self.n_restarts):
            if restart == 0:
                p_init = np.zeros(self.d)
            else:
                p_init = np.random.uniform(0, self.P_max / 2, self.d)

            result = minimize(
                objective,
                p_init,
                method='L-BFGS-B',
                bounds=[(0, self.P_max)] * self.d,
                options={'maxiter': 100}
            )
            if result.fun < best_val:
                best_val = result.fun
                best_p = result.x.copy()

        return best_p

    def _compute_mixture(self, t: int) -> np.ndarray:
        """Compute mixture weights via UCB on per-config values."""
        beta = self._compute_confidence_radii(t)

        ucb_values = np.zeros(self.K)
        for theta in range(self.K):
            # Per-config value
            p_theta = self.p_per_config[theta]
            g_hat = self.compute_empirical_surplus(theta, p_theta)
            V_hat = np.dot(p_theta, self.b_safe) + g_hat
            ucb_values[theta] = V_hat + beta[theta]

        # Mixture: uniform over max achievers
        max_ucb = np.max(ucb_values)
        w = (ucb_values >= max_ucb - 1e-8).astype(float)
        w /= w.sum()
        return w

    def select_config(self, t: int) -> Tuple[int, np.ndarray, np.ndarray]:
        """Select config from mixture, use per-config price."""
        self.t = t

        if self.warm_start and t < self.K:
            theta = t % self.K
            w = np.zeros(self.K)
            w[theta] = 1.0
            p = np.zeros(self.d)
            self.current_theta = theta
            self.current_w = w
            self.current_p = p
            return theta, w, p

        # Update per-config prices periodically
        if t - self.last_solve_time >= self.solve_frequency:
            for theta in range(self.K):
                self.p_per_config[theta] = self._solve_per_config(theta)
            self.last_solve_time = t

        # Compute mixture
        w = self._compute_mixture(t)

        # Sample from mixture
        theta = np.random.choice(self.K, p=w)

        # KEY DIFFERENCE: Use per-config price instead of global price
        p = self.p_per_config[theta].copy()

        self.current_theta = theta
        self.current_w = w
        self.current_p = p
        return theta, w, p

    def decide_admission(
        self,
        t: int,
        theta: int,
        r: float,
        a: np.ndarray,
        p: np.ndarray
    ) -> bool:
        """Accept using per-config price."""
        if np.any(a > self.B_remaining + 1e-9):
            return False
        return r >= np.dot(p, a) - 1e-6

    def __repr__(self) -> str:
        return f"MixtureLocalPriceOLP(K={self.K}, d={self.d}, T={self.T}, alpha={self.alpha})"


class NoSlackSPUCBOLP(BaseAlgorithm):
    """
    No-Slack Ablation: SP-UCB-OLP with ε=0.

    This ablation removes the SLACK parameter from the algorithm.
    Uses b instead of b_safe = (1-ε)b.

    Expected failure mode:
    - Premature budget exhaustion from variance
    - Higher regret due to feasibility violations
    - Budget may run out before horizon ends
    """

    def __init__(
        self,
        K: int,
        d: int,
        T: int,
        B: np.ndarray,
        config: Dict[str, Any] = None
    ):
        config = config or {}
        config['epsilon'] = 0.0  # No slack
        super().__init__(K, d, T, B, config)

        self.alpha = self.config.get('alpha', 0.1)
        self.beta_max = self.config.get('beta_max', 10.0 * self.R_max)
        self.n_restarts = self.config.get('n_restarts', 2)
        self.warm_start = self.config.get('warm_start', True)

        self.last_solve_time = -1
        self.solve_frequency = max(1, int(np.sqrt(T)))
        self._last_optimal_p = None

    def _compute_confidence_radii(self, t: int) -> np.ndarray:
        """Compute confidence radii for all configs."""
        beta = np.zeros(self.K)
        for theta in range(self.K):
            raw_beta = self.compute_confidence_radius(theta, t)
            beta[theta] = min(raw_beta, self.beta_max)
        return beta

    def _solve_saddle_point(self, t: int) -> Tuple[np.ndarray, np.ndarray]:
        """Solve saddle-point (same as SP-UCB-OLP but with b instead of b_safe)."""
        beta = self._compute_confidence_radii(t)

        def objective(p):
            # KEY: Uses self.b (no slack) instead of self.b_safe
            linear_term = np.dot(p, self.b)
            ucb_values = np.zeros(self.K)
            for theta in range(self.K):
                g_hat = self.compute_empirical_surplus(theta, p)
                ucb_values[theta] = g_hat + beta[theta]
            envelope = np.max(ucb_values)
            return linear_term + envelope

        best_val = np.inf
        best_p = np.zeros(self.d)

        initial_points = []
        if self._last_optimal_p is not None:
            initial_points.append(self._last_optimal_p)
        initial_points.append(np.zeros(self.d))

        for p_init in initial_points:
            result = minimize(
                objective,
                p_init,
                method='L-BFGS-B',
                bounds=[(0, self.P_max)] * self.d,
                options={'maxiter': 100, 'ftol': 1e-6}
            )
            if result.fun < best_val:
                best_val = result.fun
                best_p = result.x.copy()

        self._last_optimal_p = best_p.copy()

        # Compute mixture
        ucb_values = np.zeros(self.K)
        for theta in range(self.K):
            g_hat = self.compute_empirical_surplus(theta, best_p)
            ucb_values[theta] = g_hat + beta[theta]

        max_ucb = np.max(ucb_values)
        w = (ucb_values >= max_ucb - 1e-8).astype(float)
        w /= w.sum()

        return w, best_p

    def select_config(self, t: int) -> Tuple[int, np.ndarray, np.ndarray]:
        """Select config (same as SP-UCB-OLP)."""
        self.t = t

        if self.warm_start and t < self.K:
            theta = t % self.K
            w = np.zeros(self.K)
            w[theta] = 1.0
            p = np.zeros(self.d)
            self.current_theta = theta
            self.current_w = w
            self.current_p = p
            return theta, w, p

        if t - self.last_solve_time >= self.solve_frequency or self.current_w is None:
            w, p = self._solve_saddle_point(t)
            self.current_w = w
            self.current_p = p
            self.last_solve_time = t
        else:
            w = self.current_w
            p = self.current_p

        theta = np.random.choice(self.K, p=w)

        self.current_theta = theta
        return theta, w, p

    def decide_admission(
        self,
        t: int,
        theta: int,
        r: float,
        a: np.ndarray,
        p: np.ndarray
    ) -> bool:
        """Accept using global price."""
        if np.any(a > self.B_remaining + 1e-9):
            return False
        return r >= np.dot(p, a) - 1e-6

    def __repr__(self) -> str:
        return f"NoSlackSPUCBOLP(K={self.K}, d={self.d}, T={self.T}, alpha={self.alpha})"


class AcceptedOnlySPUCBOLP(BaseAlgorithm):
    """
    Accepted-Only Ablation: Estimate ĝ from ACCEPTED samples only.

    This ablation introduces SELECTION BIAS in surplus estimation.
    Instead of using all observed samples, it only uses accepted ones.

    Expected failure mode:
    - Overestimates surplus (accepted samples have r - <p,a> >= 0)
    - May incorrectly rank configs
    - Selection bias corrupts learning
    """

    def __init__(
        self,
        K: int,
        d: int,
        T: int,
        B: np.ndarray,
        config: Dict[str, Any] = None
    ):
        super().__init__(K, d, T, B, config)

        self.alpha = self.config.get('alpha', 0.1)
        self.beta_max = self.config.get('beta_max', 10.0 * self.R_max)
        self.n_restarts = self.config.get('n_restarts', 2)
        self.warm_start = self.config.get('warm_start', True)

        self.last_solve_time = -1
        self.solve_frequency = max(1, int(np.sqrt(T)))
        self._last_optimal_p = None

        # Track ACCEPTED samples separately
        self.accepted_samples: Dict[int, List[Tuple[float, np.ndarray]]] = {
            theta: [] for theta in range(K)
        }

    def record_outcome(
        self,
        t: int,
        theta: int,
        r: float,
        a: np.ndarray,
        accept: bool
    ):
        """Override to track accepted samples separately."""
        # Call parent to update all samples
        super().record_outcome(t, theta, r, a, accept)

        # Also track accepted samples
        if accept:
            self.accepted_samples[theta].append((r, a.copy()))

    def _compute_empirical_surplus_accepted(self, theta: int, p: np.ndarray) -> float:
        """Compute surplus using only ACCEPTED samples (biased)."""
        accepted = self.accepted_samples[theta]
        if len(accepted) == 0:
            return self.R_max  # Optimistic when no data

        total_surplus = 0.0
        for r, a in accepted:
            surplus = max(r - np.dot(p, a), 0.0)
            total_surplus += surplus

        return total_surplus / len(accepted)

    def _compute_confidence_radii(self, t: int) -> np.ndarray:
        """Compute confidence radii based on accepted sample counts."""
        beta = np.zeros(self.K)
        for theta in range(self.K):
            n_accepted = len(self.accepted_samples[theta])
            if n_accepted == 0:
                beta[theta] = self.beta_max
            else:
                # Use accepted sample count for confidence
                raw_beta = self.alpha * self.R_max * np.sqrt(
                    (self.d * np.log(t + 1) + np.log(self.K * self.T)) / n_accepted
                )
                beta[theta] = min(raw_beta, self.beta_max)
        return beta

    def _solve_saddle_point(self, t: int) -> Tuple[np.ndarray, np.ndarray]:
        """Solve saddle-point using accepted-only surplus estimates."""
        beta = self._compute_confidence_radii(t)

        def objective(p):
            linear_term = np.dot(p, self.b_safe)
            ucb_values = np.zeros(self.K)
            for theta in range(self.K):
                # KEY: Use accepted-only surplus estimate
                g_hat = self._compute_empirical_surplus_accepted(theta, p)
                ucb_values[theta] = g_hat + beta[theta]
            envelope = np.max(ucb_values)
            return linear_term + envelope

        best_val = np.inf
        best_p = np.zeros(self.d)

        initial_points = []
        if self._last_optimal_p is not None:
            initial_points.append(self._last_optimal_p)
        initial_points.append(np.zeros(self.d))

        for p_init in initial_points:
            result = minimize(
                objective,
                p_init,
                method='L-BFGS-B',
                bounds=[(0, self.P_max)] * self.d,
                options={'maxiter': 100, 'ftol': 1e-6}
            )
            if result.fun < best_val:
                best_val = result.fun
                best_p = result.x.copy()

        self._last_optimal_p = best_p.copy()

        # Compute mixture
        ucb_values = np.zeros(self.K)
        for theta in range(self.K):
            g_hat = self._compute_empirical_surplus_accepted(theta, best_p)
            ucb_values[theta] = g_hat + beta[theta]

        max_ucb = np.max(ucb_values)
        w = (ucb_values >= max_ucb - 1e-8).astype(float)
        w /= w.sum()

        return w, best_p

    def select_config(self, t: int) -> Tuple[int, np.ndarray, np.ndarray]:
        """Select config (same as SP-UCB-OLP)."""
        self.t = t

        if self.warm_start and t < self.K:
            theta = t % self.K
            w = np.zeros(self.K)
            w[theta] = 1.0
            p = np.zeros(self.d)
            self.current_theta = theta
            self.current_w = w
            self.current_p = p
            return theta, w, p

        if t - self.last_solve_time >= self.solve_frequency or self.current_w is None:
            w, p = self._solve_saddle_point(t)
            self.current_w = w
            self.current_p = p
            self.last_solve_time = t
        else:
            w = self.current_w
            p = self.current_p

        theta = np.random.choice(self.K, p=w)

        self.current_theta = theta
        return theta, w, p

    def decide_admission(
        self,
        t: int,
        theta: int,
        r: float,
        a: np.ndarray,
        p: np.ndarray
    ) -> bool:
        """Accept using global price."""
        if np.any(a > self.B_remaining + 1e-9):
            return False
        return r >= np.dot(p, a) - 1e-6

    def __repr__(self) -> str:
        return f"AcceptedOnlySPUCBOLP(K={self.K}, d={self.d}, T={self.T}, alpha={self.alpha})"
