"""
Oracle Computation Module

Computes the switching-aware fluid oracle V^mix and the fixed-configuration
oracle V* from sample data. Also extracts the optimal mixture w* and price p*.

Key formulas:
- V^mix(b) = min_p {<p, b> + max_theta g_theta(p)}
- V* = max_theta min_p {<p, b> + g_theta(p)}
- g_theta(p) = E[(r - <p, a>)_+]
"""

import numpy as np
from scipy.optimize import minimize, minimize_scalar
from typing import Dict, List, Tuple, Optional, Any
from dataclasses import dataclass
import warnings


@dataclass
class OracleResult:
    """Result of oracle computation."""
    V_mix: float                    # Switching-aware oracle value
    V_star: float                   # Best fixed-config oracle value
    w_star: np.ndarray              # Optimal mixture (K,)
    p_star: np.ndarray              # Optimal price (d,)
    p_per_config: np.ndarray        # Optimal price per config (K, d)
    V_per_config: np.ndarray        # Value per config (K,)
    best_fixed_config: int          # Index of best fixed config
    gap: float                      # V^mix - V*


class OracleComputer:
    """
    Compute oracle values V^mix and V* from samples.

    The oracle computation is based on empirical estimates of the surplus
    function g_theta(p) = E[(r - <p, a>)_+].

    Parameters
    ----------
    K : int
        Number of configurations
    d : int
        Number of resource dimensions
    P_max : float
        Upper bound on price box (default: inferred from data)
    n_price_grid : int
        Number of grid points for price optimization (default: 50)
    """

    def __init__(
        self,
        K: int,
        d: int,
        P_max: float = 10.0,
        n_price_grid: int = 50
    ):
        self.K = K
        self.d = d
        self.P_max = P_max
        self.n_price_grid = n_price_grid

    def compute_empirical_surplus(
        self,
        samples: List[Tuple[float, np.ndarray]],
        p: np.ndarray
    ) -> float:
        """
        Compute empirical surplus g_hat(p) = (1/n) * sum (r - <p, a>)_+

        Parameters
        ----------
        samples : List[Tuple[float, np.ndarray]]
            List of (reward, consumption) tuples
        p : np.ndarray
            Price vector of shape (d,)

        Returns
        -------
        float
            Empirical surplus value
        """
        if len(samples) == 0:
            return 0.0

        surplus = 0.0
        for r, a in samples:
            surplus += max(0.0, r - np.dot(p, a))
        return surplus / len(samples)

    def compute_empirical_surplus_vectorized(
        self,
        rewards: np.ndarray,
        consumptions: np.ndarray,
        p: np.ndarray
    ) -> float:
        """
        Vectorized computation of empirical surplus.

        Parameters
        ----------
        rewards : np.ndarray
            Rewards of shape (n,)
        consumptions : np.ndarray
            Consumptions of shape (n, d)
        p : np.ndarray
            Price vector of shape (d,)

        Returns
        -------
        float
            Empirical surplus value
        """
        if len(rewards) == 0:
            return 0.0
        margins = rewards - consumptions @ p
        return np.mean(np.maximum(margins, 0))

    def compute_empirical_consumption(
        self,
        rewards: np.ndarray,
        consumptions: np.ndarray,
        p: np.ndarray
    ) -> np.ndarray:
        """
        Compute empirical threshold consumption h_hat(p).

        h_hat(p) = (1/n) * sum a * 1{r > <p, a>}

        Parameters
        ----------
        rewards : np.ndarray
            Rewards of shape (n,)
        consumptions : np.ndarray
            Consumptions of shape (n, d)
        p : np.ndarray
            Price vector of shape (d,)

        Returns
        -------
        np.ndarray
            Empirical consumption of shape (d,)
        """
        if len(rewards) == 0:
            return np.zeros(self.d)
        margins = rewards - consumptions @ p
        accept_mask = (margins > 0).astype(float)
        return np.mean(consumptions * accept_mask[:, np.newaxis], axis=0)

    def _objective_V_mix(
        self,
        p: np.ndarray,
        samples_per_config: Dict[int, Tuple[np.ndarray, np.ndarray]],
        b_safe: np.ndarray,
        beta: Optional[np.ndarray] = None
    ) -> float:
        """
        Objective for V^mix: <p, b_safe> + max_theta (g_theta(p) + beta_theta)

        Parameters
        ----------
        p : np.ndarray
            Price vector (d,)
        samples_per_config : Dict[int, Tuple[np.ndarray, np.ndarray]]
            Mapping from config to (rewards, consumptions)
        b_safe : np.ndarray
            Safe budget per period (d,)
        beta : Optional[np.ndarray]
            Confidence bonuses (K,), default None (no optimism)

        Returns
        -------
        float
            Objective value
        """
        linear_term = np.dot(p, b_safe)

        surplus_per_config = []
        for theta in range(self.K):
            if theta in samples_per_config:
                rewards, consumptions = samples_per_config[theta]
                g_theta = self.compute_empirical_surplus_vectorized(rewards, consumptions, p)
            else:
                g_theta = 0.0

            bonus = beta[theta] if beta is not None else 0.0
            surplus_per_config.append(g_theta + bonus)

        envelope = max(surplus_per_config)
        return linear_term + envelope

    def _objective_V_theta(
        self,
        p: np.ndarray,
        rewards: np.ndarray,
        consumptions: np.ndarray,
        b: np.ndarray
    ) -> float:
        """
        Objective for fixed-config oracle: <p, b> + g_theta(p)
        """
        linear_term = np.dot(p, b)
        g_theta = self.compute_empirical_surplus_vectorized(rewards, consumptions, p)
        return linear_term + g_theta

    def compute_V_mix(
        self,
        samples_per_config: Dict[int, Tuple[np.ndarray, np.ndarray]],
        b: np.ndarray,
        slack: float = 0.0,
        beta: Optional[np.ndarray] = None,
        return_full: bool = False
    ) -> Tuple[float, np.ndarray, np.ndarray]:
        """
        Compute V^mix = min_p {<p, b_safe> + max_theta g_theta(p)}

        Uses coordinate descent / grid search for optimization.

        Parameters
        ----------
        samples_per_config : Dict[int, Tuple[np.ndarray, np.ndarray]]
            Mapping from config index to (rewards array, consumptions array)
        b : np.ndarray
            Per-period budget (d,)
        slack : float
            Slack factor epsilon, b_safe = (1 - slack) * b
        beta : Optional[np.ndarray]
            Confidence bonuses for UCB (K,)
        return_full : bool
            If True, also return optimal w* achieving the envelope

        Returns
        -------
        V_mix : float
            Switching-aware oracle value
        p_star : np.ndarray
            Optimal price (d,)
        w_star : np.ndarray
            Optimal mixture (K,) - configs achieving envelope at p_star
        """
        b_safe = (1 - slack) * b

        # Use scipy minimize with multiple restarts
        best_val = np.inf
        best_p = np.zeros(self.d)

        # Try multiple starting points
        n_restarts = 5
        for restart in range(n_restarts):
            if restart == 0:
                p_init = np.zeros(self.d)
            else:
                p_init = np.random.uniform(0, self.P_max / 2, self.d)

            def objective(p):
                return self._objective_V_mix(p, samples_per_config, b_safe, beta)

            result = minimize(
                objective,
                p_init,
                method='L-BFGS-B',
                bounds=[(0, self.P_max)] * self.d,
                options={'maxiter': 200}
            )

            if result.fun < best_val:
                best_val = result.fun
                best_p = result.x

        # Compute optimal mixture w* (configs achieving envelope at p_star)
        surplus_per_config = []
        for theta in range(self.K):
            if theta in samples_per_config:
                rewards, consumptions = samples_per_config[theta]
                g_theta = self.compute_empirical_surplus_vectorized(rewards, consumptions, best_p)
            else:
                g_theta = 0.0
            bonus = beta[theta] if beta is not None else 0.0
            surplus_per_config.append(g_theta + bonus)

        surplus_per_config = np.array(surplus_per_config)
        max_surplus = np.max(surplus_per_config)

        # w* puts weight on configs achieving the max (ties allowed)
        eps_tie = 1e-8
        achieving_max = (surplus_per_config >= max_surplus - eps_tie)
        w_star = achieving_max.astype(float)
        w_star /= w_star.sum()  # Normalize to simplex

        return best_val, best_p, w_star

    def compute_V_star(
        self,
        samples_per_config: Dict[int, Tuple[np.ndarray, np.ndarray]],
        b: np.ndarray
    ) -> Tuple[float, int, np.ndarray, np.ndarray]:
        """
        Compute V* = max_theta min_p {<p, b> + g_theta(p)}

        Parameters
        ----------
        samples_per_config : Dict[int, Tuple[np.ndarray, np.ndarray]]
            Mapping from config index to (rewards array, consumptions array)
        b : np.ndarray
            Per-period budget (d,)

        Returns
        -------
        V_star : float
            Best fixed-config oracle value
        best_theta : int
            Index of best fixed config
        p_per_config : np.ndarray
            Optimal price per config (K, d)
        V_per_config : np.ndarray
            Value per config (K,)
        """
        V_per_config = np.zeros(self.K)
        p_per_config = np.zeros((self.K, self.d))

        for theta in range(self.K):
            if theta not in samples_per_config:
                V_per_config[theta] = 0.0
                continue

            rewards, consumptions = samples_per_config[theta]

            def objective(p):
                return self._objective_V_theta(p, rewards, consumptions, b)

            # Multiple restarts
            best_val = np.inf
            best_p = np.zeros(self.d)

            for restart in range(3):
                if restart == 0:
                    p_init = np.zeros(self.d)
                else:
                    p_init = np.random.uniform(0, self.P_max / 2, self.d)

                result = minimize(
                    objective,
                    p_init,
                    method='L-BFGS-B',
                    bounds=[(0, self.P_max)] * self.d,
                    options={'maxiter': 200}
                )

                if result.fun < best_val:
                    best_val = result.fun
                    best_p = result.x

            V_per_config[theta] = best_val
            p_per_config[theta] = best_p

        best_theta = int(np.argmax(V_per_config))
        V_star = V_per_config[best_theta]

        return V_star, best_theta, p_per_config, V_per_config

    def compute_full_oracle(
        self,
        samples_per_config: Dict[int, Tuple[np.ndarray, np.ndarray]],
        b: np.ndarray,
        slack: float = 0.0
    ) -> OracleResult:
        """
        Compute both V^mix and V* and return full result.

        Parameters
        ----------
        samples_per_config : Dict[int, Tuple[np.ndarray, np.ndarray]]
            Mapping from config index to (rewards array, consumptions array)
        b : np.ndarray
            Per-period budget (d,)
        slack : float
            Slack factor for safe budget

        Returns
        -------
        OracleResult
            Complete oracle computation result
        """
        # Compute V^mix
        V_mix, p_star, w_star = self.compute_V_mix(samples_per_config, b, slack)

        # Compute V*
        V_star, best_theta, p_per_config, V_per_config = self.compute_V_star(
            samples_per_config, b
        )

        return OracleResult(
            V_mix=V_mix,
            V_star=V_star,
            w_star=w_star,
            p_star=p_star,
            p_per_config=p_per_config,
            V_per_config=V_per_config,
            best_fixed_config=best_theta,
            gap=V_mix - V_star
        )


