import math
import numpy as np
import matplotlib.pyplot as plt
import time

###########################################################

def near_g_optimal_design(X, max_iters=1e2, tol=1e-5):
    X = np.asarray(X, dtype=float)

    if X.ndim == 1:
        X = X.reshape(1, -1)             # (d,) -> (1,d)
    if X.ndim != 2 or X.shape[0] == 0 or X.shape[1] == 0:
        # return np.ones(1)  # or raise
        raise ValueError(f"design expects 2D (K,d), got shape {X.shape}")
    
    K, d = X.shape

    # --- rank reduction (work in r-dim subspace; gives same weights) ---
    # SVD threshold; you can tighten/loosen if needed
    s_threshold = 1e-12
    U, s, Vt = np.linalg.svd(X, full_matrices=False)
    r = int((s > s_threshold).sum())
    if r == 0:
        # all-zero inputs; return uniform
        return np.ones(K) / K
    Vr = Vt[:r, :].T           # (d, r)
    Xr = X @ Vr                # (K, r) projected contexts

    # --- init uniform weights on the simplex ---
    phi = np.ones(K) / K

    # small Tikhonov regularization for numerical stability in rank-deficient / ill-conditioned cases
    reg = 1e-12

    T = int(max_iters)

    for _ in range(max(1, T)):
        # Sigma = sum_i phi_i x_i x_i^T  (here in r-dim)
        # Build via a single GEMM: Xr.T @ (phi[:,None]*Xr)
        Sigma = Xr.T @ (phi[:, None] * Xr)
        # Regularize slightly to ensure PD
        if reg > 0:
            Sigma = Sigma + reg * np.eye(r)

        # Cholesky + triangular solves instead of explicit inverse (stable and fast)
        try:
            L = np.linalg.cholesky(Sigma)
        except np.linalg.LinAlgError:
            # if numerical issues, bump reg and retry
            eps = max(1e-12, reg * 10)
            Sigma = Sigma + eps * np.eye(r)
            L = np.linalg.cholesky(Sigma)

        # Compute scores v_i = x_i^T Sigma^{-1} x_i efficiently:
        # Solve Sigma^{-1} Xr^T via two triangular solves
        Y = np.linalg.solve(L, Xr.T)       # (r, K)
        Y = np.linalg.solve(L.T, Y)        # (r, K) = Sigma^{-1} Xr^T
        scores = np.sum(Xr * Y.T, axis=1)  # diag(Xr @ Sigma^{-1} @ Xr^T), shape (K,)

        i_star = int(np.argmax(scores))
        v_star = float(scores[i_star])

        # Frank–Wolfe duality gap g = max_i v_i - r  (r is the effective dimension)
        gap = v_star - r

        # Step size: gamma = (v* - r) / (r (v* - 1))  == ((v*/r) - 1) / (v* - 1)
        # Clip to [0,1] just in case of roundoff
        denom = max(1e-18, v_star - 1.0)
        gamma = (v_star - r) / (r * denom)
        gamma = float(np.clip(gamma, 0.0, 1.0))

        # Update weights toward the vertex e_{i*}
        new_phi = (1.0 - gamma) * phi
        new_phi[i_star] += gamma

        # Stopping: (a) FW gap small   OR   (b) step is tiny in L1
        if gap <= tol or np.linalg.norm(new_phi - phi, 1) <= tol:
            phi = new_phi
            break

        phi = new_phi

    return phi


def _logdet(A):
    s, v = np.linalg.slogdet(A)
    return -np.inf if s <= 0 else float(v)

def core_identification_gopt(S_sets, lam, gamma, d, gopt_sampler, c=6, max_pass=10):
    C = list(S_sets)

    for _ in range(max_pass):
        if not C:
            break
        cov = np.zeros((d, d))
        for X in C:
            cov += gopt_sampler.expected_xxT(X)  # E_{π^G}[xx^T]
        M = lam * np.eye(d) + (1.0 / max(gamma, 1.0)) * cov
        M_inv = _chol_inv(M)

        thresh = 0.5 * (d ** c)
        C_next = []
        for X in C:
            smax = float(np.max(np.einsum('id,df,if->i', X, M_inv, X, optimize=True)))
            if smax <= thresh:
                C_next.append(X)
        if len(C_next) == len(C):
            break
        C = C_next
    return C

def learn_distributional_gopt_modified(S_sets, lam, T_total, d, gopt_sampler, N=None):
    Γ = len(S_sets)
    if Γ == 0:
        return [(1.0, np.eye(d))]

    if N is None:
        N = max(1, int(np.ceil(2 * (d ** 2) * np.log(max(d, 2))))) 

    # U0
    cov0 = np.zeros((d, d))
    for X in S_sets:
        cov0 += gopt_sampler.expected_xxT(X)
    U = lam * N * T_total * np.eye(d) + 0.5 * N * cov0
    W_n = U.copy()
    logdet_W = _logdet(W_n)

    epoch_counts = [0]
    W_list = [W_n.copy()]

    alphas = [float(np.log(max(len(X), 2))) for X in S_sets]

    for t in range(N * Γ):
        i = t % Γ
        X = S_sets[i]
        W_inv = _chol_inv(W_n)
        scores = np.einsum('id,df,if->i', X, W_inv, X, optimize=True)
        p_soft = _softmax(scores, alphas[i])
        X_exp = _expected_xxT(X, p_soft)
        U = U + X_exp
        epoch_counts[-1] += 1

        if _logdet(U) - logdet_W >= np.log(2.0):
            W_n = U.copy()
            logdet_W = _logdet(W_n)
            W_list.append(W_n.copy())
            epoch_counts.append(0)

    counts = np.array(epoch_counts, dtype=float)
    counts = counts[counts > 0]
    if counts.size == 0:
        counts = np.array([1.0])
    pis = counts / counts.sum()

    # M_i = N T W_i^{-1}
    mixture = []
    for p, W in zip(pis, W_list[:len(pis)]):
        M = (N * T_total) * _chol_inv(W)
        mixture.append((float(p), M))
    return mixture

