# -*- coding: utf-8 -*-
# Fixed-action-set Hybrid Linear Bandits experiment (Biased vs Unbiased vs Pure-Online)
# Implements lemma-based betas and fair RNG handling.
# Every line has an English comment for clarity.

import math                              # for math.sqrt
import numpy as np                       # numerical computations
import matplotlib.pyplot as plt          # plotting

# ====================== Config ======================

d = 10                                   # feature dimension
K = 50                                   # number of fixed actions
T = 3000                                 # time horizon
sigma = 0.5                              # reward noise std
lam = 1e-2                               # ridge regularization
delta = 0.05                             # confidence level used in beta
S = 1.0                                  # ||theta||_2 bound in lemma
L = 1.0                                  # action norm bound (we keep actions within unit ball)
N = 5000                               # number of offline samples per condition
gamma = 0.9/math.sqrt(d)                             # bias strength used to label offline data
sqrt_beta_N = 0.0                        # offline-only radius term (if you have a formula, set it here)

seed_master = 2                        # master seed for reproducibility

# ====================== RNG helpers ======================

def make_rngs(seed_base=12345):
    """Create independent RNGs for data generation vs evaluation noise."""
    rng_data = np.random.default_rng(seed_base + 1)    # RNG for building datasets / parameters
    rng_eval = np.random.default_rng(seed_base + 2)    # RNG for reward noise at evaluation
    return rng_data, rng_eval

# ====================== Problem setup ======================

# Master RNG for parameter and action set generation (separate from evaluation RNG)
rng0 = np.random.default_rng(seed_master)              # global RNG for theta and A

# Generate true parameter theta and normalize it to unit length
theta = rng0.normal(size=d)                            # draw a Gaussian vector
theta /= np.linalg.norm(theta)                         # normalize to ||theta||_2 = 1

# Build a fixed action set A (K x d), then clip to unit ball (||a|| <= 1)
A = rng0.normal(size=(K, d))                           # sample K actions
A /= np.maximum(1.0, np.linalg.norm(A, axis=1, keepdims=True))  # project to unit ball

# ====================== Offline dataset generator ======================

def offline_dataset(N, A, theta, gamma, sigma, rng, bias_type="random"):
    """
    Generate an offline dataset (B, y) from the fixed action set A.
    bias_type in {"none", "positive", "negative", "random"} controls parameter shift:
      theta_off = theta (+/-) gamma * v, where v is a random unit vector.
    """
    if N == 0:                                          # no offline data requested
        return (np.zeros((0, A.shape[1])),              # empty B (0 x d)
                np.zeros(0),                            # empty y (0,)
                theta)                                  # theta_off = theta

    idx = rng.integers(low=0, high=A.shape[0], size=N)  # uniformly sample N indices from A
    B = A[idx]                                          # pick corresponding actions (N x d)

    v = rng.normal(size=A.shape[1])                     # draw a random direction in R^d
    v /= np.linalg.norm(v)                              # normalize to unit length

    if bias_type == "none":                             # unbiased case
        theta_off = theta                               # exactly theta
    elif bias_type == "positive":                       # positive bias along v
        theta_off = theta + gamma * v
    elif bias_type == "negative":                       # negative bias along v
        theta_off = theta - gamma * v
    else:                                               # "random": flip a fair coin
        sgn = 1.0 if rng.random() < 0.5 else -1.0
        theta_off = theta + sgn * gamma * v

    y = B @ theta_off + rng.normal(scale=sigma, size=N) # generate noisy offline labels
    return B, y, theta_off                              # return components

# ====================== Ridge warm-start ======================

def ridge(B, y, lam):
    """Return ridge solution theta0, design matrix V0, and b0 (= B^T y)."""
    d = B.shape[1] if B.size > 0 else d                 # dimension fallback to global d
    V0 = lam * np.eye(d) + (B.T @ B if B.size > 0 else 0)  # lam*I + B^T B
    b0 = (B.T @ y) if B.size > 0 else np.zeros(d)       # b0 = sum x r = B^T y
    theta0 = np.linalg.solve(V0, b0)                    # theta0 = V0^{-1} b0
    return theta0, V0, b0

# ====================== Lemma-based beta functions ======================

def beta_online_from_lemma(t, d, lam, delta, L=1.0, S=1.0):
    """
    Implements: sqrt(beta_t) = sqrt(lam)*S + sqrt( 2*log(1/delta) + d*log((d*lam + t*L^2)/(d*lam)) )
    Returns beta_t (square of the RHS).
    """
    t = max(1, int(t))                                  # guard for t=0
    inner = 2.0 * np.log(1.0 / delta) + d * np.log((d * lam + t * (L**2)) / (d * lam))
    sqrt_beta_t = math.sqrt(lam) * S + math.sqrt(max(inner, 0.0))  # numerical guard
    return float(sqrt_beta_t ** 2)                      # return beta_t

