"""
S5: Simple 5-Arm Gaussian Environment

A straightforward environment for validating regret scaling with α ≥ 1.

Setup (K=5, d=3):
- 5 arms with different mean rewards and consumption profiles
- Truncated Gaussian distributions (positive values)
- Each arm specializes in different resource dimensions

Purpose:
- Test theory-compliant exploration ratio α = 1.5
- Validate √T regret scaling predicted by Theorem 5
"""

import numpy as np
from typing import Dict, Any, Tuple
from .base_loader import BaseDataLoader


class S5GaussianLoader(BaseDataLoader):
    """
    S5: Simple 5-Arm Gaussian environment for α ≥ 1 validation.

    Parameters
    ----------
    K : int
        Number of configs (default: 5)
    T : int
        Time horizon
    seed : int
        Random seed
    d : int
        Number of resource dimensions (default: 3)
    """

    def __init__(
        self,
        K: int = 5,
        T: int = 10000,
        seed: int = 42,
        d: int = 3,
    ):
        super().__init__(K, d, T, seed)

        # Bounds matching paper's assumptions (used for truncation)
        self.R_max = 2.0
        self.A_max = 2.0

        # Number of Monte Carlo samples for surplus estimation
        self._n_mc_samples = 50000
        self._mc_samples = None  # Lazy initialization

        # 5-arm configuration with diverse profiles
        # Each arm has different reward mean and consumption pattern
        self._config_profiles = {
            # Arm 0: High reward, resource-1 heavy
            0: {
                'reward_mean': 1.0, 'reward_std': 0.3,
                'consumption_mean': np.array([0.8, 0.2, 0.2]),
                'consumption_std': np.array([0.2, 0.2, 0.2]),
            },
            # Arm 1: Medium reward, resource-2 heavy
            1: {
                'reward_mean': 0.8, 'reward_std': 0.3,
                'consumption_mean': np.array([0.2, 0.7, 0.2]),
                'consumption_std': np.array([0.2, 0.2, 0.2]),
            },
            # Arm 2: Lower reward, resource-3 heavy
            2: {
                'reward_mean': 0.6, 'reward_std': 0.3,
                'consumption_mean': np.array([0.2, 0.2, 0.6]),
                'consumption_std': np.array([0.2, 0.2, 0.2]),
            },
            # Arm 3: High reward, balanced consumption
            3: {
                'reward_mean': 0.9, 'reward_std': 0.3,
                'consumption_mean': np.array([0.5, 0.5, 0.5]),
                'consumption_std': np.array([0.2, 0.2, 0.2]),
            },
            # Arm 4: Low reward, low consumption (efficient)
            4: {
                'reward_mean': 0.4, 'reward_std': 0.3,
                'consumption_mean': np.array([0.1, 0.1, 0.1]),
                'consumption_std': np.array([0.2, 0.2, 0.2]),
            },
        }

        self._generate_arrivals()
        self._compute_nominal_budget()

    def _generate_arrivals(self):
        """Generate i.i.d. arrivals from truncated Gaussian distributions.

        Values are bounded to [0.01, R_max] for rewards and [0.01, A_max] for consumptions.
        """
        rng = np.random.RandomState(self.seed)

        for theta in range(self.K):
            profile = self._config_profiles[theta]
            arrivals = np.zeros((self.T, self.d + 1))

            for t in range(self.T):
                # Sample reward (truncated to [0.01, R_max])
                r = rng.normal(profile['reward_mean'], profile['reward_std'])
                r = np.clip(r, 0.01, self.R_max)

                # Sample consumption vector (truncated to [0.01, A_max])
                a = rng.normal(profile['consumption_mean'], profile['consumption_std'])
                a = np.clip(a, 0.01, self.A_max)

                arrivals[t, 0] = r
                arrivals[t, 1:] = a

            self._arrivals[theta] = arrivals

    def _compute_nominal_budget(self):
        """Set nominal budget based on average consumption."""
        # Average consumption across all arms
        avg_consumption = np.mean([
            p['consumption_mean'] for p in self._config_profiles.values()
        ], axis=0)
        # Nominal budget allows ~50% acceptance at rho=1.0
        self._nominal_budget = avg_consumption * self.T * 0.5

    def get_arrival(self, theta: int, t: int) -> Tuple[float, np.ndarray]:
        """Get arrival at timestep t under configuration theta."""
        if theta < 0 or theta >= self.K:
            raise ValueError(f"theta {theta} out of range [0, {self.K-1}]")
        if t < 0 or t >= self.T:
            raise ValueError(f"t {t} out of range [0, {self.T-1}]")

        arrival = self._arrivals[theta][t]
        return float(arrival[0]), arrival[1:].copy()

    def get_budget(self, rho: float) -> np.ndarray:
        """Get budget scaled by rho."""
        return rho * self._nominal_budget

    def _generate_mc_samples(self):
        """Generate Monte Carlo samples for surplus estimation (lazy initialization)."""
        if self._mc_samples is not None:
            return

        rng = np.random.RandomState(12345)  # Fixed seed for reproducibility
        self._mc_samples = {}

        for theta in range(self.K):
            profile = self._config_profiles[theta]
            n = self._n_mc_samples

            # Sample rewards (truncated)
            r = rng.normal(profile['reward_mean'], profile['reward_std'], n)
            r = np.clip(r, 0.01, self.R_max)

            # Sample consumptions (truncated)
            a = rng.normal(profile['consumption_mean'], profile['consumption_std'], (n, self.d))
            a = np.clip(a, 0.01, self.A_max)

            self._mc_samples[theta] = (r, a)

    def _compute_surplus(self, theta: int, p: np.ndarray) -> float:
        """
        Compute g_θ(p) = E[(r - <p, a>)_+] using Monte Carlo with truncated distributions.

        This matches the paper's definition where arrivals are bounded.
        """
        self._generate_mc_samples()

        r, a = self._mc_samples[theta]
        surplus = r - a @ p
        return float(np.mean(np.maximum(surplus, 0)))

    def _solve_V_mix(self, b: np.ndarray) -> Tuple[float, np.ndarray, np.ndarray]:
        """Solve V^mix(b) = min_p { <p, b> + max_θ g_θ(p) }."""
        from scipy.optimize import minimize

        p_max = 5.0

        def objective(p):
            surpluses = np.array([self._compute_surplus(k, p) for k in range(self.K)])
            envelope = np.max(surpluses)
            return np.dot(p, b) + envelope

        best_val = np.inf
        best_p = np.zeros(self.d)

        # Multiple restarts for robustness
        for restart in range(5):
            if restart == 0:
                p_init = np.zeros(self.d)
            else:
                p_init = np.random.uniform(0, p_max / 2, self.d)

            result = minimize(
                objective,
                p_init,
                method='L-BFGS-B',
                bounds=[(0, p_max)] * self.d,
                options={'maxiter': 100, 'ftol': 1e-6}
            )

            if result.fun < best_val:
                best_val = result.fun
                best_p = result.x.copy()

        # Compute optimal mixture
        best_surpluses = np.array([self._compute_surplus(k, best_p) for k in range(self.K)])
        max_surplus = np.max(best_surpluses)
        w_star = (best_surpluses >= max_surplus - 1e-8).astype(float)
        w_star /= w_star.sum()

        return best_val, best_p, w_star

    def get_oracle_values(self, rho: float = 1.0) -> Dict[str, Any]:
        """Compute theoretical oracle values."""
        from scipy.optimize import minimize

        b = rho * self._nominal_budget / self.T

        # Compute efficiencies
        efficiencies = {}
        for k, profile in self._config_profiles.items():
            total_cost = np.sum(profile['consumption_mean'])
            efficiencies[k] = profile['reward_mean'] / total_cost

        V_mix, p_star, w_star = self._solve_V_mix(b)

        # Compute V* (best fixed config)
        V_star = 0.0
        best_theta = 0
        p_max = 5.0

        for theta in range(self.K):
            def obj_theta(p, th=theta):
                g = self._compute_surplus(th, p)
                return np.dot(p, b) + g

            result = minimize(
                obj_theta,
                np.zeros(self.d),
                method='L-BFGS-B',
                bounds=[(0, p_max)] * self.d,
                options={'maxiter': 100}
            )

            if result.fun > V_star:
                V_star = result.fun
                best_theta = theta

        return {
            'V_mix': V_mix,
            'V_star': V_star,
            'gap': V_mix - V_star,
            'w_star': w_star,
            'p_star': p_star,
            'efficiencies': efficiencies,
            'best_fixed_config': best_theta,
            'rho': rho,
            'b_per_period': b,
        }

    def get_metadata(self) -> Dict[str, Any]:
        """Get metadata about S5."""
        base = super().get_metadata()
        oracle = self.get_oracle_values()

        return {
            **base,
            'family': 'S5',
            'name': 'Simple 5-Arm Gaussian',
            'efficiencies': oracle['efficiencies'],
            'V_mix_approx': oracle['V_mix'],
            'V_star_approx': oracle['V_star'],
            'gap_approx': oracle['gap'],
            'design': 'Simple Gaussian for alpha>=1 validation',
        }


