"""
S2: Deceptive Arms (Efficiency Traps) - Redesigned

This scenario tests the algorithm's ability to discover EFFICIENT configs
when some configs appear attractive but are actually efficiency traps.

Key Design (Why Exploration Matters):
- HIGH-REWARD configs have HIGH COST and HIGH VARIANCE
- EFFICIENT configs have MODERATE REWARD but LOW COST and LOW VARIANCE
- Key: Efficiency differences are SUBTLE - not immediately obvious
- Traps can look like the best choice with limited samples
- Exploration is needed to accurately estimate efficiency

Setup (K=4, d=3):
- Config 0 (Flashy Trap): r=1.8±0.5, total_cost=1.4 → eff ~1.28
  - Looks great initially due to high reward
  - But efficiency is actually poor due to high cost
- Config 1 (Subtle Trap): r=1.4±0.4, total_cost=1.1 → eff ~1.27
  - Another trap with deceptive appearance
- Config 2 (Hidden Efficient): r=1.0±0.1, total_cost=0.6 → eff ~1.67 (OPTIMAL)
  - Lower reward looks unimpressive
  - But very high efficiency due to low cost
- Config 3 (Medium): r=1.1±0.2, total_cost=0.8 → eff ~1.38
  - Mediocre config

Budget: Tight - efficiency differences matter over long horizon.

The insight:
- Greedy attracted to high-reward configs (0, 1)
- But these drain budget quickly due to high costs
- UCB exploration discovers high-efficiency config 2
"""

import numpy as np
from typing import Dict, Any, Tuple
from .base_loader import BaseDataLoader