# # =========================================================
# # Algorithm 3: CoreLearning wrapper (π^G=g-optimal)
# # =========================================================

def core_learning_gopt(S_sets, lam, T_total, d, gopt_sampler, c=6, rng=None):
    Γ = len(S_sets)
    gamma = max(1, Γ)
    C = core_identification_gopt(S_sets, lam=lam, gamma=gamma, d=d, gopt_sampler=gopt_sampler, c=c, max_pass=int(d*np.log(T_total)))
    if len(C) == 0:
        C = S_sets
    mixture = learn_distributional_gopt_modified(C, lam=lam, T_total=T_total, d=d, gopt_sampler=gopt_sampler, N=None)
    return MixedSoftmaxPolicyGOpt(mixture, gopt_sampler=gopt_sampler, alpha_temp=None, rng = rng)

###############################################################

# ---------------------------------------------------------
# Utilities
# ---------------------------------------------------------

def sigmoid(z: np.ndarray) -> np.ndarray:
    return 1.0 / (1.0 + np.exp(-z))

def logistic_link(x_theta: np.ndarray) -> np.ndarray:
    return sigmoid(x_theta)

def logistic_deriv(x_theta: np.ndarray) -> np.ndarray:
    p = sigmoid(x_theta)
    return p * (1.0 - p)

def batch_schedule_T(T: int) -> np.ndarray:
    T = int(T)
    if T <= 0:
        return np.array([], dtype=int)

    def log2log2_safe(n: int) -> float:
        return max(1.0, math.log2(math.log2(max(n, 4))))

    L = log2log2_safe(T)
    ends = []
    s = math.ceil((T ** (1.0 / 3.0)) / (L))
    s = max(1, s)

    for ell in range(1, 3 + 1):
        nxt = ell * s
        if nxt >= T:
            ends.append(T)
            return np.array(ends, dtype=int)
        ends.append(nxt)

    ell = 4
    while ends[-1] < T:
        inc = math.ceil(T ** (1.0 - 1.0 / (3.0 * (2 ** (ell - 4)))) / L)
        inc = max(1, inc)
        nxt = ends[-1] + inc
        if nxt >= T:
            ends.append(T)
            break
        ends.append(nxt)
        ell += 1

    for i in range(1, len(ends)):
        if ends[i] <= ends[i - 1]:
            ends[i] = ends[i - 1] + 1
    if ends[-1] != T:
        ends[-1] = T

    return np.array(ends, dtype=int)

def chol_solve(A: np.ndarray, b: np.ndarray) -> np.ndarray:
    L = np.linalg.cholesky(A)
    y = np.linalg.solve(L, b)
    return np.linalg.solve(L.T, y)

def chol_inv(A: np.ndarray) -> np.ndarray:
    L = np.linalg.cholesky(A)
    Linv = np.linalg.inv(L)
    return Linv.T @ Linv

def compute_beta(d: int, T: int) -> float:
    return 50 * float(np.sqrt(d + np.log(max(T,3))))

def eliminate_arms_with_history(X: np.ndarray, V_list: list, theta_list: list, beta: float) -> np.ndarray:
    X_curr = X
    for V, th in zip(V_list, theta_list):
        if X_curr.size == 0:
            return X_curr
        try:
            Vinv = chol_inv(V)
        except np.linalg.LinAlgError:
            Vinv = np.linalg.pinv(V)
        var = np.einsum("ij,jk,ik->i", X_curr, Vinv, X_curr, optimize=True)
        eps = 2 * beta * float(np.sqrt(max(var.max(), 1e-15)))
        scores = X_curr @ th
        best = float(scores.max())
        keep = (best - scores) <= eps
        X_curr = X_curr[keep]
    return X_curr

def fit_logistic_ridge(Y: np.ndarray, R: np.ndarray, lam: float, max_newton: int = 50, tol: float = 1e-6) -> np.ndarray:
    n, d = Y.shape
    theta = np.zeros(d)
    lamI = lam * np.eye(d)
    for _ in range(max_newton):
        z = Y @ theta
        p = logistic_link(z)
        W = p * (1.0 - p)
        g = Y.T @ (p - R) + lam * theta
        H = (Y.T * W) @ Y + lamI
        try:
            step = chol_solve(H, g)
        except np.linalg.LinAlgError:
            step = np.linalg.pinv(H) @ g
        theta_new = theta - step
        if np.linalg.norm(theta_new - theta) <= tol * (1.0 + np.linalg.norm(theta)):
            theta = theta_new
            break
        theta = theta_new
    return theta

# ---------------------------------------------------------
# BGLE for GLM (Bernoulli/logistic)
# ---------------------------------------------------------

