{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8b2e74a8-5501-451e-8d7a-03ff9108174c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os, json\n",
    "from datetime import datetime\n",
    "\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "from sklearn.metrics import adjusted_rand_score\n",
    "from sklearn.cluster import KMeans\n",
    "from lightgbm import LGBMRegressor\n",
    "\n",
    "# =====================\n",
    "# Hyperparameters\n",
    "# =====================\n",
    "SEED_BASE   = 1\n",
    "ROUNDS      = 100\n",
    "M_TASKS     = 20         # number of tasks\n",
    "K_CLUST     = 3 \n",
    "DELTAS      = [1/3, 2/3, 1]\n",
    "\n",
    "# List of fixed-lambda values\n",
    "FIXED_LAMBDA_LIST = [1e-5, 1e-4, 1e-3, 1e-2, 1e-1]\n",
    "\n",
    "# Split sizes\n",
    "A_FRAC      = 0.5        # fraction to use for nuisance training (split-A)\n",
    "\n",
    "# Noise levels\n",
    "NOISE_Y     = 1\n",
    "NOISE_D     = 1\n",
    "\n",
    "# LightGBM (nuisances)\n",
    "N_JOBS = 1\n",
    "LGBM_ESTIMATORS        = 100\n",
    "LGBM_LEARNING_RATE     = 0.03\n",
    "LGBM_NUM_LEAVES        = 20\n",
    "LGBM_MIN_CHILD_SAMPLES = 20\n",
    "LGBM_SUBSAMPLE         = 0.9\n",
    "LGBM_COLSAMPLE         = 0.9\n",
    "LGBM_REG_ALPHA         = 0.0\n",
    "LGBM_REG_LAMBDA        = 0.0\n",
    "LGBM_SEED_OFFSET       = 1000\n",
    "\n",
    "# Adaptive-fusion hyperparameters\n",
    "C_W               = .1\n",
    "ADAPT_EPS         = 1e-12\n",
    "GAMMA_DEFAULT     = 2       # adaptive gamma\n",
    "ADAPT_TAU_DEFAULT = 10       # adaptive tau\n",
    "\n",
    "# ---------------- ARMUL controls ----------------\n",
    "# Per-task penalty: lambda_j = C_LAM * sqrt(1 / n_j)\n",
    "C_LAM            = 1.0\n",
    "ARMUL_K_MAX      = M_TASKS\n",
    "ARMUL_N_INIT     = 1\n",
    "ARMUL_N_ITER     = 100\n",
    "ARMUL_TOL        = 1e-7\n",
    "ARMUL_SEED       = 123\n",
    "\n",
    "EPS = 1e-12\n",
    "\n",
    "# =====================\n",
    "# Utilities\n",
    "# =====================\n",
    "def ensure_dir(path: str):\n",
    "    os.makedirs(path, exist_ok=True)\n",
    "    return path\n",
    "\n",
    "def split_two_parts(n, a_frac, seed):\n",
    "    \"\"\"Return indices for split-A (nuisance) and split-T (target).\"\"\"\n",
    "    rng = np.random.default_rng(seed)\n",
    "    idx = np.arange(n)\n",
    "    rng.shuffle(idx)\n",
    "    n_a = int(np.floor(a_frac * n))\n",
    "    A = idx[:n_a]\n",
    "    T = idx[n_a:]\n",
    "    return A, T\n",
    "\n",
    "# =====================\n",
    "# PLM DGP\n",
    "# =====================\n",
    "def make_plm_continuous_tasks(m: int, K: int, delta: float, seed: int):\n",
    "    \"\"\"\n",
    "    Y = θ_k * D + g_j(X) + ε\n",
    "    D = μ_Dj(X) + ν\n",
    "    θ_k = k * delta - (K+1)delta/2 (centered); tasks evenly assigned to clusters.\n",
    "    \"\"\"\n",
    "    rng = np.random.default_rng(seed)\n",
    "    theta_clusters = np.arange(1, K + 1, dtype=float) * delta - (K + 1) * delta / 2.0\n",
    "\n",
    "    cluster_ids = np.repeat(np.arange(K), repeats=int(np.ceil(m / K)))[:m]\n",
    "    rng.shuffle(cluster_ids)\n",
    "\n",
    "    tasks = []\n",
    "    theta_true = np.zeros(m)\n",
    "    for j in range(m):\n",
    "        n_j = 3200 + 80 * j\n",
    "        p_j = 5 + j\n",
    "\n",
    "        X = rng.normal(0.0, 1.0, size=(n_j, p_j))\n",
    "\n",
    "        # μ_Dj(X)\n",
    "        mu_D = 0.2 * np.tanh(X[:, :p_j].sum(axis=1))\n",
    "        D = mu_D + rng.normal(0.0, NOISE_D, size=n_j)\n",
    "\n",
    "        # g_j(X)\n",
    "        g = 0.0\n",
    "        for i in range(p_j):\n",
    "            xi = X[:, i]\n",
    "            s = np.exp(xi)\n",
    "            g += (-0.8) ** (i+1) * s / (s + 1.0)\n",
    "\n",
    "        k = cluster_ids[j]\n",
    "        theta = theta_clusters[k]\n",
    "        Y = theta * D + g + rng.normal(0.0, NOISE_Y, size=n_j)\n",
    "\n",
    "        tasks.append({\"X\": X, \"D\": D, \"Y\": Y})\n",
    "        theta_true[j] = theta\n",
    "\n",
    "    return tasks, theta_true, cluster_ids\n",
    "\n",
    "def _lgbm_reg(seed):\n",
    "    return LGBMRegressor(\n",
    "        n_estimators=LGBM_ESTIMATORS,\n",
    "        learning_rate=LGBM_LEARNING_RATE,\n",
    "        num_leaves=LGBM_NUM_LEAVES,\n",
    "        min_child_samples=LGBM_MIN_CHILD_SAMPLES,\n",
    "        subsample=LGBM_SUBSAMPLE,\n",
    "        colsample_bytree=LGBM_COLSAMPLE,\n",
    "        reg_alpha=LGBM_REG_ALPHA,\n",
    "        reg_lambda=LGBM_REG_LAMBDA,\n",
    "        objective=\"regression\",\n",
    "        random_state=seed,\n",
    "        n_jobs=N_JOBS,\n",
    "        verbose=-1,\n",
    "    )\n",
    "\n",
    "def orth_residuals_trainA_predict_subset(X, D, Y, A_idx, subset_idx, seed):\n",
    "    \"\"\"\n",
    "    Fit nuisance regressions on split-A only:\n",
    "        m(X) = E[Y | X]\n",
    "        p(X) = E[D | X]\n",
    "    Return orthogonal residuals on DISJOINT subset (e.g., T):\n",
    "        rY = Y - m(X), rD = D - p(X).\n",
    "    \"\"\"\n",
    "    m = _lgbm_reg(seed)\n",
    "    p = _lgbm_reg(seed + 1)\n",
    "    m.fit(X[A_idx], Y[A_idx])\n",
    "    p.fit(X[A_idx], D[A_idx])\n",
    "    rY = Y[subset_idx] - m.predict(X[subset_idx])\n",
    "    rD = D[subset_idx] - p.predict(X[subset_idx])\n",
    "    return rD, rY\n",
    "\n",
    "def sstats_and_if_se(rD, rY):\n",
    "    \"\"\"\n",
    "    Orthogonal PLM score:\n",
    "        rY = θ rD + error.\n",
    "    a = E[rD^2];  b = (Σ rD*rY)/(Σ rD^2);\n",
    "    Var(b) ≈ E[(rD (rY - b rD))^2] / (n * a^2).\n",
    "    \"\"\"\n",
    "    rD = np.asarray(rD, dtype=float)\n",
    "    rY = np.asarray(rY, dtype=float)\n",
    "    n = len(rD)\n",
    "    if n == 0:\n",
    "        return 1.0, 0.0, 1e-12\n",
    "    s2 = float(np.mean(rD ** 2))\n",
    "    s2 = max(s2, EPS)\n",
    "    denom = max(EPS, np.sum(rD ** 2))\n",
    "    b = float(np.sum(rD * rY) / denom)\n",
    "    psi = rD * (rY - b * rD)\n",
    "    var_b = float(np.mean(psi ** 2) / (n * s2 ** 2))\n",
    "    var_b = max(var_b, 1e-12)\n",
    "    return s2, b, var_b\n",
    "\n",
    "# =====================\n",
    "# Graph-fused lasso helpers\n",
    "# =====================\n",
    "def build_lambda_adaptive(theta_init, c_w=C_W, gamma=GAMMA_DEFAULT,\n",
    "                          tau=None, eps=ADAPT_EPS,\n",
    "                          km_n_init=10, random_state=0):\n",
    "    \"\"\"\n",
    "    Adaptive pairwise penalties Λ_{ij}\n",
    "      w_ij = c_w * (|θ_i - θ_j| + 1e-8)^(-gamma).\n",
    "    \"\"\"\n",
    "    m = len(theta_init)\n",
    "    L = np.zeros((m, m), dtype=float)\n",
    "    if m < 2:\n",
    "        return L\n",
    "\n",
    "    wvals = []\n",
    "    for i in range(m):\n",
    "        for j in range(i + 1, m):\n",
    "            diff = abs(theta_init[i] - theta_init[j])\n",
    "            w = c_w * (diff + 1e-8) ** (-gamma)\n",
    "            wvals.append(w)\n",
    "    wvals = np.asarray(wvals, dtype=float)\n",
    "\n",
    "    if tau is None:\n",
    "        z = np.log(np.maximum(wvals, 1e-18)).reshape(-1, 1)\n",
    "        if np.allclose(z, z[0]):\n",
    "            tau = float(np.exp(z[0, 0]))\n",
    "        else:\n",
    "            km = KMeans(n_clusters=2, n_init=km_n_init, random_state=random_state)\n",
    "            km.fit(z)\n",
    "            centers = np.sort(km.cluster_centers_.ravel())\n",
    "            z_split = 0.5 * (centers[0] + centers[1])\n",
    "            tau = float(np.exp(z_split))\n",
    "\n",
    "    idx = 0\n",
    "    for i in range(m):\n",
    "        for j in range(i + 1, m):\n",
    "            w = wvals[idx]; idx += 1\n",
    "            lam = eps if w <= tau else w\n",
    "            L[i, j] = L[j, i] = lam\n",
    "\n",
    "    return L\n",
    "\n",
    "def build_lambda_fixed(m, lam):\n",
    "    L = np.full((m, m), 0.0)\n",
    "    for i in range(m):\n",
    "        for j in range(i + 1, m):\n",
    "            L[i, j] = L[j, i] = lam\n",
    "    return L\n",
    "\n",
    "def admm_graph_fused_lasso(a, b, Lambda, rho=1.0, max_iter=2000, tol=1e-6):\n",
    "    \"\"\"\n",
    "    min_θ Σ_j (a_j/2)(θ_j - b_j)^2 + Σ_{i<j} Λ_{ij}|θ_i - θ_j|.\n",
    "    \"\"\"\n",
    "    m = len(a)\n",
    "    edges = [(i, j) for i in range(m) for j in range(i + 1, m)]\n",
    "    E = len(edges)\n",
    "    if E == 0:\n",
    "        return b.copy()\n",
    "\n",
    "    B = np.zeros((E, m))\n",
    "    for e, (i, j) in enumerate(edges):\n",
    "        B[e, i] = 1.0\n",
    "        B[e, j] = -1.0\n",
    "    lams = np.array([Lambda[i, j] for (i, j) in edges])\n",
    "\n",
    "    A = np.diag(a)\n",
    "    M = A + rho * (B.T @ B)\n",
    "\n",
    "    theta = b.copy()\n",
    "    z = B @ theta\n",
    "    u = np.zeros_like(z)\n",
    "\n",
    "    def soft(v, kappa):\n",
    "        return np.sign(v) * np.maximum(np.abs(v) - kappa, 0.0)\n",
    "\n",
    "    for _ in range(max_iter):\n",
    "        theta_new = np.linalg.solve(M, A @ b + rho * B.T @ (z - u))\n",
    "        r = B @ theta_new + u\n",
    "        z_new = soft(r, lams / rho)\n",
    "        u_new = r - z_new\n",
    "        if np.linalg.norm(B @ theta_new - z_new) < tol and np.linalg.norm(rho * B.T @ (z_new - z)) < tol:\n",
    "            theta = theta_new\n",
    "            break\n",
    "        theta, z, u = theta_new, z_new, u_new\n",
    "\n",
    "    return theta\n",
    "\n",
    "def clusters_from_theta(theta_hat, tol=1e-3):\n",
    "    \"\"\"Greedy grouping of nearly-equal θ's into cluster labels.\"\"\"\n",
    "    order = np.argsort(theta_hat)\n",
    "    labels_sorted = np.zeros_like(order)\n",
    "    cur = 0\n",
    "    labels_sorted[0] = 0\n",
    "    for k in range(1, len(order)):\n",
    "        if abs(theta_hat[order[k]] - theta_hat[order[k - 1]]) > tol:\n",
    "            cur += 1\n",
    "        labels_sorted[k] = cur\n",
    "    labels = np.zeros_like(labels_sorted)\n",
    "    labels[order] = labels_sorted\n",
    "    return labels\n",
    "\n",
    "# =====================\n",
    "# ARMUL: clustered MTL second stage\n",
    "# =====================\n",
    "def _soft(x, t):\n",
    "    return np.sign(x) * np.maximum(np.abs(x) - t, 0.0)\n",
    "\n",
    "def _armul_obj(theta, gamma, c, a, b, lam_vec):\n",
    "    return 0.5 * np.sum(a * (theta - b)**2) + np.sum(lam_vec * np.abs(theta - gamma[c]))\n",
    "\n",
    "def _ensure_nonempty(labels, K, rng):\n",
    "    counts = np.bincount(labels, minlength=K)\n",
    "    for k in range(K):\n",
    "        if counts[k] == 0:\n",
    "            src = int(np.argmax(counts))\n",
    "            idx_src = np.where(labels == src)[0]\n",
    "            j_move = int(rng.choice(idx_src))\n",
    "            labels[j_move] = k\n",
    "            counts[src] -= 1\n",
    "            counts[k] += 1\n",
    "    return labels\n",
    "\n",
    "def _kmeans_init_labels_1d(b, K, seed, n_init_kmeans=10):\n",
    "    km = KMeans(n_clusters=K, n_init=n_init_kmeans, random_state=seed)\n",
    "    labels = km.fit_predict(b.reshape(-1, 1))\n",
    "    return labels.astype(int, copy=False)\n",
    "\n",
    "def armul_clustered_fit(\n",
    "    b, a, lam_vec,\n",
    "    K_candidates,\n",
    "    init=\"kmeans\", n_init=1, n_iter=300, tol=1e-7, seed=123, kmeans_n_init=10\n",
    "):\n",
    "    rng = np.random.default_rng(seed)\n",
    "    m = len(b)\n",
    "    lam_vec = np.asarray(lam_vec, dtype=float)\n",
    "    assert lam_vec.shape == (m,)\n",
    "\n",
    "    K_list = [K_candidates] if isinstance(K_candidates, int) else list(K_candidates)\n",
    "    if not K_list:\n",
    "        raise ValueError(\"K_candidates is empty.\")\n",
    "\n",
    "    def theta_update(gamma, labels):\n",
    "        theta = np.empty_like(b, dtype=float)\n",
    "        for j in range(m):\n",
    "            gk = gamma[labels[j]]\n",
    "            theta[j] = gk + _soft(b[j] - gk, lam_vec[j] / max(a[j], EPS))\n",
    "        return theta\n",
    "\n",
    "    best = dict(obj=np.inf, theta=None, labels=None, gamma=None, K=None)\n",
    "\n",
    "    for K in K_list:\n",
    "        for init_id in range(n_init):\n",
    "            if init == \"kmeans\":\n",
    "                labels = _kmeans_init_labels_1d(b, K, seed + init_id, n_init_kmeans=kmeans_n_init)\n",
    "            elif init == \"quantiles\":\n",
    "                order = np.argsort(b)\n",
    "                labels = np.zeros(m, dtype=int)\n",
    "                edges = np.linspace(0, m, K + 1).astype(int)\n",
    "                for k in range(K):\n",
    "                    labels[order[edges[k]:edges[k + 1]]] = k\n",
    "            elif init == \"random\":\n",
    "                labels = rng.integers(0, K, size=m)\n",
    "            else:\n",
    "                raise ValueError(f\"Unknown init='{init}'\")\n",
    "\n",
    "            labels = _ensure_nonempty(labels, K, rng)\n",
    "\n",
    "            gamma = np.array([np.median(b[labels == k]) for k in range(K)], dtype=float)\n",
    "            theta = theta_update(gamma, labels)\n",
    "            obj_prev = _armul_obj(theta, gamma, labels, a, b, lam_vec)\n",
    "\n",
    "            for _ in range(n_iter):\n",
    "                # θ-update\n",
    "                theta = theta_update(gamma, labels)\n",
    "                # γ-update\n",
    "                for k in range(K):\n",
    "                    idx = np.where(labels == k)[0]\n",
    "                    if len(idx) == 0:\n",
    "                        gamma[k] = float(rng.choice(theta))\n",
    "                    else:\n",
    "                        gamma[k] = float(np.median(theta[idx]))\n",
    "                # c-update\n",
    "                theta_new = np.empty_like(theta)\n",
    "                labels_new = labels.copy()\n",
    "                for j in range(m):\n",
    "                    costs = np.empty(K, dtype=float)\n",
    "                    thetas_j = np.empty(K, dtype=float)\n",
    "                    for k in range(K):\n",
    "                        gk = gamma[k]\n",
    "                        tj = gk + _soft(b[j] - gk, lam_vec[j] / max(a[j], EPS))\n",
    "                        thetas_j[k] = tj\n",
    "                        costs[k] = 0.5 * a[j] * (tj - b[j])**2 + lam_vec[j] * abs(tj - gk)\n",
    "                    kbest = int(np.argmin(costs))\n",
    "                    labels_new[j] = kbest\n",
    "                    theta_new[j]  = thetas_j[kbest]\n",
    "\n",
    "                labels_new = _ensure_nonempty(labels_new, K, rng)\n",
    "                obj = _armul_obj(theta_new, gamma, labels_new, a, b, lam_vec)\n",
    "\n",
    "                if (obj_prev - obj) <= tol:\n",
    "                    theta, labels, obj_prev = theta_new, labels_new, obj\n",
    "                    break\n",
    "                theta, labels, obj_prev = theta_new, labels_new, obj\n",
    "\n",
    "            if obj_prev < best[\"obj\"]:\n",
    "                best = dict(theta=theta.copy(), labels=labels.copy(),\n",
    "                            gamma=gamma.copy(), obj=float(obj_prev), K=K)\n",
    "\n",
    "    return best[\"theta\"], best[\"labels\"], best[\"obj\"], best[\"K\"], best.get(\"gamma\", None)\n",
    "\n",
    "# =============== Per-task λ_j = C_LAM * sqrt(1 / n_j) ===============\n",
    "def armul_lambda_per_task(splits, c=C_LAM, use_target=True):\n",
    "    n_list = []\n",
    "    for (A_idx, T_idx) in splits:\n",
    "        n_j = len(T_idx) if use_target else (len(A_idx) + len(T_idx))\n",
    "        n_list.append(max(1, int(n_j)))\n",
    "    n_arr = np.asarray(n_list, dtype=float)\n",
    "    lam_vec = c * np.sqrt(1.0 / n_arr)\n",
    "    return lam_vec.astype(float), {\"n_list\": n_arr.copy()}\n",
    "\n",
    "# =====================\n",
    "# Clusterwise SE\n",
    "# =====================\n",
    "def cluster_sandwich_from_var(theta_in, labels, var_b_list):\n",
    "    \"\"\"\n",
    "    var_b_list[j] is per-task IF variance Var(θ̂_j).\n",
    "    Within each cluster g, we pool var_b_j and assign the same SE to all tasks:\n",
    "        Var_cluster ≈ mean_j∈g var_b_j.\n",
    "    \"\"\"\n",
    "    L = int(labels.max() + 1)\n",
    "    se_out = np.zeros_like(theta_in, dtype=float)\n",
    "    var_b_arr = np.asarray(var_b_list, dtype=float)\n",
    "\n",
    "    for g in range(L):\n",
    "        idx_tasks = np.where(labels == g)[0]\n",
    "        if len(idx_tasks) == 0:\n",
    "            continue\n",
    "        var_g = float(np.mean(var_b_arr[idx_tasks]))\n",
    "        se_g = float(np.sqrt(max(var_g, 1e-12)))\n",
    "        se_out[idx_tasks] = se_g\n",
    "\n",
    "    return theta_in.copy(), se_out, labels\n",
    "\n",
    "# =====================\n",
    "# One-round\n",
    "# =====================\n",
    "def one_round_core(delta: float, seed: int):\n",
    "    \"\"\"\n",
    "    Do the DGP, orthogonal stats, individual, adaptive fused, and ARMUL.\n",
    "    \"\"\"\n",
    "    tasks, theta_true, cluster_ids = make_plm_continuous_tasks(M_TASKS, K_CLUST, delta, seed)\n",
    "\n",
    "    # per-task splits\n",
    "    splits = []\n",
    "    for j, t in enumerate(tasks):\n",
    "        n = len(t[\"Y\"])\n",
    "        A_idx, T_idx = split_two_parts(n, A_FRAC, seed=SEED_BASE + 11111 + j)\n",
    "        splits.append((A_idx, T_idx))\n",
    "\n",
    "    # sufficient stats on T (orthogonal score)\n",
    "    a_T, b_T, var_b_T = [], [], []\n",
    "    for j, t in enumerate(tasks):\n",
    "        X, D, Y = t[\"X\"], t[\"D\"], t[\"Y\"]\n",
    "        A_idx, T_idx = splits[j]\n",
    "\n",
    "        rD_T, rY_T = orth_residuals_trainA_predict_subset(\n",
    "            X, D, Y, A_idx, T_idx, seed=LGBM_SEED_OFFSET + j\n",
    "        )\n",
    "        s2, bj, vj = sstats_and_if_se(rD_T, rY_T)\n",
    "        a_T.append(s2)\n",
    "        b_T.append(bj)\n",
    "        var_b_T.append(vj)\n",
    "\n",
    "    a_T = np.array(a_T)\n",
    "    b_T = np.array(b_T)\n",
    "    var_b_T = np.array(var_b_T)\n",
    "    se_T_indiv = np.sqrt(np.maximum(EPS, var_b_T.copy()))\n",
    "\n",
    "    theta_true = np.array(theta_true)\n",
    "    cluster_ids = np.array(cluster_ids, dtype=int)\n",
    "\n",
    "    n_T = np.array([len(T_idx) for (A_idx, T_idx) in splits], dtype=float)\n",
    "\n",
    "    K_true = int(cluster_ids.max() + 1)\n",
    "    N_k_true = np.zeros(K_true, dtype=float)\n",
    "    for k in range(K_true):\n",
    "        idx_k = (cluster_ids == k)\n",
    "        N_k_true[k] = n_T[idx_k].sum()\n",
    "\n",
    "    cluster_weights = N_k_true[cluster_ids]  # shape (M_TASKS,)\n",
    "\n",
    "    # (i) Individual\n",
    "    theta_ind = b_T.copy()\n",
    "    cl_ind = clusters_from_theta(theta_ind)\n",
    "    ari_ind = float(adjusted_rand_score(cluster_ids, cl_ind))\n",
    "\n",
    "    # Cluster-size–weighted metrics for IND\n",
    "    err_ind = theta_ind - theta_true\n",
    "    cw_rmse_ind = float(np.sqrt(np.sum(cluster_weights * err_ind**2)))\n",
    "    diff_cw_ind = np.sqrt(cluster_weights) * err_ind\n",
    "\n",
    "    # (ii) Adaptive fused — use individual PLM estimator as initial θ\n",
    "    L_ad_T = build_lambda_adaptive(\n",
    "        theta_ind,\n",
    "        c_w=C_W,\n",
    "        gamma=GAMMA_DEFAULT,\n",
    "        tau=ADAPT_TAU_DEFAULT,\n",
    "        eps=ADAPT_EPS\n",
    "    )\n",
    "    theta_ad_pen = admm_graph_fused_lasso(a_T, b_T, L_ad_T)\n",
    "    cl_ad = clusters_from_theta(theta_ad_pen)\n",
    "    theta_ad, se_ad, cl_ad = cluster_sandwich_from_var(theta_ad_pen, cl_ad, var_b_T)\n",
    "    ari_ad = float(adjusted_rand_score(cluster_ids, cl_ad))\n",
    "\n",
    "    err_ad = theta_ad - theta_true\n",
    "    cw_rmse_ad = float(np.sqrt(np.sum(cluster_weights * err_ad**2)))\n",
    "    diff_cw_ad = np.sqrt(cluster_weights) * err_ad\n",
    "\n",
    "    # (iii) ARMUL with per-task λ_j — K-1, K (oracle), K+1\n",
    "    lam_vec, lam_diag = armul_lambda_per_task(splits, c=C_LAM, use_target=True)\n",
    "\n",
    "    def fit_armul_with_K(K_fixed):\n",
    "        K_fixed = int(np.clip(K_fixed, 1, M_TASKS))\n",
    "        theta_k, labels_k, obj_k, K_sel_k, _ = armul_clustered_fit(\n",
    "            b=b_T, a=a_T, lam_vec=lam_vec,\n",
    "            K_candidates=[K_fixed],\n",
    "            init=\"kmeans\",\n",
    "            n_init=ARMUL_N_INIT,\n",
    "            n_iter=ARMUL_N_ITER,\n",
    "            tol=ARMUL_TOL,\n",
    "            seed=ARMUL_SEED,\n",
    "            kmeans_n_init=10,\n",
    "        )\n",
    "        theta_k, se_k, labels_k = cluster_sandwich_from_var(theta_k, labels_k, var_b_T)\n",
    "        ari_k = float(adjusted_rand_score(cluster_ids, labels_k))\n",
    "        rmse_k = float(np.sqrt(np.mean((theta_k - theta_true) ** 2)))\n",
    "        return dict(theta=theta_k, se=se_k, labels=labels_k,\n",
    "                    ari=ari_k, rmse=rmse_k, K=K_sel_k, obj=obj_k)\n",
    "\n",
    "    armul_km1 = fit_armul_with_K(K_CLUST - 1)\n",
    "    armul_k   = fit_armul_with_K(K_CLUST)       # oracle\n",
    "    armul_kp1 = fit_armul_with_K(K_CLUST + 1)\n",
    "\n",
    "    # Metrics for ARMUL\n",
    "    err_arm_km1 = armul_km1[\"theta\"] - theta_true\n",
    "    err_arm_k   = armul_k[\"theta\"]   - theta_true\n",
    "    err_arm_kp1 = armul_kp1[\"theta\"] - theta_true\n",
    "\n",
    "    cw_rmse_armul_km1 = float(np.sqrt(np.sum(cluster_weights * err_arm_km1**2)))\n",
    "    cw_rmse_armul_k   = float(np.sqrt(np.sum(cluster_weights * err_arm_k**2)))\n",
    "    cw_rmse_armul_kp1 = float(np.sqrt(np.sum(cluster_weights * err_arm_kp1**2)))\n",
    "\n",
    "    diff_cw_armul_km1 = np.sqrt(cluster_weights) * err_arm_km1\n",
    "    diff_cw_armul_k   = np.sqrt(cluster_weights) * err_arm_k\n",
    "    diff_cw_armul_kp1 = np.sqrt(cluster_weights) * err_arm_kp1\n",
    "\n",
    "    # RMSE (unweighted) for individual & adaptive\n",
    "    rmse_ind   = float(np.sqrt(np.mean((theta_ind - theta_true) ** 2)))\n",
    "    rmse_ad    = float(np.sqrt(np.mean((theta_ad  - theta_true) ** 2)))\n",
    "\n",
    "    zI     = (theta_ind - theta_true) / se_T_indiv\n",
    "    zA_ad  = (theta_ad  - theta_true) / se_ad\n",
    "\n",
    "    return {\n",
    "        # truths and core stats\n",
    "        \"theta_true\": theta_true,\n",
    "        \"cluster_ids_true\": cluster_ids,\n",
    "        \"a_T\": a_T,\n",
    "        \"b_T\": b_T,\n",
    "        \"var_b_T\": var_b_T,\n",
    "        \"se_T_indiv\": se_T_indiv,\n",
    "\n",
    "        # cluster sizes / weights\n",
    "        \"cluster_n_T\": n_T,\n",
    "        \"cluster_N_k_true\": N_k_true,\n",
    "        \"cluster_weights\": cluster_weights,\n",
    "\n",
    "        # individual\n",
    "        \"theta_ind\": theta_ind,\n",
    "        \"ARI_ind\": ari_ind,\n",
    "        \"rmse_ind\": rmse_ind,\n",
    "        \"zI\": zI,\n",
    "\n",
    "        # adaptive fused\n",
    "        \"theta_ad\": theta_ad,\n",
    "        \"se_ad\": se_ad,\n",
    "        \"ARI_ad\": ari_ad,\n",
    "        \"rmse_ad\": rmse_ad,\n",
    "        \"zA_ad\": zA_ad,\n",
    "\n",
    "        # ARMUL variants\n",
    "        \"theta_armul_km1\": armul_km1[\"theta\"],\n",
    "        \"se_armul_km1\": armul_km1[\"se\"],\n",
    "        \"ARI_armul_km1\": armul_km1[\"ari\"],\n",
    "        \"rmse_armul_km1\": armul_km1[\"rmse\"],\n",
    "        \"armul_K_sel_km1\": armul_km1[\"K\"],\n",
    "        \"armul_obj_km1\": armul_km1[\"obj\"],\n",
    "\n",
    "        \"theta_armul_k\": armul_k[\"theta\"],\n",
    "        \"se_armul_k\": armul_k[\"se\"],\n",
    "        \"ARI_armul_k\": armul_k[\"ari\"],\n",
    "        \"rmse_armul_k\": armul_k[\"rmse\"],\n",
    "        \"armul_K_sel_k\": armul_k[\"K\"],\n",
    "        \"armul_obj_k\": armul_k[\"obj\"],\n",
    "\n",
    "        \"theta_armul_kp1\": armul_kp1[\"theta\"],\n",
    "        \"se_armul_kp1\": armul_kp1[\"se\"],\n",
    "        \"ARI_armul_kp1\": armul_kp1[\"ari\"],\n",
    "        \"rmse_armul_kp1\": armul_kp1[\"rmse\"],\n",
    "        \"armul_K_sel_kp1\": armul_kp1[\"K\"],\n",
    "        \"armul_obj_kp1\": armul_kp1[\"obj\"],\n",
    "\n",
    "        # ARMUL diagnostics\n",
    "        \"armul_lambda_vec\": lam_vec,\n",
    "        \"armul_lambda_diag\": lam_diag,\n",
    "\n",
    "        # cluster-size–weighted RMSE (per round)\n",
    "        \"cw_rmse_ind\": cw_rmse_ind,\n",
    "        \"cw_rmse_ad\": cw_rmse_ad,\n",
    "        \"cw_rmse_armul_km1\": cw_rmse_armul_km1,\n",
    "        \"cw_rmse_armul_k\": cw_rmse_armul_k,\n",
    "        \"cw_rmse_armul_kp1\": cw_rmse_armul_kp1,\n",
    "\n",
    "        # cluster-size–weighted differences (per round, per task)\n",
    "        \"diff_cw_ind\": diff_cw_ind,\n",
    "        \"diff_cw_ad\": diff_cw_ad,\n",
    "        \"diff_cw_armul_km1\": diff_cw_armul_km1,\n",
    "        \"diff_cw_armul_k\": diff_cw_armul_k,\n",
    "        \"diff_cw_armul_kp1\": diff_cw_armul_kp1,\n",
    "    }\n",
    "\n",
    "\n",
    "# =====================\n",
    "# Run ROUNDS\n",
    "# =====================\n",
    "def run_for_delta_fixed_grid(rounds: int, delta: float, start_seed: int, fixed_lambda_list):\n",
    "    def init_agg():\n",
    "        return {\n",
    "            \"TH_true\": [],\n",
    "            \"TH_ind\": [],\n",
    "            \"TH_fx\": [],\n",
    "            \"TH_ad\": [],\n",
    "            \"TH_arm_km1\": [],\n",
    "            \"TH_arm_k\": [],\n",
    "            \"TH_arm_kp1\": [],\n",
    "            \"zI\": [],\n",
    "            \"zA_fx\": [],\n",
    "            \"zA_ad\": [],\n",
    "            \"ARI_ind\": [],\n",
    "            \"ARI_fx\": [],\n",
    "            \"ARI_ad\": [],\n",
    "            \"ARI_arm_km1\": [],\n",
    "            \"ARI_arm_k\": [],\n",
    "            \"ARI_arm_kp1\": [],\n",
    "            \"rmse_ind\": [],\n",
    "            \"rmse_fx\": [],\n",
    "            \"rmse_ad\": [],\n",
    "            \"rmse_arm_km1\": [],\n",
    "            \"rmse_arm_k\": [],\n",
    "            \"rmse_arm_kp1\": [],\n",
    "            \"armul_K_sel_km1\": [],\n",
    "            \"armul_K_sel_k\": [],\n",
    "            \"armul_K_sel_kp1\": [],\n",
    "            \"armul_lambda\": [],\n",
    "\n",
    "            \"cw_rmse_ind\": [],\n",
    "            \"cw_rmse_fx\": [],\n",
    "            \"cw_rmse_ad\": [],\n",
    "            \"cw_rmse_arm_km1\": [],\n",
    "            \"cw_rmse_arm_k\": [],\n",
    "            \"cw_rmse_arm_kp1\": [],\n",
    "\n",
    "            \"diff_cw_ind\": [],\n",
    "            \"diff_cw_fx\": [],\n",
    "            \"diff_cw_ad\": [],\n",
    "            \"diff_cw_arm_km1\": [],\n",
    "            \"diff_cw_arm_k\": [],\n",
    "            \"diff_cw_arm_kp1\": [],\n",
    "        }\n",
    "\n",
    "    aggs = {lam: init_agg() for lam in fixed_lambda_list}\n",
    "\n",
    "    for r in tqdm(range(rounds), desc=f\"delta={delta}\"):\n",
    "        core = one_round_core(delta=delta, seed=start_seed + r)\n",
    "\n",
    "        a_T = core[\"a_T\"]\n",
    "        b_T = core[\"b_T\"]\n",
    "        var_b_T = core[\"var_b_T\"]\n",
    "        theta_true = core[\"theta_true\"]\n",
    "        cluster_ids = core[\"cluster_ids_true\"]\n",
    "        cluster_weights = core[\"cluster_weights\"]\n",
    "\n",
    "        for fixed_lambda in fixed_lambda_list:\n",
    "            agg = aggs[fixed_lambda]\n",
    "\n",
    "            L_fx_T = build_lambda_fixed(M_TASKS, fixed_lambda)\n",
    "            theta_fx_pen = admm_graph_fused_lasso(a_T, b_T, L_fx_T)\n",
    "            cl_fx = clusters_from_theta(theta_fx_pen)\n",
    "            theta_fx, se_fx, cl_fx = cluster_sandwich_from_var(theta_fx_pen, cl_fx, var_b_T)\n",
    "            ari_fx = float(adjusted_rand_score(cluster_ids, cl_fx))\n",
    "            rmse_fx = float(np.sqrt(np.mean((theta_fx - theta_true) ** 2)))\n",
    "            zA_fx = (theta_fx - theta_true) / se_fx\n",
    "\n",
    "            err_fx = theta_fx - theta_true\n",
    "            cw_rmse_fx = float(np.sqrt(np.sum(cluster_weights * err_fx**2)))\n",
    "            diff_cw_fx = np.sqrt(cluster_weights) * err_fx\n",
    "\n",
    "            agg[\"TH_true\"].append(theta_true)\n",
    "            agg[\"TH_ind\"].append(core[\"theta_ind\"])\n",
    "            agg[\"TH_fx\"].append(theta_fx)\n",
    "            agg[\"TH_ad\"].append(core[\"theta_ad\"])\n",
    "            agg[\"TH_arm_km1\"].append(core[\"theta_armul_km1\"])\n",
    "            agg[\"TH_arm_k\"].append(core[\"theta_armul_k\"])\n",
    "            agg[\"TH_arm_kp1\"].append(core[\"theta_armul_kp1\"])\n",
    "\n",
    "            agg[\"zI\"].append(core[\"zI\"])\n",
    "            agg[\"zA_fx\"].append(zA_fx)\n",
    "            agg[\"zA_ad\"].append(core[\"zA_ad\"])\n",
    "\n",
    "            agg[\"ARI_ind\"].append(core[\"ARI_ind\"])\n",
    "            agg[\"ARI_fx\"].append(ari_fx)\n",
    "            agg[\"ARI_ad\"].append(core[\"ARI_ad\"])\n",
    "            agg[\"ARI_arm_km1\"].append(core[\"ARI_armul_km1\"])\n",
    "            agg[\"ARI_arm_k\"].append(core[\"ARI_armul_k\"])\n",
    "            agg[\"ARI_arm_kp1\"].append(core[\"ARI_armul_kp1\"])\n",
    "\n",
    "            agg[\"rmse_ind\"].append(core[\"rmse_ind\"])\n",
    "            agg[\"rmse_fx\"].append(rmse_fx)\n",
    "            agg[\"rmse_ad\"].append(core[\"rmse_ad\"])\n",
    "            agg[\"rmse_arm_km1\"].append(core[\"rmse_armul_km1\"])\n",
    "            agg[\"rmse_arm_k\"].append(core[\"rmse_armul_k\"])\n",
    "            agg[\"rmse_arm_kp1\"].append(core[\"rmse_armul_kp1\"])\n",
    "\n",
    "            agg[\"armul_K_sel_km1\"].append(core[\"armul_K_sel_km1\"])\n",
    "            agg[\"armul_K_sel_k\"].append(core[\"armul_K_sel_k\"])\n",
    "            agg[\"armul_K_sel_kp1\"].append(core[\"armul_K_sel_kp1\"])\n",
    "            agg[\"armul_lambda\"].append(core[\"armul_lambda_vec\"])\n",
    "\n",
    "            agg[\"cw_rmse_ind\"].append(core[\"cw_rmse_ind\"])\n",
    "            agg[\"cw_rmse_fx\"].append(cw_rmse_fx)\n",
    "            agg[\"cw_rmse_ad\"].append(core[\"cw_rmse_ad\"])\n",
    "            agg[\"cw_rmse_arm_km1\"].append(core[\"cw_rmse_armul_km1\"])\n",
    "            agg[\"cw_rmse_arm_k\"].append(core[\"cw_rmse_armul_k\"])\n",
    "            agg[\"cw_rmse_arm_kp1\"].append(core[\"cw_rmse_armul_kp1\"])\n",
    "\n",
    "            agg[\"diff_cw_ind\"].append(core[\"diff_cw_ind\"])\n",
    "            agg[\"diff_cw_fx\"].append(diff_cw_fx)\n",
    "            agg[\"diff_cw_ad\"].append(core[\"diff_cw_ad\"])\n",
    "            agg[\"diff_cw_arm_km1\"].append(core[\"diff_cw_armul_km1\"])\n",
    "            agg[\"diff_cw_arm_k\"].append(core[\"diff_cw_armul_k\"])\n",
    "            agg[\"diff_cw_arm_kp1\"].append(core[\"diff_cw_armul_kp1\"])\n",
    "\n",
    "    res_by_lam = {}\n",
    "    for fixed_lambda, agg in aggs.items():\n",
    "        TH_true_all      = np.concatenate(agg[\"TH_true\"])\n",
    "        TH_ind_all       = np.concatenate(agg[\"TH_ind\"])\n",
    "        TH_fx_all        = np.concatenate(agg[\"TH_fx\"])\n",
    "        TH_ad_all        = np.concatenate(agg[\"TH_ad\"])\n",
    "        TH_arm_km1_all   = np.concatenate(agg[\"TH_arm_km1\"])\n",
    "        TH_arm_k_all     = np.concatenate(agg[\"TH_arm_k\"])\n",
    "        TH_arm_kp1_all   = np.concatenate(agg[\"TH_arm_kp1\"])\n",
    "\n",
    "        zI_all    = np.concatenate(agg[\"zI\"])\n",
    "        zA_fx_all = np.concatenate(agg[\"zA_fx\"])\n",
    "        zA_ad_all = np.concatenate(agg[\"zA_ad\"])\n",
    "\n",
    "        ARI_ind_arr      = np.array(agg[\"ARI_ind\"])\n",
    "        ARI_fx_arr       = np.array(agg[\"ARI_fx\"])\n",
    "        ARI_ad_arr       = np.array(agg[\"ARI_ad\"])\n",
    "        ARI_arm_km1_arr  = np.array(agg[\"ARI_arm_km1\"])\n",
    "        ARI_arm_k_arr    = np.array(agg[\"ARI_arm_k\"])\n",
    "        ARI_arm_kp1_arr  = np.array(agg[\"ARI_arm_kp1\"])\n",
    "\n",
    "        rmse_ind_arr     = np.array(agg[\"rmse_ind\"])\n",
    "        rmse_fx_arr      = np.array(agg[\"rmse_fx\"])\n",
    "        rmse_ad_arr      = np.array(agg[\"rmse_ad\"])\n",
    "        rmse_arm_km1_arr = np.array(agg[\"rmse_arm_km1\"])\n",
    "        rmse_arm_k_arr   = np.array(agg[\"rmse_arm_k\"])\n",
    "        rmse_arm_kp1_arr = np.array(agg[\"rmse_arm_kp1\"])\n",
    "\n",
    "        armul_K_sel_km1_all = np.array(agg[\"armul_K_sel_km1\"])\n",
    "        armul_K_sel_k_all   = np.array(agg[\"armul_K_sel_k\"])\n",
    "        armul_K_sel_kp1_all = np.array(agg[\"armul_K_sel_kp1\"])\n",
    "        armul_lambda_all    = np.array(agg[\"armul_lambda\"])\n",
    "\n",
    "        cw_rmse_ind_arr     = np.array(agg[\"cw_rmse_ind\"])\n",
    "        cw_rmse_fx_arr      = np.array(agg[\"cw_rmse_fx\"])\n",
    "        cw_rmse_ad_arr      = np.array(agg[\"cw_rmse_ad\"])\n",
    "        cw_rmse_arm_km1_arr = np.array(agg[\"cw_rmse_arm_km1\"])\n",
    "        cw_rmse_arm_k_arr   = np.array(agg[\"cw_rmse_arm_k\"])\n",
    "        cw_rmse_arm_kp1_arr = np.array(agg[\"cw_rmse_arm_kp1\"])\n",
    "\n",
    "        diff_cw_ind_all     = np.concatenate(agg[\"diff_cw_ind\"])\n",
    "        diff_cw_fx_all      = np.concatenate(agg[\"diff_cw_fx\"])\n",
    "        diff_cw_ad_all      = np.concatenate(agg[\"diff_cw_ad\"])\n",
    "        diff_cw_arm_km1_all = np.concatenate(agg[\"diff_cw_arm_km1\"])\n",
    "        diff_cw_arm_k_all   = np.concatenate(agg[\"diff_cw_arm_k\"])\n",
    "        diff_cw_arm_kp1_all = np.concatenate(agg[\"diff_cw_arm_kp1\"])\n",
    "\n",
    "        res = {\n",
    "            # estimators\n",
    "            \"theta_true_all\": TH_true_all,\n",
    "            \"theta_ind_all\": TH_ind_all,\n",
    "            \"theta_fx_all\": TH_fx_all,\n",
    "            \"theta_ad_all\": TH_ad_all,\n",
    "            \"theta_armul_km1_all\": TH_arm_km1_all,\n",
    "            \"theta_armul_k_all\": TH_arm_k_all,\n",
    "            \"theta_armul_kp1_all\": TH_arm_kp1_all,\n",
    "\n",
    "            # z-scores (base three)\n",
    "            \"zI\": zI_all,\n",
    "            \"zA_fx\": zA_fx_all,\n",
    "            \"zA_ad\": zA_ad_all,\n",
    "\n",
    "            # ARI per round\n",
    "            \"ARI_ind\": ARI_ind_arr,\n",
    "            \"ARI_fx\": ARI_fx_arr,\n",
    "            \"ARI_ad\": ARI_ad_arr,\n",
    "            \"ARI_armul_km1\": ARI_arm_km1_arr,\n",
    "            \"ARI_armul_k\":   ARI_arm_k_arr,\n",
    "            \"ARI_armul_kp1\": ARI_arm_kp1_arr,\n",
    "\n",
    "            # RMSE per round\n",
    "            \"rmse_ind_all\":   rmse_ind_arr,\n",
    "            \"rmse_fx_all\":    rmse_fx_arr,\n",
    "            \"rmse_ad_all\":    rmse_ad_arr,\n",
    "            \"rmse_armul_km1_all\": rmse_arm_km1_arr,\n",
    "            \"rmse_armul_k_all\":   rmse_arm_k_arr,\n",
    "            \"rmse_armul_kp1_all\": rmse_arm_kp1_arr,\n",
    "\n",
    "            # summaries (unweighted)\n",
    "            \"mean_rmse_ind\": float(np.mean(rmse_ind_arr)),\n",
    "            \"mean_rmse_fx\":  float(np.mean(rmse_fx_arr)),\n",
    "            \"mean_rmse_ad\":  float(np.mean(rmse_ad_arr)),\n",
    "            \"mean_rmse_armul_km1\": float(np.mean(rmse_arm_km1_arr)),\n",
    "            \"mean_rmse_armul_k\":   float(np.mean(rmse_arm_k_arr)),\n",
    "            \"mean_rmse_armul_kp1\": float(np.mean(rmse_arm_kp1_arr)),\n",
    "            \"mean_ari_ind\": float(np.mean(ARI_ind_arr)),\n",
    "            \"mean_ari_fx\":  float(np.mean(ARI_fx_arr)),\n",
    "            \"mean_ari_ad\":  float(np.mean(ARI_ad_arr)),\n",
    "            \"mean_ari_armul_km1\": float(np.mean(ARI_arm_km1_arr)),\n",
    "            \"mean_ari_armul_k\":   float(np.mean(ARI_arm_k_arr)),\n",
    "            \"mean_ari_armul_kp1\": float(np.mean(ARI_arm_kp1_arr)),\n",
    "\n",
    "            # selections & diagnostics\n",
    "            \"armul_K_sel_km1_all\": armul_K_sel_km1_all,\n",
    "            \"armul_K_sel_k_all\":   armul_K_sel_k_all,\n",
    "            \"armul_K_sel_kp1_all\": armul_K_sel_kp1_all,\n",
    "            \"armul_lambda_all\": armul_lambda_all,\n",
    "\n",
    "            # cluster-size–weighted RMSE (per round)\n",
    "            \"cw_rmse_ind_all\":     cw_rmse_ind_arr,\n",
    "            \"cw_rmse_fx_all\":      cw_rmse_fx_arr,\n",
    "            \"cw_rmse_ad_all\":      cw_rmse_ad_arr,\n",
    "            \"cw_rmse_armul_km1_all\": cw_rmse_arm_km1_arr,\n",
    "            \"cw_rmse_armul_k_all\":   cw_rmse_arm_k_arr,\n",
    "            \"cw_rmse_armul_kp1_all\": cw_rmse_arm_kp1_arr,\n",
    "\n",
    "            # cluster-size–weighted differences (all rounds, all tasks)\n",
    "            \"diff_cw_ind_all\":     diff_cw_ind_all,\n",
    "            \"diff_cw_fx_all\":      diff_cw_fx_all,\n",
    "            \"diff_cw_ad_all\":      diff_cw_ad_all,\n",
    "            \"diff_cw_armul_km1_all\": diff_cw_arm_km1_all,\n",
    "            \"diff_cw_armul_k_all\":   diff_cw_arm_k_all,\n",
    "            \"diff_cw_armul_kp1_all\": diff_cw_arm_kp1_all,\n",
    "\n",
    "            # mean cluster-size–weighted RMSE\n",
    "            \"mean_cw_rmse_ind\": float(np.mean(cw_rmse_ind_arr)),\n",
    "            \"mean_cw_rmse_fx\":  float(np.mean(cw_rmse_fx_arr)),\n",
    "            \"mean_cw_rmse_ad\":  float(np.mean(cw_rmse_ad_arr)),\n",
    "            \"mean_cw_rmse_armul_km1\": float(np.mean(cw_rmse_arm_km1_arr)),\n",
    "            \"mean_cw_rmse_armul_k\":   float(np.mean(cw_rmse_arm_k_arr)),\n",
    "            \"mean_cw_rmse_armul_kp1\": float(np.mean(cw_rmse_arm_kp1_arr)),\n",
    "        }\n",
    "\n",
    "        res_by_lam[fixed_lambda] = res\n",
    "\n",
    "    return res_by_lam\n",
    "\n",
    "# =====================\n",
    "# Saving helpers\n",
    "# =====================\n",
    "def save_delta_results(outdir, delta, res_by_lam):\n",
    "    fixed_lams = np.array(sorted(res_by_lam.keys()), dtype=float)\n",
    "    npz_path = os.path.join(outdir, f\"continuous_delta{delta:.2f}.npz\")\n",
    "\n",
    "    to_save = {}\n",
    "    to_save[\"fixed_lambda_values\"] = fixed_lams\n",
    "\n",
    "    rep_res = res_by_lam[fixed_lams[0]]\n",
    "\n",
    "    shared_array_keys = [\n",
    "        \"theta_true_all\",\n",
    "        \"theta_ind_all\",\n",
    "        \"theta_ad_all\",\n",
    "        \"theta_armul_km1_all\",\n",
    "        \"theta_armul_k_all\",\n",
    "        \"theta_armul_kp1_all\",\n",
    "        \"zI\",\n",
    "        \"zA_ad\",\n",
    "        \"ARI_ind\",\n",
    "        \"ARI_ad\",\n",
    "        \"ARI_armul_km1\",\n",
    "        \"ARI_armul_k\",\n",
    "        \"ARI_armul_kp1\",\n",
    "        \"rmse_ind_all\",\n",
    "        \"rmse_ad_all\",\n",
    "        \"rmse_armul_km1_all\",\n",
    "        \"rmse_armul_k_all\",\n",
    "        \"rmse_armul_kp1_all\",\n",
    "        \"armul_K_sel_km1_all\",\n",
    "        \"armul_K_sel_k_all\",\n",
    "        \"armul_K_sel_kp1_all\",\n",
    "        \"armul_lambda_all\",\n",
    "\n",
    "        \"cw_rmse_ind_all\",\n",
    "        \"cw_rmse_ad_all\",\n",
    "        \"cw_rmse_armul_km1_all\",\n",
    "        \"cw_rmse_armul_k_all\",\n",
    "        \"cw_rmse_armul_kp1_all\",\n",
    "        \"diff_cw_ind_all\",\n",
    "        \"diff_cw_ad_all\",\n",
    "        \"diff_cw_armul_km1_all\",\n",
    "        \"diff_cw_armul_k_all\",\n",
    "        \"diff_cw_armul_kp1_all\",\n",
    "    ]\n",
    "\n",
    "    shared_summary_keys = [\n",
    "        \"mean_rmse_ind\",\n",
    "        \"mean_rmse_ad\",\n",
    "        \"mean_rmse_armul_km1\",\n",
    "        \"mean_rmse_armul_k\",\n",
    "        \"mean_rmse_armul_kp1\",\n",
    "        \"mean_ari_ind\",\n",
    "        \"mean_ari_ad\",\n",
    "        \"mean_ari_armul_km1\",\n",
    "        \"mean_ari_armul_k\",\n",
    "        \"mean_ari_armul_kp1\",\n",
    "\n",
    "        \"mean_cw_rmse_ind\",\n",
    "        \"mean_cw_rmse_ad\",\n",
    "        \"mean_cw_rmse_armul_km1\",\n",
    "        \"mean_cw_rmse_armul_k\",\n",
    "        \"mean_cw_rmse_armul_kp1\",\n",
    "    ]\n",
    "\n",
    "    for key in shared_array_keys:\n",
    "        to_save[key] = rep_res[key]\n",
    "    for key in shared_summary_keys:\n",
    "        to_save[key] = np.array(rep_res[key], dtype=float)\n",
    "\n",
    "    per_lam_array_keys = [\n",
    "        \"theta_fx_all\",\n",
    "        \"zA_fx\",\n",
    "        \"ARI_fx\",\n",
    "        \"rmse_fx_all\",\n",
    "        \"cw_rmse_fx_all\",\n",
    "        \"diff_cw_fx_all\",\n",
    "    ]\n",
    "    per_lam_summary_keys = [\n",
    "        \"mean_rmse_fx\",\n",
    "        \"mean_ari_fx\",\n",
    "        \"mean_cw_rmse_fx\",\n",
    "    ]\n",
    "\n",
    "    for k, lam in enumerate(fixed_lams):\n",
    "        res = res_by_lam[lam]\n",
    "        prefix = f\"lam{k}_\"\n",
    "        for key in per_lam_array_keys:\n",
    "            to_save[prefix + key] = res[key]\n",
    "        for key in per_lam_summary_keys:\n",
    "            to_save[prefix + key] = np.array(res[key], dtype=float)\n",
    "\n",
    "    np.savez(npz_path, **to_save)\n",
    "\n",
    "\n",
    "def append_summary_csv(outdir, delta, res_by_lam):\n",
    "    \"\"\"\n",
    "    One CSV file, with one row per (delta, lambda).\n",
    "    \"\"\"\n",
    "    csv_path = os.path.join(outdir, \"summary.csv\")\n",
    "    write_header = not os.path.exists(csv_path)\n",
    "\n",
    "    import csv\n",
    "    fieldnames = [\n",
    "        \"model\",\n",
    "        \"delta\",\n",
    "        \"fixed_lambda\",\n",
    "        \"rounds\",\n",
    "        \"m_tasks\",\n",
    "        \"mean_rmse_individual\",\n",
    "        \"mean_rmse_fixed\",\n",
    "        \"mean_rmse_adaptive\",\n",
    "        \"mean_rmse_armul_km1\",\n",
    "        \"mean_rmse_armul_k\",\n",
    "        \"mean_rmse_armul_kp1\",\n",
    "        \"mean_ari_individual\",\n",
    "        \"mean_ari_fixed\",\n",
    "        \"mean_ari_adaptive\",\n",
    "        \"mean_ari_armul_km1\",\n",
    "        \"mean_ari_armul_k\",\n",
    "        \"mean_ari_armul_kp1\",\n",
    "        \"notes\",\n",
    "    ]\n",
    "\n",
    "    with open(csv_path, \"a\", newline=\"\") as f:\n",
    "        w = csv.DictWriter(f, fieldnames=fieldnames)\n",
    "        if write_header:\n",
    "            w.writeheader()\n",
    "\n",
    "        for fixed_lambda, res in res_by_lam.items():\n",
    "            row = {\n",
    "                \"model\": \"continuous_lgbm_pooledSE_plus_ARMUL\",\n",
    "                \"delta\": float(delta),\n",
    "                \"fixed_lambda\": float(fixed_lambda),\n",
    "                \"rounds\": ROUNDS,\n",
    "                \"m_tasks\": M_TASKS,\n",
    "                \"mean_rmse_individual\": res[\"mean_rmse_ind\"],\n",
    "                \"mean_rmse_fixed\":      res[\"mean_rmse_fx\"],\n",
    "                \"mean_rmse_adaptive\":   res[\"mean_rmse_ad\"],\n",
    "                \"mean_rmse_armul_km1\":  res[\"mean_rmse_armul_km1\"],\n",
    "                \"mean_rmse_armul_k\":    res[\"mean_rmse_armul_k\"],\n",
    "                \"mean_rmse_armul_kp1\":  res[\"mean_rmse_armul_kp1\"],\n",
    "                \"mean_ari_individual\":  res[\"mean_ari_ind\"],\n",
    "                \"mean_ari_fixed\":       res[\"mean_ari_fx\"],\n",
    "                \"mean_ari_adaptive\":    res[\"mean_ari_ad\"],\n",
    "                \"mean_ari_armul_km1\":   res[\"mean_ari_armul_km1\"],\n",
    "                \"mean_ari_armul_k\":     res[\"mean_ari_armul_k\"],\n",
    "                \"mean_ari_armul_kp1\":   res[\"mean_ari_armul_kp1\"],\n",
    "                \"notes\": (\n",
    "                    f\"PLM with orthogonal residuals, LightGBM nuisances, \"\n",
    "                    f\"ARMUL per-task λ_j = C_LAM * sqrt(1 / n_j) with n_j=|T_j|, C_LAM={C_LAM}. \"\n",
    "                    f\"Variants K-1/K/K+1 around true K={K_CLUST}. \"\n",
    "                    \"Clusterwise IF SE via pooled per-task var_b; no nuisance refit. \"\n",
    "                    \"Adaptive fused penalties built from individual PLM estimator b_T. \"\n",
    "                    f\"Fixed fused λ = {fixed_lambda}.\"\n",
    "                ),\n",
    "            }\n",
    "            w.writerow(row)\n",
    "\n",
    "# =====================\n",
    "# Main\n",
    "# =====================\n",
    "if __name__ == \"__main__\":\n",
    "    timestamp = datetime.now().strftime(\"%Y%m%d_%H%M%S\")\n",
    "    base_out = ensure_dir(os.path.join(\"results_plm_cont\", timestamp))\n",
    "    model_out = ensure_dir(os.path.join(base_out, \"plm_continuous_treatment\"))\n",
    "\n",
    "    print(f\"Running PLM-continuous with IND / FIX / ADAPT and ARMUL(K-1,K,K+1). \"\n",
    "          f\"Rounds={ROUNDS}, m={M_TASKS}, true K={K_CLUST}\")\n",
    "    print(f\"Fixed λ grid (per delta): {FIXED_LAMBDA_LIST}\")\n",
    "\n",
    "    for delta in DELTAS:\n",
    "        print(f\"\\n=== delta = {delta} ===\")\n",
    "        res_by_lam = run_for_delta_fixed_grid(\n",
    "            rounds=ROUNDS,\n",
    "            delta=delta,\n",
    "            start_seed=SEED_BASE,\n",
    "            fixed_lambda_list=FIXED_LAMBDA_LIST,\n",
    "        )\n",
    "\n",
    "        lam0 = sorted(res_by_lam.keys())[0]\n",
    "        shared = res_by_lam[lam0]\n",
    "        print(\"Shared metrics (IND / ADAPT / ARMUL):\")\n",
    "        print({\n",
    "            \"delta\": delta,\n",
    "            \"Mean_RMSE_Indiv\":        shared[\"mean_rmse_ind\"],\n",
    "            \"Mean_RMSE_Adapt\":        shared[\"mean_rmse_ad\"],\n",
    "            \"Mean_RMSE_ARMUL(K-1)\":   shared[\"mean_rmse_armul_km1\"],\n",
    "            \"Mean_RMSE_ARMUL(K)\":     shared[\"mean_rmse_armul_k\"],\n",
    "            \"Mean_RMSE_ARMUL(K+1)\":   shared[\"mean_rmse_armul_kp1\"],\n",
    "            \"Mean_ARI_Indiv\":         shared[\"mean_ari_ind\"],\n",
    "            \"Mean_ARI_Adapt\":         shared[\"mean_ari_ad\"],\n",
    "            \"Mean_ARI_ARMUL(K-1)\":    shared[\"mean_ari_armul_km1\"],\n",
    "            \"Mean_ARI_ARMUL(K)\":      shared[\"mean_ari_armul_k\"],\n",
    "            \"Mean_ARI_ARMUL(K+1)\":    shared[\"mean_ari_armul_kp1\"],\n",
    "        })\n",
    "\n",
    "        print(\"Fixed-lambda sweep (only FIX metrics):\")\n",
    "        for lam in sorted(res_by_lam.keys()):\n",
    "            res = res_by_lam[lam]\n",
    "            print({\n",
    "                \"delta\": delta,\n",
    "                \"fixed_lambda\": lam,\n",
    "                \"Mean_RMSE_Fixed\": res[\"mean_rmse_fx\"],\n",
    "                \"Mean_ARI_Fixed\":  res[\"mean_ari_fx\"],\n",
    "            })\n",
    "\n",
    "        save_delta_results(model_out, delta, res_by_lam)\n",
    "        append_summary_csv(base_out, delta, res_by_lam)\n",
    "\n",
    "    manifest = {\n",
    "        \"model\": \"continuous_lgbm_pooledSE_plus_ARMUL\",\n",
    "        \"deltas\": DELTAS,\n",
    "        \"rounds\": ROUNDS,\n",
    "        \"m_tasks\": M_TASKS,\n",
    "        \"timestamp\": timestamp,\n",
    "        \"true_K\": K_CLUST,\n",
    "        \"centers\": \"theta_k = k*delta - (K+1)delta/2, k=1..K\",\n",
    "        \"nuisance_target_split\": {\"A_frac\": A_FRAC, \"role_swap\": False},\n",
    "        \"fusion_hyperparams\": {\n",
    "            \"adaptive\": {\n",
    "                \"gamma\": GAMMA_DEFAULT,\n",
    "                \"tau\": ADAPT_TAU_DEFAULT,\n",
    "                \"eps\": ADAPT_EPS,\n",
    "                \"c_w\": C_W,\n",
    "                \"init_from\": \"individual PLM estimator b_T\",\n",
    "            },\n",
    "            \"fixed_lambda_list\": FIXED_LAMBDA_LIST,\n",
    "        },\n",
    "        \"armul\": {\n",
    "            \"C_LAM\": C_LAM,\n",
    "            \"variants\": [\"K-1\", \"K (oracle)\", \"K+1\"],\n",
    "            \"n_init\": ARMUL_N_INIT,\n",
    "            \"n_iter\": ARMUL_N_ITER,\n",
    "            \"tol\": ARMUL_TOL,\n",
    "            \"seed\": ARMUL_SEED,\n",
    "            \"lambda_rule\": \"λ_j = C_LAM * sqrt(1 / n_j), n_j = |T_j|\",\n",
    "        },\n",
    "        \"nuisance\": {\n",
    "            \"type\": \"LGBMRegressor\",\n",
    "            \"params\": {\n",
    "                \"n_estimators\": LGBM_ESTIMATORS,\n",
    "                \"learning_rate\": LGBM_LEARNING_RATE,\n",
    "                \"num_leaves\": LGBM_NUM_LEAVES,\n",
    "                \"min_child_samples\": LGBM_MIN_CHILD_SAMPLES,\n",
    "                \"subsample\": LGBM_SUBSAMPLE,\n",
    "                \"colsample_bytree\": LGBM_COLSAMPLE,\n",
    "                \"reg_alpha\": LGBM_REG_ALPHA,\n",
    "                \"reg_lambda\": LGBM_REG_LAMBDA,\n",
    "                \"n_jobs\": N_JOBS,\n",
    "            },\n",
    "        },\n",
    "        \"inference\": (\n",
    "            \"Orthogonal PLM; per-task IF var from rD*(rY-θrD); \"\n",
    "            \"clusterwise SE via pooled var_b over tasks in each cluster; \"\n",
    "            \"no nuisance refit. Adaptive fused penalties built from individual PLM b_T.\"\n",
    "        ),\n",
    "    }\n",
    "    with open(os.path.join(base_out, \"manifest.json\"), \"w\") as f:\n",
    "        json.dump(manifest, f, indent=2)\n",
    "\n",
    "    print(f\"\\nAll results saved under: ./results_plm_cont/{timestamp}/\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "efb16809-b65d-4d64-b8ff-9998f27d9a7d",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
