import numpy as np
import math
import random
import matplotlib.pyplot as plt
from collections import defaultdict
from tqdm import tqdm
import time

###########################################################

def g_optimal_design(X, max_iters=1e1, tol=1e-5):
    K, d = X.shape
    # 1) p^(o) = uniform
    phi = np.ones(K) / K

    for _ in range(1, int(max_iters)+1):
        # 2) M(p) = sum_i p_i x_i x_i^T
        eps = 1e-5
        M = eps * np.eye(d, d)
        for i in range(K):
            M += phi[i] * np.outer(X[i], X[i])
        M_inv = np.linalg.inv(M)

        # 3) gradient direction i* = argmax_i x_i^T M^{-1} x_i
        scores = np.einsum('ij,jk,ik->i', X, M_inv, X) # square of mahalanobis distance
        i_star = np.argmax(scores)
        v_star = scores[i_star]

        # 4) step size r_k = [ (1/d)·v*  - 1 ] / [ v* - 1 ]
        gamma = ((v_star / d) - 1.0) / (v_star - 1.0)
        gamma = np.clip(gamma, 0.0, 1.0)

        # 5) p ← (1−r) * p + r * e_{i*}
        new_phi = (1 - gamma) * phi
        new_phi[i_star] += gamma

        # 6) terminate condition
        if np.linalg.norm(new_phi - phi, ord=1) < tol:
            phi = new_phi
            break

        phi = new_phi

    return phi

###########################################################

def near_g_optimal_design(X, max_iters=1e2, tol=1e-5):
    X = np.asarray(X, dtype=float)
    K, d = X.shape

    # --- rank reduction (work in r-dim subspace; gives same weights) ---
    # SVD threshold; you can tighten/loosen if needed
    s_threshold = 1e-12
    U, s, Vt = np.linalg.svd(X, full_matrices=False)
    r = int((s > s_threshold).sum())
    if r == 0:
        # all-zero inputs; return uniform
        return np.ones(K) / K
    Vr = Vt[:r, :].T           # (d, r)
    Xr = X @ Vr                # (K, r) projected contexts

    # --- init uniform weights on the simplex ---
    phi = np.ones(K) / K

    # small Tikhonov regularization for numerical stability in rank-deficient / ill-conditioned cases
    reg = 1e-12

    T = int(max_iters)

    for _ in range(max(1, T)):
        # Sigma = sum_i phi_i x_i x_i^T  (here in r-dim)
        # Build via a single GEMM: Xr.T @ (phi[:,None]*Xr)
        Sigma = Xr.T @ (phi[:, None] * Xr)
        # Regularize slightly to ensure PD
        if reg > 0:
            Sigma = Sigma + reg * np.eye(r)

        # Cholesky + triangular solves instead of explicit inverse (stable and fast)
        try:
            L = np.linalg.cholesky(Sigma)
        except np.linalg.LinAlgError:
            # if numerical issues, bump reg and retry
            eps = max(1e-12, reg * 10)
            Sigma = Sigma + eps * np.eye(r)
            L = np.linalg.cholesky(Sigma)

        # Compute scores v_i = x_i^T Sigma^{-1} x_i efficiently:
        # Solve Sigma^{-1} Xr^T via two triangular solves
        Y = np.linalg.solve(L, Xr.T)       # (r, K)
        Y = np.linalg.solve(L.T, Y)        # (r, K) = Sigma^{-1} Xr^T
        scores = np.sum(Xr * Y.T, axis=1)  # diag(Xr @ Sigma^{-1} @ Xr^T), shape (K,)

        i_star = int(np.argmax(scores))
        v_star = float(scores[i_star])

        # Frank–Wolfe duality gap g = max_i v_i - r  (r is the effective dimension)
        gap = v_star - r

        # Step size: gamma = (v* - r) / (r (v* - 1))  == ((v*/r) - 1) / (v* - 1)
        # Clip to [0,1] just in case of roundoff
        denom = max(1e-18, v_star - 1.0)
        gamma = (v_star - r) / (r * denom)
        gamma = float(np.clip(gamma, 0.0, 1.0))

        # Update weights toward the vertex e_{i*}
        new_phi = (1.0 - gamma) * phi
        new_phi[i_star] += gamma

        # Stopping: (a) FW gap small   OR   (b) step is tiny in L1
        if gap <= tol or np.linalg.norm(new_phi - phi, 1) <= tol:
            phi = new_phi
            break

        phi = new_phi

    return phi

###########################################################

def eliminate_arms(X, Lamdas, thetas, alpha):

    # for i=1...K, eliminate arms by UCB criterion
    for Λ, θ in zip(Lamdas, thetas):
        K, d = X.shape
        survivors = set(range(K))
        Λ_inv = np.linalg.inv(Λ)

        # UCB_i = x_i^T θ + α √(x_i^T Λ^{-1} x_i)
        # LCB_i = x_i^T θ - α √(x_i^T Λ^{-1} x_i)
        variances = np.einsum('ij,jk,ik->i', X, Λ_inv, X)
        ucbs = X.dot(θ) + alpha * np.sqrt(variances)
        lcbs = X.dot(θ) - alpha * np.sqrt(variances)

        # elimiate arms which UCB < LCB_max
        lcb_max = lcbs.max()
        survivors = {idx for idx, u in enumerate(ucbs) if u >= lcb_max}

        X = X[list(survivors)]

        # for whole arm elimination
        if not survivors:
            survivors = {0}

    return X

###########################################################

def eliminate_arms_nonUCB(X, Lamdas, thetas, batch, T):
    # for i=1...K, eliminate arms by UCB criterion
    for Λ, θ in zip(Lamdas, thetas):
        K, d = X.shape
        survivors = set(range(K))
        Λ_inv = np.linalg.inv(Λ)

        # In BatchCLAE-G, using λ = 1 (instead of log(dT)) is fine (the theory only changes up to a constant factor)
        beta_1 = np.sqrt(2 * np.log(K * (batch-1) * (T ** 2))) + 1
        beta_2 = 2 * np.sqrt(np.log((2 ** (6*d-5)) * np.pi * d * ((batch-1)**2) * (T**2) / (15 ** (d-1)))) + 2
        beta = min(beta_1, beta_2)

        variances = np.einsum('ij,jk,ik->i', X, Λ_inv, X)
        var_max = np.sqrt(np.max(variances))
        epsilon = var_max * beta

        scores = X @ θ
        idx_max = np.argmax(scores)
        best_context = X[idx_max]
        lst = []
        for i in range(K):
            lst.append(best_context - X[i])

        survivors = {idx for idx, u in enumerate(lst) if u @ θ <= epsilon}

        X = X[list(survivors)]

        # for whole arm elimination
        if not survivors:
            survivors = {0}

    return X

###########################################################

# return times when batch ends
def grid(T,k,d):

    tilde_d = d * np.log(d * K * (T ** 2)) * np.log((T ** 2) / (10 / T))
    h = min(K, d)
    # M = np.int_(np.log2(np.log2(T))) + 2
    M = 8
    grad_t = np.zeros(M)

    if(T < tilde_d):
        M = 1
        G = np.zeros(M)
        G[-1] = T
        G = np.ceil(G).astype(np.int64)

        return G
    
    elif(T <= tilde_d * (float(h) ** (2 - (2.0 ** (2-M))))):
        G  = np.zeros(M)
        rate_1 = 1/(2 - (2.0 ** (2-M)))
        rate_2 = (1 - (2.0 ** (2-M)))/(2 - (2.0 ** (2-M)))
        gamma = (T ** rate_1) * (tilde_d ** rate_2)
        grad_t[0] = gamma
        grad_t[1] = gamma
        for i in range(1, M-1):
            grad_t[i+1] = math.sqrt(grad_t[i]) * gamma/math.sqrt(tilde_d)

    else:
        G = np.zeros(M)
        rate_1 = 1/(2 - (2.0 ** (1-M)))
        rate_2 = (1 - (2.0 ** (1-M)))/(2 - (2.0 ** (1-M)))
        rate_3 = (2.0 ** (1-M)) / (2 - (2.0 ** (1-M)))
        gamma = (T ** rate_1) * (tilde_d ** rate_2) * (h ** rate_3)
        grad_t[0] = gamma
        grad_t[1] = gamma * math.sqrt(grad_t[0]) / math.sqrt(tilde_d * h)
        for i in range(1, M-1):
            grad_t[i+1] = math.sqrt(grad_t[i]) * gamma/math.sqrt(tilde_d)

    G = np.ceil(grad_t).astype(np.int64)
    G = np.cumsum(G)
    idx = 0

    for i in range(len(G)):
        if(G[i] > T):
            if(G[i-1] == T):
                idx = i-1
            else:
                idx = i
            break
    G = G[:idx+1]
    G[-1] = T

    return G

###########################################################

class ExpPolicy:
    def __init__(self, Z_list, L, T):
        self.Z_list = Z_list
        self.L = L
        self.m = len(Z_list)
        self.d = Z_list[0].shape[1] # K x d
        self.T = T

    def run(self):
        # Initialization
        kappa = 1.0 / (self.T ** 2)
        U = kappa * np.eye(self.d) # U_0
        epochs = [] # return list
        tau = [] # τ_η
        W = U.copy() # W_η

        for i, Z_i in enumerate(self.Z_list):
            # add to current epoch
            tau.append(i)

            # choose z_i ∈ Z_i maximizing zᵀ W⁻¹ z
            W_inv = np.linalg.inv(W)
            scores = np.einsum('ij,jk,ik->i', Z_i, W_inv, Z_i)
            z_i = Z_i[np.argmax(scores)]

            # compute clipped vector
            norm2 = float(z_i.T @ W_inv @ z_i)
            clip = min(math.sqrt(self.L / norm2), 1.0)
            z_tilde = clip * z_i

            # update U
            U = U + np.outer(z_tilde, z_tilde)

            # det(U_i) > 2 det(W_η)?
            if np.linalg.det(U) > 2 * np.linalg.det(W):
                # close out previous epoch
                epochs.append({
                    'W': W.copy(),
                    'count': len(tau)
                })
                # start new epoch
                tau = []
                W = U.copy()

        # add last epoch
        if len(tau) > 0:
            epochs.append({
                'W': W.copy(),
                'count': len(tau)
            })

        return epochs

###########################################################

