"""
Multi-Constraint Constrained Online Convex Optimization (MC-COCO)
Algorithm Implementations

Implements:
- Algorithm MC-1: Multi-Constraint Constrained Expert (Theorem 2)
- Algorithm MC-3: Smooth Convex MC-COCO with Adaptive OGD (Theorem 4)
- Algorithm MC-1-Hetero: Heterogeneous Prioritization (Theorem 5)
- Baselines: Naive independent constraint handling, drift-plus-penalty
"""

import numpy as np
from typing import Optional, List, Tuple


# ============================================================
# Algorithm MC-1: Multi-Constraint Constrained Expert
# ============================================================

class MC1_ConstrainedExperts:
    """
    Algorithm MC-1 from Theorem 2.

    Maintains expert probabilities via Adaptive Hedge with exponential
    Lyapunov surrogate cost encoding all K constraints.

    Parameters
    ----------
    N : int
        Number of experts
    K : int
        Number of constraints
    T : int
        Time horizon
    beta : float
        Regret-CCV trade-off parameter in [0, 1]
    c : float
        Universal constant (default 10)
    """

    def __init__(self, N: int, K: int, T: int, beta: float = 0.5, c: float = 10.0):
        self.N = N
        self.K = K
        self.T = T
        self.beta = beta
        self.c = c

        # Lyapunov parameter
        self.lam = T ** (-(1 - beta)) / (2 * c * np.log(N))

        # State
        self.Q = np.zeros(K)                  # Cumulative constraint violations
        self.L_cum = np.zeros(N)              # Cumulative surrogate loss per expert
        self.L_tilde = 0.0                    # Algorithm's cumulative surrogate loss
        self.G_prev = 1 + K * self.lam        # G_0
        self.gamma = np.exp(self.lam)

        # History
        self.regret_history = []
        self.ccv_history = []                  # K-dimensional CCV at each step
        self.cost_history = []

    def _compute_p(self) -> np.ndarray:
        """Compute expert probability distribution."""
        G_prev = self.G_prev
        eta = (1.0 / np.sqrt(G_prev)) * np.sqrt(
            np.log(self.N) / (self.L_tilde + self.gamma * G_prev)
        )
        # Numerical stability: subtract max for softmax
        logits = -eta * self.L_cum
        logits -= logits.max()
        weights = np.exp(logits)
        p = weights / weights.sum()
        return p

    def step(self, f_t: np.ndarray, g_t: np.ndarray) -> np.ndarray:
        """
        Execute one round of Algorithm MC-1.

        Parameters
        ----------
        f_t : np.ndarray, shape (N,)
            Cost for each expert, values in [0, 1]
        g_t : np.ndarray, shape (K, N)
            Constraint values for each constraint and expert, values in [0, 1]

        Returns
        -------
        p_t : np.ndarray, shape (N,)
            Expert probability distribution
        """
        # Step 2: select expert probability
        p_t = self._compute_p()

        # Step 4: update CCV
        for k in range(self.K):
            self.Q[k] += np.dot(g_t[k], p_t)

        # Step 5: update G_t
        G_t = 1.0 + self.lam * np.sum(np.exp(self.lam * self.Q))
        self.G_prev = G_t

        # Step 6: compute surrogate cost vector
        f_hat = f_t.copy()
        for k in range(self.K):
            f_hat += self.lam * np.exp(self.lam * self.Q[k]) * g_t[k]

        # Step 7: update cumulative losses
        self.L_tilde += np.dot(f_hat, p_t)
        self.L_cum += f_hat

        # Track cost
        self.cost_history.append(np.dot(f_t, p_t))
        self.ccv_history.append(self.Q.copy())

        return p_t

    def get_regret(self, comparator_costs: np.ndarray) -> float:
        """Compute regret w.r.t. a comparator's cumulative cost."""
        return sum(self.cost_history) - comparator_costs

    def get_per_constraint_ccv(self) -> np.ndarray:
        """Return final per-constraint CCV."""
        return self.Q.copy()

    def get_max_ccv(self) -> float:
        """Return max per-constraint CCV."""
        return np.max(self.Q)

    def get_total_ccv(self) -> float:
        """Return total CCV."""
        return np.sum(self.Q)