class BGLE_GLM:
    def __init__(self, T: int, K: int, d: int, lam: float = None, R: float = 1.0, S: float = 1.0):
        self.T, self.K, self.d = int(T), int(K), int(d)
        self.ends = batch_schedule_T(T)
        self.B = len(self.ends)
        self.lam = (d + math.log(max(T,3))) if lam is None else float(lam)
        self.R = float(R)
        self.S = float(S)

        self.V_list = []
        self.theta_list = []
        self.played_X = []
        self.played_R = []
        self.survivor_counts = []

    def _alpha_batch2(self) -> float:
        # k=1
        return float(math.exp(-2.0 * self.R * self.S))

    def _alpha_subseq(self, x: np.ndarray, V_prev: np.ndarray, beta_const: float) -> float:
        # k >= 2 : exp(-R * min(2S, ||x||_{V^{-1}} * beta))
        try:
            Vinv = chol_inv(V_prev)
        except np.linalg.LinAlgError:
            Vinv = np.linalg.pinv(V_prev)
        q = float(np.sqrt(max(x @ Vinv @ x, 1e-15)))
        return float(math.exp(- self.R * min(2.0 * self.S, q * beta_const)))

    def run(self, get_contexts, play, tracker=None, verbose: bool = False):
        lam = self.lam
        lamI = lam * np.eye(self.d)

        # -------- Batch 1 (unweighted) --------
        H = lamI.copy()
        X_batch, R_batch = [], []
        t_prev = 0
        T1 = int(self.ends[0])
        for t in range(1, T1 + 1):
            X_t = get_contexts(t)
            Hinv = chol_inv(H)
            idx = int(np.argmax(np.einsum("ij,jk,ik->i", X_t, Hinv, X_t, optimize=True)))
            x = X_t[idx]
            r = play(x, t)
            if tracker is not None and hasattr(tracker, "observe"):
                tracker.observe(x, X_t)
            X_batch.append(x); R_batch.append(r)
            H += np.outer(x, x)

        Y1 = np.stack(X_batch, axis=0)
        R1 = np.array(R_batch, dtype=float)
        theta1 = fit_logistic_ridge(Y1, R1, lam)
        V1 = H.copy()
        self.V_list.append(V1)
        self.theta_list.append(theta1)
        self.played_X.append(Y1); self.played_R.append(R1)
        H = lamI.copy()
        t_prev = T1

        # -------- Batch 2 (k=1, alpha = e^{-2RS}) --------
        if self.B >= 2:
            beta_const = compute_beta(self.d, self.T)
            X_batch, R_batch = [], []
            counts_this_batch = []
            T2_end = int(self.ends[1])
            alpha_b2 = self._alpha_batch2()
            theta_prev = self.theta_list[-1]     # θ̂_1
            V_prev = self.V_list[-1]             # V_1

            times = list(range(t_prev + 1, T2_end + 1))
            mid = len(times) // 2
            A_times = times[:mid]          
            B_times = times[mid:]  

            for t in A_times:
                X_t = get_contexts(t)
                X_surv = X_t
                counts_this_batch.append(len(X_surv))

                Hinv = chol_inv(H)
                idx = int(np.argmax(np.einsum("ij,jk,ik->i", X_surv, Hinv, X_surv, optimize=True)))
                x = X_surv[idx]
                r = play(x, t)
                if tracker is not None and hasattr(tracker, "observe"):
                    tracker.observe(x, X_t)

                X_batch.append(x); R_batch.append(r)

                s = float(logistic_deriv(float(x @ theta_prev)))
                s = max(s, 1e-6)
                H += (alpha_b2 * s) * np.outer(x, x)
            
            # ---- B half: greedy w.r.t. θ̂_1 ----
            for t in B_times:
                X_t = get_contexts(t)
                X_surv = X_t
                counts_this_batch.append(len(X_surv))

                # greedy: argmax x^T θ̂_1
                idx = int(np.argmax(X_surv @ theta_prev))
                x = X_surv[idx]
                r = play(x, t)
                if tracker is not None and hasattr(tracker, "observe"):
                    tracker.observe(x, X_t)

                X_batch.append(x); R_batch.append(r)

                s = float(logistic_deriv(float(x @ theta_prev))); s = max(s, 1e-6)
                H += (alpha_b2 * s) * np.outer(x, x)

            if len(X_batch) > 0:
                Y = np.stack(X_batch, axis=0)
                R = np.array(R_batch, dtype=float)
                theta_hat = fit_logistic_ridge(Y, R, lam)
                V = H.copy()
            else:
                Y = np.zeros((0, self.d)); R = np.zeros((0,))
                theta_hat = self.theta_list[-1].copy()
                V = H.copy()

            self.V_list.append(V)
            self.theta_list.append(theta_hat)
            self.played_X.append(Y); self.played_R.append(R)
            self.survivor_counts.append(float(np.mean(counts_this_batch)) if counts_this_batch else 0.0)

            H = lamI.copy()
            t_prev = T2_end

        # -------- Subsequent batches (ℓ ≥ 3, k = ℓ-1 ≥ 2) --------
        for ell_idx in range(2, self.B):   # ell_idx: 0→B1,1→B2,2→B3,...
            T_end = int(self.ends[ell_idx])
            X_batch, R_batch = [], []
            counts_this_batch = []

            theta_prev = self.theta_list[-1]
            V_prev = self.V_list[-1]

            beta_const = compute_beta(self.d, self.T)

            times = list(range(t_prev + 1, T_end + 1))
            mid = len(times) // 2
            A_times = times[:mid]          
            B_times = times[mid:]  

            for t in A_times:
                X_t = get_contexts(t)

                X_surv = eliminate_arms_with_history(X_t, self.V_list[1:], self.theta_list[1:], beta_const)
                if len(X_surv) == 0:
                    X_surv = X_t

                counts_this_batch.append(len(X_surv))

                Hinv = chol_inv(H)
                idx = int(np.argmax(np.einsum("ij,jk,ik->i", X_surv, Hinv, X_surv, optimize=True)))
                x = X_surv[idx]
                r = play(x, t)
                if tracker is not None and hasattr(tracker, "observe"):
                    tracker.observe(x, X_t)
                X_batch.append(x); R_batch.append(r)

                s = float(logistic_deriv(float(x @ theta_prev)))
                s = max(s, 1e-6)
                alpha = self._alpha_subseq(x, V_prev, beta_const)
                H += (alpha * s) * np.outer(x, x)
            
            # ---- B half: greedy w.r.t. θ̂_{prev} ----
            for t in B_times:
                X_t = get_contexts(t)

                X_surv = eliminate_arms_with_history(X_t, self.V_list[1:], self.theta_list[1:], beta_const)
                if len(X_surv) == 0:
                    X_surv = X_t

                counts_this_batch.append(len(X_surv))

                # greedy: argmax x^T θ̂_{prev}
                idx = int(np.argmax(X_surv @ theta_prev))
                x = X_surv[idx]
                r = play(x, t)
                if tracker is not None and hasattr(tracker, "observe"):
                    tracker.observe(x, X_t)
                X_batch.append(x); R_batch.append(r)

                s = float(logistic_deriv(float(x @ theta_prev))); s = max(s, 1e-6)
                alpha = self._alpha_subseq(x, V_prev, beta_const)
                H += (alpha * s) * np.outer(x, x)

            if len(X_batch) > 0:
                Y = np.stack(X_batch, axis=0)
                R = np.array(R_batch, dtype=float)
                theta_hat = fit_logistic_ridge(Y, R, lam)
                V = H.copy()
            else:
                Y = np.zeros((0, self.d)); R = np.zeros((0,))
                theta_hat = self.theta_list[-1].copy()
                V = H.copy()

            self.V_list.append(V)
            self.theta_list.append(theta_hat)
            self.played_X.append(Y); self.played_R.append(R)
            self.survivor_counts.append(float(np.mean(counts_this_batch)) if counts_this_batch else 0.0)

            H = lamI.copy()
            t_prev = T_end

        print("Batch-wise average survivor counts:", self.survivor_counts)
        return