class BatchLearner:
    def __init__(self, T_schedule, T, n_features, alpha):
        self.T_schedule = T_schedule
        self.T = T
        self.M = len(T_schedule)
        self.d = n_features
        self.lamda = 10/T
        self.L = 1/(200 * np.log(self.d * (T ** 2)))
        self.alpha = alpha
        self.Lamdas = [] 
        self.thetas = [] 
        self.survivor_counts_per_batch = []        

    def run(self, get_contexts, play):
        all_Xk = []  # Collect all contexts

        # === Batch 1: t = 1..T1 ===
        T1 = self.T_schedule[0]
        half1 = T1 // 2
        Ys, rewards = [], [] # selected arms, rewards
        survivor_counts_batch = []

        for t in tqdm(range(1, T1+1), desc = "[Batch 1]", disable=True):
            X_t = get_contexts(t)  # (K, d)

            # select arm based on G-optimal design
            phi = near_g_optimal_design(X_t)
            idx = np.random.choice(len(phi), p = phi)
            y_t = X_t[idx]
            r_t = play(y_t)

            Ys.append(y_t)
            rewards.append(r_t)
            all_Xk.append(X_t)

        # estimate Λ₁, θ₁ after batch 1
        Y_mat = np.stack(Ys) # T1 x d
        R_vec = np.array(rewards)
        Λ1 = self.lamda * np.eye(self.d) + Y_mat[:half1].T @ Y_mat[:half1]
        θ1 = np.linalg.inv(Λ1) @ (Y_mat[:half1].T @ R_vec[:half1])

        self.Lamdas.append(Λ1)
        self.thetas.append(θ1)

        # using other half1 arms for elimination
        Z_list1 = []
        for X_t in all_Xk[half1:]:
            X_reduced = eliminate_arms(X_t, self.Lamdas, self.thetas, self.alpha)
            Z_list1.append(X_reduced)
            # survivor_counts_batch.append(len(X_reduced))

        # === π₂ ← Exp-Policy({Z_i}_{i=1..T1/2}, L) ===
        exp1 = ExpPolicy(Z_list1, self.L, self.T)
        epochs = exp1.run()
        self.survivor_counts_per_batch.append(len(X_t))

        # subsequent batches k=2..M
        for k in range(2, self.M+1):
            Tk = self.T_schedule[k-1]
            prev = self.T_schedule[k-2]
            Ys, rewards = [], []
            survivor_counts_batch = []

            for t in tqdm(range(prev+1, Tk+1), desc = f"[Batch {k}]", disable=True):
                X_t = get_contexts(t)

                # arm elimination: E(X_t; {Λ_i,θ_i}_{i=1..k-1})
                Xk_t = eliminate_arms(X_t, self.Lamdas[:k-1], self.thetas[:k-1], self.alpha)
                survivor_counts_batch.append(len(X_t))

                prob_dict = defaultdict(float)

                for j, e in enumerate(epochs):
                    Wj_inv = np.linalg.inv(e['W'])
                    scores = np.einsum('ij,jk,ik->i', Xk_t, Wj_inv, Xk_t)
                    idx = np.argmax(scores)
                    prob_dict[idx] += e['count']

                prob_vec = np.zeros(len(Xk_t))
                for idx, p in prob_dict.items():
                    prob_vec[idx] = p
                prob_vec /= prob_vec.sum()

                # play arm and receive the reward
                chosen_idx = np.random.choice(len(Xk_t), p = prob_vec)
                y_t = Xk_t[chosen_idx]
                r_t = play(y_t)
                Ys.append(y_t)
                rewards.append(r_t)
            
            self.survivor_counts_per_batch.append(np.mean(survivor_counts_batch))

            # estimate Λ_k, θ_k after batch K
            halfk = (Tk + prev + 1) // 2
            Y_mat = np.stack(Ys) # (# of plays) * d
            R_vec = np.array(rewards) # (# of plays,)
            Λk = self.lamda * np.eye(self.d) + Y_mat[:halfk].T @ Y_mat[:halfk]
            θk = np.linalg.inv(Λk) @ (Y_mat[:halfk].T @ R_vec[:halfk])

            self.Lamdas.append(Λk)
            self.thetas.append(θk)

            # using other half1 arms for elimination
            Z_listk = []
            for t in range(halfk, Tk+1):
                X_t = get_contexts(t)
                Xk_t = eliminate_arms(X_t, self.Lamdas[:k], self.thetas[:k], self.alpha)
                Z_listk.append(Xk_t)

            # π_{k+1}
            expk = ExpPolicy(Z_listk, self.L, self.T)
            epochs = expk.run()

        return
    
###########################################################

def generate_batch_schedule_new(T):
    batch_ends = []
    m = 2 + np.int64(np.log2(np.log2(T)))

    for i in range(1, m + 1):
        Ti = int(np.ceil(T ** (1 - 2 ** (-i))))
        batch_ends.append(Ti)

    batch_ends.append(T)
    return batch_ends

###########################################################

def generate_batch_schedule_argmax(T):
    batch_ends = []
    m = np.log2(np.log2(T))
    t = 0
    i = 1

    while t < T:
        if(i == 1):
            Ti = int(np.ceil(T ** (1 - 2 ** (-i)) / m)) + 1
        else:
            Ti = int(np.ceil(T ** (1 - 2 ** (-i)) / m)) + 2
        t += Ti
        if(t < T):
            batch_ends.append(t)
        i += 1

    batch_ends.append(T)
    return batch_ends

###########################################################

def _batchclae_beta(K, d, batch, T):
    beta_1 = np.sqrt(2.0 * np.log(K * max(batch-1, 1) * (T ** 2))) + 1
    beta_2 = 2.0 * np.sqrt(np.log((2.0 ** (6*d-5)) * np.pi * d * (max(batch-1,1)**2) * (T**2) / (15.0 ** (d-1)))) + 2
    return min(beta_1, beta_2)

def eliminate_arms_nonUCB_vec(X, chol_list, thetas, beta):
    X_curr = X
    for L, theta in zip(chol_list, thetas):
        if X_curr.size == 0:
            return X_curr
        
        Z = np.linalg.solve(L, X_curr.T)                # (d, K_curr)
        var = np.sum(Z * Z, axis=0)                     # (K_curr,)
        eps = np.sqrt(var.max()) * beta

        scores = X_curr @ theta                          # (K_curr,)
        thr = scores.max() - eps
        keep = (scores >= thr)                           
        X_curr = X_curr[keep]
    return X_curr

class BatchCLAE:
    def __init__(self, T, K, d, c):
        self.T = T
        self.K = K
        self.d = d
        self.c = c
        self.lamda = np.log(d*T)
        self.schedule = generate_batch_schedule_argmax(T)
        self.batch = len(self.schedule)
        self.Lamdas = []
        self.thetas = []
        # self.survivor_counts_per_batch = [] 

    def run(self, get_contexts, play, tracker, verbose=False): 
        prev = 0
        m = np.log2(np.log2(self.T))
        # if verbose:
        #     print(self.schedule)

        chol_list = [np.linalg.cholesky(L) for L in self.Lamdas]  # ⟵ PATCH
        beta_const = _batchclae_beta(self.K, self.d, self.batch, self.T)  # ⟵ PATCH

        for i, end in enumerate(self.schedule):
            Ys, Rs = [], []
            # survivor_counts_batch = []

            H_inv = (1.0 / self.lamda) * np.eye(self.d)           # ⟵ PATCH

            for t in tqdm(range(prev + 1, end + 1), desc=f"[Batch {i + 1}]", disable=True):
                X_t = get_contexts(t)

                if len(chol_list) > 0:                             # ⟵ PATCH
                    X_t = eliminate_arms_nonUCB_vec(X_t, chol_list, self.thetas, beta_const)  # ⟵ PATCH
                # survivor_counts_batch.append(len(X_t))
                if len(X_t) == 0:
                    continue

                if i == 0:
                    # scores = diag(X H_inv X^T)
                    scores = np.einsum("ij,jk,ik->i", X_t, H_inv, X_t, optimize=True)  # ⟵ PATCH
                    idx = int(np.argmax(scores))
                    x = X_t[idx]
                    # H_inv ← (H + x x^T)^{-1}  via Sherman–Morrison
                    Hx = H_inv @ x
                    denom = 1.0 + float(x @ Hx)
                    H_inv -= np.outer(Hx, Hx) / max(denom, 1e-15)  # ⟵ PATCH
                    r = play(x, t, tracker)
                    Ys.append(x); Rs.append(r)

                else:
                    cutoff_1 = prev + int(np.ceil(self.c * self.T ** (1 - 2 ** (-(i+1))) / m))
                    if t <= cutoff_1:
                        
                        scores = np.einsum("ij,jk,ik->i", X_t, H_inv, X_t, optimize=True)  # ⟵ PATCH
                        idx = int(np.argmax(scores))
                        x = X_t[idx]
                        Hx = H_inv @ x
                        denom = 1.0 + float(x @ Hx)
                        H_inv -= np.outer(Hx, Hx) / max(denom, 1e-15)                     # ⟵ PATCH
                        r = play(x, t, tracker)
                        Ys.append(x); Rs.append(r)

                    else:
                        
                        theta_last = self.thetas[-1] if len(self.thetas) > 0 else np.zeros(self.d)
                        idx = int(np.argmax(X_t @ theta_last))
                        x = X_t[idx]
                        r = play(x, t, tracker)
                        Ys.append(x); Rs.append(r)

            # self.survivor_counts_per_batch.append(np.mean(survivor_counts_batch))

            if len(Ys) > 0:
                Y_mat = np.stack(Ys)
                R_vec = np.array(Rs)
                Λk = self.lamda * np.eye(self.d) + Y_mat.T @ Y_mat
                
                θk = _chol_solve(Λk, Y_mat.T @ R_vec)              # ⟵ PATCH
                self.Lamdas.append(Λk)
                self.thetas.append(θk)

                
                L_new = np.linalg.cholesky(Λk)                     # ⟵ PATCH
                chol_list.append(L_new)                             # ⟵ PATCH

            prev = end

# class BatchCLAE:
#     def __init__(self, T, K, d, c):
#         self.T = T
#         self.K = K
#         self.d = d
#         self.c = c
#         self.lamda = np.log(d*T)
#         self.schedule = generate_batch_schedule_argmax(T)
#         self.batch = len(self.schedule)
#         self.Lamdas = []
#         self.thetas = []
#         self.survivor_counts_per_batch = [] 

#     def run(self, get_contexts, play, tracker):
#         prev = 0
#         m = np.log2(np.log2(self.T))

#         print(self.schedule)

#         for i, end in enumerate(self.schedule):
#             Ys, Rs = [], []
#             survivor_counts_batch = []
#             H = self.lamda * np.eye(self.d)

#             for t in tqdm(range(prev + 1, end + 1), desc=f"[Batch {i + 1}]", disable=True):
#                 X_t = get_contexts(t)
#                 X_t = eliminate_arms_nonUCB(X_t, self.Lamdas, self.thetas, self.batch, self.T)       
#                 survivor_counts_batch.append(len(X_t))

#                 if(i == 0):
#                     H_inv = np.linalg.inv(H)
#                     scores = np.einsum("ij,jk,ik->i", X_t, H_inv, X_t)
#                     idx = np.argmax(scores)

#                     x = X_t[idx]
#                     H += np.outer(x, x)
#                     r = play(x, t, tracker)
#                     Ys.append(x)
#                     Rs.append(r)
#                 else:
#                     if(t <= prev + int(np.ceil(self.c * self.T ** (1 - 2 ** (-(i+1))) / m))):
#                         H_inv = np.linalg.inv(H)
#                         scores = np.einsum("ij,jk,ik->i", X_t, H_inv, X_t)
#                         idx = np.argmax(scores)