# ============================================================
# Algorithm MC-3: Smooth Convex MC-COCO (Adaptive OGD)
# ============================================================

class MC3_SmoothConvex:
    """
    Algorithm MC-3 from Theorem 4.

    Uses adaptive OGD on smooth convex surrogate costs with
    exponential Lyapunov encoding of K constraints.

    Parameters
    ----------
    d : int
        Dimension
    K : int
        Number of constraints
    T : int
        Time horizon
    D : float
        Diameter of decision set
    M : float
        Smoothness constant
    beta : float
        Trade-off parameter
    """

    def __init__(self, d: int, K: int, T: int, D: float = 1.0,
                 M: float = 1.0, beta: float = 0.5):
        self.d = d
        self.K = K
        self.T = T
        self.D = D
        self.M = M
        self.beta = beta

        # Lyapunov parameter
        self.lam = T ** (-(1 - beta)) / (8 * D**2 * M)

        # State
        self.x = np.zeros(d)          # Current action (initialized at origin)
        self.Q = np.zeros(K)           # Cumulative constraint violations
        self.Sigma_grad = 1e-8         # Gradient accumulator (small init for stability)

        # History
        self.cost_history = []
        self.ccv_history = []

    def _project(self, x: np.ndarray) -> np.ndarray:
        """Project onto L2 ball of radius D centered at origin."""
        norm = np.linalg.norm(x)
        if norm > self.D:
            x = x * (self.D / norm)
        return x

    def step(self, f_val: float, f_grad: np.ndarray,
             g_vals: np.ndarray, g_grads: np.ndarray) -> np.ndarray:
        """
        Execute one round of Algorithm MC-3.

        Parameters
        ----------
        f_val : float
            f_t(x_t)
        f_grad : np.ndarray, shape (d,)
            Gradient of f_t at x_t
        g_vals : np.ndarray, shape (K,)
            g_{k,t}(x_t) for each constraint
        g_grads : np.ndarray, shape (K, d)
            Gradient of g_{k,t} at x_t for each constraint

        Returns
        -------
        x_next : np.ndarray, shape (d,)
            Next action
        """
        # Step 3: update CCV
        self.Q += g_vals

        # Step 5: compute surrogate gradient
        grad = f_grad.copy()
        for k in range(self.K):
            grad += self.lam * np.exp(self.lam * self.Q[k]) * g_grads[k]

        # Step 6: update gradient accumulator and step size
        self.Sigma_grad += np.dot(grad, grad)
        eta = self.D / np.sqrt(2 * self.Sigma_grad)

        # Step 7: OGD update
        self.x = self._project(self.x - eta * grad)

        # Track
        self.cost_history.append(f_val)
        self.ccv_history.append(self.Q.copy())

        return self.x.copy()

    def get_action(self) -> np.ndarray:
        """Return current action."""
        return self.x.copy()

    def get_per_constraint_ccv(self) -> np.ndarray:
        return self.Q.copy()

    def get_max_ccv(self) -> float:
        return np.max(self.Q)

    def get_total_ccv(self) -> float:
        return np.sum(self.Q)


# ============================================================
# Algorithm MC-1-Hetero: Heterogeneous Prioritization (Theorem 5)
# ============================================================