# ---------------------------------------------------------
# Environment & tracker
# ---------------------------------------------------------

class BernoulliGLMEnv:
    def __init__(self, d: int, theta: np.ndarray = None, rng = None):
        # if seed is not None:
        #     np.random.seed(seed)
        self.d = int(d)
        self.rng = rng if rng is not None else np.random.default_rng()
        self.theta = (self.rng.normal(size=d) if theta is None else np.asarray(theta, float))

    def play(self, x: np.ndarray, t: int = None) -> float:
        p = float(sigmoid(x @ self.theta))
        return float(self.rng.random() < p)

class RegretTrackerGLM:
    def __init__(self, theta_true: np.ndarray):
        self.regrets = []
        self.theta_true = np.asarray(theta_true, float)

    def observe(self, chosen_x: np.ndarray, context_set: np.ndarray):
        best = float(np.max(sigmoid(context_set @ self.theta_true)))
        got  = float(sigmoid(chosen_x @ self.theta_true))
        self.regrets.append(best - got)

# ---------------------------------------------------------
# Simulation
# ---------------------------------------------------------

def simulate_glm_bgle(T: int, K: int, d: int, seed: int = 3, R: float = 1.0, S: float = 1.0, contexts = None):
    rng = np.random.default_rng(seed)
    theta_star = rng.normal(size=d)
    env = BernoulliGLMEnv(d, theta=theta_star, rng=rng)  
    tracker = RegretTrackerGLM(theta_true=theta_star)

    if contexts is None:
        context_log = {}
        def get_contexts(t: int) -> np.ndarray:
            if t not in context_log:
                context_log[t] = rng.normal(size=(K, d)) # rng.random
            return context_log[t]
    
    else:
        def get_contexts(t):
            return contexts[t-1]

    def play_wrapper(x, t):
        r = env.play(x, t)
        tracker.observe(x, get_contexts(t))
        return r

    algo = BGLE_GLM(T=T, K=K, d=d, lam=(d + math.log(max(T,3))), R=R, S=S)

    algo.run(get_contexts, play_wrapper, tracker=None, verbose=False)

    cum_regret = np.cumsum(np.array(tracker.regrets, dtype=float))

    ends = np.array(algo.ends, dtype=int)
    starts = np.r_[1, ends[:-1] + 1]
    t_grid = np.arange(1, T+1)
    batch_counts = np.searchsorted(starts, t_grid, side="right")

    return cum_regret, batch_counts

# ##############################################################################

# =========================================================
# Basic utils
# =========================================================

def sigmoid(z): return 1.0 / (1.0 + np.exp(-z))
def logistic_link(x_theta): return sigmoid(x_theta)
def logistic_deriv(x_theta):
    p = sigmoid(x_theta); return p * (1.0 - p)

def _chol_inv(A):
    L = np.linalg.cholesky(A)
    Linv = np.linalg.inv(L)
    return Linv.T @ Linv

