"""
LA-COCO: Learning-Augmented Online Convex Optimization with Adversarial Constraints
Algorithm Implementations — Revised for meaningful experiments

Key insight: The theory's Lyapunov mechanism is very effective, so we need
adversaries that create genuine cost-constraint tension to see differences
between sub-policies A and B.

Design principle: Use time-varying constraint halfspaces that rotate,
creating a scenario where predicting the NEXT constraint direction helps
the algorithm pre-position itself to avoid future violations.
"""

import numpy as np
from typing import Optional, Tuple, List, Dict


def project_l2_ball(x: np.ndarray, D: float) -> np.ndarray:
    """Project onto L2 ball of radius D centered at origin."""
    norm = np.linalg.norm(x)
    if norm > D:
        return x * (D / norm)
    return x.copy()


# ============================================================
# Sub-policy A: Baseline (mu=0, no prediction)
# ============================================================

class SubPolicyA:
    """Sub-policy A: OGD on surrogate cost WITHOUT prediction (mu=0)."""
    
    def __init__(self, d: int, T: int, G: float = 1.0, D: float = 1.0,
                 alpha: float = 1.0, setting: str = 'strongly_convex'):
        self.d = d
        self.T = T
        self.G = G
        self.D = D
        self.alpha = alpha
        self.setting = setting
        
        if setting == 'strongly_convex':
            self.beta_scale = 1.0
            self.V = 64 * G**2 * np.log(T * np.e) / alpha
        else:
            self.beta_scale = 1.0 / (2 * G * D)
            self.V = 1.0
            self._lambda = 1.0 / (2 * np.sqrt(T))
        
        self.x = np.zeros(d)
        self.Q = 0.0
        self.t = 0
        self.grad_sq_sum = 1e-8
        self.cost_history = []
        self.constraint_violations = []
    
    def _phi_prime(self, q: float) -> float:
        if self.setting == 'strongly_convex':
            return 2.0 * q
        else:
            return self._lambda * np.exp(self._lambda * q)
    
    def _get_step_size(self) -> float:
        if self.setting == 'strongly_convex':
            return 1.0 / (self.V * self.alpha * max(self.t, 1))
        else:
            return self.D / np.sqrt(2 * self.grad_sq_sum)
    
    def step(self, f_val, f_grad, g_val, g_grad):
        self.t += 1
        cv = max(g_val, 0.0)
        self.constraint_violations.append(cv)
        tilde_g = self.beta_scale * cv
        self.Q += tilde_g
        
        tilde_f_grad = self.beta_scale * f_grad
        tilde_g_grad = self.beta_scale * g_grad if g_val > 0 else np.zeros(self.d)
        phi_prime = self._phi_prime(self.Q)
        surrogate_grad = self.V * tilde_f_grad + phi_prime * tilde_g_grad
        
        if self.setting == 'convex':
            self.grad_sq_sum += np.dot(surrogate_grad, surrogate_grad)
        eta = self._get_step_size()
        self.x = project_l2_ball(self.x - eta * surrogate_grad, self.D)
        self.cost_history.append(f_val)
        return self.x.copy()
    
    def get_action(self): return self.x.copy()
    def get_ccv(self): return sum(self.constraint_violations)
    def get_Q(self): return self.Q


# ============================================================
# Sub-policy B: Prediction-augmented (mu=1)
# ============================================================