#                         x = X_t[idx]
#                         H += np.outer(x, x)
#                         r = play(x, t, tracker)
#                         Ys.append(x)
#                         Rs.append(r)

#                     else:
#                         dot_products = np.dot(X_t, self.thetas[-1])
#                         idx = np.argmax(dot_products)

#                         x = X_t[idx]
#                         r = play(x, t, tracker)
#                         Ys.append(x)
#                         Rs.append(r)

#             self.survivor_counts_per_batch.append(np.mean(survivor_counts_batch))
#             # update Lambda, theta after batch
#             if len(Ys) == 0:
#                 continue
#             Y_mat = np.stack(Ys)
#             R_vec = np.array(Rs)
#             Λk = self.lamda * np.eye(self.d) + Y_mat.T @ Y_mat
#             θk = np.linalg.inv(Λk) @ (Y_mat.T @ R_vec)
#             self.Lamdas.append(Λk)
#             self.thetas.append(θk)
#             prev = end

###########################################################

def _batchclae_beta(K, d, batch, T):
    
    beta_1 = np.sqrt(2.0 * np.log(2.0 * K * max(batch-1, 1) * (T ** 2))) + 1
    beta_2 = 2.0 * np.sqrt(np.log((2.0 ** (6*d-3)) * np.pi * d * (max(batch-1,1)**2) * (T**2) / (15.0 ** (d-1)))) + 2
    return min(beta_1, beta_2)

def eliminate_arms_nonUCB_vec(X, chol_list, thetas, beta):
    
    X_curr = X
    for L, theta in zip(chol_list, thetas):
        if X_curr.size == 0:
            return X_curr
        
        Z = np.linalg.solve(L, X_curr.T)          # (d, K_curr)
        var = np.sum(Z * Z, axis=0)               # (K_curr,)
        eps = np.sqrt(var.max()) * beta

        scores = X_curr @ theta                   # (K_curr,)
        thr = scores.max() - eps
        keep = (scores >= thr)                    
        X_curr = X_curr[keep]
    return X_curr

class BatchCLAE_G:
    def __init__(self, T, K, d, c):
        self.T = T
        self.K = K
        self.d = d
        self.c = c
        self.lamda = np.log(d*T)
        self.schedule = generate_batch_schedule_argmax(T)
        self.batch = len(self.schedule)
        self.Lamdas = []
        self.thetas = []
        self.survivor_counts_per_batch = [] 

    def run(self, get_contexts, play, tracker):
        prev = 0
        m = np.log2(np.log2(self.T))

        print(self.schedule)

        
        chol_list = [np.linalg.cholesky(L) for L in self.Lamdas]     # ⟵ PATCH
        beta_const = _batchclae_beta(self.K, self.d, self.batch, self.T)  # ⟵ PATCH

        for i, end in enumerate(self.schedule):
            Ys, Rs = [], []
            survivor_counts_batch = []

            H_inv = (1.0 / self.lamda) * np.eye(self.d)              # ⟵ PATCH

            Tpow = self.T ** (1 - 2 ** (-(i+1)))
            if i == 0:
                cutoff1 = prev + int(np.ceil(self.c * Tpow / m))
                cutoff2 = cutoff1 
            else:
                g_len   = int(np.ceil(self.c * self.c * Tpow / m))
                exp_len = int(np.ceil(self.c * (1.0 - self.c) * Tpow / m))
                cutoff1 = prev + g_len
                cutoff2 = cutoff1 + exp_len

            for t in tqdm(range(prev + 1, end + 1), desc=f"[Batch {i + 1}]", disable=True):
                X_t = get_contexts(t)

                if len(chol_list) > 0:                                 # ⟵ PATCH
                    X_t = eliminate_arms_nonUCB_vec(X_t, chol_list, self.thetas, beta_const)  # ⟵ PATCH
                survivor_counts_batch.append(len(X_t))
                if len(X_t) == 0:
                    continue

                if i == 0:
                    if t <= cutoff1:
                        # ---- g-optimal sampling ----
                        probs = near_g_optimal_design(X_t)
                        idx = int(np.random.choice(len(probs), p=probs))
                        x = X_t[idx]
                        # H_inv ← (H + x x^T)^{-1} via Sherman–Morrison
                        Hx = H_inv @ x
                        denom = 1.0 + float(x @ Hx)
                        H_inv -= np.outer(Hx, Hx) / max(denom, 1e-15)       # ⟵ PATCH
                    else:
                        scores = np.einsum("ij,jk,ik->i", X_t, H_inv, X_t, optimize=True)  # ⟵ PATCH
                        idx = int(np.argmax(scores))
                        x = X_t[idx]
                        Hx = H_inv @ x
                        denom = 1.0 + float(x @ Hx)
                        H_inv -= np.outer(Hx, Hx) / max(denom, 1e-15)       # ⟵ PATCH

                    r = play(x, t, tracker)
                    Ys.append(x); Rs.append(r)

                else:
                    if t <= cutoff1:
                        # ---- g-optimal sampling ----
                        probs = near_g_optimal_design(X_t)
                        idx = int(np.random.choice(len(probs), p=probs))
                        x = X_t[idx]
                        Hx = H_inv @ x
                        denom = 1.0 + float(x @ Hx)
                        H_inv -= np.outer(Hx, Hx) / max(denom, 1e-15)       # ⟵ PATCH
                        r = play(x, t, tracker)
                        Ys.append(x); Rs.append(r)

                    elif t <= cutoff2:
                        scores = np.einsum("ij,jk,ik->i", X_t, H_inv, X_t, optimize=True)  # ⟵ PATCH
                        idx = int(np.argmax(scores))
                        x = X_t[idx]
                        Hx = H_inv @ x
                        denom = 1.0 + float(x @ Hx)
                        H_inv -= np.outer(Hx, Hx) / max(denom, 1e-15)       # ⟵ PATCH
                        r = play(x, t, tracker)
                        Ys.append(x); Rs.append(r)

                    else:
                        theta_last = self.thetas[-1] if len(self.thetas) > 0 else np.zeros(self.d)
                        idx = int(np.argmax(X_t @ theta_last))
                        x = X_t[idx]
                        r = play(x, t, tracker)
                        Ys.append(x); Rs.append(r)

            self.survivor_counts_per_batch.append(np.mean(survivor_counts_batch))

            if len(Ys) > 0:
                Y_mat = np.stack(Ys)
                R_vec = np.array(Rs)
                Λk = self.lamda * np.eye(self.d) + Y_mat.T @ Y_mat
                θk = _chol_solve(Λk, Y_mat.T @ R_vec)                   # ⟵ PATCH
                self.Lamdas.append(Λk)
                self.thetas.append(θk)
                chol_list.append(np.linalg.cholesky(Λk))               # ⟵ PATCH

            prev = end

# class BatchCLAE_G:
#     def __init__(self, T, K, d, c):
#         self.T = T
#         self.K = K
#         self.d = d
#         self.c = c
#         self.lamda = np.log(d*T)
#         self.schedule = generate_batch_schedule_argmax(T)
#         self.batch = len(self.schedule)
#         self.Lamdas = []
#         self.thetas = []
#         self.survivor_counts_per_batch = [] 

#     def run(self, get_contexts, play, tracker):
#         prev = 0
#         m = np.log2(np.log2(self.T))

#         print(self.schedule)

#         for i, end in enumerate(self.schedule):
#             Ys, Rs = [], []
#             survivor_counts_batch = []
#             H = self.lamda * np.eye(self.d)

#             for t in tqdm(range(prev + 1, end + 1), desc=f"[Batch {i + 1}]", disable=True):
#                 X_t = get_contexts(t)
#                 X_t = eliminate_arms_nonUCB(X_t, self.Lamdas, self.thetas, self.batch, self.T)       
#                 survivor_counts_batch.append(len(X_t))

#                 if(i == 0):
#                     if(t <= prev + int(np.ceil(self.c * self.T ** (1 - 2 ** (-(i+1))) / m))):
#                         probs = near_g_optimal_design(X_t)
#                         idx = np.random.choice(len(probs), p = probs)

#                         x = X_t[idx]
#                         H += np.outer(x, x)
#                         r = play(x, t, tracker)
#                         Ys.append(x)
#                         Rs.append(r)

#                     else:
#                         H_inv = np.linalg.inv(H)
#                         scores = np.einsum("ij,jk,ik->i", X_t, H_inv, X_t)
#                         idx = np.argmax(scores)

#                         x = X_t[idx]
#                         H += np.outer(x, x)
#                         r = play(x, t, tracker)
#                         Ys.append(x)
#                         Rs.append(r)
#                 else:
#                     if(t <= prev + int(np.ceil(self.c * self.c * self.T ** (1 - 2 ** (-(i+1))) / m))):
#                         probs = near_g_optimal_design(X_t)
#                         idx = np.random.choice(len(probs), p = probs)

#                         x = X_t[idx]
#                         H += np.outer(x, x)
#                         r = play(x, t, tracker)
#                         Ys.append(x)
#                         Rs.append(r)

#                     elif(t <= prev + int(np.ceil(self.c * self.c * self.T ** (1 - 2 ** (-(i+1))) / m)) + int(np.ceil(self.c * (1 - self.c) * self.T ** (1 - 2 ** (-(i+1))) / m))):
#                         H_inv = np.linalg.inv(H)
#                         scores = np.einsum("ij,jk,ik->i", X_t, H_inv, X_t)
#                         idx = np.argmax(scores)

#                         x = X_t[idx]
#                         H += np.outer(x, x)
#                         r = play(x, t, tracker)
#                         Ys.append(x)
#                         Rs.append(r)

#                     else:
#                         dot_products = np.dot(X_t, self.thetas[-1])
#                         idx = np.argmax(dot_products)

#                         x = X_t[idx]
#                         r = play(x, t, tracker)
#                         Ys.append(x)
#                         Rs.append(r)

#             self.survivor_counts_per_batch.append(np.mean(survivor_counts_batch))
#             # update Lambda, theta after batch
#             if len(Ys) == 0:
#                 continue
#             Y_mat = np.stack(Ys)
#             R_vec = np.array(Rs)
#             Λk = self.lamda * np.eye(self.d) + Y_mat.T @ Y_mat
#             θk = np.linalg.inv(Λk) @ (Y_mat.T @ R_vec)
#             self.Lamdas.append(Λk)
#             self.thetas.append(θk)
#             prev = end

###########################################################
# Ruan

# =========================================================
# Helpers
# =========================================================

def _softmax(scores, alpha):
    # s = alpha * (scores - np.max(scores))
    p = np.power(scores, alpha)
    # p = np.exp(s)
    Z = p.sum()
    return p / Z if np.isfinite(Z) and Z > 0 else np.ones_like(scores) / len(scores)

def _expected_xxT(X, probs):
    # X: (K,d), probs: (K,)
    return X.T @ (probs[:, None] * X)

