import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim



# Set seeds
np.random.seed(42)
torch.manual_seed(42)

class ContextualDuelingBanditEnv:
    """
    Contextual Dueling Bandit Environment.
    
    Simulates a dueling bandit environment with:
    - Contexts (X) and Post-Serving Contexts (Y)
    - Latent preference parameters (theta_star, zeta_star)
    - Mechanisms for Stochastic/Adversarial Delay and Corruption
    """
    def __init__(self, d=1, e=2, k=10, delay_model=None, corruption_model=None, phi_type='sinusoidal', delay_first=False):
        self.d = d
        self.e = e
        self.k = k
        self.delay_model = delay_model if delay_model is not None else StochasticGeometricDelay(0)
        self.corruption_model = corruption_model if corruption_model is not None else NoCorruption()
        self.phi_type = phi_type
        self.delay_first = delay_first

        # True parameters
        self.theta_star = np.random.randn(d) * 0.1 # Small dependence on X
        self.zeta_star = np.random.randn(e) * 5.0  # Large dependence on Y
        print(f"True Params: Theta (X)={self.theta_star}, Zeta (Y)={self.zeta_star}")
        
        # For Linear Mapping
        if self.phi_type == 'linear':
            # Create a fixed random matrix M: d -> e
            # If d != e, we map d -> e directly.
            # If d == e, just a square matrix.
            self.M = np.random.randn(d, e)

    def get_context(self):
        X = np.random.uniform(-np.pi, np.pi, size=(self.k, self.d))
        return X

    def get_true_post_context(self, X):
        """
        Generates Y based on X and self.phi_type.
        Ensures output shape is (K, e).
        """
        if self.phi_type == 'sinusoidal':
            raw_y = np.concatenate([np.cos(X), np.sin(X)], axis=1)
        
        elif self.phi_type == 'piecewise':
            # Y = X * (X > 0) + 0.5 * X * (X <= 0) (Leaky ReLU-like)
            raw_y = X * (X > 0) + 0.5 * X * (X <= 0)
            
        elif self.phi_type == 'linear':
            # Y = X @ M
            # X is (K, d), M is (d, e) -> (K, e)
            raw_y = X @ self.M
            
        elif self.phi_type == 'polynomial':
            # Y = [X^2, sqrt(|X|)]
            raw_y = np.concatenate([X**2, np.sqrt(np.abs(X))], axis=1)

        elif self.phi_type == 'interaction':
            # Y = [X_i * X_j] for i <= j (Upper triangular interaction terms)
            # Efficiently compute pairwise products
            K_batch, d_dim = X.shape
            interactions = []
            for i in range(d_dim):
                for j in range(i, d_dim):
                    interactions.append(X[:, i] * X[:, j])
            raw_y = np.stack(interactions, axis=1)

        elif self.phi_type == 'abs':
            # Y = |X|. Symmetric (Uncorrelated with X).
            raw_y = np.abs(X)

        elif self.phi_type == 'cosine':
            # Y = cos(X). Symmetric (Uncorrelated with X).
            raw_y = np.cos(X)
            
        else:
            raise ValueError(f"Unknown phi_type: {self.phi_type}")

        # Ensure Y has dimension e by tiling or slicing if raw_y doesn't match roughly
        # Special case for Linear: it matches perfectly if we used M(d, e).
        if self.phi_type == 'linear':
            return raw_y

        # For others, raw_y dimension depends on d (e.g. 2d for sin, d for piecewise, 2d for poly)
        # We need to map this to e.
        if raw_y.shape[1] >= self.e:
            return raw_y[:, :self.e]
        else:
            repeats = int(np.ceil(self.e / raw_y.shape[1]))
            filled = np.tile(raw_y, (1, repeats))
            return filled[:, :self.e]

    def get_utility(self, X, Y):
        u_x = X @ self.theta_star
        u_y = Y @ self.zeta_star
        return u_x + u_y

    def get_feedback(self, u_a, u_b):
        """
        Returns (outcome, delay) pair based on the configured priority.
        
        If delay_first is True:
            1. Check Delay. If delay > 0, return (True Outcome, delay).
               (Corruption is skipped because delay 'consumption' happened usually, 
                or we define delay as the primary attack).
            2. If delay == 0, Check Corruption. Return (Corrupted Outcome, 0).
            
        If delay_first is False (Default/Corruption First):
            1. Check Corruption. If outcome is corrupted, return (Corrupted Outcome, 0).
            2. If outcome is NOT corrupted, Check Delay. Return (True Outcome, delay).
        """
        # Calculate True Outcome Prob
        prob = 1 / (1 + np.exp(-(u_a - u_b)))
        true_outcome = 1 if np.random.rand() < prob else 0
        
        if self.delay_first:
            # 1. Try Delay
            delay = self.delay_model.get_delay(u_a, u_b)
            if delay > 0:
                # If delayed, we assume the adversary used their resource to delay.
                # So we return the TRUE outcome (or maybe corruption doesn't get a chance).
                return true_outcome, delay
            
            # 2. If no delay, Try Corruption
            final_outcome = self.corruption_model.corrupt(true_outcome, u_a, u_b)
            return final_outcome, 0
            
        else:
            # 1. Try Corruption
            final_outcome = self.corruption_model.corrupt(true_outcome, u_a, u_b)
            
            # Check if corruption happened
            is_corrupted = (final_outcome != true_outcome)
            
            # Strategic delays are only applied if outcome is NOT corrupted.
            # However, stochastic delays are independent of adversarial strategy.
            is_stochastic = isinstance(self.delay_model, (StochasticGeometricDelay, StochasticGaussianDelay))
            
            if is_corrupted and not is_stochastic:
                # If corrupted and using adversarial delay, immediate feedback
                return final_outcome, 0
            
            # 2. Try Delay
            delay = self.delay_model.get_delay(u_a, u_b)
            return final_outcome, delay

    def compare(self, u_a, u_b):
        # Legacy support / Direct call if needed (but prefer get_feedback)
        # warning: logic here might not respect the coupled interaction if called standalone without state sharing
        prob = 1 / (1 + np.exp(-(u_a - u_b)))
        true_outcome = 1 if np.random.rand() < prob else 0
        return self.corruption_model.corrupt(true_outcome, u_a, u_b)

    def get_delay(self, u_a, u_b):
         return self.delay_model.get_delay(u_a, u_b)

