{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import json\n",
    "import torch \n",
    "import random\n",
    "import traceback\n",
    "import numpy as np\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import matplotlib.pyplot as plt\n",
    "import torch.nn.functional as F\n",
    "\n",
    "from typing import List, Tuple\n",
    "from utils import  judge_multi_choice\n",
    "from confseq.betting import betting_lower_cs, betting_cs\n",
    "from confseq.conjmix_bounded import conjmix_empbern_lower_cs\n",
    "from confseq.predmix import predmix_empbern_lower_cs, predmix_empbern_twosided_cs\n",
    "from confseq.boundaries import normal_mixture_bound, gamma_exponential_mixture_bound\n",
    "from confseq.conjmix_bounded import conjmix_empbern_lower_cs, conjmix_empbern_twosided_cs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "plt.rcParams['font.family'] = 'Times New Roman'\n",
    "def set_seed(seed: int):\n",
    "    import os\n",
    "    import random\n",
    "    import numpy as np\n",
    "    import torch\n",
    "    random.seed(seed)\n",
    "    np.random.seed(seed)\n",
    "    torch.manual_seed(seed)\n",
    "    torch.cuda.manual_seed(seed)\n",
    "    torch.cuda.manual_seed_all(seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Mod = 16  # 16QAM\n",
    "\n",
    "def complex_normal(mean, cov):\n",
    "    \"\"\"\n",
    "    Circularly-symmetric complex normal CN(mean, cov) sampler.\n",
    "    \"\"\"\n",
    "    m = mean.shape[0]\n",
    "    x = np.random.randn(m) + 1j*np.random.randn(m)\n",
    "    jitter = 0.0\n",
    "    for _ in range(3):\n",
    "        try:\n",
    "            L = np.linalg.cholesky(cov + jitter*np.eye(cov.shape[0]))\n",
    "            break\n",
    "        except np.linalg.LinAlgError:\n",
    "            jitter = max(1e-12, 10*(jitter if jitter > 0 else 1e-12))\n",
    "    else:\n",
    "        # 仍失败时用特征分解兜底\n",
    "        w, V = np.linalg.eigh((cov + cov.conj().T)/2)\n",
    "        w = np.clip(w, 0, None)\n",
    "        L = V @ np.diag(np.sqrt(w)) @ V.conj().T\n",
    "    return mean + L @ (x/np.sqrt(2.0))\n",
    "\n",
    "def kron(a,b):\n",
    "    return np.kron(a,b)\n",
    "\n",
    "def vec(A):\n",
    "    return A.reshape(-1, order='F')  # column-stacking\n",
    "\n",
    "# =============================\n",
    "# I-A: Channel Model pieces\n",
    "# =============================\n",
    "\n",
    "def ula_response(M:int, theta:float) -> np.ndarray:\n",
    "    \"\"\"\n",
    "    ULA response vector a(theta) for half-wavelength spacing.\n",
    "    a(theta) in C^{M}\n",
    "    \"\"\"\n",
    "    m = np.arange(M)\n",
    "    return np.exp(1j*np.pi*m*np.sin(theta))\n",
    "\n",
    "def wrap_dist(delta, theta):\n",
    "    \"\"\"\n",
    "    Wrap-around angular distance on [-pi, pi].\n",
    "    \"\"\"\n",
    "    d = delta - theta\n",
    "    d = (d + np.pi) % (2*np.pi) - np.pi\n",
    "    return np.abs(d)\n",
    "\n",
    "def laplace_density(theta_grid, delta, sigma_asd):\n",
    "    \"\"\"\n",
    "    Laplace-like power density g(theta; delta) ∝ exp( - d2π(delta, theta) / sigma_asd ).\n",
    "    \"\"\"\n",
    "    d = np.array([wrap_dist(delta, th) for th in theta_grid])\n",
    "    g = np.exp(-d / max(sigma_asd, 1e-12))\n",
    "    g /= (np.trapz(g, theta_grid) + 1e-12)\n",
    "    return g\n",
    "\n",
    "def spatial_covariance_ula(M:int, delta:float, sigma_asd:float, num_theta:int=1024) -> np.ndarray:\n",
    "    \"\"\"\n",
    "    Numerical integral Ct = ∫ g(θ;δ) a(θ)a(θ)^H dθ\n",
    "    \"\"\"\n",
    "    theta_grid = np.linspace(-np.pi, np.pi, num_theta)\n",
    "    g = laplace_density(theta_grid, delta, sigma_asd)\n",
    "    A = np.stack([ula_response(M, th) for th in theta_grid], axis=1)  # M x T\n",
    "    dth = (theta_grid[-1]-theta_grid[0])/(num_theta-1)\n",
    "    Ct = (A * g) @ A.conj().T * dth\n",
    "    return (Ct + Ct.conj().T)/2\n",
    "\n",
    "def estimate_cov_noise_psd_ls_from_blocks(Yc_list: List[np.ndarray],\n",
    "                                          xc: np.ndarray,\n",
    "                                          M: int, nc: int) -> Tuple[np.ndarray, float]:\n",
    "    L = len(Yc_list)\n",
    "    Mn = M*nc\n",
    "    Sigma_tilde = np.zeros((Mn, Mn), dtype=complex)\n",
    "    for Yc in Yc_list:\n",
    "        y_stack = vec(Yc)[:, None]              # (Mn, 1)\n",
    "        Sigma_tilde += y_stack @ y_stack.conj().T\n",
    "    Sigma_tilde /= max(L, 1)\n",
    "    Sigma_tilde = (Sigma_tilde + Sigma_tilde.conj().T)/2  # \n",
    "\n",
    "    # Step 3: P_perp = I - Xc (Xc^H Xc)^{-1} Xc^H\n",
    "    Xc = kron(xc.reshape(-1,1), np.eye(M))      # (Mn, M)\n",
    "    xc_norm2 = float((xc.conj().T @ xc).real) + 1e-12\n",
    "    Ginv = np.eye(M) / xc_norm2                 # (Xc^H Xc)^{-1} = (1/||xc||^2) I\n",
    "    Proj = Xc @ Ginv @ Xc.conj().T              # (Mn, Mn)\n",
    "    P_perp = np.eye(Mn) - Proj\n",
    "\n",
    "    # σ̂^2_LS = 1/(M (nc-1)) * trace(P_perp Σ̃c P_perp)\n",
    "    tmp = P_perp @ Sigma_tilde @ P_perp\n",
    "    sigma2_ls = np.trace(tmp).real / (M * max(nc-1, 1))\n",
    "\n",
    "    # Step 4: Whitening Xc' = Xc (Xc^H Xc)^(-1/2) = Xc * (1/||xc||) I\n",
    "    Xc_prime = Xc / (np.sqrt(xc_norm2) + 1e-12)\n",
    "\n",
    "    # Step 5: EVD of B = Xc'^H Σ̃c Xc'\n",
    "    B = Xc_prime.conj().T @ Sigma_tilde @ Xc_prime\n",
    "    B = (B + B.conj().T)/2\n",
    "    lam, V = np.linalg.eigh(B)\n",
    "    idx = np.argsort(lam)[::-1]\n",
    "    lam = lam[idx].real\n",
    "    V = V[:, idx]\n",
    "\n",
    "    sigma2 = float(sigma2_ls)\n",
    "    d = np.zeros(M, dtype=float)\n",
    "    Z = []  \n",
    "    for i in range(M-1, -1, -1):  # \n",
    "        if (lam[i] - sigma2) < 0:\n",
    "            d[i] = 0.0\n",
    "            Z.append(i)\n",
    "            sigma2 = (M*(nc-1)*sigma2_ls + float(np.sum(lam[Z]))) / (M*(nc-1) + len(Z))\n",
    "        else:\n",
    "            d[i] = lam[i] - sigma2\n",
    "\n",
    "    D = np.diag(d)\n",
    "\n",
    "    C_hat = (V @ D @ V.conj().T) / (xc_norm2 + 1e-24)\n",
    "    C_hat = (C_hat + C_hat.conj().T)/2\n",
    "    sigma2_hat = float(sigma2)\n",
    "    return C_hat, sigma2_hat\n",
    "\n",
    "def mmse_channel_estimate(C_hat: np.ndarray, sigma2_hat: float, xh: np.ndarray, yh: np.ndarray, M:int) -> np.ndarray:\n",
    "    nh = xh.shape[0]\n",
    "    Xh = kron(xh.reshape(-1,1), np.eye(M))  # (M*nh) x M\n",
    "    A = Xh @ C_hat @ Xh.conj().T + sigma2_hat * np.eye(M*nh)\n",
    "    # Solve A^{-1} y\n",
    "    z = np.linalg.solve(A, yh)\n",
    "    h_hat = C_hat @ Xh.conj().T @ z\n",
    "    return h_hat\n",
    "\n",
    "def mmse_equalize(h_hat: np.ndarray, yk: np.ndarray, sigma2: float, rho: float) -> complex:\n",
    "    # y = h s + z, E|s|^2 = rho  =>  ŝ = (h^H y) / (||h||^2 + sigma2/rho)\n",
    "    denom = (np.vdot(h_hat, h_hat) + (sigma2 / (rho + 1e-12))).real + 1e-12\n",
    "    return np.vdot(h_hat, yk) / denom\n",
    "\n",
    "def qam16_constellation():\n",
    "    \"\"\" Gray-coded 16QAM constellation, normalized to unit average power. \"\"\"\n",
    "    re = np.array([-3, -1, 1, 3])\n",
    "    im = np.array([-3, -1, 1, 3])\n",
    "    const = np.array([x + 1j*y for x in re for y in im])\n",
    "    const /= np.sqrt((np.mean(np.abs(const)**2)))  # normalize average power=1\n",
    "    return const\n",
    "\n",
    "    \n",
    "def qam64_constellation():\n",
    "    \"\"\" Gray-coded 64QAM constellation, normalized to unit average power. \"\"\"\n",
    "    re = np.array([-7, -5, -3, -1, 1, 3, 5, 7])\n",
    "    im = np.array([-7, -5, -3, -1, 1, 3, 5, 7])\n",
    "    const = np.array([x + 1j*y for x in re for y in im])\n",
    "    const /= np.sqrt(np.mean(np.abs(const)**2))  # normalize average power=1\n",
    "    return const\n",
    "\n",
    "\n",
    "def qam16_mod(bits):\n",
    "    \"\"\"Map random bits (len multiple of 4) to 16QAM symbols.\"\"\"\n",
    "    bits = np.array(bits).reshape(-1,4)\n",
    "    idx = bits[:,0]*8 + bits[:,1]*4 + bits[:,2]*2 + bits[:,3]\n",
    "    return CONST[idx]\n",
    "\n",
    "def qam16_demod(symbols):\n",
    "    \"\"\"Nearest neighbor hard decision for 16QAM.\"\"\"\n",
    "    decisions = []\n",
    "    for s in symbols:\n",
    "        idx = np.argmin(np.abs(s - CONST))\n",
    "        decisions.append(CONST[idx])\n",
    "    return np.array(decisions)\n",
    "\n",
    "class DecoderNN(nn.Module):\n",
    "    def __init__(self, n_classes=16):\n",
    "        super().__init__()\n",
    "        self.net = nn.Sequential(\n",
    "            nn.Linear(2, 64),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(64, 64),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(64, n_classes)\n",
    "        )\n",
    "\n",
    "    def forward(self, x):\n",
    "        # x: (..., 2) real features [Re, Im]\n",
    "        return self.net(x)\n",
    "\n",
    "\n",
    "def train_decoder_nn(model: nn.Module, C0: np.ndarray, C_hat: np.ndarray, sigma2_hat: float,\n",
    "                     rho: float, M: int, nh: int, xh: np.ndarray,\n",
    "                     n_train: int = 20000, batch_size: int = 256, epochs: int = 10,\n",
    "                     device: str = 'cpu') -> Tuple[DecoderNN, dict]:\n",
    "    # np.random.seed(0)\n",
    "    X = np.zeros((n_train, 2), dtype=np.float32)\n",
    "    Y = np.zeros((n_train,), dtype=np.int64)\n",
    "\n",
    "    for i in range(n_train):\n",
    "        # draw channel\n",
    "        h = complex_normal(np.zeros(M, dtype=complex), C0)\n",
    "        Zh = (np.random.randn(M, nh) + 1j*np.random.randn(M, nh)) * np.sqrt(sigma2_hat/2)\n",
    "        Yh = np.outer(h, xh) + Zh\n",
    "        yh_stack = vec(Yh)\n",
    "        h_hat = mmse_channel_estimate(C_hat, sigma2_hat, xh, yh_stack, M)\n",
    "\n",
    "        # random symbol\n",
    "        idx = np.random.randint(0, M)\n",
    "        s = CONST[idx] * np.sqrt(rho)\n",
    "        z = (np.random.randn(M) + 1j*np.random.randn(M)) * np.sqrt(sigma2_hat/2)\n",
    "        yk = h * s + z\n",
    "        s_eq = mmse_equalize(h_hat, yk, sigma2_hat, rho)\n",
    "        X[i,0] = s_eq.real\n",
    "        X[i,1] = s_eq.imag\n",
    "        Y[i] = idx\n",
    "    device = device\n",
    "    X_t = torch.from_numpy(X).to(device)\n",
    "    Y_t = torch.from_numpy(Y).to(device)\n",
    "\n",
    "    optimizer = optim.Adam(model.parameters(), lr=1e-3)\n",
    "    loss_fn = nn.CrossEntropyLoss(label_smoothing=0.3)\n",
    "\n",
    "    n_batches = int(np.ceil(n_train / batch_size))\n",
    "    for epoch in range(epochs):\n",
    "        perm = np.random.permutation(n_train)\n",
    "        X_t = X_t[perm]\n",
    "        Y_t = Y_t[perm]\n",
    "        epoch_loss = 0.0\n",
    "        correct = 0\n",
    "        for b in range(n_batches):\n",
    "            st = b*batch_size\n",
    "            en = min((b+1)*batch_size, n_train)\n",
    "            xb = X_t[st:en]\n",
    "            yb = Y_t[st:en]\n",
    "            logits = model(xb)\n",
    "            loss = loss_fn(logits, yb)\n",
    "            optimizer.zero_grad()\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            epoch_loss += loss.item()*(en-st)\n",
    "            preds = logits.argmax(dim=1)\n",
    "            correct += (preds == yb).sum().item()\n",
    "        acc = correct / n_train\n",
    "        print(f\"[Decoder Train] Epoch {epoch+1}/{epochs}, Loss={epoch_loss/n_train:.4f}, Acc={acc:.4f}\")\n",
    "\n",
    "    return model, {'train_samples': n_train}\n",
    "\n",
    "def compute_v_opt(x, t_opt):\n",
    "    x = np.array(x)\n",
    "    t = np.arange(1, len(x) + 1)\n",
    "    S_t = np.cumsum(x)\n",
    "    mu_hat_t = S_t / t\n",
    "    mu_hat_tminus1 = np.append(1 / 2, mu_hat_t[0 : (len(mu_hat_t) - 1)])\n",
    "    V_t = np.cumsum(np.power(x - mu_hat_tminus1, 2))\n",
    "    v_opt = V_t[t_opt] * t_opt\n",
    "    return v_opt\n",
    "\n",
    "def running_average_cumulative(x): \n",
    "    return np.cumsum(x) / (np.arange(len(x)) + 1)\n",
    "\n",
    "if Mod==16:\n",
    "    CONST = qam16_constellation()\n",
    "elif Mod==64:\n",
    "    CONST = qam64_constellation()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def shuffle_by_severity_multi(trajs, severities):\n",
    "    shuffled_results = [[] for _ in trajs]\n",
    "    start = 0\n",
    "    n_sev = len(severities)\n",
    "    length_per_sev = len(trajs[0]) // n_sev  \n",
    "    for sev in severities:\n",
    "        idx = np.arange(length_per_sev)\n",
    "        np.random.shuffle(idx)\n",
    "        # print(idx)\n",
    "        for i, traj in enumerate(trajs):\n",
    "            part = traj[start:start + length_per_sev]\n",
    "            shuffled_results[i].append(part[idx])  #\n",
    "\n",
    "        start += length_per_sev\n",
    "    return [np.concatenate(parts) for parts in shuffled_results]\n",
    "\n",
    "para = 1\n",
    "def conjmix_empbern_cs_flexible(x, v_opt, alpha=0.05, c=1,  running_intersection=False):\n",
    "    x = np.array(x)\n",
    "    t = np.arange(1, len(x) + 1)\n",
    "    S_t = np.cumsum(x)\n",
    "    mu_hat_t = S_t / t\n",
    "    mu_hat_tminus1 = np.append(1/2., mu_hat_t[0:(len(mu_hat_t) - 1)])\n",
    "    V_t = np.cumsum(np.power(x - mu_hat_tminus1, 2))\n",
    "    bdry = (gamma_exponential_mixture_bound(\n",
    "        V_t, alpha=alpha / 2, v_opt=v_opt, c=c, alpha_opt=alpha / 2) / t)\n",
    "    l, u = mu_hat_t - bdry, mu_hat_t + bdry\n",
    "    l = np.maximum(l, -1 * para)\n",
    "    u = np.minimum(u, 2 * para)\n",
    "    if running_intersection:\n",
    "        l = np.maximum.accumulate(l)\n",
    "        u = np.minimum.accumulate(u)\n",
    "    return l, u"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.random.seed(5)\n",
    "M = 16\n",
    "nc = 3   \n",
    "nh = 3   \n",
    "SNR_dB = 6.\n",
    "sigma2_true = 1.0\n",
    "rho = sigma2_true * 10**(SNR_dB/10.0)\n",
    "\n",
    "delta0 = 0.2\n",
    "sigma_asd = 0.1\n",
    "C0 = spatial_covariance_ula(M, delta0, sigma_asd)\n",
    "\n",
    "xc = (np.random.randn(nc) + 1j*np.random.randn(nc)) / np.sqrt(2) * np.sqrt(rho)\n",
    "\n",
    "L = 1500  \n",
    "Yc_list = []\n",
    "for _ in range(L):\n",
    "    h0 = complex_normal(np.zeros(M, dtype=complex), C0)\n",
    "    Zc = (np.random.randn(M, nc) + 1j*np.random.randn(M, nc)) * np.sqrt(sigma2_true/2)\n",
    "    Yc = np.outer(h0, xc) + Zc\n",
    "    Yc_list.append(Yc)\n",
    "\n",
    "C_hat, sigma2_hat = estimate_cov_noise_psd_ls_from_blocks(Yc_list, xc, M, nc)\n",
    "print(\"Estimated sigma2_hat:\", sigma2_hat)\n",
    "\n",
    "xh = ((np.random.randn(nh) + 1j*np.random.randn(nh)) / np.sqrt(2)) * np.sqrt(rho)\n",
    "\n",
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "model = DecoderNN(n_classes=Mod).to(device)\n",
    "decoder, info = train_decoder_nn(model=model, C0=C0, C_hat=C_hat, sigma2_hat=sigma2_hat,\n",
    "                                rho=rho, M=M, nh=nh, xh=xh,\n",
    "                                n_train=10000, batch_size=128, epochs=10,\n",
    "                                device=device)\n",
    "\n",
    "h1 = complex_normal(np.zeros(M, dtype=complex), C0)\n",
    "Zh = (np.random.randn(M, nh) + 1j*np.random.randn(M, nh)) * np.sqrt(sigma2_true/2)\n",
    "Yh = np.outer(h1, xh) + Zh\n",
    "yh_stack = vec(Yh)\n",
    "\n",
    "h1_hat = mmse_channel_estimate(C_hat, sigma2_hat, xh, yh_stack, M)\n",
    "\n",
    "num_data = 20000\n",
    "idx_tx = np.random.randint(0, M, size=(num_data,))\n",
    "s_tx = CONST[idx_tx] * np.sqrt(rho)\n",
    "Z_data = (np.random.randn(M, num_data) + 1j*np.random.randn(M, num_data)) * np.sqrt(sigma2_true/2)\n",
    "Y_data = h1.reshape(-1,1) @ s_tx.reshape(1,-1) + Z_data\n",
    "\n",
    "X_test = np.zeros((num_data, 2), dtype=np.float32)\n",
    "for k in range(num_data):\n",
    "    yk = Y_data[:, k]\n",
    "    s_eq = mmse_equalize(h1_hat, yk, sigma2_hat, rho)\n",
    "    X_test[k,0] = s_eq.real\n",
    "    X_test[k,1] = s_eq.imag\n",
    "decoder.eval()\n",
    "with torch.no_grad():\n",
    "    X_t = torch.from_numpy(X_test).to(device)\n",
    "    logits = decoder(X_t)\n",
    "    preds = logits.argmax(dim=1).cpu().numpy()\n",
    "ser_nn = np.mean(preds != idx_tx)\n",
    "print(f\"SER using NN decoder: {ser_nn:.6f}\")\n",
    "const_scaled = CONST * np.sqrt(rho)\n",
    "s_hat_nn = preds  # indices\n",
    "# classical demod\n",
    "decided_idxs = []\n",
    "for k in range(num_data):\n",
    "    # hard decision nearest neighbor\n",
    "    distances = np.abs(X_test[k,0] + 1j*X_test[k,1] - const_scaled)\n",
    "    decided_idxs.append(int(np.argmin(distances)))\n",
    "decided_idxs = np.array(decided_idxs)\n",
    "ser_classic = np.mean(decided_idxs != idx_tx)\n",
    "print(f\"SER using nearest-neighbor (classical) demod on equalized symbols: {ser_classic:.6f}\")\n",
    "print(\"Done.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "temperature = 1.0\n",
    "n_test = 100\n",
    "saved_files = []\n",
    "delta = 0.2\n",
    "print(f\"\\n=== Running experiment for delta = {delta:.3f} ===\")\n",
    "C = spatial_covariance_ula(M, delta, sigma_asd)\n",
    "pilot_bits = np.random.randint(0, 16, size=(nh,))\n",
    "pilot_syms = CONST[pilot_bits] * np.sqrt(rho)\n",
    "xhhh = ((np.random.randn(nh) + 1j*np.random.randn(nh)) / np.sqrt(2)) * np.sqrt(rho)\n",
    "pilot_labels = []\n",
    "for s in xhhh: \n",
    "    idx = np.argmin(np.abs(s/np.sqrt(rho) - CONST))  # \n",
    "    pilot_labels.append(idx)\n",
    "pilot_labels = np.array(pilot_labels)\n",
    "sum_risk = []\n",
    "num_count = 0\n",
    "for trial in range(2000):  \n",
    "    h = complex_normal(np.zeros(M, dtype=complex), C)\n",
    "    Zh = (np.random.randn(M, nh) + 1j*np.random.randn(M, nh)) * np.sqrt(sigma2_true/2)\n",
    "    Yh = np.outer(h, xh) + Zh\n",
    "    yh_stack = vec(Yh)\n",
    "    h_hat = mmse_channel_estimate(C_hat, sigma2_hat, xh, yh_stack, M)\n",
    "\n",
    "    X_pilot = np.zeros((nh, 2), dtype=np.float32)\n",
    "    for k in range(nh):\n",
    "        s_eq = mmse_equalize(h_hat, Yh[:,k], sigma2_hat, rho)\n",
    "        X_pilot[k,0] = s_eq.real\n",
    "        X_pilot[k,1] = s_eq.imag\n",
    "\n",
    "    decoder.eval()\n",
    "    with torch.no_grad():\n",
    "        X_t = torch.from_numpy(X_pilot).to(device)\n",
    "        logits_pilot = decoder(X_t).cpu().numpy()\n",
    "\n",
    "    idx_tx = np.random.randint(0, Mod, size=(n_test,))\n",
    "    s_tx = CONST[idx_tx] * np.sqrt(rho)\n",
    "    Z_data = (np.random.randn(M, n_test) + 1j*np.random.randn(M, n_test)) * np.sqrt(sigma2_true/2)\n",
    "    Y_data = h.reshape(-1,1) @ s_tx.reshape(1,-1) + Z_data\n",
    "\n",
    "    X_test = np.zeros((n_test, 2), dtype=np.float32)\n",
    "    for k in range(n_test):\n",
    "        yk = Y_data[:, k]\n",
    "        s_eq = mmse_equalize(h_hat, yk, sigma2_hat, rho)\n",
    "        X_test[k,0] = s_eq.real\n",
    "        X_test[k,1] = s_eq.imag\n",
    "\n",
    "    decoder.eval()\n",
    "    with torch.no_grad():\n",
    "        X_t = torch.from_numpy(X_test).to(device)\n",
    "        logits = decoder(X_t).cpu().numpy()\n",
    "        # softmax \n",
    "        temperature = 3.2\n",
    "        probs_data = np.exp(logits/temperature)\n",
    "        probs_data /= np.sum(probs_data, axis=-1, keepdims=True)\n",
    "        Y_real_data = np.zeros_like(probs_data)\n",
    "        Y_real_data[np.arange(len(idx_tx)), idx_tx] = 1\n",
    "        risk_item = np.mean(np.sum((probs_data - Y_real_data)**2, axis=1)) * 0.5\n",
    "        sum_risk.append(risk_item)\n",
    "print(f\"=== The risk on the setting is {np.mean(np.array(sum_risk)):.3f} ===\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_test = 200   \n",
    "n_trails = 200  \n",
    "temperature = 1.\n",
    "delta_list = [0.2, 0.2,0.2, 0.8, 0.8,0.8, 0.9]\n",
    "n_split, n_start = 1, 0\n",
    "risk_traj_supervised_true = []\n",
    "risk_traj_supervised_pred = []\n",
    "risk_traj_unsupervised_pred = []\n",
    "risk_traj_ppi_ideal = []\n",
    "\n",
    "for i, delta in enumerate(delta_list):\n",
    "    print(f\"=== Running experiment for delta = {delta:.2f} ===\")\n",
    "    C = spatial_covariance_ula(M, delta, sigma_asd)\n",
    "    pilot_bits = np.random.randint(0, 16, size=(nh,))\n",
    "    pilot_syms = CONST[pilot_bits] * np.sqrt(rho)\n",
    "\n",
    "    xhhh = ((np.random.randn(nh) + 1j*np.random.randn(nh)) / np.sqrt(2)) * np.sqrt(rho)\n",
    "\n",
    "    pilot_labels = np.array([\n",
    "        np.argmin(np.abs(s / np.sqrt(rho) - CONST)) for s in xhhh\n",
    "    ])\n",
    "\n",
    "    logits_trials, labels_trials = [], []\n",
    "    logits_pilot_trials, labels_pilot_trials = [], []\n",
    "\n",
    "    for trial in range(n_trails):\n",
    "        h = complex_normal(np.zeros(M, dtype=complex), C)\n",
    "        Zh = (np.random.randn(M, nh) + 1j*np.random.randn(M, nh)) * np.sqrt(sigma2_true / 2)\n",
    "        Yh = np.outer(h, xh) + Zh\n",
    "        yh_stack = vec(Yh)\n",
    "        h_hat = mmse_channel_estimate(C_hat, sigma2_hat, xh, yh_stack, M)\n",
    "\n",
    "        X_pilot = np.column_stack([\n",
    "            [mmse_equalize(h_hat, Yh[:, k], sigma2_hat, rho).real for k in range(nh)],\n",
    "            [mmse_equalize(h_hat, Yh[:, k], sigma2_hat, rho).imag for k in range(nh)]\n",
    "        ]).astype(np.float32)\n",
    "\n",
    "        decoder.eval()\n",
    "        with torch.no_grad():\n",
    "            X_t = torch.from_numpy(X_pilot).to(device)\n",
    "            logits_pilot = decoder(X_t).cpu().numpy()\n",
    "\n",
    "        logits_pilot_trials.append(logits_pilot)\n",
    "        labels_pilot_trials.append(pilot_labels)\n",
    "\n",
    "        idx_tx = np.random.randint(0, Mod, size=(n_test,))\n",
    "        s_tx = CONST[idx_tx] * np.sqrt(rho)\n",
    "        Z_data = (np.random.randn(M, n_test) + 1j*np.random.randn(M, n_test)) * np.sqrt(sigma2_true / 2)\n",
    "        Y_data = h.reshape(-1, 1) @ s_tx.reshape(1, -1) + Z_data\n",
    "\n",
    "        X_test = np.column_stack([\n",
    "            [mmse_equalize(h_hat, Y_data[:, k], sigma2_hat, rho).real for k in range(n_test)],\n",
    "            [mmse_equalize(h_hat, Y_data[:, k], sigma2_hat, rho).imag for k in range(n_test)]\n",
    "        ]).astype(np.float32)\n",
    "\n",
    "        decoder.eval()\n",
    "        with torch.no_grad():\n",
    "            X_t = torch.from_numpy(X_test).to(device)\n",
    "            logits = decoder(X_t).cpu().numpy()\n",
    "\n",
    "        logits_trials.append(logits)\n",
    "        labels_trials.append(idx_tx)\n",
    "\n",
    "        probs_data = np.exp(logits / temperature)\n",
    "        probs_data /= np.sum(probs_data, axis=-1, keepdims=True)\n",
    "\n",
    "        labels_data = idx_tx\n",
    "        Y_true = np.zeros_like(probs_data)\n",
    "        Y_true[np.arange(len(labels_data)), labels_data] = 1\n",
    "\n",
    "        probs_segment = probs_data[n_start:n_start + n_split]\n",
    "        Y_segment_true = Y_true[n_start:n_start + n_split]\n",
    "        risk_supervised_true = 0.5 * np.mean(np.sum((probs_segment - Y_segment_true) ** 2, axis=1))\n",
    "        risk_traj_supervised_true.append(risk_supervised_true)\n",
    "\n",
    "        pseudo_labels_segment = np.argmax(probs_segment, axis=1)\n",
    "        Y_segment_pseudo = np.zeros_like(probs_segment)\n",
    "        Y_segment_pseudo[np.arange(len(pseudo_labels_segment)), pseudo_labels_segment] = 1\n",
    "        risk_supervised_pred = 0.5 * np.mean(np.sum((probs_segment - Y_segment_pseudo) ** 2, axis=1))\n",
    "        risk_traj_supervised_pred.append(risk_supervised_pred)\n",
    "\n",
    "        probs_rest = probs_data[n_start + n_split:]\n",
    "        pseudo_labels_rest = np.argmax(probs_rest, axis=1)\n",
    "        Y_rest_pseudo = np.zeros_like(probs_rest)\n",
    "        Y_rest_pseudo[np.arange(len(pseudo_labels_rest)), pseudo_labels_rest] = 1\n",
    "        risk_unsupervised_pred = 0.5 * np.mean(np.sum((probs_rest - Y_rest_pseudo) ** 2, axis=1))\n",
    "        risk_traj_unsupervised_pred.append(risk_unsupervised_pred)\n",
    "\n",
    "        Y_rest_true = Y_true[n_start + n_split:]\n",
    "        risk_ppi_ideal = 0.5 * np.mean(np.sum((probs_rest - Y_rest_true) ** 2, axis=1))\n",
    "        risk_traj_ppi_ideal.append(risk_ppi_ideal)\n",
    "    \n",
    "    print(f\"=== The risk on the setting is {np.mean(np.array(risk_traj_ppi_ideal)):.3f} ===\")\n",
    "    print(f\"=== The risk on the pseudo is {np.mean(np.array(risk_traj_unsupervised_pred)):.3f} ===\")\n",
    "\n",
    "risk_traj_supervised_true =   np.array(risk_traj_supervised_true)\n",
    "risk_traj_supervised_pred =   np.array(risk_traj_supervised_pred)\n",
    "risk_traj_unsupervised_pred = np.array(risk_traj_unsupervised_pred)\n",
    "risk_traj_ppi_ideal =  np.array(risk_traj_ppi_ideal)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_test = 12   # \n",
    "n_trails = 250   \n",
    "temperature = 1.0\n",
    "delta_list = [0.2, 0.8, 0.8, 0.8,  1.0, 1.1]\n",
    "n_split, n_start = 1, 0\n",
    "risk_traj_supervised_true = []\n",
    "risk_traj_supervised_pred = []\n",
    "risk_traj_unsupervised_pred = []\n",
    "risk_traj_ppi_ideal = []\n",
    "\n",
    "for i, delta in enumerate(delta_list):\n",
    "    print(f\"=== Running experiment for delta = {delta:.2f} ===\")\n",
    "    C = spatial_covariance_ula(M, delta, sigma_asd)\n",
    "\n",
    "    pilot_bits = np.random.randint(0, 16, size=(nh,))\n",
    "    pilot_syms = CONST[pilot_bits] * np.sqrt(rho)\n",
    "\n",
    "    xhhh = ((np.random.randn(nh) + 1j*np.random.randn(nh)) / np.sqrt(2)) * np.sqrt(rho)\n",
    "\n",
    "    pilot_labels = np.array([\n",
    "        np.argmin(np.abs(s / np.sqrt(rho) - CONST)) for s in xhhh\n",
    "    ])\n",
    "\n",
    "    logits_trials, labels_trials = [], []\n",
    "    logits_pilot_trials, labels_pilot_trials = [], []\n",
    "\n",
    "    for trial in range(n_trails):\n",
    "        h = complex_normal(np.zeros(M, dtype=complex), C)\n",
    "        Zh = (np.random.randn(M, nh) + 1j*np.random.randn(M, nh)) * np.sqrt(sigma2_true / 2)\n",
    "        Yh = np.outer(h, xh) + Zh\n",
    "        yh_stack = vec(Yh)\n",
    "        h_hat = mmse_channel_estimate(C_hat, sigma2_hat, xh, yh_stack, M)\n",
    "\n",
    "        X_pilot = np.column_stack([\n",
    "            [mmse_equalize(h_hat, Yh[:, k], sigma2_hat, rho).real for k in range(nh)],\n",
    "            [mmse_equalize(h_hat, Yh[:, k], sigma2_hat, rho).imag for k in range(nh)]\n",
    "        ]).astype(np.float32)\n",
    "\n",
    "        decoder.eval()\n",
    "        with torch.no_grad():\n",
    "            X_t = torch.from_numpy(X_pilot).to(device)\n",
    "            logits_pilot = decoder(X_t).cpu().numpy()\n",
    "\n",
    "        logits_pilot_trials.append(logits_pilot)\n",
    "        labels_pilot_trials.append(pilot_labels)\n",
    "        idx_tx = np.random.randint(0, Mod, size=(n_test,))\n",
    "        s_tx = CONST[idx_tx] * np.sqrt(rho)\n",
    "        Z_data = (np.random.randn(M, n_test) + 1j*np.random.randn(M, n_test)) * np.sqrt(sigma2_true / 2)\n",
    "        Y_data = h.reshape(-1, 1) @ s_tx.reshape(1, -1) + Z_data\n",
    "\n",
    "        X_test = np.column_stack([\n",
    "            [mmse_equalize(h_hat, Y_data[:, k], sigma2_hat, rho).real for k in range(n_test)],\n",
    "            [mmse_equalize(h_hat, Y_data[:, k], sigma2_hat, rho).imag for k in range(n_test)]\n",
    "        ]).astype(np.float32)\n",
    "\n",
    "        decoder.eval()\n",
    "        with torch.no_grad():\n",
    "            X_t = torch.from_numpy(X_test).to(device)\n",
    "            logits = decoder(X_t).cpu().numpy()\n",
    "\n",
    "        logits_trials.append(logits)\n",
    "        labels_trials.append(idx_tx)\n",
    "\n",
    "        probs_data = np.exp(logits / temperature)\n",
    "        probs_data /= np.sum(probs_data, axis=-1, keepdims=True)\n",
    "\n",
    "        labels_data = idx_tx\n",
    "        Y_true = np.zeros_like(probs_data)\n",
    "        Y_true[np.arange(len(labels_data)), labels_data] = 1\n",
    "\n",
    "        probs_segment = probs_data[n_start:n_start + n_split]\n",
    "        Y_segment_true = Y_true[n_start:n_start + n_split]\n",
    "        risk_supervised_true = 0.5 * np.mean(np.sum((probs_segment - Y_segment_true) ** 2, axis=1))\n",
    "        risk_traj_supervised_true.append(risk_supervised_true)\n",
    "\n",
    "        pseudo_labels_segment = np.argmax(probs_segment, axis=1)\n",
    "        Y_segment_pseudo = np.zeros_like(probs_segment)\n",
    "        Y_segment_pseudo[np.arange(len(pseudo_labels_segment)), pseudo_labels_segment] = 1\n",
    "        risk_supervised_pred = 0.5 * np.mean(np.sum((probs_segment - Y_segment_pseudo) ** 2, axis=1))\n",
    "        risk_traj_supervised_pred.append(risk_supervised_pred)\n",
    "\n",
    "        probs_rest = probs_data[n_start + n_split:]\n",
    "        pseudo_labels_rest = np.argmax(probs_rest, axis=1)\n",
    "        Y_rest_pseudo = np.zeros_like(probs_rest)\n",
    "        Y_rest_pseudo[np.arange(len(pseudo_labels_rest)), pseudo_labels_rest] = 1\n",
    "        risk_unsupervised_pred = 0.5 * np.mean(np.sum((probs_rest - Y_rest_pseudo) ** 2, axis=1))\n",
    "        risk_traj_unsupervised_pred.append(risk_unsupervised_pred)\n",
    "\n",
    "        Y_rest_true = Y_true[n_start + n_split:]\n",
    "        risk_ppi_ideal = 0.5 * np.mean(np.sum((probs_rest - Y_rest_true) ** 2, axis=1))\n",
    "        risk_traj_ppi_ideal.append(risk_ppi_ideal)\n",
    "    \n",
    "    print(f\"=== The risk on the setting is {np.mean(np.array(risk_traj_ppi_ideal)):.3f} ===\")\n",
    "    print(f\"=== The risk on the pseudo is {np.mean(np.array(risk_traj_unsupervised_pred)):.3f} ===\")\n",
    "\n",
    "risk_traj_supervised_true =   np.array(risk_traj_supervised_true)\n",
    "risk_traj_supervised_pred =   np.array(risk_traj_supervised_pred)\n",
    "risk_traj_unsupervised_pred = np.array(risk_traj_unsupervised_pred)\n",
    "risk_traj_ppi_ideal =         np.array(risk_traj_ppi_ideal)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def compute_bounds_cmeb(traj_sup_true, traj_unsup_proxy, L1_proxy, alpha, t_opt_ratio):\n",
    "    # --- prepare proxy L1 (as np array) ---\n",
    "    L1 = L1_proxy.numpy() if hasattr(L1_proxy, \"numpy\") else np.array(L1_proxy)\n",
    "    # ====================== conjmix ======================\n",
    "    T_OPT = int(len(traj_sup_true) * t_opt_ratio)\n",
    "    v_opt_sup = compute_v_opt(traj_sup_true, T_OPT)\n",
    "    v_opt_unsup = compute_v_opt(traj_unsup_proxy, T_OPT)\n",
    "    L_sup_conjmix, _ = conjmix_empbern_twosided_cs(x=traj_sup_true, alpha=alpha * 2, v_opt=v_opt_sup)\n",
    "    # unsupervised\n",
    "    L_unsup_conjmix, _ = conjmix_empbern_twosided_cs(x=traj_unsup_proxy, alpha=alpha * 2, v_opt=v_opt_unsup)\n",
    "    L_unsup_conjmix = np.maximum(L_unsup_conjmix - L1, np.zeros_like(L1))\n",
    "    # L_unsup_conjmix = np.maximum((L_unsup_conjmix - L1) * best_q.numpy(), np.zeros_like(L1))\n",
    "    supervised_bounds = {\"CM-EB\": L_sup_conjmix}\n",
    "    unsupervised_bounds = {\"CM-EB\": L_unsup_conjmix}\n",
    "    return supervised_bounds, unsupervised_bounds\n",
    "\n",
    "def compute_bounds_pmeb_betting(traj_sup_true, traj_unsup_proxy, L1_proxy, alpha):\n",
    "    # --- prepare proxy L1 (as np array) ---\n",
    "    L1 = L1_proxy.numpy() if hasattr(L1_proxy, \"numpy\") else np.array(L1_proxy)\n",
    "    # ====================== predmix ======================\n",
    "    # supervised\n",
    "    L_sup_predmix, _ = predmix_empbern_twosided_cs(x=traj_sup_true, alpha=alpha)\n",
    "    # unsupervised\n",
    "    L_unsup_predmix, _ = predmix_empbern_twosided_cs(x=traj_unsup_proxy, alpha=alpha)\n",
    "    L_unsup_predmix = np.maximum(L_unsup_predmix - L1, np.zeros_like(L1))\n",
    "\n",
    "    # ====================== betting ======================\n",
    "    L_sup_betting, _ = betting_cs(x=traj_sup_true, alpha=alpha)\n",
    "\n",
    "    L_unsup_betting, _ = betting_cs(x=traj_unsup_proxy, alpha=alpha)\n",
    "    L_unsup_betting = np.maximum(L_unsup_betting - L1, np.zeros_like(L_unsup_betting))\n",
    "\n",
    "    supervised_bounds = {\"PM-EB\": L_sup_predmix,\"Betting\": L_sup_betting,}\n",
    "    unsupervised_bounds = {\"PM-EB\": L_unsup_predmix,\"Betting\": L_unsup_betting,}\n",
    "    return supervised_bounds, unsupervised_bounds\n",
    "\n",
    "\n",
    "def compute_bounds_cmeb_ppi(traj_ppi_any, alpha, t_opt_ratio, eta_max):\n",
    "    a = - eta_max\n",
    "    b = 1 + eta_max\n",
    "    traj_ppi_any = (traj_ppi_any  - a) / (b - a)\n",
    "    T_OPT = int(len(traj_ppi_any) * t_opt_ratio)\n",
    "    v_opt_ppi = compute_v_opt(traj_ppi_any, T_OPT)\n",
    "    L_ppi_conjmix, _ = conjmix_empbern_twosided_cs(x=traj_ppi_any, alpha=alpha * 2, v_opt=v_opt_ppi)\n",
    "    L_ppi_conjmix = L_ppi_conjmix * (b - a) + a\n",
    "    ppi_bounds = {\"CM-EB\": L_ppi_conjmix}\n",
    "    return ppi_bounds\n",
    "\n",
    "def compute_bounds_pmeb_betting_ppi(traj_ppi_any, alpha, eta_max):\n",
    "    a = 0\n",
    "    b = 1 + eta_max\n",
    "    traj_ppi_any = (traj_ppi_any  - a) / (b - a)\n",
    "    # v_opt_ppi = T_OPT * 0.5\n",
    "    L_ppi_predmix, _ = predmix_empbern_twosided_cs(x=traj_ppi_any, alpha=alpha)\n",
    "    L_ppi_predmix = L_ppi_predmix * (b - a) + a\n",
    "    L_ppi_betting, _ = betting_cs(x=traj_ppi_any, alpha=alpha)\n",
    "    L_ppi_betting = L_ppi_betting * (b - a) + a\n",
    "    ppi_bounds = {\"PM-EB\": L_ppi_predmix, \"Betting\": L_ppi_betting}\n",
    "    return ppi_bounds\n",
    "\n",
    "def compute_eta_t(labeled_losses_hist, agent_labeled_losses_hist, agent_unlabeled_losses_hist, eta_max,  eps=1e-8):\n",
    "    u = np.concatenate(labeled_losses_hist, axis=0)\n",
    "    u_tilde_l = np.concatenate(agent_labeled_losses_hist, axis=0)\n",
    "    u_tilde_u = np.concatenate(agent_unlabeled_losses_hist, axis=0)\n",
    "    n_l = len(u)\n",
    "    n_u = len(u_tilde_u)\n",
    "\n",
    "    if n_l == 0 or n_u == 0:\n",
    "        return 0.0\n",
    "    u_bar = u.mean()\n",
    "    u_tilde_l_bar = u_tilde_l.mean()\n",
    "    u_tilde_u_bar = u_tilde_u.mean()\n",
    "\n",
    "    cov = np.mean((u - u_bar) * (u_tilde_l - u_tilde_l_bar))\n",
    "\n",
    "    var = np.mean((u_tilde_u - u_tilde_u_bar) ** 2)\n",
    "    if var < eps:\n",
    "        return 0.0\n",
    "    eta = cov / ((1.0 + n_l / n_u) * var)\n",
    "    eta = np.clip(eta, 0.0, eta_max)\n",
    "    return float(eta)\n",
    "\n",
    "def softmax_brier_loss_vector(logits, targets, temperature=1.0):\n",
    "    \"\"\"\n",
    "    logits: Tensor (N, C)\n",
    "    targets: Tensor (N,) or (N, C)\n",
    "    return: Tensor (N,), dtype=torch.float32\n",
    "    \"\"\"\n",
    "    probs = F.softmax(logits / temperature, dim=1)\n",
    "\n",
    "    if targets.ndim == 1:\n",
    "        targets = F.one_hot(\n",
    "            targets.to(torch.int64),\n",
    "            num_classes=probs.shape[1]\n",
    "        ).to(dtype=probs.dtype)\n",
    "\n",
    "    loss_vec = 0.5 * torch.sum((probs - targets) ** 2, dim=1)\n",
    "    return loss_vec.to(torch.float32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "N_EXPERIMENTS = 15                 \n",
    "ALPHA = 0.2                          \n",
    "THRESH = 0.11     \n",
    "T_OPT_RATIO = 0.25\n",
    "VERBOSE = True\n",
    "ETA_MAX =  1.0\n",
    "WINDOW_SIZE = 60\n",
    "eta_fixed = 1.0\n",
    "set_seed(0)\n",
    "methods = [\"CM-EB\"]\n",
    "results = {m: {\"supervised\": [], \"unsupervised\": [], \"ppi\": [], \"adaptive_ppi\": [], \"ideal_ppi\": []} for m in methods}\n",
    "results_box = {m: {\"supervised\": [], \"unsupervised\": [], \"ppi\": [], \"adaptive_ppi\": [], \"ideal_ppi\": []} for m in methods}\n",
    "n_split, n_start = 1, 0\n",
    "all_traj_list = []\n",
    "for exp_id in range(N_EXPERIMENTS):\n",
    "    n_test = 32\n",
    "    n_trails = 250   \n",
    "    temperature = 1.0\n",
    "    delta_list = [0.2, 0.8, 0.8, 0.8, 1.0, 1.1]\n",
    "    n_split, n_start = 1, 0\n",
    "    risk_traj_supervised_true = []\n",
    "    risk_traj_supervised_pred = []\n",
    "    risk_traj_unsupervised_pred = []\n",
    "    risk_traj_ppi_ideal = []\n",
    "\n",
    "    eta_seq = []\n",
    "    labeled_losses_hist = []\n",
    "    agent_labeled_losses_hist = []\n",
    "    agent_unlabeled_losses_hist = []\n",
    "    for i, delta in enumerate(delta_list):\n",
    "        print(f\"=== Running experiment for delta = {delta:.2f} ===\")\n",
    "        C = spatial_covariance_ula(M, delta, sigma_asd)\n",
    "        pilot_bits = np.random.randint(0, 16, size=(nh,))\n",
    "        pilot_syms = CONST[pilot_bits] * np.sqrt(rho)\n",
    "        xhhh = ((np.random.randn(nh) + 1j*np.random.randn(nh)) / np.sqrt(2)) * np.sqrt(rho)\n",
    "        pilot_labels = np.array([np.argmin(np.abs(s / np.sqrt(rho) - CONST)) for s in xhhh])\n",
    "        logits_trials, labels_trials = [], []\n",
    "        logits_pilot_trials, labels_pilot_trials = [], []\n",
    "        for trial in range(n_trails):\n",
    "            h = complex_normal(np.zeros(M, dtype=complex), C)\n",
    "            Zh = (np.random.randn(M, nh) + 1j*np.random.randn(M, nh)) * np.sqrt(sigma2_true / 2)\n",
    "            Yh = np.outer(h, xh) + Zh\n",
    "            yh_stack = vec(Yh)\n",
    "            h_hat = mmse_channel_estimate(C_hat, sigma2_hat, xh, yh_stack, M)\n",
    "            X_pilot = np.column_stack([\n",
    "                [mmse_equalize(h_hat, Yh[:, k], sigma2_hat, rho).real for k in range(nh)],\n",
    "                [mmse_equalize(h_hat, Yh[:, k], sigma2_hat, rho).imag for k in range(nh)]\n",
    "            ]).astype(np.float32)\n",
    "            decoder.eval()\n",
    "            with torch.no_grad():\n",
    "                X_t = torch.from_numpy(X_pilot).to(device)\n",
    "                logits_pilot = decoder(X_t).cpu().numpy()\n",
    "            logits_pilot_trials.append(logits_pilot)\n",
    "            labels_pilot_trials.append(pilot_labels)\n",
    "            idx_tx = np.random.randint(0, Mod, size=(n_test,))\n",
    "            s_tx = CONST[idx_tx] * np.sqrt(rho)\n",
    "            Z_data = (np.random.randn(M, n_test) + 1j*np.random.randn(M, n_test)) * np.sqrt(sigma2_true / 2)\n",
    "            Y_data = h.reshape(-1, 1) @ s_tx.reshape(1, -1) + Z_data\n",
    "            X_test = np.column_stack([\n",
    "                [mmse_equalize(h_hat, Y_data[:, k], sigma2_hat, rho).real for k in range(n_test)],\n",
    "                [mmse_equalize(h_hat, Y_data[:, k], sigma2_hat, rho).imag for k in range(n_test)]\n",
    "            ]).astype(np.float32)\n",
    "            decoder.eval()\n",
    "            with torch.no_grad():\n",
    "                X_t = torch.from_numpy(X_test).to(device)\n",
    "                logits = decoder(X_t).cpu().numpy()\n",
    "            logits_trials.append(logits)\n",
    "            labels_trials.append(idx_tx)\n",
    "            probs_data = np.exp(logits / temperature)\n",
    "            probs_data /= np.sum(probs_data, axis=-1, keepdims=True)\n",
    "            labels_data = idx_tx\n",
    "            Y_true = np.zeros_like(probs_data)\n",
    "            Y_true[np.arange(len(labels_data)), labels_data] = 1\n",
    "\n",
    "            probs_segment = probs_data[n_start:n_start + n_split]\n",
    "            Y_segment_true = Y_true[n_start:n_start + n_split]\n",
    "            risk_supervised_true = 0.5 * np.mean(np.sum((probs_segment - Y_segment_true) ** 2, axis=1))\n",
    "            risk_traj_supervised_true.append(risk_supervised_true)\n",
    "\n",
    "            pseudo_labels_segment = np.argmax(probs_segment, axis=1)\n",
    "            Y_segment_pseudo = np.zeros_like(probs_segment)\n",
    "            Y_segment_pseudo[np.arange(len(pseudo_labels_segment)), pseudo_labels_segment] = 1\n",
    "            risk_supervised_pred = 0.5 * np.mean(np.sum((probs_segment - Y_segment_pseudo) ** 2, axis=1))\n",
    "            risk_traj_supervised_pred.append(risk_supervised_pred)\n",
    "\n",
    "            probs_rest = probs_data[n_start + n_split:]\n",
    "            pseudo_labels_rest = np.argmax(probs_rest, axis=1)\n",
    "            Y_rest_pseudo = np.zeros_like(probs_rest)\n",
    "            Y_rest_pseudo[np.arange(len(pseudo_labels_rest)), pseudo_labels_rest] = 1\n",
    "            risk_unsupervised_pred = 0.5 * np.mean(np.sum((probs_rest - Y_rest_pseudo) ** 2, axis=1))\n",
    "            risk_traj_unsupervised_pred.append(risk_unsupervised_pred)\n",
    "\n",
    "\n",
    "            Y_rest_true = Y_true[n_start + n_split:]\n",
    "            risk_ppi_ideal = 0.5 * np.mean(np.sum((probs_rest - Y_rest_true) ** 2, axis=1))\n",
    "            risk_traj_ppi_ideal.append(risk_ppi_ideal)\n",
    "\n",
    "            u_t = softmax_brier_loss_vector(torch.from_numpy(probs_segment), torch.from_numpy(Y_segment_true), temperature)\n",
    "            u_tilde_l_t = softmax_brier_loss_vector(torch.from_numpy(probs_segment), torch.from_numpy(Y_segment_pseudo), temperature)\n",
    "            u_tilde_u_t = softmax_brier_loss_vector(torch.from_numpy(probs_rest), torch.from_numpy(Y_rest_pseudo), temperature)\n",
    "            # ---- compute eta_t using ONLY history ----\n",
    "            if len(labeled_losses_hist) > 20:\n",
    "                eta_t = compute_eta_t( labeled_losses_hist, agent_labeled_losses_hist, agent_unlabeled_losses_hist,ETA_MAX)\n",
    "            else:\n",
    "                eta_t = ETA_MAX / 2   \n",
    "            eta_seq.append(eta_t)\n",
    "            # ---- update history ----\n",
    "            labeled_losses_hist.append(u_t)\n",
    "            agent_labeled_losses_hist.append(u_tilde_l_t)\n",
    "            agent_unlabeled_losses_hist.append(u_tilde_u_t)\n",
    "            # ---- sliding window ----\n",
    "            if len(labeled_losses_hist) > WINDOW_SIZE:\n",
    "                labeled_losses_hist.pop(0)\n",
    "                agent_labeled_losses_hist.pop(0)\n",
    "                agent_unlabeled_losses_hist.pop(0)\n",
    "    risk_traj_supervised_true =   np.array(risk_traj_supervised_true)\n",
    "    risk_traj_supervised_pred =   np.array(risk_traj_supervised_pred)\n",
    "    risk_traj_unsupervised_pred = np.array(risk_traj_unsupervised_pred)\n",
    "    risk_traj_ppi_ideal =         np.array(risk_traj_ppi_ideal)\n",
    "\n",
    "    traj_sup_true = risk_traj_supervised_true\n",
    "    traj_ppi_ideal = risk_traj_ppi_ideal\n",
    "    traj_unsup_proxy = risk_traj_supervised_true \n",
    "    traj_ppi_pred = eta_fixed * risk_traj_unsupervised_pred + risk_traj_supervised_true - eta_fixed * risk_traj_supervised_pred\n",
    "    traj_ppi_pred_adaptive = eta_seq * risk_traj_unsupervised_pred + risk_traj_supervised_true - eta_seq * risk_traj_supervised_pred\n",
    "    all_traj_list.append(traj_ppi_ideal)\n",
    "    if VERBOSE:\n",
    "        print(f\"=== Experiment {exp_id+1}/{N_EXPERIMENTS} ===\")\n",
    "    try:\n",
    "        sup_bounds, unsup_bounds = compute_bounds_cmeb(traj_sup_true, traj_unsup_proxy, L1_proxy=1.0, alpha=ALPHA, t_opt_ratio=T_OPT_RATIO)\n",
    "        ppi_bounds =               compute_bounds_cmeb_ppi(traj_ppi_pred, alpha=ALPHA, t_opt_ratio=T_OPT_RATIO, eta_max=eta_fixed)\n",
    "        adaptive_ppi_bounds =      compute_bounds_cmeb_ppi(traj_ppi_pred_adaptive,  alpha=ALPHA, t_opt_ratio=T_OPT_RATIO, eta_max=ETA_MAX)\n",
    "        ideal_ppi_bounds =         compute_bounds_cmeb_ppi(traj_ppi_ideal,  alpha=ALPHA, t_opt_ratio=T_OPT_RATIO, eta_max=eta_fixed)\n",
    "    except Exception as e:\n",
    "        print(f\"[Warning] Skipping exp {exp_id}: compute_bounds() failed — {e}\")\n",
    "        continue\n",
    "    for m in methods:\n",
    "        results[m][\"supervised\"].append(sup_bounds[m])\n",
    "        results[m][\"unsupervised\"].append(unsup_bounds[m])\n",
    "        results[m][\"ppi\"].append(ppi_bounds[m])\n",
    "        results[m][\"ideal_ppi\"].append(ideal_ppi_bounds[m])\n",
    "        results[m][\"adaptive_ppi\"].append(adaptive_ppi_bounds[m])\n",
    "    for m in methods:\n",
    "        sup_arr = np.asarray(sup_bounds.get(m, []))\n",
    "        unsup_arr = np.asarray(unsup_bounds.get(m, []))\n",
    "        adaptive_ppi_arr = np.asarray(adaptive_ppi_bounds.get(m, []))\n",
    "        ppi_arr = np.asarray(ppi_bounds.get(m, []))\n",
    "        ideal_ppi_arr = np.asarray(ideal_ppi_bounds.get(m, []))\n",
    "\n",
    "        assert len(sup_arr) == len(ppi_arr) == len(ideal_ppi_arr) == len(unsup_arr) == len(traj_sup_true)\n",
    "\n",
    "        # find threshold crossings\n",
    "        t_sup = int(np.argmax(sup_arr > THRESH) + 1) if np.any(sup_arr > THRESH) else np.nan\n",
    "        t_unsup = int(np.argmax(unsup_arr > THRESH) + 1) if np.any(unsup_arr > THRESH) else np.nan\n",
    "        t_ppi = int(np.argmax(ppi_arr > THRESH) + 1) if np.any(ppi_arr > THRESH) else np.nan\n",
    "        t_adaptive = int(np.argmax(adaptive_ppi_arr > THRESH) + 1) if np.any(adaptive_ppi_arr > THRESH) else np.nan\n",
    "        t_ideal = int(np.argmax(ideal_ppi_arr > THRESH) + 1) if np.any(ideal_ppi_arr > THRESH) else np.nan\n",
    "\n",
    "        results_box[m][\"supervised\"].append(t_sup)\n",
    "        results_box[m][\"unsupervised\"].append(t_unsup)\n",
    "        results_box[m][\"ppi\"].append(t_ppi)\n",
    "        results_box[m][\"adaptive_ppi\"].append(t_adaptive)\n",
    "        results_box[m][\"ideal_ppi\"].append(t_ideal)\n",
    "\n",
    "colors = {\"supervised\": \"#1f77b4\", \"unsupervised\": \"#7f7f7f\",  \"ppi\": \"#ff7f0e\",  \"adaptive_ppi\": \"#2ca02c\", \"ideal_ppi\": \"#9467bd\" }\n",
    "labels = {\"supervised\": \"SRM\", \"unsupervised\": \"URM\", \"adaptive_ppi\": r\"PPRM\", \"ppi\": \"PPRM\", \"ideal_ppi\": \"Ideal PPRM\"}\n",
    "\n",
    "fig, ax = plt.subplots(figsize=(6, 6.5))\n",
    "m = methods[0]\n",
    "if True:\n",
    "    for key in [\"supervised\", \"adaptive_ppi\", \"ideal_ppi\"]:\n",
    "        arrs = np.array(results[m][key])\n",
    "        if len(arrs) == 0:\n",
    "            continue\n",
    "        mean_curve = np.nanmean(arrs, axis=0)\n",
    "        std_curve = np.nanstd(arrs, axis=0)\n",
    "        steps = np.arange(len(mean_curve))\n",
    "        ax.plot(steps, mean_curve, label=labels[key], color=colors[key])\n",
    "        ax.fill_between(\n",
    "            steps,\n",
    "            mean_curve - std_curve,\n",
    "            mean_curve + std_curve,\n",
    "            color=colors[key],\n",
    "            alpha=0.2)\n",
    "    try:\n",
    "        traj_ppi_ideal_all = np.array([\n",
    "            running_average_cumulative(traj)\n",
    "            for traj in all_traj_list])\n",
    "        traj_ppi_ideal_mean = np.nanmean(traj_ppi_ideal_all, axis=0)\n",
    "        traj_ppi_ideal_std = np.nanstd(traj_ppi_ideal_all, axis=0)\n",
    "        steps = np.arange(len(traj_ppi_ideal_mean))\n",
    "        ax.plot(\n",
    "            steps,\n",
    "            traj_ppi_ideal_mean,\n",
    "            linestyle=\"--\",\n",
    "            color=\"#9467bd\",\n",
    "            linewidth=2,\n",
    "            label=\"Running Risk\")\n",
    "        ax.fill_between(\n",
    "            steps,\n",
    "            traj_ppi_ideal_mean - traj_ppi_ideal_std,\n",
    "            traj_ppi_ideal_mean + traj_ppi_ideal_std,\n",
    "            color=\"#9467bd\",\n",
    "            alpha=0.15)\n",
    "    except Exception as e:\n",
    "        print(f\"[Warning] Failed to plot averaged PPI_Ideal risk trajectory: {e}\")\n",
    "    font_size = 22\n",
    "    ax.axhline(y=THRESH, color=\"red\", linestyle=\"--\", linewidth=2, label=f\"Risk Threshold\")\n",
    "    # ax.set_title(f\"Bound Method: {m}\", fontsize=font_size)\n",
    "    ax.set_xlabel(r\"Time Step $t$\", fontsize=font_size)\n",
    "    ax.grid(True, linestyle=\":\", linewidth=2, alpha=0.4)\n",
    "    ax.set_ylabel(\"Running risk lower bound\", fontsize=font_size+1)\n",
    "    ax.legend(fontsize=font_size, loc='upper left', ncol=1)\n",
    "    ax.tick_params(axis='both', labelsize=font_size+1)  \n",
    "\n",
    "    plt.ylim(0.06, 0.16)\n",
    "    plt.tight_layout()\n",
    "    save_dir = \"Simulations/Results_Channel\"\n",
    "    os.makedirs(save_dir, exist_ok=True)\n",
    "\n",
    "    save_path_pdf = os.path.join(save_dir, f\"sim_fig_ceq_increase_delta{delta_list[:]}_lowerbound_Brier_self.pdf\")\n",
    "    plt.savefig(save_path_pdf, format='pdf', bbox_inches='tight')\n",
    "    plt.show()\n",
    "\n",
    "if True:\n",
    "    font_size = 22.5\n",
    "    plot_params = {\n",
    "        \"title_fontsize\": font_size,\n",
    "        \"xlabel_fontsize\": font_size,\n",
    "        \"ylabel_fontsize\": font_size,\n",
    "        \"xtick_fontsize\": font_size,\n",
    "        \"ytick_fontsize\": font_size,\n",
    "        \"legend_fontsize\": font_size,\n",
    "        \"suptitle_fontsize\": font_size,\n",
    "        \"title_fontweight\": \"bold\",\n",
    "        \"label_fontweight\": \"normal\"\n",
    "    }\n",
    "\n",
    "    colors = [ \"#1f77b4\",\"#2ca02c\",  \"#9467bd\" ]\n",
    "\n",
    "    title_map = {\"srm\": \"SRM\", \"pprm\": \"PPRM\", \"Adaptive pprm\": \"PPRM\", \"ideal pprm\": \"Ideal PPRM\", \"urm\": \"URM\"}\n",
    "    fig, ax = plt.subplots(figsize=(6, 6.5))\n",
    "    m = methods[0]\n",
    "    data_for_box = []\n",
    "    labels_for_box = []\n",
    "    # supervised\n",
    "    sup_list = np.array(results_box[m][\"supervised\"], dtype=float)\n",
    "    sup_list = sup_list[~np.isnan(sup_list)]\n",
    "    if sup_list.size == 0:\n",
    "        sup_list = np.array([np.nan])\n",
    "    data_for_box.append(sup_list)\n",
    "    labels_for_box.append(\"SRM\")\n",
    "\n",
    "    # adaptive ppi\n",
    "    adaptive_ppi_list = np.array(results_box[m][\"adaptive_ppi\"], dtype=float)\n",
    "    adaptive_ppi_list = adaptive_ppi_list[~np.isnan(adaptive_ppi_list)]\n",
    "    if adaptive_ppi_list.size == 0:\n",
    "        adaptive_ppi_list = np.array([np.nan])\n",
    "    data_for_box.append(adaptive_ppi_list)\n",
    "    labels_for_box.append(\"PPRM\")\n",
    "\n",
    "    ideal_list = np.array(results_box[m][\"ideal_ppi\"], dtype=float)\n",
    "    ideal_list = ideal_list[~np.isnan(ideal_list)]\n",
    "    if ideal_list.size == 0:\n",
    "        ideal_list = np.array([np.nan])\n",
    "    data_for_box.append(ideal_list)\n",
    "    labels_for_box.append(\"Ideal PPRM\")\n",
    "    bp = ax.boxplot(\n",
    "        data_for_box,\n",
    "        labels=labels_for_box,\n",
    "        showmeans=True,\n",
    "        patch_artist=True,\n",
    "        boxprops=dict(linewidth=2),\n",
    "        whiskerprops=dict(linewidth=2),\n",
    "        capprops=dict(linewidth=2),\n",
    "        medianprops=dict(linewidth=2),\n",
    "        showfliers=False)\n",
    "\n",
    "    for patch, color in zip(bp['boxes'], colors):\n",
    "        patch.set_facecolor(color)\n",
    "        patch.set_alpha(0.6)\n",
    "\n",
    "    for mean in bp['means']:\n",
    "        mean.set_markerfacecolor(\"red\")\n",
    "        mean.set_markeredgecolor(\"black\")\n",
    "    ax.tick_params(axis=\"x\", labelsize=plot_params[\"xtick_fontsize\"])\n",
    "    ax.tick_params(axis=\"y\", labelsize=plot_params[\"ytick_fontsize\"])\n",
    "    ax.grid(True, linestyle=\":\", linewidth=2, alpha=0.4)\n",
    "\n",
    "    ax.set_ylabel(\"Average time to alarm\",\n",
    "                    fontsize=plot_params[\"ylabel_fontsize\"],\n",
    "                    fontweight=plot_params[\"label_fontweight\"])\n",
    "\n",
    "    plt.tight_layout(rect=[0, 0, 1, 0.96])\n",
    "    save_dir = \"Simulations/Results_Channel\"\n",
    "    os.makedirs(save_dir, exist_ok=True)  \n",
    "\n",
    "    save_path_pdf = os.path.join(save_dir, f\"sim_fig_ceq_increase_delta{delta_list[:]}_alarm_time_brier_self.pdf\")\n",
    "    plt.savefig(save_path_pdf, format='pdf', bbox_inches='tight')\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "N_EXPERIMENTS = 15                 \n",
    "ALPHA = 0.2                          \n",
    "THRESH = 0.11     \n",
    "T_OPT_RATIO = 0.25\n",
    "VERBOSE = True\n",
    "ETA_MAX =  1.0\n",
    "WINDOW_SIZE = 60\n",
    "eta_fixed = 1.0\n",
    "set_seed(0)\n",
    "methods = [\"CM-EB\"]\n",
    "results = {m: {\"supervised\": [], \"unsupervised\": [], \"ppi\": [], \"adaptive_ppi\": [], \"ideal_ppi\": []} for m in methods}\n",
    "results_box = {m: {\"supervised\": [], \"unsupervised\": [], \"ppi\": [], \"adaptive_ppi\": [], \"ideal_ppi\": []} for m in methods}\n",
    "n_split, n_start = 1, 0\n",
    "all_traj_list = []\n",
    "for exp_id in range(N_EXPERIMENTS):\n",
    "    n_test = 32 \n",
    "    n_trails = 200   \n",
    "    temperature = 1.0\n",
    "    delta_list = [0.2, 0.8, 0.8, 1.1, 0.9, 0.5]\n",
    "    n_split, n_start = 1, 0\n",
    "    risk_traj_supervised_true = []\n",
    "    risk_traj_supervised_pred = []\n",
    "    risk_traj_unsupervised_pred = []\n",
    "    risk_traj_ppi_ideal = []\n",
    "    eta_seq = []\n",
    "    labeled_losses_hist = []\n",
    "    agent_labeled_losses_hist = []\n",
    "    agent_unlabeled_losses_hist = []\n",
    "    for i, delta in enumerate(delta_list):\n",
    "        print(f\"=== Running experiment for delta = {delta:.2f} ===\")\n",
    "        C = spatial_covariance_ula(M, delta, sigma_asd)\n",
    "\n",
    "        pilot_bits = np.random.randint(0, 16, size=(nh,))\n",
    "        pilot_syms = CONST[pilot_bits] * np.sqrt(rho)\n",
    "        xhhh = ((np.random.randn(nh) + 1j*np.random.randn(nh)) / np.sqrt(2)) * np.sqrt(rho)\n",
    "        pilot_labels = np.array([np.argmin(np.abs(s / np.sqrt(rho) - CONST)) for s in xhhh])\n",
    "\n",
    "        logits_trials, labels_trials = [], []\n",
    "        logits_pilot_trials, labels_pilot_trials = [], []\n",
    "        for trial in range(n_trails):\n",
    "            h = complex_normal(np.zeros(M, dtype=complex), C)\n",
    "            Zh = (np.random.randn(M, nh) + 1j*np.random.randn(M, nh)) * np.sqrt(sigma2_true / 2)\n",
    "            Yh = np.outer(h, xh) + Zh\n",
    "            yh_stack = vec(Yh)\n",
    "            h_hat = mmse_channel_estimate(C_hat, sigma2_hat, xh, yh_stack, M)\n",
    "            X_pilot = np.column_stack([\n",
    "                [mmse_equalize(h_hat, Yh[:, k], sigma2_hat, rho).real for k in range(nh)],\n",
    "                [mmse_equalize(h_hat, Yh[:, k], sigma2_hat, rho).imag for k in range(nh)]\n",
    "            ]).astype(np.float32)\n",
    "            decoder.eval()\n",
    "            with torch.no_grad():\n",
    "                X_t = torch.from_numpy(X_pilot).to(device)\n",
    "                logits_pilot = decoder(X_t).cpu().numpy()\n",
    "            logits_pilot_trials.append(logits_pilot)\n",
    "            labels_pilot_trials.append(pilot_labels)\n",
    "            idx_tx = np.random.randint(0, Mod, size=(n_test,))\n",
    "            s_tx = CONST[idx_tx] * np.sqrt(rho)\n",
    "            Z_data = (np.random.randn(M, n_test) + 1j*np.random.randn(M, n_test)) * np.sqrt(sigma2_true / 2)\n",
    "            Y_data = h.reshape(-1, 1) @ s_tx.reshape(1, -1) + Z_data\n",
    "            X_test = np.column_stack([\n",
    "                [mmse_equalize(h_hat, Y_data[:, k], sigma2_hat, rho).real for k in range(n_test)],\n",
    "                [mmse_equalize(h_hat, Y_data[:, k], sigma2_hat, rho).imag for k in range(n_test)]\n",
    "            ]).astype(np.float32)\n",
    "            decoder.eval()\n",
    "            with torch.no_grad():\n",
    "                X_t = torch.from_numpy(X_test).to(device)\n",
    "                logits = decoder(X_t).cpu().numpy()\n",
    "            logits_trials.append(logits)\n",
    "            labels_trials.append(idx_tx)\n",
    "            probs_data = np.exp(logits / temperature)\n",
    "            probs_data /= np.sum(probs_data, axis=-1, keepdims=True)\n",
    "            labels_data = idx_tx\n",
    "            Y_true = np.zeros_like(probs_data)\n",
    "            Y_true[np.arange(len(labels_data)), labels_data] = 1\n",
    "\n",
    " \n",
    "            probs_segment = probs_data[n_start:n_start + n_split]\n",
    "            Y_segment_true = Y_true[n_start:n_start + n_split]\n",
    "            risk_supervised_true = 0.5 * np.mean(np.sum((probs_segment - Y_segment_true) ** 2, axis=1))\n",
    "            risk_traj_supervised_true.append(risk_supervised_true)\n",
    "\n",
    "            pseudo_labels_segment = np.argmax(probs_segment, axis=1)\n",
    "            Y_segment_pseudo = np.zeros_like(probs_segment)\n",
    "            Y_segment_pseudo[np.arange(len(pseudo_labels_segment)), pseudo_labels_segment] = 1\n",
    "            risk_supervised_pred = 0.5 * np.mean(np.sum((probs_segment - Y_segment_pseudo) ** 2, axis=1))\n",
    "            risk_traj_supervised_pred.append(risk_supervised_pred)\n",
    "\n",
    "            probs_rest = probs_data[n_start + n_split:]\n",
    "            pseudo_labels_rest = np.argmax(probs_rest, axis=1)\n",
    "            Y_rest_pseudo = np.zeros_like(probs_rest)\n",
    "            Y_rest_pseudo[np.arange(len(pseudo_labels_rest)), pseudo_labels_rest] = 1\n",
    "            risk_unsupervised_pred = 0.5 * np.mean(np.sum((probs_rest - Y_rest_pseudo) ** 2, axis=1))\n",
    "            risk_traj_unsupervised_pred.append(risk_unsupervised_pred)\n",
    "\n",
    "            Y_rest_true = Y_true[n_start + n_split:]\n",
    "            risk_ppi_ideal = 0.5 * np.mean(np.sum((probs_rest - Y_rest_true) ** 2, axis=1))\n",
    "            risk_traj_ppi_ideal.append(risk_ppi_ideal)\n",
    "\n",
    "            u_t = softmax_brier_loss_vector(torch.from_numpy(probs_segment), torch.from_numpy(Y_segment_true), temperature)\n",
    "            u_tilde_l_t = softmax_brier_loss_vector(torch.from_numpy(probs_segment), torch.from_numpy(Y_segment_pseudo), temperature)\n",
    "            u_tilde_u_t = softmax_brier_loss_vector(torch.from_numpy(probs_rest), torch.from_numpy(Y_rest_pseudo), temperature)\n",
    "            # ---- compute eta_t using ONLY history ----\n",
    "            if len(labeled_losses_hist) > 5:\n",
    "                eta_t = compute_eta_t( labeled_losses_hist, agent_labeled_losses_hist, agent_unlabeled_losses_hist,ETA_MAX)\n",
    "            else:\n",
    "                eta_t = ETA_MAX / 2   \n",
    "            eta_seq.append(eta_t)\n",
    "            # ---- update history ----\n",
    "            labeled_losses_hist.append(u_t)\n",
    "            agent_labeled_losses_hist.append(u_tilde_l_t)\n",
    "            agent_unlabeled_losses_hist.append(u_tilde_u_t)\n",
    "            # ---- sliding window ----\n",
    "            if len(labeled_losses_hist) > WINDOW_SIZE:\n",
    "                labeled_losses_hist.pop(0)\n",
    "                agent_labeled_losses_hist.pop(0)\n",
    "                agent_unlabeled_losses_hist.pop(0)\n",
    "    risk_traj_supervised_true =   np.array(risk_traj_supervised_true)\n",
    "    risk_traj_supervised_pred =   np.array(risk_traj_supervised_pred)\n",
    "    risk_traj_unsupervised_pred = np.array(risk_traj_unsupervised_pred)\n",
    "    risk_traj_ppi_ideal =         np.array(risk_traj_ppi_ideal)\n",
    "\n",
    "    traj_sup_true = risk_traj_supervised_true\n",
    "    traj_ppi_ideal = risk_traj_ppi_ideal\n",
    "    traj_unsup_proxy = risk_traj_supervised_true \n",
    "    traj_ppi_pred = eta_fixed * risk_traj_unsupervised_pred + risk_traj_supervised_true - eta_fixed * risk_traj_supervised_pred\n",
    "    traj_ppi_pred_adaptive = eta_seq * risk_traj_unsupervised_pred + risk_traj_supervised_true - eta_seq * risk_traj_supervised_pred\n",
    "    all_traj_list.append(traj_ppi_ideal)\n",
    "    if VERBOSE:\n",
    "        print(f\"=== Experiment {exp_id+1}/{N_EXPERIMENTS} ===\")\n",
    "    try:\n",
    "        sup_bounds, unsup_bounds = compute_bounds_cmeb(traj_sup_true, traj_unsup_proxy, L1_proxy=1.0, alpha=ALPHA, t_opt_ratio=T_OPT_RATIO)\n",
    "        ppi_bounds =               compute_bounds_cmeb_ppi(traj_ppi_pred, alpha=ALPHA, t_opt_ratio=T_OPT_RATIO, eta_max=eta_fixed)\n",
    "        adaptive_ppi_bounds =      compute_bounds_cmeb_ppi(traj_ppi_pred_adaptive,  alpha=ALPHA, t_opt_ratio=T_OPT_RATIO, eta_max=ETA_MAX)\n",
    "        ideal_ppi_bounds =         compute_bounds_cmeb_ppi(traj_ppi_ideal,  alpha=ALPHA, t_opt_ratio=T_OPT_RATIO, eta_max=eta_fixed)\n",
    "    except Exception as e:\n",
    "        print(f\"[Warning] Skipping exp {exp_id}: compute_bounds() failed — {e}\")\n",
    "        continue\n",
    "    for m in methods:\n",
    "        results[m][\"supervised\"].append(sup_bounds[m])\n",
    "        results[m][\"unsupervised\"].append(unsup_bounds[m])\n",
    "        results[m][\"ppi\"].append(ppi_bounds[m])\n",
    "        results[m][\"ideal_ppi\"].append(ideal_ppi_bounds[m])\n",
    "        results[m][\"adaptive_ppi\"].append(adaptive_ppi_bounds[m])\n",
    "    \n",
    "    for m in methods:\n",
    "        sup_arr = np.asarray(sup_bounds.get(m, []))\n",
    "        unsup_arr = np.asarray(unsup_bounds.get(m, []))\n",
    "        adaptive_ppi_arr = np.asarray(adaptive_ppi_bounds.get(m, []))\n",
    "        ppi_arr = np.asarray(ppi_bounds.get(m, []))\n",
    "        ideal_ppi_arr = np.asarray(ideal_ppi_bounds.get(m, []))\n",
    "\n",
    "        assert len(sup_arr) == len(ppi_arr) == len(ideal_ppi_arr) == len(unsup_arr) == len(traj_sup_true)\n",
    "\n",
    "        # find threshold crossings\n",
    "        t_sup = int(np.argmax(sup_arr > THRESH) + 1) if np.any(sup_arr > THRESH) else np.nan\n",
    "        t_unsup = int(np.argmax(unsup_arr > THRESH) + 1) if np.any(unsup_arr > THRESH) else np.nan\n",
    "        t_ppi = int(np.argmax(ppi_arr > THRESH) + 1) if np.any(ppi_arr > THRESH) else np.nan\n",
    "        t_adaptive = int(np.argmax(adaptive_ppi_arr > THRESH) + 1) if np.any(adaptive_ppi_arr > THRESH) else np.nan\n",
    "        t_ideal = int(np.argmax(ideal_ppi_arr > THRESH) + 1) if np.any(ideal_ppi_arr > THRESH) else np.nan\n",
    "\n",
    "        results_box[m][\"supervised\"].append(t_sup)\n",
    "        results_box[m][\"unsupervised\"].append(t_unsup)\n",
    "        results_box[m][\"ppi\"].append(t_ppi)\n",
    "        results_box[m][\"adaptive_ppi\"].append(t_adaptive)\n",
    "        results_box[m][\"ideal_ppi\"].append(t_ideal)\n",
    "\n",
    "colors = {\"supervised\": \"#1f77b4\", \"unsupervised\": \"#7f7f7f\",  \"ppi\": \"#ff7f0e\",  \"adaptive_ppi\": \"#2ca02c\", \"ideal_ppi\": \"#9467bd\" }\n",
    "labels = {\"supervised\": \"SRM\", \"unsupervised\": \"URM\", \"adaptive_ppi\": r\"PPRM\", \"ppi\": \"PPRM\", \"ideal_ppi\": \"Ideal PPRM\"}\n",
    "fig, ax = plt.subplots(figsize=(6, 6.5))\n",
    "m = methods[0]\n",
    "if True:\n",
    "    for key in [\"supervised\", \"adaptive_ppi\", \"ideal_ppi\"]:\n",
    "        arrs = np.array(results[m][key])\n",
    "        if len(arrs) == 0:\n",
    "            continue\n",
    "        mean_curve = np.nanmean(arrs, axis=0)\n",
    "        std_curve = np.nanstd(arrs, axis=0)\n",
    "        steps = np.arange(len(mean_curve))\n",
    "        ax.plot(steps, mean_curve, label=labels[key], color=colors[key])\n",
    "        ax.fill_between(\n",
    "            steps,\n",
    "            mean_curve - std_curve,\n",
    "            mean_curve + std_curve,\n",
    "            color=colors[key],\n",
    "            alpha=0.2)\n",
    "    try:\n",
    "        traj_ppi_ideal_all = np.array([\n",
    "            running_average_cumulative(traj)\n",
    "            for traj in all_traj_list])\n",
    "        traj_ppi_ideal_mean = np.nanmean(traj_ppi_ideal_all, axis=0)\n",
    "        traj_ppi_ideal_std = np.nanstd(traj_ppi_ideal_all, axis=0)\n",
    "        steps = np.arange(len(traj_ppi_ideal_mean))\n",
    "        ax.plot(\n",
    "            steps,\n",
    "            traj_ppi_ideal_mean,\n",
    "            linestyle=\"--\",\n",
    "            color=\"#9467bd\",\n",
    "            linewidth=2,\n",
    "            label=\"Running Risk\")\n",
    "        ax.fill_between(\n",
    "            steps,\n",
    "            traj_ppi_ideal_mean - traj_ppi_ideal_std,\n",
    "            traj_ppi_ideal_mean + traj_ppi_ideal_std,\n",
    "            color=\"#9467bd\",\n",
    "            alpha=0.15)\n",
    "    except Exception as e:\n",
    "        print(f\"[Warning] Failed to plot averaged PPI_Ideal risk trajectory: {e}\")\n",
    "    font_size = 22\n",
    "    ax.axhline(y=THRESH, color=\"red\", linestyle=\"--\", linewidth=2, label=f\"Risk Threshold\")\n",
    "    ax.set_xlabel(r\"Time Step $t$\", fontsize=font_size)\n",
    "    ax.grid(True, linestyle=\":\", linewidth=2, alpha=0.4)\n",
    "    ax.set_ylabel(\"Running risk lower bound\", fontsize=font_size+1)\n",
    "    ax.legend(fontsize=font_size, loc='upper left', ncol=1)\n",
    "    ax.tick_params(axis='both', labelsize=font_size+1)  \n",
    "    plt.ylim(0.06, 0.18)\n",
    "    plt.tight_layout()\n",
    "    save_dir = \"Simulations/Results_Channel\"\n",
    "    os.makedirs(save_dir, exist_ok=True)  \n",
    "\n",
    "    save_path_pdf = os.path.join(save_dir, f\"sim_fig_ceq_increase_delta{delta_list[:]}_lowerbound_Brier_self.pdf\")\n",
    "    plt.savefig(save_path_pdf, format='pdf', bbox_inches='tight')\n",
    "    plt.show()\n",
    "\n",
    "if True:\n",
    "    font_size = 22.5\n",
    "    plot_params = {\n",
    "        \"title_fontsize\": font_size,\n",
    "        \"xlabel_fontsize\": font_size,\n",
    "        \"ylabel_fontsize\": font_size,\n",
    "        \"xtick_fontsize\": font_size,\n",
    "        \"ytick_fontsize\": font_size,\n",
    "        \"legend_fontsize\": font_size,\n",
    "        \"suptitle_fontsize\": font_size,\n",
    "        \"title_fontweight\": \"bold\",\n",
    "        \"label_fontweight\": \"normal\"\n",
    "    }\n",
    "\n",
    "    colors = [ \"#1f77b4\",\"#2ca02c\",  \"#9467bd\" ]\n",
    "\n",
    "    title_map = {\"srm\": \"SRM\", \"pprm\": \"PPRM\", \"Adaptive pprm\": \"PPRM\", \"ideal pprm\": \"Ideal PPRM\", \"urm\": \"URM\"}\n",
    "    fig, ax = plt.subplots(figsize=(6, 6.5))\n",
    "    m = methods[0]\n",
    "    data_for_box = []\n",
    "    labels_for_box = []\n",
    "\n",
    "    sup_list = np.array(results_box[m][\"supervised\"], dtype=float)\n",
    "    sup_list = sup_list[~np.isnan(sup_list)]\n",
    "    if sup_list.size == 0:\n",
    "        sup_list = np.array([np.nan])\n",
    "    data_for_box.append(sup_list)\n",
    "    labels_for_box.append(\"SRM\")\n",
    "\n",
    "    # adaptive ppi\n",
    "    adaptive_ppi_list = np.array(results_box[m][\"adaptive_ppi\"], dtype=float)\n",
    "    adaptive_ppi_list = adaptive_ppi_list[~np.isnan(adaptive_ppi_list)]\n",
    "    if adaptive_ppi_list.size == 0:\n",
    "        adaptive_ppi_list = np.array([np.nan])\n",
    "    data_for_box.append(adaptive_ppi_list)\n",
    "    labels_for_box.append(\"PPRM\")\n",
    "\n",
    "    ideal_list = np.array(results_box[m][\"ideal_ppi\"], dtype=float)\n",
    "    ideal_list = ideal_list[~np.isnan(ideal_list)]\n",
    "    if ideal_list.size == 0:\n",
    "        ideal_list = np.array([np.nan])\n",
    "    data_for_box.append(ideal_list)\n",
    "    labels_for_box.append(\"Ideal PPRM\")\n",
    "    bp = ax.boxplot(\n",
    "        data_for_box,\n",
    "        labels=labels_for_box,\n",
    "        showmeans=True,\n",
    "        patch_artist=True,\n",
    "        boxprops=dict(linewidth=2),\n",
    "        whiskerprops=dict(linewidth=2),\n",
    "        capprops=dict(linewidth=2),\n",
    "        medianprops=dict(linewidth=2),\n",
    "        showfliers=False)\n",
    "\n",
    "    for patch, color in zip(bp['boxes'], colors):\n",
    "        patch.set_facecolor(color)\n",
    "        patch.set_alpha(0.6)\n",
    "\n",
    "    for mean in bp['means']:\n",
    "        mean.set_markerfacecolor(\"red\")\n",
    "        mean.set_markeredgecolor(\"black\")\n",
    "\n",
    "    ax.tick_params(axis=\"x\", labelsize=plot_params[\"xtick_fontsize\"])\n",
    "    ax.tick_params(axis=\"y\", labelsize=plot_params[\"ytick_fontsize\"])\n",
    "    ax.grid(True, linestyle=\":\", linewidth=2, alpha=0.4)\n",
    "\n",
    "    ax.set_ylabel(\"Average time to alarm\",\n",
    "                    fontsize=plot_params[\"ylabel_fontsize\"],\n",
    "                    fontweight=plot_params[\"label_fontweight\"])\n",
    "\n",
    "    plt.tight_layout(rect=[0, 0, 1, 0.96])\n",
    "save_dir = \"Simulations/Results_Channel\"\n",
    "os.makedirs(save_dir, exist_ok=True)  \n",
    "\n",
    "save_path_pdf = os.path.join(save_dir, f\"sim_fig_ceq_increase_delta{delta_list[:]}_alarm_time_brier_self.pdf\")\n",
    "plt.savefig(save_path_pdf, format='pdf', bbox_inches='tight')\n",
    "plt.show()\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