def _chol_inv(A):
    L = np.linalg.cholesky(A)
    Linv = np.linalg.inv(L)
    return Linv.T @ Linv

def _chol_solve(A, b):
    L = np.linalg.cholesky(A)
    y = np.linalg.solve(L, b)
    return np.linalg.solve(L.T, y)

def _logdet(A):
    s, v = np.linalg.slogdet(A)
    return -np.inf if s <= 0 else float(v)

def _eliminate_intersection_ucb(X, thetas, Lambdas, alpha):
    if len(thetas) == 0:
        return X
    keep_mask = np.ones(len(X), dtype=bool)
    for θ, Λ in zip(thetas, Lambdas):
        Λ_inv = _chol_inv(Λ)
        var = np.einsum('ij,jk,ik->i', X, Λ_inv, X, optimize=True)
        sd = np.sqrt(np.maximum(var, 1e-15))
        ucbs = X @ θ + alpha * sd
        lcbs = X @ θ - alpha * sd
        thresh = float(lcbs.max())
        keep_mask &= (ucbs >= thresh)
        if not np.any(keep_mask):
            return np.zeros((0, X.shape[1]))
        X = X[keep_mask]
    return X

# =========================================================
# π^G = g-optimal design sampler
# =========================================================

class GOptimalSampler:
    def __init__(self, design_fn):
        self.design_fn = design_fn

    def sample_index(self, X):
        phi = self.design_fn(X)
        if not np.all(np.isfinite(phi)) or phi.sum() <= 0:
            phi = np.ones(len(X)) / len(X)
        return int(np.random.choice(len(X), p=phi))

    def expected_xxT(self, X):
        """E_{x~π^G(X)}[xx^T] = X^T diag(φ) X"""
        phi = self.design_fn(X)
        if not np.all(np.isfinite(phi)) or phi.sum() <= 0:
            phi = np.ones(len(X)) / len(X)
        phi = phi / phi.sum()
        return _expected_xxT(X, phi)

# =========================================================
# Definition 3: Mixed-Softmax Policy (with π^G = g-optimal)
# =========================================================

class MixedSoftmaxPolicyGOpt:
    """
    π_MS(X):
      - with prob 1/2: draw i ~ π^G(X)  (g-optimal design)
      - with prob p_i/2: draw i via softmax_α(x^T M_i x) for mixture item i
    mixture = [(p_i, M_i)],  sum p_i = 1
    """
    def __init__(self, mixture, gopt_sampler, alpha_temp=None):
        self.mixture = mixture
        self.ps = np.array([p for p, _ in mixture], dtype=float)
        self.Ms = [M for _, M in mixture]
        self.gopt = gopt_sampler
        self.alpha_temp = alpha_temp  # default: log K at call-time

    def choose(self, X):
        K = len(X)
        if K == 0:
            raise ValueError("Empty context set.")
        alpha = self.alpha_temp if self.alpha_temp is not None else float(np.log(max(K, 2)))
        if np.random.rand() < 0.5:
            return self.gopt.sample_index(X)
        else:
            j = int(np.random.choice(len(self.mixture), p=self.ps))
            M = self.Ms[j]
            scores = np.einsum('id,df,if->i', X, M, X, optimize=True)
            probs = _softmax(scores, alpha)
            return int(np.random.choice(K, p=probs))

# =========================================================
# Algorithm 4: CoreIdentification (π^G = g-optimal)
# =========================================================

def core_identification_gopt(S_sets, lam, gamma, d, gopt_sampler, c=6, max_pass=10):
    C = list(S_sets)

    for _ in range(max_pass):
        if not C:
            break
        cov = np.zeros((d, d))
        for X in C:
            cov += gopt_sampler.expected_xxT(X)  # E_{π^G}[xx^T]
        M = lam * np.eye(d) + (1.0 / max(gamma, 1.0)) * cov
        M_inv = _chol_inv(M)

        thresh = 0.5 * (d ** c)
        C_next = []
        for X in C:
            smax = float(np.max(np.einsum('id,df,if->i', X, M_inv, X, optimize=True)))
            if smax <= thresh:
                C_next.append(X)
        if len(C_next) == len(C):
            break
        C = C_next
    return C

# =========================================================
# Algorithm 2: Distributional G-opt via softmax expectation
# =========================================================

def learn_distributional_gopt_modified(S_sets, lam, T_total, d, gopt_sampler, N=None):
    Γ = len(S_sets)
    if Γ == 0:
        return [(1.0, np.eye(d))]

    if N is None:
        N = max(1, int(np.ceil(2 * (d ** 2) * np.log(max(d, 2))))) 

    # U0
    cov0 = np.zeros((d, d))
    for X in S_sets:
        cov0 += gopt_sampler.expected_xxT(X)
    U = lam * N * T_total * np.eye(d) + 0.5 * N * cov0
    W_n = U.copy()
    logdet_W = _logdet(W_n)

    epoch_counts = [0]
    W_list = [W_n.copy()]

    alphas = [float(np.log(max(len(X), 2))) for X in S_sets]

    for t in range(N * Γ):
        i = t % Γ
        X = S_sets[i]
        W_inv = _chol_inv(W_n)
        scores = np.einsum('id,df,if->i', X, W_inv, X, optimize=True)
        p_soft = _softmax(scores, alphas[i])
        X_exp = _expected_xxT(X, p_soft)
        U = U + X_exp
        epoch_counts[-1] += 1

        if _logdet(U) - logdet_W >= np.log(2.0):
            W_n = U.copy()
            logdet_W = _logdet(W_n)
            W_list.append(W_n.copy())
            epoch_counts.append(0)

    counts = np.array(epoch_counts, dtype=float)
    counts = counts[counts > 0]
    if counts.size == 0:
        counts = np.array([1.0])
    pis = counts / counts.sum()

    # M_i = N T W_i^{-1}
    mixture = []
    for p, W in zip(pis, W_list[:len(pis)]):
        M = (N * T_total) * _chol_inv(W)
        mixture.append((float(p), M))
    return mixture

# =========================================================
# Algorithm 3: CoreLearning wrapper (π^G=g-optimal)
# =========================================================

def core_learning_gopt(S_sets, lam, T_total, d, gopt_sampler, c=6):
    Γ = len(S_sets)
    gamma = max(1, Γ)
    C = core_identification_gopt(S_sets, lam=lam, gamma=gamma, d=d, gopt_sampler=gopt_sampler, c=c, max_pass=int(d*np.log(T_total)))
    if len(C) == 0:
        C = S_sets
    mixture = learn_distributional_gopt_modified(C, lam=lam, T_total=T_total, d=d, gopt_sampler=gopt_sampler, N=None)
    return MixedSoftmaxPolicyGOpt(mixture, gopt_sampler=gopt_sampler, alpha_temp=None)

# =========================================================
# Algorithm 5: BatchLinUCB-DG (with π^G = g-optimal)
# =========================================================

def _schedule_batchlinucb_dg(T):
    M = int(np.floor(np.log2(np.log2(max(T, 3))))) + 1
    if M < 2: M = 2
    tau = []
    T1 = int(np.ceil(np.sqrt(T))); tau.append(min(T, T1))
    if M >= 2:
        T2 = int(np.ceil(2 * np.sqrt(T))); tau.append(min(T, T2))
    for i in range(3, M):
        Ti = int(np.ceil(T ** (1.0 - 2.0 ** (-(i - 1)))))
        tau.append(min(T, Ti))
    if tau[-1] != T:
        tau.append(T)
    for i in range(1, len(tau)):
        if tau[i] <= tau[i-1]:
            tau[i] = min(T, tau[i-1] + 1)
    tau[-1] = T
    return tau

class BatchLinUCB_DG_GOPT:
    def __init__(self, T, K, d, delta=None, design_fn=None):
        self.T, self.K, self.d = T, K, d
        self.delta = (1.0 / T) if (delta is None) else delta
        self.alpha = 10.0 * np.sqrt(np.log(2.0 * d * K * T / self.delta))  # Alg.5
        self.lam_ridge = 32.0 * np.log(2.0 * d * T / self.delta)           # Alg.5 line 12
        self.schedule = _schedule_batchlinucb_dg(T)
        # π^G = (near_)g_optimal_design
        self.gopt_sampler = GOptimalSampler(design_fn if design_fn is not None else near_g_optimal_design)
        self.Lambdas, self.Thetas = [], []
        self.policies = []
        self.survivor_counts_per_batch = []

    def run(self, get_contexts, play, tracker):
        prev = 0
        init_mixture = [(1.0, np.eye(self.d))]
        pi_prev = MixedSoftmaxPolicyGOpt(init_mixture, gopt_sampler=self.gopt_sampler)

        for k, end in enumerate(self.schedule, start=1):
            times = list(range(prev + 1, end + 1))
            chosen_x = {}
            rewards = {}
            survived_X = {}

            for t in tqdm(times, desc=f"[BatchLinUCB-DG {k}/{len(self.schedule)}]", disable=True):
                X_t = get_contexts(t)
                X_surv = _eliminate_intersection_ucb(X_t, self.Thetas, self.Lambdas, self.alpha)
                survived_X[t] = X_surv
                if len(X_surv) == 0:
                    continue
                idx = pi_prev.choose(X_surv)
                x = X_surv[idx]
                r = play(x, t, tracker)
                chosen_x[t] = x
                rewards[t] = r

            mid = len(times) // 2
            A_times = times[:mid]
            B_times = times[mid:]

            Y_A = np.stack([chosen_x[t] for t in A_times if t in chosen_x], axis=0) if any(t in chosen_x for t in A_times) else None
            R_A = np.array([rewards[t] for t in A_times if t in rewards]) if any(t in rewards for t in A_times) else None

            if Y_A is None or len(Y_A) == 0:
                if len(chosen_x) == 0:
                    self.survivor_counts_per_batch.append(0.0)
                    prev = end
                    continue
                Y_A = np.stack(list(chosen_x.values()), axis=0)
                R_A = np.array(list(rewards.values()))

            Λk = self.lam_ridge * np.eye(self.d) + Y_A.T @ Y_A
            ξk = Y_A.T @ R_A
            θk = _chol_solve(Λk, ξk)
            self.Lambdas.append(Λk)
            self.Thetas.append(θk)

            S_sets = []
            surv_sizes = []
            for t in B_times:
                X_t = survived_X.get(t, get_contexts(t))
                X_surv_k = _eliminate_intersection_ucb(X_t, [θk], [Λk], self.alpha)
                if len(X_surv_k) > 0:
                    S_sets.append(X_surv_k)
                    surv_sizes.append(len(X_surv_k))
            self.survivor_counts_per_batch.append(float(np.mean(surv_sizes)) if surv_sizes else 0.0)

            lam_core = 1.0 / self.T
            pi_k = core_learning_gopt(S_sets, lam=lam_core, T_total=self.T, d=self.d, gopt_sampler=self.gopt_sampler, c=6)
            self.policies.append(pi_k)
            pi_prev = pi_k
            prev = end
        return
    