# --- Delay Models ---
class DelayModel:
    def get_delay(self, u_a=None, u_b=None):
        raise NotImplementedError

class StochasticGeometricDelay(DelayModel):
    def __init__(self, mean_delay):
        self.mean_delay = mean_delay
        
    def get_delay(self, u_a=None, u_b=None):
        if self.mean_delay <= 0:
            return 0
        p = 1.0 / (self.mean_delay + 1.0)
        return np.random.geometric(p) - 1

class StochasticGaussianDelay(DelayModel):
    def __init__(self, mean_delay, std_delay):
        self.mean_delay = mean_delay
        self.std_delay = std_delay
        
    def get_delay(self, u_a=None, u_b=None):
        # Generate delay from Gaussian(mean, std), clip to >= 0, and round to integer
        delay = np.random.normal(self.mean_delay, self.std_delay)
        return int(max(0, np.round(delay)))

class AdversarialDelay(DelayModel):
    def __init__(self, budget, fixed_delay):
        self.budget = budget
        self.fixed_delay = fixed_delay
        self.total_delayed = 0
        
    def get_delay(self, u_a=None, u_b=None):
        if self.total_delayed >= self.budget:
            return 0
            
        allowed = self.budget - self.total_delayed
        actual_delay = min(self.fixed_delay, allowed)
        
        if actual_delay > 0:
            self.total_delayed += actual_delay
            return actual_delay
        return 0

class ImpulseDelay(DelayModel):
    def __init__(self, start_round, duration, magnitude):
        self.start_round = start_round
        self.end_round = start_round + duration
        self.magnitude = magnitude
        self.current_round = 0
        
    def get_delay(self, u_a=None, u_b=None):
        delay = 0
        if self.start_round <= self.current_round < self.end_round:
            delay = self.magnitude
        self.current_round += 1
        return delay

class StrategicDelay(DelayModel):
    """
    Strategic Delay (Attack on Best Arm):
    - If the learner chooses good arms (high collective utility), delay the feedback.
    - If the learner chooses bad arms, give feedback instantly.
    This starves the learner of positive signal.
    """
    def __init__(self, budget, magnitude=100, threshold_val=0.0):
        self.budget = budget
        self.magnitude = magnitude
        self.threshold = threshold_val
        self.total_delayed = 0
        
    def get_delay(self, u_a, u_b):
        # Attack: If the arms are "good" (high utility), delay the feedback to starve the learner.
        # But respect the budget.
        if self.total_delayed >= self.budget:
            return 0
            
        val = max(u_a, u_b)
        if val > self.threshold:
            # Apply delay, but cap at remaining budget? 
            # Or just stop applying if budget exceeded?
            # Usually strict budget means sum(delay) <= budget.
            
            allowed = self.budget - self.total_delayed
            actual_delay = min(self.magnitude, allowed)
            
            if actual_delay > 0:
                self.total_delayed += actual_delay
                return actual_delay
        return 0

# --- Corruption Models ---
class CorruptionModel:
    def corrupt(self, outcome, u_a=None, u_b=None):
        raise NotImplementedError

