"""
S1: Exploration-Critical Scenario (Redesigned)

This scenario tests SP-UCB-OLP's exploration advantage when configs have
SIMILAR expected efficiencies but different variances.

Key Design (Why Exploration Matters):
- ALL configs have similar expected efficiency (~1.4-1.6)
- HIGH-VARIANCE configs can look excellent OR terrible by chance
- The TRUE OPTIMAL has slightly higher efficiency but LOWER variance
- Greedy gets "stuck" when a high-variance config gets lucky early samples
- UCB exploration maintains optimism and eventually discovers the true optimal

Setup (K=4, d=3):
- Config 0 (Trap A): High reward, high cost, HIGH VARIANCE
  - Mean efficiency ~1.4, but variance means range ~[0.8, 2.0]
  - If early samples lucky, looks like best config
- Config 1 (Hidden Gem): Moderate reward, low cost, LOW VARIANCE
  - Mean efficiency ~1.6 (truly optimal), low variance confirms it
  - Requires exploration to accumulate evidence
- Config 2 (Trap B): Medium reward, medium cost, HIGH VARIANCE
  - Similar efficiency to Trap A, can also look good by chance
- Config 3 (Distractor): Lower efficiency, medium variance

Budget: Tight - efficiency differences matter over long horizon.

The insight:
- After K rounds, greedy may commit to a "trap" that got lucky
- UCB maintains exploration and eventually samples Hidden Gem enough
- With enough samples, UCB shifts weight to Hidden Gem
"""

import numpy as np
from typing import Dict, Any, Tuple
from .base_loader import BaseDataLoader