class MC1_Heterogeneous(MC1_ConstrainedExperts):
    """
    Heterogeneous constraint prioritization from Theorem 5.

    Each constraint k gets weight alpha_k, with Lyapunov parameter
    lambda_k = alpha_k * Lambda.

    Parameters
    ----------
    N, K, T, beta, c : same as MC1
    alphas : np.ndarray, shape (K,)
        Per-constraint priority weights in (0, 1]
    """

    def __init__(self, N: int, K: int, T: int, beta: float = 0.5,
                 c: float = 10.0, alphas: Optional[np.ndarray] = None):
        # Do not call super().__init__ directly; we override lambda handling
        self.N = N
        self.K = K
        self.T = T
        self.beta = beta
        self.c = c

        # Per-constraint weights
        if alphas is None:
            alphas = np.ones(K)
        self.alphas = alphas.copy()

        # Base Lyapunov parameter
        self.Lambda = T ** (-(1 - beta)) / (2 * c * np.log(N))
        self.lam_k = self.alphas * self.Lambda  # Per-constraint lambda_k

        # State
        self.Q = np.zeros(K)
        self.L_cum = np.zeros(N)
        self.L_tilde = 0.0
        # G_bar uses Lambda (not individual lambda_k)
        self.G_prev = 1 + self.Lambda * np.sum(np.exp(self.lam_k * self.Q))
        self.gamma = np.exp(self.Lambda)

        # History
        self.regret_history = []
        self.ccv_history = []
        self.cost_history = []

    def step(self, f_t: np.ndarray, g_t: np.ndarray) -> np.ndarray:
        """
        Execute one round with heterogeneous prioritization.
        Uses G_bar_t (with Lambda) as cost upper bound sequence.
        """
        # Step 2: select expert probability
        p_t = self._compute_p()

        # Step 4: update CCV
        for k in range(self.K):
            self.Q[k] += np.dot(g_t[k], p_t)

        # Step 5: update G_bar_t (using Lambda, not individual lambda_k)
        G_t = 1.0 + self.Lambda * np.sum(np.exp(self.lam_k * self.Q))
        self.G_prev = G_t

        # Step 6: compute surrogate cost vector (using individual lambda_k)
        f_hat = f_t.copy()
        for k in range(self.K):
            f_hat += self.lam_k[k] * np.exp(self.lam_k[k] * self.Q[k]) * g_t[k]

        # Step 7: update cumulative losses
        self.L_tilde += np.dot(f_hat, p_t)
        self.L_cum += f_hat

        # Track
        self.cost_history.append(np.dot(f_t, p_t))
        self.ccv_history.append(self.Q.copy())

        return p_t


# ============================================================
# Baseline: Naive Independent Constraint Handling
# ============================================================

class NaiveIndependent:
    """
    Baseline: treats each constraint independently by running
    K separate single-constraint algorithms and averaging.
    This should show linear K-dependence.
    """

    def __init__(self, N: int, K: int, T: int, beta: float = 0.5, c: float = 10.0):
        self.N = N
        self.K = K
        self.T = T
        self.beta = beta
        self.c = c

        # One algorithm per constraint
        self.lam = T ** (-(1 - beta)) / (2 * c * np.log(N))

        # State per constraint
        self.Q = np.zeros(K)
        self.L_cum = np.zeros((K, N))   # Per-constraint cumulative loss
        self.L_tilde = np.zeros(K)
        self.G_prev = np.full(K, 1.0 + self.lam)
        self.gamma = np.exp(self.lam)

        # History
        self.cost_history = []
        self.ccv_history = []

    def step(self, f_t: np.ndarray, g_t: np.ndarray) -> np.ndarray:
        """Each constraint votes for an expert distribution; average them."""
        p_total = np.zeros(self.N)

        for k in range(self.K):
            # Compute p_k using constraint k's own history
            G_prev = self.G_prev[k]
            eta = (1.0 / np.sqrt(G_prev)) * np.sqrt(
                np.log(self.N) / (self.L_tilde[k] + self.gamma * G_prev)
            )
            logits = -eta * self.L_cum[k]
            logits -= logits.max()
            weights = np.exp(logits)
            p_k = weights / weights.sum()
            p_total += p_k

        p_t = p_total / self.K

        # Update each constraint independently
        for k in range(self.K):
            self.Q[k] += np.dot(g_t[k], p_t)
            G_t = 1.0 + self.lam * np.exp(self.lam * self.Q[k])
            self.G_prev[k] = G_t
            # Surrogate loss for constraint k only
            f_hat_k = f_t + self.lam * np.exp(self.lam * self.Q[k]) * g_t[k]
            self.L_tilde[k] += np.dot(f_hat_k, p_t)
            self.L_cum[k] += f_hat_k

        self.cost_history.append(np.dot(f_t, p_t))
        self.ccv_history.append(self.Q.copy())

        return p_t

    def get_per_constraint_ccv(self) -> np.ndarray:
        return self.Q.copy()

    def get_max_ccv(self) -> float:
        return np.max(self.Q)

    def get_total_ccv(self) -> float:
        return np.sum(self.Q)


# ============================================================
# Adversary: Generate cost and constraint functions
# ============================================================