class NoCorruption(CorruptionModel):
    def __init__(self):
        self.count = 0
        
    def corrupt(self, outcome, u_a=None, u_b=None):
        return outcome

class AdversarialOutcomeCorruption(CorruptionModel):
    def __init__(self, budget):
        self.budget = budget
        self.count = 0
        
    def corrupt(self, outcome, u_a=None, u_b=None):
        if self.count < self.budget:
            self.count += 1
            return 1 - outcome # Flip outcome
        return outcome

class StrategicOutcomeCorruption(CorruptionModel):
    """
    Strategic Corruption (Best Arm Attack):
    Flips the outcome if it favors the better arm.
    Effectively tries to hide the superiority of the better arm.
    """
    def __init__(self, budget):
        self.budget = budget
        self.count = 0
        
    def corrupt(self, outcome, u_a, u_b):
        if self.count >= self.budget:
            return outcome
            
        # Check if outcome is "correct" (favors higher utility)
        # outcome=1 means a wins. outcome=0 means b wins.
        if (u_a > u_b and outcome == 1) or (u_b > u_a and outcome == 0):
            self.count += 1
            return 1 - outcome # Flip to wrong
            
        return outcome

class RCDBCorruption(CorruptionModel):
    """
    Corruption Strategies from RCDB Paper (Appendix E):
    1. Greedy Attack: Flip first C rounds.
    2. Random Attack: Flip with prob p until Budget C reached.
    3. Adversarial Attack: Flip if outcome aligns with true preference (hide truth).
    4. Misleading Attack: Select suboptimal arm via 'target_arm_idx' logic (conceptually).
       Here, we implement it by picking a random suboptimal arm at init as the 'target'.
       If the target is in the pair, make it win.
    """
    def __init__(self, budget, attack_type='adversarial', random_prob=0.5, d=None, env_k=None):
        self.budget = budget
        self.attack_type = attack_type
        self.random_prob = random_prob
        self.count = 0
        self.target_arm_index = None # For Misleading attack
        
        # For Misleading, we need to know a suboptimal arm.
        # Ideally, we pick one far from optimal.
        # Since we don't know outcomes in advance easily without env, 
        # we will rely on 'u_a, u_b' logic or context to identify 'target'.
        # Note: The paper says "The adversary selects a suboptimal action." 
        # This implies a fixed target or dynamic? "It will make sure THIS arm is always the winner".
        # Implementation: We will try to boost the 'worse' arm in the pair? 
        # OR boost a specific globally suboptimal arm?
        # Interpretation: Boost the WORSE of the two (u_b if u_a > u_b). 
        # This misleads the agent to think bad arms are good.
        pass

    def corrupt(self, outcome, u_a, u_b):
        if self.count >= self.budget:
            return outcome
        
        should_flip = False
        
        if self.attack_type == 'greedy':
            # Flip naturally until budget runs out
            should_flip = True
            
        elif self.attack_type == 'random':
            if np.random.rand() < self.random_prob:
                should_flip = True
                
        elif self.attack_type == 'adversarial':
            # Flip if outcome aligns with preference model (i.e. is correct)
            # P(a > b) > 0.5 => outcome 1 is "correct"
            # We use u_a, u_b proxies.
            if u_a > u_b:
                # Correct outcome is 1. If observed is 1, flip it.
                if outcome == 1:
                    should_flip = True
            elif u_b > u_a:
                # Correct outcome is 0. If observed is 0, flip it.
                if outcome == 0:
                    should_flip = True
                    
        elif self.attack_type == 'misleading':
            # Make the suboptimal arm win specific to this pair?
            # Paper: "Selects A suboptimal action... make sure THIS arm is always the winner"
            # This implies a specific global target.
            # Without global knowledge, we can implement "Make the LOWER utility arm win".
            # This is "Anti-correct".
            if u_a > u_b:
                # a is better. We want b (worse) to win (outcome 0).
                # If outcome is 1 (a wins), flip it.
                if outcome == 1:
                    should_flip = True
            elif u_b > u_a:
                # b is better. We want a (worse) to win (outcome 1).
                # If outcome is 0 (b wins), flip it.
                if outcome == 0:
                    should_flip = True
            # Note: Adversarial and Misleading (local version) are effectively mostly the same 
            # (suppressing truth). "Adversarial" in paper says "flip when aligns w/ preference model",
            # which is statistically maximizing error. "Misleading" focuses on a specific arm.
            # Given constraints, we treat Misleading ~ Adversarial locally here unless we track a global target.
            
        if should_flip:
            self.count += 1
            return 1 - outcome
            
        return outcome

