import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm

def sample_many(gen, n):
    try:
        out = gen(size=n)
        return np.asarray(out).tolist()
    except TypeError:
        return [gen() for _ in range(n)]

def compute_regret_scaled(all_final_n, stds, T, d, lambda_min_C):
    all_final_n = np.asarray(all_final_n, dtype=float)   # (runs, K)
    stds = np.asarray(stds, dtype=float)                 # (K,)
    runs, K = all_final_n.shape

    # expected_errors[k] = sigma_k^2 / n_k
    expected_errors = (stds**2)[None, :] / all_final_n   # (runs, K)


    empirical = np.sum(expected_errors, axis=1)          # (runs,)

    # oracle (p=1): q = 2p/(p+1) = 1 -> ((sum sigma_k)^2)/T
    oracle = (np.sum(stds))**2 / float(T)

    regret = empirical - oracle                          # (runs,)
    coef = 2.0 * float(d) / float(lambda_min_C)
    return float(np.mean(regret) * coef)

def make_eps_funcs(sigma_SG, delta):
    delta = float(delta)
    delta = min(max(delta, np.finfo(float).tiny), 1.0 - 1e-12)
    log_inv_delta = -np.log(delta)

    def epsilon_t_minus(t):
        t = np.asarray(t, dtype=float)
        t_safe = np.maximum(t, 2.0)
        term1 = 4.0 * sigma_SG**2 * (1.0 + np.sqrt(t_safe - 1.0)) / np.sqrt(t_safe)
        term1 *= np.sqrt(2.0 * log_inv_delta / (t_safe - 1.0))
        term2 = 13.0 * (sigma_SG**2) * log_inv_delta / (3.0 * t_safe)
        return term1 + term2

    def epsilon_t_plus(t):
        t = np.asarray(t, dtype=float)
        t_safe = np.maximum(t, 2.0)
        term1 = 4.0 * sigma_SG**2 * (1.0 + np.sqrt(t_safe - 1.0)) / np.sqrt(t_safe)
        term1 *= np.sqrt(2.0 * log_inv_delta / (t_safe - 1.0))
        term2 = 6.0 * (sigma_SG**2) * log_inv_delta / t_safe
        return term1 + term2

    return epsilon_t_minus, epsilon_t_plus


def make_s_funcs(delta):
    delta = float(delta)
    delta = min(max(delta, np.finfo(float).tiny), 1.0 - 1e-12)
    log_inv_delta = -np.log(delta)

    def s_t_minus(t):
        t = np.asarray(t, dtype=float)
        t_safe = np.maximum(t, 2.0)
        term1 = 4.0 * (1.0 + np.sqrt((t_safe - 1.0)/8.0)) / np.sqrt(t_safe)
        term1 *= np.sqrt(2.0 * log_inv_delta / (t_safe - 1.0))
        term2 = 13.0 * log_inv_delta / (3.0 * t_safe)
        return term1 + term2

    def s_t_plus(t):
        t = np.asarray(t, dtype=float)
        t_safe = np.maximum(t, 2.0)
        term1 = 4.0 * (1.0 + np.sqrt((t_safe - 1.0)/8.0)) / np.sqrt(t_safe)
        term1 *= np.sqrt(2.0 * log_inv_delta / (t_safe - 1.0))
        term2 = 6.0  * log_inv_delta / t_safe
        return term1 + term2

    return s_t_minus, s_t_plus


# ==============================

# ==============================
def ridge_from_stats(V0, b, s_yy, sx, sy, n, gamma_fn):
    d = V0.shape[0]
    gamma = float(gamma_fn(int(n)))


    lam_min = float(np.linalg.eigvalsh(V0).min()) if d > 0 else 0.0
    lam_min = max(lam_min, -1e-12)
    add = max(0.0, gamma - lam_min)
    V = V0 + add * np.eye(d)


    V_spd = V + 1e-12 * np.eye(d)
    try:
        L = np.linalg.cholesky(V_spd)
        z = np.linalg.solve(L, b)
        beta_hat = np.linalg.solve(L.T, z)
    except np.linalg.LinAlgError:
        try:
            beta_hat = np.linalg.solve(V, b)
        except np.linalg.LinAlgError:
            beta_hat, *_ = np.linalg.lstsq(V, b, rcond=None)


    sse = float(s_yy - 2.0 * beta_hat.dot(b) + beta_hat.dot(V0.dot(beta_hat)))
    if n >= 2:
        y_bar = float(sy) / float(n)
        x_bar = sx / float(n)
        r_bar = y_bar - float(beta_hat.dot(x_bar))
        sigma2_hat = max((sse - float(n) * (r_bar ** 2)) / float(n - 1), 0.0)

    return beta_hat, sigma2_hat, V0, V