class S2NoisyLoader(BaseDataLoader):
    """
    S2: Deceptive Arms with efficiency-reward tradeoff.

    Parameters
    ----------
    K : int
        Number of configs (default: 4)
    T : int
        Time horizon
    seed : int
        Random seed
    noise_std : float
        Base noise level (default: 0.1, but profiles override this)
    d : int
        Number of resource dimensions (default: 3)
    """

    def __init__(
        self,
        K: int = 4,
        T: int = 10000,
        seed: int = 42,
        noise_std: float = 0.1,
        d: int = 3,
    ):
        super().__init__(K, d, T, seed)

        self.noise_std = noise_std

        # Config profiles: Deceptive efficiency traps
        # Key: Efficiency differences are subtle (~1.27 vs ~1.67)
        # Variance is higher on traps, making them look better early

        self._config_profiles = {
            # Config 0 (Flashy Trap): HIGH reward, HIGH cost, HIGH variance
            # eff = 1.8 / 1.4 ≈ 1.28 - looks great but inefficient
            0: {
                'reward': 1.8, 'reward_std': 0.5,  # HIGH variance
                'consumption': np.array([0.50, 0.48, 0.42]),  # total ~1.4
                'consumption_std': np.array([0.15, 0.14, 0.12]),  # HIGH variance
            },

            # Config 1 (Subtle Trap): Medium-high reward, medium-high cost
            # eff = 1.4 / 1.1 ≈ 1.27 - another trap
            1: {
                'reward': 1.4, 'reward_std': 0.4,  # HIGH variance
                'consumption': np.array([0.40, 0.38, 0.32]),  # total ~1.1
                'consumption_std': np.array([0.12, 0.11, 0.09]),  # HIGH variance
            },

            # Config 2 (Hidden Efficient): LOW reward but VERY LOW cost
            # eff = 1.0 / 0.6 ≈ 1.67 - TRUE OPTIMAL but doesn't look impressive
            2: {
                'reward': 1.0, 'reward_std': 0.1,  # LOW variance
                'consumption': np.array([0.22, 0.20, 0.18]),  # total ~0.6
                'consumption_std': np.array([0.03, 0.03, 0.02]),  # LOW variance
            },

            # Config 3 (Medium): In-between efficiency
            # eff = 1.1 / 0.8 ≈ 1.38
            3: {
                'reward': 1.1, 'reward_std': 0.25,
                'consumption': np.array([0.28, 0.26, 0.26]),  # total ~0.8
                'consumption_std': np.array([0.07, 0.07, 0.06]),
            },
        }

        # For K > 4, add more configs
        if K > 4:
            rng = np.random.RandomState(seed)
            for k in range(4, K):
                # Mix of traps and mediocre configs
                if k % 2 == 0:
                    # More traps
                    r = 1.3 + 0.4 * rng.random()
                    total_c = r / (1.1 + 0.3 * rng.random())  # Low efficiency
                    r_std = 0.3 + 0.2 * rng.random()
                else:
                    # Mediocre configs
                    r = 0.8 + 0.3 * rng.random()
                    total_c = r / (1.3 + 0.2 * rng.random())  # Medium efficiency
                    r_std = 0.15 + 0.1 * rng.random()

                c = rng.dirichlet(np.ones(self.d)) * total_c
                self._config_profiles[k] = {
                    'reward': r,
                    'reward_std': r_std,
                    'consumption': c,
                    'consumption_std': 0.1 * c,  # Proportional noise
                }

        self._generate_arrivals()
        self._compute_nominal_budget()

    def _generate_arrivals(self):
        """Generate stochastic arrivals with d-dimensional consumption."""
        rng = np.random.RandomState(self.seed)

        for theta in range(self.K):
            profile = self._config_profiles[theta]
            arrivals = np.zeros((self.T, self.d + 1))

            for t in range(self.T):
                # Reward with config-specific variance
                r = rng.normal(profile['reward'], profile['reward_std'])
                r = max(r, 0.01)

                # d-dimensional consumption with config-specific variance
                a = rng.normal(profile['consumption'], profile['consumption_std'])
                a = np.maximum(a, 0.01)

                arrivals[t, 0] = r
                arrivals[t, 1:] = a

            self._arrivals[theta] = arrivals

    def _compute_nominal_budget(self):
        """
        Compute nominal budget that makes efficiency matter.

        Budget based on efficient config - allows high acceptance with efficient,
        but traps exhaust budget quickly.
        """
        efficient_consumption = self._config_profiles[2]['consumption']
        self._nominal_budget = efficient_consumption * self.T * 0.7

    def get_arrival(self, theta: int, t: int) -> Tuple[float, np.ndarray]:
        """Get arrival."""
        if theta < 0 or theta >= self.K:
            raise ValueError(f"theta {theta} out of range [0, {self.K-1}]")
        if t < 0 or t >= self.T:
            raise ValueError(f"t {t} out of range [0, {self.T-1}]")

        arrival = self._arrivals[theta][t]
        return float(arrival[0]), arrival[1:].copy()

    def get_budget(self, rho: float) -> np.ndarray:
        """Get budget scaled by rho."""
        return rho * self._nominal_budget

    def _compute_surplus(self, theta: int, p: np.ndarray, n_samples: int = 10000) -> float:
        """
        Compute g_θ(p) = E[(r - <p, a>)_+] using closed-form for normal distributions.
        """
        from scipy.stats import norm

        profile = self._config_profiles[theta]

        # Mean of surplus s = r - <p, a>
        mu_s = profile['reward'] - np.dot(p, profile['consumption'])

        # Variance: σ_r² + Σ_j p_j² σ_{a_j}²
        var_s = profile['reward_std']**2 + np.sum((p * profile['consumption_std'])**2)
        sigma_s = np.sqrt(var_s)

        if sigma_s < 1e-10:
            return max(mu_s, 0.0)

        # E[s_+] = μ Φ(μ/σ) + σ φ(μ/σ) for s ~ N(μ, σ²)
        z = mu_s / sigma_s
        return mu_s * norm.cdf(z) + sigma_s * norm.pdf(z)

    def _solve_V_mix(self, b: np.ndarray, n_grid: int = 20) -> Tuple[float, np.ndarray, np.ndarray]:
        """
        Solve V^mix(b) = min_p { <p, b> + max_θ g_θ(p) } via optimization.
        """
        from scipy.optimize import minimize

        p_max = 5.0

        def objective(p):
            surpluses = np.array([self._compute_surplus(k, p) for k in range(self.K)])
            envelope = np.max(surpluses)
            return np.dot(p, b) + envelope

        best_val = np.inf
        best_p = np.zeros(self.d)

        n_restarts = 5
        for restart in range(n_restarts):
            if restart == 0:
                p_init = np.zeros(self.d)
            else:
                p_init = np.random.uniform(0, p_max / 2, self.d)

            result = minimize(
                objective,
                p_init,
                method='L-BFGS-B',
                bounds=[(0, p_max)] * self.d,
                options={'maxiter': 100, 'ftol': 1e-6}
            )

            if result.fun < best_val:
                best_val = result.fun
                best_p = result.x.copy()

        # Compute mixture achieving envelope at optimal price
        best_surpluses = np.array([self._compute_surplus(k, best_p) for k in range(self.K)])
        max_surplus = np.max(best_surpluses)
        w_star = (best_surpluses >= max_surplus - 1e-8).astype(float)
        w_star /= w_star.sum()

        return best_val, best_p, w_star

    def get_oracle_values(self, rho: float = 1.0) -> Dict[str, Any]:
        """Compute theoretical oracle values."""
        from scipy.optimize import minimize

        b = rho * self._nominal_budget / self.T

        V_mix, p_star, w_star = self._solve_V_mix(b)

        # Compute V*
        V_star = 0.0
        best_theta = 0
        p_max = 5.0

        for theta in range(self.K):
            def obj_theta(p, th=theta):
                g = self._compute_surplus(th, p)
                return np.dot(p, b) + g

            result = minimize(
                obj_theta,
                np.zeros(self.d),
                method='L-BFGS-B',
                bounds=[(0, p_max)] * self.d,
                options={'maxiter': 100}
            )

            if result.fun > V_star:
                V_star = result.fun
                best_theta = theta

        # Compute efficiencies
        efficiencies = {}
        for k, profile in self._config_profiles.items():
            total_cost = np.sum(profile['consumption'])
            efficiencies[k] = profile['reward'] / total_cost

        return {
            'V_mix': V_mix,
            'V_star': V_star,
            'gap': V_mix - V_star,
            'w_star': w_star,
            'p_star': p_star,
            'trap_configs': [0, 1],  # High reward but inefficient
            'efficient_configs': [2],  # True optimal
            'efficiencies': efficiencies,
            'best_fixed_config': best_theta,
            'rho': rho,
            'b_per_period': b,
        }

    def get_metadata(self) -> Dict[str, Any]:
        """Get metadata about S2."""
        base = super().get_metadata()
        oracle = self.get_oracle_values()

        return {
            **base,
            'family': 'S2',
            'name': 'Deceptive Arms (Efficiency Traps)',
            'deterministic': False,
            'efficiencies': oracle['efficiencies'],
            'trap_configs': oracle['trap_configs'],
            'efficient_configs': oracle['efficient_configs'],
            'design': 'High-reward traps vs low-cost efficient configs',
        }

    def compute_deceptiveness_analysis(self) -> Dict[str, Any]:
        """
        Analyze how deceptive the traps are.

        Returns statistics showing:
        - Reward distributions (traps look better)
        - Efficiency distributions (efficient config is better)
        """
        stats = {}
        for theta in range(self.K):
            arrivals = self._arrivals[theta]
            rewards = arrivals[:, 0]
            consumptions = arrivals[:, 1:]
            total_consumption = np.sum(consumptions, axis=1)
            efficiencies = rewards / total_consumption

            stats[f'config_{theta}'] = {
                'mean_reward': float(np.mean(rewards)),
                'std_reward': float(np.std(rewards)),
                'reward_95th_pct': float(np.percentile(rewards, 95)),
                'mean_efficiency': float(np.mean(efficiencies)),
                'std_efficiency': float(np.std(efficiencies)),
                'efficiency_5th_pct': float(np.percentile(efficiencies, 5)),
                'efficiency_95th_pct': float(np.percentile(efficiencies, 95)),
            }

        return stats