class NeuralApproximator(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(NeuralApproximator, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, output_dim)
        )
    
    def forward(self, x):
        return self.net(x)

class DuelingGLMLearner:
    """
    Dueling GLM Learner (Ours: RCDP-UCB).
    
    Implements a robust UCB algorithm that handles:
    - Post-serving contexts (via Neural Approximator)
    - Adversarial corruptions and delays (via Adaptive Weighting)
    """
    def __init__(self, d, e, lambda_reg=1.0, alpha=0.1, C=1.0, Lambda=1.0, mu_tau=0.0, kappa=0.1):
        self.d = d
        self.e = e
        self.dim = d + e
        self.lambda_reg = lambda_reg
        self.ucb_alpha = alpha # For arm selection confidence interval (c_t in paper)
        
        # Robustness parameters
        self.C = C
        self.Lambda = Lambda
        self.mu_tau = mu_tau
        self.kappa = kappa # Conservative lower bound for derivative
        
        # Alpha from Theorem (Weighting parameter)
        # User requirement: Divide by sqrt(kappa)
        denom = C + min(np.sqrt(Lambda), mu_tau) if mu_tau > 0 else C + np.sqrt(Lambda)
        self.weight_alpha = np.sqrt(self.dim) / (np.sqrt(self.kappa) * denom + 1e-6)
        
        # V: Full history (for weighting)
        self.V = self.lambda_reg * np.eye(self.dim)
        self.V_inv = (1.0 / self.lambda_reg) * np.eye(self.dim)
        
        # W: Observed history (for estimation/confidence)
        self.W = self.lambda_reg * np.eye(self.dim)
        self.W_inv = (1.0 / self.lambda_reg) * np.eye(self.dim)
        
        self.theta_hat = np.zeros(self.dim)
        
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
        self.f_t = NeuralApproximator(d, e).to(self.device)
        self.optimizer_f = optim.Adam(self.f_t.parameters(), lr=1e-3)
        self.loss_fn_f = nn.MSELoss()
        
        self.data_buffer_x = []
        self.data_buffer_y = []
        
    def train_approximator(self, epochs=5):
        if self.e == 0:
            return # No post-serving context to learn
            
        if not self.data_buffer_x:
            return
        
        X_tensor = torch.FloatTensor(np.array(self.data_buffer_x)).to(self.device)
        Y_tensor = torch.FloatTensor(np.array(self.data_buffer_y)).to(self.device)
        
        self.f_t.train()
        for _ in range(epochs):
            self.optimizer_f.zero_grad()
            outputs = self.f_t(X_tensor)
            loss = self.loss_fn_f(outputs, Y_tensor)
            loss.backward()
            self.optimizer_f.step()
            
    def predict_y(self, X):
        if self.e == 0:
            return np.zeros((len(X), 0))
            
        self.f_t.eval()
        with torch.no_grad():
            X_tensor = torch.FloatTensor(X).to(self.device)
            Y_pred = self.f_t(X_tensor).cpu().numpy()
        return Y_pred

    def get_features(self, X):
        Y_pred = self.predict_y(X)
        return np.concatenate([X, Y_pred], axis=1)
    
    def select_arms(self, X, env=None):
        """
        PonLinRUCB Strategy with Adaptive Weighting.
        Returns: a_t, b_t, omega_t
        """
        Z = self.get_features(X) 
        
        # 1. Champion a_t (Greedy)
        utilities = Z @ self.theta_hat
        a_t = np.argmax(utilities)
        z_a = Z[a_t]
        
        # 2. Challenger b_t using Full History V
        Delta_Z_a = Z - z_a 
        mean_diff = Delta_Z_a @ self.theta_hat
        width_V = np.sqrt(np.sum((Delta_Z_a @ self.V_inv) * Delta_Z_a, axis=1))
        
        scores_b = mean_diff + self.ucb_alpha * width_V
        b_t = np.argmax(scores_b)
        
        # 3. Adaptive Weight using Full History V
        delta_z = Z[a_t] - Z[b_t]
        norm_V = np.sqrt(delta_z @ self.V_inv @ delta_z)
        
        if norm_V < 1e-9:
            omega_t = 1.0
        else:
            omega_t = min(1.0, self.weight_alpha / norm_V)
            
        # 4. Update V (Full History)
        v_vec = np.sqrt(self.kappa * omega_t) * delta_z
        Av = self.V_inv @ v_vec
        denom = 1 + np.dot(v_vec, Av)
        self.V_inv -= np.outer(Av, Av) / denom
        
        return a_t, b_t, omega_t

    def update(self, X, a_idx, b_idx, outcome, y_a_obs, y_b_obs, omega_t):
        x_a = X[a_idx]
        x_b = X[b_idx]
        
        self.data_buffer_x.append(x_a)
        self.data_buffer_y.append(y_a_obs)
        self.data_buffer_x.append(x_b)
        self.data_buffer_y.append(y_b_obs)
        
        self.train_approximator(epochs=2)
        
        z = self.get_features(X)
        delta_z = z[a_idx] - z[b_idx]
        
        # Weighted update for Observed History W
        v_vec = np.sqrt(self.kappa * omega_t) * delta_z
        
        Aw = self.W_inv @ v_vec
        denom_w = 1 + np.dot(v_vec, Aw)
        self.W_inv -= np.outer(Aw, Aw) / denom_w
        
        # Update Theta (Weighted ONS)
        mu_val = 1 / (1 + np.exp(- np.dot(self.theta_hat, delta_z)))
        step = self.W_inv @ (omega_t * (outcome - mu_val) * delta_z)
        self.theta_hat += step

