{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "collapsed_sections": [
        "k7cB6N4fs8p3",
        "dHQju_Z4sme3"
      ],
      "toc_visible": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "source": [
        "# Common Utility Functions"
      ],
      "metadata": {
        "id": "k7cB6N4fs8p3"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "mXz8bgh2PlbO"
      },
      "outputs": [],
      "source": [
        "import os, time, gc, math\n",
        "from dataclasses import dataclass\n",
        "from typing import Callable, Dict, Any, Tuple, Optional, List\n",
        "\n",
        "import numpy as np\n",
        "import matplotlib.pyplot as plt\n",
        "import seaborn as sns\n",
        "from matplotlib.ticker import StrMethodFormatter, MaxNLocator\n",
        "\n",
        "try:\n",
        "    import pandas as pd\n",
        "except Exception:\n",
        "    pd = None\n",
        "\n",
        "os.environ.setdefault(\"OMP_NUM_THREADS\", \"1\")\n",
        "os.environ.setdefault(\"MKL_NUM_THREADS\", \"1\")\n",
        "os.environ.setdefault(\"OPENBLAS_NUM_THREADS\", \"1\")\n",
        "os.environ.setdefault(\"VECLIB_MAXIMUM_THREADS\", \"1\")\n",
        "os.environ.setdefault(\"NUMEXPR_NUM_THREADS\", \"1\")\n",
        "\n",
        "\n",
        "sns.set_theme(style=\"whitegrid\", context=\"talk\", font_scale=1.15)\n",
        "plt.rcParams.update({\n",
        "    \"figure.dpi\": 180,\n",
        "    \"savefig.dpi\": 300,\n",
        "    \"pdf.fonttype\": 42,\n",
        "    \"ps.fonttype\": 42,\n",
        "    \"axes.titlesize\": 18,\n",
        "    \"axes.labelsize\": 16,\n",
        "    \"legend.fontsize\": 14,\n",
        "    \"xtick.labelsize\": 13,\n",
        "    \"ytick.labelsize\": 13,\n",
        "    \"lines.linewidth\": 2.6,\n",
        "})\n",
        "\n",
        "def _polyfit_predict(x: np.ndarray, y: np.ndarray, deg: int, x_grid: np.ndarray):\n",
        "    \"\"\"Least-squares polynomial fit y ≈ p(x) of degree=deg.\"\"\"\n",
        "    coeffs = np.polyfit(x, y, deg=deg)\n",
        "    yhat = np.polyval(coeffs, x_grid)\n",
        "    return coeffs, yhat\n",
        "\n",
        "def _invfit_predict_with_intercept(x: np.ndarray, y: np.ndarray, power: int, x_grid: np.ndarray):\n",
        "    \"\"\"Least-squares fit y ≈ a*(1/x)^power + b.\"\"\"\n",
        "    phi = (1.0 / x) ** power\n",
        "    A = np.column_stack([phi, np.ones_like(phi)])\n",
        "    (a, b), *_ = np.linalg.lstsq(A, y, rcond=None)\n",
        "    a = float(a); b = float(b)\n",
        "    yhat = a * ((1.0 / x_grid) ** power) + b\n",
        "    return (a, b), yhat\n",
        "\n",
        "def _r2(y: np.ndarray, yhat: np.ndarray) -> float:\n",
        "    y = np.asarray(y, dtype=float)\n",
        "    yhat = np.asarray(yhat, dtype=float)\n",
        "    ss_res = float(np.sum((y - yhat) ** 2))\n",
        "    ss_tot = float(np.sum((y - np.mean(y)) ** 2)) + 1e-12\n",
        "    return 1.0 - ss_res / ss_tot\n",
        "\n",
        "\n",
        "def pretty_lineplot(\n",
        "    df, xcol, ycol, *, title, xlabel, ylabel,\n",
        "    xticks=None, savebase=None,\n",
        "    fits=None,             # e.g. [(\"poly\",2),(\"poly\",3)] or [(\"inv\",1),(\"inv\",2)]\n",
        "    show_fit_stats=True,   # prints params + R^2\n",
        "):\n",
        "    dfp = df.copy()\n",
        "    dfp[ycol] = pd.to_numeric(dfp[ycol], errors=\"coerce\")\n",
        "    dfp[xcol] = pd.to_numeric(dfp[xcol], errors=\"coerce\")\n",
        "    dfp = dfp.dropna(subset=[xcol, ycol]).sort_values(xcol)\n",
        "\n",
        "    fig, ax = plt.subplots(figsize=(9.2, 5.8), dpi=180)\n",
        "    sns.lineplot(data=dfp, x=xcol, y=ycol, marker=\"o\", ax=ax, label=\"Empirical\")\n",
        "\n",
        "    # ---- Fit overlays ----\n",
        "    if fits:\n",
        "        x = dfp[xcol].to_numpy(dtype=float)\n",
        "        y = dfp[ycol].to_numpy(dtype=float)\n",
        "\n",
        "        x_grid = np.linspace(float(np.min(x)), float(np.max(x)), 250)\n",
        "        # avoid division by zero for inverse fits (shouldn't happen for your gamma/Delta anyway)\n",
        "        x_grid_safe = np.clip(x_grid, 1e-12, None)\n",
        "\n",
        "        for kind, k in fits:\n",
        "            if kind == \"poly\":\n",
        "                deg = int(k)\n",
        "                coeffs, yhat_grid = _polyfit_predict(x, y, deg=deg, x_grid=x_grid)\n",
        "                yhat_train = np.polyval(coeffs, x)\n",
        "                r2 = _r2(y, yhat_train)\n",
        "                ax.plot(x_grid, yhat_grid, linestyle=\"--\", linewidth=2.4, label=f\"Poly deg {deg}\")\n",
        "                if show_fit_stats:\n",
        "                    print(f\"[poly deg {deg}] coeffs={coeffs}, R^2={r2:.4f}\")\n",
        "\n",
        "            elif kind == \"inv\":\n",
        "                power = int(k)\n",
        "                (a, b), yhat_grid = _invfit_predict_with_intercept(x, y, power=power, x_grid=x_grid_safe)\n",
        "                yhat_train = a * ((1.0 / x) ** power) + b\n",
        "                r2 = _r2(y, yhat_train)\n",
        "                lab = (rf\"Fit: $a/x + b$\" if power == 1\n",
        "                       else rf\"Fit: $a/x^{power} + b$\")\n",
        "                ax.plot(x_grid, yhat_grid, linestyle=\"--\", linewidth=2.4, label=lab)\n",
        "                if show_fit_stats:\n",
        "                    print(f\"[inv power {power}] a={a:.4g}, b={b:.4g}, R^2={r2:.4f}\")\n",
        "\n",
        "            else:\n",
        "                raise ValueError(f\"Unknown fit kind: {kind}\")\n",
        "\n",
        "    ax.set_title(title, pad=12)\n",
        "    ax.set_xlabel(xlabel)\n",
        "    ax.set_ylabel(ylabel)\n",
        "\n",
        "    if xticks is not None:\n",
        "        ax.set_xticks(xticks)\n",
        "        ax.set_xticklabels([str(t) for t in xticks])\n",
        "\n",
        "    ax.yaxis.set_major_locator(MaxNLocator(integer=True))\n",
        "    ax.yaxis.set_major_formatter(StrMethodFormatter(\"{x:,.0f}\"))\n",
        "    sns.despine(ax=ax)\n",
        "    ax.legend(frameon=True, loc=\"best\")\n",
        "    plt.tight_layout()\n",
        "\n",
        "    if savebase is not None:\n",
        "        fig.savefig(f\"{savebase}.pdf\", bbox_inches=\"tight\")\n",
        "        fig.savefig(f\"{savebase}.png\", bbox_inches=\"tight\")\n",
        "        print(\"Saved:\", f\"{savebase}.pdf and {savebase}.png\")\n",
        "\n",
        "    plt.show()\n",
        "\n",
        "# -------------------------\n",
        "# small utilities\n",
        "# -------------------------\n",
        "def now() -> str:\n",
        "    return time.strftime(\"%Y-%m-%d %H:%M:%S\")\n",
        "\n",
        "def log_print(verbose: bool, msg: str):\n",
        "    if verbose:\n",
        "        print(f\"[{now()}] {msg}\", flush=True)\n",
        "\n",
        "def fmt_bytes(nbytes: int) -> str:\n",
        "    if nbytes < 1024: return f\"{nbytes} B\"\n",
        "    if nbytes < 1024**2: return f\"{nbytes/1024:.2f} KB\"\n",
        "    if nbytes < 1024**3: return f\"{nbytes/1024**2:.2f} MB\"\n",
        "    return f\"{nbytes/1024**3:.2f} GB\"\n",
        "\n",
        "def clamp_int(x: int, lo: int, hi: int) -> int:\n",
        "    return max(lo, min(hi, int(x)))\n",
        "\n",
        "def round_to_step(n: int, step: int, base: int) -> int:\n",
        "    if step <= 0:\n",
        "        return int(n)\n",
        "    k = int(round((n - base) / step))\n",
        "    return int(base + k * step)\n",
        "\n",
        "\n",
        "# =========================\n",
        "# Dataset 1: Spiked identity model\n",
        "#   Σ = (1-γ)I + γ v v^T\n",
        "# =========================\n",
        "def make_sparse_spike(d: int, s: int, seed: int) -> np.ndarray:\n",
        "    rng = np.random.default_rng(seed)\n",
        "    v = np.zeros(d, dtype=np.float32)\n",
        "    support = rng.choice(d, size=s, replace=False)\n",
        "    signs = rng.choice([-1.0, 1.0], size=s).astype(np.float32)\n",
        "    v[support] = signs / np.sqrt(np.float32(s))\n",
        "    return v\n",
        "\n",
        "def sample_spiked_gaussian(n: int, d: int, v: np.ndarray, gamma: float, seed: int) -> np.ndarray:\n",
        "    assert 0.0 < gamma < 1.0\n",
        "    rng = np.random.default_rng(seed)\n",
        "    a = np.float32(np.sqrt(1.0 - gamma))\n",
        "    b = np.float32(np.sqrt(gamma))\n",
        "    Z = rng.standard_normal(size=(n, d)).astype(np.float32)\n",
        "    X = a * Z\n",
        "    y = rng.standard_normal(size=(n, 1)).astype(np.float32)\n",
        "    X += (b * y) * v.reshape(1, -1)\n",
        "    return X\n",
        "\n",
        "\n",
        "# =========================\n",
        "# Dataset 2: Lemma-4 worst-case (embedded into large d)\n",
        "#   Base block dimension d0 = 2s-1\n",
        "#   Σ_block(γ) = v v^T + (1-γ) Σ_{r=1}^{s-1} u_r u_r^T\n",
        "#   Extra coords (d-d0) get isotropic variance τ = 1-γ\n",
        "# =========================\n",
        "def _householder_map(x: np.ndarray, t: np.ndarray) -> np.ndarray:\n",
        "    x = x.astype(np.float64)\n",
        "    t = t.astype(np.float64)\n",
        "    if np.allclose(x, t, atol=1e-12):\n",
        "        return np.eye(x.shape[0], dtype=np.float64)\n",
        "    u = x - t\n",
        "    nu = np.linalg.norm(u)\n",
        "    if nu < 1e-15:\n",
        "        return np.eye(x.shape[0], dtype=np.float64)\n",
        "    u = u / nu\n",
        "    return np.eye(x.shape[0], dtype=np.float64) - 2.0 * np.outer(u, u)\n",
        "\n",
        "def lemma16_basis(dim: int) -> np.ndarray:\n",
        "    # deterministic orthonormal basis with u1 = 1/sqrt(dim) * 1, and other cols have first coord = 1/sqrt(dim)\n",
        "    assert dim >= 2\n",
        "    u1 = np.ones(dim, dtype=np.float64) / math.sqrt(dim)\n",
        "    e1 = np.zeros(dim, dtype=np.float64); e1[0] = 1.0\n",
        "\n",
        "    w = e1 - (1.0 / math.sqrt(dim)) * u1\n",
        "    w_norm = np.linalg.norm(w)\n",
        "\n",
        "    B = np.zeros((dim, dim - 1), dtype=np.float64)\n",
        "    for j in range(dim - 1):\n",
        "        B[0, j] = -1.0\n",
        "        B[j + 1, j] = 1.0\n",
        "\n",
        "    V, _ = np.linalg.qr(B)\n",
        "    x = V.T @ (w / (w_norm + 1e-18))\n",
        "    t = np.ones(dim - 1, dtype=np.float64) / math.sqrt(dim - 1)\n",
        "    Q = _householder_map(x, t)\n",
        "\n",
        "    U_rest = V @ Q.T\n",
        "    U = np.column_stack([u1, U_rest])\n",
        "\n",
        "    if not np.allclose(U.T @ U, np.eye(dim), atol=1e-8):\n",
        "        raise RuntimeError(\"lemma16_basis: orthonormality failed\")\n",
        "    if not np.allclose(U[0, 1:], 1.0 / math.sqrt(dim), atol=1e-6):\n",
        "        raise RuntimeError(\"lemma16_basis: first-coordinate property failed\")\n",
        "\n",
        "    return U.astype(np.float64)\n",
        "\n",
        "def make_lemma4_vectors_embedded(d: int, s: int) -> Tuple[np.ndarray, np.ndarray, int]:\n",
        "    assert s >= 2\n",
        "    d0 = 2 * s - 1\n",
        "    assert d >= d0, f\"Need d >= 2s-1 = {d0}, got d={d}\"\n",
        "\n",
        "    U_small = lemma16_basis(s)         # columns: [v_small, g_1..g_{s-1}]\n",
        "    v_small = U_small[:, 0]\n",
        "    G = U_small[:, 1:]                 # g_r\n",
        "\n",
        "    v = np.zeros(d, dtype=np.float64)\n",
        "    v[:s] = v_small                    # = 1/sqrt(s) * 1_s\n",
        "\n",
        "    U = np.zeros((d, s - 1), dtype=np.float64)\n",
        "    for r in range(s - 1):\n",
        "        u = np.zeros(d, dtype=np.float64)\n",
        "        u[:s] += (1.0 / math.sqrt(2.0)) * G[:, r]\n",
        "        u[s + r] += 1.0 / math.sqrt(2.0)\n",
        "        U[:, r] = u\n",
        "\n",
        "    # sanity\n",
        "    if not np.allclose(U.T @ U, np.eye(s - 1), atol=1e-8):\n",
        "        raise RuntimeError(\"Lemma-4 embed: u_r not orthonormal\")\n",
        "    if not np.allclose(U.T @ v, np.zeros(s - 1), atol=1e-8):\n",
        "        raise RuntimeError(\"Lemma-4 embed: u_r not orthogonal to v\")\n",
        "    if not np.allclose(np.linalg.norm(v), 1.0, atol=1e-10):\n",
        "        raise RuntimeError(\"Lemma-4 embed: v not unit\")\n",
        "\n",
        "    return v, U, d0\n",
        "\n",
        "def sample_lemma4_embedded(n: int, d: int, s: int, gamma: float, seed: int) -> Tuple[np.ndarray, np.ndarray]:\n",
        "    assert 0.0 < gamma < 1.0\n",
        "    v, U, d0 = make_lemma4_vectors_embedded(d=d, s=s)\n",
        "    alpha = 1.0 - float(gamma)   # <- \"second eigenvalue\" knob\n",
        "\n",
        "    rng = np.random.default_rng(seed)\n",
        "\n",
        "    z0 = rng.standard_normal(size=(n, 1)).astype(np.float32)\n",
        "    zr = rng.standard_normal(size=(n, s - 1)).astype(np.float32)\n",
        "    X = z0 * v.reshape(1, -1).astype(np.float32)\n",
        "    X += (math.sqrt(alpha) * (zr @ U.T.astype(np.float32)))\n",
        "\n",
        "    # \"add an identity\" on extra dims at the same scale alpha = 1-γ\n",
        "    if d > d0:\n",
        "        Zextra = rng.standard_normal(size=(n, d - d0)).astype(np.float32)\n",
        "        X[:, d0:] = math.sqrt(alpha) * Zextra\n",
        "\n",
        "    return X, v.astype(np.float64)\n",
        "\n",
        "\n",
        "# =========================\n",
        "# RTPM pieces\n",
        "# =========================\n",
        "def hatSigma_times_U(X: np.ndarray, U: np.ndarray) -> np.ndarray:\n",
        "    n = X.shape[0]\n",
        "    return (X.T @ (X @ U)) / float(n)\n",
        "\n",
        "def normalize_columns(U: np.ndarray, eps: float = 1e-12) -> np.ndarray:\n",
        "    norms = np.linalg.norm(U, axis=0) + eps\n",
        "    return U / norms\n",
        "\n",
        "def top_r_truncate_columns(W: np.ndarray, r: int) -> np.ndarray:\n",
        "    d, m = W.shape\n",
        "    if r >= d:\n",
        "        return W\n",
        "    absW = np.abs(W)\n",
        "    idx = np.argpartition(absW, -r, axis=0)[-r:, :]\n",
        "    out = np.zeros_like(W)\n",
        "    cols = np.arange(m)[None, :]\n",
        "    out[idx, cols] = W[idx, cols]\n",
        "    return out\n",
        "\n",
        "def rtpm_early_stop(\n",
        "    X: np.ndarray,\n",
        "    r: int,\n",
        "    T_max: int,\n",
        "    restart_indices: np.ndarray,\n",
        "    v_true: np.ndarray,\n",
        "    Delta_eff: float,\n",
        "    block_size: int = 512,\n",
        "    check_every: int = 5,\n",
        ") -> dict:\n",
        "    n, d = X.shape\n",
        "    m = int(restart_indices.shape[0])\n",
        "\n",
        "    U = np.zeros((d, m), dtype=np.float32)\n",
        "    U[restart_indices, np.arange(m)] = 1.0\n",
        "\n",
        "    v = v_true.astype(np.float64)\n",
        "    v = v / (np.linalg.norm(v) + 1e-12)\n",
        "\n",
        "    best_sin2 = 1.0\n",
        "    best_idx = 0\n",
        "\n",
        "    for it in range(1, T_max + 1):\n",
        "        W = np.empty_like(U)\n",
        "        for j0 in range(0, m, block_size):\n",
        "            j1 = min(m, j0 + block_size)\n",
        "            W[:, j0:j1] = hatSigma_times_U(X, U[:, j0:j1])\n",
        "\n",
        "        U = top_r_truncate_columns(W, r)\n",
        "        U = normalize_columns(U)\n",
        "        del W\n",
        "\n",
        "        if it == 1 or it % check_every == 0 or it == T_max:\n",
        "            c = (U.astype(np.float64).T @ v)\n",
        "            c2 = c * c\n",
        "            j = int(np.argmax(c2))\n",
        "            sin2 = float(max(0.0, 1.0 - c2[j]))\n",
        "            if sin2 < best_sin2:\n",
        "                best_sin2 = sin2\n",
        "                best_idx = j\n",
        "            if best_sin2 <= Delta_eff:\n",
        "                return {\"sin2_best\": best_sin2, \"t_stop\": it, \"best_restart\": best_idx}\n",
        "\n",
        "    return {\"sin2_best\": best_sin2, \"t_stop\": T_max, \"best_restart\": best_idx}\n",
        "\n",
        "def choose_rtpm_hyperparams(d: int, s: int, Delta: float, r_mult: float, T_min: int) -> Tuple[int, int]:\n",
        "    r = int(min(d, max(s, math.ceil(r_mult * s))))\n",
        "    T_max = int(T_min)   # keep as in your original experiments\n",
        "    return r, T_max\n",
        "\n",
        "\n",
        "# =========================\n",
        "# Warm-start linear grid search over n\n",
        "# =========================\n",
        "def eval_n_one_trial(\n",
        "    n: int,\n",
        "    *,\n",
        "    dataset: str,          # \"spiked\" or \"lemma4\"\n",
        "    d: int,\n",
        "    s: int,\n",
        "    gamma: float,\n",
        "    Delta_eff: float,\n",
        "    seed_setting: int,\n",
        "    seed_data: int,\n",
        "    restart_indices: np.ndarray,\n",
        "    r: int,\n",
        "    T_max: int,\n",
        "    block_size: int,\n",
        "    check_every: int,\n",
        "    verbose: bool,\n",
        ") -> dict:\n",
        "    t0 = time.perf_counter()\n",
        "\n",
        "    if dataset == \"spiked\":\n",
        "        v = make_sparse_spike(d, s, seed_setting + 11_000_000).astype(np.float64)\n",
        "        X = sample_spiked_gaussian(n, d, v.astype(np.float32), gamma, seed_data)\n",
        "        v_true = v\n",
        "    elif dataset == \"lemma4\":\n",
        "        X, v_true = sample_lemma4_embedded(n=n, d=d, s=s, gamma=gamma, seed=seed_data)\n",
        "    else:\n",
        "        raise ValueError(\"dataset must be 'spiked' or 'lemma4'\")\n",
        "\n",
        "    out = rtpm_early_stop(\n",
        "        X=X, r=r, T_max=T_max,\n",
        "        restart_indices=restart_indices,\n",
        "        v_true=v_true,\n",
        "        Delta_eff=Delta_eff,\n",
        "        block_size=block_size,\n",
        "        check_every=check_every,\n",
        "    )\n",
        "    sin2 = float(out[\"sin2_best\"])\n",
        "    elapsed = float(time.perf_counter() - t0)\n",
        "\n",
        "    log_print(verbose, f\"  n={n}: sin^2={sin2:.6f}, good={sin2 <= Delta_eff}, t_stop={out['t_stop']}/{T_max}, time={elapsed:.2f}s\")\n",
        "    del X\n",
        "    gc.collect()\n",
        "\n",
        "    return {\"n\": int(n), \"sin2\": sin2, \"good\": bool(sin2 <= Delta_eff)}\n",
        "\n",
        "def find_n_scale_linear_warm(\n",
        "    *,\n",
        "    dataset: str,\n",
        "    d: int,\n",
        "    s: int,\n",
        "    gamma: float,\n",
        "    Delta: float,\n",
        "    m_restarts: int,\n",
        "    n_min: int,\n",
        "    n_max: int,\n",
        "    n_step: int,\n",
        "    start_n: int,\n",
        "    seed_setting: int,\n",
        "    r_mult: float,\n",
        "    T_min: int,\n",
        "    block_size: int,\n",
        "    check_every: int,\n",
        "    margin_frac: float,\n",
        "    verbose: bool,\n",
        ") -> dict:\n",
        "    # Restarts (coordinate basis) — same idea as before\n",
        "    rng = np.random.default_rng(seed_setting + 22_000_000)\n",
        "    replace = (m_restarts > d)\n",
        "    restart_indices = rng.choice(d, size=m_restarts, replace=replace).astype(np.int64)\n",
        "\n",
        "    r, T_max = choose_rtpm_hyperparams(d=d, s=s, Delta=Delta, r_mult=r_mult, T_min=T_min)\n",
        "    Delta_eff = (1.0 - margin_frac) * float(Delta)\n",
        "    seed_data = int(seed_setting + 31_000_000)  # fixed per setting (same as original)\n",
        "\n",
        "    log_print(verbose, \"------------------------------------------------------------\")\n",
        "    log_print(verbose, f\"{dataset} | d={d}, s={s}, gamma={gamma}, Delta={Delta}\")\n",
        "    log_print(verbose, f\"Grid: n in [{n_min},{n_max}] step={n_step}, warm start {start_n}\")\n",
        "    log_print(verbose, f\"RTPM: r={r}, T_max={T_max}, m_restarts={m_restarts}, Delta_eff={Delta_eff:.4f}\")\n",
        "\n",
        "    cache: Dict[int, dict] = {}\n",
        "\n",
        "    def evaluate(n: int) -> dict:\n",
        "        if n in cache:\n",
        "            rec = cache[n]\n",
        "            log_print(verbose, f\"  n={n}: (cached) sin^2={rec['sin2']:.6f}, good={rec['good']}\")\n",
        "            return rec\n",
        "        rec = eval_n_one_trial(\n",
        "            n=n,\n",
        "            dataset=dataset,\n",
        "            d=d, s=s, gamma=gamma,\n",
        "            Delta_eff=Delta_eff,\n",
        "            seed_setting=seed_setting,\n",
        "            seed_data=seed_data,\n",
        "            restart_indices=restart_indices,\n",
        "            r=r, T_max=T_max,\n",
        "            block_size=block_size,\n",
        "            check_every=check_every,\n",
        "            verbose=verbose,\n",
        "        )\n",
        "        cache[n] = rec\n",
        "        return rec\n",
        "\n",
        "    start_n = clamp_int(start_n, n_min, n_max)\n",
        "    start_n = round_to_step(start_n, n_step, n_min)\n",
        "    start_n = clamp_int(start_n, n_min, n_max)\n",
        "\n",
        "    rec0 = evaluate(start_n)\n",
        "\n",
        "    if rec0[\"good\"]:\n",
        "        # walk down\n",
        "        n_pass = start_n\n",
        "        n = start_n - n_step\n",
        "        while n >= n_min:\n",
        "            rec = evaluate(n)\n",
        "            if rec[\"good\"]:\n",
        "                n_pass = n\n",
        "                n -= n_step\n",
        "            else:\n",
        "                break\n",
        "        return {\"n_scale\": n_pass, \"r\": r, \"T_max\": T_max}\n",
        "    else:\n",
        "        # walk up\n",
        "        n = start_n + n_step\n",
        "        while n <= n_max:\n",
        "            rec = evaluate(n)\n",
        "            if rec[\"good\"]:\n",
        "                return {\"n_scale\": n, \"r\": r, \"T_max\": T_max}\n",
        "            n += n_step\n",
        "        return {\"n_scale\": n_max, \"r\": r, \"T_max\": T_max}"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Spike Model Experiments"
      ],
      "metadata": {
        "id": "dHQju_Z4sme3"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# =========================\n",
        "# Cell 2/7: SPIKED - Experiment A (vary s)\n",
        "# sweep: d=2500, gamma=0.8, Delta=0.1, s in {4,8,16,20}, n in [400,2400] step 100\n",
        "# =========================\n",
        "\n",
        "dataset = \"spiked\"\n",
        "d = 2500\n",
        "gamma0 = 0.8\n",
        "Delta = 0.1\n",
        "s_list = [4, 8, 16, 20]\n",
        "\n",
        "n_min, n_max, n_step = 400, 2400, 100\n",
        "start_n = n_min\n",
        "\n",
        "m_restarts = 1000\n",
        "r_mult = 10.0\n",
        "T_min = 100\n",
        "block_size = 512\n",
        "check_every = 5\n",
        "margin_frac = 0.02\n",
        "verbose = True\n",
        "\n",
        "rows = []\n",
        "prev_n = start_n\n",
        "\n",
        "for s in s_list:\n",
        "    seed_setting = 100_000 + int(s)\n",
        "    out = find_n_scale_linear_warm(\n",
        "        dataset=dataset,\n",
        "        d=d, s=s, gamma=gamma0, Delta=Delta,\n",
        "        m_restarts=m_restarts,\n",
        "        n_min=n_min, n_max=n_max, n_step=n_step,\n",
        "        start_n=prev_n,\n",
        "        seed_setting=seed_setting,\n",
        "        r_mult=r_mult, T_min=T_min,\n",
        "        block_size=block_size, check_every=check_every,\n",
        "        margin_frac=margin_frac, verbose=verbose,\n",
        "    )\n",
        "    rows.append({\"s\": s, \"n_scale\": out[\"n_scale\"], \"r\": out[\"r\"], \"T_max\": out[\"T_max\"]})\n",
        "    if out[\"n_scale\"] is not None:\n",
        "        prev_n = out[\"n_scale\"]\n",
        "\n",
        "df_spA = pd.DataFrame(rows)\n",
        "display(df_spA)\n",
        "\n",
        "pretty_lineplot(\n",
        "    df_spA, \"s\", \"n_scale\",\n",
        "    title=fr\"Spiked: $n_{{scale}}$ vs $s$ (d={d}, $\\gamma$={gamma0}, $\\Delta$={Delta})\",\n",
        "    xlabel=r\"$s$ (sparsity)\",\n",
        "    ylabel=r\"$n_{\\mathrm{scale}}$\",\n",
        "    xticks=s_list,\n",
        "    savebase=\"spiked_expA_vs_s\",\n",
        "    fits=[(\"poly\",2), (\"poly\",3)]\n",
        ")"
      ],
      "metadata": {
        "id": "1uOZzNdtsik2"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# =========================\n",
        "# Cell 3/7: SPIKED - Experiment B (vary gamma)\n",
        "# sweep: d=2500, s=5, Delta=0.1, gamma in {0.7,0.75,0.85,0.9}, n in [200,2000] step 50\n",
        "# =========================\n",
        "\n",
        "dataset = \"spiked\"\n",
        "d = 2500\n",
        "s0 = 5\n",
        "Delta = 0.1\n",
        "gamma_list = [0.7, 0.75, 0.85, 0.9]\n",
        "\n",
        "n_min, n_max, n_step = 200, 2000, 50\n",
        "start_n = n_max\n",
        "\n",
        "m_restarts = 1000\n",
        "r_mult = 10.0\n",
        "T_min = 100\n",
        "block_size = 512\n",
        "check_every = 5\n",
        "margin_frac = 0.02\n",
        "verbose = True\n",
        "\n",
        "rows = []\n",
        "prev_n = start_n\n",
        "\n",
        "for g in gamma_list:\n",
        "    seed_setting = 200_000 + int(10_000 * g)\n",
        "    out = find_n_scale_linear_warm(\n",
        "        dataset=dataset,\n",
        "        d=d, s=s0, gamma=g, Delta=Delta,\n",
        "        m_restarts=m_restarts,\n",
        "        n_min=n_min, n_max=n_max, n_step=n_step,\n",
        "        start_n=prev_n,\n",
        "        seed_setting=seed_setting,\n",
        "        r_mult=r_mult, T_min=T_min,\n",
        "        block_size=block_size, check_every=check_every,\n",
        "        margin_frac=margin_frac, verbose=verbose,\n",
        "    )\n",
        "    rows.append({\"gamma\": g, \"n_scale\": out[\"n_scale\"], \"r\": out[\"r\"], \"T_max\": out[\"T_max\"]})\n",
        "    if out[\"n_scale\"] is not None:\n",
        "        prev_n = out[\"n_scale\"]\n",
        "\n",
        "df_spB = pd.DataFrame(rows)\n",
        "display(df_spB)\n",
        "\n",
        "pretty_lineplot(\n",
        "    df_spB, \"gamma\", \"n_scale\",\n",
        "    title=fr\"Spiked: $n_{{scale}}$ vs $\\gamma$ (d={d}, s={s0}, $\\Delta$={Delta})\",\n",
        "    xlabel=r\"$\\gamma$ (eigengap)\",\n",
        "    ylabel=r\"$n_{\\mathrm{scale}}$\",\n",
        "    xticks=gamma_list,\n",
        "    savebase=\"spiked_expB_vs_gamma\",\n",
        "    fits=[(\"inv\",1), (\"inv\",2)]\n",
        ")"
      ],
      "metadata": {
        "id": "4T7NOskMsq3G"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# =========================\n",
        "# Cell 4/7: SPIKED - Experiment C (vary Delta)\n",
        "# weep: d=2500, s=5, gamma=0.8, Delta in {0.1,0.15,0.2,0.4,0.5}, n in [50,1000] step 50\n",
        "# =========================\n",
        "\n",
        "dataset = \"spiked\"\n",
        "d = 2500\n",
        "s0 = 5\n",
        "gamma0 = 0.8\n",
        "Delta_list = [0.1, 0.15, 0.2, 0.4, 0.5]\n",
        "\n",
        "n_min, n_max, n_step = 50, 1000, 50\n",
        "start_n = n_max\n",
        "\n",
        "m_restarts = 1000\n",
        "r_mult = 10.0\n",
        "T_min = 100\n",
        "block_size = 512\n",
        "check_every = 5\n",
        "margin_frac = 0.02\n",
        "verbose = True\n",
        "\n",
        "rows = []\n",
        "prev_n = start_n\n",
        "\n",
        "for D in Delta_list:\n",
        "    seed_setting = 300_000 + int(10_000 * D)\n",
        "    out = find_n_scale_linear_warm(\n",
        "        dataset=dataset,\n",
        "        d=d, s=s0, gamma=gamma0, Delta=D,\n",
        "        m_restarts=m_restarts,\n",
        "        n_min=n_min, n_max=n_max, n_step=n_step,\n",
        "        start_n=prev_n,\n",
        "        seed_setting=seed_setting,\n",
        "        r_mult=r_mult, T_min=T_min,\n",
        "        block_size=block_size, check_every=check_every,\n",
        "        margin_frac=margin_frac, verbose=verbose,\n",
        "    )\n",
        "    rows.append({\"Delta\": D, \"n_scale\": out[\"n_scale\"], \"r\": out[\"r\"], \"T_max\": out[\"T_max\"]})\n",
        "    if out[\"n_scale\"] is not None:\n",
        "        prev_n = out[\"n_scale\"]\n",
        "\n",
        "df_spC = pd.DataFrame(rows)\n",
        "display(df_spC)\n",
        "\n",
        "pretty_lineplot(\n",
        "    df_spC, \"Delta\", \"n_scale\",\n",
        "    title=fr\"Spiked: $n_{{scale}}$ vs $\\Delta$ (d={d}, s={s0}, $\\gamma$={gamma0})\",\n",
        "    xlabel=r\"$\\Delta$ (target $\\sin^2$ error)\",\n",
        "    ylabel=r\"$n_{\\mathrm{scale}}$\",\n",
        "    xticks=Delta_list,\n",
        "    savebase=\"spiked_expC_vs_Delta\",\n",
        "    fits=[(\"inv\",1), (\"inv\",2)]\n",
        ")"
      ],
      "metadata": {
        "id": "w3tDP1Iosr9Y"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Greedy Correlation Counterexample Example Experiments"
      ],
      "metadata": {
        "id": "13f3mvWTsuEi"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# =========================\n",
        "# Cell 5/7: LEMMA-4 - Experiment A (vary s)\n",
        "# ssweep: d=2500, gamma=0.8, Delta=0.1, s in {4,8,16,20}, n in [400,2400] step 100\n",
        "# =========================\n",
        "\n",
        "dataset = \"lemma4\"\n",
        "d = 2500\n",
        "gamma0 = 0.8\n",
        "Delta = 0.1\n",
        "s_list = [4, 8, 16, 20]\n",
        "\n",
        "n_min, n_max, n_step = 400, 2400, 100\n",
        "start_n = n_min\n",
        "\n",
        "m_restarts = 1000\n",
        "r_mult = 10.0\n",
        "T_min = 100\n",
        "block_size = 512\n",
        "check_every = 5\n",
        "margin_frac = 0.02\n",
        "verbose = True\n",
        "\n",
        "rows = []\n",
        "prev_n = start_n\n",
        "\n",
        "for s in s_list:\n",
        "    seed_setting = 400_000 + int(s)\n",
        "    out = find_n_scale_linear_warm(\n",
        "        dataset=dataset,\n",
        "        d=d, s=s, gamma=gamma0, Delta=Delta,\n",
        "        m_restarts=m_restarts,\n",
        "        n_min=n_min, n_max=n_max, n_step=n_step,\n",
        "        start_n=prev_n,\n",
        "        seed_setting=seed_setting,\n",
        "        r_mult=r_mult, T_min=T_min,\n",
        "        block_size=block_size, check_every=check_every,\n",
        "        margin_frac=margin_frac, verbose=verbose,\n",
        "    )\n",
        "    rows.append({\"s\": s, \"n_scale\": out[\"n_scale\"], \"r\": out[\"r\"], \"T_max\": out[\"T_max\"]})\n",
        "    if out[\"n_scale\"] is not None:\n",
        "        prev_n = out[\"n_scale\"]\n",
        "\n",
        "df_l4A = pd.DataFrame(rows)\n",
        "display(df_l4A)\n",
        "\n",
        "pretty_lineplot(\n",
        "    df_l4A, \"s\", \"n_scale\",\n",
        "    title=fr\"Greedy Correlation Counterexample: $n_{{scale}}$ vs $s$ (d={d}, $\\gamma$={gamma0}, $\\Delta$={Delta})\",\n",
        "    xlabel=r\"$s$ (sparsity)\",\n",
        "    ylabel=r\"$n_{\\mathrm{scale}}$\",\n",
        "    xticks=s_list,\n",
        "    savebase=\"lemma4_expA_vs_s\",\n",
        "    fits=[(\"poly\",2), (\"poly\",3)]\n",
        ")"
      ],
      "metadata": {
        "id": "vTHJ8OmqszTN"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# =========================\n",
        "# Cell 6/7: LEMMA-4 - Experiment B (vary gamma)\n",
        "# sweep: d=2500, s=5, Delta=0.1, gamma in {0.7,0.75,0.85,0.9}, n in [200,2000] step 50\n",
        "# =========================\n",
        "\n",
        "dataset = \"lemma4\"\n",
        "d = 2500\n",
        "s0 = 5\n",
        "Delta = 0.1\n",
        "gamma_list = [0.7, 0.75, 0.85, 0.9]\n",
        "\n",
        "n_min, n_max, n_step = 200, 2000, 50\n",
        "start_n = n_max\n",
        "\n",
        "m_restarts = 1000\n",
        "r_mult = 10.0\n",
        "T_min = 100\n",
        "block_size = 512\n",
        "check_every = 5\n",
        "margin_frac = 0.02\n",
        "verbose = True\n",
        "\n",
        "rows = []\n",
        "prev_n = start_n\n",
        "\n",
        "for g in gamma_list:\n",
        "    seed_setting = 500_000 + int(10_000 * g)\n",
        "    out = find_n_scale_linear_warm(\n",
        "        dataset=dataset,\n",
        "        d=d, s=s0, gamma=g, Delta=Delta,\n",
        "        m_restarts=m_restarts,\n",
        "        n_min=n_min, n_max=n_max, n_step=n_step,\n",
        "        start_n=prev_n,\n",
        "        seed_setting=seed_setting,\n",
        "        r_mult=r_mult, T_min=T_min,\n",
        "        block_size=block_size, check_every=check_every,\n",
        "        margin_frac=margin_frac, verbose=verbose,\n",
        "    )\n",
        "    rows.append({\"gamma\": g, \"n_scale\": out[\"n_scale\"], \"r\": out[\"r\"], \"T_max\": out[\"T_max\"]})\n",
        "    if out[\"n_scale\"] is not None:\n",
        "        prev_n = out[\"n_scale\"]\n",
        "\n",
        "df_l4B = pd.DataFrame(rows)\n",
        "display(df_l4B)\n",
        "\n",
        "pretty_lineplot(\n",
        "    df_l4B, \"gamma\", \"n_scale\",\n",
        "    title=fr\"Greedy Correlation Counterexample: $n_{{scale}}$ vs $\\gamma$ (d={d}, s={s0}, $\\Delta$={Delta})\",\n",
        "    xlabel=r\"$\\gamma$ (eigengap)\",\n",
        "    ylabel=r\"$n_{\\mathrm{scale}}$\",\n",
        "    xticks=gamma_list,\n",
        "    savebase=\"lemma4_expB_vs_gamma\",\n",
        "    fits=[(\"inv\",1), (\"inv\",2)]\n",
        ")"
      ],
      "metadata": {
        "id": "j1CrKGzKs0LL"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# =========================\n",
        "# Cell 7/7: LEMMA-4 - Experiment C (vary Delta)\n",
        "# sweep: d=2500, s=5, gamma=0.8, Delta in {0.1,0.15,0.2,0.4,0.5}, n in [50,1000] step 50\n",
        "# =========================\n",
        "\n",
        "dataset = \"lemma4\"\n",
        "d = 2500\n",
        "s0 = 5\n",
        "gamma0 = 0.8\n",
        "Delta_list = [0.1, 0.15, 0.2, 0.4, 0.5]\n",
        "\n",
        "n_min, n_max, n_step = 50, 1000, 50\n",
        "start_n = n_max\n",
        "\n",
        "m_restarts = 1000\n",
        "r_mult = 10.0\n",
        "T_min = 100\n",
        "block_size = 512\n",
        "check_every = 5\n",
        "margin_frac = 0.02\n",
        "verbose = True\n",
        "\n",
        "rows = []\n",
        "prev_n = start_n\n",
        "\n",
        "for D in Delta_list:\n",
        "    seed_setting = 600_000 + int(10_000 * D)\n",
        "    out = find_n_scale_linear_warm(\n",
        "        dataset=dataset,\n",
        "        d=d, s=s0, gamma=gamma0, Delta=D,\n",
        "        m_restarts=m_restarts,\n",
        "        n_min=n_min, n_max=n_max, n_step=n_step,\n",
        "        start_n=prev_n,\n",
        "        seed_setting=seed_setting,\n",
        "        r_mult=r_mult, T_min=T_min,\n",
        "        block_size=block_size, check_every=check_every,\n",
        "        margin_frac=margin_frac, verbose=verbose,\n",
        "    )\n",
        "    rows.append({\"Delta\": D, \"n_scale\": out[\"n_scale\"], \"r\": out[\"r\"], \"T_max\": out[\"T_max\"]})\n",
        "    if out[\"n_scale\"] is not None:\n",
        "        prev_n = out[\"n_scale\"]\n",
        "\n",
        "df_l4C = pd.DataFrame(rows)\n",
        "display(df_l4C)\n",
        "\n",
        "pretty_lineplot(\n",
        "    df_l4C, \"Delta\", \"n_scale\",\n",
        "    title=fr\"Greedy Correlation Counterexample: $n_{{scale}}$ vs $\\Delta$ (d={d}, s={s0}, $\\gamma$={gamma0})\",\n",
        "    xlabel=r\"$\\Delta$ (target $\\sin^2$ error)\",\n",
        "    ylabel=r\"$n_{\\mathrm{scale}}$\",\n",
        "    xticks=Delta_list,\n",
        "    savebase=\"lemma4_expC_vs_Delta\",\n",
        "    fits=[(\"inv\",1), (\"inv\",2)]\n",
        ")"
      ],
      "metadata": {
        "id": "b7-6w5-is2QQ"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "QTeLZi-W267U"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}