def beta_hybrid_from_lemma(t, d, lam, delta, V0, V_bias, L=1.0, S=1.0):
    """
    Implements: sqrt(beta_{t,N}) = sqrt( lambda_max(V0) ) * V + sqrt(beta_t) + sqrt(beta_N)
    Returns beta_{t,N} (square of the RHS).
    """
    beta_t = beta_online_from_lemma(t, d, lam, delta, L=L, S=S)  # compute online beta_t
    sqrt_beta_t = math.sqrt(beta_t)                              # take square root

    beta_N = beta_online_from_lemma(N, d, lam, delta, L=L, S=S)  # compute online beta_t
    sqrt_beta_N = math.sqrt(beta_N)                              # take square root

    lam_max_V0 = float(np.linalg.eigvalsh(V0).max())             # largest eigenvalue of V0
    sqrt_beta_tN = math.sqrt(max(lam_max_V0, 0.0)) * V_bias + sqrt_beta_t + sqrt_beta_N
    return float(sqrt_beta_tN ** 2)                              # return beta_{t,N}

# ====================== H-LinUCB (intersected confidence sets) ======================

class HLinUCB:
    """
    H-LinUCB with intersected confidence sets on a fixed action set A.
    Uses two ellipsoids (online and hybrid). Score(a) = min( online_UCB(a), hybrid_UCB(a) ).
    """
    def __init__(self, d, lam, sigma, delta=0.05, V0=None, b0=None,
                 S=1.0, L=1.0, V_bias=0.0, sqrt_beta_N=0.0):
        self.d = d                                      # feature dimension
        self.lam = lam                                  # ridge regularization
        self.sigma = sigma                              # noise std (not used explicitly in lemma betas)
        self.delta = delta                              # confidence level
        self.S = S                                      # bound on ||theta||_2 in lemma
        self.L = L                                      # bound on action norm
        self.V_bias = V_bias                            # "V" term in lemma (bias magnitude)
        self.sqrt_beta_N = sqrt_beta_N                  # sqrt(beta_N) from offline-only analysis

        # Store V0 for lambda_max(V0) in beta_{t,N}
        self.V0_for_beta = (V0.copy() if V0 is not None else lam * np.eye(d))

        # Online state (start from lam*I or provided V0/b0)
        self.V = (V0.copy() if V0 is not None else lam * np.eye(d))
        self.b = (b0.copy() if b0 is not None else np.zeros(d))
        self.V_inv = np.linalg.inv(self.V)
        self.theta_hat = self.V_inv @ self.b

        # Hybrid state (also starts from same V0/b0)
        self.VN = (V0.copy() if V0 is not None else lam * np.eye(d))
        self.bN = (b0.copy() if b0 is not None else np.zeros(d))
        self.VN_inv = np.linalg.inv(self.VN)
        self.theta_hat_N = self.VN_inv @ self.bN

        self.t = 0                                      # internal round counter

    def _betas(self):
        """Compute beta_t and beta_{t,N} according to the lemma formulas."""
        beta_t = beta_online_from_lemma(self.t + 1, self.d, self.lam, self.delta,
                                        L=self.L, S=self.S)
        beta_tN = beta_hybrid_from_lemma(self.t + 1, self.d, self.lam, self.delta,
                                         V0=self.V0_for_beta, V_bias=self.V_bias,
                                         L=self.L, S=self.S)
        return beta_t, beta_tN

    def _scores(self, A):
        """Compute intersected UCB scores for all actions in A (K x d)."""
        beta_t, beta_tN = self._betas()                                 # get both betas
        mean_online = A @ self.theta_hat                                # online means
        var_online  = np.einsum('ij,jk,ik->i', A, self.V_inv, A)        # diag(A V^{-1} A^T)
        ucb_online  = mean_online + math.sqrt(beta_t) * np.sqrt(np.maximum(var_online, 1e-12))

        mean_hybrid = A @ self.theta_hat_N                              # hybrid means
        var_hybrid  = np.einsum('ij,jk,ik->i', A, self.VN_inv, A)       # diag(A V_N^{-1} A^T)
        ucb_hybrid  = mean_hybrid + math.sqrt(beta_tN) * np.sqrt(np.maximum(var_hybrid, 1e-12))

        ucb_int = np.minimum(ucb_online, ucb_hybrid)                    # intersected bound
        return ucb_int                                                  # scores for all K actions

    def select(self, A):
        """Select the action index that maximizes the intersected UCB score."""
        scores = self._scores(A)                                        # compute scores for all actions
        return int(np.argmax(scores))                                   # pick argmax

    def update(self, x, r):
        """Update both online and hybrid ellipsoids using Sherman–Morrison."""
        self.t += 1                                                     # increment round

        # Online update
        self.b += x * r                                                 # b <- b + x r
        Vx = self.V_inv @ x                                             # V^{-1} x
        denom = 1.0 + x @ Vx                                            # scalar denominator
        self.V_inv = self.V_inv - np.outer(Vx, Vx) / denom              # rank-1 inverse update
        self.theta_hat = self.V_inv @ self.b                            # refresh estimator

        # Hybrid update
        self.bN += x * r                                                # bN <- bN + x r
        VNx = self.VN_inv @ x                                           # V_N^{-1} x
        denomN = 1.0 + x @ VNx                                          # scalar denominator
        self.VN_inv = self.VN_inv - np.outer(VNx, VNx) / denomN         # rank-1 inverse update
        self.theta_hat_N = self.VN_inv @ self.bN                        # refresh hybrid estimator