class BaselineDuelingGLMLearner:
    """
    Baseline 1: X-only (Standard Dueling Contextual Bandit)
    Uses only pre-serving context X.
    """
    def __init__(self, d, lambda_reg=1.0, alpha=0.1):
        self.d = d
        self.dim = d 
        self.lambda_reg = lambda_reg
        self.alpha = alpha 
        
        self.theta_hat = np.zeros(self.dim)
        # Using W (Observed History) for Baselines as requested
        self.W = self.lambda_reg * np.eye(self.dim)
        self.W_inv = (1.0 / self.lambda_reg) * np.eye(self.dim)
        
    def select_arms(self, X, env=None):
        Z = X 
        utilities = Z @ self.theta_hat
        a_t = np.argmax(utilities)
        z_a = Z[a_t]
        
        Delta_Z = Z - z_a 
        mean_diff = Delta_Z @ self.theta_hat
        weighted_norm = np.sqrt(np.sum((Delta_Z @ self.W_inv) * Delta_Z, axis=1))
        
        ucb_scores = mean_diff + self.alpha * weighted_norm
        b_t = np.argmax(ucb_scores)
        
        # Return 1.0 for omega_t (not used)
        return a_t, b_t, 1.0 

    def update(self, X, a_idx, b_idx, outcome, y_a_obs, y_b_obs, omega_t=1.0):
        # Ignore y_a_obs, y_b_obs, omega_t
        Z = X
        delta_z = Z[a_idx] - Z[b_idx]
        
        mu_val = 1 / (1 + np.exp(- np.dot(self.theta_hat, delta_z)))
        
        outer = np.outer(delta_z, delta_z)
        self.W += outer
        
        Aw = self.W_inv @ delta_z
        denom = 1 + np.dot(delta_z, Aw)
        self.W_inv -= np.outer(Aw, Aw) / denom
        
        step = self.W_inv @ (delta_z * (outcome - mu_val))
        self.theta_hat += step

class NonRobustDuelingGLMLearner(DuelingGLMLearner):
    """
    Baseline 2: Non-Robust Full-Info Learner
    Uses X + Predicted Y, but NO adaptive weighting (omega_t = 1 always).
    """
    def __init__(self, d, e, lambda_reg=1.0, alpha=0.1):
        # Initialize with dummy robust params since we won't use them for weighting
        super().__init__(d, e, lambda_reg, alpha, C=0, Lambda=0, mu_tau=0)
        
    def select_arms(self, X, env=None):
        Z = self.get_features(X) 
        
        # 1. Champion a_t
        utilities = Z @ self.theta_hat
        a_t = np.argmax(utilities)
        z_a = Z[a_t]
        
        # 2. Challenger b_t using Observed History W
        Delta_Z_a = Z - z_a 
        mean_diff = Delta_Z_a @ self.theta_hat
        width_W = np.sqrt(np.sum((Delta_Z_a @ self.W_inv) * Delta_Z_a, axis=1))
        
        scores_b = mean_diff + self.ucb_alpha * width_W
        b_t = np.argmax(scores_b)
        
        # 3. Force Omega = 1
        omega_t = 1.0
        
        # 4. Update V (Full History) with omega=1
        delta_z = Z[a_t] - Z[b_t]
        v_vec = np.sqrt(omega_t) * delta_z
        Av = self.V_inv @ v_vec
        denom = 1 + np.dot(v_vec, Av)
        self.V_inv -= np.outer(Av, Av) / denom
        
        return a_t, b_t, omega_t

class RobustBaselineDuelingGLMLearner(DuelingGLMLearner):
    """
    Baseline 3: Robust X-only
    Uses only pre-serving context X, but APPLIES the robust weighting strategy.
    This helps isolate the benefit of Y from the benefit of Robustness.
    """
    def __init__(self, d, lambda_reg=1.0, alpha=0.1, C=1.0, Lambda=1.0, mu_tau=0.0):
        # Pass e=0 effectively.
        super().__init__(d, 0, lambda_reg, alpha, C, Lambda, mu_tau)
    
    def get_features(self, X):
        # Override to return just X
        return X

    def train_approximator(self, epochs=5):
        # Override to do nothing
        pass