# ==============================
def ridge_beta_and_resid_var(C_list, y_list, gamma_fn):
    n = len(y_list)
    C = np.stack(C_list, axis=0).astype(float)  # (n,d)
    y = np.asarray(y_list, dtype=float)         # (n,)
    V0 = C.T @ C
    b  = C.T @ y
    s_yy = float(y @ y)
    sx = C.sum(axis=0)
    sy = float(y.sum())
    return ridge_from_stats(V0, b, s_yy, sx, sy, n, gamma_fn)


# ------------------------------

# ------------------------------
def sample_many(gen, n):
    try:
        return np.asarray(gen(size=n))
    except TypeError:
        return np.array([gen() for _ in range(n)])


# ------------------------------
def step2_adaptive_elimination(
    T: int,
    X: list[list[np.ndarray]],
    y: list[list[float]],
    s_minus_fn,                          # Callable[[np.ndarray], np.ndarray]
    s_plus_fn,                           # Callable[[np.ndarray], np.ndarray]
    gamma_fn,                            # gamma_k(n_k)
    q: float = 1.0,
    context_sampler=None,
    reward_generators=None,
):
    K = len(X)
    T_int = int(T)


    t_counts = np.array([len(y[k]) for k in range(K)], dtype=int)
    d = X[0][0].shape[0] if (K > 0 and len(X[0]) > 0) else 0
    V0 = [np.zeros((d, d), dtype=float) for _ in range(K)]
    b  = [np.zeros(d, dtype=float)       for _ in range(K)]
    s_yy = np.zeros(K, dtype=float)
    sx = [np.zeros(d, dtype=float)       for _ in range(K)]
    sy = np.zeros(K, dtype=float)

    for k in range(K):
        if t_counts[k] > 0:
            Ck = np.stack(X[k], axis=0)      # (n_k, d)
            yk = np.asarray(y[k], dtype=float)
            V0[k] = Ck.T @ Ck
            b[k]  = Ck.T @ yk
            s_yy[k] = float(yk @ yk)
            sx[k] = Ck.sum(axis=0)
            sy[k] = float(yk.sum())

    lam = t_counts.astype(float) / float(T_int)
    active = np.ones(K, dtype=bool)
    tau = np.full(K, -1, dtype=int)


    dirty = np.ones(K, dtype=bool)
    sigma2_hats = np.zeros(K, dtype=float)

    max_loops = 10000
    for _ in range(max_loops):
        if not active.any():
            break

        remaining = T_int - int(t_counts.sum())
        if remaining <= 0:
            break


        for k in np.flatnonzero(active):
            target_nk = int(np.floor(lam[k] * T_int))
            to_add = max(0, target_nk - t_counts[k])
            if to_add <= 0:
                continue
            add_k = min(to_add, remaining)
            if add_k <= 0:
                break


            for _ in range(add_k):
                xk = np.asarray(context_sampler(k), dtype=float).reshape(-1)
                rk = float(reward_generators[k](xk))
                X[k].append(xk); y[k].append(rk)

                V0[k] += np.outer(xk, xk)
                b[k]  += xk * rk
                s_yy[k] += rk * rk
                sx[k]  += xk
                sy[k]  += rk
                t_counts[k] += 1
                dirty[k] = True

            remaining -= add_k
            if remaining <= 0:
                break


        for k in range(K):
            if not dirty[k]:
                continue
            _, s2_k, _, _ = ridge_from_stats(V0[k], b[k], s_yy[k], sx[k], sy[k], t_counts[k], gamma_fn)
            sigma2_hats[k] = s2_k
            dirty[k] = False


        s_minus = s_minus_fn(t_counts)   # for UCB: divide by (1 - s_minus)
        s_plus  = s_plus_fn(t_counts)    # for LCB: divide by (1 + s_plus)
        denom_lcb = np.maximum(1.0 + s_plus, 1e-8)
        denom_ucb = np.maximum(1.0 - s_minus, 1e-8)
        lcb_vals = sigma2_hats / denom_lcb
        ucb_vals = sigma2_hats / denom_ucb


        pow_ = q / 2.0
        lcb_q = np.power(np.maximum(lcb_vals, 0.0), pow_)
        ucb_q = np.power(np.maximum(ucb_vals, 0.0), pow_)
        denom = lcb_q + (ucb_q.sum() - ucb_q)
        denom = np.where(denom <= 0.0, 1e-12, denom)
        lam_new = lcb_q / denom


        lam_thresh = np.floor(np.maximum(lam_new, 0.0) * T_int + 1e-12).astype(int)
        elim = t_counts >= lam_thresh
        tau[elim] = t_counts[elim]

        active = ~elim
        lam = lam_new

    return tau, V0, b, s_yy, sx ,sy