class SubPolicyB:
    """Sub-policy B: OGD on surrogate cost WITH prediction (mu=1)."""
    
    def __init__(self, d: int, T: int, G: float = 1.0, D: float = 1.0,
                 alpha: float = 1.0, setting: str = 'strongly_convex'):
        self.d = d
        self.T = T
        self.G = G
        self.D = D
        self.alpha = alpha
        self.setting = setting
        
        if setting == 'strongly_convex':
            self.beta_scale = 1.0
            self.V = 64 * G**2 * np.log(T * np.e) / alpha
        else:
            self.beta_scale = 1.0 / (2 * G * D)
            self.V = 1.0
            self._lambda = 1.0 / (2 * np.sqrt(T))
        
        self.x = np.zeros(d)
        self.Q = 0.0
        self.t = 0
        self.grad_sq_sum = 1e-8
        self.cost_history = []
        self.constraint_violations = []
        self.bonus_history = []
    
    def _phi_prime(self, q):
        if self.setting == 'strongly_convex':
            return 2.0 * q
        else:
            return self._lambda * np.exp(self._lambda * q)
    
    def _get_step_size(self):
        if self.setting == 'strongly_convex':
            return 1.0 / (self.V * self.alpha * max(self.t, 1))
        else:
            return self.D / np.sqrt(2 * self.grad_sq_sum)
    
    def set_lambda(self, lam): self._lambda = lam
    
    def step(self, f_val, f_grad, g_val, g_grad, g_pred_val, g_pred_grad):
        self.t += 1
        cv = max(g_val, 0.0)
        self.constraint_violations.append(cv)
        tilde_g = self.beta_scale * cv
        self.Q += tilde_g
        
        tilde_f_grad = self.beta_scale * f_grad
        tilde_g_grad = self.beta_scale * g_grad if g_val > 0 else np.zeros(self.d)
        pred_cv = max(g_pred_val, 0.0)
        tilde_g_pred_grad = self.beta_scale * g_pred_grad if g_pred_val > 0 else np.zeros(self.d)
        
        phi_prime = self._phi_prime(self.Q)
        surrogate_grad = (self.V * tilde_f_grad 
                         + phi_prime * tilde_g_grad 
                         + phi_prime * tilde_g_pred_grad)
        
        self.bonus_history.append(phi_prime * self.beta_scale * pred_cv)
        
        if self.setting == 'convex':
            self.grad_sq_sum += np.dot(surrogate_grad, surrogate_grad)
        eta = self._get_step_size()
        self.x = project_l2_ball(self.x - eta * surrogate_grad, self.D)
        self.cost_history.append(f_val)
        return self.x.copy()
    
    def get_action(self): return self.x.copy()
    def get_ccv(self): return sum(self.constraint_violations)
    def get_Q(self): return self.Q
    def get_bonus(self): return sum(self.bonus_history)


# ============================================================
# Sub-policy C: Standard OGD (convex only)
# ============================================================

class SubPolicyC:
    def __init__(self, d, T, G=1.0, D=1.0):
        self.d, self.T, self.G, self.D = d, T, G, D
        self.eta = D / (G * np.sqrt(T))
        self.x = np.zeros(d)
        self.cost_history = []
        self.constraint_violations = []
    
    def step(self, f_val, f_grad, g_val):
        self.cost_history.append(f_val)
        self.constraint_violations.append(max(g_val, 0.0))
        self.x = project_l2_ball(self.x - self.eta * f_grad, self.D)
        return self.x.copy()
    
    def get_action(self): return self.x.copy()
    def get_ccv(self): return sum(self.constraint_violations)


# ============================================================
# Baseline: Naive OGD (ignores constraints)
# ============================================================

class NaiveOGD:
    """OGD that only minimizes f_t, ignoring constraints entirely."""
    def __init__(self, d, T, G=1.0, D=1.0, alpha=1.0):
        self.d, self.T, self.G, self.D, self.alpha = d, T, G, D, alpha
        self.x = np.zeros(d)
        self.t = 0
        self.constraint_violations = []
        self.cost_history = []
    
    def step(self, f_val, f_grad, g_val):
        self.t += 1
        self.cost_history.append(f_val)
        self.constraint_violations.append(max(g_val, 0.0))
        if self.alpha > 0:
            eta = 1.0 / (self.alpha * self.t)
        else:
            eta = self.D / (self.G * np.sqrt(self.t))
        self.x = project_l2_ball(self.x - eta * f_grad, self.D)
        return self.x.copy()
    
    def get_action(self): return self.x.copy()
    def get_ccv(self): return sum(self.constraint_violations)


# ============================================================
# Baseline: Primal-Dual (standard constrained OCO)
# ============================================================