class OracleDuelingGLMLearner(DuelingGLMLearner):
    """
    Baseline 4: Oracle (Ideal Skyline)
    Uses TRUE Post-serving context Y (cheating), and applies robust weighting.
    Serves as an upper bound on performance (lower bound on regret).
    """
    def __init__(self, d, e, lambda_reg=1.0, alpha=0.1, C=1.0, Lambda=1.0, mu_tau=0.0, env=None):
        super().__init__(d, e, lambda_reg, alpha, C, Lambda, mu_tau)
        self.env = env # Store env to access True Y
    
    def train_approximator(self, epochs=5):
        pass

    def get_features(self, X):
        # Override to use True Y from env
        if self.env is not None:
             Y_true = self.env.get_true_post_context(X)
             # Use [X, Y_true] as features
             return np.concatenate([X, Y_true], axis=1)
        else:
             # Should not happen if used correctly
             return super().get_features(X)

    def select_arms(self, X, env=None):
        # Use get_features which now handles True Y
        Z = self.get_features(X) 

        # 1. Champion a_t (Greedy)
        utilities = Z @ self.theta_hat
        a_t = np.argmax(utilities)
        z_a = Z[a_t]
        
        # 2. Challenger b_t using Observed History W
        Delta_Z_a = Z - z_a 
        mean_diff = Delta_Z_a @ self.theta_hat
        width_W = np.sqrt(np.sum((Delta_Z_a @ self.W_inv) * Delta_Z_a, axis=1))
        
        scores_b = mean_diff + self.ucb_alpha * width_W
        b_t = np.argmax(scores_b)
        
        # 3. Adaptive Weight using Full History V
        delta_z = Z[a_t] - Z[b_t]
        norm_V = np.sqrt(delta_z @ self.V_inv @ delta_z)
        
        if norm_V < 1e-9:
            omega_t = 1.0
        else:
            omega_t = min(1.0, self.weight_alpha / norm_V)
            
        # 4. Update V (Full History)
        v_vec = np.sqrt(self.kappa * omega_t) * delta_z
        Av = self.V_inv @ v_vec
        denom = 1 + np.dot(v_vec, Av)
        self.V_inv -= np.outer(Av, Av) / denom
        
        return a_t, b_t, omega_t

def run_simulation(env, learner, T, name="Learner"):
    # np.random.seed(42) # REMOVED to allow external seeding
    
    cumulative_regret = 0
    regrets = []
    feedback_queue = []
    
    print(f"Starting simulation for {name}...")
    
    for t in range(T):
        X = env.get_context()
        
        a_idx, b_idx, omega_t = learner.select_arms(X, env=env)
        
        Y = env.get_true_post_context(X) 
        y_a = Y[a_idx]
        y_b = Y[b_idx]
        
        u_all = env.get_utility(X, Y)
        if hasattr(u_all, 'flatten'):
            u_all = u_all.flatten()
        
        u_a = u_all[a_idx]
        u_b = u_all[b_idx]
        
        # Unified feedback handling (Corruption vs Delay)
        outcome, delay = env.get_feedback(u_a, u_b)

        arrival_time = t + delay
        
        feedback_queue.append({
            'arrival_time': arrival_time,
            'X': X,
            'a_idx': a_idx,
            'b_idx': b_idx,
            'outcome': outcome,
            'y_a': y_a,
            'y_b': y_b,
            'omega_t': omega_t
        })
        
        pending_removals = []
        for i, item in enumerate(feedback_queue):
            if item['arrival_time'] <= t:
                learner.update(
                    item['X'], 
                    item['a_idx'], 
                    item['b_idx'], 
                    item['outcome'], 
                    item['y_a'], 
                    item['y_b'],
                    item['omega_t']
                )
                pending_removals.append(i)
        
        for i in sorted(pending_removals, reverse=True):
            del feedback_queue[i]
        
        k_star = np.argmax(u_all)
        u_star = u_all[k_star]
        inst_regret = (u_star - u_a) + (u_star - u_b)
        
        cumulative_regret += inst_regret
        regrets.append(cumulative_regret) 
        
        if (t+1) % 500 == 0:
            avg_omega = np.mean([item['omega_t'] for item in feedback_queue] + [1.0]) # approximate
            print(f"[{name}] Round {t+1}/{T}, Regret: {cumulative_regret:.2f}, Avg Omega: {avg_omega:.4f}")
            
    return regrets