def _chol_solve(A, b):
    L = np.linalg.cholesky(A)
    y = np.linalg.solve(L, b)
    return np.linalg.solve(L.T, y)

# ---------------------------------------------------------
# Batch schedule τ_k  (Fig.2)
# τ1 = ((sqrt(κ) e^{3S} d^2 γ^2 α) / S)^{2/3}, τ2 = α, τk = α sqrt(τ_{k-1})
# α = T^{ 2 (1 - 2^{-(M-1)}) } if M <= log log T, else α = 2 sqrt(T)
# γ = 30 R S sqrt(d log T)
# ---------------------------------------------------------
def _gamma_RS(R, S, d, T):
    return 30 * R * S * math.sqrt(d * math.log(max(T, 3)))

def _alpha_T(T, M):
    th = math.log(max(math.log(max(T, 3)), 2.0), 2)  # log2 log T
    if M <= th:
        return T ** (1/(2.0 * (1.0 - 2.0 ** (-(M - 1)))))
    else:
        return 2.0 * math.sqrt(T)

def _tau_schedule(T, d, M, R, S, kappa):
    gamma = _gamma_RS(R, S, d, T)
    alpha = _alpha_T(T, M)
    tau = []
    # τ1
    t1 = ((math.sqrt(kappa) * math.exp(3.0 * S) * (d ** 2) * (gamma ** 2) * alpha) / max(S, 1e-12)) ** (2.0 / 3.0)
    t1 = int(max(1, min(T, math.ceil(t1))))
    tau.append(t1)
    # τ2
    t2 = int(min(T, math.ceil(alpha))); tau.append(max(t1 + 1, t2))
    # τk = α√τ_{k-1}
    while tau[-1] < T:
        tk = int(min(T, math.ceil(alpha * math.sqrt(tau[-1]))))
        if tk <= tau[-1]: tk = min(T, tau[-1] + 1)
        tau.append(tk)
    tau[-1] = T
    for i in range(1, len(tau)):
        if tau[i] <= tau[i-1]: tau[i] = min(T, tau[i-1] + 1)
    return tau, gamma, alpha

# ---------------------------------------------------------
# κ upper bound for logistic (Fig.4)
# κ = max 1 / μ'(⟨x, θ*⟩).  |⟨x,θ*⟩| ≤ R S
# μ'(z) ≥ σ(RS)(1-σ(RS))  ⇒  κ ≤ 1 / (σ(RS)(1-σ(RS)))
# ---------------------------------------------------------
def kappa_upper_bound(R, S):
    z = R * S
    p = sigmoid(z)
    return 1.0 / max(p * (1.0 - p), 1e-12)

# ---------------------------------------------------------
# β(x) (Fig.3)  with V from Batch1
# β(x) = exp( R * min{ 2S, γ * sqrt(κ) * ||x||_{V^{-1}} } )
# ---------------------------------------------------------
def beta_of_x(x, Vinv, R, S, gamma, kappa):
    q = float(np.sqrt(max(x @ Vinv @ x, 1e-15)))
    return math.exp(R * min(2.0 * S, gamma * math.sqrt(kappa) * q))

# ---------------------------------------------------------
# UCB/LCB elimination (Fig.5)
# k=1 : ⟨x, θ̂_w⟩ ± γ√κ ||x||_{V^{-1}}
# k>1 : ⟨x, θ̂_k⟩ ± γ ||x||_{H_k^{-1}}
# ---------------------------------------------------------
def eliminate_ucb_intersection_BGLinCB(X, thetas, mats, gamma, kappa):
    if len(thetas) == 0:
        return X
    d = X.shape[1]
    for j, (th, M) in enumerate(zip(thetas, mats), start=1):
        # shape guard
        th = np.asarray(th).reshape(-1)
        if th.size != d or M.shape != (d, d):
            continue

        Minv = _chol_inv(M)
        sd = np.sqrt(np.maximum(np.einsum('id,df,if->i', X, Minv, X, optimize=True), 1e-15))

        bonus = (gamma * np.sqrt(kappa) if j == 1 else gamma) * sd  # (K,)
        ucbs  = X @ th + bonus                                       # (K,)
        lcbs  = X @ th - bonus                                       # (K,)
        thresh = float(lcbs.max())                                   # scalar

        cond = (ucbs >= thresh)                                      # (K,)
        if not np.any(cond):
            return np.zeros((0, d))
        X = X[cond]
    return X

# ---------------------------------------------------------
# ---------------------------------------------------------
def fit_logistic_ridge(Y, Rv, lam, d=None, max_iter=50, tol=1e-6):
    if Y is None or len(Y) == 0:
        assert d is not None, "d must be provided when Y is empty"
        return np.zeros(d)
    
    n, d_local = Y.shape
    if d is None:
        d = d_local
    theta = np.zeros(d)
    lamI = lam * np.eye(d)
    for _ in range(max_iter):
        z = Y @ theta
        p = logistic_link(z)
        W = p * (1.0 - p)
        g = Y.T @ (p - Rv) + lam * theta
        H = (Y.T * W) @ Y + lamI
        try: step = _chol_solve(H, g)
        except np.linalg.LinAlgError: step = np.linalg.pinv(H) @ g
        new = theta - step
        if np.linalg.norm(new - theta) <= tol * (1.0 + np.linalg.norm(theta)):
            theta = new; break
        theta = new
    return theta