def test_s2_loader():
    """Test S2 data loader."""
    print("=" * 60)
    print("Testing S2: Deceptive Arms (Efficiency Traps) (K=4, d=3)")
    print("=" * 60)

    loader = S2NoisyLoader(K=4, T=5000, seed=42, d=3)

    print(f"\nDimensions: K={loader.K}, d={loader.d}, T={loader.T}")

    print("\nConfig Profiles:")
    for theta in range(loader.K):
        profile = loader._config_profiles[theta]
        total_cost = np.sum(profile['consumption'])
        eff = profile['reward'] / total_cost
        print(f"  Config {theta}: r={profile['reward']:.2f}±{profile['reward_std']:.2f}, "
              f"cost={total_cost:.2f}, eff≈{eff:.2f}")

    print("\nDeceptiveness Analysis:")
    dec_stats = loader.compute_deceptiveness_analysis()
    for config, stats in dec_stats.items():
        print(f"  {config}:")
        print(f"    Reward: {stats['mean_reward']:.2f}±{stats['std_reward']:.2f} (95th: {stats['reward_95th_pct']:.2f})")
        print(f"    Efficiency: {stats['mean_efficiency']:.2f}±{stats['std_efficiency']:.2f}")

    print("\nOracle Values:")
    oracle = loader.get_oracle_values()
    print(f"  V^mix ≈ {oracle['V_mix']:.4f}")
    print(f"  V* ≈ {oracle['V_star']:.4f}")
    print(f"  Efficiencies: {oracle['efficiencies']}")
    print(f"  Trap configs: {oracle['trap_configs']}")
    print(f"  Efficient config: {oracle['efficient_configs']}")

    print(f"\nBudget (rho=1.0): {loader.get_budget(1.0)}")

    print("\nKey Insight:")
    print("  - Traps (0,1) have HIGH rewards but LOW efficiency")
    print("  - Efficient config (2) has LOWER reward but HIGH efficiency")
    print("  - Greedy attracted to high rewards, drains budget on traps")
    print("  - UCB exploration discovers efficient config")

    print(f"\nValidation: {loader.validate()}")
    print("\n" + "=" * 60)


if __name__ == "__main__":
    test_s2_loader()