# --- Advanced Robust Baselines ---

class RCDBLearner(BaselineDuelingGLMLearner):
    """
    RCDB (Robust Contextual Dueling Bandits)
    From "Nearly Optimal Algorithms for Contextual Dueling Bandits from Adversarial Feedback".
    
    Features:
    1. Weighted ONS/MLE update based on uncertainty: w_t = min(1, alpha / ||z_{i,j}||_{V^{-1}})
    2. Exploration Bonus: beta * ||z_{i,j}||_{V^{-1}}
    
    We adapt a simplified version compatible with the current GLM structure:
    - Update uses weights.
    - Selection uses UCB with beta scaling.
    """
    def __init__(self, d, lambda_reg=1.0, alpha=0.1, rcdb_alpha=1.0, rcdb_beta=1.0, kappa=0.1):
        super().__init__(d, lambda_reg, alpha)
        self.rcdb_alpha = rcdb_alpha
        self.rcdb_beta = rcdb_beta
        self.kappa = kappa
        
    def select_arms(self, X, env=None):
        # 1. Select i_t = argmax <x_i, theta> (Greedy Leader)
        Z = X
        utilities = Z @ self.theta_hat
        i_t = np.argmax(utilities)
        z_i = Z[i_t]
        
        # 2. Select j_t = argmax <x_i - x_j, theta> + beta * ||x_i - x_j||_W^{-1}
        # Maximizing UCB of the pair relative to the first arm.
        Delta_Z = Z - z_i 
        
        mean_diff = Delta_Z @ self.theta_hat
        weighted_norm = np.sqrt(np.sum((Delta_Z @ self.W_inv) * Delta_Z, axis=1))
        
        # Maximize UCB
        scores = mean_diff + self.rcdb_beta * weighted_norm
        j_t = np.argmax(scores)
        
        return i_t, j_t, 1.0

    def update(self, X, a_idx, b_idx, outcome, y_a_obs, y_b_obs, omega_t=1.0):
        # Weight Calculation
        Z = X
        delta_z = Z[a_idx] - Z[b_idx]
        norm_val = np.sqrt(delta_z @ self.W_inv @ delta_z)
        
        # w_t = min(1, alpha / norm)
        if norm_val < 1e-9:
            w_t = 1.0
        else:
            w_t = min(1.0, self.rcdb_alpha / norm_val)
            
        # Weighted Update
        mu_val = 1 / (1 + np.exp(- np.dot(self.theta_hat, delta_z)))
        
        # Scale covariance by kappa per theory (Sigma_t)
        w_delta_z = np.sqrt(w_t * self.kappa) * delta_z 
        
        outer = np.outer(w_delta_z, w_delta_z)
        self.W += outer
        
        Av = self.W_inv @ w_delta_z
        denom = 1 + np.dot(w_delta_z, Av)
        self.W_inv -= np.outer(Av, Av) / denom
        
        # Gradient Step weighted
        step = self.W_inv @ (delta_z * w_t * (outcome - mu_val))
        self.theta_hat += step


class MaxInPLearner(BaselineDuelingGLMLearner):
    """
    MaxInP (Maximum Informative Pair) - Regret Minimizing Version
    Original strategy: Maximize ||x_i - x_j||_{V^{-1}}
    The pure exploration version yields linear regret because it explores irrelevant (bad) arms.
    
    Fix: Restrict the "uncertain" candidate j_t to be within the set of plausibly optimal arms.
    We select j_t from the set {k : UCB(k) >= f(a_t)}.
    """
    def select_arms(self, X, env=None):
        Z = X
        # 1. Optimistic Arm (Champion)
        utilities = Z @ self.theta_hat
        i_t = np.argmax(utilities)
        z_i = Z[i_t]
        
        # 2. Define Candidate Set (Optimism)
        # Calculate marginal widths for UCB
        # Note: This is an approximation. Strict MaxInP uses pairwise. 
        # But to filter bad arms, marginal UCB is sufficient.
        self_widths = np.sqrt(np.sum((Z @ self.W_inv) * Z, axis=1))
        ucb_scores = utilities + self.alpha * self_widths
        
        # Filter: Candidates must have UCB >= Greedy Mean
        threshold = utilities[i_t]
        candidates = np.where(ucb_scores >= threshold)[0]
        
        # Remove self (i_t) to ensure distinct pair and learning
        candidates = candidates[candidates != i_t]
        
        if len(candidates) == 0:
            # Fallback: If no optimistic candidates exist (we are very confident),
            # pick the arm with the next highest UCB to keep verifying.
            ucb_scores[i_t] = -np.inf
            j_t = np.argmax(ucb_scores)
        else:
            # 3. Select j_t from Candidates maximizing relative uncertainty
            Delta_Z = Z[candidates] - z_i
            uncertainties = np.sqrt(np.sum((Delta_Z @ self.W_inv) * Delta_Z, axis=1))
            best_candidate_idx = np.argmax(uncertainties)
            j_t = candidates[best_candidate_idx]
        
        return i_t, j_t, 1.0