# =========================================================
# =========================================================

def _softmax(scores, alpha):
    p = np.power(np.maximum(scores, 1e-15), alpha)
    Z = p.sum()
    return p / Z if np.isfinite(Z) and Z > 0 else np.ones_like(scores) / len(scores)

def _expected_xxT(X, probs): return X.T @ (probs[:, None] * X)

class GOptimalSampler:
    def __init__(self, design_fn, rng=None): 
        self.design_fn = design_fn
        self.rng = rng if rng is not None else np.random.default_rng()

    def _coerce_matrix(self, X):
        X = np.asarray(X, dtype=float)
        if X.ndim == 1:
            X = X.reshape(1, -1)
        return X
    
    def sample_index(self, X):
        X = self._coerce_matrix(X)
        if X.ndim != 2 or X.shape[0] == 0:
            raise ValueError(f"[Sampler] empty/invalid X: shape={X.shape}")
        phi = self.design_fn(X)
        if not np.all(np.isfinite(phi)) or phi.sum() <= 0:
            phi = np.ones(len(X)) / max(len(X), 1)
        return int(self.rng.choice(len(X), p=(phi/phi.sum())))
    
    def expected_xxT(self, X):
        X = self._coerce_matrix(X)
        if X.ndim != 2 or X.shape[0] == 0:
            d = X.shape[1] if X.ndim == 2 else 0
            return np.zeros((d, d))
        phi = self.design_fn(X)
        if not np.all(np.isfinite(phi)) or phi.sum() <= 0:
            phi = np.ones(len(X)) / len(X)
        phi = phi / phi.sum()
        return _expected_xxT(X, phi)

class MixedSoftmaxPolicyGOpt:
    def __init__(self, mixture, gopt_sampler, alpha_temp=None, rng = None):
        self.mixture = mixture
        self.ps = np.array([p for p,_ in mixture], float)
        self.Ms = [M for _,M in mixture]
        self.gopt = gopt_sampler
        self.alpha_temp = alpha_temp
        self.rng = rng if rng is not None else np.random.default_rng()

    def choose(self, X):
        K = len(X); 
        if K == 0: raise ValueError("empty X")
        alpha = self.alpha_temp if self.alpha_temp is not None else float(np.log(max(K,2)))
        if np.random.rand() < 0.5:
            return self.gopt.sample_index(X)
        j = int(np.random.choice(len(self.mixture), p=self.ps))
        M = self.Ms[j]
        scores = np.einsum('id,df,if->i', X, M, X, optimize=True)
        probs = _softmax(scores, alpha)
        return int(self.rng.choice(K, p=probs))

# =========================================================
# B-GLinCB (pseudo code in your Figure 1)
# =========================================================

