"""
S4: Complementarity Data Loader

This scenario directly validates the "fixed oracle can be beaten" pathology.
It shows that V^mix > V* when configurations have complementary resource usage.

Design Principles:
- Two configurations with ORTHOGONAL resource consumption
- Config 0 uses only resource 1, Config 1 uses only resource 2
- Budget is equal across both resources
- Fixed oracle can only use one resource → V* = 0.5
- Switching oracle uses BOTH resources → V^mix = 1.0
- Complementarity gap = V^mix / V* = 2.0

Setup (K=2, d=2):
- Config 0: r ~ 1 + noise, a = [1, 0] + noise (resource 1 only)
- Config 1: r ~ 1 + noise, a = [0, 1] + noise (resource 2 only)

Budget: b = [0.5, 0.5] per period

Key insight:
- Best fixed config uses only ONE resource budget
- Alternating policy uses BOTH resource budgets
- SP-UCB-OLP should learn to mix and achieve CR^mix ≈ 1
- CR^* (vs fixed oracle) will EXCEED 1.0
"""

import numpy as np
from typing import Dict, Any, Tuple
from .base_loader import BaseDataLoader


class S4ComplementarityLoader(BaseDataLoader):
    """
    S4: Complementarity with orthogonal resource usage.

    Parameters
    ----------
    K : int
        Number of configs (default: 2, designed for exactly 2)
    T : int
        Time horizon
    seed : int
        Random seed
    noise_std : float
        Small noise for no-ties assumption (default: 0.01)
    d : int
        Number of resource dimensions (default: 2, designed for exactly 2)
    """

    def __init__(
        self,
        K: int = 2,
        T: int = 10000,
        seed: int = 42,
        noise_std: float = 0.01,
        d: int = 2,  # Designed for d=2
    ):
        # Force K=2, d=2 for this scenario (core design)
        if K != 2:
            print(f"Warning: S4 is designed for K=2, got K={K}. Using K=2.")
            K = 2
        if d != 2:
            print(f"Warning: S4 is designed for d=2, got d={d}. Using d=2.")
            d = 2

        super().__init__(K, d, T, seed)

        self.noise_std = noise_std

        # Config profiles: orthogonal resource consumption
        # Config 0: consumes resource 1 only
        # Config 1: consumes resource 2 only
        self._config_profiles = {
            0: {'reward': 1.0, 'consumption': np.array([1.0, 0.0])},
            1: {'reward': 1.0, 'consumption': np.array([0.0, 1.0])},
        }

        self._generate_arrivals()
        self._compute_nominal_budget()

    def _generate_arrivals(self):
        """Generate arrivals with small noise for no-ties assumption."""
        rng = np.random.RandomState(self.seed)

        for theta in range(self.K):
            profile = self._config_profiles[theta]
            arrivals = np.zeros((self.T, self.d + 1))

            for t in range(self.T):
                # Reward with small noise
                r = profile['reward'] + rng.uniform(-self.noise_std, self.noise_std)
                r = max(r, 0.01)

                # Consumption with small noise (maintain orthogonality approximately)
                a = profile['consumption'].copy()
                a += rng.uniform(-self.noise_std, self.noise_std, self.d)
                a = np.maximum(a, 0.0)  # Allow 0 for orthogonal dimension

                arrivals[t, 0] = r
                arrivals[t, 1:] = a

            self._arrivals[theta] = arrivals

    def _compute_nominal_budget(self):
        """
        Compute budget to demonstrate complementarity.

        Budget = [0.5, 0.5] per period means:
        - Fixed config 0: can accept 50% of arrivals (limited by resource 1)
        - Fixed config 1: can accept 50% of arrivals (limited by resource 2)
        - Mixing: can accept 50% from each config = 100% total utilization
        """
        # Per-period budget of 0.5 for each resource
        self._nominal_budget = np.array([0.5, 0.5]) * self.T

    def get_arrival(self, theta: int, t: int) -> Tuple[float, np.ndarray]:
        """Get arrival."""
        if theta < 0 or theta >= self.K:
            raise ValueError(f"theta {theta} out of range [0, {self.K-1}]")
        if t < 0 or t >= self.T:
            raise ValueError(f"t {t} out of range [0, {self.T-1}]")

        arrival = self._arrivals[theta][t]
        return float(arrival[0]), arrival[1:].copy()

    def get_budget(self, rho: float) -> np.ndarray:
        """Get budget scaled by rho."""
        return rho * self._nominal_budget

    def _compute_surplus(self, theta: int, p: np.ndarray, n_samples: int = 10000) -> float:
        """
        Compute g_θ(p) = E[(r - <p, a>)_+] using closed-form.

        For near-deterministic arrivals with small noise:
        g_θ(p) ≈ max(r_base - <p, a_base>, 0)
        """
        profile = self._config_profiles[theta]

        # Mean surplus (ignoring small noise)
        mu_s = profile['reward'] - np.dot(p, profile['consumption'])

        # For small noise, use deterministic approximation
        # Add small adjustment for noise variance
        if mu_s > 3 * self.noise_std:
            return mu_s  # Clearly positive
        elif mu_s < -3 * self.noise_std:
            return 0.0  # Clearly negative
        else:
            # Transition region: use linear approximation
            # E[(s)_+] ≈ μ·Φ(μ/σ) + σ·φ(μ/σ) for small σ
            from scipy.stats import norm
            sigma_s = self.noise_std * np.sqrt(1 + np.sum(p ** 2))
            if sigma_s < 1e-10:
                return max(mu_s, 0.0)
            z = mu_s / sigma_s
            return mu_s * norm.cdf(z) + sigma_s * norm.pdf(z)

    def _solve_V_mix(self, b: np.ndarray, n_grid: int = 20) -> Tuple[float, np.ndarray, np.ndarray]:
        """
        Solve V^mix(b) = min_p { <p, b> + max_θ g_θ(p) }.

        For complementary configs with equal budgets, optimal is:
        - p* = [0, 0] (accept everything)
        - w* = [0.5, 0.5] (equal mixture)
        - V^mix = 1.0 (accept all arrivals from both configs)
        """
        from scipy.optimize import minimize

        p_max = 5.0

        def objective(p):
            """Objective: <p, b> + max_θ g_θ(p)"""
            surpluses = np.array([self._compute_surplus(k, p) for k in range(self.K)])
            envelope = np.max(surpluses)
            return np.dot(p, b) + envelope

        # Multi-start optimization
        best_val = np.inf
        best_p = np.zeros(self.d)

        n_restarts = 5
        for restart in range(n_restarts):
            if restart == 0:
                p_init = np.zeros(self.d)
            else:
                p_init = np.random.uniform(0, p_max / 2, self.d)

            result = minimize(
                objective,
                p_init,
                method='L-BFGS-B',
                bounds=[(0, p_max)] * self.d,
                options={'maxiter': 100, 'ftol': 1e-6}
            )

            if result.fun < best_val:
                best_val = result.fun
                best_p = result.x.copy()

        # Compute optimal mixture at optimal price
        best_surpluses = np.array([self._compute_surplus(k, best_p) for k in range(self.K)])
        max_surplus = np.max(best_surpluses)
        w_star = (best_surpluses >= max_surplus - 1e-8).astype(float)
        w_star /= w_star.sum()

        return best_val, best_p, w_star

    def get_oracle_values(self, rho: float = 1.0) -> Dict[str, Any]:
        """
        Compute theoretical oracle values.

        For S4 with complementary resources:
        - V* (best fixed config) ≈ 0.5 (uses only one resource budget)
        - V^mix (switching) ≈ 1.0 (uses both resource budgets)
        - Complementarity gap ≈ 2.0
        """
        from scipy.optimize import minimize

        # Per-period budget scaled by rho
        b = rho * self._nominal_budget / self.T

        # Solve switching-aware oracle
        V_mix, p_star, w_star = self._solve_V_mix(b)

        # Compute V* = max_θ min_p { <p, b> + g_θ(p) }
        V_star = 0.0
        best_theta = 0
        p_max = 5.0

        for theta in range(self.K):
            def obj_theta(p, th=theta):
                g = self._compute_surplus(th, p)
                return np.dot(p, b) + g

            result = minimize(
                obj_theta,
                np.zeros(self.d),
                method='L-BFGS-B',
                bounds=[(0, p_max)] * self.d,
                options={'maxiter': 100}
            )

            if result.fun > V_star:
                V_star = result.fun
                best_theta = theta

        # Compute complementarity gap
        comp_gap = V_mix / V_star if V_star > 0 else float('inf')

        return {
            'V_mix': V_mix,
            'V_star': V_star,
            'gap': V_mix - V_star,
            'complementarity_gap': comp_gap,
            'w_star': w_star,
            'p_star': p_star,
            'best_fixed_config': best_theta,
            'rho': rho,
            'b_per_period': b,
        }

    def get_metadata(self) -> Dict[str, Any]:
        """Get metadata about S4."""
        base = super().get_metadata()
        oracle = self.get_oracle_values()

        return {
            **base,
            'family': 'S4',
            'name': 'Complementarity',
            'noise_std': self.noise_std,
            'deterministic': False,  # Small noise for no-ties
            'V_mix': oracle['V_mix'],
            'V_star': oracle['V_star'],
            'complementarity_gap': oracle['complementarity_gap'],
            'expected_gap': 2.0,  # Theoretical expected gap
        }