class PrimalDual:
    """
    Standard primal-dual method for constrained OCO.
    x_{t+1} = Proj(x_t - eta_x * (f_grad + lambda_t * g_grad))
    lambda_{t+1} = max(0, lambda_t + eta_lambda * g_val)
    """
    def __init__(self, d, T, G=1.0, D=1.0, alpha=1.0):
        self.d, self.T, self.G, self.D, self.alpha = d, T, G, D, alpha
        self.x = np.zeros(d)
        self.lam = 0.0
        self.t = 0
        self.eta_x = D / (G * np.sqrt(T))
        self.eta_lam = 1.0 / np.sqrt(T)
        self.constraint_violations = []
        self.cost_history = []
    
    def step(self, f_val, f_grad, g_val, g_grad):
        self.t += 1
        self.cost_history.append(f_val)
        self.constraint_violations.append(max(g_val, 0.0))
        # Primal update
        grad = f_grad + self.lam * g_grad
        self.x = project_l2_ball(self.x - self.eta_x * grad, self.D)
        # Dual update
        self.lam = max(0.0, self.lam + self.eta_lam * g_val)
        return self.x.copy()
    
    def get_action(self): return self.x.copy()
    def get_ccv(self): return sum(self.constraint_violations)


# ============================================================
# Hedge Mixer
# ============================================================

class HedgeMixer:
    def __init__(self, N, T, G=1.0, D=1.0):
        self.N, self.T, self.G, self.D = N, T, G, D
        self.eta_H = np.sqrt(np.log(N) / T)
        self.weights = np.ones(N) / N
        self.weight_history = [self.weights.copy()]
    
    def get_weights(self): return self.weights.copy()
    
    def mix_actions(self, actions):
        result = np.zeros_like(actions[0])
        for i, a in enumerate(actions):
            result += self.weights[i] * a
        return result
    
    def update(self, cvs):
        losses = np.array([cv / (self.G * self.D) for cv in cvs])
        losses = np.clip(losses, 0.0, 1.0)
        self.weights *= np.exp(-self.eta_H * losses)
        self.weights /= self.weights.sum()
        self.weight_history.append(self.weights.copy())
    
    def get_weight_history(self): return np.array(self.weight_history)


# ============================================================
# Adversary: Rotating constraint halfspaces
# ============================================================

