"""
S3: Selective Admission Data Loader

This scenario tests the algorithm's ADMISSION CONTROL capability.
Arrivals have variable quality, and the algorithm must learn
when to accept (high quality) vs reject (low quality).

Design Principles:
- Arrivals have variable reward-to-cost ratios
- Budget is TIGHT - can't accept everything
- Must learn optimal price threshold for admission
- Random accepts everything → runs out of budget on bad arrivals
- SP-UCB-OLP learns to reject low-quality arrivals

Setup (K=2, d=1):
- Config 0: r ~ Uniform(0.5, 1.5), a ~ Uniform(0.8, 1.2)
  → Variable efficiency, some arrivals good, some bad
- Config 1: r ~ Uniform(0.3, 1.0), a ~ Uniform(0.5, 1.5)
  → More variable, requires selective admission

Budget: Very tight - can only accept ~50% of arrivals

Key insight:
- Optimal policy: set price p > 0 to filter bad arrivals
- Accept only when r >= p*a (positive surplus)
- Random accepts all → wastes budget on r < p*a arrivals
- Greedy may not learn optimal price quickly
- SP-UCB-OLP should learn price from saddle-point optimization
"""

import numpy as np
from typing import Dict, Any, Tuple
from .base_loader import BaseDataLoader


class S3DominantLoader(BaseDataLoader):
    """
    S3: Selective Admission with variable arrival quality.

    Parameters
    ----------
    K : int
        Number of configs (default: 2)
    T : int
        Time horizon
    seed : int
        Random seed
    budget_tightness : float
        How tight the budget is (lower = tighter). Default 0.5 means
        can only afford ~50% of arrivals if accepting all.
    """

    def __init__(
        self,
        K: int = 2,
        T: int = 10000,
        seed: int = 42,
        budget_tightness: float = 0.5,
        d: int = 3,  # Number of resource dimensions
    ):
        super().__init__(K, d, T, seed)

        self.budget_tightness = budget_tightness

        # Config profiles define the distribution of arrivals
        # For d=3: consumption is a vector, uniform over a range per dimension
        # Higher variance means more value from selective admission
        self._config_profiles = {
            0: {
                'reward_low': 0.5, 'reward_high': 1.5,
                'consumption_low': np.array([0.3, 0.25, 0.25]),
                'consumption_high': np.array([0.5, 0.4, 0.3]),
            },
            1: {
                'reward_low': 0.3, 'reward_high': 1.2,
                'consumption_low': np.array([0.2, 0.15, 0.15]),
                'consumption_high': np.array([0.6, 0.5, 0.4]),
            },
        }

        # Extend for K > 2
        if K > 2:
            rng = np.random.RandomState(seed)
            for k in range(2, K):
                self._config_profiles[k] = {
                    'reward_low': 0.2 + 0.3 * rng.random(),
                    'reward_high': 0.8 + 0.5 * rng.random(),
                    'consumption_low': 0.2 + 0.3 * rng.random(self.d),
                    'consumption_high': 0.6 + 0.4 * rng.random(self.d),
                }

        self._generate_arrivals()
        self._compute_nominal_budget()

        # Pre-generate Monte Carlo samples for surplus computation (cache for reuse)
        self._mc_samples = {}
        self._generate_mc_samples()

    def _generate_arrivals(self):
        """Generate arrivals with variable quality and d-dimensional consumption."""
        rng = np.random.RandomState(self.seed)

        for theta in range(self.K):
            profile = self._config_profiles[theta]
            arrivals = np.zeros((self.T, self.d + 1))

            for t in range(self.T):
                # Reward uniform in [low, high]
                r = rng.uniform(profile['reward_low'], profile['reward_high'])

                # d-dimensional consumption uniform in [low, high] per dimension
                a = rng.uniform(profile['consumption_low'], profile['consumption_high'])

                arrivals[t, 0] = r
                arrivals[t, 1:] = a

            self._arrivals[theta] = arrivals

    def _compute_nominal_budget(self):
        """
        Compute tight budget that forces selective admission.
        """
        # Compute expected consumption across all configs (d-dimensional)
        total_expected_consumption = np.zeros(self.d)
        for theta in range(self.K):
            profile = self._config_profiles[theta]
            expected_a = (profile['consumption_low'] + profile['consumption_high']) / 2
            total_expected_consumption += expected_a
        avg_consumption = total_expected_consumption / self.K

        # Budget = tightness * T * avg_consumption (d-dimensional)
        # tightness=0.5 means budget for ~50% of arrivals
        self._nominal_budget = self.budget_tightness * avg_consumption * self.T

    def _generate_mc_samples(self, n_samples: int = 10000):
        """Pre-generate Monte Carlo samples for all configs (cached for reuse)."""
        rng = np.random.RandomState(42)  # Fixed seed for reproducibility

        for theta in range(self.K):
            profile = self._config_profiles[theta]

            r_samples = rng.uniform(profile['reward_low'], profile['reward_high'], n_samples)
            a_samples = rng.uniform(
                profile['consumption_low'],
                profile['consumption_high'],
                (n_samples, self.d)
            )

            self._mc_samples[theta] = (r_samples, a_samples)

    def get_arrival(self, theta: int, t: int) -> Tuple[float, np.ndarray]:
        """Get arrival."""
        if theta < 0 or theta >= self.K:
            raise ValueError(f"theta {theta} out of range [0, {self.K-1}]")
        if t < 0 or t >= self.T:
            raise ValueError(f"t {t} out of range [0, {self.T-1}]")

        arrival = self._arrivals[theta][t]
        return float(arrival[0]), arrival[1:].copy()

    def get_budget(self, rho: float) -> np.ndarray:
        """Get budget scaled by rho."""
        return rho * self._nominal_budget

    def _compute_surplus(self, theta: int, p: np.ndarray, n_samples: int = 10000) -> float:
        """
        Compute g_θ(p) = E[(r - <p, a>)_+] using cached Monte Carlo samples.

        Uses pre-generated samples for efficiency (avoids regenerating on each call).
        """
        # Use cached samples (generated once at init)
        r_samples, a_samples = self._mc_samples[theta]

        # Safeguard: clip prices to valid range and handle numerical issues
        p_safe = np.clip(np.nan_to_num(p, nan=0.0, posinf=100.0, neginf=0.0), 0, 100.0)

        # Surplus: (r - <p, a>)_+
        cost = a_samples @ p_safe  # (n_samples,)
        surplus = np.maximum(r_samples - cost, 0)
        return np.mean(surplus)

    def _solve_V_mix(self, b: np.ndarray, n_grid: int = 20) -> Tuple[float, np.ndarray, np.ndarray]:
        """
        Solve V^mix(b) = min_p { <p, b> + max_θ g_θ(p) } via optimization.
        """
        from scipy.optimize import minimize

        p_max = 5.0

        def objective(p):
            """Objective: <p, b> + max_θ g_θ(p)"""
            surpluses = np.array([self._compute_surplus(k, p) for k in range(self.K)])
            envelope = np.max(surpluses)
            return np.dot(p, b) + envelope

        # Multi-start optimization
        best_val = np.inf
        best_p = np.zeros(self.d)

        n_restarts = 5
        for restart in range(n_restarts):
            if restart == 0:
                p_init = np.zeros(self.d)
            else:
                p_init = np.random.uniform(0, p_max / 2, self.d)

            result = minimize(
                objective,
                p_init,
                method='L-BFGS-B',
                bounds=[(0, p_max)] * self.d,
                options={'maxiter': 100, 'ftol': 1e-6}
            )

            if result.fun < best_val:
                best_val = result.fun
                best_p = result.x.copy()

        # Compute mixture achieving envelope at optimal price
        best_surpluses = np.array([self._compute_surplus(k, best_p) for k in range(self.K)])
        max_surplus = np.max(best_surpluses)
        w_star = (best_surpluses >= max_surplus - 1e-8).astype(float)
        w_star /= w_star.sum()

        return best_val, best_p, w_star

    def get_oracle_values(self, rho: float = 1.0) -> Dict[str, Any]:
        """
        Compute theoretical oracle values.

        Parameters
        ----------
        rho : float
            Budget scaling factor (default: 1.0)
        """
        from scipy.optimize import minimize

        # Per-period budget scaled by rho (d-dimensional)
        b = rho * self._nominal_budget / self.T

        # Solve switching-aware oracle
        V_mix, p_star, w_star = self._solve_V_mix(b)

        # Compute V* = max_θ min_p { <p, b> + g_θ(p) }
        V_star = 0.0
        best_theta = 0
        p_max = 5.0

        for theta in range(self.K):
            def obj_theta(p, th=theta):
                g = self._compute_surplus(th, p)
                return np.dot(p, b) + g

            result = minimize(
                obj_theta,
                np.zeros(self.d),
                method='L-BFGS-B',
                bounds=[(0, p_max)] * self.d,
                options={'maxiter': 100}
            )

            if result.fun > V_star:
                V_star = result.fun
                best_theta = theta

        return {
            'V_mix': V_mix,
            'V_star': V_star,
            'gap': V_mix - V_star,
            'w_star': w_star,
            'p_star': p_star,
            'best_fixed_config': best_theta,
            'rho': rho,
            'b_per_period': b,
        }

    def compute_arrival_statistics(self) -> Dict[str, Any]:
        """Compute statistics about arrival quality distribution."""
        stats = {}

        for theta in range(self.K):
            arrivals = self._arrivals[theta]
            rewards = arrivals[:, 0]
            consumptions = arrivals[:, 1:]  # (T, d)
            total_consumption = np.sum(consumptions, axis=1)  # (T,)
            efficiencies = rewards / total_consumption

            stats[f'config_{theta}'] = {
                'mean_reward': float(np.mean(rewards)),
                'mean_consumption': list(np.mean(consumptions, axis=0)),
                'mean_total_consumption': float(np.mean(total_consumption)),
                'mean_efficiency': float(np.mean(efficiencies)),
                'std_efficiency': float(np.std(efficiencies)),
                'min_efficiency': float(np.min(efficiencies)),
                'max_efficiency': float(np.max(efficiencies)),
                'pct_high_quality': float(np.mean(efficiencies > 1.0)),
            }

        return stats

    def get_metadata(self) -> Dict[str, Any]:
        """Get metadata about S3."""
        base = super().get_metadata()

        return {
            **base,
            'family': 'S3',
            'name': 'Selective Admission',
            'budget_tightness': self.budget_tightness,
            'deterministic': False,
        }