###########################################################
# Hanna

# =========================================================
# Helpers for SoftBatch
# =========================================================

# def _argmax_action(A, theta):
#     """Linear optimization oracle: O(A; theta) = argmax_{a in A} <a, theta>"""
#     idx = int(np.argmax(A @ theta))
#     return A[idx]

# def _quantize_theta(theta, q):
#     """[θ]_q = q floor(θ sqrt(d) / q) / sqrt(d)"""
#     d = len(theta)
#     return (q * np.floor(theta * np.sqrt(d) / max(q, 1e-12))) / np.sqrt(d)

# def _sample_theta_grid(d, q, n_random=8):
#    
#     thetas = []
#     # ± basis
#     for i in range(d):
#         e = np.zeros(d); e[i] = 1.0/np.sqrt(d)
#         thetas.append(_quantize_theta(e, q))
#         thetas.append(_quantize_theta(-e, q))
#     # random directions
#     R = max(n_random * d, 2*d)
#     Z = np.random.randn(R, d)
#     Z /= np.linalg.norm(Z, axis=1, keepdims=True) + 1e-12
#     Z = Z / np.sqrt(d)
#     for z in Z:
#         thetas.append(_quantize_theta(z, q))
#     # dedup
#     uniq = []
#     seen = set()
#     for th in thetas:
#         key = tuple(np.round(th, 12))
#         if key not in seen:
#             uniq.append(th)
#             seen.add(key)
#     return np.array(uniq)

# def _ball_points_B_inv_sqrtT(d, T):
#     
#     pts = []
#     r = 1.0 / np.sqrt(T)
#     for i in range(d):
#         e = np.zeros(d); e[i] = r
#         pts.append(e.copy())
#         e[i] = -r
#         pts.append(e.copy())
#     return np.array(pts)

# def _greedy_spanner(phi_actions, d):
#     
#     K = len(phi_actions)
#     chosen = []
#     G = np.zeros((0, d))
#     for _ in range(min(d, K)):
#         best_i, best_val = None, -np.inf
#         for i in range(K):
#             if i in chosen: 
#                 continue
#             C = phi_actions[i:i+1] if G.size == 0 else np.vstack([G, phi_actions[i]])
#             # Gram determinant ~ volume^2
#             val = np.linalg.slogdet(C @ C.T + 1e-12*np.eye(C.shape[0]))[1]
#             if val > best_val:
#                 best_val = val
#                 best_i = i
#         if best_i is None: 
#             break
#         chosen.append(best_i)
#         G = phi_actions[chosen]
#     return chosen

# def _softbatch_endpoints_TM(T, d):
#     # M = ceil(log log T) + 1 
#     M = int(np.ceil(np.log2(np.log2(max(T, 3))))) + 1
#     Tm = []
#     last = 0
#     for m in range(1, M):  # m = 1..M-1
#         Em = max(int(np.floor(T ** (1.0 - 2.0**(-m)))), 2 * d)  # endpoint
#         Em = min(max(Em, last + 1), T - 1) 
#         Tm.append(Em)
#         last = Em
#     Tm.append(T)  # T_M = T
#     return Tm 

# def _build_tau_from_TM_endpoints(T, TM_endpoints):
#     
#     lengths = []
#     prev = 0
#     for Em in TM_endpoints:
#         lengths.append(Em - prev)
#         prev = Em
#     # Algorithm 4: τ_m = floor(L_m/2), (τ_m, L_m - τ_m)
#     taus = []
#     for L in lengths:
#         odd  = L     # τ_m
#         even = L            
#         taus.extend([odd, even]) 
#     
#     total = sum(taus)
#     if total < T:
#         taus[-1] += (T - total)
#     elif total > T:
#         over = total - T
#         for i in range(len(taus)-1, -1, -1):
#             dec = min(taus[i], over)
#             taus[i] -= dec
#             over -= dec
#             if over == 0:
#                 break
#     return taus, np.cumsum(taus)

# # =========================================================
# # SoftBatch
# # =========================================================

# class SoftBatch:
#     def __init__(self, T, K, d, delta=None, q=None, base_schedule=None, theta_grid_size_factor=8):
#         self.T, self.K, self.d = T, K, d
#         self.delta = (1.0 / T) if (delta is None) else float(delta)
#         
#         self.q = (1.0 / (8.0 * np.sqrt(d))) if (q is None) else float(q)

#         
#         if base_schedule is None:
#             TM_endpoints = _softbatch_endpoints_TM(T,d) 
#         else:
#             TM_endpoints = list(base_schedule)
#             if TM_endpoints[-1] != T:
#                 TM_endpoints[-1] = T
        
#         self.taus, self.tau_cum = _build_tau_from_TM_endpoints(T, TM_endpoints)
#         self.M = len(TM_endpoints)                 
#         self.M_iter = len(self.taus)           

#         
#         self.C_L = np.exp(8.0) * d
#         self.gamma = 10.0 * np.sqrt(self.C_L * d * (np.log(8.0 * self.M / self.delta) + 57.0 * d * (np.log(6.0 * T) ** 2)))
#         self.tau_minus1 = 1
#         self.tau0 = 1

#         
#         self.Theta_grid = _sample_theta_grid(d, self.q, n_random=theta_grid_size_factor)
#         # g^{(1)}(θ) = 0
#         self.g_table = { tuple(np.round(th, 12)) : np.zeros(d) for th in self.Theta_grid }

#         # X₁, X'₁
#         self.X_curr = np.stack(list(self.g_table.values()), axis=0)  
#         self.X_curr_prime = np.vstack([ self.X_curr, _ball_points_B_inv_sqrtT(d, T) ])

#         
#         self.Lambdas, self.Thetas = [], []           # (V_m, θ_m)
#         self.avg_survivors = []                      
#         self.time_used = 0

#         
#         self.theta_m = np.zeros(d)
#         if len(self.X_curr) == 0:
#             self.a_star = np.zeros(d)
#         else:
#             self.a_star = self.X_curr[np.random.randint(len(self.X_curr))]

#     
#     def _lws_approx(self, X_prime, eta_m, a_star, theta_m):
#         
#         Kp, d = X_prime.shape
#         if Kp == 0:
#             # fallback
#             A = [np.zeros(d) for _ in range(d)]
#             Th = [np.eye(d)[i] / np.sqrt(d) for i in range(d)]
#             return A, Th

#         
#         Delta = X_prime @ theta_m - float(a_star @ theta_m)
#         Delta = (float(a_star @ theta_m) - X_prime @ theta_m)  # Δ = <a* - a, θ_m>
#         Delta = np.maximum(Delta, 0.0)
#         denom = 1.0 + eta_m * Delta
#         Phi = X_prime / denom[:, None]

#         
#         idxs = _greedy_spanner(Phi, d=min(d, len(Phi)))
#         if len(idxs) < d:
#             
#             all_idx = list(range(len(Phi)))
#             for j in all_idx:
#                 if len(idxs) >= d: 
#                     break
#                 if j not in idxs:
#                     idxs.append(j)

#         A_sel = [ X_prime[i].copy() for i in idxs[:d] ]
#         Phi_sel = [ Phi[i].copy() for i in idxs[:d] ]

#         
#         Theta_sel = []
#         for v in Phi_sel:
#             nv = np.linalg.norm(v) + 1e-12
#             Theta_sel.append( v / nv )

#         # a_i = g^{(m)}(θ^{(i)})
#         A_list = []
#         for th in Theta_sel:
#             key = tuple(np.round(_quantize_theta(th, self.q), 12))
#             a_map = self.g_table.get(key, None)
#             if a_map is None or not np.any(a_map):
#                 
#                 a_map = X_prime[np.argmax(X_prime @ th)]
#             A_list.append(a_map.copy())

#         return A_list, Theta_sel

#     
#     def _run_one_batch(self, m, tau_m, contexts_log, get_contexts, play, tracker):
#         
#         d = self.d
#         # η_m = √(τ_{m-2}) / (8γ)  (τ_{-1}=τ_0=1)
#         tau_prev2 = self.tau_minus1 if (m <= 2) else self.taus[m-3]
#         eta_m = np.sqrt(max(tau_prev2, 1)) / (8.0 * self.gamma)

#         
#         A_list, Theta_list = self._lws_approx(self.X_curr_prime, eta_m, self.a_star, self.theta_m)

#         # (line 8) π(i)=1/d, a0 = a*_m = g^{(m)}([θ_m]_q)
#         pi = np.ones(len(A_list)) / max(1, len(A_list))
#         theta0 = _quantize_theta(self.theta_m, self.q)
#         a0 = self.g_table.get(tuple(np.round(theta0, 12)), None)
#         if a0 is None or not np.any(a0):
#             a0 = self.X_curr_prime[np.argmax(self.X_curr_prime @ theta0)]
#         a_star_m = a0.copy()

#         
#         pulls_total = 0
#         Y_list, R_list, H_m_times = [], [], []
#         i = 0
#         while i < len(A_list) and pulls_total < tau_m and self.time_used < self.T:
#             ai = A_list[i]
#             thetai = Theta_list[i]
#             # Δ_m(ai) = <a* - ai, θ_m>
#             Delta_ai = float((a_star_m - ai) @ self.theta_m)
#             Delta_ai = max(0.0, Delta_ai)
#             # n_m(i) = floor( π(i) τ_m / [ 4 (1 + √(τ_{m-1}) Δ / (8γ) )^2 ] )
#             tau_prev1 = self.tau0 if (m == 1) else (self.taus[m-2])
#             denom = (1.0 + (np.sqrt(max(tau_prev1,1)) * Delta_ai) / (8.0 * self.gamma)) ** 2
#             n_i = int(np.floor( (pi[i] * tau_m) / max(4.0 * denom, 1e-12) ))
#             n_i = max(0, n_i)

#             
#             for _ in range(n_i):
#                 if pulls_total >= tau_m or self.time_used >= self.T:
#                     break
#                 self.time_used += 1
#                 t = self.time_used
#                 A_t = contexts_log.get(t)
#                 if A_t is None:
#                     A_t = get_contexts(t)
#                     contexts_log[t] = A_t
#                 a_play = _argmax_action(A_t, thetai)
#                 r = play(a_play, t, tracker)
#                 Y_list.append(a_play)
#                 R_list.append(r)
#                 H_m_times.append(t)
#                 pulls_total += 1
#             i += 1

#         
#         remain = max(0, min(tau_m, self.T - self.time_used) - pulls_total)
#         for _ in range(remain):
#             if self.time_used >= self.T:
#                 break
#             self.time_used += 1
#             t = self.time_used
#             A_t = contexts_log.get(t)
#             if A_t is None:
#                 A_t = get_contexts(t)
#                 contexts_log[t] = A_t
#             a_play = _argmax_action(A_t, theta0)
#             r = play(a_play, t, tracker)
#             Y_list.append(a_play)
#             R_list.append(r)
#             H_m_times.append(t)
#             pulls_total += 1