class RotatingAdversary:
    """
    Adversary with rotating constraint halfspaces.
    
    Key design for meaningful experiments:
    - Cost: f_t(x) = (alpha/2)||x - v_t||^2, v_t fixed or slowly varying
      This creates a strong pull toward v_t.
    - Constraint: g_t(x) = a_t^T x + b_t where a_t ROTATES over time
      The rotation means the feasible region changes each round.
      Predicting the next a_{t+1} helps B pre-position.
    - Feasible point: x*=0 always satisfies g_t(0) = b_t <= 0.
    
    The rotation speed controls difficulty:
    - Fast rotation = hard to predict = B loses advantage
    - Slow rotation = easy to predict = B gains advantage
    """
    
    def __init__(self, d, T, G=1.0, D=1.0, alpha=1.0,
                 adversary_type='stochastic', prediction_noise=0.0,
                 rotation_speed=0.1, seed=42):
        self.d, self.T, self.G, self.D = d, T, G, D
        self.alpha = alpha
        self.adversary_type = adversary_type
        self.prediction_noise = prediction_noise
        self.rotation_speed = rotation_speed
        self.rng = np.random.RandomState(seed)
        self.t = 0
        self._pregenerate()
    
    def _pregenerate(self):
        T = self.T + 2
        d = self.d
        
        # Constraint: rotating halfspace
        # a_t rotates, creating predictable constraint changes
        self.a_seq = []
        self.b_seq = []
        base_angle = self.rng.uniform(0, 2 * np.pi)
        
        # For OCS: the feasible set moves over time. The comparator x*
        # is a fixed point that is always feasible. The algorithm starts at
        # origin which is sometimes infeasible, and must learn to track the
        # moving feasible set. With predictions, it can anticipate the next
        # constraint and pre-position, yielding O(log T) CCV.
        if self.adversary_type == 'ocs':
            # x* = -D*0.1 * e_1: always in the "safe" half of the L2 ball
            self.x_star_ocs = np.zeros(d)
            self.x_star_ocs[0] = -self.D * 0.03
        
        for i in range(T):
            angle = base_angle + self.rotation_speed * i
            if self.adversary_type != 'ocs':
                # Add random perturbation to the rotation angle
                # This makes the constraint direction less predictable
                angle += self.rng.normal(0, 0.3)
            a = np.zeros(d)
            a[0] = np.cos(angle)
            a[1] = np.sin(angle)
            if d > 2:
                if self.adversary_type == 'ocs':
                    # OCS: smaller random component for more predictable constraints
                    a[2:] = self.rng.randn(d - 2) * 0.05
                else:
                    # Stronger random component in higher dimensions
                    a[2:] = self.rng.randn(d - 2) * 0.3
            a = a / (np.linalg.norm(a) + 1e-10) * self.G
            
            if self.adversary_type == 'ocs':
                # OCS: x* is always feasible with small margin.
                # g(x*) = a^T x* + b <= 0.
                # Origin may violate: g(0) = b can be > 0.
                # The violation at origin is small, so the algorithm
                # can learn to avoid it with the Lyapunov mechanism.
                ax_star = np.dot(a, self.x_star_ocs)
                # Ensure x* has margin 0.01-0.02
                margin = self.rng.uniform(0.01, 0.02)
                b = -ax_star - margin
                # b = -a[0]*(-0.1) - margin = 0.1*cos(angle) - margin
                # When cos(angle) > margin/0.1 ~ 0.1-0.2, b > 0 => origin violates
            else:
                # b is very small negative or zero: g(0) = b <= 0
                # The constraint boundary nearly touches the origin.
                # Any movement toward v_t will violate the constraint.
                b = -self.rng.uniform(0.0, 0.005)
            
            self.a_seq.append(a)
            self.b_seq.append(b)
        
        # Cost: v_t ALIGNED with a_t to create cost-constraint tension
        # The cost pushes the algorithm in the direction of a_t (toward
        # constraint violation), while the constraint pushes it away.
        # v_t has LARGE norm so the pull is strong.
        self.v_seq = []
        for i in range(T):
            if self.adversary_type == 'ocs':
                self.v_seq.append(np.zeros(d))
            else:
                # v_t = scale * a_t / ||a_t|| * D, with large scale
                # This creates STRONG tension: cost wants x near v_t
                # (in the constraint-violating direction), but constraint
                # pushes x back. The algorithm must balance these forces.
                a_dir = self.a_seq[i] / (np.linalg.norm(self.a_seq[i]) + 1e-10)
                scale = self.rng.uniform(0.6, 0.95)
                v = a_dir * self.D * scale
                self.v_seq.append(v)
    
    def get_cost_and_constraint(self, x):
        t = self.t
        if self.adversary_type == 'ocs':
            f_val, f_grad = 0.0, np.zeros(self.d)
        else:
            diff = x - self.v_seq[t]
            f_val = 0.5 * self.alpha * np.sum(diff**2)
            f_grad = self.alpha * diff
            gn = np.linalg.norm(f_grad)
            if gn > self.G:
                f_grad = f_grad * (self.G / gn)
        
        g_val = np.dot(self.a_seq[t], x) + self.b_seq[t]
        g_grad = self.a_seq[t].copy()
        return f_val, f_grad, g_val, g_grad
    
    def get_prediction(self, x):
        """Predict g_{t+1}(x) with optional noise."""
        t_next = self.t + 1
        if t_next >= len(self.a_seq):
            return 0.0, np.zeros(self.d)
        
        true_val = np.dot(self.a_seq[t_next], x) + self.b_seq[t_next]
        true_grad = self.a_seq[t_next].copy()
        
        if self.prediction_noise > 0:
            # Noise in constraint direction prediction
            noise_a = self.rng.randn(self.d) * self.prediction_noise
            noise_b = self.rng.normal(0, self.prediction_noise * 0.1)
            pred_val = np.dot(self.a_seq[t_next] + noise_a, x) + self.b_seq[t_next] + noise_b
            pred_grad = self.a_seq[t_next] + noise_a
        else:
            pred_val = true_val
            pred_grad = true_grad.copy()
        
        return pred_val, pred_grad
    
    def get_prediction_error(self):
        if self.prediction_noise == 0:
            return 0.0
        return self.prediction_noise * self.D * np.sqrt(self.d) + self.prediction_noise * 0.1
    
    def advance(self): self.t += 1
    
    def get_comparator_cost(self):
        t = self.t
        if self.adversary_type == 'ocs':
            # f=0 for OCS, so comparator cost is also 0
            return 0.0
        return 0.5 * self.alpha * np.sum(self.v_seq[t]**2)
    
    def get_comparator_point(self):
        """Return the comparator x* used for regret computation."""
        if self.adversary_type == 'ocs':
            return self.x_star_ocs.copy()
        return np.zeros(self.d)
    
    def reset(self): self.t = 0