class S1ComplementarityLoader(BaseDataLoader):
    """
    S1: Exploration-Critical Scenario with similar efficiencies.

    Parameters
    ----------
    K : int
        Number of configs (default: 4)
    T : int
        Time horizon
    seed : int
        Random seed
    d : int
        Number of resource dimensions (default: 3)
    """

    def __init__(
        self,
        K: int = 4,
        T: int = 10000,
        seed: int = 42,
        d: int = 3,
    ):
        super().__init__(K, d, T, seed)

        # Config profiles designed so exploration is NECESSARY:
        # - All configs have similar expected efficiency (~1.4-1.6)
        # - High-variance configs can look better OR worse by chance
        # - Low-variance optimal is only identifiable with sufficient samples

        self._config_profiles = {
            # Config 0 (Trap A): HIGH VARIANCE, similar efficiency
            # High reward but high cost - efficiency ~1.4
            # High variance means early samples could be anywhere in [0.8, 2.0]
            0: {
                'reward_mean': 1.4, 'reward_std': 0.6,  # HIGH reward variance
                'consumption_mean': np.array([0.35, 0.35, 0.30]),  # total ~1.0
                'consumption_std': np.array([0.15, 0.15, 0.12]),  # HIGH consumption variance
            },

            # Config 1 (Hidden Gem): LOW VARIANCE, slightly better efficiency
            # True optimal with efficiency ~1.6, but LOW variance means
            # it doesn't look amazing early - requires exploration to confirm
            1: {
                'reward_mean': 1.12, 'reward_std': 0.08,  # LOW variance
                'consumption_mean': np.array([0.25, 0.22, 0.23]),  # total ~0.7, eff ~1.6
                'consumption_std': np.array([0.02, 0.02, 0.02]),  # VERY LOW variance
            },

            # Config 2 (Trap B): HIGH VARIANCE, similar efficiency to Trap A
            # Another trap that can look great by chance
            2: {
                'reward_mean': 1.2, 'reward_std': 0.5,  # HIGH variance
                'consumption_mean': np.array([0.30, 0.28, 0.27]),  # total ~0.85, eff ~1.4
                'consumption_std': np.array([0.12, 0.12, 0.10]),  # HIGH variance
            },

            # Config 3 (Distractor): Medium efficiency, medium variance
            # Clearly worse but included for completeness
            3: {
                'reward_mean': 0.9, 'reward_std': 0.3,
                'consumption_mean': np.array([0.28, 0.26, 0.26]),  # total ~0.8, eff ~1.12
                'consumption_std': np.array([0.08, 0.08, 0.08]),
            },
        }

        # Extend for K > 4 with random configs (similar efficiency range)
        if K > 4:
            rng = np.random.RandomState(seed)
            for k in range(4, K):
                # Random efficiency between 1.0 and 1.5
                eff = 1.0 + 0.5 * rng.random()
                total_consumption = 0.6 + 0.4 * rng.random()
                r_mean = eff * total_consumption

                self._config_profiles[k] = {
                    'reward_mean': r_mean,
                    'reward_std': 0.3 + 0.3 * rng.random(),  # Random variance
                    'consumption_mean': rng.dirichlet(np.ones(self.d)) * total_consumption,
                    'consumption_std': 0.1 * np.ones(self.d),
                }

        self._generate_arrivals()
        self._compute_nominal_budget()

    def _generate_arrivals(self):
        """Generate i.i.d. arrivals from stationary distributions."""
        rng = np.random.RandomState(self.seed)

        for theta in range(self.K):
            profile = self._config_profiles[theta]
            arrivals = np.zeros((self.T, self.d + 1))

            for t in range(self.T):
                # i.i.d. samples from stationary distribution
                r = rng.normal(profile['reward_mean'], profile['reward_std'])
                r = max(r, 0.01)  # Ensure positive

                # d-dimensional consumption vector
                a = rng.normal(profile['consumption_mean'], profile['consumption_std'])
                a = np.maximum(a, 0.01)  # Ensure positive

                arrivals[t, 0] = r
                arrivals[t, 1:] = a

            self._arrivals[theta] = arrivals

    def _compute_nominal_budget(self):
        """
        Compute budget that makes efficiency differences matter.

        Budget is set tight enough that:
        - Using inefficient configs leads to budget exhaustion
        - Using efficient config allows near-full horizon acceptance
        """
        # Budget based on hidden gem (config 1) consumption
        # Allow ~70% acceptance rate with optimal config
        hidden_gem_consumption = self._config_profiles[1]['consumption_mean']
        self._nominal_budget = hidden_gem_consumption * self.T * 0.7

    def get_arrival(self, theta: int, t: int) -> Tuple[float, np.ndarray]:
        """Get arrival."""
        if theta < 0 or theta >= self.K:
            raise ValueError(f"theta {theta} out of range [0, {self.K-1}]")
        if t < 0 or t >= self.T:
            raise ValueError(f"t {t} out of range [0, {self.T-1}]")

        arrival = self._arrivals[theta][t]
        return float(arrival[0]), arrival[1:].copy()

    def get_budget(self, rho: float) -> np.ndarray:
        """Get budget scaled by rho."""
        return rho * self._nominal_budget

    def _compute_surplus(self, theta: int, p: np.ndarray, n_samples: int = 10000) -> float:
        """
        Compute g_θ(p) = E[(r - <p, a>)_+] using closed-form for normal distributions.

        For normal r and a, the surplus s = r - <p,a> is also normal with:
        - μ_s = μ_r - Σ_j p_j μ_{a_j}
        - σ_s² = σ_r² + Σ_j p_j² σ_{a_j}²

        The expected positive part is: E[s_+] = μ_s Φ(μ_s/σ_s) + σ_s φ(μ_s/σ_s)
        """
        from scipy.stats import norm

        profile = self._config_profiles[theta]

        # Mean and variance of surplus s = r - <p, a>
        mu_s = profile['reward_mean'] - np.dot(p, profile['consumption_mean'])
        # Variance: σ_r² + Σ_j p_j² σ_{a_j}²
        var_s = profile['reward_std']**2 + np.sum((p * profile['consumption_std'])**2)
        sigma_s = np.sqrt(var_s)

        if sigma_s < 1e-10:
            return max(mu_s, 0.0)

        # E[s_+] = μ Φ(μ/σ) + σ φ(μ/σ) for s ~ N(μ, σ²)
        z = mu_s / sigma_s
        return mu_s * norm.cdf(z) + sigma_s * norm.pdf(z)

    def _solve_V_mix(self, b: np.ndarray, n_grid: int = 20) -> Tuple[float, np.ndarray, np.ndarray]:
        """
        Solve V^mix(b) = min_p { <p, b> + max_θ g_θ(p) } via optimization.
        """
        from scipy.optimize import minimize

        p_max = 5.0

        def objective(p):
            surpluses = np.array([self._compute_surplus(k, p) for k in range(self.K)])
            envelope = np.max(surpluses)
            return np.dot(p, b) + envelope

        best_val = np.inf
        best_p = np.zeros(self.d)

        n_restarts = 5
        for restart in range(n_restarts):
            if restart == 0:
                p_init = np.zeros(self.d)
            else:
                p_init = np.random.uniform(0, p_max / 2, self.d)

            result = minimize(
                objective,
                p_init,
                method='L-BFGS-B',
                bounds=[(0, p_max)] * self.d,
                options={'maxiter': 100, 'ftol': 1e-6}
            )

            if result.fun < best_val:
                best_val = result.fun
                best_p = result.x.copy()

        # Compute optimal mixture
        best_surpluses = np.array([self._compute_surplus(k, best_p) for k in range(self.K)])
        max_surplus = np.max(best_surpluses)
        w_star = (best_surpluses >= max_surplus - 1e-8).astype(float)
        w_star /= w_star.sum()

        return best_val, best_p, w_star

    def get_oracle_values(self, rho: float = 1.0) -> Dict[str, Any]:
        """Compute theoretical oracle values."""
        from scipy.optimize import minimize

        b = rho * self._nominal_budget / self.T

        # Compute efficiencies
        efficiencies = {}
        for k, profile in self._config_profiles.items():
            total_cost = np.sum(profile['consumption_mean'])
            efficiencies[k] = profile['reward_mean'] / total_cost

        V_mix, p_star, w_star = self._solve_V_mix(b)

        # Compute V*
        V_star = 0.0
        best_theta = 0
        p_max = 5.0

        for theta in range(self.K):
            def obj_theta(p, th=theta):
                g = self._compute_surplus(th, p)
                return np.dot(p, b) + g

            result = minimize(
                obj_theta,
                np.zeros(self.d),
                method='L-BFGS-B',
                bounds=[(0, p_max)] * self.d,
                options={'maxiter': 100}
            )

            if result.fun > V_star:
                V_star = result.fun
                best_theta = theta

        return {
            'V_mix': V_mix,
            'V_star': V_star,
            'gap': V_mix - V_star,
            'w_star': w_star,
            'p_star': p_star,
            'efficiencies': efficiencies,
            'best_fixed_config': best_theta,
            'rho': rho,
            'b_per_period': b,
            'hidden_gem': 1,  # Config 1 is the true optimal
            'traps': [0, 2],  # High-variance traps
        }

    def get_metadata(self) -> Dict[str, Any]:
        """Get metadata about S1."""
        base = super().get_metadata()
        oracle = self.get_oracle_values()

        return {
            **base,
            'family': 'S1',
            'name': 'Exploration-Critical (Similar Efficiencies)',
            'efficiencies': oracle['efficiencies'],
            'V_mix_approx': oracle['V_mix'],
            'V_star_approx': oracle['V_star'],
            'gap_approx': oracle['gap'],
            'hidden_gem': oracle['hidden_gem'],
            'traps': oracle['traps'],
            'design': 'Similar efficiencies, variance-driven exploration',
        }

    def compute_variance_analysis(self) -> Dict[str, Any]:
        """
        Analyze config variances to show why exploration matters.

        Returns statistics showing that:
        - High-variance configs have wide efficiency ranges
        - Low-variance config (hidden gem) is consistently good
        """
        stats = {}
        for theta in range(self.K):
            arrivals = self._arrivals[theta]
            rewards = arrivals[:, 0]
            consumptions = arrivals[:, 1:]
            total_consumption = np.sum(consumptions, axis=1)
            efficiencies = rewards / total_consumption

            # Compute statistics
            stats[f'config_{theta}'] = {
                'mean_efficiency': float(np.mean(efficiencies)),
                'std_efficiency': float(np.std(efficiencies)),
                'efficiency_5th_pct': float(np.percentile(efficiencies, 5)),
                'efficiency_95th_pct': float(np.percentile(efficiencies, 95)),
                'mean_reward': float(np.mean(rewards)),
                'std_reward': float(np.std(rewards)),
            }

        return stats