#         
#         if len(Y_list) > 0:
#             Y = np.stack(Y_list, axis=0)
#             R = np.array(R_list)
#             V_m = np.eye(d) + Y.T @ Y
#             xi = Y.T @ R
#             try:
#                 theta_next = _chol_solve(V_m, xi)
#             except np.linalg.LinAlgError:
#                 theta_next = np.linalg.pinv(V_m) @ xi
#         else:
#             V_m = np.eye(d)
#             theta_next = self.theta_m.copy()

#         self.Lambdas.append(V_m)
#         self.Thetas.append(theta_next)

#         
#         dots = self.X_curr_prime @ theta_next
#         a_star_next = self.X_curr_prime[int(np.argmax(dots))].copy()

#         
#         # g^{(m+1)}(θ) = (1/τ_m) Σ_{t∈H_m} O(A_t; θ)
#         
#         g_next = {}
#         for th in self.Theta_grid:
#             acc = np.zeros(d)
#             cnt = 0
#             for t in H_m_times:
#                 A_t = contexts_log[t]
#                 acc += _argmax_action(A_t, th)
#                 cnt += 1
#             g_next[tuple(np.round(th, 12))] = (acc / max(cnt, 1)) if cnt > 0 else self.g_table[tuple(np.round(th,12))]

#         
#         self.theta_m = theta_next
#         self.a_star = a_star_next
#         self.g_table = g_next
#         self.X_curr = np.stack(list(self.g_table.values()), axis=0) if len(self.g_table) > 0 else np.zeros((0, d))
#         self.X_curr_prime = np.vstack([ self.X_curr, _ball_points_B_inv_sqrtT(d, self.T) ])

#         return len(H_m_times)

#     def run(self, get_contexts, play, tracker):
#         contexts_log = {}   
#         prev_used = 0
#         avg_sizes = []
        
#         for m in range(1, self.M_iter + 1):
#             tau_m = self.taus[m-1]
#             used_before = self.time_used
#             pulls = self._run_one_batch(m, tau_m, contexts_log, get_contexts, play, tracker)
#             avg_sizes.append(float(self.K)) 
#             prev_used = self.time_used
#             if self.time_used >= self.T:
#                 break
        
#         theta0 = _quantize_theta(self.theta_m, self.q)
#         while self.time_used < self.T:
#             self.time_used += 1
#             t = self.time_used
#             A_t = contexts_log.get(t) or get_contexts(t)
#             a_play = _argmax_action(A_t, theta0)
#             _ = play(a_play, t, tracker)

#         self.avg_survivors = avg_sizes
#         return


# --------------------------- small utils ---------------------------

def _chol_solve(A, b):
    try:
        L = np.linalg.cholesky(A)
        y = np.linalg.solve(L, b)
        x = np.linalg.solve(L.T, y)
        return x
    except np.linalg.LinAlgError:
        return np.linalg.pinv(A) @ b

def _argmax_action(A, theta):
    """Linear optimization oracle O(A; theta) = argmax_{a in A} <a,theta>."""
    return A[int(np.argmax(A @ theta))]

def _quantize_theta(theta, q):
    d = len(theta)
    return (q * np.floor(theta * np.sqrt(d) / max(q, 1e-12))) / np.sqrt(d)

def _rand_unit_dirs(k, d):
    Z = np.random.randn(k, d)
    Z /= (np.linalg.norm(Z, axis=1, keepdims=True) + 1e-12)
    return Z / np.sqrt(d)  # ||z|| = 1/√d

def _ball_points_B_inv_sqrtT(d, T):
    r = 1.0 / np.sqrt(max(T, 1))
    pts = []
    for i in range(d):
        e = np.zeros(d); e[i] = r; pts.append(e.copy())
        e[i] = -r; pts.append(e.copy())
    return np.array(pts)

def _greedy_spanner_rows(M, d):
    """
    Volume-greedy spanner on row vectors of M (K x d).
    Returns up to d indices. Uses logdet of Gram (C C^T).
    """
    K = M.shape[0]
    chosen, G = [], np.zeros((0, d))
    for _ in range(min(d, K)):
        best_i, best_val = None, -np.inf
        for i in range(K):
            if i in chosen:
                continue
            C = M[i:i+1] if G.size == 0 else np.vstack([G, M[i]])
            val = np.linalg.slogdet(C @ C.T + 1e-12*np.eye(C.shape[0]))[1]
            if val > best_val:
                best_val = val; best_i = i
        if best_i is None: break
        chosen.append(best_i); G = M[chosen]
    return chosen

def _tm_lengths(T, d):
    T = int(T); d = int(d)
    M = int(np.ceil(np.log2(np.log2(max(T, 3))))) + 1
    lens = []
    for ell in range(1, M):
        t_ell = max(int(np.floor(T ** (1.0 - 2.0**(-ell)))), 2 * d)
        lens.append(t_ell)
    lens.append(T)  # T_M = T
    return np.array(lens, dtype=int)  # [T1, T2, ..., T_M]

def _repeat_Tm_as_2M(Tm, T):
    lens_2M = np.repeat(Tm, 2).astype(int)       # [T1, T1, T2, T2, ...]
    cum = np.cumsum(lens_2M)

    if cum[-1] >= T:
        idx = int(np.searchsorted(cum, T, side='left'))
        lens_2M = lens_2M[:idx+1]
        excess = cum[idx] - T
        if excess > 0:
            lens_2M[-1] -= excess     
    ends_2M = np.cumsum(lens_2M)
    return lens_2M.astype(int), ends_2M.astype(int)

# --------------------------- Efficient Algorithm 4 ---------------------------

class SoftBatch:

    def __init__(self, T, d, delta=1e-3, q=None, gamma=None, theta_cap_factor=4, random_seed_dirs=1):
        self.T, self.d = int(T), int(d)
        self.delta = float(delta)
        self.q = (1.0 / (8.0 * np.sqrt(d))) if (q is None) else float(q)

        self.Tm = _tm_lengths(T, d)                         # [T1,...,T_M]
        self.batch_lengths_2M, self.ends_2M = _repeat_Tm_as_2M(self.Tm, self.T)
        self.M = len(self.Tm)                                
        self.M2 = len(self.batch_lengths_2M)                

        if gamma is None:
            C_L = np.exp(8.0) * d
            self.gamma = 10.0 * np.sqrt(
                C_L * d * (np.log(8.0 * self.M / max(self.delta, 1e-12))
                + 57.0 * d * (np.log(6.0 * max(T, 3)) ** 2))
            )
        else:
            self.gamma = float(gamma)

        thetas = []
        for i in range(d):
            e = np.zeros(d); e[i] = 1.0/np.sqrt(d)
            thetas.append(e.copy()); thetas.append(-e.copy())
        if random_seed_dirs > 0:
            thetas.extend(list(_rand_unit_dirs(random_seed_dirs * d, d)))
        active, seen = [], set()
        for th in thetas:
            key = tuple(np.round(_quantize_theta(th, self.q), 12))
            if key not in seen:
                active.append(np.array(key)); seen.add(key)
        self.theta_active = np.array(active)
        self.theta_cap = max(theta_cap_factor * d, d)

        # g^{(1)}(θ)=0 on ACTIVE Θ
        self.g_table = { tuple(np.round(th, 12)) : np.zeros(d) for th in self.theta_active }

        self.V = np.eye(d)
        self.xi = np.zeros(d)
        self.theta_m = np.zeros(d)

        self._refresh_sets()

        self.H_lengths = []

        self.time_used = 0

    # ---------------- helper ----------------
    def _refresh_sets(self):
        self.X_m = np.stack(list(self.g_table.values()), axis=0) if len(self.g_table) > 0 else np.zeros((0, self.d))
        self.X_m_prime = np.vstack([ self.X_m, _ball_points_B_inv_sqrtT(self.d, self.T) ])

    def _maybe_prune_theta_active(self):
        if len(self.theta_active) <= self.theta_cap:
            return
        idxs = _greedy_spanner_rows(self.theta_active, self.d)
        keep = set(idxs)
        if len(keep) < self.theta_cap:
            extra = [i for i in range(len(self.theta_active)) if i not in keep]
            extra = extra[:(self.theta_cap - len(keep))]
            keep.update(extra)
        self.theta_active = self.theta_active[sorted(list(keep))]
        new_g = {}
        for th in self.theta_active:
            key = tuple(np.round(th, 12))
            new_g[key] = self.g_table.get(key, np.zeros(self.d))
        self.g_table = new_g

    def _eta_for_ell(self, ell_1based):
        # η_m = √(τ_{m-2}) / (8γ),  τ_{-1}=τ_0=1
        if ell_1based <= 2:
            base = 1
        else:
            base = self.H_lengths[ell_1based - 3] if (ell_1based - 3) < len(self.H_lengths) else 1
        return np.sqrt(max(base, 1)) / (8.0 * self.gamma)

    # ---------------- one H-phase ----------------
    def _phase_H(self, ell_1based, length, get_contexts, play, tracker):
        d = self.d
        eta_m = self._eta_for_ell(ell_1based)

        # a*_m = g^{(m)}([θ_m]_q)
        theta0 = _quantize_theta(self.theta_m, self.q)
        a_star_m = self.g_table.get(tuple(np.round(theta0,12)), np.zeros(d)).copy()

        # Φ_m on X'_m
        Xp = self.X_m_prime
        if Xp.shape[0] == 0:
            return []

        Delta = np.maximum( (a_star_m - Xp) @ self.theta_m, 0.0 )
        Phi = Xp / (1.0 + eta_m * Delta)[:, None]

        # volume-greedy spanner
        idxs = _greedy_spanner_rows(Phi, d=min(d, Phi.shape[0]))
        if len(idxs) < d:
            for j in range(Phi.shape[0]):
                if len(idxs) >= d: break
                if j not in idxs: idxs.append(j)

        A_list = [Xp[i].copy() for i in idxs[:d]]
        Theta_list = []
        for v in Phi[idxs[:d]]:
            nv = np.linalg.norm(v) + 1e-12
            Theta_list.append(v / nv)
        Theta_list = np.array(Theta_list)

        for th in Theta_list:
            key = tuple(np.round(_quantize_theta(th, self.q), 12))
            if key not in self.g_table:
                self.g_table[key] = np.zeros(d)
                self.theta_active = np.vstack([self.theta_active, np.array(key)])
        self._maybe_prune_theta_active()
        self._refresh_sets()

        pi = np.ones(len(Theta_list)) / max(1, len(Theta_list))
        tau_prev1 = self.H_lengths[-1] if len(self.H_lengths) > 0 else 1

        H_times, pulls = [], 0
        for i, (ai, thetai) in enumerate(zip(A_list, Theta_list)):
            Delta_ai = max(0.0, float((a_star_m - ai) @ self.theta_m))
            denom = (1.0 + (np.sqrt(max(tau_prev1,1)) * Delta_ai) / (8.0 * self.gamma)) ** 2
            n_i = int(np.floor((pi[i] * length) / max(4.0 * denom, 1e-12)))
            n_i = max(0, n_i)
            for _ in range(n_i):
                if pulls >= length or self.time_used >= self.T: break
                self.time_used += 1
                t = self.time_used
                A_t = get_contexts(t)
                a_play = _argmax_action(A_t, thetai)
                r = play(a_play, t, tracker)
                self.V += np.outer(a_play, a_play)
                self.xi += a_play * r
                H_times.append(t)
                pulls += 1
            if pulls >= length or self.time_used >= self.T: break

        while pulls < length and self.time_used < self.T:
            self.time_used += 1
            t = self.time_used
            A_t = get_contexts(t)
            a_play = _argmax_action(A_t, theta0)
            r = play(a_play, t, tracker)
            self.V += np.outer(a_play, a_play)
            self.xi += a_play * r
            H_times.append(t)
            pulls += 1

        self.theta_m = _chol_solve(self.V, self.xi)

        if length > 0 and len(H_times) > 0:
            g_next = {}
            A_cache = { t: get_contexts(t) for t in H_times }
            for th_vec in self.theta_active:
                th = np.array(th_vec, dtype=float)
                acc = np.zeros(d)
                for t in H_times:
                    acc += _argmax_action(A_cache[t], th)
                g_next[tuple(np.round(th, 12))] = acc / float(length)
            self.g_table = g_next
            self._refresh_sets()

        self.H_lengths.append(int(length))
        return H_times

    # ---------------- run ----------------
    def run(self, get_contexts, play, tracker=None):
        for j, L in enumerate(self.batch_lengths_2M, start=1):
            if self.time_used >= self.T:
                break
            L_eff = min(L, self.T - self.time_used)
            if L_eff <= 0:
                break
            ell = (j + 1) // 2  # 1,1,2,2,3,3,... → 1,2,3,...
            if j % 2 == 1:
                # H-phase
                self._phase_H(ell, L_eff, get_contexts, play, tracker)
            else:
                # H-phase
                self._phase_H(ell, L_eff, get_contexts, play, tracker)
        return

    @property
    def tau_cum(self):
        return self.ends_2M