class BGLinCB:
    def __init__(self, T, K, d, R=1.0, S=1.0, lam=None,
                 M=None, design_fn=None, core_learning_fn=None, rng=None):
        self.T, self.K, self.d = int(T), int(K), int(d)
        self.R, self.S = float(R), float(S)
        self.lam = (d + math.log(max(T,3))) if lam is None else float(lam)

        self.M = int(M) if M is not None else max(2, int(math.ceil(math.log2(max(math.log2(max(T,4)),2.0)))))
        self.kappa = kappa_upper_bound(self.R, self.S)
        self.tau, self.gamma, self.alpha_param = _tau_schedule(T, d, self.M, R, S, self.kappa)

        self.rng = rng if rng is not None else np.random.default_rng()

        # π_G
        assert design_fn is not None, "design_fn (near g-optimal) must be provided"
        self.gopt = GOptimalSampler(design_fn, rng=self.rng)
        self.gopt_sampler = self.gopt

        assert core_learning_fn is not None, "core_learning_fn (core_learning_gopt) must be provided"
        self.core_learning_fn = core_learning_fn

        self.V = None           # batch1 information matrix
        self.theta_w = None     # batch1 logistic ridge
        self.H_list = []       
        self.theta_list = []    # θ̂_1(=θ_w), θ̂_2, ..., θ̂_{M-1}
        self.policy_list = []   # π_1 (= π_G), π_2, ...

    def _scale_X_with_beta(self, X, Vinv):
        betas = np.array([beta_of_x(x, Vinv, self.R, self.S, self.gamma, self.kappa) for x in X], float)
        scales = 1.0 / np.sqrt(np.maximum(betas, 1e-15))
        return X * scales[:, None]

    def run(self, get_contexts, play):
        """
        get_contexts(t) -> (K,d) numpy
        play(x,t) -> reward (0/1)
        """
        # ---------------- Batch 1 ----------------
        t_prev = 0; T1 = int(self.tau[0])
        X_obs, R_obs = [], []
        H = self.lam * np.eye(self.d)

        for t in range(1, T1 + 1):
            X_t = get_contexts(t)

            X_t = np.asarray(X_t, dtype=float)
            if X_t.ndim == 1:
                X_t = X_t.reshape(1, -1)
            if X_t.ndim != 2 or X_t.shape[0] == 0 or X_t.shape[1] != self.d:
                raise ValueError(f"[BGLinCB] invalid context at t={t}, got shape {X_t.shape}, expected (K,{self.d})")

            # near g-optimal design sampler
            idx = self.gopt.sample_index(X_t)
            x = X_t[idx]
            r = play(x, t)
            X_obs.append(x); R_obs.append(r)
            H += np.outer(x, x)

        Y1 = np.stack(X_obs, 0); R1 = np.array(R_obs, float)
        theta_w = fit_logistic_ridge(Y1, R1, self.lam)
        self.theta_w = theta_w
        V = H.copy()         # information matrix after batch1
        self.V = V
        self.theta_list = [theta_w]
        Vinv = _chol_inv(V)

        # π_1 = G-optimal policy
        self.policy_list = [self.gopt]

        # ---------------- Batches k = 2..M ----------------
        prev_end = T1
        for k in range(2, len(self.tau)+1):
            end = int(self.tau[k-1])
            times = list(range(prev_end + 1, end + 1))

            def eliminate_all(X):
                thetas = [self.theta_w] + self.theta_list[1:]
                mats   = [self.V]       + self.H_list
                return eliminate_ucb_intersection_BGLinCB(X, thetas, mats, self.gamma, self.kappa)

            mid = len(times) // 2
            A_times = times[:mid]
            B_times = times[mid:] if mid > 0 else []

            # ===== A half: collect (play) for H_k, θ̂_k =====
            XA, RA = [], []
            pi_prev = self.policy_list[k-2]

            for t in A_times:
                X_t = get_contexts(t)
                X_surv = eliminate_all(X_t)
                if len(X_surv) == 0: 
                    continue
                X_scaled = self._scale_X_with_beta(X_surv, Vinv)
                idx = pi_prev.sample_index(X_scaled) if isinstance(pi_prev, GOptimalSampler) else pi_prev.choose(X_scaled)
                x = X_surv[idx]
                r = play(x, t)                 
                XA.append(x); RA.append(r)

            Hk = self.lam * np.eye(self.d)
            for x in XA:
                s = float(logistic_deriv(float(x @ self.theta_w))); s = max(s, 1e-8)
                beta_x = beta_of_x(x, Vinv, self.R, self.S, self.gamma, self.kappa)
                Hk += (s / beta_x) * np.outer(x, x)

            theta_k = fit_logistic_ridge(np.stack(XA,0) if XA else np.zeros((0,self.d)),
                                        np.array(RA,float) if XA else np.zeros((0,)),
                                        self.lam, d=self.d)
            self.H_list.append(Hk)
            self.theta_list.append(theta_k)

            # ===== B half: collect sets -> learn π_k =====
            S_sets = []
            B_cache = {}                      
            for t in B_times:
                X_t = get_contexts(t)
                X_surv = eliminate_all(X_t)
                if len(X_surv) > 0:
                    S_sets.append(X_surv)
                B_cache[t] = X_surv if len(X_surv) > 0 else X_t

            lam_core = 1.0 / self.T
            pi_k = self.core_learning_fn(S_sets, lam=lam_core, T_total=self.T,
                                        d=self.d, gopt_sampler=self.gopt_sampler, c=6, rng=self.rng)
            self.policy_list.append(pi_k)

            for t in B_times:
                X_surv = B_cache[t]
                if len(X_surv) == 0:
                    X_surv = get_contexts(t)
                X_scaled = self._scale_X_with_beta(X_surv, Vinv)
                idx = pi_k.sample_index(X_scaled) if isinstance(pi_k, GOptimalSampler) else pi_k.choose(X_scaled)
                x = X_surv[idx]
                _ = play(x, t)                 

            prev_end = end
        return
    
def simulate_bglincb(T, K, d, seed=3, R=1.0, S=1.0, design_fn=None, core_learning_fn=None, contexts = None):
    assert design_fn is not None, "need near g-optimal design function"
    assert core_learning_fn is not None, "need core_learning_gopt function"

    rng = np.random.default_rng(seed)
    theta_star = rng.normal(size=d)

    class _Env:
        def __init__(self, theta): self.theta = theta
        def play(self, x, t=None):
            p = float(1.0/(1.0+np.exp(-x@self.theta)))
            return float(rng.random() < p)

    class _Tracker:
        def __init__(self, theta_true):
            self.theta_true = theta_true; self.regrets = []
        def observe(self, chosen_x, context_set):
            best = float(np.max(1.0/(1.0+np.exp(-(context_set@self.theta_true)))))
            got  = float(1.0/(1.0+np.exp(-(chosen_x@self.theta_true))))
            self.regrets.append(best - got)

    env = _Env(theta_star)
    tracker = _Tracker(theta_star)

    if contexts is None:
        context_log = {}
        def get_contexts(t: int) -> np.ndarray:
            if t not in context_log:
                context_log[t] = rng.normal(size=(K, d)) # rng.random
            return context_log[t]
    
    else:
        def get_contexts(t):
            return contexts[t-1]

    def play_wrapper(x, t):
        r = env.play(x, t)
        tracker.observe(x, get_contexts(t))
        return r

    algo = BGLinCB(T=T, K=K, d=d, R=R, S=S,lam=(20*d*math.log(max(T,3))),
                   design_fn=design_fn,
                   core_learning_fn=core_learning_fn, rng=rng)
    algo.run(get_contexts, play_wrapper)

    cum_regret = np.cumsum(np.array(tracker.regrets, float))

    tau = np.array(algo.tau, dtype=int)
    starts = np.r_[1, tau[:-1] + 1]
    t_grid = np.arange(1, T+1)
    batch_counts = np.searchsorted(starts, t_grid, side="right")

    return cum_regret, batch_counts

# ##############################################################################