def contextual_exploration_free(
    T: int,
    d: int,
    K: int,
    q: float,
    sigma_min: float,
    sigma_SG: float,
    delta: float,
    context_sampler,             # Callable[[int], np.ndarray[d]]
    reward_generators,           # list[Callable[[np.ndarray[d]], float]]
    eigenvalue_minimum,
    gamma_fn,                    # callable: n -> gamma
):
    T_int = int(T)


    s_minus_fn, s_plus_fn = make_s_funcs(delta)


    X = [[] for _ in range(K)]
    y = [[] for _ in range(K)]

    # ---- Step 1: warm-up until s_minus(n) <= 1/m (budget-aware) ----
#     m = 1
#     t_counts = np.zeros(K, dtype=int)



#     n0 = int(np.ceil(18.0 * (m ** 2) * np.log(max(T_int, 2))))
#     n_max = T_int // K
#     n = min(n0, n_max)


#     for k in range(K):
#         for _ in range(n):
#             x = np.asarray(context_sampler(k), dtype=float).reshape(-1)
#             r = float(reward_generators[k](x))
#             X[k].append(x)
#             y[k].append(r)
#         t_counts[k] = n

#     total_pulls = int(t_counts.sum())


#     while (total_pulls + K) <= T_int:
#         s_val = float(s_minus_fn(float(n)))
#         if s_val <= 1.0 / float(m):
#             break

#         for k in range(K):
#             x = np.asarray(context_sampler(k), dtype=float).reshape(-1)
#             r = float(reward_generators[k](x))
#             X[k].append(x)
#             y[k].append(r)
#             t_counts[k] += 1