def compute_oracle_from_loader(
    loader,
    n_samples: int = 10000,
    seed: int = 42
) -> OracleResult:
    """
    Convenience function to compute oracle from a data loader.

    Parameters
    ----------
    loader : BaseDataLoader
        Data loader with get_arrival(regime, t) method
    n_samples : int
        Number of samples to use for estimation
    seed : int
        Random seed

    Returns
    -------
    OracleResult
        Oracle computation result
    """
    np.random.seed(seed)

    K = loader.K
    d = loader.d

    # Collect samples
    samples_per_config = {}
    for theta in range(K):
        rewards = []
        consumptions = []
        for t in range(min(n_samples, loader.T)):
            r, a = loader.get_arrival(theta, t)
            rewards.append(r)
            consumptions.append(a)
        samples_per_config[theta] = (np.array(rewards), np.array(consumptions))

    # Get budget
    b = loader.get_base_budget(1.0) / loader.T  # Per-period budget

    # Compute R_max, A_max for P_max
    R_max = max(np.max(samples_per_config[theta][0]) for theta in range(K))
    A_max = max(np.max(samples_per_config[theta][1]) for theta in range(K))
    b_min = np.min(b[b > 0]) if np.any(b > 0) else 1.0
    P_max = R_max / b_min + 1.0

    # Compute oracle
    oracle = OracleComputer(K, d, P_max)
    return oracle.compute_full_oracle(samples_per_config, b)