# ====================== Regret evaluation ======================

def instantaneous_regret(theta, A, chosen_idx):
    """Compute pseudo-regret r* - r_chosen using noiseless rewards A @ theta."""
    rewards = A @ theta                                                 # mean rewards (no noise)
    return float(np.max(rewards) - rewards[chosen_idx])                 # optimality gap

def run_episode(A, theta, algo, T, sigma, rng):
    """Run T rounds with fixed A; return cumulative pseudo-regret curve."""
    regrets = np.zeros(T)                                               # cumulative regret storage
    cum = 0.0                                                           # running sum
    for t in range(T):                                                  # loop over rounds
        idx = algo.select(A)                                            # choose action index
        r = A[idx] @ theta + rng.normal(scale=sigma)                    # observe noisy reward
        cum += instantaneous_regret(theta, A, idx)                      # add pseudo-regret
        regrets[t] = cum                                                # record cumulative regret
        algo.update(A[idx], r)                                          # update algorithm
    return regrets                                                      # return curve

# ====================== Build datasets (Biased & Unbiased) ======================

# Independent RNGs for data vs evaluation noise (to ensure fairness later)
rng_data, rng_eval_master = make_rngs(seed_base=seed_master)

# Biased offline dataset (random ±gamma direction)
B_b, y_b, theta_off_b = offline_dataset(N=N, A=A, theta=theta, gamma=gamma,
                                        sigma=sigma, rng=rng_data, bias_type="random")
theta0_b, V0_b, b0_b = ridge(B_b, y_b, lam)                            # biased warm-start

# Unbiased offline dataset (theta_off = theta)
B_u, y_u, theta_off_u = offline_dataset(N=N, A=A, theta=theta, gamma=gamma,
                                        sigma=sigma, rng=rng_data, bias_type="none")
theta0_u, V0_u, b0_u = ridge(B_u, y_u, lam)                            # unbiased warm-start

# ====================== Fair evaluation: same noise trajectory ======================

def run_condition(A, theta, V0, b0, T, sigma, lam, delta, V_bias, sqrt_beta_N, eval_seed):
    """Run one condition on a fixed evaluation RNG so all curves share the same noise."""
    rng_eval = np.random.default_rng(eval_seed)                         # per-condition eval RNG
    algo = HLinUCB(d=A.shape[1], lam=lam, sigma=sigma, delta=delta,
                   V0=V0, b0=b0, S=S, L=L, V_bias=V_bias, sqrt_beta_N=sqrt_beta_N)
    return run_episode(A, theta, algo, T=T, sigma=sigma, rng=rng_eval)  # run and return regret curve

eval_seed = 20250925                                                   # shared seed across all curves

# Per lemma, set V term for biased; keep 0 for unbiased and pure-online
V_term_biased   = gamma * math.sqrt(d)                                  # V = gamma * sqrt(d)
V_term_unbiased = 0.0                                                   # no bias inflation

# # Run three conditions on the SAME noise trajectory
# regret_hybrid_biased   = run_condition(A, theta, V0_b, b0_b, T, sigma, lam, delta,
#                                        V_bias=V_term_biased,   sqrt_beta_N=sqrt_beta_N,
#                                        eval_seed=eval_seed)

# regret_hybrid_unbiased = run_condition(A, theta, V0_u, b0_u, T, sigma, lam, delta,
#                                        V_bias=V_term_unbiased, sqrt_beta_N=sqrt_beta_N,
#                                        eval_seed=eval_seed)

# regret_online          = run_condition(A, theta, None,  None,  T, sigma, lam, delta,
#                                        V_bias=0.0,             sqrt_beta_N=0.0,
#                                        eval_seed=eval_seed)

# ====================== Save data to disk ======================

# # Save offline datasets and parameters for record
# np.savez("offline_biased.npz",
#          A=A, B=B_b, y=y_b, theta=theta, theta_off=theta_off_b,
#          N=N, gamma=gamma, sigma=sigma, lam=lam)