def test_s1_loader():
    """Test S1 data loader."""
    print("=" * 60)
    print("Testing S1: Exploration-Critical Scenario (K=4, d=3)")
    print("=" * 60)

    loader = S1ComplementarityLoader(K=4, T=5000, seed=42, d=3)

    print(f"\nDimensions: K={loader.K}, d={loader.d}, T={loader.T}")

    print("\nConfig Profiles (efficiency and variance):")
    for theta in range(loader.K):
        profile = loader._config_profiles[theta]
        total_cost = np.sum(profile['consumption_mean'])
        eff = profile['reward_mean'] / total_cost
        print(f"  Config {theta}: r_mean={profile['reward_mean']:.2f}±{profile['reward_std']:.2f}, "
              f"cost={total_cost:.2f}, eff≈{eff:.2f}")

    print("\nVariance Analysis (why exploration matters):")
    var_stats = loader.compute_variance_analysis()
    for config, stats in var_stats.items():
        print(f"  {config}:")
        print(f"    Efficiency: {stats['mean_efficiency']:.2f} ± {stats['std_efficiency']:.2f}")
        print(f"    Range (5%-95%): [{stats['efficiency_5th_pct']:.2f}, {stats['efficiency_95th_pct']:.2f}]")

    print("\nOracle Values:")
    oracle = loader.get_oracle_values()
    print(f"  V^mix ≈ {oracle['V_mix']:.4f}")
    print(f"  V* ≈ {oracle['V_star']:.4f}")
    print(f"  Hidden gem: config {oracle['hidden_gem']}")
    print(f"  Traps: configs {oracle['traps']}")
    print(f"  p* = {oracle['p_star']}")
    print(f"  w* = {oracle['w_star']}")

    print(f"\nBudget (rho=1.0): {loader.get_budget(1.0)}")

    print("\nKey Insight:")
    print("  - Traps (0,2) have similar mean efficiency but HIGH variance")
    print("  - Hidden gem (1) has slightly better efficiency but LOW variance")
    print("  - Greedy may commit to trap if early samples lucky")
    print("  - UCB exploration discovers hidden gem through optimism")

    print(f"\nValidation: {loader.validate()}")
    print("\n" + "=" * 60)


if __name__ == "__main__":
    test_s1_loader()