def test_s5_loader():
    """Test S5 data loader."""
    print("=" * 60)
    print("Testing S5: Simple 5-Arm Gaussian (K=5, d=3)")
    print("=" * 60)

    loader = S5GaussianLoader(K=5, T=5000, seed=42, d=3)

    print(f"\nDimensions: K={loader.K}, d={loader.d}, T={loader.T}")

    print("\nConfig Profiles:")
    for theta in range(loader.K):
        profile = loader._config_profiles[theta]
        total_cost = np.sum(profile['consumption_mean'])
        eff = profile['reward_mean'] / total_cost
        print(f"  Arm {theta}: r_mean={profile['reward_mean']:.2f}, "
              f"cost={total_cost:.2f}, eff={eff:.2f}")

    print("\nOracle Values:")
    oracle = loader.get_oracle_values(rho=0.7)
    print(f"  V^mix = {oracle['V_mix']:.4f}")
    print(f"  V* = {oracle['V_star']:.4f}")
    print(f"  Gap = {oracle['gap']:.4f}")
    print(f"  p* = {oracle['p_star']}")
    print(f"  w* = {oracle['w_star']}")

    print(f"\nBudget (rho=0.7): {loader.get_budget(0.7)}")
    print(f"Validation: {loader.validate()}")
    print("=" * 60)


if __name__ == "__main__":
    test_s5_loader()
