{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1a77b37f-932e-4d58-9a92-00950489d80d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os, json\n",
    "from datetime import datetime\n",
    "\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "from sklearn.metrics import adjusted_rand_score\n",
    "from sklearn.cluster import KMeans\n",
    "from lightgbm import LGBMRegressor, LGBMClassifier\n",
    "\n",
    "# =====================\n",
    "# Hyperparameters\n",
    "# =====================\n",
    "SEED_BASE   = 1\n",
    "ROUNDS      = 100\n",
    "M_TASKS     = 20         # number of tasks\n",
    "K_CLUST     = 3          # latent clusters\n",
    "DELTAS      = [1/3, 2/3, 1]\n",
    "\n",
    "# Fixed-lambda sweep (inside each delta)\n",
    "FIXED_LAMBDA_LIST = [1e-5, 1e-4, 1e-3, 1e-2, 1e-1]\n",
    "\n",
    "# Split sizes\n",
    "A_FRAC      = 0.5        # fraction to use for nuisance training (split-A)\n",
    "\n",
    "# Noise levels\n",
    "NOISE_Y     = 1\n",
    "\n",
    "# LightGBM (nuisances)\n",
    "N_JOBS = 1\n",
    "LGBM_ESTIMATORS        = 100\n",
    "LGBM_LEARNING_RATE     = 0.03\n",
    "LGBM_NUM_LEAVES        = 20\n",
    "LGBM_MIN_CHILD_SAMPLES = 20\n",
    "LGBM_SUBSAMPLE         = 0.9\n",
    "LGBM_COLSAMPLE         = 0.9\n",
    "LGBM_REG_ALPHA         = 0.0\n",
    "LGBM_REG_LAMBDA        = 0.0\n",
    "LGBM_SEED_OFFSET       = 1000\n",
    "\n",
    "# Adaptive-fusion hyperparameters\n",
    "C_W               = .1\n",
    "ADAPT_EPS         = 1e-12\n",
    "GAMMA_DEFAULT     = 2       # adaptive gamma\n",
    "ADAPT_TAU_DEFAULT = 10       # adaptive tau\n",
    "\n",
    "# ---------------- ARMUL controls ----------------\n",
    "# Per-task penalty: lambda_j = C_LAM * sqrt(1 / n_j)\n",
    "C_LAM            = 10.0\n",
    "ARMUL_K_MAX      = M_TASKS\n",
    "ARMUL_N_INIT     = 1          \n",
    "ARMUL_N_ITER     = 300   \n",
    "ARMUL_TOL        = 1e-7    \n",
    "ARMUL_SEED       = 123  \n",
    "\n",
    "EPS = 1e-12\n",
    "\n",
    "# =====================\n",
    "# Utilities\n",
    "# =====================\n",
    "\n",
    "def ensure_dir(path: str):\n",
    "    os.makedirs(path, exist_ok=True)\n",
    "    return path\n",
    "\n",
    "def split_two_parts(n, a_frac, seed):\n",
    "    \"\"\"Return indices for split-A (nuisance) and split-T (target).\"\"\"\n",
    "    rng = np.random.default_rng(seed)\n",
    "    idx = np.arange(n)\n",
    "    rng.shuffle(idx)\n",
    "    n_a = int(np.floor(a_frac * n))\n",
    "    A = idx[:n_a]\n",
    "    T = idx[n_a:]\n",
    "    return A, T\n",
    "\n",
    "def sigmoid(x):\n",
    "    return 1 / (1 + np.exp(-x))\n",
    "\n",
    "# =====================\n",
    "# ATE Data Generating Process\n",
    "# =====================\n",
    "\n",
    "def make_te_tasks(m: int, K: int, delta: float, seed: int):\n",
    "    \"\"\"\n",
    "    Binary treatment A with propensity e_j(X); outcome:\n",
    "        Y = theta_k * A + g_j(X) + eps,  eps ~ N(0, NOISE_Y^2)\n",
    "    X ~ N(0, I_{p_j}); p_j varies by task.\n",
    "    \"\"\"\n",
    "    rng = np.random.default_rng(seed)\n",
    "    theta_clusters = np.arange(1, K + 1, dtype=float) * delta - (K + 1) * delta / 2.0\n",
    "\n",
    "    cluster_ids = np.repeat(np.arange(K), repeats=int(np.ceil(m / K)))[:m]\n",
    "    rng.shuffle(cluster_ids)\n",
    "\n",
    "    tasks = []\n",
    "    theta_true = np.zeros(m)\n",
    "    for j in range(m):\n",
    "        n_j = 3200 + 80 * j\n",
    "        p_j = 5 + j\n",
    "\n",
    "        X = rng.normal(0.0, 1.0, size=(n_j, p_j))\n",
    "\n",
    "        # Task-specific propensity e_j(X) = sigmoid(h_j(X)), clip away extremes\n",
    "        h = X[:, 3] * X[:, 4] - X[:, 0] * X[:, 1]\n",
    "        e = np.clip(sigmoid(h), 0.05, 0.95)\n",
    "        A = (rng.uniform(size=n_j) < e).astype(float)\n",
    "\n",
    "        # g_j(X): smooth baseline\n",
    "        g = 0.0\n",
    "        for i in range(p_j):\n",
    "            g += (-0.8) ** (i+1) * (np.exp(X[:, i]) / (1.0 + np.exp(X[:, i])))\n",
    "\n",
    "        # cluster & outcome\n",
    "        k = cluster_ids[j]\n",
    "        theta = theta_clusters[k]\n",
    "        Y = theta * A + g + rng.normal(0.0, NOISE_Y, size=n_j)\n",
    "\n",
    "        tasks.append({\"X\": X, \"A\": A, \"Y\": Y})\n",
    "        theta_true[j] = theta\n",
    "\n",
    "    return tasks, theta_true, cluster_ids\n",
    "\n",
    "# =====================\n",
    "# Models\n",
    "# =====================\n",
    "\n",
    "def _lgbm_reg(seed):\n",
    "    return LGBMRegressor(\n",
    "        n_estimators=LGBM_ESTIMATORS,\n",
    "        learning_rate=LGBM_LEARNING_RATE,\n",
    "        num_leaves=LGBM_NUM_LEAVES,\n",
    "        min_child_samples=LGBM_MIN_CHILD_SAMPLES,\n",
    "        subsample=LGBM_SUBSAMPLE,\n",
    "        colsample_bytree=LGBM_COLSAMPLE,\n",
    "        reg_alpha=LGBM_REG_ALPHA,\n",
    "        reg_lambda=LGBM_REG_LAMBDA,\n",
    "        objective=\"regression\",\n",
    "        random_state=seed,\n",
    "        n_jobs=N_JOBS,\n",
    "        verbose=-1,\n",
    "    )\n",
    "\n",
    "def _lgbm_clf(seed):\n",
    "    return LGBMClassifier(\n",
    "        n_estimators=LGBM_ESTIMATORS,\n",
    "        learning_rate=LGBM_LEARNING_RATE,\n",
    "        num_leaves=LGBM_NUM_LEAVES,\n",
    "        min_child_samples=LGBM_MIN_CHILD_SAMPLES,\n",
    "        subsample=LGBM_SUBSAMPLE,\n",
    "        colsample_bytree=LGBM_COLSAMPLE,\n",
    "        reg_alpha=LGBM_REG_ALPHA,\n",
    "        reg_lambda=LGBM_REG_LAMBDA,\n",
    "        objective=\"binary\",\n",
    "        random_state=seed,\n",
    "        n_jobs=N_JOBS,\n",
    "        verbose=-1,\n",
    "    )\n",
    "\n",
    "# =====================\n",
    "# AIPW\n",
    "# =====================\n",
    "\n",
    "def aipw_pseudo_trainA_predict_subset_TE(X, A, Y, A_idx, T_idx, seed):\n",
    "    \"\"\"\n",
    "    Fit nuisances on split-A only:\n",
    "        m_1(x) = E[Y | A=1, X=x]\n",
    "        m_0(x) = E[Y | A=0, X=x]\n",
    "        e(x)   = P(A=1 | X=x)\n",
    "    \"\"\"\n",
    "    reg1 = _lgbm_reg(seed)\n",
    "    reg0 = _lgbm_reg(seed + 1)\n",
    "\n",
    "    X_A = X[A_idx]\n",
    "    A_A = A[A_idx]\n",
    "    Y_A = Y[A_idx]\n",
    "\n",
    "    mask1 = (A_A == 1)\n",
    "    mask0 = (A_A == 0)\n",
    "\n",
    "    if mask1.sum() < 10:\n",
    "        reg1.fit(X_A, Y_A)\n",
    "    else:\n",
    "        reg1.fit(X_A[mask1], Y_A[mask1])\n",
    "\n",
    "    if mask0.sum() < 10:\n",
    "        reg0.fit(X_A, Y_A)\n",
    "    else:\n",
    "        reg0.fit(X_A[mask0], Y_A[mask0])\n",
    "\n",
    "    m1_hat = reg1.predict(X[T_idx])\n",
    "    m0_hat = reg0.predict(X[T_idx])\n",
    "\n",
    "    eclf = _lgbm_clf(seed + 2)\n",
    "    eclf.fit(X_A, A_A.astype(int))\n",
    "    e_hat = eclf.predict_proba(X[T_idx])[:, 1]\n",
    "    e_hat = np.clip(e_hat, 1e-3, 1 - 1e-3)\n",
    "\n",
    "    A_T = A[T_idx]\n",
    "    Y_T = Y[T_idx]\n",
    "\n",
    "    term1 = (A_T / e_hat) * (Y_T - m1_hat)\n",
    "    term0 = ((1.0 - A_T) / (1.0 - e_hat)) * (Y_T - m0_hat)\n",
    "    phi = term1 - term0 + (m1_hat - m0_hat)\n",
    "\n",
    "    return phi\n",
    "\n",
    "def sstats_and_if_se_from_aipw(phi):\n",
    "    phi = np.asarray(phi, dtype=float)\n",
    "    n = phi.size\n",
    "    if n == 0:\n",
    "        return 1.0, 0.0, 1e-12\n",
    "    b = float(phi.mean())\n",
    "    var_phi = float(np.mean((phi - b) ** 2))\n",
    "    var_b = max(var_phi / max(n, 1), 1e-12)\n",
    "    return 1.0, b, var_b\n",
    "\n",
    "# =====================\n",
    "# Graph fused lasso helpers\n",
    "# =====================\n",
    "\n",
    "def build_lambda_adaptive(theta_init, c_w=C_W, gamma=GAMMA_DEFAULT,\n",
    "                          tau=None, eps=ADAPT_EPS,\n",
    "                          km_n_init=10, random_state=0):\n",
    "    \"\"\"\n",
    "    Adaptive pairwise penalties Λ_{ij}\n",
    "    theta_init should be the individual estimator (here, AIPW).\n",
    "    \"\"\"\n",
    "    m = len(theta_init)\n",
    "    L = np.zeros((m, m), dtype=float)\n",
    "    if m < 2:\n",
    "        return L\n",
    "\n",
    "    wvals = []\n",
    "    for i in range(m):\n",
    "        for j in range(i + 1, m):\n",
    "            diff = abs(theta_init[i] - theta_init[j])\n",
    "            w = c_w * (diff + 1e-8) ** (-gamma)\n",
    "            wvals.append(w)\n",
    "    wvals = np.asarray(wvals, dtype=float)\n",
    "\n",
    "    if tau is None:\n",
    "        z = np.log(np.maximum(wvals, 1e-18)).reshape(-1, 1)\n",
    "        if np.allclose(z, z[0]):\n",
    "            tau = float(np.exp(z[0, 0]))\n",
    "        else:\n",
    "            km = KMeans(n_clusters=2, n_init=km_n_init, random_state=random_state)\n",
    "            km.fit(z)\n",
    "            centers = np.sort(km.cluster_centers_.ravel())\n",
    "            z_split = 0.5 * (centers[0] + centers[1])\n",
    "            tau = float(np.exp(z_split))\n",
    "\n",
    "    idx = 0\n",
    "    for i in range(m):\n",
    "        for j in range(i + 1, m):\n",
    "            w = wvals[idx]; idx += 1\n",
    "            lam = eps if w <= tau else w\n",
    "            L[i, j] = L[j, i] = lam\n",
    "\n",
    "    return L\n",
    "\n",
    "def build_lambda_fixed(m, lam):\n",
    "    L = np.full((m, m), 0.0)\n",
    "    for i in range(m):\n",
    "        for j in range(i + 1, m):\n",
    "            L[i, j] = L[j, i] = lam\n",
    "    return L\n",
    "\n",
    "def admm_graph_fused_lasso(a, b, Lambda, rho=1.0, max_iter=2000, tol=1e-6):\n",
    "    \"\"\"\n",
    "    min_θ Σ_j (a_j/2)(θ_j - b_j)^2 + Σ_{i<j} Λ_{ij}|θ_i - θ_j|\n",
    "    \"\"\"\n",
    "    m = len(a)\n",
    "    edges = [(i, j) for i in range(m) for j in range(i + 1, m)]\n",
    "    E = len(edges)\n",
    "    if E == 0:\n",
    "        return b.copy()\n",
    "\n",
    "    B = np.zeros((E, m))\n",
    "    for e, (i, j) in enumerate(edges):\n",
    "        B[e, i] = 1.0\n",
    "        B[e, j] = -1.0\n",
    "    lams = np.array([Lambda[i, j] for (i, j) in edges])\n",
    "    A_mat = np.diag(a)\n",
    "    M = A_mat + rho * (B.T @ B)\n",
    "\n",
    "    theta = b.copy()\n",
    "    z = B @ theta\n",
    "    u = np.zeros_like(z)\n",
    "\n",
    "    def soft(v, kappa):\n",
    "        return np.sign(v) * np.maximum(np.abs(v) - kappa, 0.0)\n",
    "\n",
    "    for _ in range(max_iter):\n",
    "        theta_new = np.linalg.solve(M, A_mat @ b + rho * B.T @ (z - u))\n",
    "        r = B @ theta_new + u\n",
    "        z_new = soft(r, lams / rho)\n",
    "        u_new = r - z_new\n",
    "        if np.linalg.norm(B @ theta_new - z_new) < tol and np.linalg.norm(rho * B.T @ (z_new - z)) < tol:\n",
    "            theta = theta_new\n",
    "            break\n",
    "        theta, z, u = theta_new, z_new, u_new\n",
    "\n",
    "    return theta\n",
    "\n",
    "def clusters_from_theta(theta_hat, tol=1e-3):\n",
    "    \"\"\"Greedy grouping of nearly-equal θ's into cluster labels.\"\"\"\n",
    "    order = np.argsort(theta_hat)\n",
    "    labels_sorted = np.zeros_like(order)\n",
    "    cur = 0\n",
    "    labels_sorted[0] = 0\n",
    "    for k in range(1, len(order)):\n",
    "        if abs(theta_hat[order[k]] - theta_hat[order[k - 1]]) > tol:\n",
    "            cur += 1\n",
    "        labels_sorted[k] = cur\n",
    "    labels = np.zeros_like(labels_sorted)\n",
    "    labels[order] = labels_sorted\n",
    "    return labels\n",
    "\n",
    "# =====================\n",
    "# ARMUL (per-task λ_j) clustered MTL second stage\n",
    "# =====================\n",
    "\n",
    "def _soft(x, t):\n",
    "    return np.sign(x) * np.maximum(np.abs(x) - t, 0.0)\n",
    "\n",
    "def _armul_obj(theta, gamma, c, a, b, lam_vec):\n",
    "    return 0.5 * np.sum(a * (theta - b)**2) + np.sum(lam_vec * np.abs(theta - gamma[c]))\n",
    "\n",
    "def _ensure_nonempty(labels, K, rng):\n",
    "    counts = np.bincount(labels, minlength=K)\n",
    "    for k in range(K):\n",
    "        if counts[k] == 0:\n",
    "            src = int(np.argmax(counts))\n",
    "            idx_src = np.where(labels == src)[0]\n",
    "            j_move = int(rng.choice(idx_src))\n",
    "            labels[j_move] = k\n",
    "            counts[src] -= 1\n",
    "            counts[k] += 1\n",
    "    return labels\n",
    "\n",
    "def _kmeans_init_labels_1d(b, K, seed, n_init_kmeans=10):\n",
    "    km = KMeans(n_clusters=K, n_init=n_init_kmeans, random_state=seed)\n",
    "    labels = km.fit_predict(b.reshape(-1, 1))\n",
    "    return labels.astype(int, copy=False)\n",
    "\n",
    "def armul_clustered_fit(\n",
    "    b, a, lam_vec,\n",
    "    K_candidates,\n",
    "    init=\"kmeans\",\n",
    "    n_init=1,\n",
    "    n_iter=300, tol=1e-7, seed=123,\n",
    "    kmeans_n_init=10\n",
    "):\n",
    "    \"\"\"\n",
    "    Coordinate-descent for clustered MTL with per-task λ_j:\n",
    "      min_{θ,γ,c} sum_j (a_j/2)(θ_j-b_j)^2 + sum_j λ_j |θ_j - γ_{c_j}|\n",
    "    \"\"\"\n",
    "    rng = np.random.default_rng(seed)\n",
    "    m = len(b)\n",
    "    lam_vec = np.asarray(lam_vec, dtype=float)\n",
    "    assert lam_vec.shape == (m,)\n",
    "\n",
    "    if isinstance(K_candidates, int):\n",
    "        K_list = [K_candidates]\n",
    "    else:\n",
    "        K_list = list(K_candidates)\n",
    "        if len(K_list) == 0:\n",
    "            raise ValueError(\"K_candidates is empty.\")\n",
    "\n",
    "    def theta_update(gamma, labels):\n",
    "        theta = np.empty_like(b, dtype=float)\n",
    "        for j in range(m):\n",
    "            gk = gamma[labels[j]]\n",
    "            theta[j] = gk + _soft(b[j] - gk, lam_vec[j] / max(a[j], EPS))\n",
    "        return theta\n",
    "\n",
    "    best = dict(obj=np.inf, theta=None, labels=None, gamma=None, K=None)\n",
    "\n",
    "    for K in K_list:\n",
    "        for init_id in range(n_init):\n",
    "            # initialize labels\n",
    "            if init == \"kmeans\":\n",
    "                labels = _kmeans_init_labels_1d(b, K, seed + init_id, n_init_kmeans=kmeans_n_init)\n",
    "            elif init == \"quantiles\":\n",
    "                order = np.argsort(b)\n",
    "                labels = np.zeros(m, dtype=int)\n",
    "                edges = np.linspace(0, m, K+1).astype(int)\n",
    "                for k in range(K):\n",
    "                    idx = order[edges[k]:edges[k+1]]\n",
    "                    labels[idx] = k\n",
    "            elif init == \"random\":\n",
    "                labels = rng.integers(0, K, size=m)\n",
    "            else:\n",
    "                raise ValueError(f\"Unknown init='{init}'\")\n",
    "\n",
    "            labels = _ensure_nonempty(labels, K, rng)\n",
    "\n",
    "            # γ <- medians of b per cluster (L1 center), then CD\n",
    "            gamma = np.array([np.median(b[labels == k]) for k in range(K)], dtype=float)\n",
    "            theta = theta_update(gamma, labels)\n",
    "            obj_prev = _armul_obj(theta, gamma, labels, a, b, lam_vec)\n",
    "\n",
    "            for _ in range(n_iter):\n",
    "                # (1) θ-update\n",
    "                theta = theta_update(gamma, labels)\n",
    "                # (2) γ-update: median of θ in each cluster\n",
    "                for k in range(K):\n",
    "                    idx = np.where(labels == k)[0]\n",
    "                    if len(idx) == 0:\n",
    "                        gamma[k] = float(rng.choice(theta))\n",
    "                    else:\n",
    "                        gamma[k] = float(np.median(theta[idx]))\n",
    "                # (3) c-update: assign each j to best cluster\n",
    "                theta_new = np.empty_like(theta)\n",
    "                labels_new = labels.copy()\n",
    "                for j in range(m):\n",
    "                    costs = np.empty(K, dtype=float)\n",
    "                    thetas_j = np.empty(K, dtype=float)\n",
    "                    for k in range(K):\n",
    "                        gk = gamma[k]\n",
    "                        tj = gk + _soft(b[j] - gk, lam_vec[j] / max(a[j], EPS))\n",
    "                        thetas_j[k] = tj\n",
    "                        costs[k] = 0.5 * a[j] * (tj - b[j])**2 + lam_vec[j] * abs(tj - gk)\n",
    "                    kbest = int(np.argmin(costs))\n",
    "                    labels_new[j] = kbest\n",
    "                    theta_new[j]  = thetas_j[kbest]\n",
    "\n",
    "                labels_new = _ensure_nonempty(labels_new, K, rng)\n",
    "                obj = _armul_obj(theta_new, gamma, labels_new, a, b, lam_vec)\n",
    "\n",
    "                if (obj_prev - obj) <= tol:\n",
    "                    theta, labels, obj_prev = theta_new, labels_new, obj\n",
    "                    break\n",
    "                theta, labels, obj_prev = theta_new, labels_new, obj\n",
    "\n",
    "            if obj_prev < best[\"obj\"]:\n",
    "                best = dict(theta=theta.copy(), labels=labels.copy(),\n",
    "                            gamma=gamma.copy(), obj=float(obj_prev), K=K)\n",
    "\n",
    "    return best[\"theta\"], best[\"labels\"], best[\"obj\"], best[\"K\"], best.get(\"gamma\", None)\n",
    "\n",
    "# ---- Per-task λ_j = C_LAM * sqrt(1 / n_j) for ATE ----\n",
    "def armul_lambda_per_task_TE(splits, c=C_LAM, use_target=True):\n",
    "    \"\"\"\n",
    "    λ_j = c * sqrt(1 / n_j), with n_j from |T_j| by default.\n",
    "    \"\"\"\n",
    "    n_list = []\n",
    "    for (A_idx, T_idx) in splits:\n",
    "        n_j = len(T_idx) if use_target else (len(A_idx) + len(T_idx))\n",
    "        n_list.append(max(1, int(n_j)))\n",
    "    n_arr = np.asarray(n_list, dtype=float)\n",
    "    lam_vec = c * np.sqrt(1.0 / n_arr)\n",
    "    return lam_vec.astype(float), {\"n_list\": n_arr.copy()}\n",
    "\n",
    "# ---- Clusterwise SEs ----\n",
    "def cluster_sandwich_from_var(theta_in, labels, var_b_list):\n",
    "    \"\"\"\n",
    "    Clusterwise SEs\n",
    "    \"\"\"\n",
    "    L = int(labels.max() + 1)\n",
    "    se_out = np.zeros_like(theta_in, dtype=float)\n",
    "    var_b_arr = np.asarray(var_b_list, dtype=float)\n",
    "\n",
    "    for g in range(L):\n",
    "        idx_tasks = np.where(labels == g)[0]\n",
    "        if len(idx_tasks) == 0:\n",
    "            continue\n",
    "        var_g = float(np.mean(var_b_arr[idx_tasks]))\n",
    "        se_g = float(np.sqrt(max(var_g, 1e-12)))\n",
    "        se_out[idx_tasks] = se_g\n",
    "\n",
    "    return theta_in.copy(), se_out, labels\n",
    "\n",
    "# =====================\n",
    "# One-round core \n",
    "# =====================\n",
    "\n",
    "def one_round_core(delta: float, seed: int):\n",
    "    \"\"\"\n",
    "    Do the ATE DGP, orthogonal stats via AIPW, individual, adaptive fused, and ARMUL.\n",
    "    \"\"\"\n",
    "    # --- DGP ---\n",
    "    tasks, theta_true, cluster_ids = make_te_tasks(M_TASKS, K_CLUST, delta, seed)\n",
    "\n",
    "    splits = []\n",
    "    for j, t in enumerate(tasks):\n",
    "        n = len(t[\"Y\"])\n",
    "        A_idx, T_idx = split_two_parts(n, A_FRAC, seed=SEED_BASE + 11111 + j)\n",
    "        splits.append((A_idx, T_idx))\n",
    "\n",
    "    a_T, b_T, var_b_T = [], [], []\n",
    "    for j, t in enumerate(tasks):\n",
    "        X, A, Y = t[\"X\"], t[\"A\"], t[\"Y\"]\n",
    "        A_idx, T_idx = splits[j]\n",
    "\n",
    "        # AIPW pseudo-outcomes on T\n",
    "        phi_T = aipw_pseudo_trainA_predict_subset_TE(\n",
    "            X, A, Y, A_idx, T_idx, seed=LGBM_SEED_OFFSET + j\n",
    "        )\n",
    "        a_j, b_j, var_b_j = sstats_and_if_se_from_aipw(phi_T)\n",
    "        a_T.append(a_j)\n",
    "        b_T.append(b_j)\n",
    "        var_b_T.append(var_b_j)\n",
    "\n",
    "    a_T = np.array(a_T, dtype=float)\n",
    "    b_T = np.array(b_T, dtype=float)\n",
    "    var_b_T = np.array(var_b_T, dtype=float)\n",
    "    se_T_indiv = np.sqrt(np.maximum(EPS, var_b_T.copy()))\n",
    "\n",
    "    theta_true = np.array(theta_true, dtype=float)\n",
    "    cluster_ids = np.array(cluster_ids, dtype=int)\n",
    "\n",
    "    n_T = np.array([len(T_idx) for (A_idx, T_idx) in splits], dtype=float)\n",
    "\n",
    "    K_true = int(cluster_ids.max() + 1)\n",
    "    N_k_true = np.zeros(K_true, dtype=float)\n",
    "    for k in range(K_true):\n",
    "        idx_k = (cluster_ids == k)\n",
    "        N_k_true[k] = n_T[idx_k].sum()\n",
    "\n",
    "    cluster_weights = N_k_true[cluster_ids]  # shape (M_TASKS,)\n",
    "\n",
    "    # ============================\n",
    "    # (i) Individual estimator\n",
    "    # ============================\n",
    "    theta_ind = b_T.copy()\n",
    "    cl_ind = clusters_from_theta(theta_ind)\n",
    "    ari_ind = float(adjusted_rand_score(cluster_ids, cl_ind))\n",
    "\n",
    "    err_ind = theta_ind - theta_true\n",
    "    rmse_ind = float(np.sqrt(np.mean(err_ind ** 2)))\n",
    "\n",
    "    cw_rmse_ind = float(np.sqrt(np.sum(cluster_weights * err_ind ** 2)))\n",
    "    diff_cw_ind = np.sqrt(cluster_weights) * err_ind\n",
    "\n",
    "    zI = err_ind / se_T_indiv\n",
    "\n",
    "    # ============================\n",
    "    # (ii) Adaptive fused\n",
    "    # ============================\n",
    "    L_ad_T = build_lambda_adaptive(\n",
    "        theta_ind,\n",
    "        c_w=C_W,\n",
    "        gamma=GAMMA_DEFAULT,\n",
    "        tau=ADAPT_TAU_DEFAULT,\n",
    "        eps=ADAPT_EPS\n",
    "    )\n",
    "    theta_ad_pen = admm_graph_fused_lasso(a_T, b_T, L_ad_T)\n",
    "    cl_ad = clusters_from_theta(theta_ad_pen)\n",
    "    theta_ad, se_ad, cl_ad = cluster_sandwich_from_var(theta_ad_pen, cl_ad, var_b_T)\n",
    "    ari_ad = float(adjusted_rand_score(cluster_ids, cl_ad))\n",
    "\n",
    "    err_ad = theta_ad - theta_true\n",
    "    rmse_ad = float(np.sqrt(np.mean(err_ad ** 2)))\n",
    "\n",
    "    cw_rmse_ad = float(np.sqrt(np.sum(cluster_weights * err_ad ** 2)))\n",
    "    diff_cw_ad = np.sqrt(cluster_weights) * err_ad\n",
    "\n",
    "    zA_ad = err_ad / se_ad\n",
    "\n",
    "    # ============================\n",
    "    # (iii) ARMUL with per-task λ_j\n",
    "    # ============================\n",
    "    lam_vec, lam_diag = armul_lambda_per_task_TE(splits, c=C_LAM, use_target=True)\n",
    "\n",
    "    def fit_armul_with_K(K_fixed):\n",
    "        K_fixed = int(np.clip(K_fixed, 1, M_TASKS))\n",
    "        theta_k, labels_k, obj_k, K_sel_k, _ = armul_clustered_fit(\n",
    "            b=b_T, a=a_T, lam_vec=lam_vec,\n",
    "            K_candidates=[K_fixed],\n",
    "            init=\"kmeans\",\n",
    "            n_init=ARMUL_N_INIT,\n",
    "            n_iter=ARMUL_N_ITER,\n",
    "            tol=ARMUL_TOL,\n",
    "            seed=ARMUL_SEED,\n",
    "            kmeans_n_init=10,\n",
    "        )\n",
    "        theta_k, se_k, labels_k = cluster_sandwich_from_var(theta_k, labels_k, var_b_T)\n",
    "        ari_k = float(adjusted_rand_score(cluster_ids, labels_k))\n",
    "\n",
    "        err_k = theta_k - theta_true\n",
    "        rmse_k = float(np.sqrt(np.mean(err_k ** 2)))\n",
    "        cw_rmse_k = float(np.sqrt(np.sum(cluster_weights * err_k ** 2)))\n",
    "        diff_cw_k = np.sqrt(cluster_weights) * err_k\n",
    "\n",
    "        return dict(\n",
    "            theta=theta_k, se=se_k, labels=labels_k,\n",
    "            ari=ari_k, rmse=rmse_k, cw_rmse=cw_rmse_k,\n",
    "            diff_cw=diff_cw_k, K=K_sel_k, obj=obj_k\n",
    "        )\n",
    "\n",
    "    armul_km1 = fit_armul_with_K(K_CLUST - 1)\n",
    "    armul_k   = fit_armul_with_K(K_CLUST)       # oracle\n",
    "    armul_kp1 = fit_armul_with_K(K_CLUST + 1)\n",
    "\n",
    "    # unpack some for convenience\n",
    "    rmse_armul_km1 = armul_km1[\"rmse\"]\n",
    "    rmse_armul_k   = armul_k[\"rmse\"]\n",
    "    rmse_armul_kp1 = armul_kp1[\"rmse\"]\n",
    "\n",
    "    cw_rmse_armul_km1 = armul_km1[\"cw_rmse\"]\n",
    "    cw_rmse_armul_k   = armul_k[\"cw_rmse\"]\n",
    "    cw_rmse_armul_kp1 = armul_kp1[\"cw_rmse\"]\n",
    "\n",
    "    diff_cw_armul_km1 = armul_km1[\"diff_cw\"]\n",
    "    diff_cw_armul_k   = armul_k[\"diff_cw\"]\n",
    "    diff_cw_armul_kp1 = armul_kp1[\"diff_cw\"]\n",
    "\n",
    "    return {\n",
    "        \"theta_true\": theta_true,\n",
    "        \"cluster_ids_true\": cluster_ids,\n",
    "        \"a_T\": a_T,\n",
    "        \"b_T\": b_T,\n",
    "        \"var_b_T\": var_b_T,\n",
    "        \"se_T_indiv\": se_T_indiv,\n",
    "\n",
    "        \"cluster_n_T\": n_T,\n",
    "        \"cluster_N_k_true\": N_k_true,\n",
    "        \"cluster_weights\": cluster_weights,\n",
    "\n",
    "        \"theta_ind\": theta_ind,\n",
    "        \"ARI_ind\": ari_ind,\n",
    "        \"rmse_ind\": rmse_ind,\n",
    "        \"zI\": zI,\n",
    "\n",
    "        \"theta_ad\": theta_ad,\n",
    "        \"se_ad\": se_ad,\n",
    "        \"ARI_ad\": ari_ad,\n",
    "        \"rmse_ad\": rmse_ad,\n",
    "        \"zA_ad\": zA_ad,\n",
    "\n",
    "        \"theta_armul_km1\": armul_km1[\"theta\"],\n",
    "        \"se_armul_km1\": armul_km1[\"se\"],\n",
    "        \"ARI_armul_km1\": armul_km1[\"ari\"],\n",
    "        \"rmse_armul_km1\": rmse_armul_km1,\n",
    "        \"armul_K_sel_km1\": armul_km1[\"K\"],\n",
    "        \"armul_obj_km1\": armul_km1[\"obj\"],\n",
    "\n",
    "        \"theta_armul_k\": armul_k[\"theta\"],\n",
    "        \"se_armul_k\": armul_k[\"se\"],\n",
    "        \"ARI_armul_k\": armul_k[\"ari\"],\n",
    "        \"rmse_armul_k\": rmse_armul_k,\n",
    "        \"armul_K_sel_k\": armul_k[\"K\"],\n",
    "        \"armul_obj_k\": armul_k[\"obj\"],\n",
    "\n",
    "        \"theta_armul_kp1\": armul_kp1[\"theta\"],\n",
    "        \"se_armul_kp1\": armul_kp1[\"se\"],\n",
    "        \"ARI_armul_kp1\": armul_kp1[\"ari\"],\n",
    "        \"rmse_armul_kp1\": rmse_armul_kp1,\n",
    "        \"armul_K_sel_kp1\": armul_kp1[\"K\"],\n",
    "        \"armul_obj_kp1\": armul_kp1[\"obj\"],\n",
    "\n",
    "        \"armul_lambda_vec\": lam_vec,\n",
    "        \"armul_lambda_diag\": lam_diag,\n",
    "\n",
    "        \"cw_rmse_ind\": cw_rmse_ind,\n",
    "        \"cw_rmse_ad\": cw_rmse_ad,\n",
    "        \"cw_rmse_armul_km1\": cw_rmse_armul_km1,\n",
    "        \"cw_rmse_armul_k\": cw_rmse_armul_k,\n",
    "        \"cw_rmse_armul_kp1\": cw_rmse_armul_kp1,\n",
    "\n",
    "        \"diff_cw_ind\": diff_cw_ind,\n",
    "        \"diff_cw_ad\": diff_cw_ad,\n",
    "        \"diff_cw_armul_km1\": diff_cw_armul_km1,\n",
    "        \"diff_cw_armul_k\": diff_cw_armul_k,\n",
    "        \"diff_cw_armul_kp1\": diff_cw_armul_kp1,\n",
    "    }\n",
    "\n",
    "\n",
    "\n",
    "# =====================\n",
    "# Run ROUNDS\n",
    "# =====================\n",
    "\n",
    "def run_for_delta_fixed_grid(rounds: int, delta: float, start_seed: int, fixed_lambda_list):\n",
    "    \"\"\"\n",
    "    For a given delta, run ROUNDS times.\n",
    "    \"\"\"\n",
    "    def init_agg():\n",
    "        return {\n",
    "            \"TH_true\": [],\n",
    "            \"TH_ind\": [],\n",
    "            \"TH_fx\": [],\n",
    "            \"TH_ad\": [],\n",
    "            \"TH_arm_km1\": [],\n",
    "            \"TH_arm_k\": [],\n",
    "            \"TH_arm_kp1\": [],\n",
    "            \"zI\": [],\n",
    "            \"zA_fx\": [],\n",
    "            \"zA_ad\": [],\n",
    "            \"ARI_ind\": [],\n",
    "            \"ARI_fx\": [],\n",
    "            \"ARI_ad\": [],\n",
    "            \"ARI_arm_km1\": [],\n",
    "            \"ARI_arm_k\": [],\n",
    "            \"ARI_arm_kp1\": [],\n",
    "            \"rmse_ind\": [],\n",
    "            \"rmse_fx\": [],\n",
    "            \"rmse_ad\": [],\n",
    "            \"rmse_arm_km1\": [],\n",
    "            \"rmse_arm_k\": [],\n",
    "            \"rmse_arm_kp1\": [],\n",
    "            \"armul_K_sel_km1\": [],\n",
    "            \"armul_K_sel_k\": [],\n",
    "            \"armul_K_sel_kp1\": [],\n",
    "            \"armul_lambda\": [],\n",
    "\n",
    "            \"cw_rmse_ind\": [],\n",
    "            \"cw_rmse_fx\": [],\n",
    "            \"cw_rmse_ad\": [],\n",
    "            \"cw_rmse_arm_km1\": [],\n",
    "            \"cw_rmse_arm_k\": [],\n",
    "            \"cw_rmse_arm_kp1\": [],\n",
    "\n",
    "            \"diff_cw_ind\": [],\n",
    "            \"diff_cw_fx\": [],\n",
    "            \"diff_cw_ad\": [],\n",
    "            \"diff_cw_arm_km1\": [],\n",
    "            \"diff_cw_arm_k\": [],\n",
    "            \"diff_cw_arm_kp1\": [],\n",
    "        }\n",
    "\n",
    "    aggs = {lam: init_agg() for lam in fixed_lambda_list}\n",
    "\n",
    "    for r in tqdm(range(rounds), desc=f\"delta={delta}\"):\n",
    "        core = one_round_core(delta=delta, seed=start_seed + r)\n",
    "\n",
    "        a_T = core[\"a_T\"]\n",
    "        b_T = core[\"b_T\"]\n",
    "        var_b_T = core[\"var_b_T\"]\n",
    "        theta_true = core[\"theta_true\"]\n",
    "        cluster_ids = core[\"cluster_ids_true\"]\n",
    "        cluster_weights = core[\"cluster_weights\"]\n",
    "\n",
    "        for fixed_lambda in fixed_lambda_list:\n",
    "            agg = aggs[fixed_lambda]\n",
    "\n",
    "            L_fx_T = build_lambda_fixed(M_TASKS, fixed_lambda)\n",
    "            theta_fx_pen = admm_graph_fused_lasso(a_T, b_T, L_fx_T)\n",
    "            cl_fx = clusters_from_theta(theta_fx_pen)\n",
    "            theta_fx, se_fx, cl_fx = cluster_sandwich_from_var(theta_fx_pen, cl_fx, var_b_T)\n",
    "            ari_fx = float(adjusted_rand_score(cluster_ids, cl_fx))\n",
    "            err_fx = theta_fx - theta_true\n",
    "            rmse_fx = float(np.sqrt(np.mean(err_fx ** 2)))\n",
    "            zA_fx = err_fx / se_fx\n",
    "\n",
    "            cw_rmse_fx = float(np.sqrt(np.sum(cluster_weights * err_fx ** 2)))\n",
    "            diff_cw_fx = np.sqrt(cluster_weights) * err_fx\n",
    "\n",
    "            agg[\"TH_true\"].append(theta_true)\n",
    "            agg[\"TH_ind\"].append(core[\"theta_ind\"])\n",
    "            agg[\"TH_fx\"].append(theta_fx)\n",
    "            agg[\"TH_ad\"].append(core[\"theta_ad\"])\n",
    "            agg[\"TH_arm_km1\"].append(core[\"theta_armul_km1\"])\n",
    "            agg[\"TH_arm_k\"].append(core[\"theta_armul_k\"])\n",
    "            agg[\"TH_arm_kp1\"].append(core[\"theta_armul_kp1\"])\n",
    "\n",
    "            agg[\"zI\"].append(core[\"zI\"])\n",
    "            agg[\"zA_fx\"].append(zA_fx)\n",
    "            agg[\"zA_ad\"].append(core[\"zA_ad\"])\n",
    "\n",
    "            agg[\"ARI_ind\"].append(core[\"ARI_ind\"])\n",
    "            agg[\"ARI_fx\"].append(ari_fx)\n",
    "            agg[\"ARI_ad\"].append(core[\"ARI_ad\"])\n",
    "            agg[\"ARI_arm_km1\"].append(core[\"ARI_armul_km1\"])\n",
    "            agg[\"ARI_arm_k\"].append(core[\"ARI_armul_k\"])\n",
    "            agg[\"ARI_arm_kp1\"].append(core[\"ARI_armul_kp1\"])\n",
    "\n",
    "            agg[\"rmse_ind\"].append(core[\"rmse_ind\"])\n",
    "            agg[\"rmse_fx\"].append(rmse_fx)\n",
    "            agg[\"rmse_ad\"].append(core[\"rmse_ad\"])\n",
    "            agg[\"rmse_arm_km1\"].append(core[\"rmse_armul_km1\"])\n",
    "            agg[\"rmse_arm_k\"].append(core[\"rmse_armul_k\"])\n",
    "            agg[\"rmse_arm_kp1\"].append(core[\"rmse_armul_kp1\"])\n",
    "\n",
    "            agg[\"armul_K_sel_km1\"].append(core[\"armul_K_sel_km1\"])\n",
    "            agg[\"armul_K_sel_k\"].append(core[\"armul_K_sel_k\"])\n",
    "            agg[\"armul_K_sel_kp1\"].append(core[\"armul_K_sel_kp1\"])\n",
    "            agg[\"armul_lambda\"].append(core[\"armul_lambda_vec\"])\n",
    "\n",
    "            agg[\"cw_rmse_ind\"].append(core[\"cw_rmse_ind\"])\n",
    "            agg[\"cw_rmse_fx\"].append(cw_rmse_fx)\n",
    "            agg[\"cw_rmse_ad\"].append(core[\"cw_rmse_ad\"])\n",
    "            agg[\"cw_rmse_arm_km1\"].append(core[\"cw_rmse_armul_km1\"])\n",
    "            agg[\"cw_rmse_arm_k\"].append(core[\"cw_rmse_armul_k\"])\n",
    "            agg[\"cw_rmse_arm_kp1\"].append(core[\"cw_rmse_armul_kp1\"])\n",
    "\n",
    "            agg[\"diff_cw_ind\"].append(core[\"diff_cw_ind\"])\n",
    "            agg[\"diff_cw_fx\"].append(diff_cw_fx)\n",
    "            agg[\"diff_cw_ad\"].append(core[\"diff_cw_ad\"])\n",
    "            agg[\"diff_cw_arm_km1\"].append(core[\"diff_cw_armul_km1\"])\n",
    "            agg[\"diff_cw_arm_k\"].append(core[\"diff_cw_armul_k\"])\n",
    "            agg[\"diff_cw_arm_kp1\"].append(core[\"diff_cw_armul_kp1\"])\n",
    "\n",
    "    res_by_lam = {}\n",
    "    for fixed_lambda, agg in aggs.items():\n",
    "        TH_true_all      = np.concatenate(agg[\"TH_true\"])\n",
    "        TH_ind_all       = np.concatenate(agg[\"TH_ind\"])\n",
    "        TH_fx_all        = np.concatenate(agg[\"TH_fx\"])\n",
    "        TH_ad_all        = np.concatenate(agg[\"TH_ad\"])\n",
    "        TH_arm_km1_all   = np.concatenate(agg[\"TH_arm_km1\"])\n",
    "        TH_arm_k_all     = np.concatenate(agg[\"TH_arm_k\"])\n",
    "        TH_arm_kp1_all   = np.concatenate(agg[\"TH_arm_kp1\"])\n",
    "\n",
    "        zI_all    = np.concatenate(agg[\"zI\"])\n",
    "        zA_fx_all = np.concatenate(agg[\"zA_fx\"])\n",
    "        zA_ad_all = np.concatenate(agg[\"zA_ad\"])\n",
    "\n",
    "        ARI_ind_arr      = np.array(agg[\"ARI_ind\"])\n",
    "        ARI_fx_arr       = np.array(agg[\"ARI_fx\"])\n",
    "        ARI_ad_arr       = np.array(agg[\"ARI_ad\"])\n",
    "        ARI_arm_km1_arr  = np.array(agg[\"ARI_arm_km1\"])\n",
    "        ARI_arm_k_arr    = np.array(agg[\"ARI_arm_k\"])\n",
    "        ARI_arm_kp1_arr  = np.array(agg[\"ARI_arm_kp1\"])\n",
    "\n",
    "        rmse_ind_arr     = np.array(agg[\"rmse_ind\"])\n",
    "        rmse_fx_arr      = np.array(agg[\"rmse_fx\"])\n",
    "        rmse_ad_arr      = np.array(agg[\"rmse_ad\"])\n",
    "        rmse_arm_km1_arr = np.array(agg[\"rmse_arm_km1\"])\n",
    "        rmse_arm_k_arr   = np.array(agg[\"rmse_arm_k\"])\n",
    "        rmse_arm_kp1_arr = np.array(agg[\"rmse_arm_kp1\"])\n",
    "\n",
    "        armul_K_sel_km1_all = np.array(agg[\"armul_K_sel_km1\"])\n",
    "        armul_K_sel_k_all   = np.array(agg[\"armul_K_sel_k\"])\n",
    "        armul_K_sel_kp1_all = np.array(agg[\"armul_K_sel_kp1\"])\n",
    "        armul_lambda_all    = np.array(agg[\"armul_lambda\"])\n",
    "\n",
    "        cw_rmse_ind_arr     = np.array(agg[\"cw_rmse_ind\"])\n",
    "        cw_rmse_fx_arr      = np.array(agg[\"cw_rmse_fx\"])\n",
    "        cw_rmse_ad_arr      = np.array(agg[\"cw_rmse_ad\"])\n",
    "        cw_rmse_arm_km1_arr = np.array(agg[\"cw_rmse_arm_km1\"])\n",
    "        cw_rmse_arm_k_arr   = np.array(agg[\"cw_rmse_arm_k\"])\n",
    "        cw_rmse_arm_kp1_arr = np.array(agg[\"cw_rmse_arm_kp1\"])\n",
    "\n",
    "        diff_cw_ind_all     = np.concatenate(agg[\"diff_cw_ind\"])\n",
    "        diff_cw_fx_all      = np.concatenate(agg[\"diff_cw_fx\"])\n",
    "        diff_cw_ad_all      = np.concatenate(agg[\"diff_cw_ad\"])\n",
    "        diff_cw_arm_km1_all = np.concatenate(agg[\"diff_cw_arm_km1\"])\n",
    "        diff_cw_arm_k_all   = np.concatenate(agg[\"diff_cw_arm_k\"])\n",
    "        diff_cw_arm_kp1_all = np.concatenate(agg[\"diff_cw_arm_kp1\"])\n",
    "\n",
    "        res = {\n",
    "            # estimators\n",
    "            \"theta_true_all\": TH_true_all,\n",
    "            \"theta_ind_all\": TH_ind_all,\n",
    "            \"theta_fx_all\": TH_fx_all,\n",
    "            \"theta_ad_all\": TH_ad_all,\n",
    "            \"theta_armul_km1_all\": TH_arm_km1_all,\n",
    "            \"theta_armul_k_all\": TH_arm_k_all,\n",
    "            \"theta_armul_kp1_all\": TH_arm_kp1_all,\n",
    "\n",
    "            \"zI\": zI_all,\n",
    "            \"zA_fx\": zA_fx_all,\n",
    "            \"zA_ad\": zA_ad_all,\n",
    "\n",
    "            \"ARI_ind\": ARI_ind_arr,\n",
    "            \"ARI_fx\": ARI_fx_arr,\n",
    "            \"ARI_ad\": ARI_ad_arr,\n",
    "            \"ARI_armul_km1\": ARI_arm_km1_arr,\n",
    "            \"ARI_armul_k\":   ARI_arm_k_arr,\n",
    "            \"ARI_armul_kp1\": ARI_arm_kp1_arr,\n",
    "\n",
    "            \"rmse_ind_all\":   rmse_ind_arr,\n",
    "            \"rmse_fx_all\":    rmse_fx_arr,\n",
    "            \"rmse_ad_all\":    rmse_ad_arr,\n",
    "            \"rmse_armul_km1_all\": rmse_arm_km1_arr,\n",
    "            \"rmse_armul_k_all\":   rmse_arm_k_arr,\n",
    "            \"rmse_armul_kp1_all\": rmse_arm_kp1_arr,\n",
    "\n",
    "            \"mean_rmse_ind\": float(np.mean(rmse_ind_arr)),\n",
    "            \"mean_rmse_fx\":  float(np.mean(rmse_fx_arr)),\n",
    "            \"mean_rmse_ad\":  float(np.mean(rmse_ad_arr)),\n",
    "            \"mean_rmse_armul_km1\": float(np.mean(rmse_arm_km1_arr)),\n",
    "            \"mean_rmse_armul_k\":   float(np.mean(rmse_arm_k_arr)),\n",
    "            \"mean_rmse_armul_kp1\": float(np.mean(rmse_arm_kp1_arr)),\n",
    "            \"mean_ari_ind\": float(np.mean(ARI_ind_arr)),\n",
    "            \"mean_ari_fx\":  float(np.mean(ARI_fx_arr)),\n",
    "            \"mean_ari_ad\":  float(np.mean(ARI_ad_arr)),\n",
    "            \"mean_ari_armul_km1\": float(np.mean(ARI_arm_km1_arr)),\n",
    "            \"mean_ari_armul_k\":   float(np.mean(ARI_arm_k_arr)),\n",
    "            \"mean_ari_armul_kp1\": float(np.mean(ARI_arm_kp1_arr)),\n",
    "\n",
    "            \"armul_K_sel_km1_all\": armul_K_sel_km1_all,\n",
    "            \"armul_K_sel_k_all\":   armul_K_sel_k_all,\n",
    "            \"armul_K_sel_kp1_all\": armul_K_sel_kp1_all,\n",
    "            \"armul_lambda_all\": armul_lambda_all,\n",
    "\n",
    "            \"cw_rmse_ind_all\":     cw_rmse_ind_arr,\n",
    "            \"cw_rmse_fx_all\":      cw_rmse_fx_arr,\n",
    "            \"cw_rmse_ad_all\":      cw_rmse_ad_arr,\n",
    "            \"cw_rmse_armul_km1_all\": cw_rmse_arm_km1_arr,\n",
    "            \"cw_rmse_armul_k_all\":   cw_rmse_arm_k_arr,\n",
    "            \"cw_rmse_armul_kp1_all\": cw_rmse_arm_kp1_arr,\n",
    "\n",
    "            \"diff_cw_ind_all\":     diff_cw_ind_all,\n",
    "            \"diff_cw_fx_all\":      diff_cw_fx_all,\n",
    "            \"diff_cw_ad_all\":      diff_cw_ad_all,\n",
    "            \"diff_cw_armul_km1_all\": diff_cw_arm_km1_all,\n",
    "            \"diff_cw_armul_k_all\":   diff_cw_arm_k_all,\n",
    "            \"diff_cw_armul_kp1_all\": diff_cw_arm_kp1_all,\n",
    "\n",
    "            \"mean_cw_rmse_ind\": float(np.mean(cw_rmse_ind_arr)),\n",
    "            \"mean_cw_rmse_fx\":  float(np.mean(cw_rmse_fx_arr)),\n",
    "            \"mean_cw_rmse_ad\":  float(np.mean(cw_rmse_ad_arr)),\n",
    "            \"mean_cw_rmse_armul_km1\": float(np.mean(cw_rmse_arm_km1_arr)),\n",
    "            \"mean_cw_rmse_armul_k\":   float(np.mean(cw_rmse_arm_k_arr)),\n",
    "            \"mean_cw_rmse_armul_kp1\": float(np.mean(cw_rmse_arm_kp1_arr)),\n",
    "        }\n",
    "\n",
    "        res_by_lam[fixed_lambda] = res\n",
    "\n",
    "    return res_by_lam\n",
    "\n",
    "\n",
    "\n",
    "# =====================\n",
    "# Saving helpers\n",
    "# =====================\n",
    "\n",
    "def save_delta_results(outdir, delta, res_by_lam):\n",
    "    fixed_lams = np.array(sorted(res_by_lam.keys()), dtype=float)\n",
    "    npz_path = os.path.join(outdir, f\"teps_delta{delta:.2f}.npz\")\n",
    "\n",
    "    to_save = {}\n",
    "    to_save[\"fixed_lambda_values\"] = fixed_lams\n",
    "\n",
    "    rep_res = res_by_lam[fixed_lams[0]]\n",
    "\n",
    "    shared_array_keys = [\n",
    "        \"theta_true_all\",\n",
    "        \"theta_ind_all\",\n",
    "        \"theta_ad_all\",\n",
    "        \"theta_armul_km1_all\",\n",
    "        \"theta_armul_k_all\",\n",
    "        \"theta_armul_kp1_all\",\n",
    "        \"zI\",\n",
    "        \"zA_ad\",\n",
    "        \"ARI_ind\",\n",
    "        \"ARI_ad\",\n",
    "        \"ARI_armul_km1\",\n",
    "        \"ARI_armul_k\",\n",
    "        \"ARI_armul_kp1\",\n",
    "        \"rmse_ind_all\",\n",
    "        \"rmse_ad_all\",\n",
    "        \"rmse_armul_km1_all\",\n",
    "        \"rmse_armul_k_all\",\n",
    "        \"rmse_armul_kp1_all\",\n",
    "        \"armul_K_sel_km1_all\",\n",
    "        \"armul_K_sel_k_all\",\n",
    "        \"armul_K_sel_kp1_all\",\n",
    "        \"armul_lambda_all\",\n",
    "\n",
    "        \"cw_rmse_ind_all\",\n",
    "        \"cw_rmse_ad_all\",\n",
    "        \"cw_rmse_armul_km1_all\",\n",
    "        \"cw_rmse_armul_k_all\",\n",
    "        \"cw_rmse_armul_kp1_all\",\n",
    "        \"diff_cw_ind_all\",\n",
    "        \"diff_cw_ad_all\",\n",
    "        \"diff_cw_armul_km1_all\",\n",
    "        \"diff_cw_armul_k_all\",\n",
    "        \"diff_cw_armul_kp1_all\",\n",
    "    ]\n",
    "\n",
    "    shared_summary_keys = [\n",
    "        \"mean_rmse_ind\",\n",
    "        \"mean_rmse_ad\",\n",
    "        \"mean_rmse_armul_km1\",\n",
    "        \"mean_rmse_armul_k\",\n",
    "        \"mean_rmse_armul_kp1\",\n",
    "        \"mean_ari_ind\",\n",
    "        \"mean_ari_ad\",\n",
    "        \"mean_ari_armul_km1\",\n",
    "        \"mean_ari_armul_k\",\n",
    "        \"mean_ari_armul_kp1\",\n",
    "\n",
    "        \"mean_cw_rmse_ind\",\n",
    "        \"mean_cw_rmse_ad\",\n",
    "        \"mean_cw_rmse_armul_km1\",\n",
    "        \"mean_cw_rmse_armul_k\",\n",
    "        \"mean_cw_rmse_armul_kp1\",\n",
    "    ]\n",
    "\n",
    "    for key in shared_array_keys:\n",
    "        to_save[key] = rep_res[key]\n",
    "    for key in shared_summary_keys:\n",
    "        to_save[key] = np.array(rep_res[key], dtype=float)\n",
    "\n",
    "    per_lam_array_keys = [\n",
    "        \"theta_fx_all\",\n",
    "        \"zA_fx\",\n",
    "        \"ARI_fx\",\n",
    "        \"rmse_fx_all\",\n",
    "        \"cw_rmse_fx_all\",\n",
    "        \"diff_cw_fx_all\",\n",
    "    ]\n",
    "    per_lam_summary_keys = [\n",
    "        \"mean_rmse_fx\",\n",
    "        \"mean_ari_fx\",\n",
    "        \"mean_cw_rmse_fx\",\n",
    "    ]\n",
    "\n",
    "    for k, lam in enumerate(fixed_lams):\n",
    "        res = res_by_lam[lam]\n",
    "        prefix = f\"lam{k}_\"\n",
    "        for key in per_lam_array_keys:\n",
    "            to_save[prefix + key] = res[key]\n",
    "        for key in per_lam_summary_keys:\n",
    "            to_save[prefix + key] = np.array(res[key], dtype=float)\n",
    "\n",
    "    np.savez(npz_path, **to_save)\n",
    "\n",
    "\n",
    "\n",
    "def append_summary_csv(outdir, delta, res_by_lam):\n",
    "    \"\"\"\n",
    "    One CSV file, with one row per (delta, lambda).\n",
    "    \"\"\"\n",
    "    csv_path = os.path.join(outdir, \"summary.csv\")\n",
    "    write_header = not os.path.exists(csv_path)\n",
    "\n",
    "    import csv\n",
    "    fieldnames = [\n",
    "        \"model\",\n",
    "        \"delta\",\n",
    "        \"fixed_lambda\",\n",
    "        \"rounds\",\n",
    "        \"m_tasks\",\n",
    "        \"mean_rmse_individual\",\n",
    "        \"mean_rmse_fixed\",\n",
    "        \"mean_rmse_adaptive\",\n",
    "        \"mean_rmse_armul_km1\",\n",
    "        \"mean_rmse_armul_k\",\n",
    "        \"mean_rmse_armul_kp1\",\n",
    "        \"mean_ari_individual\",\n",
    "        \"mean_ari_fixed\",\n",
    "        \"mean_ari_adaptive\",\n",
    "        \"mean_ari_armul_km1\",\n",
    "        \"mean_ari_armul_k\",\n",
    "        \"mean_ari_armul_kp1\",\n",
    "        \"notes\",\n",
    "    ]\n",
    "\n",
    "    with open(csv_path, \"a\", newline=\"\") as f:\n",
    "        w = csv.DictWriter(f, fieldnames=fieldnames)\n",
    "        if write_header:\n",
    "            w.writeheader()\n",
    "\n",
    "        for fixed_lambda, res in res_by_lam.items():\n",
    "            row = {\n",
    "                \"model\": \"ate_aipw_lgbm_split_noCV_pooledSE_plus_ARMUL\",\n",
    "                \"delta\": float(delta),\n",
    "                \"fixed_lambda\": float(fixed_lambda),\n",
    "                \"rounds\": ROUNDS,\n",
    "                \"m_tasks\": M_TASKS,\n",
    "                \"mean_rmse_individual\": res[\"mean_rmse_ind\"],\n",
    "                \"mean_rmse_fixed\":      res[\"mean_rmse_fx\"],\n",
    "                \"mean_rmse_adaptive\":   res[\"mean_rmse_ad\"],\n",
    "                \"mean_rmse_armul_km1\":  res[\"mean_rmse_armul_km1\"],\n",
    "                \"mean_rmse_armul_k\":    res[\"mean_rmse_armul_k\"],\n",
    "                \"mean_rmse_armul_kp1\":  res[\"mean_rmse_armul_kp1\"],\n",
    "                \"mean_ari_individual\":  res[\"mean_ari_ind\"],\n",
    "                \"mean_ari_fixed\":       res[\"mean_ari_fx\"],\n",
    "                \"mean_ari_adaptive\":    res[\"mean_ari_ad\"],\n",
    "                \"mean_ari_armul_km1\":   res[\"mean_ari_armul_km1\"],\n",
    "                \"mean_ari_armul_k\":     res[\"mean_ari_armul_k\"],\n",
    "                \"mean_ari_armul_kp1\":   res[\"mean_ari_armul_kp1\"],\n",
    "                \"notes\": (\n",
    "                    \"ATE with AIPW (DR) score; lambda_j = C_LAM * sqrt(1/n_j) with n_j=|T_j|. \"\n",
    "                    \"Clusterwise IF-SE from per-task AIPW var_b, no nuisance refit. \"\n",
    "                    f\"Fixed fused λ = {fixed_lambda}.\"\n",
    "                ),\n",
    "            }\n",
    "            w.writerow(row)\n",
    "\n",
    "# =====================\n",
    "# Main\n",
    "# =====================\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    timestamp = datetime.now().strftime(\"%Y%m%d_%H%M%S\")\n",
    "    base_out = ensure_dir(os.path.join(\"results_teps_aipw\", timestamp))\n",
    "    model_out = ensure_dir(os.path.join(base_out, \"ate_aipw\"))\n",
    "\n",
    "    print(f\"Running ATE (AIPW) with IND / FIX / ADAPT and ARMUL(K-1,K,K+1). \"\n",
    "          f\"Rounds={ROUNDS}, m={M_TASKS}, true K={K_CLUST}\")\n",
    "    print(f\"Fixed λ grid (per delta): {FIXED_LAMBDA_LIST}\")\n",
    "\n",
    "    for delta in DELTAS:\n",
    "        print(f\"\\n=== delta = {delta} ===\")\n",
    "        res_by_lam = run_for_delta_fixed_grid(\n",
    "            rounds=ROUNDS,\n",
    "            delta=delta,\n",
    "            start_seed=SEED_BASE,\n",
    "            fixed_lambda_list=FIXED_LAMBDA_LIST,\n",
    "        )\n",
    "\n",
    "        lam0 = sorted(res_by_lam.keys())[0]\n",
    "        shared = res_by_lam[lam0]\n",
    "        print(\"Shared metrics (IND / ADAPT / ARMUL):\")\n",
    "        print({\n",
    "            \"delta\": delta,\n",
    "            \"Mean_RMSE_Indiv\":        shared[\"mean_rmse_ind\"],\n",
    "            \"Mean_RMSE_Adapt\":        shared[\"mean_rmse_ad\"],\n",
    "            \"Mean_RMSE_ARMUL(K-1)\":   shared[\"mean_rmse_armul_km1\"],\n",
    "            \"Mean_RMSE_ARMUL(K)\":     shared[\"mean_rmse_armul_k\"],\n",
    "            \"Mean_RMSE_ARMUL(K+1)\":   shared[\"mean_rmse_armul_kp1\"],\n",
    "            \"Mean_ARI_Indiv\":         shared[\"mean_ari_ind\"],\n",
    "            \"Mean_ARI_Adapt\":         shared[\"mean_ari_ad\"],\n",
    "            \"Mean_ARI_ARMUL(K-1)\":    shared[\"mean_ari_armul_km1\"],\n",
    "            \"Mean_ARI_ARMUL(K)\":      shared[\"mean_ari_armul_k\"],\n",
    "            \"Mean_ARI_ARMUL(K+1)\":    shared[\"mean_ari_armul_kp1\"],\n",
    "        })\n",
    "\n",
    "        print(\"Fixed-lambda sweep (only FIX metrics):\")\n",
    "        for lam in sorted(res_by_lam.keys()):\n",
    "            res = res_by_lam[lam]\n",
    "            print({\n",
    "                \"delta\": delta,\n",
    "                \"fixed_lambda\": lam,\n",
    "                \"Mean_RMSE_Fixed\": res[\"mean_rmse_fx\"],\n",
    "                \"Mean_ARI_Fixed\":  res[\"mean_ari_fx\"],\n",
    "            })\n",
    "\n",
    "        save_delta_results(model_out, delta, res_by_lam)\n",
    "        append_summary_csv(base_out, delta, res_by_lam)\n",
    "\n",
    "    manifest = {\n",
    "        \"model\": \"ATE_with_AIPW_plus_ARMUL_fast\",\n",
    "        \"deltas\": DELTAS,\n",
    "        \"rounds\": ROUNDS,\n",
    "        \"m_tasks\": M_TASKS,\n",
    "        \"timestamp\": timestamp,\n",
    "        \"true_K_for_eval\": K_CLUST,\n",
    "        \"centers\": \"theta_k = k*delta - (K+1)*delta/2\",\n",
    "        \"nuisance_target_split\": {\"A_frac\": A_FRAC, \"role_swap\": False},\n",
    "        \"orthogonal_score\": (\n",
    "            \"AIPW / DR EIF: phi = A/e (Y-m1) - (1-A)/(1-e)(Y-m0) + m1-m0; \"\n",
    "            \"per-task loss 0.5 * (theta_j - mean(phi))^2 with a_j=1.0.\"\n",
    "        ),\n",
    "        \"fusion_hyperparams\": {\n",
    "            \"adaptive\": {\n",
    "                \"gamma\": GAMMA_DEFAULT,\n",
    "                \"tau\": ADAPT_TAU_DEFAULT,\n",
    "                \"eps\": ADAPT_EPS,\n",
    "                \"c_w\": C_W\n",
    "            },\n",
    "            \"fixed_lambda_list\": FIXED_LAMBDA_LIST\n",
    "        },\n",
    "        \"armul\": {\n",
    "            \"C_LAM\": C_LAM,\n",
    "            \"variants\": [\"K-1\", \"K (oracle)\", \"K+1\"],\n",
    "            \"n_init\": ARMUL_N_INIT,\n",
    "            \"n_iter\": ARMUL_N_ITER,\n",
    "            \"tol\": ARMUL_TOL,\n",
    "            \"seed\": ARMUL_SEED,\n",
    "            \"lambda_rule\": \"λ_j = C_LAM * sqrt(1 / n_j), with n_j = |T_j| by default\"\n",
    "        },\n",
    "        \"nuisance\": {\n",
    "            \"m0(X), m1(X)\": \"LGBMRegressor (two separate models by A)\",\n",
    "            \"e(X)\": \"LGBMClassifier\",\n",
    "            \"shared_params\": {\n",
    "                \"n_estimators\": LGBM_ESTIMATORS,\n",
    "                \"learning_rate\": LGBM_LEARNING_RATE,\n",
    "                \"num_leaves\": LGBM_NUM_LEAVES,\n",
    "                \"min_child_samples\": LGBM_MIN_CHILD_SAMPLES,\n",
    "                \"subsample\": LGBM_SUBSAMPLE,\n",
    "                \"colsample_bytree\": LGBM_COLSAMPLE,\n",
    "                \"reg_alpha\": LGBM_REG_ALPHA,\n",
    "                \"reg_lambda\": LGBM_REG_LAMBDA,\n",
    "                \"n_jobs\": N_JOBS\n",
    "            }\n",
    "        },\n",
    "        \"inference\": (\n",
    "            \"Per-task IF var from AIPW phi; clusterwise SE = pooled var_b per cluster; \"\n",
    "            \"no nuisance refit in second stage.\"\n",
    "        )\n",
    "    }\n",
    "    with open(os.path.join(base_out, \"manifest.json\"), \"w\") as f:\n",
    "        json.dump(manifest, f, indent=2)\n",
    "\n",
    "    print(f\"\\nAll results saved under: ./results_teps_aipw/{timestamp}/\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0680b8eb-b410-4f4f-8501-42d31551b918",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