###########################################################

# =========================================================
# RS-OFUL (Rarely Switching OFUL)
# =========================================================

class RSOFUL:
    def __init__(self, T, K, d, delta=None, lam=None, alpha=None, C=1.0):
        self.T, self.K, self.d = int(T), int(K), int(d)
        self.delta = (1.0 / max(T, 2)) if (delta is None) else float(delta)
        self.lam = 32.0 * np.log(2.0 * d * max(T, 3) / self.delta) if lam is None else float(lam)
        self.alpha = 10.0 * np.sqrt(np.log(2.0 * d * K * max(T, 3) / self.delta)) if alpha is None else float(alpha)
        self.C = float(C)

        self.V = self.lam * np.eye(d)      # λI + Σ x x^T
        self.xi = np.zeros(d)              # Σ x r
        self.theta_hat = np.zeros(d)       # V^{-1} xi
        self.theta_tilde = None            # rarely switching parameter
        self.logdet_V = _logdet(self.V)    
        self.logdet_at_switch = self.logdet_V
        self.tau = 0                      

        self.switch_times = []
        self.theta_tilde_hist = []
        self.survivor_counts_per_batch = []

    def _recompute_stats(self):
        self.theta_hat = _chol_solve(self.V, self.xi)
        self.logdet_V = _logdet(self.V)

    def _var_diag(self, A, V_inv=None):
        if V_inv is None:
            V_inv = _chol_inv(self.V)
        return np.einsum('ij,jk,ik->i', A, V_inv, A, optimize=True), V_inv

    def _maybe_switch(self, A_t, t):

        self._recompute_stats()

        
        if self.logdet_V <= self.logdet_at_switch + np.log1p(self.C):
            return False

        # (x*, θ~) = argmax_{x∈D_t} <θ_hat, x> + α ||x||_{V^{-1}}
        var, V_inv = self._var_diag(A_t)         # diag(A V^{-1} A^T)
        sd = np.sqrt(np.maximum(var, 1e-15))
        base = A_t @ self.theta_hat
        ucb = base + self.alpha * sd
        idx = int(np.argmax(ucb))
        x_star = A_t[idx]

        # θ~ = θ_hat + α * V^{-1} x* / ||x*||_{V^{-1}}
        denom = float(sd[idx])
        self.theta_tilde = self.theta_hat + (self.alpha / denom) * (V_inv @ x_star)

        self.tau = t
        self.logdet_at_switch = self.logdet_V
        self.switch_times.append(t)
        self.theta_tilde_hist.append(self.theta_tilde.copy())
        return True
    
    def get_update_times(self, as_array: bool = False, include_initial: bool = False):
        
        times = list(self.switch_times)
        if include_initial:
            if len(times) == 0 or times[0] != 1:
                times = [1] + times
        if as_array:
            return np.array(times, dtype=int)
        return times

    def run(self, get_contexts, play, tracker):
        for t in range(1, self.T + 1):
            A_t = get_contexts(t)  # (K,d)

            switched = self._maybe_switch(A_t, t)

            if self.theta_tilde is None:
                var, _ = self._var_diag(A_t)
                idx = int(np.argmax(var))
            else:
                idx = int(np.argmax(A_t @ self.theta_tilde))

            x = A_t[idx]
            r = play(x, t, tracker)

            self.V += np.outer(x, x)
            self.xi += x * r

        return

###########################################################

class Environment:
    def __init__(self, d, sigma=1):
        self.d = d
        self.theta = np.random.randn(d)
        self.sigma = sigma  # sub-Gaussian noise scale

    def play(self, x):
        noise = np.random.normal(scale=self.sigma)
        return x @ self.theta + noise
    
###########################################################

class RegretTracker:
    def __init__(self):
        self.regrets = []

    def observe(self, chosen_arm, theta, context_set=None):
        if context_set is not None:
            optimal_reward = max(context_set @ theta)
            actual_reward = chosen_arm @ theta
            regret = optimal_reward - actual_reward
        else:
            regret = 0
        self.regrets.append(regret)
        return regret

    def plot(self):
        plt.plot(np.cumsum(self.regrets), label = "BatchLearning")

###########################################################

class RegretTrackerArg_5:
    def __init__(self):
        self.regrets = []

    def observe(self, chosen_arm, theta, context_set=None):
        if context_set is not None:
            optimal_reward = max(context_set @ theta)
            actual_reward = chosen_arm @ theta
            regret = optimal_reward - actual_reward
        else:
            regret = 0
        self.regrets.append(regret)
        return regret

    def plot(self):
        plt.plot(np.cumsum(self.regrets), label="BCLE")

###########################################################

class RegretTrackerBat_5:
    def __init__(self):
        self.regrets = []

    def observe(self, chosen_arm, theta, context_set=None):
        if context_set is not None:
            optimal_reward = max(context_set @ theta)
            actual_reward = chosen_arm @ theta
            regret = optimal_reward - actual_reward
        else:
            regret = 0
        self.regrets.append(regret)
        return regret

    def plot(self):
        plt.plot(np.cumsum(self.regrets), label="BCLE-G")

###########################################################

class RegretTrackerDG:
    def __init__(self):
        self.regrets = []

    def observe(self, chosen_arm, theta, context_set=None):
        if context_set is not None:
            optimal_reward = max(context_set @ theta)
            actual_reward = chosen_arm @ theta
            regret = optimal_reward - actual_reward
        else:
            regret = 0
        self.regrets.append(regret)
        return regret

    def plot(self):
        plt.plot(np.cumsum(self.regrets), label="BatchLinUCB_DG")

###########################################################

class RegretTrackerSoft:
    def __init__(self):
        self.regrets = []

    def observe(self, chosen_arm, theta, context_set=None):
        if context_set is not None:
            optimal_reward = max(context_set @ theta)
            actual_reward = chosen_arm @ theta
            regret = optimal_reward - actual_reward
        else:
            regret = 0
        self.regrets.append(regret)
        return regret

    def plot(self):
        plt.plot(np.cumsum(self.regrets), label="SoftBatch")

###########################################################

class RegretTrackerOFUL:
    def __init__(self):
        self.regrets = []

    def observe(self, chosen_arm, theta, context_set=None):
        if context_set is not None:
            optimal_reward = max(context_set @ theta)
            actual_reward = chosen_arm @ theta
            regret = optimal_reward - actual_reward
        else:
            regret = 0
        self.regrets.append(regret)
        return regret

    def plot(self):
        plt.plot(np.cumsum(self.regrets), label="RS-OFUL")

###########################################################