class AdversaryExpert:
    """
    Adversary for the constrained experts setting.

    Generates:
    - Cost vectors f_t in [0,1]^N
    - Constraint matrices g_t in [0,1]^{K x N}
    - One expert i* that is always feasible: g_{k,t}(i*) = 0 for all k,t

    Parameters
    ----------
    N : int
        Number of experts
    K : int
        Number of constraints
    feasible_expert : int
        Index of the always-feasible expert
    adversary_type : str
        'conflicting' - constraints conflict with each other (default, best)
        'random' - uniformly random costs/constraints
        'asymmetric' - constraints have different inherent difficulty
    constraint_difficulty : Optional[np.ndarray]
        Per-constraint difficulty in (0, 1]. Used by 'asymmetric' adversary.
    seed : int
        Random seed
    """

    def __init__(self, N: int, K: int, feasible_expert: int = 0,
                 adversary_type: str = 'conflicting', seed: int = 42,
                 constraint_difficulty: Optional[np.ndarray] = None):
        self.N = N
        self.K = K
        self.feasible_expert = feasible_expert
        self.adversary_type = adversary_type
        self.rng = np.random.RandomState(seed)
        self.constraint_difficulty = constraint_difficulty
        self.t = 0  # Round counter

        # Pre-compute expert structure for conflicting adversary
        if adversary_type == 'conflicting':
            self._init_conflicting()
        elif adversary_type == 'asymmetric':
            self._init_asymmetric()
        # uniform_hard needs no special init

    def _init_conflicting(self):
        """Initialize conflicting adversary.

        Key design: for each constraint k, a different subset of experts
        has LOW constraint values (good for k), and HIGH values for
        other constraints. This creates a fundamental tension —
        the only way to satisfy all constraints is the feasible expert i*.

        Expert i* has cost ~0.5 and g=0 everywhere.
        Other experts have cost ~0.2 (lower!) but violate some constraints.
        This forces the algorithm to trade off cost vs feasibility.
        """
        N, K = self.N, self.K
        # Partition non-feasible experts into K groups
        non_feas = [i for i in range(N) if i != self.feasible_expert]
        # Build boolean mask: group_mask[k, j] = True if expert j is in group k
        self.group_mask = np.zeros((K, N), dtype=bool)
        group_size = max(1, len(non_feas) // K)
        for k in range(K):
            start = k * group_size
            end = start + group_size if k < K - 1 else len(non_feas)
            for idx in range(start, min(end, len(non_feas))):
                self.group_mask[k, non_feas[idx]] = True

    def _init_asymmetric(self):
        """Initialize asymmetric adversary for heterogeneous priority testing.

        Each constraint k has a different difficulty level.
        Higher difficulty = more experts violate this constraint heavily.
        """
        if self.constraint_difficulty is None:
            # Default: geometric decrease
            self.constraint_difficulty = np.array(
                [1.0 / (2**k) for k in range(self.K)])

    def generate(self) -> Tuple[np.ndarray, np.ndarray]:
        """Generate one round of cost and constraint vectors."""
        self.t += 1

        if self.adversary_type == 'conflicting':
            return self._gen_conflicting()
        elif self.adversary_type == 'random':
            return self._gen_random()
        elif self.adversary_type == 'asymmetric':
            return self._gen_asymmetric()
        elif self.adversary_type == 'uniform_hard':
            return self._gen_uniform_hard()
        else:
            raise ValueError(f"Unknown adversary type: {self.adversary_type}")

    def _gen_random(self) -> Tuple[np.ndarray, np.ndarray]:
        """Simple random adversary."""
        f_t = self.rng.uniform(0, 1, self.N)
        g_t = self.rng.uniform(0, 1, (self.K, self.N))
        g_t[:, self.feasible_expert] = 0.0
        return f_t, g_t

    def _gen_uniform_hard(self) -> Tuple[np.ndarray, np.ndarray]:
        """Uniform-hard adversary: all non-feasible experts violate all constraints.

        No constraint-specific 'good expert groups'. Every non-feasible expert
        has high violation on EVERY constraint (~0.6-0.9). The only safe choice
        is the feasible expert (expensive, cost ~0.5).

        This forces the algorithm to balance ALL constraints simultaneously.
        With heterogeneous alpha_k, higher alpha_k constraints get stronger
        Lyapunov penalty and thus lower CCV. Lower alpha_k constraints
        have weaker penalty and accumulate more violation.

        This is the ideal setting to test Theorem 5 (CCV_k ~ 1/alpha_k).
        """
        N, K = self.N, self.K

        # All non-feasible experts are cheap but violate all constraints
        f_t = self.rng.uniform(0.10, 0.25, N)
        f_t[self.feasible_expert] = 0.5 + self.rng.uniform(-0.05, 0.05)

        # All non-feasible experts have high violation on ALL constraints
        g_t = self.rng.uniform(0.6, 0.9, (K, N))
        g_t[:, self.feasible_expert] = 0.0

        f_t = np.clip(f_t, 0.0, 1.0)
        g_t = np.clip(g_t, 0.0, 1.0)
        return f_t, g_t

    def _gen_conflicting(self) -> Tuple[np.ndarray, np.ndarray]:
        """Conflicting constraints adversary (vectorized).

        - Expert i* has cost ~0.5, g=0 (feasible but expensive)
        - Experts in group k have cost ~0.15 (cheap!) but:
            - g_{k,t} ~ 0.05 (low for their own constraint)
            - g_{j,t} ~ 0.8 for j ≠ k (high for other constraints)
        """
        N, K = self.N, self.K
        f_t = np.full(N, 0.3) + self.rng.uniform(-0.05, 0.05, N)

        # Default: high violation for all experts on all constraints
        g_t = self.rng.uniform(0.5, 1.0, (K, N))

        # For each constraint k, experts in group k get low violation
        low_vals = self.rng.uniform(0.0, 0.1, (K, N))
        g_t[self.group_mask] = low_vals[self.group_mask]

        # Feasible expert: expensive but feasible
        f_t[self.feasible_expert] = 0.5 + self.rng.uniform(-0.05, 0.05)
        g_t[:, self.feasible_expert] = 0.0

        # Group experts get lower cost
        any_group = self.group_mask.any(axis=0)  # (N,) mask of experts in any group
        low_costs = self.rng.uniform(0.10, 0.20, N)
        f_t[any_group] = low_costs[any_group]

        f_t = np.clip(f_t, 0.0, 1.0)
        g_t = np.clip(g_t, 0.0, 1.0)
        return f_t, g_t

    def _gen_asymmetric(self) -> Tuple[np.ndarray, np.ndarray]:
        """Asymmetric constraints (vectorized): different constraints have different difficulty."""
        N, K = self.N, self.K
        d = self.constraint_difficulty

        f_t = self.rng.uniform(0.1, 0.4, N)
        f_t[self.feasible_expert] = 0.5 + self.rng.uniform(-0.05, 0.05)

        # Generate mask: high violation with probability d[k] per expert
        mask = self.rng.uniform(0, 1, (K, N)) < d[:, None]  # (K, N) bool
        high_vals = self.rng.uniform(0.5, 1.0, (K, N))
        low_vals = self.rng.uniform(0.0, 0.15, (K, N))
        g_t = np.where(mask, high_vals, low_vals)
        g_t[:, self.feasible_expert] = 0.0

        f_t = np.clip(f_t, 0.0, 1.0)
        g_t = np.clip(g_t, 0.0, 1.0)
        return f_t, g_t


class AdversarySmooth:
    """
    Adversary for smooth convex setting.

    Creates strong cost-constraint tension:
    - Cost is minimized far from origin (at v_t near boundary)
    - Constraints are 0 at origin, grow with ||x|| along constraint directions
    - K constraint directions conflict with each other
    - Algorithm must balance low cost (far from origin) vs low constraint (near origin)

    Cost: f_t(x) = 0.5 * a * ||x - v_t||^2 / normalizer
    Constraint: g_{k,t}(x) = c_k * (w_k^T x)^2 / D^2 + b_k * max(0, w_k^T x) / D
    where the linear term b_k * max(0, w_k^T x) / D makes constraints grow faster
    while maintaining g_{k,t}(0) = 0.

    Parameters
    ----------
    d : int
        Dimension
    K : int
        Number of constraints
    D : float
        Radius of L2 ball
    M : float
        Smoothness constant
    seed : int
        Random seed
    """

    def __init__(self, d: int, K: int, D: float = 1.0, M: float = 1.0, seed: int = 42):
        self.d = d
        self.K = K
        self.D = D
        self.M = M
        self.rng = np.random.RandomState(seed)
        self.t = 0
        self.last_f_at_origin = 0.0  # Store for regret calculation

        # Pre-compute fixed constraint directions (orthogonalized for conflict)
        self.w_fixed = []
        for k in range(K):
            w = self.rng.randn(d)
            w = w / (np.linalg.norm(w) + 1e-10)
            self.w_fixed.append(w)

    def generate(self, x: np.ndarray) -> Tuple[float, np.ndarray, np.ndarray, np.ndarray]:
        """
        Generate cost and constraint at current action x.

        Returns (f_val, f_grad, g_vals, g_grads)
        """
        self.t += 1

        # Cost function: f_t(x) = 0.5 * a * ||x - v_t||^2 / normalizer
        # v_t is pushed to the boundary of the ball to create strong tension
        a = self.rng.uniform(0.5, self.M)
        v = self.rng.randn(self.d)
        v = v / (np.linalg.norm(v) + 1e-10) * self.D * self.rng.uniform(0.7, 0.95)

        diff = x - v
        raw_f = 0.5 * a * np.sum(diff**2)
        raw_grad = a * diff
        normalizer = max(2 * a * self.D**2, 1.0)
        f_val = raw_f / normalizer
        f_grad = raw_grad / normalizer

        # Store f_t(0) exactly for regret computation
        self.last_f_at_origin = 0.5 * a * np.sum(v**2) / normalizer

        # Constraint functions with both quadratic and linear terms
        # g_{k,t}(x) = c_k * (w_k^T x)^2 / D^2 + b_k * max(0, w_k^T x) / D
        # g_{k,t}(0) = 0 (both terms vanish at origin)
        # The linear term makes constraint grow O(||x||/D) instead of O(||x||^2/D^2)
        g_vals = np.zeros(self.K)
        g_grads = np.zeros((self.K, self.d))

        for k in range(self.K):
            w_k = self.w_fixed[k]
            c_k = self.rng.uniform(0.3, 0.8)
            b_k = self.rng.uniform(0.3, 0.7)  # Linear coefficient

            proj = np.dot(w_k, x)

            # Quadratic part: c_k * proj^2 / D^2
            quad_val = c_k * proj**2 / self.D**2
            quad_grad = 2 * c_k * proj * w_k / self.D**2

            # Linear part: b_k * max(0, proj) / D
            if proj > 0:
                lin_val = b_k * proj / self.D
                lin_grad = b_k * w_k / self.D
            else:
                lin_val = 0.0
                lin_grad = np.zeros(self.d)

            g_vals[k] = quad_val + lin_val
            g_grads[k] = quad_grad + lin_grad

            # Clip to [0, 1]
            g_vals[k] = min(max(g_vals[k], 0.0), 1.0)

        return f_val, f_grad, g_vals, g_grads

    def get_last_f_at_origin(self) -> float:
        """Return f_t(0) for the most recent round. Call after generate()."""
        return self.last_f_at_origin


# ============================================================
# Theoretical bound computation
# ============================================================

def theoretical_ccv_bound_ce(T: int, K: int, N: int, beta: float, c: float = 10.0) -> float:
    """Compute Theorem 2 theoretical per-constraint CCV bound."""
    C0 = 8 * c
    return 2 * c * np.log(N) * T**(1 - beta) * np.log(C0 * (K + T + np.log(N)))


def theoretical_regret_bound_ce(T: int, K: int, N: int, beta: float, c: float = 10.0) -> float:
    """Compute Theorem 2 theoretical regret bound."""
    return c * np.sqrt(T * np.log(N)) + (c / 4) * T**beta + c * np.log(N) + K


def theoretical_ccv_bound_smooth(T: int, K: int, D: float, M: float,
                                  beta: float) -> float:
    """Compute Theorem 4 theoretical per-constraint CCV bound."""
    C1 = 20
    return 8 * D**2 * M * T**(1 - beta) * np.log(
        C1 * (K + T + D * np.sqrt(M * T) + D**2 * M))


def theoretical_regret_bound_smooth(T: int, K: int, D: float, M: float,
                                     beta: float) -> float:
    """Compute Theorem 4 theoretical regret bound."""
    return 4 * D * np.sqrt(M * T) + T**beta + 4 * D**2 * M + K