class AdaptiveAdversary(RotatingAdversary):
    """Adaptive adversary: creates strongly negative regret.
    
    Strategy: v_t is FIXED at a point deep inside the feasible region.
    The algorithm slowly converges to v_t, achieving f(x_t) → 0.
    Meanwhile f(x*=0) = alpha/2 * ||v||^2 is large and constant.
    Over T rounds: Regret = sum f(x_t) - T * f(0) << 0.
    
    With very negative regret, V*Regret + Bonus < 0, so Event E fails.
    """
    
    def __init__(self, d, T, G=1.0, D=1.0, alpha=1.0,
                 prediction_noise=0.0, seed=42):
        super().__init__(d, T, G, D, alpha, 'adaptive', prediction_noise,
                         rotation_speed=0.1, seed=seed)
        # Fixed v in the feasible interior: v = -D*0.6 * e_last
        # This is deep inside the L2 ball and satisfies all constraints
        # (since a_t is mostly in dims 0,1, a^T v ≈ 0, g(v) ≈ b ≈ 0)
        v_fixed = np.zeros(d)
        v_fixed[-1] = -D * 0.6  # In a dimension orthogonal to constraint rotation
        for i in range(len(self.v_seq)):
            self.v_seq[i] = v_fixed.copy()
        
        # Make constraints loose enough that v_fixed is always feasible
        # but tight enough that the algorithm still has some violations
        for i in range(len(self.b_seq)):
            self.b_seq[i] = -self.rng.uniform(0.0, 0.005)


# ============================================================
# Theoretical bounds
# ============================================================

def theoretical_ccv_B_strongly_convex(T, G, D, alpha, eta_T):
    return 16 * G**2 * np.log(T * np.e) / alpha + 4 * eta_T + 6 * G * D

def theoretical_ccv_A_strongly_convex(T, G, D, alpha):
    return np.sqrt(16/15 * (64 * G**4 * np.log(T*np.e)**2 / alpha**2 
                            + 64 * G**3 * D * np.log(T*np.e) / alpha * T))

def theoretical_ccv_hedge_strongly_convex(T, G, D, alpha):
    return (theoretical_ccv_A_strongly_convex(T, G, D, alpha) 
            + 2 * G * D * np.sqrt(T * np.log(2)))

def theoretical_ccv_B_convex(T, G, D, eta_T):
    return (4 * G * D * np.sqrt(T) + 2 * eta_T) * np.log(2 * (T + 1))

def theoretical_ccv_A_convex(T, G, D):
    return 4 * G * D * np.sqrt(T) * np.log(2 * (T + 1))

def theoretical_crossover(T, G, D, alpha):
    return np.sqrt(G**3 * D * T * np.log(T * np.e) / alpha)


def theoretical_ccv_pd(T, G, D):
    """Primal-Dual CCV = O(sqrt(T)) — standard bound."""
    return G * D * np.sqrt(T)


def theoretical_ccv_naive(T, G, D):
    """Naive OGD ignoring constraints — CCV can be O(T)."""
    return G * D * T
