{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "82d73e35",
   "metadata": {},
   "source": [
    "# Figure 2 (Illustration Plot)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "da13551b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "old_rc = plt.rcParams.copy()\n",
    "\n",
    "plt.rcParams.update({\n",
    "    \"font.size\": 16,        \n",
    "    \"axes.titlesize\": 20,   \n",
    "    \"axes.labelsize\": 18,   \n",
    "    \"legend.fontsize\": 13,  \n",
    "})\n",
    "\n",
    "\n",
    "T_min, T_max = 1.0, 100.0\n",
    "T1, T2 = 25.0, 70.0 \n",
    "\n",
    "C = 0.4\n",
    "kappa = np.exp(2*np.log(C)/T1)  # ensure continuity at T1\n",
    "A = C/np.sqrt(T2)               # ensure continuity at T2\n",
    "\n",
    "T = np.linspace(T_min, T_max, 800)\n",
    "decay = kappa**(0.5*T)            \n",
    "plateau = np.full_like(T, C)      \n",
    "rise = A*np.sqrt(T)              \n",
    "envelope = np.maximum.reduce([decay, plateau, rise])\n",
    "\n",
    "plt.figure(figsize=(9, 5.2), dpi=300)\n",
    "plt.plot(T, decay, linestyle=\"--\", linewidth=1.8, label=r\"$\\kappa^{T/2}$\", alpha=0.6)\n",
    "plt.plot(T, plateau, linestyle=\"--\", linewidth=1.8, label=r\"$\\frac{\\sqrt{pq}\\,(\\tau+\\log(pq))}{\\sqrt{n}}$\", alpha=0.6)\n",
    "plt.plot(T, rise, linestyle=\"--\", linewidth=1.8, label=r\"$\\frac{\\sqrt{p}(\\sqrt{q}+\\sqrt{\\tau})^3}{n\\sqrt{\\min\\{\\rho_1,\\rho_2\\}}}\\sqrt{T}$\", alpha=0.6)\n",
    "plt.plot(T, envelope, linewidth=2.5, color='red', zorder=3,\n",
    "         label=r\"$\\|\\beta^{(T)}-\\hat{\\beta}\\|$\")\n",
    "\n",
    "for t, name in [(T1, r\"$T_1$\"), (T2, r\"$T_2$\")]:\n",
    "    plt.axvline(t, linestyle=\":\", linewidth=2.0)\n",
    "    ymax = envelope.min()\n",
    "    plt.text(t, ymax*1.02, name, ha=\"center\", va=\"bottom\", fontsize=18)\n",
    "\n",
    "plt.text(T_min+3, C*1.5, \"contraction\", fontsize=16)\n",
    "plt.text((T1+T2)/2, C*1.5, \"plateau\", fontsize=16, ha=\"center\")\n",
    "plt.text(T2+5, C*1.5, r\"$\\sqrt{T}$-growth\", fontsize=16)\n",
    "\n",
    "plt.xticks([])\n",
    "plt.yticks([])\n",
    "plt.gca().tick_params(axis='both', which='both', length=0)\n",
    "\n",
    "plt.xlabel(r\"$T$\")\n",
    "plt.ylabel(r\"$\\|\\beta^{(T)}-\\hat{\\beta}\\|$\")\n",
    "plt.title(\"Error vs. Iterations\")\n",
    "handles, labels = plt.gca().get_legend_handles_labels()\n",
    "handles = [handles[-1]] + handles[:-1]\n",
    "labels  = [labels[-1]]  + labels[:-1]\n",
    "plt.legend(handles, labels, frameon=False, loc=\"lower right\")\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"Figures/sketch.png\")\n",
    "plt.show()\n",
    "\n",
    "plt.rcParams.update(old_rc)  # restore everything"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "30bfb98a",
   "metadata": {},
   "source": [
    "# Algorithm 1 Performance versus n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f2447922",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from tqdm import tqdm\n",
    "\n",
    "def simulate_alg1_vs_n(\n",
    "    n_list,\n",
    "    repeats=10,\n",
    "    Theta_true=None,\n",
    "    p=10, q=10,\n",
    "    sigma_z=1.0, sigma_1=1.0, sigma_2=1.0,\n",
    "    C=1.0, C0=1.0, c=1.0, c0=1.0, c1=1.0, c2=1.0,\n",
    "    tau=5.0,\n",
    "    rho_1=0.9, rho_2=0.1,\n",
    "    seed=0,\n",
    "    T_given=None\n",
    "):\n",
    "    \"\"\"\n",
    "    Runs one full sweep over n_list for Algorithm 1 and 2SLS\n",
    "\n",
    "    Returns:\n",
    "        dict with keys:\n",
    "            'n_list', 'errors_alg1', 'stds_alg1', 'T_all', 'errors_2sls', 'stds_2sls'\n",
    "    \"\"\"\n",
    "    np.random.seed(seed)\n",
    "\n",
    "    if Theta_true is None:\n",
    "        Theta_true = 5 * np.eye(q, p) + np.random.randn(q, p)\n",
    "    beta_true = np.random.randn(p)\n",
    "    sigma_min_Theta = np.linalg.svd(Theta_true, compute_uv=False).min()\n",
    "    sigma_max_Theta = np.linalg.norm(Theta_true, 2)\n",
    "    R = (np.sqrt(q) + np.sqrt(tau))**2 * (np.sqrt(p*q) + np.sqrt(p*(np.log(p)+tau)))\n",
    "    errors_alg1, stds_alg1, T_per_n = [], [], []\n",
    "    errors_2sls, stds_2sls = [], []\n",
    "\n",
    "    rho = rho_1 + rho_2\n",
    "    for n in tqdm(n_list, desc=f\"ρ1={rho_1:.2f}, ρ2={rho_2:.2f}\"):\n",
    "        err_list_alg1 = []\n",
    "        err_list_2sls = []\n",
    "\n",
    "        for _ in range(repeats):\n",
    "            Z = np.random.randn(n, q)\n",
    "            delta = (C0 * sigma_z**2 * (np.sqrt(q) + np.sqrt(tau))) / np.sqrt(n)\n",
    "            Psi_norm = c0 * (sigma_z * sigma_2 * np.sqrt(p * q * (tau + np.log(2*p*q)))) / (np.sqrt(n) * (1 - delta)**2)\n",
    "\n",
    "            # if delta >= 1:\n",
    "            #     print(f\"[warn] delta >= 1 for n={n:.0f} (delta={delta:.3f})\")\n",
    "            # if sigma_min_Theta < Psi_norm:\n",
    "            #     print(f\"[warn] sigma_min_Theta < Psi_norm for n={n:.0f} (min={sigma_min_Theta:.3f}, Psi={Psi_norm:.3f})\")\n",
    "\n",
    "            gamma_floor = (1 - delta)**2 * (sigma_min_Theta - Psi_norm)**2\n",
    "            gamma_ceil = (1 + delta)**2 * (sigma_max_Theta + Psi_norm)**2\n",
    "            \n",
    "            U = np.random.randn(n, q)\n",
    "            Phi = np.random.randn(q, p)\n",
    "            phi = np.random.randn(q)\n",
    "            X = Z @ Theta_true + U @ Phi + sigma_1 * np.random.randn(n, p)\n",
    "            y = X @ beta_true + U @ phi + sigma_2 * np.random.randn(n)\n",
    "\n",
    "            Theta_t = np.zeros((q, p))\n",
    "            beta_t = np.zeros(p)\n",
    "\n",
    "            # Stepsizes and T\n",
    "            alpha = 2 / (2*gamma_ceil + gamma_floor)\n",
    "            kappa_beta = max(\n",
    "                np.abs(1 - alpha * gamma_floor / (2 * n)),\n",
    "                np.abs(1 - alpha * (2*gamma_floor + gamma_ceil) / (2 * n))\n",
    "            )\n",
    "            # Ensure log(1/kappa_beta) positive; if not, fallback to small T\n",
    "            if T_given == None:\n",
    "                if kappa_beta >= 1:\n",
    "                    T = 5\n",
    "                else:\n",
    "                    T = int(np.ceil(c * (np.log(n * max(rho_1, 1e-12) / max(R, 1e-12)))))\n",
    "                    if T <= 0 or not np.isfinite(T):\n",
    "                        T = 5\n",
    "            else:\n",
    "                T = T_given\n",
    "\n",
    "            eta = 1 / (1+delta)**2\n",
    "            gamma_1 = c1 * (np.sqrt(q) + np.sqrt(tau + np.log(n*T))) ** 2\n",
    "            gamma_2 = c2 * (np.sqrt(q) + np.sqrt(tau + np.log(n*T))) ** 2\n",
    "\n",
    "            # Noise scales\n",
    "            lambda_1 = (2 * gamma_1 / n) * np.sqrt(T / rho_1)\n",
    "            lambda_2 = (2 * gamma_2 / n) * np.sqrt(T / rho_2)\n",
    "\n",
    "            for _t in range(T):\n",
    "                # Gradient for Theta\n",
    "                grad_Theta_t = Z.T @ (Z @ Theta_t - X)\n",
    "                Xi = np.random.randn(q, p) * lambda_1\n",
    "                Theta_t = Theta_t - eta / n * grad_Theta_t + eta * Xi\n",
    "\n",
    "                # Gradient for beta\n",
    "                grad_beta_t = (Z @ Theta_t).T @ (Z @ Theta_t @ beta_t - y)\n",
    "                nu = np.random.randn(p) * lambda_2\n",
    "                beta_t = beta_t - alpha / n * grad_beta_t + alpha * nu\n",
    "\n",
    "            # Errors\n",
    "            err_alg1 = np.linalg.norm(beta_t - beta_true)\n",
    "            err_list_alg1.append(err_alg1)\n",
    "\n",
    "            # 2SLS baseline\n",
    "            # (Pz = Z(Z^T Z)^{-1} Z^T)\n",
    "            Pz = Z @ np.linalg.inv(Z.T @ Z) @ Z.T\n",
    "            X_hat = Pz @ X\n",
    "            beta_2sls = np.linalg.inv(X_hat.T @ X_hat) @ (X_hat.T @ y)\n",
    "            err_list_2sls.append(np.linalg.norm(beta_2sls - beta_true))\n",
    "\n",
    "        errors_alg1.append(np.mean(err_list_alg1))\n",
    "        stds_alg1.append(np.std(err_list_alg1))\n",
    "        errors_2sls.append(np.mean(err_list_2sls))\n",
    "        stds_2sls.append(np.std(err_list_2sls))\n",
    "        T_per_n.append(T)\n",
    "    \n",
    "    return {\n",
    "        'n_list': list(n_list),\n",
    "        'errors_alg1': np.array(errors_alg1),\n",
    "        'stds_alg1': np.array(stds_alg1),\n",
    "        'T_all': np.array(T_per_n),\n",
    "        'errors_2sls': np.array(errors_2sls),\n",
    "        'stds_2sls': np.array(stds_2sls),\n",
    "    }\n",
    "\n",
    "\n",
    "def plot_alg1_vs_n(\n",
    "    n_list=None,\n",
    "    repeats=10,\n",
    "    regimes=None,\n",
    "    common_kwargs=None,\n",
    "    seed=0,\n",
    "    annotate_T=True,\n",
    "    plt_title=\"Algorithm 1 Error vs n\",\n",
    "    save_path=None\n",
    "):\n",
    "    \"\"\"\n",
    "    regimes: list of tuples [(rho1, rho2), ...]\n",
    "    common_kwargs: dict of parameters to pass through to simulate function (p, q, etc.)\n",
    "    \"\"\"\n",
    "    if n_list is None:\n",
    "        n_list = [1000,1500,2000,2500,3000,3500,4000,4500]\n",
    "    if regimes is None:\n",
    "        regimes = [\n",
    "            (0.2, 0.8),\n",
    "            (0.5, 0.5),\n",
    "            (0.8, 0.2),\n",
    "        ]\n",
    "    if common_kwargs is None:\n",
    "        common_kwargs = {}\n",
    "\n",
    "    results = []\n",
    "\n",
    "    # Run all regimes\n",
    "    for rho1, rho2 in regimes:\n",
    "        res = simulate_alg1_vs_n(\n",
    "            n_list=n_list,\n",
    "            repeats=repeats,\n",
    "            rho_1=rho1,\n",
    "            rho_2=rho2,\n",
    "            seed=seed,\n",
    "            **common_kwargs\n",
    "        )\n",
    "        res['rho1'] = rho1\n",
    "        res['rho2'] = rho2\n",
    "        results.append(res)\n",
    "\n",
    "    colors = plt.rcParams['axes.prop_cycle'].by_key()['color']\n",
    "\n",
    "    for i, res in enumerate(results):\n",
    "        col = colors[i % len(colors)]\n",
    "        ebc = plt.errorbar(\n",
    "            res['n_list'],\n",
    "            res['errors_alg1'],\n",
    "            yerr=res['stds_alg1'],\n",
    "            fmt='-^',\n",
    "            color=col,        \n",
    "            ecolor=col,        \n",
    "            capsize=3,\n",
    "            label=f\"ρ₁={res['rho1']:.2f}, ρ₂={res['rho2']:.2f}\"\n",
    "        )\n",
    "        if annotate_T:\n",
    "            for x, y, Tm in zip(res['n_list'], res['errors_alg1'], res['T_all']):\n",
    "                plt.annotate(\n",
    "                    f\"T={Tm:.0f}\", (x, y),\n",
    "                    textcoords=\"offset points\", xytext=(0, 8), ha='center',\n",
    "                    color=col, fontsize=8\n",
    "            )\n",
    "    base = results[0]\n",
    "    plt.errorbar(\n",
    "        base['n_list'],\n",
    "        base['errors_2sls'],\n",
    "        yerr=base['stds_2sls'],\n",
    "        fmt='--s',\n",
    "        label='2SLS (baseline)',\n",
    "        color=colors[(i+1) % len(colors)]\n",
    "    )\n",
    "    plt.xlabel('Sample size n')\n",
    "    plt.ylabel(r'$\\|\\beta^{(T)}-\\beta_{\\mathrm{true}}\\|$')\n",
    "    plt.title(plt_title)\n",
    "    plt.grid(True)\n",
    "    plt.legend()\n",
    "    plt.tight_layout()\n",
    "    if save_path is not None:\n",
    "        plt.savefig(save_path, dpi=300)\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2b012b1c",
   "metadata": {},
   "source": [
    "## Figure 3(a)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "39270ab4",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.random.seed(0)\n",
    "n_list = list(range(100, 2101, 200))\n",
    "p, q = 5, 5\n",
    "common_kwargs = dict(\n",
    "    p=p, q=q,\n",
    "    Theta_true=5 * np.eye(q, p) + np.random.randn(q, p),\n",
    "    sigma_z=1.0, sigma_1 = 1.0, sigma_2=1.0,\n",
    "    T_given=20\n",
    ")\n",
    "plt_title = f\"Algorithm 1 Error vs n (p={p}, q={q})\"\n",
    "save_path = f\"Figures/Alg1_vs_n_p={p}_q={q}.png\"\n",
    "plot_alg1_vs_n(\n",
    "    n_list=n_list,\n",
    "    repeats=100,\n",
    "    regimes=[(1,9), (5,5), (9,1)],\n",
    "    common_kwargs=common_kwargs,\n",
    "    seed=0,\n",
    "    annotate_T=False,\n",
    "    plt_title=plt_title,\n",
    "    save_path=save_path,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "69276dec",
   "metadata": {},
   "source": [
    "## Figure 3(b)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "806c288a",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_list = list(range(1001, 5001, 500))\n",
    "p, q = 50, 50\n",
    "common_kwargs = dict(\n",
    "    p=p, q=q,\n",
    "    Theta_true=10 * np.eye(q, p),\n",
    "    sigma_z=1.0, sigma_1=1.0, sigma_2=1.0,\n",
    "    T_given=20\n",
    ")\n",
    "plt_title = f\"Algorithm 1 Error vs n (p={p}, q={q})\"\n",
    "save_path = f\"Figures/Alg1_vs_n_p={p}_q={q}.png\"\n",
    "plot_alg1_vs_n(\n",
    "    n_list=n_list,\n",
    "    repeats=10,\n",
    "    regimes=[(1,9), (5,5), (9,1)],\n",
    "    common_kwargs=common_kwargs,\n",
    "    seed=0,\n",
    "    annotate_T=False,\n",
    "    plt_title=plt_title,\n",
    "    save_path=save_path,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8d0cc5b8",
   "metadata": {},
   "source": [
    "# Algorithm 1 Performance versus T"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e2f3fbf2",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from tqdm import tqdm\n",
    "\n",
    "def simulate_alg1_vs_T(\n",
    "    n=5000,\n",
    "    T_list=None,\n",
    "    repeats=10,\n",
    "    Theta_true=None,\n",
    "    p=10, q=10,\n",
    "    sigma_z=1.0, sigma_1=1.0, sigma_2=1.0,\n",
    "    C=1.0, C0=1.0, c=1.0, c0=1.0, c1=1.0, c2=1.0,\n",
    "    tau=5.0,\n",
    "    rho_1=0.9, rho_2=0.1,\n",
    "    seed=0,\n",
    "):\n",
    "    \"\"\"\n",
    "    Fix n and sweep T\n",
    "    Return: mean/std error over repeats for Algorithm 1 vs T, 2SLS baseline (independent of T).\n",
    "    Dataset is re-sampled per repeat, and reused across all T within that repeat.\n",
    "    \"\"\"\n",
    "    rng = np.random.default_rng(seed)\n",
    "\n",
    "    if T_list is None:\n",
    "        T_list = list(range(10, 101, 10))\n",
    "\n",
    "    if Theta_true is None:\n",
    "        Theta_true = 5 * np.eye(q, p) + rng.standard_normal((q, p))\n",
    "    beta_true = rng.standard_normal(p)\n",
    "    sigma_min_Theta = np.linalg.svd(Theta_true, compute_uv=False).min()\n",
    "    sigma_max_Theta = np.linalg.norm(Theta_true, 2)\n",
    "    R = (np.sqrt(q) + np.sqrt(tau))**2 * (np.sqrt(p*q) + np.sqrt(p*(np.log(p)+tau)))\n",
    "\n",
    "    delta = (C0 * sigma_z**2 * (np.sqrt(q) + np.sqrt(tau))) / np.sqrt(n)\n",
    "    Psi_norm = (sigma_z * sigma_2 * np.sqrt(c0 * p * q * (tau + np.log(2*p*q)))) / (np.sqrt(n) * (1 - delta)**2)\n",
    "    if delta >= 1:\n",
    "        print(f\"[warn] delta >= 1 for n={n:.0f} (delta={delta:.3f})\")\n",
    "    if sigma_min_Theta < Psi_norm:\n",
    "        print(f\"[warn] sigma_min_Theta < Psi_norm for n={n:.0f} (min={sigma_min_Theta:.3f}, Psi={Psi_norm:.3f})\")\n",
    "\n",
    "    gamma_floor = (1 - delta)**2 * (sigma_min_Theta - Psi_norm)**2\n",
    "    gamma_ceil = (1 + delta)**2 * (sigma_max_Theta + Psi_norm)**2\n",
    "\n",
    "    # Stepsize for beta \n",
    "    alpha = 2 / (2*gamma_ceil + gamma_floor)\n",
    "\n",
    "    errors_alg1_allT = {T: [] for T in T_list}\n",
    "    errors_2sls_allT = {T: [] for T in T_list}\n",
    "\n",
    "    for rep in tqdm(range(repeats), desc=f\"n={n}, sweep T\"):\n",
    "        # --- Sample one dataset for this repeat and reuse it across T ---\n",
    "        Z = rng.standard_normal((n, q))\n",
    "        U = rng.standard_normal((n, q))\n",
    "        Phi = rng.standard_normal((q, p))\n",
    "        phi = rng.standard_normal(q)\n",
    "        X = Z @ Theta_true + U @ Phi + sigma_1 * rng.standard_normal((n, p))\n",
    "        y = X @ beta_true + U @ phi + sigma_2 * rng.standard_normal(n)\n",
    "\n",
    "        # 2SLS baseline for this dataset (same across all T in this repeat)\n",
    "        Pz = Z @ np.linalg.inv(Z.T @ Z) @ Z.T\n",
    "        X_hat = Pz @ X\n",
    "        beta_2sls = np.linalg.inv(X_hat.T @ X_hat) @ (X_hat.T @ y)\n",
    "        err_2sls = np.linalg.norm(beta_2sls - beta_true)\n",
    "\n",
    "        for T in T_list:\n",
    "            delta_prime = (C0 * sigma_z**2 * (np.sqrt(q) + np.sqrt(tau + np.log(max(T,1))))) / np.sqrt(n)\n",
    "            eta = 1 / (1+delta_prime)**2\n",
    "\n",
    "            gamma_1 = c1 * (np.sqrt(q) + np.sqrt(tau + np.log(n*T))) ** 2\n",
    "            gamma_2 = c2 * (np.sqrt(q) + np.sqrt(tau + np.log(n*T))) ** 2\n",
    "\n",
    "            rho = rho_1 + rho_2\n",
    "            lambda_1 = (2 * gamma_1 / n) * np.sqrt(T / rho_1)\n",
    "            lambda_2 = (2 * gamma_2 / n) * np.sqrt(T / rho_2)\n",
    "\n",
    "            Theta_t = np.zeros((q, p))\n",
    "            beta_t  = np.zeros(p)\n",
    "\n",
    "            # Run T iterations\n",
    "            for _t in range(T):\n",
    "                # Theta step\n",
    "                grad_Theta_t = Z.T @ (Z @ Theta_t - X)\n",
    "                Xi = rng.standard_normal((q, p)) * lambda_1\n",
    "                Theta_t = Theta_t - (eta / n) * grad_Theta_t + eta * Xi\n",
    "\n",
    "                # beta step\n",
    "                grad_beta_t = (Z @ Theta_t).T @ (Z @ Theta_t @ beta_t - y)\n",
    "                nu = rng.standard_normal(p) * lambda_2\n",
    "                beta_t = beta_t - (alpha / n) * grad_beta_t + alpha * nu\n",
    "\n",
    "            # Record errors\n",
    "            errors_alg1_allT[T].append(np.linalg.norm(beta_t - beta_true))\n",
    "            errors_2sls_allT[T].append(err_2sls)\n",
    "\n",
    "    # Aggregate\n",
    "    T_arr = np.array(T_list)\n",
    "    err_alg1_mean = np.array([np.mean(errors_alg1_allT[T]) for T in T_list])\n",
    "    err_alg1_std  = np.array([np.std(errors_alg1_allT[T])  for T in T_list])\n",
    "    err_2sls_mean = np.array([np.mean(errors_2sls_allT[T]) for T in T_list])  \n",
    "    err_2sls_std  = np.array([np.std(errors_2sls_allT[T])  for T in T_list])\n",
    "    return {\n",
    "        'n': n,\n",
    "        'T_list': T_arr,\n",
    "        'errors_alg1': err_alg1_mean,\n",
    "        'stds_alg1': err_alg1_std,\n",
    "        'errors_2sls': err_2sls_mean,\n",
    "        'stds_2sls': err_2sls_std,\n",
    "        'rho1': rho_1,\n",
    "        'rho2': rho_2,\n",
    "    }\n",
    "\n",
    "\n",
    "def plot_alg1_vs_T(\n",
    "    n=5000,\n",
    "    T_list=None,\n",
    "    repeats=10,\n",
    "    regimes=None,           # [(rho1, rho2), ...]\n",
    "    common_kwargs=None,     # shared kwargs passed to simulate_alg1_vs_T\n",
    "    seed=0,\n",
    "    plt_title=\"Algorithm 2 Error vs T (n fixed)\",\n",
    "    save_path=None\n",
    "):\n",
    "    if T_list is None:\n",
    "        T_list = list(range(10, 101, 10))\n",
    "    if regimes is None:\n",
    "        regimes = [\n",
    "            (0.2, 0.8),\n",
    "            (0.5, 0.5),\n",
    "            (0.8, 0.2),\n",
    "        ]\n",
    "    if common_kwargs is None:\n",
    "        common_kwargs = {}\n",
    "\n",
    "    results = []\n",
    "    colors = plt.rcParams['axes.prop_cycle'].by_key()['color']\n",
    "\n",
    "    for i, (rho1, rho2) in enumerate(regimes):\n",
    "        res = simulate_alg1_vs_T(\n",
    "            n=n,\n",
    "            T_list=T_list,\n",
    "            repeats=repeats,\n",
    "            rho_1=rho1,\n",
    "            rho_2=rho2,\n",
    "            seed=seed,\n",
    "            **common_kwargs\n",
    "        )\n",
    "        results.append(res)\n",
    "\n",
    "        col = colors[i % len(colors)]\n",
    "\n",
    "        label=f\"ρ₁={rho1}, ρ₂={rho2}\"\n",
    "\n",
    "        plt.errorbar(\n",
    "            res['T_list'],\n",
    "            res['errors_alg1'],\n",
    "            yerr=res['stds_alg1'],\n",
    "            fmt='-^',\n",
    "            color=col,\n",
    "            ecolor=col,\n",
    "            capsize=3,\n",
    "            label=label\n",
    "        )\n",
    "\n",
    "    base = results[0]\n",
    "    plt.errorbar(\n",
    "        base['T_list'],\n",
    "        base['errors_2sls'],\n",
    "        yerr=base['stds_2sls'],\n",
    "        fmt='--s',\n",
    "        label='2SLS (baseline)',\n",
    "        color=colors[(len(regimes)) % len(colors)]\n",
    "    )\n",
    "\n",
    "    plt.xlabel('Iterations T')\n",
    "    plt.ylabel(r'$\\|\\beta^{(T)}-\\beta_{\\mathrm{true}}\\|$')\n",
    "    plt.title(f\"{plt_title}\")\n",
    "    plt.grid(True)\n",
    "    plt.legend()\n",
    "    plt.tight_layout()\n",
    "    if save_path is not None:\n",
    "        plt.savefig(save_path, dpi=300)\n",
    "    plt.show()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "910d0b2b",
   "metadata": {},
   "source": [
    "## Figure 4(a)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "37477f52",
   "metadata": {},
   "outputs": [],
   "source": [
    "p, q=5, 5\n",
    "n=1000\n",
    "save_path = f\"Figures/Alg1_vs_T_n={n}_p={p}_q={q}_rho1_small.png\"\n",
    "plot_alg1_vs_T(\n",
    "    n=n,\n",
    "    T_list=list(range(10, 101, 10)),\n",
    "    repeats=100,\n",
    "    regimes=[(0.1, 100), (1, 100), (10, 100)],\n",
    "    common_kwargs=dict(p=p, q=q, sigma_z=1.0, sigma_1=1.0, sigma_2=1.0),\n",
    "    seed=0,\n",
    "    plt_title=f\"Algorithm 1 Error vs T (n={n}, p={p}, q={q})\",\n",
    "    save_path=save_path\n",
    ")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "31136ac4",
   "metadata": {},
   "source": [
    "## Figure 4(b)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "39656a04",
   "metadata": {},
   "outputs": [],
   "source": [
    "p, q=5, 5\n",
    "n=1000\n",
    "save_path = f\"Figures/Alg1_vs_T_n={n}_p={p}_q={q}_rho2_small.png\"\n",
    "plot_alg1_vs_T(\n",
    "    n=n,\n",
    "    T_list=list(range(10, 101, 10)),\n",
    "    repeats=100,\n",
    "    regimes=[(100, 0.1), (100, 1), (100, 10)],\n",
    "    common_kwargs=dict(p=p, q=q, sigma_z=1.0, sigma_1=1.0, sigma_2=1.0),\n",
    "    seed=0,\n",
    "    plt_title=f\"Algorithm 1 Error vs T (n={n}, p={p}, q={q})\",\n",
    "    save_path=save_path\n",
    ")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8e0b7e90",
   "metadata": {},
   "source": [
    "## Figure 6(a)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a04fd123",
   "metadata": {},
   "outputs": [],
   "source": [
    "p, q=5, 5\n",
    "n=1000\n",
    "save_path = f\"Figures/Alg1_vs_T_n={n}_p={p}_q={q}_rho1_small_largeT.png\"\n",
    "plot_alg1_vs_T(\n",
    "    n=n,\n",
    "    T_list=list(range(10, 2601, 100)),\n",
    "    repeats=10,\n",
    "    regimes=[(1, 0.1), (1, 1), (1, 10)],\n",
    "    common_kwargs=dict(p=p, q=q, sigma_z=1.0, sigma_1=1.0, sigma_2=1.0),\n",
    "    seed=0,\n",
    "    plt_title=f\"Algorithm 1 Error vs T (n={n}, p={p}, q={q})\",\n",
    "    save_path=save_path\n",
    ")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "720ef623",
   "metadata": {},
   "source": [
    "# Algorithm 3 (DP-2S-GD-beta)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "876c96fb",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from tqdm import tqdm\n",
    "\n",
    "def simulate_alg3_vs_T(\n",
    "    n=5000,\n",
    "    T_list=None,\n",
    "    repeats=10,\n",
    "    Theta_true=None,\n",
    "    p=10, q=10,\n",
    "    sigma_z=1.0, sigma_1=1.0, sigma_2=1.0,\n",
    "    C=1.0, C0=1.0, c=1.0, c0=1.0, c1=1.0, c2=1.0,\n",
    "    tau=5.0,\n",
    "    rho_2=0.1,\n",
    "    rng_seed=0,\n",
    "):\n",
    "    \"\"\"\n",
    "    Fix n and sweep T. Returns mean/std error over repeats for Algorithm 2 vs T,\n",
    "    plus 2SLS baseline (independent of T, but averaged the same way).\n",
    "    Dataset is re-sampled per repeat, and reused across all T within that repeat.\n",
    "    \"\"\"\n",
    "    if T_list is None:\n",
    "        T_list = list(range(10, 101, 10))\n",
    "\n",
    "    rng = np.random.default_rng(rng_seed)\n",
    "\n",
    "    if Theta_true is None:\n",
    "        Theta_true = 5 * np.eye(q, p) + rng.standard_normal((q, p))\n",
    "    beta_true = rng.standard_normal(p)\n",
    "    sigma_min_Theta = np.linalg.svd(Theta_true, compute_uv=False).min()\n",
    "    sigma_max_Theta = np.linalg.norm(Theta_true, 2)\n",
    "    R = (np.sqrt(q) + np.sqrt(tau))**2 * (np.sqrt(p*q) + np.sqrt(p*(np.log(p)+tau)))\n",
    "\n",
    "    # Precompute constants that do not depend on T\n",
    "    delta = (C0 * sigma_z**2 * (np.sqrt(q) + np.sqrt(tau))) / np.sqrt(n)\n",
    "    Psi_norm = (sigma_z * sigma_2 * np.sqrt(c0 * p * q * (tau + np.log(2*p*q)))) / (np.sqrt(n) * (1 - delta)**2)\n",
    "    if delta >= 1:\n",
    "        print(f\"[warn] delta >= 1 for n={n:.0f} (delta={delta:.3f})\")\n",
    "    if sigma_min_Theta < Psi_norm:\n",
    "        print(f\"[warn] sigma_min_Theta < Psi_norm for n={n:.0f} (min={sigma_min_Theta:.3f}, Psi={Psi_norm:.3f})\")\n",
    "\n",
    "    gamma_floor = (1 - delta)**2 * (sigma_min_Theta - Psi_norm)**2\n",
    "    gamma_ceil = (1 + delta)**2 * (sigma_max_Theta + Psi_norm)**2\n",
    "\n",
    "    alpha = 2 / (2*gamma_ceil + gamma_floor)\n",
    "\n",
    "    errors_alg3_allT = {T: [] for T in T_list}\n",
    "    errors_2sls_allT = {T: [] for T in T_list}\n",
    "\n",
    "    for rep in tqdm(range(repeats), desc=f\"n={n}, sweep T\"):\n",
    "        # --- Sample one dataset for this repeat and reuse it across T ---\n",
    "        Z = np.random.randn(n, q)\n",
    "        U = np.random.randn(n, q)\n",
    "        Phi = np.random.randn(q, p)\n",
    "        phi = np.random.randn(q)\n",
    "        X = Z @ Theta_true + U @ Phi + sigma_1 * np.random.randn(n, p)\n",
    "        y = X @ beta_true + U @ phi + sigma_2 * np.random.randn(n)\n",
    "\n",
    "        # 2SLS baseline for this dataset (same across all T in this repeat)\n",
    "        Pz = Z @ np.linalg.inv(Z.T @ Z) @ Z.T\n",
    "        X_hat = Pz @ X\n",
    "        beta_2sls = np.linalg.inv(X_hat.T @ X_hat) @ (X_hat.T @ y)\n",
    "        err_2sls = np.linalg.norm(beta_2sls - beta_true)\n",
    "\n",
    "        for T in T_list:\n",
    "            delta_prime = (C0 * sigma_z**2 * (np.sqrt(q) + np.sqrt(tau + np.log(T)))) / np.sqrt(n)\n",
    "            eta = 1 / (1+delta_prime)**2\n",
    "\n",
    "            gamma_1 = c1 * (np.sqrt(q) + np.sqrt(tau + np.log(max(n*T, 1)))) ** 2\n",
    "            gamma_2 = c2 * (np.sqrt(q) + np.sqrt(tau + np.log(max(n*T, 1)))) ** 2\n",
    "\n",
    "            lambda_1 = 0\n",
    "            lambda_2 = (2 * gamma_2 / n) * np.sqrt(T / rho_2)\n",
    "\n",
    "            # Initialize Algorithm 2\n",
    "            Theta_t = np.zeros((q, p))\n",
    "            beta_t  = np.zeros(p)\n",
    "\n",
    "            # Run T iterations\n",
    "            for _t in range(T):\n",
    "                # Theta step\n",
    "                grad_Theta = Z.T @ (Z @ Theta_t - X)\n",
    "                Xi = rng.standard_normal((q, p)) * lambda_1\n",
    "                Theta_t = Theta_t - (eta / n) * grad_Theta + eta * Xi\n",
    "\n",
    "                # beta step\n",
    "                grad_beta = (Z @ Theta_t).T @ (Z @ Theta_t @ beta_t - y)\n",
    "                nu = rng.standard_normal(p) * lambda_2\n",
    "                beta_t = beta_t - (alpha / n) * grad_beta + alpha * nu\n",
    "\n",
    "            # Record errors\n",
    "            errors_alg3_allT[T].append(np.linalg.norm(beta_t - beta_true))\n",
    "            errors_2sls_allT[T].append(err_2sls)\n",
    "\n",
    "    T_arr = np.array(T_list)\n",
    "    err_alg3_mean = np.array([np.mean(errors_alg3_allT[T]) for T in T_list])\n",
    "    err_alg3_std  = np.array([np.std(errors_alg3_allT[T])  for T in T_list])\n",
    "    err_2sls_mean = np.array([np.mean(errors_2sls_allT[T]) for T in T_list])  \n",
    "    err_2sls_std  = np.array([np.std(errors_2sls_allT[T])  for T in T_list])\n",
    "    return {\n",
    "        'n': n,\n",
    "        'T_list': T_arr,\n",
    "        'errors_alg3': err_alg3_mean,\n",
    "        'stds_alg3': err_alg3_std,\n",
    "        'errors_2sls': err_2sls_mean,\n",
    "        'stds_2sls': err_2sls_std,\n",
    "        'rho2': rho_2,\n",
    "    }\n",
    "\n",
    "\n",
    "def plot_alg3_error_vs_T(\n",
    "    n=5000,\n",
    "    T_list=None,\n",
    "    repeats=10,\n",
    "    regimes=None,           # [(rho1, rho2, label), ...]\n",
    "    common_kwargs=None,     # shared kwargs passed to simulate_alg1_vs_T\n",
    "    rng_seed=0,\n",
    "    plt_title=\"Algorithm 3 Error vs T (n fixed)\",\n",
    "    save_path=None\n",
    "):\n",
    "    if T_list is None:\n",
    "        T_list = list(range(10, 101, 10))\n",
    "    if regimes is None:\n",
    "        regimes = [\n",
    "            (0.8),\n",
    "            (0.5),\n",
    "            (0.2),\n",
    "        ]\n",
    "    if common_kwargs is None:\n",
    "        common_kwargs = {}\n",
    "\n",
    "    results = []\n",
    "    colors = plt.rcParams['axes.prop_cycle'].by_key()['color']\n",
    "\n",
    "    for i, rho2 in enumerate(regimes):\n",
    "        res = simulate_alg3_vs_T(\n",
    "            n=n,\n",
    "            T_list=T_list,\n",
    "            repeats=repeats,\n",
    "            rho_2=rho2,\n",
    "            rng_seed=rng_seed,\n",
    "            **common_kwargs\n",
    "        )\n",
    "        results.append(res)\n",
    "\n",
    "        col = colors[i % len(colors)]\n",
    "\n",
    "        label=f\"ρ₂={rho2}\"\n",
    "\n",
    "        plt.errorbar(\n",
    "            res['T_list'],\n",
    "            res['errors_alg3'],\n",
    "            yerr=res['stds_alg3'],\n",
    "            fmt='-^',\n",
    "            color=col,\n",
    "            ecolor=col,\n",
    "            capsize=3,\n",
    "            label=label\n",
    "        )\n",
    "\n",
    "    # 2SLS baseline from the first regime (all regimes have same T-list)\n",
    "    base = results[0]\n",
    "    plt.errorbar(\n",
    "        base['T_list'],\n",
    "        base['errors_2sls'],\n",
    "        yerr=base['stds_2sls'],\n",
    "        fmt='--s',\n",
    "        label='2SLS (baseline)',\n",
    "        color=colors[(len(regimes)) % len(colors)]\n",
    "    )\n",
    "\n",
    "    plt.xlabel('Iterations T')\n",
    "    plt.ylabel(r'$\\|\\beta^{(T)}-\\beta_{\\mathrm{true}}\\|$')\n",
    "    plt.title(f\"{plt_title}\")\n",
    "    plt.grid(True)\n",
    "    plt.legend()\n",
    "    plt.tight_layout()\n",
    "    if save_path is not None:\n",
    "        plt.savefig(save_path, dpi=300)\n",
    "    plt.show()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1d93e2b8",
   "metadata": {},
   "source": [
    "## Figure 6(b)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "20c88165",
   "metadata": {},
   "outputs": [],
   "source": [
    "p, q=5, 5\n",
    "n=1000\n",
    "save_path = f\"Figures/Alg3_vs_T_n={n}_p={p}_q={q}_rho1_large.png\"\n",
    "plot_alg3_error_vs_T(\n",
    "    n=n,\n",
    "    T_list=list(range(10, 3001, 100)),\n",
    "    repeats=100,\n",
    "    regimes=[0.1, 1, 10],\n",
    "    common_kwargs=dict(p=p, q=q, sigma_z=1.0, sigma_1=1.0, sigma_2=1.0),\n",
    "    rng_seed=0,\n",
    "    plt_title=f\"Algorithm 3 Error vs T (n={n}, p={p}, q={q})\",\n",
    "    save_path=save_path\n",
    ")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4776cf77",
   "metadata": {},
   "source": [
    "# Figure 7"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "da336294",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from tqdm import tqdm\n",
    "\n",
    "np.random.seed(0)\n",
    "\n",
    "# Parameters\n",
    "p, q = 20, 20  \n",
    "sigma_z = 1\n",
    "sigma_1 = 1.0\n",
    "sigma_2 = 1.0\n",
    "C, C0 = 1.0, 1.0\n",
    "c, c0, c1 = 1, 1, 1\n",
    "tau = 5.0\n",
    "rho = 0.1\n",
    "\n",
    "\n",
    "Theta_true = 5 * np.eye(q, p) + np.random.randn(q, p)\n",
    "beta_true = np.random.randn(p)\n",
    "sigma_min_Theta = np.linalg.svd(Theta_true, compute_uv=False).min()\n",
    "sigma_max_Theta = np.linalg.norm(Theta_true, 2)\n",
    "\n",
    "R = (np.sqrt(q) + np.sqrt(tau))**2 * (np.sqrt(p*q) + np.sqrt(p*(np.log(p)+tau)))\n",
    "\n",
    "repeats = 100\n",
    "n_list = np.arange(500, 5001, 500)\n",
    "\n",
    "errors_agl2_vs_betahat, stds_agl2_vs_betahat = [], []\n",
    "errors_betahat_vs_true, stds_betahat_vs_true = [], []\n",
    "\n",
    "for n in tqdm(n_list):\n",
    "    err_list_agl2_vs_betahat = []\n",
    "    err_list_betahat_vs_true = []\n",
    "\n",
    "    for _ in range(repeats):\n",
    "        Z = np.random.randn(n, q)\n",
    "        delta = (C0 * sigma_z**2 * (np.sqrt(q) + np.sqrt(tau))) / np.sqrt(n)\n",
    "\n",
    "        Psi_norm = (c0 * sigma_z * sigma_2 * np.sqrt(p * q * (tau + np.log(2*p*q)))) / (np.sqrt(n) * (1 - delta)**2)\n",
    "        gamma_floor = n * (1 - delta)**2 * (sigma_min_Theta - Psi_norm)**2\n",
    "        gamma_ceil = n * (1 + delta)**2 * (sigma_max_Theta + Psi_norm)**2\n",
    "\n",
    "        U = np.random.randn(n, q)\n",
    "        Phi = np.random.randn(q, p)\n",
    "        phi = np.random.randn(q)\n",
    "        X = Z @ Theta_true + U @ Phi + sigma_1 * np.random.randn(n, p)\n",
    "        y = X @ beta_true + U @ phi + sigma_2 * np.random.randn(n)\n",
    "\n",
    "        Theta_0 = np.zeros((q, p))\n",
    "        beta_0 = np.zeros(p)\n",
    "\n",
    "        alpha = 2 * n / (2*gamma_ceil + gamma_floor)\n",
    "        kappa_beta = max(\n",
    "            abs(1 - alpha * gamma_floor / (2 * n)),\n",
    "            abs(1 - alpha * (2*gamma_ceil + gamma_floor) / (2 * n))\n",
    "        )\n",
    "        T = 100\n",
    "        delta_prime = (C0 * sigma_z**2 * (np.sqrt(q) + np.sqrt(tau + np.log(T)))) / np.sqrt(n)\n",
    "        eta = 1 / (1 + delta_prime)**2\n",
    "\n",
    "        for _ in range(T):\n",
    "            grad_Theta_0 = Z.T @ (Z @ Theta_0 - X)\n",
    "            Theta_0 = Theta_0 - eta / n * grad_Theta_0\n",
    "\n",
    "            ZT0 = Z @ Theta_0\n",
    "            grad_beta_0 = ZT0.T @ (ZT0 @ beta_0 - y)\n",
    "            beta_0 = beta_0 - alpha / n * grad_beta_0\n",
    "\n",
    "        Pz = Z @ np.linalg.pinv(Z.T @ Z) @ Z.T\n",
    "        X_hat = Pz @ X\n",
    "        beta_hat = np.linalg.pinv(X_hat.T @ X_hat) @ (X_hat.T @ y)\n",
    "\n",
    "        err_agl2_vs_betahat = np.linalg.norm(beta_0 - beta_true)\n",
    "        err_betahat_vs_true = np.linalg.norm(beta_hat - beta_true)\n",
    "\n",
    "        err_list_agl2_vs_betahat.append(err_agl2_vs_betahat)\n",
    "        err_list_betahat_vs_true.append(err_betahat_vs_true)\n",
    "\n",
    "    errors_agl2_vs_betahat.append(np.mean(err_list_agl2_vs_betahat))\n",
    "    stds_agl2_vs_betahat.append(np.std(err_list_agl2_vs_betahat))\n",
    "    errors_betahat_vs_true.append(np.mean(err_list_betahat_vs_true))\n",
    "    stds_betahat_vs_true.append(np.std(err_list_betahat_vs_true))\n",
    "\n",
    "plt.figure()\n",
    "plt.errorbar(n_list, errors_agl2_vs_betahat, yerr=stds_agl2_vs_betahat,\n",
    "             fmt='-o', label=r'$\\|\\beta^{(T)} - \\beta\\|$')\n",
    "plt.errorbar(n_list, errors_betahat_vs_true, yerr=stds_betahat_vs_true,\n",
    "             fmt='--s', label=r'$\\|\\hat{\\beta} - \\beta\\|$')\n",
    "plt.xlabel('Sample size n')\n",
    "plt.ylabel('Error norm')\n",
    "plt.title(r'2S-GD vs 2SLS (p=20, q=20)')\n",
    "plt.grid(True)\n",
    "plt.legend()\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"Figures/alg2_vs_2sls.png\", dpi=300)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d958938c",
   "metadata": {},
   "source": [
    "# Angrist Dataset Experiments"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9938ba9a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import os\n",
    "from typing import Callable, Dict, Any, Optional, Tuple\n",
    "import matplotlib.pyplot as plt\n",
    "from sas7bdat import SAS7BDAT\n",
    "\n",
    "def run_angrist_dp_experiment(\n",
    "    file_path: str = \"AngEv98/m_d_806.sas7bdat\",\n",
    "    *,\n",
    "    n_rows: int = 20000,                  # how many rows to read from SAS\n",
    "    filter_state: Optional[str] = None,   # e.g., '01' if you want STATE=='01'\n",
    "    center: bool = True,                  # center z, x, y\n",
    "    R: int = 1000,                        # number of repeated runs\n",
    "    T: int = 20,                          # iterations passed to dp_gd_2sls\n",
    "    rho_1: float = 1_000_000.0,\n",
    "    rho_2: float = 1.0,\n",
    "    tau: float = 0.0,\n",
    "    c0: float = 1.0,\n",
    "    C0: float = 1.0,\n",
    "    clip_grads: bool = False,\n",
    "    outdir: str = \"Figures\",\n",
    "    show_plots: bool = False,             # whether to plt.show()\n",
    "    base_seed: int = 0                    # seed offset for reproducibility across runs\n",
    ") -> Dict[str, Any]:\n",
    "    \"\"\"\n",
    "    Run DP-GD-2SLS on the Angrist dataset and save two figures:\n",
    "      (a) Boxplot of final DP beta across R runs with OLS/2SLS reference lines.\n",
    "      (b) Mean ± std learning paths for beta^(t) and Theta^(t).\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    file_path : str\n",
    "        Path to the SAS7BDAT file.\n",
    "    n_rows : int\n",
    "        Number of rows to read (including header row).\n",
    "    filter_state : Optional[str]\n",
    "        If provided, keep rows with STATE == filter_state (e.g., '01').\n",
    "    center : bool\n",
    "        If True, center z, x, y.\n",
    "    R : int\n",
    "        Number of independent runs.\n",
    "    T : int\n",
    "        Iterations to pass into dp_gd_2sls.\n",
    "    outdir : str\n",
    "        Directory to save figures.\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    dict with keys:\n",
    "        'beta_ols', 'beta_2sls', 'Theta_ols',\n",
    "        'beta_final', 'Theta_final', 'beta_paths', 'Theta_paths',\n",
    "        't_grid', 'fig_boxplot_path', 'fig_paths_path', 'n_used'\n",
    "    \"\"\"\n",
    "    os.makedirs(outdir, exist_ok=True)\n",
    "\n",
    "    # --- Load up to n_rows rows from SAS file ---\n",
    "    with SAS7BDAT(file_path) as file:\n",
    "        chunk = []\n",
    "        for i, row in enumerate(file):\n",
    "            if i == 0:\n",
    "                columns = row\n",
    "            else:\n",
    "                chunk.append(row)\n",
    "            if len(chunk) >= n_rows:\n",
    "                break\n",
    "    data = pd.DataFrame(chunk, columns=columns)\n",
    "    # Keep rows with WEEKSM != '00', SEX2ND not blank\n",
    "    mask = (data['WEEKSM'] != '00') & (data['SEX2ND'] != '')\n",
    "    if filter_state is not None and 'STATE' in data.columns:\n",
    "        mask &= (data['STATE'] == filter_state)\n",
    "    data = data[mask].copy()\n",
    "\n",
    "    for col in ['SEX2ND', 'KIDCOUNT', 'WEEKSM']:\n",
    "        data[col] = pd.to_numeric(data[col], errors='coerce')\n",
    "    data = data.dropna(subset=['SEX2ND', 'KIDCOUNT', 'WEEKSM'])\n",
    "    print(f\"Using {data.shape[0]} samples after filtering.\")\n",
    "    # Extract arrays\n",
    "    z = data['SEX2ND'].to_numpy(dtype=float)\n",
    "    x = data['KIDCOUNT'].to_numpy(dtype=float)\n",
    "    y = data['WEEKSM'].to_numpy(dtype=float) \n",
    "\n",
    "    if center:\n",
    "        z = z - z.mean()\n",
    "        x = x - x.mean()\n",
    "        y = y - y.mean()\n",
    "\n",
    "    Z = z.reshape(-1, 1)   \n",
    "    X = x.reshape(-1, 1)   \n",
    "    n = y.shape[0]\n",
    "\n",
    "    # Baselines: OLS and 2SLS\n",
    "    Pz = Z @ np.linalg.inv(Z.T @ Z) @ Z.T\n",
    "    Xhat = Pz @ X\n",
    "    beta_2sls = np.linalg.lstsq(Xhat, y, rcond=None)[0]   \n",
    "    beta_ols  = np.linalg.lstsq(X, y, rcond=None)[0]      \n",
    "    Theta_ols = np.linalg.lstsq(Z, X, rcond=None)[0]      \n",
    "\n",
    "    beta_paths_list = []\n",
    "    Theta_paths_list = []\n",
    "    beta_final = np.empty(R, dtype=float)\n",
    "    Theta_final = np.empty(R, dtype=float)\n",
    "\n",
    "    for r in range(R):\n",
    "        seed = base_seed + r\n",
    "        beta_path, Theta_path = dp_gd_2sls(\n",
    "            X, y, Z,\n",
    "            T=T,\n",
    "            rho_1=rho_1,\n",
    "            rho_2=rho_2,\n",
    "            tau=tau,\n",
    "            c0=c0,\n",
    "            C0=C0,\n",
    "            clip_grads=clip_grads,\n",
    "            seed=seed\n",
    "        )\n",
    "        bp = np.asarray(beta_path)\n",
    "        Tp = np.asarray(Theta_path)\n",
    "\n",
    "        # Ensure shapes are (L,1) and (L,1,1)\n",
    "        if bp.ndim == 1:\n",
    "            bp = bp.reshape(-1, 1)\n",
    "        if Tp.ndim == 2:\n",
    "            Tp = Tp.reshape(Tp.shape[0], 1, 1)\n",
    "\n",
    "        # Truncate/align length to min(requested T, returned length)\n",
    "        L = min(T, bp.shape[0], Tp.shape[0])\n",
    "        bp = bp[:L]\n",
    "        Tp = Tp[:L]\n",
    "\n",
    "        beta_paths_list.append(bp)   # (L,1)\n",
    "        Theta_paths_list.append(Tp)  # (L,1,1)\n",
    "\n",
    "        beta_final[r] = bp[-1, 0]\n",
    "        Theta_final[r] = Tp[-1, 0, 0]\n",
    "\n",
    "    beta_paths = np.stack(beta_paths_list, axis=0)     # (R, L, 1)\n",
    "    Theta_paths = np.stack(Theta_paths_list, axis=0)   # (R, L, 1, 1)\n",
    "    L = beta_paths.shape[1]\n",
    "    t_grid = np.arange(L)\n",
    "\n",
    "    # --- Plot 1: boxplot of final DP betas with OLS/2SLS reference lines ---\n",
    "    plt.rcParams.update({\n",
    "        \"font.size\": 16,\n",
    "        \"axes.titlesize\": 16,\n",
    "        \"axes.labelsize\": 16,\n",
    "        \"xtick.labelsize\": 14,\n",
    "        \"ytick.labelsize\": 14,\n",
    "        \"legend.fontsize\": 14,\n",
    "    })\n",
    "    plt.figure(figsize=(7, 6), dpi=300)\n",
    "    plt.boxplot(\n",
    "        beta_final,\n",
    "        vert=True,\n",
    "        labels=[r'$\\beta^{(T)}$'],\n",
    "        medianprops=dict(color=\"black\", linewidth=2)\n",
    "    )\n",
    "    plt.axhline(beta_2sls.item(), linestyle='--', color='blue', linewidth=2, label='2SLS', alpha=0.8)\n",
    "    plt.axhline(beta_ols.item(), linestyle='--', color=\"orange\", linewidth=2, label='OLS', alpha=0.8)\n",
    "    plt.ylabel('Coefficient estimate')\n",
    "    plt.title(f\"Distribution of final estimates across {R} runs\")\n",
    "    plt.legend()\n",
    "    plt.tight_layout()\n",
    "    fig_boxplot_path = os.path.join(outdir, f\"Angrist_Boxplot_rho1={rho_1}_rho2={rho_2}_T={T}.png\")\n",
    "    plt.savefig(fig_boxplot_path, dpi=300)\n",
    "    if show_plots: plt.show()\n",
    "    plt.close()\n",
    "\n",
    "    # --- Plot 2: mean ± std paths for beta and Theta ---\n",
    "    beta_path_mean  = beta_paths.mean(axis=0).squeeze(-1)          \n",
    "    beta_path_std   = beta_paths.std(axis=0).squeeze(-1)          \n",
    "    Theta_path_mean = Theta_paths.mean(axis=0).squeeze()           \n",
    "    Theta_path_std  = Theta_paths.std(axis=0).squeeze()            \n",
    "\n",
    "    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(7, 7), dpi=300)\n",
    "\n",
    "    ax1.plot(t_grid, beta_path_mean, linewidth=2.5, label=r'$\\mathbb{E}[\\beta^{(t)}]$ across runs')\n",
    "    ax1.fill_between(t_grid, beta_path_mean - beta_path_std, beta_path_mean + beta_path_std, alpha=0.2)\n",
    "    ax1.axhline(y=beta_2sls.item(), linestyle=\"--\", linewidth=2, label=r'$\\hat{\\beta}_{2SLS}$')\n",
    "    ax1.set_xlabel('Iteration t')\n",
    "    ax1.set_ylabel(r'Average $\\beta^{(t)}$')\n",
    "    ax1.set_title(r'Average $\\beta^{(t)}$ path')\n",
    "    ax1.legend(loc='best')\n",
    "\n",
    "    ax2.plot(t_grid, Theta_path_mean, linewidth=2.5, label=r'$\\mathbb{E}[\\Theta^{(t)}]$ across runs')\n",
    "    ax2.fill_between(t_grid, Theta_path_mean - Theta_path_std, Theta_path_mean + Theta_path_std, alpha=0.2)\n",
    "    ax2.axhline(y=Theta_ols.item(), linestyle=\"--\", linewidth=2, label=r'$\\hat{\\Theta}_{OLS}$')\n",
    "    ax2.set_xlabel('Iteration t')\n",
    "    ax2.set_ylabel(r'Average $\\Theta^{(t)}$')\n",
    "    ax2.set_title(r'Average $\\Theta^{(t)}$ path')\n",
    "    ax2.legend(loc='best')\n",
    "\n",
    "    plt.tight_layout()\n",
    "    fig_paths_path = os.path.join(outdir, f\"Angrist_Parameter_Path_rho1={rho_1}_rho2={rho_2}_T={T}.png\")\n",
    "    plt.savefig(fig_paths_path, dpi=300)\n",
    "    if show_plots: plt.show()\n",
    "    plt.close(fig)\n",
    "\n",
    "    return {\n",
    "        \"beta_ols\": beta_ols.item(),\n",
    "        \"beta_2sls\": beta_2sls.item(),\n",
    "        \"Theta_ols\": Theta_ols.item(),\n",
    "        \"beta_final\": beta_final,\n",
    "        \"Theta_final\": Theta_final,\n",
    "        \"beta_paths\": beta_paths,          # shape (R, L, 1)\n",
    "        \"Theta_paths\": Theta_paths,        # shape (R, L, 1, 1)\n",
    "        \"t_grid\": t_grid,\n",
    "        \"fig_boxplot_path\": fig_boxplot_path,\n",
    "        \"fig_paths_path\": fig_paths_path,\n",
    "        \"n_used\": n,\n",
    "    }\n",
    "\n",
    "def dp_gd_2sls(\n",
    "    X: np.ndarray,\n",
    "    y: np.ndarray,\n",
    "    Z: np.ndarray,\n",
    "    *,\n",
    "    T: int = 50,\n",
    "    rho_1: float = 0.5,\n",
    "    rho_2: float = 0.5,\n",
    "    tau: float = 0.0,\n",
    "    c0: float = 1.0,\n",
    "    C0: float = 1.0,\n",
    "    clip_grads: bool = True,\n",
    "    Theta_init: np.ndarray = None,\n",
    "    beta_init: np.ndarray = None,\n",
    "    seed: int = None,\n",
    ") -> np.ndarray:\n",
    "    \"\"\"\n",
    "    Differentially-Private GD-2SLS (DP-GD-2SLS).\n",
    "\n",
    "    Minimizes, with noise:\n",
    "        min_Theta  (1/2n)||Z Theta - X||_F^2\n",
    "        min_beta   (1/2n)||Z Theta beta - y||_2^2\n",
    "\n",
    "    Args:\n",
    "        X: (n, p) regressors.\n",
    "        y: (n,) response.\n",
    "        Z: (n, q) instruments.\n",
    "        T: total GD iterations.\n",
    "        rho_1, rho_2: privacy budget splits (positive). Only their ratio matters for scaling.\n",
    "        tau, C0, c1, c2: constants used in step-size / noise scaling (as in your sim code).\n",
    "        clip_grads: clip the Theta gradient (Fro) at γ1 and the beta gradient (ℓ2) at γ2.\n",
    "        Theta_init: optional (q, p) start; default zeros.\n",
    "        beta_init: optional (p,) start; default zeros.\n",
    "        rng: np.random.Generator for reproducibility.\n",
    "\n",
    "    Returns:\n",
    "        beta_hat: (p,) final estimate for β.\n",
    "    \"\"\"\n",
    "    if seed is not None:\n",
    "        np.random.seed(seed)\n",
    "    n, p = X.shape\n",
    "    n_z, q = Z.shape\n",
    "    assert n_z == n, \"X and Z must have same number of rows\"\n",
    "    assert y.shape[0] == n, \"y length must match X rows\"\n",
    "\n",
    "    # Initialize\n",
    "    Theta_t = np.zeros((q, p)) if Theta_init is None else Theta_init.copy()\n",
    "    beta_t  = np.zeros(p)       if beta_init  is None else beta_init.copy()\n",
    "    Theta_path = np.empty((int(T), q, p), dtype=float)\n",
    "    beta_path  = np.empty((int(T), p),    dtype=float)\n",
    "\n",
    "    Theta_hat = np.linalg.pinv(Z) @ X\n",
    "    alpha = 1 / np.linalg.norm(Theta_hat, 2) ** 2\n",
    "\n",
    "    # We use an empirical estimate of σ_z \n",
    "    sigma_z_est = np.linalg.norm(Z, \"fro\") / np.sqrt(n * q)\n",
    "\n",
    "    # γ1, γ2: clipping thresholds \n",
    "    gamma_1 = c0 * (np.sqrt(q) + np.sqrt(tau)) ** 2\n",
    "    gamma_2 = c0 * (np.sqrt(q) + np.sqrt(tau)) ** 2\n",
    "\n",
    "    delta = (C0 * (sigma_z_est**2) * (np.sqrt(q) + np.sqrt(tau))) / np.sqrt(n)\n",
    "    eta = 1.0 / (1.0 + delta) ** 2\n",
    "    eta = 1.0\n",
    "    \n",
    "    # Noise scales \n",
    "    lambda_1 = (2.0 * gamma_1 / n) * np.sqrt(T / rho_1)\n",
    "    lambda_2 = (2.0 * gamma_2 / n) * np.sqrt(T / rho_2)\n",
    "    for t in range(int(T)):\n",
    "        # ---- Theta step ----\n",
    "        grad_Theta = Z.T @ (Z @ Theta_t - X) \n",
    "        if clip_grads:\n",
    "            gnorm = np.linalg.norm(grad_Theta, \"fro\") / n\n",
    "            if gnorm > gamma_1:\n",
    "                grad_Theta = grad_Theta * (gamma_1 / gnorm)\n",
    "\n",
    "        # DP noise for Theta update\n",
    "        Xi = np.random.randn(q, p) * lambda_1\n",
    "        Theta_t = Theta_t - eta / n * grad_Theta + eta * Xi\n",
    "\n",
    "        X_hat = Z @ Theta_t\n",
    "\n",
    "        grad_beta = X_hat.T @ (X_hat @ beta_t - y)\n",
    "\n",
    "        if clip_grads:\n",
    "            bgnorm = np.linalg.norm(grad_beta) / n\n",
    "            if bgnorm > gamma_2 and bgnorm > 0:\n",
    "                grad_beta = grad_beta * (gamma_2 / bgnorm)\n",
    "\n",
    "        # DP noise for beta\n",
    "        nu = np.random.randn(p) * lambda_2\n",
    "        beta_t = beta_t - alpha / n * grad_beta + alpha * nu\n",
    "        Theta_path[t] = Theta_t\n",
    "        beta_path[t] = beta_t\n",
    "    # print(alpha*lambda_2, alpha)\n",
    "    return beta_path, Theta_path\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d007d3d4",
   "metadata": {},
   "source": [
    "## Figure 5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ed1e42a6",
   "metadata": {},
   "outputs": [],
   "source": [
    "results = run_angrist_dp_experiment(\n",
    "    file_path=\"AngEv98/m_d_806.sas7bdat\",\n",
    "    R=1000,\n",
    "    T=20,\n",
    "    rho_1=1,\n",
    "    rho_2=1,\n",
    "    outdir=\"Figures\",\n",
    "    show_plots=True\n",
    ")\n",
    "\n",
    "print(results[\"fig_boxplot_path\"])\n",
    "print(results[\"fig_paths_path\"])\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a8f01b50",
   "metadata": {},
   "source": [
    "# Figure 8"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b6b72589",
   "metadata": {},
   "outputs": [],
   "source": [
    "results = run_angrist_dp_experiment(\n",
    "    file_path=\"AngEv98/m_d_806.sas7bdat\",\n",
    "    R=1000,\n",
    "    T=20,\n",
    "    rho_1=0.1,\n",
    "    rho_2=0.1,\n",
    "    outdir=\"Figures\",\n",
    "    show_plots=True\n",
    ")\n",
    "\n",
    "print(results[\"fig_boxplot_path\"])\n",
    "print(results[\"fig_paths_path\"])\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "46dbd7b7",
   "metadata": {},
   "source": [
    "# Figure 9"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6768673e",
   "metadata": {},
   "outputs": [],
   "source": [
    "results = run_angrist_dp_experiment(\n",
    "    file_path=\"AngEv98/m_d_806.sas7bdat\",\n",
    "    R=1000,\n",
    "    T=20,\n",
    "    rho_1=10,\n",
    "    rho_2=10,\n",
    "    outdir=\"Figures\",\n",
    "    show_plots=True\n",
    ")\n",
    "\n",
    "print(results[\"fig_boxplot_path\"])\n",
    "print(results[\"fig_paths_path\"])\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8ffb112c",
   "metadata": {},
   "source": [
    "# New experiments"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1a3659c4",
   "metadata": {},
   "source": [
    "# Card Dataset (Multiple IVs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e430aef0",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import os\n",
    "import matplotlib.pyplot as plt\n",
    "from typing import Dict, Any\n",
    "\n",
    "def run_dp_2sls_experiment(\n",
    "    X: np.ndarray,\n",
    "    y: np.ndarray,\n",
    "    Z: np.ndarray,\n",
    "    *,\n",
    "    R: int = 1000,\n",
    "    T: int = 20,\n",
    "    rho_1: float = 1_000_000.0,\n",
    "    rho_2: float = 1.0,\n",
    "    tau: float = 0.0,\n",
    "    c0: float = 1.0,\n",
    "    C0: float = 1.0,\n",
    "    clip_grads: bool = False,\n",
    "    beta_index: int = 0,         \n",
    "    outdir: str = \"Figures\",\n",
    "    out_prefix: str = \"DP2SLS\",\n",
    "    show_plots: bool = False,\n",
    "    base_seed: int = 0,\n",
    ") -> Dict[str, Any]:\n",
    "    \"\"\"\n",
    "    Generic DP-GD-2SLS experiment:\n",
    "\n",
    "      - Only produces boxplot for beta^{(T)}[beta_index]\n",
    "      - Only produces learning path for beta^{(t)}[beta_index]\n",
    "\n",
    "    Args:\n",
    "        X: (n, p) regressors.\n",
    "        y: (n,) outcome.\n",
    "        Z: (n, q) instruments.\n",
    "        beta_index: which coordinate of beta to use for plots (default 0).\n",
    "\n",
    "    Returns:\n",
    "        dict with:\n",
    "          'beta_ols', 'beta_2sls',\n",
    "          'beta_final', 'beta_paths', 't_grid',\n",
    "          'fig_boxplot_path', 'fig_path_path', 'n_used'\n",
    "    \"\"\"\n",
    "    os.makedirs(outdir, exist_ok=True)\n",
    "\n",
    "    n, p = X.shape\n",
    "    assert y.shape[0] == n\n",
    "    assert Z.shape[0] == n\n",
    "    assert 0 <= beta_index < p\n",
    "\n",
    "    # Baselines: OLS and 2SLS\n",
    "    beta_ols = np.linalg.lstsq(X, y, rcond=None)[0]  # (p,)\n",
    "\n",
    "    Pz = Z @ np.linalg.inv(Z.T @ Z) @ Z.T\n",
    "    Xhat = Pz @ X\n",
    "    beta_2sls = np.linalg.lstsq(Xhat, y, rcond=None)[0]  # (p,)\n",
    "\n",
    "    beta_paths_list = []\n",
    "    beta_final = np.empty(R, dtype=float)\n",
    "\n",
    "    for r in range(R):\n",
    "        seed = base_seed + r\n",
    "        beta_path, _ = dp_gd_2sls(\n",
    "            X, y, Z,\n",
    "            T=T,\n",
    "            rho_1=rho_1,\n",
    "            rho_2=rho_2,\n",
    "            tau=tau,\n",
    "            c0=c0,\n",
    "            C0=C0,\n",
    "            clip_grads=clip_grads,\n",
    "            seed=seed,\n",
    "        )\n",
    "        bp = np.asarray(beta_path)          \n",
    "        if bp.ndim == 1:\n",
    "            bp = bp.reshape(-1, p)\n",
    "\n",
    "        L = min(T, bp.shape[0])\n",
    "        bp = bp[:L]                         # (L, p)\n",
    "        beta_paths_list.append(bp)\n",
    "\n",
    "        beta_final[r] = bp[-1, beta_index]  # store final coord\n",
    "\n",
    "    beta_paths = np.stack(beta_paths_list, axis=0)   # (R, L, p)\n",
    "    L = beta_paths.shape[1]\n",
    "    t_grid = np.arange(L)\n",
    "\n",
    "    plt.rcParams.update({\n",
    "        \"font.size\": 16,\n",
    "        \"axes.titlesize\": 16,\n",
    "        \"axes.labelsize\": 16,\n",
    "        \"xtick.labelsize\": 14,\n",
    "        \"ytick.labelsize\": 14,\n",
    "        \"legend.fontsize\": 14,\n",
    "    })\n",
    "    plt.figure(figsize=(7, 6), dpi=300)\n",
    "    plt.boxplot(\n",
    "        beta_final,\n",
    "        vert=True,\n",
    "        labels=[fr'$\\beta^{{(T)}}$'],\n",
    "        medianprops=dict(color=\"black\", linewidth=2)\n",
    "    )\n",
    "    plt.axhline(beta_2sls[beta_index].item(), linestyle='--',\n",
    "                color='blue', linewidth=2, label='2SLS', alpha=0.8)\n",
    "    plt.axhline(beta_ols[beta_index].item(), linestyle='--',\n",
    "                color='orange', linewidth=2, label='OLS', alpha=0.8)\n",
    "    plt.ylabel('Coefficient estimate')\n",
    "    plt.title(f\"Distribution of final estimates across {R} runs\")\n",
    "    plt.ylim(-0.05, 0.20)\n",
    "    plt.legend()\n",
    "    plt.tight_layout()\n",
    "    fig_boxplot_path = os.path.join(\n",
    "        outdir, f\"{out_prefix}_Boxplot_rho1={rho_1}_rho2={rho_2}_T={T}.png\"\n",
    "    )\n",
    "    plt.savefig(fig_boxplot_path, dpi=300)\n",
    "    if show_plots:\n",
    "        plt.show()\n",
    "    plt.close()\n",
    "\n",
    "    beta_coord_paths = beta_paths[:, :, beta_index]      # (R, L)\n",
    "    beta_path_mean = beta_coord_paths.mean(axis=0)       # (L,)\n",
    "    beta_path_std  = beta_coord_paths.std(axis=0)        # (L,)\n",
    "\n",
    "    plt.figure(figsize=(7, 5), dpi=300)\n",
    "    plt.plot(t_grid, beta_path_mean, linewidth=2.5,\n",
    "             label=fr'$\\mathbb{{E}}[\\beta^{{(t)}}]$ across runs')\n",
    "    plt.fill_between(\n",
    "        t_grid,\n",
    "        beta_path_mean - beta_path_std,\n",
    "        beta_path_mean + beta_path_std,\n",
    "        alpha=0.2\n",
    "    )\n",
    "    plt.axhline(\n",
    "        y=beta_2sls[beta_index].item(),\n",
    "        linestyle=\"--\", linewidth=2,\n",
    "        label=fr'$\\hat{{\\beta}}_{{2SLS}}$'\n",
    "    )\n",
    "    plt.xlabel('Iteration t')\n",
    "    plt.ylabel(fr'Average $\\beta^{{(t)}}$')\n",
    "    plt.title(fr'Average $\\beta^{{(t)}}$ path')\n",
    "    plt.ylim(-0.08, 0.24)\n",
    "    plt.legend(loc='best')\n",
    "    plt.tight_layout()\n",
    "    fig_path_path = os.path.join(\n",
    "        outdir, f\"{out_prefix}_BetaPath_rho1={rho_1}_rho2={rho_2}_T={T}.png\"\n",
    "    )\n",
    "    plt.savefig(fig_path_path, dpi=300)\n",
    "    if show_plots:\n",
    "        plt.show()\n",
    "    plt.close()\n",
    "\n",
    "    return {\n",
    "        \"beta_ols\": beta_ols,\n",
    "        \"beta_2sls\": beta_2sls,\n",
    "        \"beta_final\": beta_final,\n",
    "        \"beta_paths\": beta_paths,  \n",
    "        \"t_grid\": t_grid,\n",
    "        \"fig_boxplot_path\": fig_boxplot_path,\n",
    "        \"fig_path_path\": fig_path_path,\n",
    "        \"n_used\": n,\n",
    "    }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8c246128",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "from linearmodels.datasets import card\n",
    "\n",
    "df = card.load()\n",
    "vars_needed = [\n",
    "        \"wage\",\n",
    "        \"educ\", \"exper\", \"KWW\",\n",
    "        \"nearc4\", \"nearc2\", \"fatheduc\", \"motheduc\",\n",
    "        \"reg661\", \"reg662\", \"reg663\", \"reg664\", \"reg665\",\n",
    "        \"reg666\", \"reg667\", \"reg668\", \"reg669\",\n",
    "        \"south66\", \"smsa66\",\n",
    "    ]\n",
    "data = df[vars_needed].dropna().copy()\n",
    "\n",
    "\n",
    "y = np.log(data[\"wage\"].to_numpy(dtype=float))   \n",
    "X_cols = [\"educ\"]\n",
    "Z_cols = [\"nearc4\", \"nearc2\", \"fatheduc\", \"motheduc\"]\n",
    "\n",
    "X = data[X_cols].to_numpy(dtype=float)          \n",
    "Z = data[Z_cols].to_numpy(dtype=float)         \n",
    "\n",
    "y = y - y.mean()\n",
    "X = X - X.mean(axis=0)\n",
    "Z = Z - Z.mean(axis=0)\n",
    "\n",
    "Z[:, 2:] = Z[:, 2:] / Z[:, 2:].std(axis=0)\n",
    "res_card = run_dp_2sls_experiment(\n",
    "    X, y, Z,\n",
    "    R=1000,\n",
    "    T=15,\n",
    "    rho_1=10,\n",
    "    rho_2=10,\n",
    "    c0=1.0,\n",
    "    C0=1.0,\n",
    "    clip_grads=False,\n",
    "    beta_index=0,                \n",
    "    outdir=\"Figures\",\n",
    "    out_prefix=\"Card1995\",\n",
    "    show_plots=True,\n",
    "    base_seed=0,\n",
    ")\n",
    "\n",
    "print(\"OLS beta (educ):\", res_card[\"beta_ols\"][0])\n",
    "print(\"2SLS beta (educ):\", res_card[\"beta_2sls\"][0])\n",
    "print(\"Saved boxplot to:\", res_card[\"fig_boxplot_path\"])\n",
    "print(\"Saved beta path plot to:\", res_card[\"fig_path_path\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "de28ec0e",
   "metadata": {},
   "source": [
    "# Tuning step size"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fe1b54ef",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from tqdm import tqdm\n",
    "\n",
    "def simulate_alg1_vs_alpha(\n",
    "    n=5000,\n",
    "    T=50,\n",
    "    alpha_list=None,\n",
    "    repeats=10,\n",
    "    Theta_true=None,\n",
    "    p=10, q=10,\n",
    "    sigma_z=1.0, sigma_1=1.0, sigma_2=1.0,\n",
    "    C=1.0, C0=1.0, c=1.0, c0=1.0, c1=1.0, c2=1.0,\n",
    "    tau=5.0,\n",
    "    rho_1=0.9, rho_2=0.1,\n",
    "    seed=0,\n",
    "):\n",
    "    \"\"\"\n",
    "    Fix n, T and sweep alpha.\n",
    "    Return: mean/std error over repeats for Algorithm 1 vs alpha, plus 2SLS baseline.\n",
    "    Dataset is re-sampled per repeat, and reused across all alpha within that repeat.\n",
    "    \"\"\"\n",
    "    rng = np.random.default_rng(seed)\n",
    "\n",
    "    if Theta_true is None:\n",
    "        Theta_true = 5 * np.eye(q, p) + rng.standard_normal((q, p))\n",
    "    beta_true = rng.standard_normal(p)\n",
    "\n",
    "    sigma_min_Theta = np.linalg.svd(Theta_true, compute_uv=False).min()\n",
    "    sigma_max_Theta = np.linalg.norm(Theta_true, 2)\n",
    "\n",
    "    R = (np.sqrt(q) + np.sqrt(tau))**2 * (np.sqrt(p*q) + np.sqrt(p*(np.log(p)+tau)))\n",
    "\n",
    "    delta = (C0 * sigma_z**2 * (np.sqrt(q) + np.sqrt(tau))) / np.sqrt(n)\n",
    "    Psi_norm = (sigma_z * sigma_2 * np.sqrt(c0 * p * q * (tau + np.log(2*p*q)))) / (np.sqrt(n) * (1 - delta)**2)\n",
    "    if delta >= 1:\n",
    "        print(f\"[warn] delta >= 1 for n={n:.0f} (delta={delta:.3f})\")\n",
    "    if sigma_min_Theta < Psi_norm:\n",
    "        print(f\"[warn] sigma_min_Theta < Psi_norm for n={n:.0f} (min={sigma_min_Theta:.3f}, Psi={Psi_norm:.3f})\")\n",
    "\n",
    "    gamma_floor = (1 - delta)**2 * (sigma_min_Theta - Psi_norm)**2\n",
    "    gamma_ceil = (1 + delta)**2 * (sigma_max_Theta + Psi_norm)**2\n",
    "\n",
    "    # \"theoretical\" alpha* used as center for sweep\n",
    "    alpha_star = 2 / (gamma_ceil + gamma_floor)\n",
    "    alpha_max = 4 / (2*gamma_ceil + gamma_floor)\n",
    "    eta_baseline = 1.8 / ((1 + delta)**2)\n",
    "\n",
    "    if alpha_list is None:\n",
    "        alpha_list = alpha_star * np.linspace(0.1, alpha_max / alpha_star*1.2, int(np.ceil(alpha_max / alpha_star)) * 10)\n",
    "    alpha_list = np.array(alpha_list, dtype=float)\n",
    "\n",
    "    gamma_1 = c1 * (np.sqrt(q) + np.sqrt(tau + np.log(n*T))) ** 2\n",
    "    gamma_2 = c2 * (np.sqrt(q) + np.sqrt(tau + np.log(n*T))) ** 2\n",
    "\n",
    "    rho = rho_1 + rho_2\n",
    "    lambda_1 = (2 * gamma_1 / n) * np.sqrt(T / rho_1)\n",
    "    lambda_2 = (2 * gamma_2 / n) * np.sqrt(T / rho_2)\n",
    "\n",
    "    errors_alg1_all_alpha = {a: [] for a in alpha_list}\n",
    "    errors_2sls_all_alpha = {a: [] for a in alpha_list}\n",
    "\n",
    "    for rep in tqdm(range(repeats), desc=f\"n={n}, T={T}, sweep alpha\"):\n",
    "        Z = rng.standard_normal((n, q))\n",
    "        U = rng.standard_normal((n, q))\n",
    "        Phi = rng.standard_normal((q, p))\n",
    "        phi = rng.standard_normal(q)\n",
    "        X = Z @ Theta_true + U @ Phi + sigma_1 * rng.standard_normal((n, p))\n",
    "        y = X @ beta_true + U @ phi + sigma_2 * rng.standard_normal(n)\n",
    "\n",
    "        # 2SLS baseline \n",
    "        Pz = Z @ np.linalg.inv(Z.T @ Z) @ Z.T\n",
    "        X_hat = Pz @ X\n",
    "        beta_2sls = np.linalg.inv(X_hat.T @ X_hat) @ (X_hat.T @ y)\n",
    "        err_2sls = np.linalg.norm(beta_2sls - beta_true)\n",
    "\n",
    "        for alpha in alpha_list:\n",
    "            Theta_t = np.zeros((q, p))\n",
    "            beta_t  = np.zeros(p)\n",
    "\n",
    "            for _t in range(T):\n",
    "                # Theta step\n",
    "                grad_Theta_t = Z.T @ (Z @ Theta_t - X)\n",
    "                Xi = rng.standard_normal((q, p)) * lambda_1\n",
    "                Theta_t = Theta_t - (eta_baseline / n) * grad_Theta_t + eta_baseline * Xi\n",
    "\n",
    "                # beta step\n",
    "                grad_beta_t = (Z @ Theta_t).T @ (Z @ Theta_t @ beta_t - y)\n",
    "                nu = rng.standard_normal(p) * lambda_2\n",
    "                beta_t = beta_t - (alpha / n) * grad_beta_t + alpha * nu\n",
    "\n",
    "            errors_alg1_all_alpha[alpha].append(np.linalg.norm(beta_t - beta_true))\n",
    "            errors_2sls_all_alpha[alpha].append(err_2sls)\n",
    "\n",
    "    err_alg1_mean = np.array([np.mean(errors_alg1_all_alpha[a]) for a in alpha_list])\n",
    "    err_alg1_std  = np.array([np.std(errors_alg1_all_alpha[a])  for a in alpha_list])\n",
    "    err_2sls_mean = np.array([np.mean(errors_2sls_all_alpha[a]) for a in alpha_list])\n",
    "    err_2sls_std  = np.array([np.std(errors_2sls_all_alpha[a])  for a in alpha_list])\n",
    "\n",
    "    return {\n",
    "        'n': n,\n",
    "        'T': T,\n",
    "        'alpha_list': alpha_list,\n",
    "        'errors_alg1': err_alg1_mean,\n",
    "        'stds_alg1': err_alg1_std,\n",
    "        'errors_2sls': err_2sls_mean,\n",
    "        'stds_2sls': err_2sls_std,\n",
    "        'rho1': rho_1,\n",
    "        'rho2': rho_2,\n",
    "        'alpha_star': alpha_star,\n",
    "        'alpha_max': alpha_max,\n",
    "        'eta_baseline': eta_baseline,\n",
    "    }\n",
    "\n",
    "\n",
    "def plot_alg1_vs_alpha(\n",
    "    n=5000,\n",
    "    T=50,\n",
    "    alpha_list=None,\n",
    "    repeats=10,\n",
    "    rho_1=0.9, rho_2=0.1,\n",
    "    common_kwargs=None,\n",
    "    seed=0,\n",
    "    plt_title=\"Algorithm 1 Error vs alpha (T fixed)\",\n",
    "    save_path=None,\n",
    "    log_x=True,\n",
    "):\n",
    "    if common_kwargs is None:\n",
    "        common_kwargs = {}\n",
    "\n",
    "    res = simulate_alg1_vs_alpha(\n",
    "        n=n,\n",
    "        T=T,\n",
    "        alpha_list=alpha_list,\n",
    "        repeats=repeats,\n",
    "        rho_1=rho_1,\n",
    "        rho_2=rho_2,\n",
    "        seed=seed,\n",
    "        **common_kwargs\n",
    "    )\n",
    "\n",
    "    alphas = res['alpha_list']\n",
    "\n",
    "    plt.errorbar(\n",
    "        alphas,\n",
    "        res['errors_alg1'],\n",
    "        yerr=res['stds_alg1'],\n",
    "        fmt='-^',\n",
    "        capsize=3,\n",
    "        label='Algorithm 1'\n",
    "    )\n",
    "\n",
    "    plt.errorbar(\n",
    "        alphas,\n",
    "        res['errors_2sls'],\n",
    "        yerr=res['stds_2sls'],\n",
    "        fmt='--s',\n",
    "        capsize=3,\n",
    "        label='2SLS baseline'\n",
    "    )\n",
    "\n",
    "    if log_x:\n",
    "        plt.xscale('log')\n",
    "\n",
    "    plt.xlabel(r'stepsize $\\alpha$')\n",
    "    plt.ylabel(r'$\\|\\beta^{(T)}-\\beta_{\\mathrm{true}}\\|$')\n",
    "    plt.title(plt_title)\n",
    "    plt.ylim(0, 1)\n",
    "    plt.axvline(res['alpha_star'], color='green', linestyle='--', linewidth=2)\n",
    "    plt.text(\n",
    "        res['alpha_star']*0.9,            \n",
    "        plt.ylim()[1]*0.65,        \n",
    "        r'$\\alpha_{approx}^\\star$',            \n",
    "        color='green',\n",
    "        ha='center', va='top', \n",
    "        fontsize=12\n",
    "    )\n",
    "    plt.axvline(res['alpha_max'], color='red', linestyle='--', linewidth=2)\n",
    "    plt.text(\n",
    "        res['alpha_max']*1.07,            \n",
    "        plt.ylim()[1]*0.65,        \n",
    "        r'$\\alpha_{max}$',            \n",
    "        color='red',\n",
    "        ha='center', va='top', \n",
    "        fontsize=12\n",
    "    )\n",
    "    plt.grid(True)\n",
    "    plt.legend()\n",
    "    plt.tight_layout()\n",
    "    if save_path is not None:\n",
    "        plt.savefig(save_path, dpi=300)\n",
    "    plt.show()\n",
    "\n",
    "def simulate_alg1_vs_eta(\n",
    "    n=5000,\n",
    "    T=50,\n",
    "    eta_list=None,\n",
    "    repeats=10,\n",
    "    Theta_true=None,\n",
    "    p=10, q=10,\n",
    "    sigma_z=1.0, sigma_1=1.0, sigma_2=1.0,\n",
    "    C=1.0, C0=1.0, c=1.0, c0=1.0, c1=1.0, c2=1.0,\n",
    "    tau=0.0,\n",
    "    rho_1=0.9, rho_2=0.1,\n",
    "    seed=0,\n",
    "):\n",
    "    \"\"\"\n",
    "    Fix n, T and sweep eta.\n",
    "    Return: mean/std error over repeats for Algorithm 1 vs eta, plus 2SLS baseline.\n",
    "    Dataset is re-sampled per repeat, and reused across all eta within that repeat.\n",
    "    \"\"\"\n",
    "    rng = np.random.default_rng(seed)\n",
    "\n",
    "    if Theta_true is None:\n",
    "        Theta_true = 5 * np.eye(q, p) + rng.standard_normal((q, p))\n",
    "    beta_true = rng.standard_normal(p)\n",
    "\n",
    "    sigma_min_Theta = np.linalg.svd(Theta_true, compute_uv=False).min()\n",
    "    sigma_max_Theta = np.linalg.norm(Theta_true, 2)\n",
    "\n",
    "    R = (np.sqrt(q) + np.sqrt(tau))**2 * (np.sqrt(p*q) + np.sqrt(p*(np.log(p)+tau)))\n",
    "\n",
    "    delta = (C0 * sigma_z**2 * (np.sqrt(q) + np.sqrt(tau))) / np.sqrt(n)\n",
    "    Psi_norm = (sigma_z * sigma_2 * np.sqrt(c0 * p * q * (tau + np.log(2*p*q)))) / (np.sqrt(n) * (1 - delta)**2)\n",
    "    if delta >= 1:\n",
    "        print(f\"[warn] delta >= 1 for n={n:.0f} (delta={delta:.3f})\")\n",
    "    if sigma_min_Theta < Psi_norm:\n",
    "        print(f\"[warn] sigma_min_Theta < Psi_norm for n={n:.0f} (min={sigma_min_Theta:.3f}, Psi={Psi_norm:.3f})\")\n",
    "\n",
    "    gamma_floor = (1 - delta)**2 * (sigma_min_Theta - Psi_norm)**2\n",
    "    gamma_ceil = (1 + delta)**2 * (sigma_max_Theta + Psi_norm)**2\n",
    "\n",
    "    # alpha* fixed while sweeping eta\n",
    "    alpha_baseline = 3.6 / (2*gamma_ceil + gamma_floor)\n",
    "\n",
    "    # eta*: used as center of sweep\n",
    "    eta_star = 2 / ((1 + delta)**2 + (1 - delta)**2)\n",
    "    eta_max = 2 / (1 + delta)**2\n",
    "    if eta_list is None:\n",
    "        eta_list = eta_star * np.linspace(0.1, eta_max / eta_star*0.97, int(np.ceil(eta_max / eta_star*10)))\n",
    "    eta_list = np.array(eta_list, dtype=float)\n",
    "\n",
    "    gamma_1 = c1 * (np.sqrt(q) + np.sqrt(tau + np.log(n*T))) ** 2\n",
    "    gamma_2 = c2 * (np.sqrt(q) + np.sqrt(tau + np.log(n*T))) ** 2\n",
    "\n",
    "    rho = rho_1 + rho_2\n",
    "    lambda_1 = (2 * gamma_1 / n) * np.sqrt(T / rho_1)\n",
    "    lambda_2 = (2 * gamma_2 / n) * np.sqrt(T / rho_2)\n",
    "\n",
    "    errors_alg1_all_eta = {e: [] for e in eta_list}\n",
    "    errors_2sls_all_eta = {e: [] for e in eta_list}\n",
    "\n",
    "    for rep in tqdm(range(repeats), desc=f\"n={n}, T={T}, sweep eta\"):\n",
    "        Z = rng.standard_normal((n, q))\n",
    "        U = rng.standard_normal((n, q))\n",
    "        Phi = rng.standard_normal((q, p))\n",
    "        phi = rng.standard_normal(q)\n",
    "        X = Z @ Theta_true + U @ Phi + sigma_1 * rng.standard_normal((n, p))\n",
    "        y = X @ beta_true + U @ phi + sigma_2 * rng.standard_normal(n)\n",
    "\n",
    "        # 2SLS baseline \n",
    "        Pz = Z @ np.linalg.inv(Z.T @ Z) @ Z.T\n",
    "        X_hat = Pz @ X\n",
    "        beta_2sls = np.linalg.inv(X_hat.T @ X_hat) @ (X_hat.T @ y)\n",
    "        err_2sls = np.linalg.norm(beta_2sls - beta_true)\n",
    "\n",
    "        for eta in eta_list:\n",
    "            Theta_t = np.zeros((q, p))\n",
    "            beta_t  = np.zeros(p)\n",
    "\n",
    "            for _t in range(T):\n",
    "                # Theta step\n",
    "                grad_Theta_t = Z.T @ (Z @ Theta_t - X)\n",
    "                Xi = rng.standard_normal((q, p)) * lambda_1\n",
    "                Theta_t = Theta_t - (eta / n) * grad_Theta_t + eta * Xi\n",
    "\n",
    "                # beta step\n",
    "                grad_beta_t = (Z @ Theta_t).T @ (Z @ Theta_t @ beta_t - y)\n",
    "                nu = rng.standard_normal(p) * lambda_2\n",
    "                beta_t = beta_t - (alpha_baseline / n) * grad_beta_t + alpha_baseline * nu\n",
    "\n",
    "            errors_alg1_all_eta[eta].append(np.linalg.norm(beta_t - beta_true))\n",
    "            errors_2sls_all_eta[eta].append(err_2sls)\n",
    "\n",
    "    err_alg1_mean = np.array([np.mean(errors_alg1_all_eta[e]) for e in eta_list])\n",
    "    err_alg1_std  = np.array([np.std(errors_alg1_all_eta[e])  for e in eta_list])\n",
    "    err_2sls_mean = np.array([np.mean(errors_2sls_all_eta[e]) for e in eta_list])\n",
    "    err_2sls_std  = np.array([np.std(errors_2sls_all_eta[e])  for e in eta_list])\n",
    "\n",
    "    return {\n",
    "        'n': n,\n",
    "        'T': T,\n",
    "        'eta_list': eta_list,\n",
    "        'errors_alg1': err_alg1_mean,\n",
    "        'stds_alg1': err_alg1_std,\n",
    "        'errors_2sls': err_2sls_mean,\n",
    "        'stds_2sls': err_2sls_std,\n",
    "        'rho1': rho_1,\n",
    "        'rho2': rho_2,\n",
    "        'alpha_baseline': alpha_baseline,\n",
    "        'eta_star': eta_star,\n",
    "        'eta_max': eta_max,\n",
    "    }\n",
    "\n",
    "\n",
    "def plot_alg1_vs_eta(\n",
    "    n=5000,\n",
    "    T=50,\n",
    "    eta_list=None,\n",
    "    repeats=10,\n",
    "    rho_1=0.9, rho_2=0.1,\n",
    "    common_kwargs=None,\n",
    "    seed=0,\n",
    "    plt_title=\"Algorithm 1 Error vs eta (T fixed)\",\n",
    "    save_path=None,\n",
    "    log_x=True,\n",
    "):\n",
    "    if common_kwargs is None:\n",
    "        common_kwargs = {}\n",
    "\n",
    "    res = simulate_alg1_vs_eta(\n",
    "        n=n,\n",
    "        T=T,\n",
    "        eta_list=eta_list,\n",
    "        repeats=repeats,\n",
    "        rho_1=rho_1,\n",
    "        rho_2=rho_2,\n",
    "        seed=seed,\n",
    "        **common_kwargs\n",
    "    )\n",
    "\n",
    "    etas = res['eta_list']\n",
    "\n",
    "    plt.errorbar(\n",
    "        etas,\n",
    "        res['errors_alg1'],\n",
    "        yerr=res['stds_alg1'],\n",
    "        fmt='-^',\n",
    "        capsize=3,\n",
    "        label='Algorithm 1'\n",
    "    )\n",
    "\n",
    "    plt.errorbar(\n",
    "        etas,\n",
    "        res['errors_2sls'],\n",
    "        yerr=res['stds_2sls'],\n",
    "        fmt='--s',\n",
    "        capsize=3,\n",
    "        label='2SLS baseline'\n",
    "    )\n",
    "\n",
    "    if log_x:\n",
    "        plt.xscale('log')\n",
    "\n",
    "    plt.xlabel(r'stepsize $\\eta$')\n",
    "    plt.ylabel(r'$\\|\\beta^{(T)}-\\beta_{\\mathrm{true}}\\|$')\n",
    "    plt.title(plt_title)\n",
    "    plt.ylim(0, 0.2)\n",
    "    plt.axvline(res['eta_star'], color='green', linestyle='--', linewidth=2)\n",
    "    plt.text(\n",
    "        res['eta_star']*0.86,            \n",
    "        plt.ylim()[1]*0.65,        \n",
    "        r'$\\eta_{approx}^\\star$',            \n",
    "        color='green',\n",
    "        ha='center', va='top', \n",
    "        fontsize=12\n",
    "    )\n",
    "    plt.axvline(res['eta_max'], color='red', linestyle='--', linewidth=2)\n",
    "    plt.text(\n",
    "        res['eta_max']*1.06,            \n",
    "        plt.ylim()[1]*0.65,        \n",
    "        r'$\\eta_{max}$',            \n",
    "        color='red',\n",
    "        ha='center', va='top', \n",
    "        fontsize=12\n",
    "    )\n",
    "    plt.grid(True)\n",
    "    plt.legend()\n",
    "    plt.tight_layout()\n",
    "    if save_path is not None:\n",
    "        plt.savefig(save_path, dpi=300)\n",
    "    plt.show()\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "699ad0a2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Fix T and study alpha\n",
    "plot_alg1_vs_alpha(\n",
    "    n=2000,\n",
    "    T=20,\n",
    "    repeats=100,\n",
    "    rho_1=5,\n",
    "    rho_2=5,\n",
    "    log_x=False, \n",
    "    common_kwargs=dict(\n",
    "        p=5, q=5,\n",
    "    ),\n",
    "    plt_title=\"Algorithm 1 error vs alpha (p=5, q=5)\",\n",
    "    save_path=\"Figures/Alg1_vs_alpha_p=5_q=5_T=20.png\"\n",
    ")\n",
    "\n",
    "# Fix T and study eta\n",
    "plot_alg1_vs_eta(\n",
    "    n=2000,\n",
    "    T=20,\n",
    "    repeats=100,\n",
    "    rho_1=5,\n",
    "    rho_2=5,\n",
    "    log_x=False,\n",
    "    common_kwargs=dict(\n",
    "        p=5, q=5,\n",
    "    ),\n",
    "    plt_title=\"Algorithm 1 error vs eta (p=5, q=5)\",\n",
    "    save_path=\"Figures/Alg1_vs_eta_p=5_q=5_T=20.png\"\n",
    ")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cd240588",
   "metadata": {},
   "source": [
    "# Clipping threshold experiment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bd5efaf8",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "\n",
    "def clip_by_norm(v, gamma):\n",
    "    \"\"\"\n",
    "    Clips v to have norm at most gamma.\n",
    "    Works for vectors or matrices; uses Frobenius norm for matrices.\n",
    "    \"\"\"\n",
    "    norm = np.linalg.norm(v)\n",
    "    if norm <= gamma or gamma <= 0:\n",
    "        return v\n",
    "    return v * (gamma / norm)\n",
    "\n",
    "def simulate_alg1_vs_gamma(\n",
    "    n=2000,\n",
    "    gamma_list=None,\n",
    "    repeats=10,\n",
    "    Theta_true=None,\n",
    "    p=10, q=10,\n",
    "    sigma_z=1.0, sigma_1=1.0, sigma_2=1.0,\n",
    "    C=1.0, C0=1.0, c=1.0, c0=1.0, c1=1.0, c2=1.0,\n",
    "    tau=5.0,\n",
    "    rho_1=0.9, rho_2=0.1,\n",
    "    seed=0,\n",
    "    T=20,\n",
    "):\n",
    "    \"\"\"\n",
    "    Fix n and T, vary clipping thresholds gamma1=gamma2 over gamma_list.\n",
    "    Implements per-sample clipping as in Algorithm DP-2S-GD.\n",
    "\n",
    "    Returns:\n",
    "        dict with keys:\n",
    "            'gamma_list', 'errors_alg1', 'stds_alg1',\n",
    "            'error_2sls_mean', 'error_2sls_std',\n",
    "            'n', 'T', 'rho_1', 'rho_2'\n",
    "    \"\"\"\n",
    "    if gamma_list is None:\n",
    "        gamma_list = np.logspace(0, 3, 6)  \n",
    "\n",
    "    gamma_list = np.array(list(gamma_list), dtype=float)\n",
    "\n",
    "    np.random.seed(seed)\n",
    "\n",
    "    # Setup Theta_true and beta_true\n",
    "    if Theta_true is None:\n",
    "        Theta_true = 5 * np.eye(q, p) + np.random.randn(q, p)\n",
    "    beta_true = np.random.randn(p)\n",
    "\n",
    "    sigma_min_Theta = np.linalg.svd(Theta_true, compute_uv=False).min()\n",
    "    sigma_max_Theta = np.linalg.norm(Theta_true, 2)\n",
    "\n",
    "    R = (np.sqrt(q) + np.sqrt(tau))**2 * (\n",
    "        np.sqrt(p*q) + np.sqrt(p * (np.log(p) + tau))\n",
    "    )\n",
    "\n",
    "    err_lists_alg1 = [[] for _ in gamma_list]\n",
    "    err_list_2sls = []\n",
    "\n",
    "    for _rep in tqdm(range(repeats), desc=f\"n={n}, T={T}, ρ1={rho_1:.2f}, ρ2={rho_2:.2f}\"):\n",
    "\n",
    "        Z = np.random.randn(n, q)\n",
    "        U = np.random.randn(n, q)\n",
    "        Phi = np.random.randn(q, p)\n",
    "        phi = np.random.randn(q)\n",
    "\n",
    "        delta = (C0 * sigma_z**2 * (np.sqrt(q) + np.sqrt(tau))) / np.sqrt(n)\n",
    "        Psi_norm = c0 * (\n",
    "            sigma_z * sigma_2 * np.sqrt(p * q * (tau + np.log(2 * p * q)))\n",
    "        ) / (np.sqrt(n) * (1 - delta)**2)\n",
    "\n",
    "        gamma_floor = (1 - delta)**2 * (sigma_min_Theta - Psi_norm)**2\n",
    "        gamma_ceil  = (1 + delta)**2 * (sigma_max_Theta + Psi_norm)**2\n",
    "\n",
    "        X = Z @ Theta_true + U @ Phi + sigma_1 * np.random.randn(n, p)\n",
    "        y = X @ beta_true + U @ phi + sigma_2 * np.random.randn(n)\n",
    "\n",
    "        # 2SLS baseline \n",
    "        Pz = Z @ np.linalg.inv(Z.T @ Z) @ Z.T\n",
    "        X_hat = Pz @ X\n",
    "        beta_2sls = np.linalg.inv(X_hat.T @ X_hat) @ (X_hat.T @ y)\n",
    "        err_2sls = np.linalg.norm(beta_2sls - beta_true)\n",
    "        err_list_2sls.append(err_2sls)\n",
    "\n",
    "        alpha = 2.0 / (2.0 * gamma_ceil + gamma_floor)\n",
    "        eta   = 1.0 / (1.0 + delta)**2\n",
    "\n",
    "        for j, gamma_clip in enumerate(gamma_list):\n",
    "\n",
    "            gamma_1 = gamma_clip\n",
    "            gamma_2 = gamma_clip\n",
    "\n",
    "            # Noise scales depend on clipping thresholds\n",
    "            lambda_1 = (2.0 * gamma_1 / n) * np.sqrt(T / max(rho_1, 1e-12))\n",
    "            lambda_2 = (2.0 * gamma_2 / n) * np.sqrt(T / max(rho_2, 1e-12))\n",
    "\n",
    "            Theta_t = np.zeros((q, p))\n",
    "            beta_t  = np.zeros(p)\n",
    "\n",
    "            for _t in range(T):\n",
    "                grad_Theta_sum = np.zeros_like(Theta_t)\n",
    "\n",
    "                for i in range(n):\n",
    "                    z_i = Z[i, :]          # shape (q,)\n",
    "                    x_i = X[i, :]          # shape (p,)\n",
    "\n",
    "                    diff1 = z_i @ Theta_t - x_i  \n",
    "                    g_i_Theta = np.outer(z_i, diff1)  \n",
    "                    g_i_Theta = clip_by_norm(g_i_Theta, gamma_1)\n",
    "\n",
    "                    grad_Theta_sum += g_i_Theta\n",
    "\n",
    "                Xi = np.random.randn(q, p) * lambda_1\n",
    "\n",
    "                Theta_t = Theta_t - (eta / n) * grad_Theta_sum + eta * Xi\n",
    "\n",
    "                grad_beta_sum = np.zeros_like(beta_t)\n",
    "\n",
    "                for i in range(n):\n",
    "                    z_i = Z[i, :]\n",
    "                    y_i = y[i]\n",
    "\n",
    "                    v_i = Theta_t.T @ z_i\n",
    "                    r_i = z_i @ (Theta_t @ beta_t) - y_i\n",
    "                    g_i_beta = v_i * r_i  \n",
    "                    g_i_beta = clip_by_norm(g_i_beta, gamma_2)\n",
    "\n",
    "                    grad_beta_sum += g_i_beta\n",
    "\n",
    "                nu = np.random.randn(p) * lambda_2\n",
    "\n",
    "                beta_t = beta_t - (alpha / n) * grad_beta_sum + alpha * nu\n",
    "\n",
    "            err_alg1 = np.linalg.norm(beta_t - beta_true)\n",
    "            err_lists_alg1[j].append(err_alg1)\n",
    "\n",
    "    # Aggregate stats over repeats\n",
    "    errors_alg1 = np.array([np.mean(lst) for lst in err_lists_alg1])\n",
    "    stds_alg1   = np.array([np.std(lst)  for lst in err_lists_alg1])\n",
    "\n",
    "    error_2sls_mean = np.mean(err_list_2sls)\n",
    "    error_2sls_std  = np.std(err_list_2sls)\n",
    "\n",
    "    return {\n",
    "        'gamma_list': gamma_list,\n",
    "        'errors_alg1': errors_alg1,\n",
    "        'stds_alg1': stds_alg1,\n",
    "        'error_2sls_mean': error_2sls_mean,\n",
    "        'error_2sls_std': error_2sls_std,\n",
    "        'n': n,\n",
    "        'T': T,\n",
    "        'rho_1': rho_1,\n",
    "        'rho_2': rho_2,\n",
    "    }\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "def plot_alg1_vs_gamma(\n",
    "    n=2000,\n",
    "    gamma_list=None,\n",
    "    repeats=10,\n",
    "    rho_1=0.9,\n",
    "    rho_2=0.1,\n",
    "    common_kwargs=None,\n",
    "    T=20,\n",
    "    plt_title=None,\n",
    "    log_x=True,\n",
    "    save_path=None,\n",
    "):\n",
    "    if common_kwargs is None:\n",
    "        common_kwargs = {}\n",
    "\n",
    "    if plt_title is None:\n",
    "        plt_title = rf\"Algorithm 1 Error vs $\\gamma$ (n={n}, T={T}, $\\rho_1={rho_1}$, $\\rho_2={rho_2}$)\"\n",
    "\n",
    "    res = simulate_alg1_vs_gamma(\n",
    "        n=n,\n",
    "        gamma_list=gamma_list,\n",
    "        repeats=repeats,\n",
    "        rho_1=rho_1,\n",
    "        rho_2=rho_2,\n",
    "        T=T,\n",
    "        **common_kwargs,\n",
    "    )\n",
    "\n",
    "    gamma_list = res['gamma_list']\n",
    "\n",
    "    plt.figure()\n",
    "    plt.errorbar(\n",
    "        gamma_list,\n",
    "        res['errors_alg1'],\n",
    "        yerr=res['stds_alg1'],\n",
    "        fmt='-o',\n",
    "        capsize=3,\n",
    "        label=r'Algorithm 1',\n",
    "    )\n",
    "\n",
    "    mean_2sls = res['error_2sls_mean']\n",
    "    std_2sls  = res['error_2sls_std']\n",
    "    plt.axhline(mean_2sls, linestyle='--', label='2SLS (baseline)')\n",
    "\n",
    "    if log_x:\n",
    "        plt.xscale('log')\n",
    "\n",
    "    plt.xlabel(r'Clipping threshold $\\gamma$ (with $\\gamma_1=\\gamma_2=\\gamma$)')\n",
    "    plt.ylabel(r'$\\|\\beta^{(T)} - \\beta_{\\mathrm{true}}\\|$')\n",
    "    plt.title(plt_title)\n",
    "    plt.grid(True, which='both', linestyle=':')\n",
    "    plt.legend()\n",
    "    plt.tight_layout()\n",
    "\n",
    "    if save_path is not None:\n",
    "        plt.savefig(save_path, dpi=300)\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4cd68fbc",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_alg1_vs_gamma(\n",
    "    n=2000,\n",
    "    gamma_list=np.logspace(0, 3, 15),   # γ = 1 ... 1000\n",
    "    repeats=10,\n",
    "    rho_1=5,\n",
    "    rho_2=5,\n",
    "    T=20,\n",
    "    common_kwargs=dict(\n",
    "        p=5, q=5, tau=5.0,\n",
    "        C0=1.0, c0=1.0, c1=1.0, c2=1.0\n",
    "    ),\n",
    "    save_path=\"Figures/Alg1_vs_gamma_p=5_q=5_T=20.png\"\n",
    ")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