def test_s3_loader():
    """Test S3 data loader."""
    print("=" * 60)
    print("Testing S3: Selective Admission Loader (d=3)")
    print("=" * 60)

    loader = S3DominantLoader(K=2, T=1000, seed=42, budget_tightness=0.5, d=3)

    print(f"\nDimensions: K={loader.K}, d={loader.d}, T={loader.T}")

    print("\nConfig Profiles:")
    for theta in range(loader.K):
        profile = loader._config_profiles[theta]
        print(f"  Config {theta}:")
        print(f"    Reward: [{profile['reward_low']:.2f}, {profile['reward_high']:.2f}]")
        print(f"    Consumption low: {profile['consumption_low']}")
        print(f"    Consumption high: {profile['consumption_high']}")

    print("\nArrival Statistics:")
    stats = loader.compute_arrival_statistics()
    for config, config_stats in stats.items():
        print(f"\n  {config}:")
        for key, value in config_stats.items():
            if isinstance(value, float):
                print(f"    {key}: {value:.3f}")
            else:
                print(f"    {key}: {value}")

    print("\nSample arrivals:")
    for theta in range(loader.K):
        print(f"  Config {theta}:")
        for t in range(3):
            r, a = loader.get_arrival(theta, t)
            total_a = np.sum(a)
            eff = r / total_a
            print(f"    t={t}: r={r:.3f}, a={a}, eff={eff:.3f}")

    print("\nOracle Values:")
    oracle = loader.get_oracle_values()
    print(f"  V^mix ≈ {oracle['V_mix']:.4f}")
    print(f"  V* ≈ {oracle['V_star']:.4f}")
    print(f"  p* = {oracle['p_star']}")
    print(f"  w* = {oracle['w_star']}")

    print(f"\nBudget (rho=1.0): {loader.get_budget(1.0)}")
    print(f"Per-period budget: {loader.get_budget(1.0) / loader.T}")

    print(f"\nValidation: {loader.validate()}")
    print("\n" + "=" * 60)


if __name__ == "__main__":
    test_s3_loader()