def test_s4_loader():
    """Test S4 data loader."""
    print("=" * 60)
    print("Testing S4: Complementarity Loader (K=2, d=2)")
    print("=" * 60)

    loader = S4ComplementarityLoader(K=2, T=1000, seed=42, d=2)

    print(f"\nDimensions: K={loader.K}, d={loader.d}, T={loader.T}")

    print("\nConfig Profiles (orthogonal resource usage):")
    for theta in range(loader.K):
        profile = loader._config_profiles[theta]
        print(f"  Config {theta}: r={profile['reward']:.2f}, a={profile['consumption']}")

    print("\nSample arrivals (with noise):")
    for theta in range(loader.K):
        for t in range(3):
            r, a = loader.get_arrival(theta, t)
            print(f"  Config {theta}, t={t}: r={r:.4f}, a={a}")

    print("\nOracle Values:")
    oracle = loader.get_oracle_values()
    print(f"  V^mix ≈ {oracle['V_mix']:.4f}")
    print(f"  V* ≈ {oracle['V_star']:.4f}")
    print(f"  Complementarity gap = V^mix/V* ≈ {oracle['complementarity_gap']:.2f}")
    print(f"  (Expected gap: 2.0)")
    print(f"  p* = {oracle['p_star']}")
    print(f"  w* = {oracle['w_star']}")

    print(f"\nBudget (rho=1.0): {loader.get_budget(1.0)}")
    print(f"Per-period budget: {loader.get_budget(1.0) / loader.T}")

    print(f"\nValidation: {loader.validate()}")

    # Verify complementarity gap is close to 2
    assert 1.8 <= oracle['complementarity_gap'] <= 2.2, \
        f"Complementarity gap {oracle['complementarity_gap']:.2f} not close to expected 2.0"
    print("\n✓ Complementarity gap verified (≈ 2.0)")

    print("\n" + "=" * 60)


if __name__ == "__main__":
    test_s4_loader()