def collect_runs(simulate_fn, N_RUNS, seed0=3):
    runs = []
    for i in range(N_RUNS):
        cr = simulate_fn(seed0 + i)
        T = len(cr)
        if i == 0:
            T_ref = T
        else:
            if len(cr) < T_ref:
                cr = np.pad(cr, (0, T_ref - len(cr)), mode="edge")
            elif len(cr) > T_ref:
                cr = cr[:T_ref]
        runs.append(cr)
    return np.vstack(runs)

def _align_to_T(arr, T_ref):
    if len(arr) == T_ref:
        return arr
    if len(arr) > T_ref:
        return arr[:T_ref]
    # pad
    pad_val = arr[-1] if len(arr) > 0 else 0.0
    return np.pad(arr, (0, T_ref - len(arr)), mode="constant", constant_values=pad_val)

def collect_runs_with_batches(simulate_fn, N_RUNS, seed0=3):
    regrets = []
    batches_all = []
    run_times = []                 
    T_ref = None
    for i in range(N_RUNS):
        t0 = time.perf_counter()   
        cr, bc = simulate_fn(seed0 + i)
        dt = time.perf_counter() - t0 
        run_times.append(dt)

        if T_ref is None:
            T_ref = len(cr)
        cr = _align_to_T(cr, T_ref)
        bc = _align_to_T(bc, T_ref)
        regrets.append(cr)
        batches_all.append(bc)

    regrets = np.vstack(regrets)          # (N_RUNS, T)
    batches_all = np.vstack(batches_all)  # (N_RUNS, T)
    run_times = np.array(run_times, float)
    return regrets, batches_all, run_times

def plot_regret_and_batches(regret_stats, batch_stats,
                            title=None, out_pdf="regret_batches.pdf",
                            axis_fs=20, legend_fs=20, title_fs=18,T_fixed=10000, ymax_regret=3000):
    fig, axes = plt.subplots(1, 2, figsize=(12, 4))
    axR, axB = axes

    # (Left) regret + std band
    handles, labels = [], []
    for name, S in regret_stats.items():
        m, s = S["mean"], S["std"]
        t = np.arange(1, len(m) + 1)
        h, = axR.plot(t, m, linewidth=2, label=name)
        axR.fill_between(t, m - s, m + s, alpha=0.18)
        handles.append(h); labels.append(name)
    axR.set_xlabel("Time step", fontsize=axis_fs)
    axR.set_ylabel("Regret", fontsize=axis_fs)
    axR.grid(True, alpha=0.4)

    axR.set_xlim(0, T_fixed)
    axR.set_ylim(0, ymax_regret)

    # (Right) cumulative batches
    for name, S in batch_stats.items():
        b = S["batches"]
        t = np.arange(1, len(b) + 1)
        axB.step(t, b, where="post", linewidth=2)
    axB.set_xlabel("Time step", fontsize=axis_fs)
    axB.set_ylabel("Batch", fontsize=axis_fs)
    axB.grid(True, alpha=0.4)

    axB.set_xlim(0, T_fixed)

    fig.legend(handles, labels,
               loc="upper center",
               bbox_to_anchor=(0.5, 1),           
               bbox_transform=fig.transFigure,
               ncol=max(1, len(labels)),
               frameon=False, fontsize=legend_fs)

    # if title:
    #     fig.suptitle(title, fontsize=title_fs, y=1.05)

    fig.tight_layout(rect=(0, 0, 1, 0.9))
    plt.savefig(out_pdf, format="pdf", bbox_inches="tight")
    plt.show()
    plt.close()

# -------------------------------
if __name__ == "__main__":
    T, K, D = 10000, 50, 3
    R, S = 1.0, 1.0
    N_RUNS = 20 # 20

    def make_shared_contexts(seed, T, K, D):
        r = np.random.default_rng(seed)
        return r.normal(size=(T, K, D))  # r.normal(size=(T, K, D)) vs r.random((T, K, D))

    alg_registry = {
        "B-GLinCB": lambda seed: simulate_bglincb(
            T=T, K=K, d=D, seed=seed, R=R, S=S,
            design_fn=near_g_optimal_design,
            core_learning_fn=core_learning_gopt,
            contexts=make_shared_contexts(seed, T, K, D)
        ),
        "BGLE": lambda seed: simulate_glm_bgle(T=T, K=K, d=D, seed=seed, R=R, S=S, contexts=make_shared_contexts(seed, T, K, D)),
    }

    regret_stats, batch_stats = {}, {}
    for name, sim in alg_registry.items():
        runs, batches_mat, times = collect_runs_with_batches(sim, N_RUNS=N_RUNS, seed0=3)
        mean_cr = runs.mean(axis=0)
        std_cr  = runs.std(axis=0, ddof=1)
        mean_batches = batches_mat.mean(axis=0)   
        t = np.arange(1, len(mean_cr) + 1)
        regret_stats[name] = {"t": t, "mean": mean_cr, "std": std_cr}
        batch_stats[name] = {"t": np.arange(1, len(mean_cr)+1), "batches": mean_batches}

        print(f"[{name}] avg runtime per run: {times.mean():.3f}s ", f"(std: {times.std(ddof=1):.3f}s, N={len(times)})")

    plot_regret_and_batches(
        regret_stats, batch_stats,
        title=f"Cumulative Regret & Batches (N={N_RUNS})",
        out_pdf="glm_regret_and_batches.pdf",
        axis_fs=15,     
        legend_fs=15,   
        title_fs=18     
    )