{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6173701f-12a9-40f8-9501-5c5fb96b4664",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import json\n",
    "from datetime import datetime\n",
    "\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "from sklearn.metrics import adjusted_rand_score\n",
    "from sklearn.cluster import KMeans\n",
    "from lightgbm import LGBMRegressor, LGBMClassifier\n",
    "\n",
    "# =====================\n",
    "# Hyperparameters\n",
    "# =====================\n",
    "SEED_BASE   = 1\n",
    "ROUNDS      = 100\n",
    "M_TASKS     = 20         # number of tasks\n",
    "K_CLUST     = 3          # true latent clusters in DGP\n",
    "DELTAS      = [1/3, 2/3, 1]\n",
    "\n",
    "A_FRAC      = 0.5        # split-A fraction for nuisances\n",
    "\n",
    "# Noise\n",
    "NOISE_Y     = 0.3\n",
    "\n",
    "# LightGBM\n",
    "N_JOBS = 1\n",
    "LGBM_ESTIMATORS        = 100\n",
    "LGBM_LEARNING_RATE     = 0.03\n",
    "LGBM_NUM_LEAVES        = 20\n",
    "LGBM_MIN_CHILD_SAMPLES = 20\n",
    "LGBM_SUBSAMPLE         = 0.9\n",
    "LGBM_COLSAMPLE         = 0.9\n",
    "LGBM_REG_ALPHA         = 0.0\n",
    "LGBM_REG_LAMBDA        = 0.0\n",
    "LGBM_SEED_OFFSET       = 1000\n",
    "\n",
    "# Graph-fused (Adaptive)\n",
    "C_W               = .1\n",
    "ADAPT_EPS         = 1e-12\n",
    "GAMMA_DEFAULT     = 2\n",
    "ADAPT_TAU_DEFAULT = 10\n",
    "\n",
    "# Fixed-lambda sweep\n",
    "FIXED_LAMBDA_LIST = [1e-5, 1e-4, 1e-3, 1e-2, 1e-1]\n",
    "\n",
    "# ARMUL controls\n",
    "C_LAM            = 100.0 \n",
    "ARMUL_K_MAX      = M_TASKS\n",
    "ARMUL_N_INIT     = 1\n",
    "ARMUL_N_ITER     = 300\n",
    "ARMUL_TOL        = 1e-7\n",
    "ARMUL_SEED       = 123\n",
    "\n",
    "EPS = 1e-12\n",
    "\n",
    "# =====================\n",
    "# Utilities\n",
    "# =====================\n",
    "\n",
    "def ensure_dir(path: str):\n",
    "    os.makedirs(path, exist_ok=True)\n",
    "    return path\n",
    "\n",
    "\n",
    "def split_two_parts(n, a_frac, seed):\n",
    "    rng = np.random.default_rng(seed)\n",
    "    idx = np.arange(n)\n",
    "    rng.shuffle(idx)\n",
    "    n_a = int(np.floor(a_frac * n))\n",
    "    return idx[:n_a], idx[n_a:]\n",
    "\n",
    "\n",
    "def sigmoid(x):\n",
    "    return 1 / (1 + np.exp(-x))\n",
    "\n",
    "\n",
    "# =====================\n",
    "# DID Data Generating Process\n",
    "# =====================\n",
    "\n",
    "def make_did_tasks(m: int, K: int, delta: float, seed: int):\n",
    "    \"\"\"Panel DID DGP in the spirit of Sant'Anna & Zhao (2020).\n",
    "\n",
    "    For each task j and unit i:\n",
    "      Y_{ji0} = μ_{j0}(X_{ji}) + ε_{ji0}\n",
    "      Y_{ji1} = μ_{j1}(X_{ji}) + τ_j * D_{ji} + ε_{ji1}\n",
    "    \"\"\"\n",
    "    rng = np.random.default_rng(seed)\n",
    "    tau_centers = np.arange(1, K + 1, dtype=float) * delta - (K + 1) * delta / 2.0\n",
    "    cluster_ids = np.repeat(np.arange(K), repeats=int(np.ceil(m / K)))[:m]\n",
    "    rng.shuffle(cluster_ids)\n",
    "\n",
    "    tasks = []\n",
    "    tau_true = np.zeros(m)\n",
    "\n",
    "    for j in range(m):\n",
    "        n_j = 3200 + 80 * j        # units per task\n",
    "        p_j = 5 + j                # covariates per task\n",
    "\n",
    "        X = rng.normal(0.0, 1.0, size=(n_j, p_j))\n",
    "\n",
    "        # Group assignment D ~ Bern(sigmoid score)\n",
    "        score_D = X[:, 3] * X[:, 4] - X[:, 0] * X[:, 1]\n",
    "        p = np.clip(sigmoid(score_D), 0.05, 0.95)\n",
    "        D = rng.binomial(1, p).astype(float)\n",
    "\n",
    "        # Baselines μ_{j0}(X), μ_{j1}(X)\n",
    "        mu0 = 0.0\n",
    "        mu1 = 0.0\n",
    "        for i in range(p_j):\n",
    "            sig = np.exp(X[:, i]) / (1.0 + np.exp(X[:, i]))\n",
    "            mu0 += (+0.7) ** (i+1) * sig\n",
    "            mu1 += (-0.7) ** (i+1) * sig\n",
    "\n",
    "        # Clustered effect τ_j\n",
    "        k = cluster_ids[j]\n",
    "        tau = tau_centers[k]\n",
    "        tau_true[j] = tau\n",
    "\n",
    "        # Outcomes for t=0,1 (panel)\n",
    "        eps0 = rng.normal(0.0, NOISE_Y, size=n_j)\n",
    "        eps1 = rng.normal(0.0, NOISE_Y, size=n_j)\n",
    "        Y0 = mu0 + eps0\n",
    "        Y1 = mu1 + tau * D + eps1\n",
    "\n",
    "        # Stack long: (X, D, T, Y) with T∈{0,1},\n",
    "        # first n_j rows: t=0, second n_j rows: t=1.\n",
    "        X_long = np.vstack([X, X])\n",
    "        D_long = np.concatenate([D, D])\n",
    "        T_long = np.concatenate([np.zeros(n_j), np.ones(n_j)])\n",
    "        Y_long = np.concatenate([Y0, Y1])\n",
    "\n",
    "        tasks.append({\"X\": X_long, \"D\": D_long, \"T\": T_long, \"Y\": Y_long})\n",
    "\n",
    "    return tasks, tau_true, cluster_ids\n",
    "\n",
    "\n",
    "# =====================\n",
    "# Models\n",
    "# =====================\n",
    "\n",
    "def _lgbm_reg(seed):\n",
    "    return LGBMRegressor(\n",
    "        n_estimators=LGBM_ESTIMATORS,\n",
    "        learning_rate=LGBM_LEARNING_RATE,\n",
    "        num_leaves=LGBM_NUM_LEAVES,\n",
    "        min_child_samples=LGBM_MIN_CHILD_SAMPLES,\n",
    "        subsample=LGBM_SUBSAMPLE,\n",
    "        colsample_bytree=LGBM_COLSAMPLE,\n",
    "        reg_alpha=LGBM_REG_ALPHA,\n",
    "        reg_lambda=LGBM_REG_LAMBDA,\n",
    "        objective=\"regression\",\n",
    "        random_state=seed,\n",
    "        n_jobs=N_JOBS,\n",
    "        verbose=-1,\n",
    "    )\n",
    "\n",
    "\n",
    "def _lgbm_clf(seed):\n",
    "    return LGBMClassifier(\n",
    "        n_estimators=LGBM_ESTIMATORS,\n",
    "        learning_rate=LGBM_LEARNING_RATE,\n",
    "        num_leaves=LGBM_NUM_LEAVES,\n",
    "        min_child_samples=LGBM_MIN_CHILD_SAMPLES,\n",
    "        subsample=LGBM_SUBSAMPLE,\n",
    "        colsample_bytree=LGBM_COLSAMPLE,\n",
    "        reg_alpha=LGBM_REG_ALPHA,\n",
    "        reg_lambda=LGBM_REG_LAMBDA,\n",
    "        objective=\"binary\",\n",
    "        random_state=seed,\n",
    "        n_jobs=N_JOBS,\n",
    "        verbose=-1,\n",
    "    )\n",
    "\n",
    "\n",
    "# =====================\n",
    "# Sant'Anna & Zhao (2020) panel DR DID stats per task\n",
    "# =====================\n",
    "\n",
    "def panel_dr_did_stats(X_long, D_long, T_long, Y_long, A_units, T_units, seed):\n",
    "    \"\"\"Compute (a_j, b_j, var_b_j) for one task j using panel DR DID.\n",
    "    \"\"\"\n",
    "    EPS_LOC = 1e-12\n",
    "\n",
    "    n_long = len(Y_long)\n",
    "    n_units = n_long // 2\n",
    "    assert n_long % 2 == 0\n",
    "\n",
    "    # panel representation\n",
    "    X_units = X_long[:n_units]\n",
    "    D_units = D_long[:n_units]\n",
    "    Y0 = Y_long[:n_units]\n",
    "    Y1 = Y_long[n_units:]\n",
    "    dY = Y1 - Y0\n",
    "\n",
    "    # ---- Nuisances on A_units ----\n",
    "    X_A = X_units[A_units]\n",
    "    D_A = D_units[A_units].astype(int)\n",
    "    dY_A = dY[A_units]\n",
    "\n",
    "    # Propensity p(X)\n",
    "    mD = _lgbm_clf(seed)\n",
    "    mD.fit(X_A, D_A)\n",
    "    pA = np.clip(mD.predict_proba(X_A)[:, 1], 0.01, 0.99)\n",
    "\n",
    "    # Control trend m0Δ(X) = E[ΔY | D=0, X]\n",
    "    mask_ctrl_A = (D_A == 0)\n",
    "    if mask_ctrl_A.sum() < 10:\n",
    "        # fallback: use all A_units if few controls\n",
    "        X_ctrl_A = X_A\n",
    "        dY_ctrl_A = dY_A\n",
    "    else:\n",
    "        X_ctrl_A = X_A[mask_ctrl_A]\n",
    "        dY_ctrl_A = dY_A[mask_ctrl_A]\n",
    "\n",
    "    m0d = _lgbm_reg(seed + 1)\n",
    "    m0d.fit(X_ctrl_A, dY_ctrl_A)\n",
    "\n",
    "    # Normalizing expectations on A_units\n",
    "    D_bar_hat = float(D_A.mean())\n",
    "    D_bar_hat = max(D_bar_hat, EPS_LOC)\n",
    "    vA = pA * (1.0 - D_A) / (1.0 - pA)\n",
    "    v_bar_hat = float(vA.mean())\n",
    "    v_bar_hat = max(v_bar_hat, EPS_LOC)\n",
    "\n",
    "    X_T = X_units[T_units]\n",
    "    D_T = D_units[T_units].astype(int)\n",
    "    dY_T = dY[T_units]\n",
    "\n",
    "    pT = np.clip(mD.predict_proba(X_T)[:, 1], 0.01, 0.99)\n",
    "    w1 = D_T / D_bar_hat\n",
    "    vT = pT * (1.0 - D_T) / (1.0 - pT)\n",
    "    w0 = vT / v_bar_hat\n",
    "    m0d_T = m0d.predict(X_T)\n",
    "\n",
    "    A_i = (w1 - w0) * (dY_T - m0d_T)\n",
    "    B_i = w1\n",
    "\n",
    "    B_bar = float(np.mean(B_i))\n",
    "    B_bar = max(B_bar, EPS_LOC)\n",
    "    A_bar = float(np.mean(A_i))\n",
    "\n",
    "    b_hat = A_bar / B_bar\n",
    "\n",
    "    psi = A_i - B_i * b_hat\n",
    "    n_T = len(T_units)\n",
    "    var_b = float(np.mean(psi ** 2) / (max(n_T, 1) * B_bar ** 2))\n",
    "\n",
    "    a_hat = B_bar\n",
    "    return a_hat, b_hat, var_b\n",
    "\n",
    "\n",
    "# =====================\n",
    "# Fusion helpers\n",
    "# =====================\n",
    "\n",
    "def build_lambda_fixed(m, lam):\n",
    "    L = np.zeros((m, m))\n",
    "    for i in range(m):\n",
    "        for j in range(i + 1, m):\n",
    "            L[i, j] = L[j, i] = lam\n",
    "    return L\n",
    "\n",
    "\n",
    "def build_lambda_adaptive(theta_init, c_w=C_W, gamma=GAMMA_DEFAULT,\n",
    "                          tau=None, eps=ADAPT_EPS, km_n_init=10, random_state=0):\n",
    "    m = len(theta_init)\n",
    "    L = np.zeros((m, m), dtype=float)\n",
    "    if m < 2:\n",
    "        return L\n",
    "\n",
    "    wvals = []\n",
    "    for i in range(m):\n",
    "        for j in range(i + 1, m):\n",
    "            diff = abs(theta_init[i] - theta_init[j])\n",
    "            w = c_w * (diff + 1e-8) ** (-gamma)\n",
    "            wvals.append(w)\n",
    "    wvals = np.asarray(wvals, dtype=float)\n",
    "\n",
    "    if tau is None:\n",
    "        z = np.log(np.maximum(wvals, 1e-18)).reshape(-1, 1)\n",
    "        if np.allclose(z, z[0]):\n",
    "            tau = float(np.exp(z[0, 0]))\n",
    "        else:\n",
    "            km = KMeans(n_clusters=2, n_init=km_n_init, random_state=random_state)\n",
    "            km.fit(z)\n",
    "            centers = np.sort(km.cluster_centers_.ravel())\n",
    "            tau = float(np.exp(0.5 * (centers[0] + centers[1])))\n",
    "\n",
    "    idx = 0\n",
    "    for i in range(m):\n",
    "        for j in range(i + 1, m):\n",
    "            w = wvals[idx]\n",
    "            idx += 1\n",
    "            lam = eps if w <= tau else w\n",
    "            L[i, j] = L[j, i] = lam\n",
    "    return L\n",
    "\n",
    "\n",
    "def admm_graph_fused_lasso(a, b, Lambda, rho=1.0, max_iter=2000, tol=1e-6):\n",
    "    m = len(a)\n",
    "    edges = [(i, j) for i in range(m) for j in range(i + 1, m)]\n",
    "    E = len(edges)\n",
    "    if E == 0:\n",
    "        return b.copy()\n",
    "\n",
    "    B_mat = np.zeros((E, m))\n",
    "    for e, (i, j) in enumerate(edges):\n",
    "        B_mat[e, i] = 1.0\n",
    "        B_mat[e, j] = -1.0\n",
    "    lams = np.array([Lambda[i, j] for (i, j) in edges])\n",
    "    A_mat = np.diag(a)\n",
    "    M = A_mat + rho * (B_mat.T @ B_mat)\n",
    "\n",
    "    theta = b.copy()\n",
    "    z = B_mat @ theta\n",
    "    u = np.zeros_like(z)\n",
    "\n",
    "    def soft(v, kappa):\n",
    "        return np.sign(v) * np.maximum(np.abs(v) - kappa, 0.0)\n",
    "\n",
    "    for _ in range(max_iter):\n",
    "        theta_new = np.linalg.solve(M, A_mat @ b + rho * B_mat.T @ (z - u))\n",
    "        r = B_mat @ theta_new + u\n",
    "        z_new = soft(r, lams / rho)\n",
    "        u_new = r - z_new\n",
    "        if np.linalg.norm(B_mat @ theta_new - z_new) < tol and np.linalg.norm(rho * B_mat.T @ (z_new - z)) < tol:\n",
    "            theta = theta_new\n",
    "            break\n",
    "        theta, z, u = theta_new, z_new, u_new\n",
    "    return theta\n",
    "\n",
    "\n",
    "def clusters_from_theta(theta_hat, tol=1e-3):\n",
    "    order = np.argsort(theta_hat)\n",
    "    labs_sorted = np.zeros_like(order)\n",
    "    cur = 0\n",
    "    labs_sorted[0] = 0\n",
    "    for k in range(1, len(order)):\n",
    "        if abs(theta_hat[order[k]] - theta_hat[order[k - 1]]) > tol:\n",
    "            cur += 1\n",
    "        labs_sorted[k] = cur\n",
    "    labels = np.zeros_like(labs_sorted)\n",
    "    labels[order] = labs_sorted\n",
    "    return labels\n",
    "\n",
    "\n",
    "# =====================\n",
    "# ARMUL (per-task λ_j) clustered second stage\n",
    "# =====================\n",
    "\n",
    "def _soft(x, t):\n",
    "    return np.sign(x) * np.maximum(np.abs(x) - t, 0.0)\n",
    "\n",
    "\n",
    "def _armul_obj(theta, gamma, c, a, b, lam_vec):\n",
    "    return 0.5 * np.sum(a * (theta - b) ** 2) + np.sum(lam_vec * np.abs(theta - gamma[c]))\n",
    "\n",
    "\n",
    "def _ensure_nonempty(labels, K, rng):\n",
    "    cnt = np.bincount(labels, minlength=K)\n",
    "    for k in range(K):\n",
    "        if cnt[k] == 0:\n",
    "            src = int(np.argmax(cnt))\n",
    "            idx_src = np.where(labels == src)[0]\n",
    "            j_move = int(rng.choice(idx_src))\n",
    "            labels[j_move] = k\n",
    "            cnt[src] -= 1\n",
    "            cnt[k] += 1\n",
    "    return labels\n",
    "\n",
    "\n",
    "def _kmeans_init_labels_1d(b, K, seed, n_init_kmeans=10):\n",
    "    km = KMeans(n_clusters=K, n_init=n_init_kmeans, random_state=seed)\n",
    "    return km.fit_predict(b.reshape(-1, 1)).astype(int)\n",
    "\n",
    "\n",
    "def armul_clustered_fit(\n",
    "    b, a, lam_vec, K_candidates,\n",
    "    init=\"kmeans\", n_init=1, n_iter=300, tol=1e-7, seed=123, kmeans_n_init=10\n",
    "):\n",
    "    rng = np.random.default_rng(seed)\n",
    "    m = len(b)\n",
    "    lam_vec = np.asarray(lam_vec, dtype=float)\n",
    "    assert lam_vec.shape == (m,)\n",
    "\n",
    "    K_list = [K_candidates] if isinstance(K_candidates, int) else list(K_candidates)\n",
    "    if not K_list:\n",
    "        raise ValueError(\"K_candidates is empty.\")\n",
    "\n",
    "    def theta_update(gamma, labels):\n",
    "        theta = np.empty_like(b, dtype=float)\n",
    "        for j in range(m):\n",
    "            gk = gamma[labels[j]]\n",
    "            theta[j] = gk + _soft(b[j] - gk, lam_vec[j] / max(a[j], EPS))\n",
    "        return theta\n",
    "\n",
    "    best = dict(obj=np.inf, theta=None, labels=None, gamma=None, K=None)\n",
    "\n",
    "    for K in K_list:\n",
    "        for init_id in range(n_init):\n",
    "            if init == \"kmeans\":\n",
    "                labels = _kmeans_init_labels_1d(b, K, seed + init_id, n_init_kmeans=kmeans_n_init)\n",
    "            elif init == \"quantiles\":\n",
    "                order = np.argsort(b)\n",
    "                labels = np.zeros(m, dtype=int)\n",
    "                edges = np.linspace(0, m, K + 1).astype(int)\n",
    "                for k in range(K):\n",
    "                    labels[order[edges[k]:edges[k + 1]]] = k\n",
    "            elif init == \"random\":\n",
    "                labels = rng.integers(0, K, size=m)\n",
    "            else:\n",
    "                raise ValueError(f\"Unknown init='{init}'\")\n",
    "            labels = _ensure_nonempty(labels, K, rng)\n",
    "\n",
    "            gamma = np.array([np.median(b[labels == k]) for k in range(K)], dtype=float)\n",
    "            theta = theta_update(gamma, labels)\n",
    "            obj_prev = _armul_obj(theta, gamma, labels, a, b, lam_vec)\n",
    "\n",
    "            for _ in range(n_iter):\n",
    "                # θ-update\n",
    "                theta = theta_update(gamma, labels)\n",
    "                # γ-update\n",
    "                for k in range(K):\n",
    "                    idx = np.where(labels == k)[0]\n",
    "                    gamma[k] = float(np.median(theta[idx])) if len(idx) else float(rng.choice(theta))\n",
    "                # c-update\n",
    "                theta_new = np.empty_like(theta)\n",
    "                labels_new = labels.copy()\n",
    "                for j in range(m):\n",
    "                    costs = np.empty(K, dtype=float)\n",
    "                    thetas_j = np.empty(K, dtype=float)\n",
    "                    for kk in range(K):\n",
    "                        gk = gamma[kk]\n",
    "                        tj = gk + _soft(b[j] - gk, lam_vec[j] / max(a[j], EPS))\n",
    "                        thetas_j[kk] = tj\n",
    "                        costs[kk] = 0.5 * a[j] * (tj - b[j]) ** 2 + lam_vec[j] * abs(tj - gk)\n",
    "                    kbest = int(np.argmin(costs))\n",
    "                    labels_new[j] = kbest\n",
    "                    theta_new[j] = thetas_j[kbest]\n",
    "\n",
    "                labels_new = _ensure_nonempty(labels_new, K, rng)\n",
    "                obj = _armul_obj(theta_new, gamma, labels_new, a, b, lam_vec)\n",
    "                if (obj_prev - obj) <= tol:\n",
    "                    theta, labels, obj_prev = theta_new, labels_new, obj\n",
    "                    break\n",
    "                theta, labels, obj_prev = theta_new, labels_new, obj\n",
    "\n",
    "            if obj_prev < best[\"obj\"]:\n",
    "                best = dict(theta=theta.copy(), labels=labels.copy(),\n",
    "                            gamma=gamma.copy(), obj=float(obj_prev), K=K)\n",
    "    return best[\"theta\"], best[\"labels\"], best[\"obj\"], best[\"K\"], best.get(\"gamma\", None)\n",
    "\n",
    "\n",
    "# =============== Per-task λ_j = C_LAM * sqrt(1 / n_j) ===============\n",
    "\n",
    "def armul_lambda_per_task(splits, c=C_LAM, use_target=True):\n",
    "    \"\"\"Compute λ_j = c * sqrt(1 / n_j).\n",
    "    \"\"\"\n",
    "    n_list = []\n",
    "    for (A_units, T_units) in splits:\n",
    "        n_j = len(T_units) if use_target else (len(A_units) + len(T_units))\n",
    "        n_list.append(max(1, int(n_j)))\n",
    "    n_arr = np.asarray(n_list, dtype=float)\n",
    "    lam_vec = c * np.sqrt(1.0 / n_arr)\n",
    "    return lam_vec.astype(float), {\"n_list\": n_arr.copy()}\n",
    "\n",
    "\n",
    "def cluster_sandwich_given_labels(theta_in, labels, var_b_list):\n",
    "    \"\"\"Clusterwise SE\n",
    "    \"\"\"\n",
    "    L = int(labels.max() + 1)\n",
    "    se_out = np.zeros_like(theta_in, dtype=float)\n",
    "    var_b_arr = np.asarray(var_b_list, dtype=float)\n",
    "\n",
    "    for g in range(L):\n",
    "        idx_tasks = np.where(labels == g)[0]\n",
    "        if len(idx_tasks) == 0:\n",
    "            continue\n",
    "        var_g = float(np.mean(var_b_arr[idx_tasks]))\n",
    "        se_g = float(np.sqrt(max(var_g, 1e-12)))\n",
    "        se_out[idx_tasks] = se_g\n",
    "\n",
    "    return theta_in.copy(), se_out, labels\n",
    "\n",
    "\n",
    "# =====================\n",
    "# One round core\n",
    "# =====================\n",
    "\n",
    "def one_round_core(delta: float, seed: int):\n",
    "    \"\"\"\n",
    "    Do the DID DGP, DR panel stats, individual, adaptive fused, and ARMUL.\n",
    "    \"\"\"\n",
    "    # --- DGP: DID panel tasks ---\n",
    "    tasks, tau_true, cluster_ids = make_did_tasks(M_TASKS, K_CLUST, delta, seed)\n",
    "\n",
    "    splits = []\n",
    "    for j, t in enumerate(tasks):\n",
    "        Y_long = t[\"Y\"]\n",
    "        n_long = len(Y_long)\n",
    "        assert n_long % 2 == 0, \"Long panel must have 2*T rows per unit.\"\n",
    "        n_units = n_long // 2\n",
    "        A_units, T_units = split_two_parts(n_units, A_FRAC, seed=SEED_BASE + 11111 + j)\n",
    "        splits.append((A_units, T_units))\n",
    "\n",
    "    # --- DR DID stats on T-units ---\n",
    "    a_T, b_T, var_b_T = [], [], []\n",
    "    for j, t in enumerate(tasks):\n",
    "        X_long, D_long, T_long, Y_long = t[\"X\"], t[\"D\"], t[\"T\"], t[\"Y\"]\n",
    "        A_units, T_units = splits[j]\n",
    "\n",
    "        a_j, b_j, v_j = panel_dr_did_stats(\n",
    "            X_long, D_long, T_long, Y_long,\n",
    "            A_units=A_units, T_units=T_units,\n",
    "            seed=LGBM_SEED_OFFSET + j,\n",
    "        )\n",
    "        a_T.append(a_j)\n",
    "        b_T.append(b_j)\n",
    "        var_b_T.append(v_j)\n",
    "\n",
    "    a_T = np.asarray(a_T, dtype=float)\n",
    "    b_T = np.asarray(b_T, dtype=float)\n",
    "    var_b_T = np.asarray(var_b_T, dtype=float)\n",
    "    se_T_indiv = np.sqrt(np.maximum(EPS, var_b_T.copy()))\n",
    "\n",
    "    tau_true = np.asarray(tau_true, dtype=float)\n",
    "    cluster_ids = np.asarray(cluster_ids, dtype=int)\n",
    "\n",
    "    n_T = np.array([len(T_units) for (A_units, T_units) in splits], dtype=float)\n",
    "\n",
    "    K_true = int(cluster_ids.max() + 1)\n",
    "    N_k_true = np.zeros(K_true, dtype=float)\n",
    "    for k in range(K_true):\n",
    "        idx_k = (cluster_ids == k)\n",
    "        N_k_true[k] = n_T[idx_k].sum()\n",
    "\n",
    "    cluster_weights = N_k_true[cluster_ids]  # shape (M_TASKS,)\n",
    "\n",
    "    # ============================\n",
    "    # (i) Individual estimator\n",
    "    # ============================\n",
    "    theta_ind = b_T.copy()\n",
    "    cl_ind = clusters_from_theta(theta_ind)\n",
    "    ari_ind = float(adjusted_rand_score(cluster_ids, cl_ind))\n",
    "\n",
    "    err_ind = theta_ind - tau_true\n",
    "    rmse_ind = float(np.sqrt(np.mean(err_ind ** 2)))\n",
    "\n",
    "    cw_rmse_ind = float(np.sqrt(np.sum(cluster_weights * err_ind ** 2)))\n",
    "    diff_cw_ind = np.sqrt(cluster_weights) * err_ind\n",
    "\n",
    "    zI = err_ind / se_T_indiv\n",
    "\n",
    "    # ============================\n",
    "    # (ii) Adaptive fused (from IND)\n",
    "    # ============================\n",
    "    L_ad_T = build_lambda_adaptive(\n",
    "        theta_ind,\n",
    "        c_w=C_W,\n",
    "        gamma=GAMMA_DEFAULT,\n",
    "        tau=ADAPT_TAU_DEFAULT,\n",
    "        eps=ADAPT_EPS,\n",
    "    )\n",
    "    theta_ad_pen = admm_graph_fused_lasso(a_T, b_T, L_ad_T)\n",
    "    cl_ad = clusters_from_theta(theta_ad_pen)\n",
    "    theta_ad, se_ad, cl_ad = cluster_sandwich_given_labels(theta_ad_pen, cl_ad, var_b_T)\n",
    "    ari_ad = float(adjusted_rand_score(cluster_ids, cl_ad))\n",
    "\n",
    "    err_ad = theta_ad - tau_true\n",
    "    rmse_ad = float(np.sqrt(np.mean(err_ad ** 2)))\n",
    "\n",
    "    cw_rmse_ad = float(np.sqrt(np.sum(cluster_weights * err_ad ** 2)))\n",
    "    diff_cw_ad = np.sqrt(cluster_weights) * err_ad\n",
    "\n",
    "    zA_ad = err_ad / se_ad\n",
    "\n",
    "    # ============================\n",
    "    # (iii) ARMUL with per-task λ_j\n",
    "    # ============================\n",
    "    lam_vec, lam_diag = armul_lambda_per_task(splits, c=C_LAM, use_target=True)\n",
    "\n",
    "    def fit_armul_with_K(K_fixed):\n",
    "        K_fixed = int(np.clip(K_fixed, 1, M_TASKS))\n",
    "        theta_k, labels_k, obj_k, K_sel_k, _ = armul_clustered_fit(\n",
    "            b=b_T, a=a_T, lam_vec=lam_vec,\n",
    "            K_candidates=[K_fixed],\n",
    "            init=\"kmeans\",\n",
    "            n_init=ARMUL_N_INIT,\n",
    "            n_iter=ARMUL_N_ITER,\n",
    "            tol=ARMUL_TOL,\n",
    "            seed=ARMUL_SEED,\n",
    "            kmeans_n_init=10,\n",
    "        )\n",
    "        theta_k, se_k, labels_k = cluster_sandwich_given_labels(theta_k, labels_k, var_b_T)\n",
    "        ari_k = float(adjusted_rand_score(cluster_ids, labels_k))\n",
    "\n",
    "        err_k = theta_k - tau_true\n",
    "        rmse_k = float(np.sqrt(np.mean(err_k ** 2)))\n",
    "        cw_rmse_k = float(np.sqrt(np.sum(cluster_weights * err_k ** 2)))\n",
    "        diff_cw_k = np.sqrt(cluster_weights) * err_k\n",
    "\n",
    "        return dict(\n",
    "            theta=theta_k, se=se_k, labels=labels_k,\n",
    "            ari=ari_k, rmse=rmse_k,\n",
    "            cw_rmse=cw_rmse_k, diff_cw=diff_cw_k,\n",
    "            K=K_sel_k, obj=obj_k\n",
    "        )\n",
    "\n",
    "    armul_km1 = fit_armul_with_K(K_CLUST - 1)\n",
    "    armul_k   = fit_armul_with_K(K_CLUST)       # oracle\n",
    "    armul_kp1 = fit_armul_with_K(K_CLUST + 1)\n",
    "\n",
    "    return {\n",
    "        \"theta_true\": tau_true,\n",
    "        \"cluster_ids_true\": cluster_ids,\n",
    "        \"a_T\": a_T,\n",
    "        \"b_T\": b_T,\n",
    "        \"var_b_T\": var_b_T,\n",
    "        \"se_T_indiv\": se_T_indiv,\n",
    "\n",
    "        \"cluster_n_T\": n_T,\n",
    "        \"cluster_N_k_true\": N_k_true,\n",
    "        \"cluster_weights\": cluster_weights,\n",
    "\n",
    "        \"theta_ind\": theta_ind,\n",
    "        \"ARI_ind\": ari_ind,\n",
    "        \"rmse_ind\": rmse_ind,\n",
    "        \"zI\": zI,\n",
    "\n",
    "        \"theta_ad\": theta_ad,\n",
    "        \"se_ad\": se_ad,\n",
    "        \"ARI_ad\": ari_ad,\n",
    "        \"rmse_ad\": rmse_ad,\n",
    "        \"zA_ad\": zA_ad,\n",
    "\n",
    "        \"theta_armul_km1\": armul_km1[\"theta\"],\n",
    "        \"se_armul_km1\": armul_km1[\"se\"],\n",
    "        \"ARI_armul_km1\": armul_km1[\"ari\"],\n",
    "        \"rmse_armul_km1\": armul_km1[\"rmse\"],\n",
    "        \"armul_K_sel_km1\": armul_km1[\"K\"],\n",
    "        \"armul_obj_km1\": armul_km1[\"obj\"],\n",
    "\n",
    "        \"theta_armul_k\": armul_k[\"theta\"],\n",
    "        \"se_armul_k\": armul_k[\"se\"],\n",
    "        \"ARI_armul_k\": armul_k[\"ari\"],\n",
    "        \"rmse_armul_k\": armul_k[\"rmse\"],\n",
    "        \"armul_K_sel_k\": armul_k[\"K\"],\n",
    "        \"armul_obj_k\": armul_k[\"obj\"],\n",
    "\n",
    "        \"theta_armul_kp1\": armul_kp1[\"theta\"],\n",
    "        \"se_armul_kp1\": armul_kp1[\"se\"],\n",
    "        \"ARI_armul_kp1\": armul_kp1[\"ari\"],\n",
    "        \"rmse_armul_kp1\": armul_kp1[\"rmse\"],\n",
    "        \"armul_K_sel_kp1\": armul_kp1[\"K\"],\n",
    "        \"armul_obj_kp1\": armul_kp1[\"obj\"],\n",
    "\n",
    "        \"armul_lambda_vec\": lam_vec,\n",
    "        \"armul_lambda_diag\": lam_diag,\n",
    "\n",
    "        \"cw_rmse_ind\": armul_km1[\"cw_rmse\"] * 0 + cw_rmse_ind,  # just store scalar\n",
    "        \"cw_rmse_ad\": armul_km1[\"cw_rmse\"] * 0 + cw_rmse_ad,\n",
    "        \"cw_rmse_armul_km1\": armul_km1[\"cw_rmse\"],\n",
    "        \"cw_rmse_armul_k\": armul_k[\"cw_rmse\"],\n",
    "        \"cw_rmse_armul_kp1\": armul_kp1[\"cw_rmse\"],\n",
    "\n",
    "        \"diff_cw_ind\": diff_cw_ind,\n",
    "        \"diff_cw_ad\": diff_cw_ad,\n",
    "        \"diff_cw_armul_km1\": armul_km1[\"diff_cw\"],\n",
    "        \"diff_cw_armul_k\": armul_k[\"diff_cw\"],\n",
    "        \"diff_cw_armul_kp1\": armul_kp1[\"diff_cw\"],\n",
    "    }\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "# =====================\n",
    "# Run ROUNDS\n",
    "# =====================\n",
    "\n",
    "def run_for_delta_fixed_grid(rounds: int, delta: float, start_seed: int, fixed_lambda_list):\n",
    "    \"\"\"\n",
    "    For a given delta, run ROUNDS times.\n",
    "    \"\"\"\n",
    "    def init_agg():\n",
    "        return {\n",
    "            \"TH_true\": [],\n",
    "            \"TH_ind\": [],\n",
    "            \"TH_fx\": [],\n",
    "            \"TH_ad\": [],\n",
    "            \"TH_arm_km1\": [],\n",
    "            \"TH_arm_k\": [],\n",
    "            \"TH_arm_kp1\": [],\n",
    "            \"zI\": [],\n",
    "            \"zA_fx\": [],\n",
    "            \"zA_ad\": [],\n",
    "            \"ARI_ind\": [],\n",
    "            \"ARI_fx\": [],\n",
    "            \"ARI_ad\": [],\n",
    "            \"ARI_arm_km1\": [],\n",
    "            \"ARI_arm_k\": [],\n",
    "            \"ARI_arm_kp1\": [],\n",
    "            \"rmse_ind\": [],\n",
    "            \"rmse_fx\": [],\n",
    "            \"rmse_ad\": [],\n",
    "            \"rmse_arm_km1\": [],\n",
    "            \"rmse_arm_k\": [],\n",
    "            \"rmse_arm_kp1\": [],\n",
    "            \"armul_K_sel_km1\": [],\n",
    "            \"armul_K_sel_k\": [],\n",
    "            \"armul_K_sel_kp1\": [],\n",
    "            \"armul_lambda\": [],\n",
    "\n",
    "            \"cw_rmse_ind\": [],\n",
    "            \"cw_rmse_fx\": [],\n",
    "            \"cw_rmse_ad\": [],\n",
    "            \"cw_rmse_arm_km1\": [],\n",
    "            \"cw_rmse_arm_k\": [],\n",
    "            \"cw_rmse_arm_kp1\": [],\n",
    "\n",
    "            \"diff_cw_ind\": [],\n",
    "            \"diff_cw_fx\": [],\n",
    "            \"diff_cw_ad\": [],\n",
    "            \"diff_cw_arm_km1\": [],\n",
    "            \"diff_cw_arm_k\": [],\n",
    "            \"diff_cw_arm_kp1\": [],\n",
    "        }\n",
    "\n",
    "    aggs = {lam: init_agg() for lam in fixed_lambda_list}\n",
    "\n",
    "    for r in tqdm(range(rounds), desc=f\"delta={delta}\"):\n",
    "        core = one_round_core(delta=delta, seed=start_seed + r)\n",
    "\n",
    "        a_T = core[\"a_T\"]\n",
    "        b_T = core[\"b_T\"]\n",
    "        var_b_T = core[\"var_b_T\"]\n",
    "        theta_true = core[\"theta_true\"]\n",
    "        cluster_ids = core[\"cluster_ids_true\"]\n",
    "        cluster_weights = core[\"cluster_weights\"]\n",
    "\n",
    "        for fixed_lambda in fixed_lambda_list:\n",
    "            agg = aggs[fixed_lambda]\n",
    "\n",
    "            L_fx_T = build_lambda_fixed(M_TASKS, fixed_lambda)\n",
    "            theta_fx_pen = admm_graph_fused_lasso(a_T, b_T, L_fx_T)\n",
    "            cl_fx = clusters_from_theta(theta_fx_pen)\n",
    "            theta_fx, se_fx, cl_fx = cluster_sandwich_given_labels(theta_fx_pen, cl_fx, var_b_T)\n",
    "            ari_fx = float(adjusted_rand_score(cluster_ids, cl_fx))\n",
    "            err_fx = theta_fx - theta_true\n",
    "            rmse_fx = float(np.sqrt(np.mean(err_fx ** 2)))\n",
    "            zA_fx = err_fx / se_fx\n",
    "\n",
    "            cw_rmse_fx = float(np.sqrt(np.sum(cluster_weights * err_fx ** 2)))\n",
    "            diff_cw_fx = np.sqrt(cluster_weights) * err_fx\n",
    "\n",
    "            agg[\"TH_true\"].append(theta_true)\n",
    "            agg[\"TH_ind\"].append(core[\"theta_ind\"])\n",
    "            agg[\"TH_fx\"].append(theta_fx)\n",
    "            agg[\"TH_ad\"].append(core[\"theta_ad\"])\n",
    "            agg[\"TH_arm_km1\"].append(core[\"theta_armul_km1\"])\n",
    "            agg[\"TH_arm_k\"].append(core[\"theta_armul_k\"])\n",
    "            agg[\"TH_arm_kp1\"].append(core[\"theta_armul_kp1\"])\n",
    "\n",
    "            agg[\"zI\"].append(core[\"zI\"])\n",
    "            agg[\"zA_fx\"].append(zA_fx)\n",
    "            agg[\"zA_ad\"].append(core[\"zA_ad\"])\n",
    "\n",
    "            agg[\"ARI_ind\"].append(core[\"ARI_ind\"])\n",
    "            agg[\"ARI_fx\"].append(ari_fx)\n",
    "            agg[\"ARI_ad\"].append(core[\"ARI_ad\"])\n",
    "            agg[\"ARI_arm_km1\"].append(core[\"ARI_armul_km1\"])\n",
    "            agg[\"ARI_arm_k\"].append(core[\"ARI_armul_k\"])\n",
    "            agg[\"ARI_arm_kp1\"].append(core[\"ARI_armul_kp1\"])\n",
    "\n",
    "            agg[\"rmse_ind\"].append(core[\"rmse_ind\"])\n",
    "            agg[\"rmse_fx\"].append(rmse_fx)\n",
    "            agg[\"rmse_ad\"].append(core[\"rmse_ad\"])\n",
    "            agg[\"rmse_arm_km1\"].append(core[\"rmse_armul_km1\"])\n",
    "            agg[\"rmse_arm_k\"].append(core[\"rmse_armul_k\"])\n",
    "            agg[\"rmse_arm_kp1\"].append(core[\"rmse_armul_kp1\"])\n",
    "\n",
    "            agg[\"armul_K_sel_km1\"].append(core[\"armul_K_sel_km1\"])\n",
    "            agg[\"armul_K_sel_k\"].append(core[\"armul_K_sel_k\"])\n",
    "            agg[\"armul_K_sel_kp1\"].append(core[\"armul_K_sel_kp1\"])\n",
    "            agg[\"armul_lambda\"].append(core[\"armul_lambda_vec\"])\n",
    "\n",
    "            agg[\"cw_rmse_ind\"].append(core[\"cw_rmse_ind\"])\n",
    "            agg[\"cw_rmse_fx\"].append(cw_rmse_fx)\n",
    "            agg[\"cw_rmse_ad\"].append(core[\"cw_rmse_ad\"])\n",
    "            agg[\"cw_rmse_arm_km1\"].append(core[\"cw_rmse_armul_km1\"])\n",
    "            agg[\"cw_rmse_arm_k\"].append(core[\"cw_rmse_armul_k\"])\n",
    "            agg[\"cw_rmse_arm_kp1\"].append(core[\"cw_rmse_armul_kp1\"])\n",
    "\n",
    "            agg[\"diff_cw_ind\"].append(core[\"diff_cw_ind\"])\n",
    "            agg[\"diff_cw_fx\"].append(diff_cw_fx)\n",
    "            agg[\"diff_cw_ad\"].append(core[\"diff_cw_ad\"])\n",
    "            agg[\"diff_cw_arm_km1\"].append(core[\"diff_cw_armul_km1\"])\n",
    "            agg[\"diff_cw_arm_k\"].append(core[\"diff_cw_armul_k\"])\n",
    "            agg[\"diff_cw_arm_kp1\"].append(core[\"diff_cw_armul_kp1\"])\n",
    "\n",
    "    res_by_lam = {}\n",
    "    for fixed_lambda, agg in aggs.items():\n",
    "        TH_true_all      = np.concatenate(agg[\"TH_true\"])\n",
    "        TH_ind_all       = np.concatenate(agg[\"TH_ind\"])\n",
    "        TH_fx_all        = np.concatenate(agg[\"TH_fx\"])\n",
    "        TH_ad_all        = np.concatenate(agg[\"TH_ad\"])\n",
    "        TH_arm_km1_all   = np.concatenate(agg[\"TH_arm_km1\"])\n",
    "        TH_arm_k_all     = np.concatenate(agg[\"TH_arm_k\"])\n",
    "        TH_arm_kp1_all   = np.concatenate(agg[\"TH_arm_kp1\"])\n",
    "\n",
    "        zI_all    = np.concatenate(agg[\"zI\"])\n",
    "        zA_fx_all = np.concatenate(agg[\"zA_fx\"])\n",
    "        zA_ad_all = np.concatenate(agg[\"zA_ad\"])\n",
    "\n",
    "        ARI_ind_arr      = np.array(agg[\"ARI_ind\"])\n",
    "        ARI_fx_arr       = np.array(agg[\"ARI_fx\"])\n",
    "        ARI_ad_arr       = np.array(agg[\"ARI_ad\"])\n",
    "        ARI_arm_km1_arr  = np.array(agg[\"ARI_arm_km1\"])\n",
    "        ARI_arm_k_arr    = np.array(agg[\"ARI_arm_k\"])\n",
    "        ARI_arm_kp1_arr  = np.array(agg[\"ARI_arm_kp1\"])\n",
    "\n",
    "        rmse_ind_arr     = np.array(agg[\"rmse_ind\"])\n",
    "        rmse_fx_arr      = np.array(agg[\"rmse_fx\"])\n",
    "        rmse_ad_arr      = np.array(agg[\"rmse_ad\"])\n",
    "        rmse_arm_km1_arr = np.array(agg[\"rmse_arm_km1\"])\n",
    "        rmse_arm_k_arr   = np.array(agg[\"rmse_arm_k\"])\n",
    "        rmse_arm_kp1_arr = np.array(agg[\"rmse_arm_kp1\"])\n",
    "\n",
    "        armul_K_sel_km1_all = np.array(agg[\"armul_K_sel_km1\"])\n",
    "        armul_K_sel_k_all   = np.array(agg[\"armul_K_sel_k\"])\n",
    "        armul_K_sel_kp1_all = np.array(agg[\"armul_K_sel_kp1\"])\n",
    "        armul_lambda_all    = np.array(agg[\"armul_lambda\"])\n",
    "\n",
    "        cw_rmse_ind_arr     = np.array(agg[\"cw_rmse_ind\"])\n",
    "        cw_rmse_fx_arr      = np.array(agg[\"cw_rmse_fx\"])\n",
    "        cw_rmse_ad_arr      = np.array(agg[\"cw_rmse_ad\"])\n",
    "        cw_rmse_arm_km1_arr = np.array(agg[\"cw_rmse_arm_km1\"])\n",
    "        cw_rmse_arm_k_arr   = np.array(agg[\"cw_rmse_arm_k\"])\n",
    "        cw_rmse_arm_kp1_arr = np.array(agg[\"cw_rmse_arm_kp1\"])\n",
    "\n",
    "        diff_cw_ind_all     = np.concatenate(agg[\"diff_cw_ind\"])\n",
    "        diff_cw_fx_all      = np.concatenate(agg[\"diff_cw_fx\"])\n",
    "        diff_cw_ad_all      = np.concatenate(agg[\"diff_cw_ad\"])\n",
    "        diff_cw_arm_km1_all = np.concatenate(agg[\"diff_cw_arm_km1\"])\n",
    "        diff_cw_arm_k_all   = np.concatenate(agg[\"diff_cw_arm_k\"])\n",
    "        diff_cw_arm_kp1_all = np.concatenate(agg[\"diff_cw_arm_kp1\"])\n",
    "\n",
    "        res = {\n",
    "            \"theta_true_all\": TH_true_all,\n",
    "            \"theta_ind_all\": TH_ind_all,\n",
    "            \"theta_fx_all\": TH_fx_all,\n",
    "            \"theta_ad_all\": TH_ad_all,\n",
    "            \"theta_armul_km1_all\": TH_arm_km1_all,\n",
    "            \"theta_armul_k_all\": TH_arm_k_all,\n",
    "            \"theta_armul_kp1_all\": TH_arm_kp1_all,\n",
    "\n",
    "            \"zI\": zI_all,\n",
    "            \"zA_fx\": zA_fx_all,\n",
    "            \"zA_ad\": zA_ad_all,\n",
    "\n",
    "            \"ARI_ind\": ARI_ind_arr,\n",
    "            \"ARI_fx\": ARI_fx_arr,\n",
    "            \"ARI_ad\": ARI_ad_arr,\n",
    "            \"ARI_armul_km1\": ARI_arm_km1_arr,\n",
    "            \"ARI_armul_k\":   ARI_arm_k_arr,\n",
    "            \"ARI_armul_kp1\": ARI_arm_kp1_arr,\n",
    "\n",
    "            \"rmse_ind_all\":   rmse_ind_arr,\n",
    "            \"rmse_fx_all\":    rmse_fx_arr,\n",
    "            \"rmse_ad_all\":    rmse_ad_arr,\n",
    "            \"rmse_armul_km1_all\": rmse_arm_km1_arr,\n",
    "            \"rmse_armul_k_all\":   rmse_arm_k_arr,\n",
    "            \"rmse_armul_kp1_all\": rmse_arm_kp1_arr,\n",
    "\n",
    "            \"mean_rmse_ind\": float(np.mean(rmse_ind_arr)),\n",
    "            \"mean_rmse_fx\":  float(np.mean(rmse_fx_arr)),\n",
    "            \"mean_rmse_ad\":  float(np.mean(rmse_ad_arr)),\n",
    "            \"mean_rmse_armul_km1\": float(np.mean(rmse_arm_km1_arr)),\n",
    "            \"mean_rmse_armul_k\":   float(np.mean(rmse_arm_k_arr)),\n",
    "            \"mean_rmse_armul_kp1\": float(np.mean(rmse_arm_kp1_arr)),\n",
    "            \"mean_ari_ind\": float(np.mean(ARI_ind_arr)),\n",
    "            \"mean_ari_fx\":  float(np.mean(ARI_fx_arr)),\n",
    "            \"mean_ari_ad\":  float(np.mean(ARI_ad_arr)),\n",
    "            \"mean_ari_armul_km1\": float(np.mean(ARI_arm_km1_arr)),\n",
    "            \"mean_ari_armul_k\":   float(np.mean(ARI_arm_k_arr)),\n",
    "            \"mean_ari_armul_kp1\": float(np.mean(ARI_arm_kp1_arr)),\n",
    "\n",
    "            \"armul_K_sel_km1_all\": armul_K_sel_km1_all,\n",
    "            \"armul_K_sel_k_all\":   armul_K_sel_k_all,\n",
    "            \"armul_K_sel_kp1_all\": armul_K_sel_kp1_all,\n",
    "            \"armul_lambda_all\": armul_lambda_all,\n",
    "\n",
    "            \"cw_rmse_ind_all\":     cw_rmse_ind_arr,\n",
    "            \"cw_rmse_fx_all\":      cw_rmse_fx_arr,\n",
    "            \"cw_rmse_ad_all\":      cw_rmse_ad_arr,\n",
    "            \"cw_rmse_armul_km1_all\": cw_rmse_arm_km1_arr,\n",
    "            \"cw_rmse_armul_k_all\":   cw_rmse_arm_k_arr,\n",
    "            \"cw_rmse_armul_kp1_all\": cw_rmse_arm_kp1_arr,\n",
    "\n",
    "            \"diff_cw_ind_all\":     diff_cw_ind_all,\n",
    "            \"diff_cw_fx_all\":      diff_cw_fx_all,\n",
    "            \"diff_cw_ad_all\":      diff_cw_ad_all,\n",
    "            \"diff_cw_armul_km1_all\": diff_cw_arm_km1_all,\n",
    "            \"diff_cw_armul_k_all\":   diff_cw_arm_k_all,\n",
    "            \"diff_cw_armul_kp1_all\": diff_cw_arm_kp1_all,\n",
    "\n",
    "            \"mean_cw_rmse_ind\": float(np.mean(cw_rmse_ind_arr)),\n",
    "            \"mean_cw_rmse_fx\":  float(np.mean(cw_rmse_fx_arr)),\n",
    "            \"mean_cw_rmse_ad\":  float(np.mean(cw_rmse_ad_arr)),\n",
    "            \"mean_cw_rmse_armul_km1\": float(np.mean(cw_rmse_arm_km1_arr)),\n",
    "            \"mean_cw_rmse_armul_k\":   float(np.mean(cw_rmse_arm_k_arr)),\n",
    "            \"mean_cw_rmse_armul_kp1\": float(np.mean(cw_rmse_arm_kp1_arr)),\n",
    "        }\n",
    "\n",
    "        res_by_lam[fixed_lambda] = res\n",
    "\n",
    "    return res_by_lam\n",
    "\n",
    "\n",
    "\n",
    "# =====================\n",
    "# Saving helpers\n",
    "# =====================\n",
    "\n",
    "def save_delta_results(outdir, delta, res_by_lam):\n",
    "    fixed_lams = np.array(sorted(res_by_lam.keys()), dtype=float)\n",
    "    npz_path = os.path.join(outdir, f\"did_delta{delta:.2f}.npz\")\n",
    "\n",
    "    to_save = {}\n",
    "    to_save[\"fixed_lambda_values\"] = fixed_lams\n",
    "\n",
    "    rep_res = res_by_lam[fixed_lams[0]]\n",
    "\n",
    "    shared_array_keys = [\n",
    "        \"theta_true_all\",\n",
    "        \"theta_ind_all\",\n",
    "        \"theta_ad_all\",\n",
    "        \"theta_armul_km1_all\",\n",
    "        \"theta_armul_k_all\",\n",
    "        \"theta_armul_kp1_all\",\n",
    "        \"zI\",\n",
    "        \"zA_ad\",\n",
    "        \"ARI_ind\",\n",
    "        \"ARI_ad\",\n",
    "        \"ARI_armul_km1\",\n",
    "        \"ARI_armul_k\",\n",
    "        \"ARI_armul_kp1\",\n",
    "        \"rmse_ind_all\",\n",
    "        \"rmse_ad_all\",\n",
    "        \"rmse_armul_km1_all\",\n",
    "        \"rmse_armul_k_all\",\n",
    "        \"rmse_armul_kp1_all\",\n",
    "        \"armul_K_sel_km1_all\",\n",
    "        \"armul_K_sel_k_all\",\n",
    "        \"armul_K_sel_kp1_all\",\n",
    "        \"armul_lambda_all\",\n",
    "\n",
    "        \"cw_rmse_ind_all\",\n",
    "        \"cw_rmse_ad_all\",\n",
    "        \"cw_rmse_armul_km1_all\",\n",
    "        \"cw_rmse_armul_k_all\",\n",
    "        \"cw_rmse_armul_kp1_all\",\n",
    "        \"diff_cw_ind_all\",\n",
    "        \"diff_cw_ad_all\",\n",
    "        \"diff_cw_armul_km1_all\",\n",
    "        \"diff_cw_armul_k_all\",\n",
    "        \"diff_cw_armul_kp1_all\",\n",
    "    ]\n",
    "\n",
    "    shared_summary_keys = [\n",
    "        \"mean_rmse_ind\",\n",
    "        \"mean_rmse_ad\",\n",
    "        \"mean_rmse_armul_km1\",\n",
    "        \"mean_rmse_armul_k\",\n",
    "        \"mean_rmse_armul_kp1\",\n",
    "        \"mean_ari_ind\",\n",
    "        \"mean_ari_ad\",\n",
    "        \"mean_ari_armul_km1\",\n",
    "        \"mean_ari_armul_k\",\n",
    "        \"mean_ari_armul_kp1\",\n",
    "\n",
    "        \"mean_cw_rmse_ind\",\n",
    "        \"mean_cw_rmse_ad\",\n",
    "        \"mean_cw_rmse_armul_km1\",\n",
    "        \"mean_cw_rmse_armul_k\",\n",
    "        \"mean_cw_rmse_armul_kp1\",\n",
    "    ]\n",
    "\n",
    "    for key in shared_array_keys:\n",
    "        to_save[key] = rep_res[key]\n",
    "    for key in shared_summary_keys:\n",
    "        to_save[key] = np.array(rep_res[key], dtype=float)\n",
    "\n",
    "    per_lam_array_keys = [\n",
    "        \"theta_fx_all\",\n",
    "        \"zA_fx\",\n",
    "        \"ARI_fx\",\n",
    "        \"rmse_fx_all\",\n",
    "        \"cw_rmse_fx_all\",\n",
    "        \"diff_cw_fx_all\",\n",
    "    ]\n",
    "    per_lam_summary_keys = [\n",
    "        \"mean_rmse_fx\",\n",
    "        \"mean_ari_fx\",\n",
    "        \"mean_cw_rmse_fx\",\n",
    "    ]\n",
    "\n",
    "    for k, lam in enumerate(fixed_lams):\n",
    "        res = res_by_lam[lam]\n",
    "        prefix = f\"lam{k}_\"\n",
    "        for key in per_lam_array_keys:\n",
    "            to_save[prefix + key] = res[key]\n",
    "        for key in per_lam_summary_keys:\n",
    "            to_save[prefix + key] = np.array(res[key], dtype=float)\n",
    "\n",
    "    np.savez(npz_path, **to_save)\n",
    "\n",
    "\n",
    "\n",
    "def append_summary_csv(outdir, delta, res_by_lam):\n",
    "    csv_path = os.path.join(outdir, \"summary.csv\")\n",
    "    write_header = not os.path.exists(csv_path)\n",
    "\n",
    "    import csv\n",
    "    fieldnames = [\n",
    "        \"model\",\n",
    "        \"delta\",\n",
    "        \"fixed_lambda\",\n",
    "        \"rounds\",\n",
    "        \"m_tasks\",\n",
    "        \"mean_rmse_individual\",\n",
    "        \"mean_rmse_fixed\",\n",
    "        \"mean_rmse_adaptive\",\n",
    "        \"mean_rmse_armul_km1\",\n",
    "        \"mean_rmse_armul_k\",\n",
    "        \"mean_rmse_armul_kp1\",\n",
    "        \"mean_ari_individual\",\n",
    "        \"mean_ari_fixed\",\n",
    "        \"mean_ari_adaptive\",\n",
    "        \"mean_ari_armul_km1\",\n",
    "        \"mean_ari_armul_k\",\n",
    "        \"mean_ari_armul_kp1\",\n",
    "        \"notes\",\n",
    "    ]\n",
    "\n",
    "    with open(csv_path, \"a\", newline=\"\") as f:\n",
    "        w = csv.DictWriter(f, fieldnames=fieldnames)\n",
    "        if write_header:\n",
    "            w.writeheader()\n",
    "\n",
    "        for fixed_lambda, res in res_by_lam.items():\n",
    "            row = {\n",
    "                \"model\": \"did_lgbm_DR_panel_plus_ARMUL\",\n",
    "                \"delta\": float(delta),\n",
    "                \"fixed_lambda\": float(fixed_lambda),\n",
    "                \"rounds\": ROUNDS,\n",
    "                \"m_tasks\": M_TASKS,\n",
    "                \"mean_rmse_individual\": res[\"mean_rmse_ind\"],\n",
    "                \"mean_rmse_fixed\":      res[\"mean_rmse_fx\"],\n",
    "                \"mean_rmse_adaptive\":   res[\"mean_rmse_ad\"],\n",
    "                \"mean_rmse_armul_km1\":  res[\"mean_rmse_armul_km1\"],\n",
    "                \"mean_rmse_armul_k\":    res[\"mean_rmse_armul_k\"],\n",
    "                \"mean_rmse_armul_kp1\":  res[\"mean_rmse_armul_kp1\"],\n",
    "                \"mean_ari_individual\":  res[\"mean_ari_ind\"],\n",
    "                \"mean_ari_fixed\":       res[\"mean_ari_fx\"],\n",
    "                \"mean_ari_adaptive\":    res[\"mean_ari_ad\"],\n",
    "                \"mean_ari_armul_km1\":   res[\"mean_ari_armul_km1\"],\n",
    "                \"mean_ari_armul_k\":     res[\"mean_ari_armul_k\"],\n",
    "                \"mean_ari_armul_kp1\":   res[\"mean_ari_armul_kp1\"],\n",
    "                \"notes\": (\n",
    "                    \"Two-period panel DID in Sant'Anna & Zhao (2020) framework. \"\n",
    "                    \"Per-task DR score ψ = (w1 - w0)*(ΔY - m0Δ(X)) - w1*τ; \"\n",
    "                    \"quadratic loss 0.5*a_j*(τ_j - b_j)^2 with a_j=E[w1], b_j=Ā/B̄. \"\n",
    "                    f\"Fixed fused λ = {fixed_lambda}. \"\n",
    "                    \"ARMUL uses per-task λ_j=C_LAM*sqrt(1/n_j) with n_j=|T_j|.\"\n",
    "                ),\n",
    "            }\n",
    "            w.writerow(row)\n",
    "\n",
    "\n",
    "# =====================\n",
    "# Main\n",
    "# =====================\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    timestamp = datetime.now().strftime(\"%Y%m%d_%H%M%S\")\n",
    "    base_out = ensure_dir(os.path.join(\"results_did_binary\", timestamp))\n",
    "    model_out = ensure_dir(os.path.join(base_out, \"did_binary\"))\n",
    "\n",
    "    print(f\"Running DID (DR panel) with IND / FIX / ADAPT and ARMUL(K-1,K,K+1). \"\n",
    "          f\"Rounds={ROUNDS}, m={M_TASKS}, true K={K_CLUST}\")\n",
    "    print(f\"Fixed λ grid (per delta): {FIXED_LAMBDA_LIST}\")\n",
    "\n",
    "    for delta in DELTAS:\n",
    "        print(f\"\\n=== delta = {delta} ===\")\n",
    "        res_by_lam = run_for_delta_fixed_grid(\n",
    "            rounds=ROUNDS,\n",
    "            delta=delta,\n",
    "            start_seed=SEED_BASE,\n",
    "            fixed_lambda_list=FIXED_LAMBDA_LIST,\n",
    "        )\n",
    "\n",
    "        lam0 = sorted(res_by_lam.keys())[0]\n",
    "        shared = res_by_lam[lam0]\n",
    "        print(\"Shared metrics (IND / ADAPT / ARMUL):\")\n",
    "        print({\n",
    "            \"delta\": delta,\n",
    "            \"Mean_RMSE_Indiv\":        shared[\"mean_rmse_ind\"],\n",
    "            \"Mean_RMSE_Adapt\":        shared[\"mean_rmse_ad\"],\n",
    "            \"Mean_RMSE_ARMUL(K-1)\":   shared[\"mean_rmse_armul_km1\"],\n",
    "            \"Mean_RMSE_ARMUL(K)\":     shared[\"mean_rmse_armul_k\"],\n",
    "            \"Mean_RMSE_ARMUL(K+1)\":   shared[\"mean_rmse_armul_kp1\"],\n",
    "            \"Mean_ARI_Indiv\":         shared[\"mean_ari_ind\"],\n",
    "            \"Mean_ARI_Adapt\":         shared[\"mean_ari_ad\"],\n",
    "            \"Mean_ARI_ARMUL(K-1)\":    shared[\"mean_ari_armul_km1\"],\n",
    "            \"Mean_ARI_ARMUL(K)\":      shared[\"mean_ari_armul_k\"],\n",
    "            \"Mean_ARI_ARMUL(K+1)\":    shared[\"mean_ari_armul_kp1\"],\n",
    "        })\n",
    "\n",
    "        print(\"Fixed-lambda sweep (only FIX metrics):\")\n",
    "        for lam in sorted(res_by_lam.keys()):\n",
    "            res = res_by_lam[lam]\n",
    "            print({\n",
    "                \"delta\": delta,\n",
    "                \"fixed_lambda\": lam,\n",
    "                \"Mean_RMSE_Fixed\": res[\"mean_rmse_fx\"],\n",
    "                \"Mean_ARI_Fixed\":  res[\"mean_ari_fx\"],\n",
    "            })\n",
    "\n",
    "        save_delta_results(model_out, delta, res_by_lam)\n",
    "        append_summary_csv(base_out, delta, res_by_lam)\n",
    "\n",
    "    manifest = {\n",
    "        \"model\": \"did_lgbm_DR_panel_plus_ARMUL_fixedlam_grid\",\n",
    "        \"deltas\": DELTAS,\n",
    "        \"rounds\": ROUNDS,\n",
    "        \"m_tasks\": M_TASKS,\n",
    "        \"timestamp\": timestamp,\n",
    "        \"true_K_for_eval\": K_CLUST,\n",
    "        \"tau_centers\": \"tau_k = k*delta - (K+1)*delta/2, k=1..K\",\n",
    "        \"nuisance_target_split\": {\"A_frac\": A_FRAC, \"role_swap\": False},\n",
    "        \"orthogonal_score\": (\n",
    "            \"Panel DID (Sant'Anna & Zhao DR): \"\n",
    "            \"ΔY = Y1 - Y0; score ψ = (w1 - w0)*(ΔY - m0Δ(X)) - w1*τ; \"\n",
    "            \"per-task loss 0.5 * a_j * (τ_j - b_j)^2 with a_j = E[w1], \"\n",
    "            \"b_j = E[(w1-w0)(ΔY-m0Δ)]/E[w1].\"\n",
    "        ),\n",
    "        \"fusion_hyperparams\": {\n",
    "            \"adaptive\": {\n",
    "                \"gamma\": GAMMA_DEFAULT,\n",
    "                \"tau\": ADAPT_TAU_DEFAULT,\n",
    "                \"eps\": ADAPT_EPS,\n",
    "                \"c_w\": C_W,\n",
    "            },\n",
    "            \"fixed_lambda_list\": FIXED_LAMBDA_LIST,\n",
    "        },\n",
    "        \"armul\": {\n",
    "            \"C_LAM\": C_LAM,\n",
    "            \"variants\": [\"K-1\", \"K (oracle)\", \"K+1\"],\n",
    "            \"n_init\": ARMUL_N_INIT,\n",
    "            \"n_iter\": ARMUL_N_ITER,\n",
    "            \"tol\": ARMUL_TOL,\n",
    "            \"seed\": ARMUL_SEED,\n",
    "            \"lambda_rule\": \"λ_j = C_LAM * sqrt(1 / n_j), with n_j = |T_j|\",\n",
    "        },\n",
    "        \"nuisance\": {\n",
    "            \"Y_model\": \"LGBMRegressor (for m0Δ on controls)\",\n",
    "            \"D_model\": \"LGBMClassifier (propensity)\",\n",
    "            \"shared_params\": {\n",
    "                \"n_estimators\": LGBM_ESTIMATORS,\n",
    "                \"learning_rate\": LGBM_LEARNING_RATE,\n",
    "                \"num_leaves\": LGBM_NUM_LEAVES,\n",
    "                \"min_child_samples\": LGBM_MIN_CHILD_SAMPLES,\n",
    "                \"subsample\": LGBM_SUBSAMPLE,\n",
    "                \"colsample_bytree\": LGBM_COLSAMPLE,\n",
    "                \"reg_alpha\": LGBM_REG_ALPHA,\n",
    "                \"reg_lambda\": LGBM_REG_LAMBDA,\n",
    "                \"n_jobs\": N_JOBS,\n",
    "            },\n",
    "        },\n",
    "        \"inference\": (\n",
    "            \"Per-task IF-based variance from DR DID score; \"\n",
    "            \"clusterwise SE via pooling var_b across tasks in each cluster.\"\n",
    "        ),\n",
    "    }\n",
    "    with open(os.path.join(base_out, \"manifest.json\"), \"w\") as f:\n",
    "        json.dump(manifest, f, indent=2)\n",
    "\n",
    "    print(f\"\\nAll results saved under: ./results_did_binary/{timestamp}/\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "050cc47c-b5d9-4282-8c75-9d26c38625e4",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