# np.savez("offline_unbiased.npz",
#          A=A, B=B_u, y=y_u, theta=theta, theta_off=theta_off_u,
#          N=N, gamma=0.0, sigma=sigma, lam=lam)

# # Save regret curves
# np.savez("regrets.npz",
#          regret_hybrid_biased=regret_hybrid_biased,
#          regret_hybrid_unbiased=regret_hybrid_unbiased,
#          regret_online=regret_online,
#          N=N, gamma=gamma, d=d, T=T)

# # ====================== Plotting ======================

# plt.figure(figsize=(7,5))                                               # create a figure
# plt.plot(regret_hybrid_biased,   label="Hybrid (Biased)",   linewidth=2)  # biased hybrid curve
# plt.plot(regret_hybrid_unbiased, label="Hybrid (Unbiased)", linewidth=2)  # unbiased hybrid curve
# plt.plot(regret_online,          label="Pure Online",       linewidth=2, linestyle="--")  # online curve

# plt.xlabel("Round t")                                                   # x-axis label
# plt.ylabel("Cumulative Pseudo-Regret")                                  # y-axis label
# plt.title(f"Cumulative Regret (N={N}, V={gamma*math.sqrt(d)})")                      # title with N and V=gamma
# plt.legend()                                                            # show legend
# plt.grid(True)                                                          # add grid
# plt.tight_layout()                                                      # adjust layout
# plt.savefig("regret_curve_with_unbiased.png", dpi=300)                  # save figure
# plt.show()                                                              # display figure

def run_many(A, theta, V0, b0, T, sigma, lam, delta,
             V_bias, sqrt_beta_N, n_runs=20, seed_base=1234):
    """
    Run the same condition n_runs times with different seeds.
    Returns mean and std arrays of shape (T,).
    """
    regrets_all = []
    for i in range(n_runs):
        rng_eval = np.random.default_rng(seed_base + i)  # different noise seed per run
        algo = HLinUCB(d=A.shape[1], lam=lam, sigma=sigma, delta=delta,
                       V0=V0, b0=b0, V_bias=V_bias, sqrt_beta_N=sqrt_beta_N)
        regrets = run_episode(A, theta, algo, T=T, sigma=sigma, rng=rng_eval)
        regrets_all.append(regrets)
    regrets_all = np.stack(regrets_all, axis=0)          # shape (n_runs, T)
    mean = regrets_all.mean(axis=0)
    std  = regrets_all.std(axis=0)
    return mean, std

n_runs = 50

mean_biased, std_biased = run_many(A, theta, V0_b, b0_b, T, sigma, lam, delta,
                                   V_bias=V_term_biased, sqrt_beta_N=sqrt_beta_N,
                                   n_runs=n_runs, seed_base=1000)

mean_unbiased, std_unbiased = run_many(A, theta, V0_u, b0_u, T, sigma, lam, delta,
                                       V_bias=V_term_unbiased, sqrt_beta_N=sqrt_beta_N,
                                       n_runs=n_runs, seed_base=2000)

mean_online, std_online = run_many(A, theta, None, None, T, sigma, lam, delta,
                                   V_bias=0.0, sqrt_beta_N=0.0,
                                   n_runs=n_runs, seed_base=3000)

plt.figure(figsize=(7,5))
plt.plot(mean_biased,   label="Hybrid (Biased)",   linewidth=2)
plt.fill_between(np.arange(T),
                 mean_biased-std_biased/np.sqrt(n_runs),
                 mean_biased+std_biased/np.sqrt(n_runs),
                 alpha=0.2)

plt.plot(mean_unbiased, label="Hybrid (Unbiased)", linewidth=2)
plt.fill_between(np.arange(T),
                 mean_unbiased-std_unbiased/np.sqrt(n_runs),
                 mean_unbiased+std_unbiased/np.sqrt(n_runs),
                 alpha=0.2)

plt.plot(mean_online,   label="Pure Online", linewidth=2, linestyle="--")
plt.fill_between(np.arange(T),
                 mean_online-std_online/np.sqrt(n_runs),
                 mean_online+std_online/np.sqrt(n_runs),
                 alpha=0.2)

plt.xlabel("Round t")
plt.ylabel("Cumulative Regret")
plt.title(f"Cumulative Regret over {n_runs} runs (N={N}, V={gamma*math.sqrt(d):.2f})")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()

# save regrets + parameters into one npz file
np.savez("regrets_avg.npz",
         mean_biased=mean_biased,
         std_biased=std_biased,
         mean_unbiased=mean_unbiased,
         std_unbiased=std_unbiased,
         mean_online=mean_online,
         std_online=std_online,
         d=d, K=K, T=T,
         N=N, gamma=gamma, lam=lam,
         sigma=sigma, delta=delta,
         n_runs=n_runs,
         V_term_biased=V_term_biased,
         V_term_unbiased=V_term_unbiased)