def simulate(T=500, K=10, d=5, alpha=1.0, seed=1):
    np.random.seed(seed)

    env = Environment(d)
    trackerOFUL = RegretTrackerOFUL()
    trackerDG = RegretTrackerDG()
    trackerSoft = RegretTrackerSoft()
    tracker1 = RegretTracker()
    trackerBat_5 = RegretTrackerBat_5()
    trackerArg_5 = RegretTrackerArg_5()

    T_schedule = grid(T,K,d)
    # print(T_schedule)

    learnerOFUL = RSOFUL(T, K, d, delta=1/T, C=1)
    learnerDG = BatchLinUCB_DG_GOPT(T, K, d, delta=1/T, design_fn=near_g_optimal_design)
    learnerSoft = SoftBatch(T, d, delta=1/T, q=None) # SoftBatch(T, K, d, delta=1/T, q=None)
    learner1 = BatchLearner(T_schedule, T, d, alpha)
    learnerBat_5 = BatchCLAE_G(T, K, d, 0.5)
    learnerArg_5 = BatchCLAE(T, K, d, 0.5)

    # store contexts set
    contexts_log = {}

    def get_contexts(t): # IID sampling
        if t not in contexts_log:
            contexts_log[t] = np.random.rand(K, d) # Uniform distribution
            # contexts_log[t] = np.random.randn(K, d) # Normal distribution

            # contexts_log[t] = np.random.laplace(0.0, 1/np.sqrt(2), size=(K, d)) # Laplace distribution
            # contexts_log[t] = np.random.standard_t(8, size=(K, d)) # student-t distribution
            # contexts_log[t] = np.random.standard_t(4, size=(K, d)) # student-t distribution
            # contexts_log[t] = np.random.standard_t(2, size=(K, d)) # student-t distribution
            # contexts_log[t] = np.random.standard_cauchy(size=(K, d)) # cauchy distribution
            # contexts_log[t] = np.random.choice([-1,1], size=(K,d)) * ((1 + np.random.pareto(3.0, size=(K,d))) - 3.0/2.0) # symmetric pareto distribution
            # contexts_log[t] = np.random.choice([-1,1], size=(K,d)) * (1 + np.random.pareto(1.5, size=(K,d))) # symmetric pareto distribution
            
            # norms = np.linalg.norm(contexts_log[t], axis=1)
            # max_norm = np.max(norms)
            # if max_norm == 0:
            #     max_norm = 1e-12
            # contexts_log[t] = contexts_log[t] / max_norm

        return contexts_log[t]

    def play1(x):
        # record regrets
        t = len(tracker1.regrets) + 1
        return_value = env.play(x)
        tracker1.observe(x, env.theta, contexts_log.get(t))
        return return_value

    def play2(x, t, tracker):
        return_value = env.play(x)
        tracker.observe(x, env.theta, contexts_log.get(t))
        return return_value
    
    runtimes = {}

    t0 = time.perf_counter()
    learnerOFUL.run(get_contexts, play2, trackerOFUL)
    runtimes["RS-OFUL"] = time.perf_counter() - t0

    t0 = time.perf_counter()
    learnerDG.run(get_contexts, play2, trackerDG)
    runtimes["BatchLinUCB-DG"] = time.perf_counter() - t0

    t0 = time.perf_counter()
    learnerSoft.run(get_contexts, play2, trackerSoft)
    runtimes["SoftBatch"] = time.perf_counter() - t0

    t0 = time.perf_counter()
    learner1.run(get_contexts, play1)
    runtimes["BatchLearning"] = time.perf_counter() - t0

    t0 = time.perf_counter()
    learnerBat_5.run(get_contexts, play2, trackerBat_5)
    runtimes["BCLE-G"] = time.perf_counter() - t0

    t0 = time.perf_counter()
    learnerArg_5.run(get_contexts, play2, trackerArg_5)
    runtimes["BCLE"] = time.perf_counter() - t0


    # print("\nAverage surviving arms per batch:")
    # for i, avg in enumerate(learnerDG.survivor_counts_per_batch):
    #     print(f"[BatchLinUCB-DG]        Batch {i+1}: {avg:.2f}")

    # for i, avg in enumerate(learner1.survivor_counts_per_batch):
    #     print(f"[BatchLearning]   Batch {i+1}: {avg:.2f}")

    # for i, avg in enumerate(learnerBat_5.survivor_counts_per_batch):
    #     print(f"[BCLE-G]       Batch {i+1}: {avg:.2f}")

    # for i, avg in enumerate(learnerArg_5.survivor_counts_per_batch):
    #     print(f"[BCLE]       Batch {i+1}: {avg:.2f}")

    curves = {
        "RS-OFUL": np.cumsum(trackerOFUL.regrets),
        "BatchLinUCB-DG" : np.cumsum(trackerDG.regrets),
        "SoftBatch": np.cumsum(trackerSoft.regrets),
        "BatchLearning": np.cumsum(tracker1.regrets),
        "BCLE-G": np.cumsum(trackerBat_5.regrets),
        "BCLE": np.cumsum(trackerArg_5.regrets),
    }

    def _ensure_T(arr, T):
        arr = np.array(arr, dtype=int)
        if len(arr) == 0 or arr[-1] != T:
            arr = np.concatenate([arr, [T]])
        return arr
    
    schedules = {
        "RS-OFUL":          _ensure_T(learnerOFUL.get_update_times(as_array=True, include_initial=True), T),
        "BatchLinUCB-DG":   _ensure_T(learnerDG.schedule, T),
        "SoftBatch":        _ensure_T(learnerSoft.tau_cum, T),       
        "BatchLearning": _ensure_T(T_schedule, T),
        "BCLE-G":      _ensure_T(learnerBat_5.schedule, T),
        "BCLE":        _ensure_T(learnerArg_5.schedule, T),
    }


    return curves, schedules, runtimes

###########################################################

# run experiments
# (K,d) = (1000, 5), (5000, 10), (50, 20), (100, 30)
K = 1000
d = 5
T = 10000
alpha = np.sqrt(50 * np.log(K * d * (T ** 2)))

num_runs = 10
sum_curves, sqsum_curves = {}, {}
last_schedules = None

ORDER = ["RS-OFUL", "BatchLinUCB-DG", "SoftBatch", "BatchLearning", "BCLE-G", "BCLE"]
# ORDER = ["RS-OFUL", "BCLE-G", "BCLE"]
COLORS = {
    "RS-OFUL": "#D55E00", #"red",
    "BatchLinUCB-DG": "#E69F00", #"orange",
    "SoftBatch": "#009E73", #"green",
    "BatchLearning": "#0072B2", #"blue",
    "BCLE-G": "#CC79A7", #"purple",
    "BCLE": "#56B4E9", #"light blue",
}

runtime_sum = {name: 0.0 for name in ORDER}  
runtime_sqsum = {name: 0.0 for name in ORDER}

for r in range(num_runs):
    seed = 1 + r  
    curves, schedules, runtimes = simulate(T, K, d, alpha, seed=seed)
    last_schedules = schedules

    if not sum_curves:
        for name, arr in curves.items():
            sum_curves[name] = np.array(arr, dtype=float)
            sqsum_curves[name] = np.array(arr, dtype=float) ** 2
    else:
        for name, arr in curves.items():
            sum_curves[name] += arr
            sqsum_curves[name] += arr ** 2
    
    for name in ORDER:
        t = float(runtimes[name])
        runtime_sum[name] += t
        runtime_sqsum[name] += t**2

means = {name: sum_curves[name] / num_runs for name in sum_curves}
stds  = {name: np.sqrt(sqsum_curves[name] / num_runs - means[name] ** 2) for name in sum_curves}

avg_runtime = {name: runtime_sum[name] / num_runs for name in ORDER}
std_runtime = {
    name: math.sqrt(max(runtime_sqsum[name] / num_runs - avg_runtime[name]**2, 0.0))
    for name in ORDER
}

print("\n=== Average wall-clock runtime over", num_runs, "run(s) ===")
for name in ORDER:
    print(f"{name:>18}: {avg_runtime[name]:8.3f} s  ± {std_runtime[name]:.3f} s")

x = np.arange(1, T + 1) 
# fig, axes = plt.subplots(1, 3, figsize=(18, 5), constrained_layout=False)
fig = plt.figure(figsize=(18, 3))
gs  = fig.add_gridspec(
    nrows=2, ncols=3,
    height_ratios=[0.08, 0.90],
    left=0.06, right=0.995, bottom=0.12, top=0.98, wspace=0.28, hspace = 0.1
)
axes = [fig.add_subplot(gs[1, i]) for i in range(3)]
legend_ax = fig.add_subplot(gs[0, :])
legend_ax.axis("off")

plt.rcParams.update({
    "font.size": 13,        
    "axes.labelsize": 16,  
    # "axes.titlesize": 14,   
    "xtick.labelsize": 16,  
    "ytick.labelsize": 16,
    "legend.fontsize": 20   
})

ax1 = axes[0] 
line_handles = [] 
for name in ORDER: 
    m, s = means[name], stds[name] 
    ln, = ax1.plot(x, m, label=name, color=COLORS[name], linewidth=1.8) 
    ax1.fill_between(x, m - s, m + s, alpha=0.12, color=COLORS[name]) 
    line_handles.append(ln) 

# ax1.set_title("Cumulative Regret (all time)") 
ax1.set_xlabel("Time step", fontsize=20) 
ax1.set_ylabel("Regret", fontsize=20) 
ax1.grid(True, alpha=0.3)
ax1.tick_params(axis='x', labelsize=15)   
ax1.tick_params(axis='y', labelsize=15)   
ax1.set_xlim(0, 10000)  
ax1.set_ylim(0, 20000)

legend_ax.legend(
    handles=line_handles,
    labels=[ln.get_label() for ln in line_handles],
    loc="center",
    ncol=len(ORDER),
    frameon=False,
)

def draw_ybreak_double(ax, y1=0.94, y2=0.975, amp=0.012, waves=26, lw=1.8, color='0.35', alpha=0.6):

    xs = np.linspace(0.0, 1.0, 800)

    def wave(yc):
        return yc - amp * np.sin(2 * np.pi * waves * xs)

    bg = ax.get_facecolor()

    ax.add_patch(plt.Rectangle((0, y1), 1, y2 - y1,
                               transform=ax.transAxes, facecolor=bg, edgecolor='none',
                               zorder=19, clip_on=False))

    ax.add_patch(plt.Rectangle((0, y2), 1, 1 - y2,
                               transform=ax.transAxes, facecolor=bg, edgecolor='none',
                               zorder=19, clip_on=False))

    ax.plot(xs, wave(y1), transform=ax.transAxes, color=color, alpha=alpha,
            lw=lw, clip_on=False, zorder=21)
    ax.plot(xs, wave(y2), transform=ax.transAxes, color=color, alpha=alpha,
            lw=lw, clip_on=False, zorder=21)

    ax.spines['top'].set_visible(False)

ax2 = axes[1]

low_names = ["BCLE-G", "BCLE"]
y_top_data = max(np.max(means[n]) for n in low_names)
top_margin = 0.15 * y_top_data       
y_top_axis = y_top_data + top_margin

for name in ORDER:
    m, s = means[name], stds[name]
    ax2.plot(x, m, color=COLORS[name], linewidth=1.8)
    ax2.fill_between(x, m - s, m + s, alpha=0.12, color=COLORS[name])

ax2.set_xlim(1, T)
ax2.set_ylim(0, y_top_axis)
# ax2.set_title("Zoomed Regret (low range)")
ax2.set_xlabel("Time step", fontsize=20)
ax2.set_ylabel("Regret (zoomed)", fontsize=20)
ax2.grid(True, alpha=0.3)
ax2.tick_params(axis='x', labelsize=15)   
ax2.tick_params(axis='y', labelsize=15)   

# draw_ybreak_double(ax2, y1=0.962, y2=0.985, amp=0.015, waves=22, lw=2.2, color='k')
draw_ybreak_double(ax2, y1=0.93, y2=0.975, amp=0.013, waves=24, lw=1, color='k', alpha=1.0)

def plot_batch_count_from_starts(ax, endpoints, T, color, label, where='post'):
    ends = np.array(sorted(set(int(x) for x in endpoints)), dtype=int)
    if len(ends) == 0:
        ax.step([0, T], [1, 1], where=where, color=color, linewidth=2.0, label=label)
        return
    if ends[-1] != T:
        ends = np.append(ends, T)

    M = len(ends)                   
    starts = np.concatenate([[0], ends[:-1]])  

    x = np.concatenate([starts, [T]])
    y = np.concatenate([np.arange(1, M+1), [M]])
    ax.step(x, y, where=where, linewidth=2.0, color=color, label=label)

ax3 = axes[2]

for name in ORDER:
    ends = last_schedules[name]
    plot_batch_count_from_starts(ax3, ends, T, color=COLORS[name], label=name)

ax3.set_xlim(0, T)
ax3.set_xlabel("Time step", fontsize=20)
ax3.set_ylabel("Batch", fontsize=20)
# ax3.set_yscale("log")  
ax3.grid(True, which='both', alpha=0.3)
ax3.tick_params(axis='x', labelsize=15)   
ax3.tick_params(axis='y', labelsize=15)   
ax3.set_ylim(0, 50) 

plt.savefig("bandits.pdf", bbox_inches="tight")
plt.show()