class MaxPairUCBLearner(BaselineDuelingGLMLearner):
    """
    MaxPairUCB
    Maximize UCB(i, j) = <x_i - x_j, theta> + alpha * ||x_i - x_j||_{W^{-1}}
    Same as our RCDB selection if beta=alpha, but Update is standard unweighted.
    """
    def select_arms(self, X, env=None):
        Z = X
        utilities = Z @ self.theta_hat
        i_t = np.argmax(utilities)
        
        z_i = Z[i_t]
        Delta_Z = Z - z_i
        
        mean_diff = Delta_Z @ self.theta_hat
        weighted_norm = np.sqrt(np.sum((Delta_Z @ self.W_inv) * Delta_Z, axis=1))
        
        scores = mean_diff + self.alpha * weighted_norm
        j_t = np.argmax(scores)
        
        return i_t, j_t, 1.0


class ColSTIMLearner(BaselineDuelingGLMLearner):
    """
    ColSTIM (Contextualized Linear Stochastic Transitivity Imitator)
    Uses Perturbed Greedy / Sampling for arm selection.
    """
    def select_arms(self, X, env=None):
        # Thompson Sampling style selection
        Z = X
        utilities = Z @ self.theta_hat
        norms = np.sqrt(np.sum((Z @ self.W_inv) * Z, axis=1))
        
        # Thompson Sampling Scores: mu + sigma * N(0,1)
        ts_scores = utilities + self.alpha * norms * np.random.randn(len(Z))
        
        # Pick top 2
        indices = np.argsort(ts_scores)[::-1]
        i_t = indices[0]
        j_t = indices[1]
        
        return i_t, j_t, 1.0

class RCDBPostServingLearner(DuelingGLMLearner):
    """
    RCDB equipped with Post-serving Contexts.
    
    It uses the same NeuralApproximator to predict Y as DuelingGLMLearner (Ours),
    but uses the RCDB strategy (Robust Weighting + Beta-UCB) for selection and update.
    
    Hypothesis: Even with access to Y, RCDB's weighting strategy is inferior to ours
    under Strategic Delay + Corruption.
    """
    def __init__(self, d, e, lambda_reg=1.0, alpha=0.1, rcdb_alpha=1.0, rcdb_beta=1.0):
        super().__init__(d, e, lambda_reg, alpha, C=0.0, Lambda=0.0, mu_tau=0.0)
        
        self.rcdb_alpha = rcdb_alpha
        self.rcdb_beta = rcdb_beta
        
    def select_arms(self, X, env=None):
        Z = self.get_features(X) 
        utilities = Z @ self.theta_hat
        i_t = np.argmax(utilities)
        z_i = Z[i_t]
        
        Delta_Z = Z - z_i 
        mean_diff = Delta_Z @ self.theta_hat
        weighted_norm = np.sqrt(np.sum((Delta_Z @ self.V_inv) * Delta_Z, axis=1))
        
        scores = mean_diff + self.rcdb_beta * weighted_norm
        j_t = np.argmax(scores)
        
        return i_t, j_t, 1.0

    def update(self, X, a_idx, b_idx, outcome, y_a_obs, y_b_obs, omega_t_dummy=1.0):
        x_a = X[a_idx]
        x_b = X[b_idx]
        self.data_buffer_x.append(x_a)
        self.data_buffer_y.append(y_a_obs)
        self.data_buffer_x.append(x_b)
        self.data_buffer_y.append(y_b_obs)
        self.train_approximator(epochs=2)
        
        Z = self.get_features(X)
        delta_z = Z[a_idx] - Z[b_idx]
        
        norm_val = np.sqrt(delta_z @ self.V_inv @ delta_z)
        
        if norm_val < 1e-9:
            w_t = 1.0
        else:
            w_t = min(1.0, self.rcdb_alpha / norm_val)
            
        mu_val = 1 / (1 + np.exp(- np.dot(self.theta_hat, delta_z)))
        
        w_delta_z = np.sqrt(w_t) * delta_z 
        
        outer = np.outer(w_delta_z, w_delta_z)
        self.V += outer
        
        Av = self.V_inv @ w_delta_z
        denom = 1 + np.dot(w_delta_z, Av)
        self.V_inv -= np.outer(Av, Av) / denom
        
        step = self.V_inv @ (delta_z * w_t * (outcome - mu_val))
        self.theta_hat += step