#         n += 1
#         total_pulls += K


    tau_1 = int((sigma_min**q) * T_int / ((sigma_min**q) + (K - 1) * (sigma_SG**q)))
    tau_1 = max(tau_1, 2)
    tau_1 = min(tau_1, max(T_int // max(K, 1), 2))
    for k in range(K):
        for _ in range(tau_1):
            x = np.asarray(context_sampler(k), dtype=float).reshape(-1)
            r = float(reward_generators[k](x))
            X[k].append(x); y[k].append(r)


    tau, V0, b, s_yy, sx, sy = step2_adaptive_elimination(
        T=T_int, X=X, y=y,
        s_minus_fn=s_minus_fn, s_plus_fn=s_plus_fn,
        gamma_fn=gamma_fn,
        q=q,
        context_sampler=context_sampler,
        reward_generators=reward_generators
    )


    t_counts = tau
    remaining = T_int - int(t_counts.sum())
    sigma2_hats = np.zeros(K, dtype=float)
    weights = np.zeros(K, dtype=float)
    if remaining > 0:
        for k in range(K):
            _, s2_k, _, _ = ridge_from_stats(V0[k], b[k], s_yy[k], sx[k], sy[k], t_counts[k], gamma_fn)
            sigma2_hats[k] = s2_k
        weights = np.power(sigma2_hats, q/2.0)
        weights = weights / weights.sum()

        lam_thresh = np.floor(weights*T_int)
        need = np.maximum(lam_thresh - tau, 0)
        order = np.argsort(-need)
        for k in order:
            if remaining <= 0:
                break
            add_k = int(min(need[k], remaining))
            for _ in range(add_k):
                x = np.asarray(context_sampler(k), dtype=float).reshape(-1)
                r = float(reward_generators[k](x))
                X[k].append(x); y[k].append(r)
            t_counts[k] += add_k
            remaining -= add_k


    beta_hat = []
    for k in range(K):
        beta_k, _, _, _ = ridge_beta_and_resid_var(X[k], y[k], gamma_fn=gamma_fn)
        beta_hat.append(beta_k)

    return beta_hat, np.array([len(y[k]) for k in range(K)], dtype=int), tau

import numpy as np
import pandas as pd
from tqdm.auto import tqdm



d = 4
q = 1.0
sigma_min = 1.0
T_list = np.logspace(3, 6, num=30, dtype=np.int64)
reps = 10
K_list = [5, 10, 20]
base_seed_global = 42


def metric_beta(beta_hat, betas):
    return float(sum(np.linalg.norm(beta_hat[k] - betas[k])**2 for k in range(len(betas))))

# ---- gamma function ----
def make_gamma_fn(eigenvalue_minimum: float):
    def gamma_fn(n: int) -> float:
        n_safe = max(int(n), 1)
        return float(eigenvalue_minimum) / n_safe
    return gamma_fn


def make_env_rng(seed, K, d, betas, stds):
    rng = np.random.default_rng(seed)

    R = 3.0
    a = R / np.sqrt(d)


    eigenvalue_minimum = (a**2) / 3.0

    def context_sampler(k):
        return rng.uniform(-a, a, size=d)

    reward_generators = []
    for k in range(K):
        th = betas[k].copy()
        sd = float(stds[k])
        def make_gen(th_=th, sd_=sd, rng_=rng):
            def gen(x):
                return float(np.dot(th_, x) + rng_.normal(scale=sd_))
            return gen
        reward_generators.append(make_gen())
    return rng, context_sampler, reward_generators, eigenvalue_minimum


def run_metrics_over_T(T_list, reps, K, d, q, sigma_min, betas, stds, base_seed=0):

    sigma_SG = float(np.sqrt(stds**2).max())

    metric_means = []
    regret_means = []
    all_taus = []
    all_nfinals = []
    records = []

    for iT, T in enumerate(tqdm(T_list, desc="T sweep", unit="T")):
        vals = []
        regrets = []
        taus = []
        nfinals = []

        for rep in tqdm(range(reps), desc=f"reps@T={int(T)}", unit="rep", leave=False):
            seed = base_seed + 10_000 * iT + rep
            _, context_sampler, reward_generators, eigenvalue_minimum = make_env_rng(
                seed, K, d, betas, stds
            )
            delta = 1.0 / (float(T)**2.5)
            gamma_fn = make_gamma_fn(eigenvalue_minimum)


            beta_hat, n_final, tau = contextual_exploration_free(
                T=int(T), d=d, K=K, q=q,
                sigma_min=sigma_min, sigma_SG=sigma_SG, delta=delta,
                context_sampler=context_sampler, reward_generators=reward_generators,
                eigenvalue_minimum=eigenvalue_minimum,
                gamma_fn=gamma_fn
            )


            metric_val = float(metric_beta(beta_hat, betas))
            vals.append(metric_val)


            n_arr = np.asarray(n_final, dtype=float)
            T_used = float(np.sum(n_arr))
            reg_scaled = compute_regret_scaled(
                all_final_n=np.asarray([n_arr]),  # (runs=1, K)
                stds=np.asarray(stds, dtype=float),
                T=T_used,
                d=d,
                lambda_min_C=float(eigenvalue_minimum)
            )
            reg_scaled = float(reg_scaled)
            regrets.append(reg_scaled)

            tau_arr = np.array(tau, dtype=int)
            nfin_arr = np.array(n_final, dtype=int)
            taus.append(tau_arr)
            nfinals.append(nfin_arr)

            records.append({
                "K": int(K),
                "T": int(T),
                "replication": int(rep + 1),
                "tau": tau_arr.tolist(),
                "n_final": nfin_arr.tolist(),
                "sum_n_final": int(nfin_arr.sum()),
                "regret": reg_scaled,
                "metric": metric_val,
                "variances": (np.asarray(stds)**2).tolist(),
                "lambda_min_C": float(eigenvalue_minimum),
                "d": int(d),
                "q": float(q),
            })

        metric_means.append(np.mean(np.asarray(vals, dtype=float)))
        regret_means.append(np.mean(np.asarray(regrets, dtype=float)))
        all_taus.append(taus)
        all_nfinals.append(nfinals)

    records_df = pd.DataFrame(records)

    return (np.asarray(metric_means),
            np.asarray(regret_means),
            all_taus,
            all_nfinals,
            records_df)


rng_global = np.random.default_rng(123)

all_results = []
all_records = []
agg_rows = []

for K in K_list:

    betas_K = [np.round(rng_global.uniform(-2, 2, size=d), 1) for _ in range(K)]
    variances_K = np.round(rng_global.uniform(1, 4, size=K), 1)
    stds_K = np.sqrt(variances_K)


    metric_mean, regret_mean, taus, nfinals, records_df = run_metrics_over_T(
        T_list=T_list, reps=reps, K=K, d=d, q=q, sigma_min=sigma_min,
        betas=betas_K, stds=stds_K, base_seed=base_seed_global
    )


    out_file = f"contextual_records_K{K}.csv"
    records_df.to_csv(out_file, index=False)
    print(f"[K={K}] saved: {out_file}")


    df_mean_K = pd.DataFrame({
        "K": K,
        "T": T_list.astype(np.int64),
        "metric_mean": metric_mean,
        "regret_mean": regret_mean,
    })
    df_mean_K.to_csv(f"metrics_mean_by_T_K{K}.csv", index=False)
    agg_rows.append(df_mean_K)


    all_results.append({
        "K": K,
        "T_list": T_list.copy(),
        "metric_mean": metric_mean,
        "regret_mean": regret_mean
    })
    all_records.append(records_df)


metrics_mean_all = pd.concat(agg_rows, ignore_index=True)
metrics_mean_all.to_csv("metrics_mean_by_K_T.csv", index=False)
print("saved: metrics_mean_by_K_T.csv")


records_all = pd.concat(all_records, ignore_index=True)
records_all.to_csv("contextual_records_all.csv", index=False)
print("saved: contextual_records_all.csv")


# np.savez("contextual_results_summary.npz", results=all_results)

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import ast


df_mean = pd.read_csv("metrics_mean_by_K_T.csv")
df_records = pd.read_csv("contextual_records_all.csv")


d = 4
R = 3.0
a = R / np.sqrt(d)
lambda_min_C = (a**2) / 3.0


color_map = {5: "red", 10: "blue", 20: "green"}


plt.rcParams.update({
    "font.size": 18,
    "axes.labelsize": 22,
    "legend.fontsize": 12,
    "xtick.labelsize": 18,
    "ytick.labelsize": 18,
})

def get_stds_for_K(K):
    row = df_records[df_records["K"] == K].iloc[0]
    variances = np.array(ast.literal_eval(row["variances"]), dtype=float)
    return np.sqrt(variances)

fig, ax = plt.subplots(figsize=(8, 6))
eps = 1e-12

for K in sorted(df_mean["K"].unique()):
    sub = df_mean[df_mean["K"] == K].copy().sort_values("T")
    T_vals = sub["T"].to_numpy(dtype=float)
    regret_mean = sub["regret_mean"].to_numpy(dtype=float)

    # UB(T) = C * T^{-2} * log T
    stds_K = get_stds_for_K(K)
    Sigma1 = stds_K.sum()
    C = 5.0 * d * K * (Sigma1**2) / lambda_min_C
    UB = C * (T_vals**(-2.0)) * np.log(np.maximum(T_vals, 1.0))


    x = np.log(np.maximum(T_vals, eps))
    y_reg = np.log(np.maximum(regret_mean, eps))
    y_ub  = np.log(np.maximum(UB, eps))

    color = color_map.get(K, "black")
    ax.plot(x, y_reg, "-o", color=color, linewidth=3, markersize=6, label=f"K={K} Regret")
    ax.plot(x, y_ub,  "--", color=color, linewidth=3, label=f"K={K} Upper Bound")

ax.set_xlabel(r"$\log T$")
ax.set_ylabel(r"$\log \mathrm{Regret}$")
ax.grid(True, which="both", alpha=0.4)
ax.legend(ncol=2)
fig.tight_layout()
plt.show()


def step2_adaptive_elimination_GSG(
    T: int,
    X: list[list[np.ndarray]],
    y: list[list[float]],
    eps_minus_fn,                          # Callable[[np.ndarray], np.ndarray]
    eps_plus_fn,                           # Callable[[np.ndarray], np.ndarray]
    gamma_fn,                            # gamma_k(n_k)
    q: float = 1.0,
    context_sampler=None,
    reward_generators=None,
):
    K = len(X)
    T_int = int(T)


    t_counts = np.array([len(y[k]) for k in range(K)], dtype=int)
    d = X[0][0].shape[0] if (K > 0 and len(X[0]) > 0) else 0
    V0 = [np.zeros((d, d), dtype=float) for _ in range(K)]
    b  = [np.zeros(d, dtype=float)       for _ in range(K)]
    s_yy = np.zeros(K, dtype=float)
    sx = [np.zeros(d, dtype=float)       for _ in range(K)]
    sy = np.zeros(K, dtype=float)

    for k in range(K):
        if t_counts[k] > 0:
            Ck = np.stack(X[k], axis=0)      # (n_k, d)
            yk = np.asarray(y[k], dtype=float)
            V0[k] = Ck.T @ Ck
            b[k]  = Ck.T @ yk
            s_yy[k] = float(yk @ yk)
            sx[k] = Ck.sum(axis=0)
            sy[k] = float(yk.sum())

    lam = t_counts.astype(float) / float(T_int)
    active = np.ones(K, dtype=bool)
    tau = np.full(K, -1, dtype=int)


    dirty = np.ones(K, dtype=bool)
    sigma2_hats = np.zeros(K, dtype=float)

    max_loops = 10000
    for _ in range(max_loops):
        if not active.any():
            break

        remaining = T_int - int(t_counts.sum())
        if remaining <= 0:
            break


        for k in np.flatnonzero(active):
            target_nk = int(np.floor(lam[k] * T_int))
            to_add = max(0, target_nk - t_counts[k])
            if to_add <= 0:
                continue
            add_k = min(to_add, remaining)
            if add_k <= 0:
                break


            for _ in range(add_k):
                xk = np.asarray(context_sampler(k), dtype=float).reshape(-1)
                rk = float(reward_generators[k](xk))
                X[k].append(xk); y[k].append(rk)

                V0[k] += np.outer(xk, xk)
                b[k]  += xk * rk
                s_yy[k] += rk * rk
                sx[k]  += xk
                sy[k]  += rk
                t_counts[k] += 1
                dirty[k] = True

            remaining -= add_k
            if remaining <= 0:
                break


        for k in range(K):
            if not dirty[k]:
                continue
            _, s2_k, _, _ = ridge_from_stats(V0[k], b[k], s_yy[k], sx[k], sy[k], t_counts[k], gamma_fn)
            sigma2_hats[k] = s2_k
            dirty[k] = False


        eps_minus = eps_minus_fn(t_counts)
        eps_plus  = eps_plus_fn(t_counts)
        lcb_vals = np.maximum(sigma2_hats - eps_plus, 0.0).astype(float)
        ucb_vals = (sigma2_hats + eps_minus).astype(float)


        pow_ = q / 2.0
        lcb_q = np.power(np.maximum(lcb_vals, 0.0), pow_)
        ucb_q = np.power(np.maximum(ucb_vals, 0.0), pow_)
        denom = lcb_q + (ucb_q.sum() - ucb_q)
        denom = np.where(denom <= 0.0, 1e-12, denom)
        lam_new = lcb_q / denom


        lam_thresh = np.floor(np.maximum(lam_new, 0.0) * T_int + 1e-12).astype(int)
        elim = t_counts >= lam_thresh
        tau[elim] = t_counts[elim]

        active = ~elim
        lam = lam_new

    return tau, V0, b, s_yy, sx ,sy

def contextual_exploration_free_GSG(
    T: int,
    d: int,
    K: int,
    q: float,
    sigma_min: float,
    sigma_SG: float,
    delta: float,
    context_sampler,             # Callable[[int], np.ndarray[d]]
    reward_generators,           # list[Callable[[np.ndarray[d]], float]]
    eigenvalue_minimum,
    gamma_fn,                    # callable: n -> gamma
):
    T_int = int(T)


    eps_minus_fn, eps_plus_fn = make_eps_funcs(sigma_SG, delta)


    X = [[] for _ in range(K)]
    y = [[] for _ in range(K)]

    # ---- Step 1: warm-up until s_minus(n) <= 1/m (budget-aware) ----
#     m = 1
#     t_counts = np.zeros(K, dtype=int)



#     n0 = int(np.ceil(18.0 * (m ** 2) * np.log(max(T_int, 2))))
#     n_max = T_int // K
#     n = min(n0, n_max)


#     for k in range(K):
#         for _ in range(n):
#             x = np.asarray(context_sampler(k), dtype=float).reshape(-1)
#             r = float(reward_generators[k](x))
#             X[k].append(x)
#             y[k].append(r)
#         t_counts[k] = n

#     total_pulls = int(t_counts.sum())


#     while (total_pulls + K) <= T_int:
#         s_val = float(s_minus_fn(float(n)))
#         if s_val <= 1.0 / float(m):
#             break

#         for k in range(K):
#             x = np.asarray(context_sampler(k), dtype=float).reshape(-1)
#             r = float(reward_generators[k](x))
#             X[k].append(x)
#             y[k].append(r)
#             t_counts[k] += 1

#         n += 1
#         total_pulls += K


    tau_1 = int((sigma_min**q) * T_int / ((sigma_min**q) + (K - 1) * (sigma_SG**q)))
    tau_1 = max(tau_1, 2)
    tau_1 = min(tau_1, max(T_int // max(K, 1), 2))
    for k in range(K):
        for _ in range(tau_1):
            x = np.asarray(context_sampler(k), dtype=float).reshape(-1)
            r = float(reward_generators[k](x))
            X[k].append(x); y[k].append(r)


    tau, V0, b, s_yy, sx, sy = step2_adaptive_elimination_GSG(
        T=T_int, X=X, y=y,
        eps_minus_fn=eps_minus_fn, eps_plus_fn=eps_plus_fn,
        gamma_fn=gamma_fn,
        q=q,
        context_sampler=context_sampler,
        reward_generators=reward_generators
    )


    t_counts = tau
    remaining = T_int - int(t_counts.sum())
    sigma2_hats = np.zeros(K, dtype=float)
    weights = np.zeros(K, dtype=float)
    if remaining > 0:
        for k in range(K):
            _, s2_k, _, _ = ridge_from_stats(V0[k], b[k], s_yy[k], sx[k], sy[k], t_counts[k], gamma_fn)
            sigma2_hats[k] = s2_k
        weights = np.power(sigma2_hats, q/2.0)
        weights = weights / weights.sum()

        lam_thresh = np.floor(weights*T_int)
        need = np.maximum(lam_thresh - tau, 0)
        order = np.argsort(-need)
        for k in order:
            if remaining <= 0:
                break
            add_k = int(min(need[k], remaining))
            for _ in range(add_k):
                x = np.asarray(context_sampler(k), dtype=float).reshape(-1)
                r = float(reward_generators[k](x))
                X[k].append(x); y[k].append(r)
            t_counts[k] += add_k
            remaining -= add_k


    beta_hat = []
    for k in range(K):
        beta_k, _, _, _ = ridge_beta_and_resid_var(X[k], y[k], gamma_fn=gamma_fn)
        beta_hat.append(beta_k)

    return beta_hat, np.array([len(y[k]) for k in range(K)], dtype=int), tau

import numpy as np
import pandas as pd
from tqdm.auto import tqdm



d = 4
q = 1.0
sigma_min = 1.0
T_list = np.logspace(3, 6, num=30, dtype=np.int64)
reps = 10
K_list = [5, 10, 20]
base_seed_global = 42


def metric_beta(beta_hat, betas):
    return float(sum(np.linalg.norm(beta_hat[k] - betas[k])**2 for k in range(len(betas))))

# ---- gamma function ----
def make_gamma_fn(eigenvalue_minimum: float):
    def gamma_fn(n: int) -> float:
        n_safe = max(int(n), 1)
        return float(eigenvalue_minimum) / n_safe
    return gamma_fn


def make_env_rng(seed, K, d, betas, stds):
    rng = np.random.default_rng(seed)

    R = 3.0
    a = R / np.sqrt(d)


    eigenvalue_minimum = (a**2) / 3.0

    def context_sampler(k):
        return rng.uniform(-a, a, size=d)

    reward_generators = []
    for k in range(K):
        th = betas[k].copy()
        sd = float(stds[k])
        def make_gen(th_=th, sd_=sd, rng_=rng):
            def gen(x):
                return float(np.dot(th_, x) + rng_.normal(scale=sd_))
            return gen
        reward_generators.append(make_gen())
    return rng, context_sampler, reward_generators, eigenvalue_minimum


def run_metrics_over_T(T_list, reps, K, d, q, sigma_min, betas, stds, base_seed=0):

    sigma_SG = float(np.sqrt(stds**2).max())

    metric_means = []
    regret_means = []
    all_taus = []
    all_nfinals = []
    records = []

    for iT, T in enumerate(tqdm(T_list, desc="T sweep", unit="T")):
        vals = []
        regrets = []
        taus = []
        nfinals = []

        for rep in tqdm(range(reps), desc=f"reps@T={int(T)}", unit="rep", leave=False):
            seed = base_seed + 10_000 * iT + rep
            _, context_sampler, reward_generators, eigenvalue_minimum = make_env_rng(
                seed, K, d, betas, stds
            )
            delta = 1.0 / (float(T)**2.5)
            gamma_fn = make_gamma_fn(eigenvalue_minimum)


            beta_hat, n_final, tau = contextual_exploration_free_GSG(
                T=int(T), d=d, K=K, q=q,
                sigma_min=sigma_min, sigma_SG=sigma_SG, delta=delta,
                context_sampler=context_sampler, reward_generators=reward_generators,
                eigenvalue_minimum=eigenvalue_minimum,
                gamma_fn=gamma_fn
            )


            metric_val = float(metric_beta(beta_hat, betas))
            vals.append(metric_val)


            n_arr = np.asarray(n_final, dtype=float)
            T_used = float(np.sum(n_arr))
            reg_scaled = compute_regret_scaled(
                all_final_n=np.asarray([n_arr]),  # (runs=1, K)
                stds=np.asarray(stds, dtype=float),
                T=T_used,
                d=d,
                lambda_min_C=float(eigenvalue_minimum)
            )
            reg_scaled = float(reg_scaled)
            regrets.append(reg_scaled)

            tau_arr = np.array(tau, dtype=int)
            nfin_arr = np.array(n_final, dtype=int)
            taus.append(tau_arr)
            nfinals.append(nfin_arr)

            records.append({
                "K": int(K),
                "T": int(T),
                "replication": int(rep + 1),
                "tau": tau_arr.tolist(),
                "n_final": nfin_arr.tolist(),
                "sum_n_final": int(nfin_arr.sum()),
                "regret": reg_scaled,
                "metric": metric_val,
                "variances": (np.asarray(stds)**2).tolist(),
                "lambda_min_C": float(eigenvalue_minimum),
                "d": int(d),
                "q": float(q),
            })

        metric_means.append(np.mean(np.asarray(vals, dtype=float)))
        regret_means.append(np.mean(np.asarray(regrets, dtype=float)))
        all_taus.append(taus)
        all_nfinals.append(nfinals)

    records_df = pd.DataFrame(records)

    return (np.asarray(metric_means),
            np.asarray(regret_means),
            all_taus,
            all_nfinals,
            records_df)


rng_global = np.random.default_rng(123)

all_results = []
all_records = []
agg_rows = []

for K in K_list:

    betas_K = [np.round(rng_global.uniform(-2, 2, size=d), 1) for _ in range(K)]
    variances_K = np.round(rng_global.uniform(1, 4, size=K), 1)
    stds_K = np.sqrt(variances_K)


    metric_mean, regret_mean, taus, nfinals, records_df = run_metrics_over_T(
        T_list=T_list, reps=reps, K=K, d=d, q=q, sigma_min=sigma_min,
        betas=betas_K, stds=stds_K, base_seed=base_seed_global
    )


    out_file = f"contextual_records_K{K}.csv"
    records_df.to_csv(out_file, index=False)
    print(f"[K={K}] saved: {out_file}")


    df_mean_K = pd.DataFrame({
        "K": K,
        "T": T_list.astype(np.int64),
        "metric_mean": metric_mean,
        "regret_mean": regret_mean,
    })
    df_mean_K.to_csv(f"metrics_mean_by_T_K{K}.csv", index=False)
    agg_rows.append(df_mean_K)


    all_results.append({
        "K": K,
        "T_list": T_list.copy(),
        "metric_mean": metric_mean,
        "regret_mean": regret_mean
    })
    all_records.append(records_df)


metrics_mean_all = pd.concat(agg_rows, ignore_index=True)
metrics_mean_all.to_csv("metrics_mean_by_K_T_GSG.csv", index=False)
print("saved: metrics_mean_by_K_T_GSG.csv")


records_all = pd.concat(all_records, ignore_index=True)
records_all.to_csv("contextual_records_all_GSG.csv", index=False)
print("saved: contextual_records_all_GSG.csv")


# np.savez("contextual_results_summary.npz", results=all_results)

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import ast


df_mean = pd.read_csv("metrics_mean_by_K_T_GSG.csv")
df_records = pd.read_csv("contextual_records_all_GSG.csv")


d = 4
R = 3.0
a = R / np.sqrt(d)
lambda_min_C = (a**2) / 3.0
sigma_SG = 2.0


color_map = {5: "red", 10: "blue", 20: "green"}


plt.rcParams.update({
    "font.size": 18,
    "axes.labelsize": 22,
    "legend.fontsize": 12,
    "xtick.labelsize": 18,
    "ytick.labelsize": 18,
})

LINEWIDTH = 3.0
MARKERSIZE = 6
EPS = 1e-12


def get_stds_for_K(K):
    row = df_records[df_records["K"] == K].iloc[0]
    variances = np.array(ast.literal_eval(row["variances"]), dtype=float)
    return np.sqrt(variances)


def Sigma(variances, alpha):
    variances = np.asarray(variances, dtype=float)
    return np.sum(variances ** (alpha / 2.0))

def F_p3(variances, p):
    q_loc = 2.0 * p / (p + 1.0)
    # (Sigma(q_loc))^(2/q_loc) * Sigma(-4.0)
    return (p**2 * (Sigma(variances, q_loc))**(2.0 / q_loc) * Sigma(variances, -4.0)) / (p + 1.0)


fig, ax = plt.subplots(figsize=(8, 6))

for K in sorted(df_mean["K"].unique()):
    sub = df_mean[df_mean["K"] == K].copy().sort_values("T")
    T_vals = sub["T"].to_numpy(dtype=float)
    regret_mean = sub["regret_mean"].to_numpy(dtype=float)


    stds_K = get_stds_for_K(K)
    variances_K = stds_K**2



    C = 80.0 * d * F_p3(variances_K, p=1) * (sigma_SG**2) / lambda_min_C
    UB = C * (T_vals**(-2.0)) * np.log(np.maximum(T_vals, 1.0))


    x = np.log(np.maximum(T_vals, EPS))
    y_reg = np.log(np.maximum(regret_mean, EPS))
    y_ub  = np.log(np.maximum(UB, EPS))

    color = color_map.get(K, "black")


    ax.plot(x, y_reg, "-o", color=color, linewidth=LINEWIDTH, markersize=MARKERSIZE,
            label=f"K={K} Regret")

    ax.plot(x, y_ub,  "--", color=color, linewidth=LINEWIDTH,
            label=f"K={K} Upper Bound")

ax.set_xlabel(r"$\log T$")
ax.set_ylabel(r"$\log \mathrm{Regret}$")
ax.grid(True, which="both", alpha=0.4)


ax.legend(loc="upper right", fontsize=12, frameon=True, framealpha=0.9, ncol=1)

fig.tight_layout()
plt.show()