# ============================================================================
# Analytical Oracle for Deterministic Distributions (S1)
# ============================================================================

def compute_S1_oracle_analytical(b: np.ndarray) -> OracleResult:
    """
    Compute oracle analytically for S1: Pure Complementarity.

    S1 Setup (K=4, d=3):
    - theta=0: r=1, a=(1, 0, 0)
    - theta=1: r=1, a=(0, 1, 0)
    - theta=2: r=1, a=(0, 0, 1)
    - theta=3: r=0.9, a=(0.3, 0.3, 0.4)

    Parameters
    ----------
    b : np.ndarray
        Per-period budget (d=3,)

    Returns
    -------
    OracleResult
        Analytical oracle result
    """
    K, d = 4, 3

    # For deterministic configs:
    # g_theta(p) = (r - <p, a>)_+ for the single outcome

    # V* = max_theta min_p {<p, b> + g_theta(p)}
    # For theta=0,1,2: g_theta(p) = (1 - p_i)_+ where i is the active dimension
    # V_theta = min_p {<p, b> + (1 - p_i)_+}
    # Optimal: p_i = max(0, 1 - something), V_theta = b_i (when p_i=1 and g=0)
    # Actually, if b_i < 1, optimal is p_i = 1, g = 0, V_theta = b_i
    # If b_i >= 1, optimal is p_i = 0, g = 1, V_theta = 1

    # For complementary budgets b = (1/3, 1/3, 1/3):
    # V_0 = V_1 = V_2 = 1/3 (set p_i = 1, accept nothing)
    # min_p {b_i * p_i + (1 - p_i)_+} for theta in {0,1,2}
    # If p_i <= 1: objective = b_i * p_i + 1 - p_i = 1 + (b_i - 1) * p_i
    #   If b_i < 1: increasing in p_i, so minimize at p_i = 0: objective = 1
    #   If b_i > 1: decreasing in p_i, so minimize at p_i = 1: objective = b_i
    # If p_i > 1: objective = b_i * p_i + 0 = b_i * p_i, minimized at p_i = 1: objective = b_i

    # For b_i = 1/3 < 1: V_theta = 1 (optimal p_i = 0, accept all)
    V_per_config = np.array([1.0, 1.0, 1.0, 0.0])  # theta=3 is dominated

    # Actually, theta=3: r=0.9, a=(0.3, 0.3, 0.4)
    # min_p {<p, b> + (0.9 - 0.3*p_0 - 0.3*p_1 - 0.4*p_2)_+}
    # At p=0: objective = 0.9
    # This is dominated by theta=0,1,2 which give V=1.0
    V_per_config[3] = 0.9

    V_star = 1.0
    best_theta = 0

    p_per_config = np.zeros((K, d))
    # For theta=0,1,2: optimal p = 0
    # For theta=3: also p = 0 for max surplus

    # V^mix = min_p {<p, b> + max_theta g_theta(p)}
    # At p=0: max_theta g_theta(0) = max(1, 1, 1, 0.9) = 1
    # Objective = 0 + 1 = 1

    # But can we do better with p > 0?
    # With mixture w = (1/3, 1/3, 1/3, 0):
    # Expected consumption = (1/3, 1/3, 1/3) = b
    # Expected reward = 1 * (1/3 + 1/3 + 1/3) = 1 per period
    # But wait, this exceeds V^mix = 1?

    # Actually V^mix is the dual form. The primal interpretation:
    # V^mix = max_{w, x_theta} sum_theta w_theta E[r * x_theta]
    #         s.t. sum_theta w_theta E[a * x_theta] <= b

    # With w = (1/3, 1/3, 1/3, 0) and x_theta = 1 (accept all):
    # Expected consumption = 1/3 * (1,0,0) + 1/3 * (0,1,0) + 1/3 * (0,0,1) = (1/3, 1/3, 1/3)
    # If b = (1/3, 1/3, 1/3), constraint is satisfied with equality!
    # Expected reward = 1/3 * 1 + 1/3 * 1 + 1/3 * 1 = 1

    # So V^mix = 1 for b = (1/3, 1/3, 1/3)
    # But V* = 1 as well!

    # Tighter budgets are required for the pathology to manifest.
    # If b = (0.5, 0.5, 0.5) but we can only use one config:
    # V* = 1 (pick any of theta=0,1,2, accept all)
    # V^mix with w = (1/2, 1/2, 0, 0):
    # Expected consumption = 1/2 * (1,0,0) + 1/2 * (0,1,0) = (0.5, 0.5, 0)
    # This uses resources 1 and 2 but not 3. Not helpful.

    # The key insight is that with b = (1/K, 1/K, ..., 1/K) and K orthogonal configs,
    # V^mix = 1 (mixture uses all resources equally)
    # V* = 1 as well (any single config uses only one resource but accepts 100%)

    # The pathology V^mix > V* appears when:
    # b is such that a mixture can utilize more total resources than any single config.

    # Let's reconsider: for S1 with d=3, K=4:
    # b = (1/3, 1/3, 1/3) means each resource has budget 1/3 per period
    # Config 0 uses resource 0 at rate 1 -> can accept b_0 = 1/3 of arrivals
    # Config 1 uses resource 1 at rate 1 -> can accept b_1 = 1/3 of arrivals
    # Etc.

    # With fixed config 0: accept rate = min(1, b_0 / 1) = 1/3
    # Expected reward = 1/3 * 1 = 1/3
    # V* = 1/3

    # With mixture w = (1/3, 1/3, 1/3, 0):
    # Each config is used 1/3 of the time, each accepts all arrivals
    # Expected consumption per resource = 1/3 * 1 = 1/3 = b_i  (exactly balanced!)
    # Expected reward = 1/3 * 1 + 1/3 * 1 + 1/3 * 1 = 1
    # V^mix = 1

    # So V^mix = 1, V* = 1/3, gap = 2/3!

    # The primal LP for fixed config theta=0:
    # max E[r * x] s.t. E[a * x] <= b
    # For theta=0: r=1, a=(1,0,0)
    # max x s.t. x <= b_0 (resource 0 constraint)
    # Optimal x = b_0 = 1/3, reward = 1/3

    V_per_config = np.minimum(np.array([b[0], b[1], b[2], 0.9 * min(b[0]/0.3, b[1]/0.3, b[2]/0.4)]), 1.0)
    V_star = float(np.max(V_per_config))
    best_theta = int(np.argmax(V_per_config))

    # V^mix: mixture can achieve reward = sum(b) when configs are orthogonal
    # w_star = (b[0], b[1], b[2], 0) normalized
    total_b = b[0] + b[1] + b[2]
    if total_b > 0:
        w_star = np.array([b[0]/total_b, b[1]/total_b, b[2]/total_b, 0.0])
    else:
        w_star = np.array([1/3, 1/3, 1/3, 0.0])

    # With this mixture, expected reward = total_b (up to capacity 1 per config)
    V_mix = min(total_b, 1.0)  # Capped at max reward

    # Optimal price at V^mix: need to solve the dual
    # At optimum with complementary configs, p_star often = 0
    p_star = np.zeros(d)

    return OracleResult(
        V_mix=V_mix,
        V_star=V_star,
        w_star=w_star,
        p_star=p_star,
        p_per_config=p_per_config,
        V_per_config=V_per_config,
        best_fixed_config=best_theta,
        gap=V_mix - V_star
    )
