{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ============================================================================\n",
    "# CHUNK 1: SETUP - Algorithms, Helper Functions, and Definitions\n",
    "# ============================================================================\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "import pandas as pd\n",
    "import time\n",
    "import seaborn as sns\n",
    "from matplotlib.patches import Rectangle\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "import GEOparse\n",
    "import gseapy as gp\n",
    "\n",
    "# Plotting configuration\n",
    "plt.rcParams.update({'font.size': 20})\n",
    "\n",
    "\n",
    "# ============================================================================\n",
    "# Optimized HJ-Prox Implementation\n",
    "# ============================================================================\n",
    "\n",
    "def compute_prox_optimized(x, t, f, delta=1e-1, int_samples=100, alpha=1.0, \n",
    "                                linesearch_iters=0, device='cpu'):\n",
    "    \"\"\"\n",
    "    Optimized version of compute_prox with several improvements:\n",
    "    \n",
    "    1. Use torch.sqrt instead of np.sqrt\n",
    "    2. More efficient overflow checking\n",
    "    3. Better memory layout for y samples\n",
    "    4. Reduced unnecessary operations\n",
    "    \"\"\"\n",
    "    assert x.shape[1] == 1\n",
    "    assert x.shape[0] >= 1\n",
    "    \n",
    "    linesearch_iters += 1\n",
    "    dim = x.shape[0]\n",
    "    input_dtype = x.dtype\n",
    "    \n",
    "    # Use torch operations throughout\n",
    "    standard_dev = torch.sqrt(torch.tensor(delta * t / alpha, dtype=input_dtype, device=device))\n",
    "    \n",
    "    # Sample y points\n",
    "    x_flat = x.squeeze(1)\n",
    "    y = standard_dev * torch.randn(int_samples, dim, device=device, dtype=input_dtype) + x_flat\n",
    "    \n",
    "    # Compute weights\n",
    "    z = -f(y) * (alpha / delta)\n",
    "    z = z.to(dtype=input_dtype)\n",
    "    \n",
    "    # Check for overflow before softmax\n",
    "    z_max = z.max()\n",
    "    if z_max > 88:\n",
    "        alpha *= 0.5\n",
    "        return compute_prox_optimized(x, t, f, delta=delta, int_samples=int_samples, \n",
    "                                         alpha=alpha, linesearch_iters=linesearch_iters, \n",
    "                                         device=device)\n",
    "    \n",
    "    w = torch.softmax(z, dim=0)\n",
    "    \n",
    "    # More efficient overflow check\n",
    "    if torch.isinf(w).any():\n",
    "        alpha *= 0.5\n",
    "        return compute_prox_optimized(x, t, f, delta=delta, int_samples=int_samples, \n",
    "                                         alpha=alpha, linesearch_iters=linesearch_iters, \n",
    "                                         device=device)\n",
    "    \n",
    "    # Compute proximal term - use einsum for efficiency\n",
    "    prox_term = torch.einsum('i,ij->j', w, y).unsqueeze(1)\n",
    "    \n",
    "    # Envelope computation\n",
    "    envelope = (f(prox_term.T) + \n",
    "                (1 / (2 * t)) * torch.linalg.vector_norm(prox_term.squeeze() - x_flat) ** 2)\n",
    "    \n",
    "    return prox_term, envelope, linesearch_iters\n",
    "\n",
    "\n",
    "# ============================================================================\n",
    "# Helper Functions for Davis-Yin\n",
    "# ============================================================================\n",
    "\n",
    "def overlapping_group_penalty_compiled(beta_samples, groups, lambda2, weights):\n",
    "    \"\"\"\n",
    "    Compute overlapping group penalty with JIT optimization.\n",
    "    \"\"\"\n",
    "    @torch.compile(mode=\"reduce-overhead\")\n",
    "    def _compute_group_norm(group_samples):\n",
    "        return torch.linalg.vector_norm(group_samples, dim=0)\n",
    "    \n",
    "    if beta_samples.dim() == 1:\n",
    "        penalty = 0.0\n",
    "        for i, Gi in enumerate(groups):\n",
    "            penalty += lambda2 * weights[i] * _compute_group_norm(beta_samples[Gi].unsqueeze(1)).item()\n",
    "        return torch.as_tensor(penalty, device=beta_samples.device, dtype=beta_samples.dtype)\n",
    "    else:\n",
    "        n_samples = beta_samples.shape[1]\n",
    "        penalties = torch.zeros(n_samples, device=beta_samples.device, dtype=beta_samples.dtype)\n",
    "        \n",
    "        for i, Gi in enumerate(groups):\n",
    "            group_norms = _compute_group_norm(beta_samples[Gi, :])\n",
    "            penalties.add_(group_norms, alpha=lambda2 * weights[i])\n",
    "        \n",
    "        return penalties\n",
    "\n",
    "\n",
    "# ============================================================================\n",
    "# Algorithm 1: Davis-Yin with HJ-Prox (Adaptive)\n",
    "# ============================================================================\n",
    "\n",
    "def davis_yin_adaptive_overlapping_group_lasso(\n",
    "        X, y, groups, lambda1, lambda2,\n",
    "        weights=None,\n",
    "        gamma_init=None,\n",
    "        tau=0.95,\n",
    "        growth_factor=1.25,\n",
    "        delta=0.1,\n",
    "        int_samples=100,\n",
    "        max_iter=500,\n",
    "        tol=1e-6,\n",
    "        z_init=None,\n",
    "        u_init=None,\n",
    "        device='cpu',\n",
    "        verbose=True):\n",
    "    \"\"\"\n",
    "    Adaptive Three Operator Splitting (Algorithm 1, Variant 2).\n",
    "    \n",
    "    Problem: minimize f(β) + g(β) + h(β)\n",
    "    - f(β) = 0.5||Xβ - y||²      (smooth, Lipschitz gradient)\n",
    "    - g(β) = λ1||β||₁            (proximal, CHEAP - soft threshold)\n",
    "    - h(β) = λ2 Σⁱ wⁱ||β_Gⁱ||₂   (proximal, EXPENSIVE - group lasso)\n",
    "    \n",
    "    Returns:\n",
    "        dict with keys:\n",
    "            'beta_z': z_new (recommended solution for structured penalties)\n",
    "            'beta_x': x (intermediate iterate)\n",
    "            'dual': u_new (dual variable)\n",
    "            'history': tracking dict with objective_x, objective_z, gamma, etc.\n",
    "    \"\"\"\n",
    "    \n",
    "    # Setup\n",
    "    X_np = np.asarray(X, dtype=np.float64)\n",
    "    y_np = np.asarray(y, dtype=np.float64).ravel()\n",
    "    n, p = X_np.shape\n",
    "    \n",
    "    X_torch = torch.tensor(X_np, dtype=torch.float64, device=device)\n",
    "    y_torch = torch.tensor(y_np, dtype=torch.float64, device=device)\n",
    "    \n",
    "    if weights is None:\n",
    "        weights_np = np.array([np.sqrt(len(g)) for g in groups], dtype=np.float64)\n",
    "    else:\n",
    "        weights_np = np.asarray(weights, dtype=np.float64)\n",
    "    \n",
    "    # Initial step size\n",
    "    if gamma_init is None:\n",
    "        L_approx = torch.linalg.norm(X_torch, ord=2)**2\n",
    "        gamma = 0.9 / L_approx.item()\n",
    "    else:\n",
    "        gamma = gamma_init\n",
    "    \n",
    "    # Initialize z, u\n",
    "    if z_init is not None:\n",
    "        z = torch.tensor(z_init, dtype=torch.float64, device=device)\n",
    "    else:\n",
    "        z = torch.zeros(p, 1, device=device, dtype=torch.float64)\n",
    "    \n",
    "    if u_init is not None:\n",
    "        u = torch.tensor(u_init, dtype=torch.float64, device=device)\n",
    "    else:\n",
    "        u = torch.zeros(p, 1, device=device, dtype=torch.float64)\n",
    "    \n",
    "    if z.dim() == 1:\n",
    "        z = z.unsqueeze(1)\n",
    "    if u.dim() == 1:\n",
    "        u = u.unsqueeze(1)\n",
    "    \n",
    "    XTy = X_torch.T @ y_torch.unsqueeze(1)\n",
    "    \n",
    "    # Estimate Lipschitz constant of h for Variant 2 growth\n",
    "    beta_h = lambda2 * np.sqrt(len(groups))\n",
    "    \n",
    "    # Operators\n",
    "    def grad_f(beta):\n",
    "        \"\"\"∇f(β) = X^T(Xβ - y)\"\"\"\n",
    "        b = beta if beta.dim() == 2 else beta.unsqueeze(1)\n",
    "        return X_torch.T @ (X_torch @ b) - XTy\n",
    "    \n",
    "    def eval_f(beta):\n",
    "        \"\"\"f(β) = 0.5||Xβ - y||²\"\"\"\n",
    "        b = beta if beta.dim() == 2 else beta.unsqueeze(1)\n",
    "        resid = X_torch @ b - y_torch.unsqueeze(1)\n",
    "        return 0.5 * torch.sum(resid ** 2)\n",
    "    \n",
    "    def eval_g(beta):\n",
    "        \"\"\"g(β) = λ1||β||₁\"\"\"\n",
    "        b = beta if beta.dim() == 2 else beta.unsqueeze(1)\n",
    "        return lambda1 * torch.sum(torch.abs(b))\n",
    "    \n",
    "    def prox_g(v, step):\n",
    "        \"\"\"prox_{γg}(v) for g(β) = λ1||β||₁ - CHEAP (soft-thresholding)\"\"\"\n",
    "        threshold = step * lambda1\n",
    "        return torch.sign(v) * torch.maximum(torch.abs(v) - threshold, \n",
    "                                              torch.zeros_like(v))\n",
    "    \n",
    "    def prox_h(v, step, delta_k):\n",
    "        \"\"\"prox_{γh}(v) via HJ-Prox - EXPENSIVE but called once per iteration\"\"\"\n",
    "        v_col = v if v.dim() == 2 else v.unsqueeze(1)\n",
    "        \n",
    "        def penalty_func(beta_samples):\n",
    "            if beta_samples.dim() == 2:\n",
    "                beta_samples = beta_samples.T\n",
    "            return overlapping_group_penalty_compiled(\n",
    "                beta_samples, groups, lambda2, weights_np\n",
    "            )\n",
    "        \n",
    "        beta_prox, _, _ = compute_prox_optimized(\n",
    "            x=v_col, t=step, f=penalty_func,\n",
    "            delta=delta_k, int_samples=int_samples, device=device\n",
    "        )\n",
    "        return beta_prox\n",
    "    \n",
    "    # Tracking\n",
    "    history = {\n",
    "        'objective_x': [],\n",
    "        'objective_z': [],\n",
    "        'gamma': [],\n",
    "        'backtracks': [],\n",
    "        'time_prox_h': [],\n",
    "        'delta_t': []\n",
    "    }\n",
    "    \n",
    "    gamma_prev = gamma\n",
    "    delta_prev = 0.0\n",
    "    \n",
    "    if verbose:\n",
    "        print(\"Adaptive Three Operator Splitting (Paper Algorithm 1, Variant 2)\")\n",
    "        print(f\"Initial γ: {gamma:.6e}\")\n",
    "        print(f\"τ (backtrack): {tau}, β_h (Lipschitz): {beta_h:.2f}\")\n",
    "        print(\"=\" * 80 + \"\\n\")\n",
    "    \n",
    "    # Main Loop\n",
    "    for k in range(max_iter):\n",
    "        delta_k = delta / (k + 1)**2\n",
    "        \n",
    "        # STEP 1: Backtracking Line Search\n",
    "        if k > 0 and delta_prev > 0:\n",
    "            gamma_trial = gamma_prev * growth_factor\n",
    "        else:\n",
    "            gamma_trial = gamma_prev\n",
    "        \n",
    "        backtrack_count = 0\n",
    "        \n",
    "        f_z = eval_f(z)\n",
    "        grad_f_z = grad_f(z)\n",
    "        \n",
    "        while True:\n",
    "            v = z - gamma_trial * u - gamma_trial * grad_f_z\n",
    "            x = prox_g(v, gamma_trial)\n",
    "            \n",
    "            f_x = eval_f(x)\n",
    "            \n",
    "            diff = x - z\n",
    "            linear_term = torch.sum(grad_f_z * diff)\n",
    "            quadratic_term = (1.0 / (2.0 * gamma_trial)) * torch.sum(diff ** 2)\n",
    "            Q_t = f_z + linear_term + quadratic_term\n",
    "            \n",
    "            if f_x <= Q_t + 1e-12:\n",
    "                gamma = gamma_trial\n",
    "                delta_t = (Q_t - f_x).item()\n",
    "                break\n",
    "            else:\n",
    "                gamma_trial = tau * gamma_trial\n",
    "                backtrack_count += 1\n",
    "                \n",
    "                if gamma_trial < 1e-15:\n",
    "                    if verbose:\n",
    "                        print(f\"Warning: γ too small ({gamma_trial:.2e}), accepting anyway\")\n",
    "                    gamma = gamma_trial\n",
    "                    delta_t = (Q_t - f_x).item()\n",
    "                    break\n",
    "        \n",
    "        # STEP 2: Apply Expensive Operator\n",
    "        t0_h = time.time()\n",
    "        v_h = x + gamma * u\n",
    "        z_new = prox_h(v_h, gamma, delta_k)\n",
    "        t_h = time.time() - t0_h\n",
    "        \n",
    "        # STEP 3: Update Dual Variable\n",
    "        u_new = u + (x - z_new) / gamma\n",
    "        \n",
    "        # Convergence Check\n",
    "        with torch.no_grad():\n",
    "            f_x_val = eval_f(x)\n",
    "            g_x_val = eval_g(x)\n",
    "            h_x_val = overlapping_group_penalty_compiled(x, groups, lambda2, weights_np).sum()\n",
    "            obj_x = (f_x_val + g_x_val + h_x_val).item()\n",
    "            \n",
    "            f_z_val = eval_f(z_new)\n",
    "            g_z_val = eval_g(z_new)\n",
    "            h_z_val = overlapping_group_penalty_compiled(z_new, groups, lambda2, weights_np).sum()\n",
    "            obj_z = (f_z_val + g_z_val + h_z_val).item()\n",
    "        \n",
    "        z_change = torch.norm(z_new - z).item()\n",
    "        u_change = torch.norm(u_new - u).item()\n",
    "        \n",
    "        history['objective_x'].append(obj_x)\n",
    "        history['objective_z'].append(obj_z)\n",
    "        history['gamma'].append(gamma)\n",
    "        history['backtracks'].append(backtrack_count)\n",
    "        history['time_prox_h'].append(t_h * 1000)\n",
    "        history['delta_t'].append(delta_t)\n",
    "        \n",
    "        if verbose and (k % 10 == 0 or k < 5):\n",
    "            print(f\"Iter {k:3d} | Obj(x): {obj_x:.6f} | Obj(z): {obj_z:.6f} | \"\n",
    "                  f\"γ: {gamma:.6e} (bt: {backtrack_count}) | \"\n",
    "                  f\"|Δz|: {z_change:.2e} |Δu|: {u_change:.2e} | \"\n",
    "                  f\"prox_h: {t_h*1000:.1f}ms\")\n",
    "        \n",
    "        if z_change < tol and u_change < tol and k > 10:\n",
    "            if verbose:\n",
    "                print(f\"\\n{'='*80}\")\n",
    "                print(f\"Converged at iteration {k}\")\n",
    "                print(f\"Final objective (z): {obj_z:.6e}\")\n",
    "                print(f\"Final objective (x): {obj_x:.6e}\")\n",
    "            break\n",
    "        \n",
    "        z = z_new\n",
    "        u = u_new\n",
    "        gamma_prev = gamma\n",
    "        delta_prev = delta_t\n",
    "    \n",
    "    return z_new.squeeze().cpu().numpy(), history\n",
    "\n",
    "\n",
    "# ============================================================================\n",
    "# Helper Functions for FoGLasso\n",
    "# ============================================================================\n",
    "\n",
    "def soft_threshold(v, tau):\n",
    "    \"\"\"Soft-thresholding S_tau(v) applied elementwise.\"\"\"\n",
    "    return np.sign(v) * np.maximum(np.abs(v) - tau, 0.0)\n",
    "\n",
    "\n",
    "def estimate_L_xtx(X, iters=25, seed=0):\n",
    "    \"\"\"\n",
    "    Power iteration estimate of ||X^T X||_2 = sigma_max(X)^2 (no SVD needed).\n",
    "    \"\"\"\n",
    "    rng = np.random.default_rng(seed)\n",
    "    p = X.shape[1]\n",
    "    v = rng.normal(size=p)\n",
    "    v /= np.linalg.norm(v) + 1e-12\n",
    "    for _ in range(iters):\n",
    "        v = X.T @ (X @ v)\n",
    "        v /= (np.linalg.norm(v) + 1e-12)\n",
    "    XtXv = X.T @ (X @ v)\n",
    "    return float(v @ XtXv)\n",
    "\n",
    "\n",
    "def preprocess_zero_groups(u, groups, lambda2, weights, max_iter=100, tol=0.0):\n",
    "    \"\"\"\n",
    "    Paper-faithful screening (Lemma 3 iterative procedure):\n",
    "    Cycle through groups; if ||u_Gi|| <= lambda2 * w_i, set u_Gi = 0.\n",
    "    Repeat until u does not change (or max_iter reached).\n",
    "    \"\"\"\n",
    "    u = u.copy()\n",
    "    g = len(groups)\n",
    "    group_is_zero = np.zeros(g, dtype=bool)\n",
    "\n",
    "    changed = True\n",
    "    it = 0\n",
    "    while changed and it < max_iter:\n",
    "        changed = False\n",
    "        for i, Gi in enumerate(groups):\n",
    "            if group_is_zero[i]:\n",
    "                continue\n",
    "            if np.linalg.norm(u[Gi]) <= (lambda2 * weights[i] + tol):\n",
    "                group_is_zero[i] = True\n",
    "                u[Gi] = 0.0\n",
    "                changed = True\n",
    "        it += 1\n",
    "\n",
    "    active_idx = np.flatnonzero(u)\n",
    "    if active_idx.size == 0:\n",
    "        return u[active_idx], [], np.array([], dtype=int), np.array([], dtype=float)\n",
    "\n",
    "    index_map = -np.ones(u.shape[0], dtype=np.int64)\n",
    "    index_map[active_idx] = np.arange(active_idx.size)\n",
    "\n",
    "    active_groups = []\n",
    "    active_weights = []\n",
    "    for i, Gi in enumerate(groups):\n",
    "        if group_is_zero[i]:\n",
    "            continue\n",
    "        Gi_new = index_map[Gi]\n",
    "        Gi_new = Gi_new[Gi_new >= 0]\n",
    "        if Gi_new.size == 0:\n",
    "            continue\n",
    "        active_groups.append(Gi_new)\n",
    "        active_weights.append(weights[i])\n",
    "\n",
    "    return u[active_idx], active_groups, active_idx, np.asarray(active_weights)\n",
    "\n",
    "\n",
    "def _dual_lipschitz_overlap_bound(p, groups):\n",
    "    \"\"\"\n",
    "    A safe Lipschitz bound for the dual gradient based on maximum overlap count:\n",
    "    L_dual <= max_j #{i : j in G_i}\n",
    "    \"\"\"\n",
    "    if p <= 0 or len(groups) == 0:\n",
    "        return 0.0\n",
    "    counts = np.zeros(p, dtype=np.int32)\n",
    "    for Gi in groups:\n",
    "        counts[Gi] += 1\n",
    "    return float(counts.max()) if counts.size else 0.0\n",
    "\n",
    "\n",
    "def solve_dual_problem_sparse(u, groups, lambda2, weights,\n",
    "                              max_iter=2000,\n",
    "                              gap_tol=1e-10,\n",
    "                              check_every=1,\n",
    "                              verbose=False):\n",
    "    \"\"\"\n",
    "    Solve the smooth dual (Eq. (15)) with accelerated projected gradient (AGD),\n",
    "    and recover primal x via x = max(u - Y e, 0) (Eq. (14)).\n",
    "\n",
    "    Terminate when the estimated duality gap < gap_tol (Theorem 2).\n",
    "    \"\"\"\n",
    "    u = np.asarray(u)\n",
    "    p = len(u)\n",
    "    g = len(groups)\n",
    "    dtype = u.dtype\n",
    "\n",
    "    if p == 0 or g == 0:\n",
    "        return u.copy(), [np.zeros(0, dtype=dtype) for _ in range(g)], {'duality_gap': []}\n",
    "\n",
    "    L_dual = _dual_lipschitz_overlap_bound(p, groups)\n",
    "    if L_dual <= 0.0:\n",
    "        return u.copy(), [np.zeros(0, dtype=dtype) for _ in range(g)], {'duality_gap': []}\n",
    "    step = 1.0 / L_dual\n",
    "\n",
    "    # FISTA/AGD sequences on Y\n",
    "    Yk = [np.zeros(len(Gi), dtype=dtype) for Gi in groups]\n",
    "    Zk = [y.copy() for y in Yk]\n",
    "    t = 1.0\n",
    "\n",
    "    hist = {'duality_gap': []}\n",
    "\n",
    "    for it in range(max_iter):\n",
    "        # Ye = Y e (accumulate overlaps)\n",
    "        Ye = np.zeros(p, dtype=dtype)\n",
    "        for i, Gi in enumerate(groups):\n",
    "            Ye[Gi] += Zk[i]\n",
    "\n",
    "        # Primal from current dual iterate\n",
    "        x = np.maximum(u - Ye, 0.0)\n",
    "\n",
    "        # Gradient step on dual + projection onto Omega\n",
    "        Ynew = []\n",
    "        Ye_new = np.zeros(p, dtype=dtype)\n",
    "        for i, Gi in enumerate(groups):\n",
    "            y = Zk[i] + step * x[Gi]\n",
    "            r = lambda2 * weights[i]\n",
    "            nrm = np.linalg.norm(y)\n",
    "            if nrm > r:\n",
    "                y = (r / nrm) * y\n",
    "            Ynew.append(y)\n",
    "            Ye_new[Gi] += y\n",
    "\n",
    "        # Duality gap check\n",
    "        if (it % check_every) == 0:\n",
    "            x_new = np.maximum(u - Ye_new, 0.0)\n",
    "            gap = 0.0\n",
    "            for i, Gi in enumerate(groups):\n",
    "                xi = x_new[Gi]\n",
    "                gap += (lambda2 * weights[i]) * np.linalg.norm(xi) - float(np.dot(xi, Ynew[i]))\n",
    "\n",
    "            hist['duality_gap'].append(gap)\n",
    "\n",
    "            if verbose:\n",
    "                print(f\"[prox-dual] it={it} gap={gap:.3e}\")\n",
    "\n",
    "            if gap < gap_tol:\n",
    "                Yk = Ynew\n",
    "                break\n",
    "\n",
    "        # FISTA extrapolation\n",
    "        t_new = (1.0 + np.sqrt(1.0 + 4.0 * t * t)) / 2.0\n",
    "        beta = (t - 1.0) / t_new\n",
    "        Zk = [Ynew[i] + beta * (Ynew[i] - Yk[i]) for i in range(g)]\n",
    "        Yk = Ynew\n",
    "        t = t_new\n",
    "\n",
    "    # Recover primal solution\n",
    "    Ye = np.zeros(p, dtype=dtype)\n",
    "    for i, Gi in enumerate(groups):\n",
    "        Ye[Gi] += Yk[i]\n",
    "    x_prox = np.maximum(u - Ye, 0.0)\n",
    "\n",
    "    return x_prox, Yk, hist\n",
    "\n",
    "\n",
    "def compute_proximal_operator_foglasso(v, groups, lambda1, lambda2, weights,\n",
    "                                       prox_max_iter=2000,\n",
    "                                       prox_gap_tol=1e-5,\n",
    "                                       screen_max_iter=100,\n",
    "                                       verbose=False):\n",
    "    \"\"\"\n",
    "    Proximal operator π_{λ1,λ2}(v) with screening and dual solving.\n",
    "\n",
    "    Implements:\n",
    "      1) Theorem 1: reduce λ1>0 via soft-thresholding and handle sign\n",
    "      2) Lemma 3: iterative screening of zero groups\n",
    "      3) Solve smooth dual (Eq. (15)) via AGD and recover primal (Eq. (14))\n",
    "      4) Restore sign\n",
    "    \"\"\"\n",
    "    v = np.asarray(v)\n",
    "    groups = [np.asarray(g, dtype=np.int64) for g in groups]\n",
    "    weights = np.asarray(weights, dtype=float)\n",
    "\n",
    "    if lambda1 < 0 or lambda2 < 0:\n",
    "        raise ValueError(\"lambda1 and lambda2 must be >= 0.\")\n",
    "    if len(groups) != len(weights):\n",
    "        raise ValueError(\"weights must have the same length as groups.\")\n",
    "\n",
    "    # Theorem 1 reduction + sign handling\n",
    "    sgn = np.sign(v)\n",
    "    u = np.maximum(np.abs(v) - lambda1, 0.0)\n",
    "\n",
    "    if not np.any(u):\n",
    "        return np.zeros_like(v), {'duality_gap': []}\n",
    "\n",
    "    if lambda2 == 0.0 or len(groups) == 0:\n",
    "        return sgn * u, {'duality_gap': []}\n",
    "\n",
    "    # Lemma 3 screening\n",
    "    u_red, active_groups, active_idx, active_w = preprocess_zero_groups(\n",
    "        u, groups, lambda2, weights, max_iter=screen_max_iter, tol=0.0\n",
    "    )\n",
    "\n",
    "    if active_idx.size == 0:\n",
    "        return np.zeros_like(v), {'duality_gap': []}\n",
    "\n",
    "    # Dual solve on reduced problem\n",
    "    x_red, _, hist = solve_dual_problem_sparse(\n",
    "        u_red, active_groups, lambda2, active_w,\n",
    "        max_iter=prox_max_iter, gap_tol=prox_gap_tol, check_every=1, verbose=verbose\n",
    "    )\n",
    "\n",
    "    # Reconstruct full solution and restore sign\n",
    "    x_abs = np.zeros_like(v, dtype=u.dtype)\n",
    "    x_abs[active_idx] = x_red\n",
    "\n",
    "    return sgn * x_abs, hist\n",
    "\n",
    "\n",
    "# ============================================================================\n",
    "# Algorithm 2: FoGLasso (Dual Solver with Analytical Proximal Operators) (Implemented exactly as in the 2011 NeurIPS paper)\n",
    "# ============================================================================\n",
    "\n",
    "def foglasso(X, y, groups, lambda1, lambda2, weights=None,\n",
    "             L0=1.0,\n",
    "             max_iter=1000,\n",
    "             obj_tol=1e-5,\n",
    "             prox_gap_tol=1e-5,\n",
    "             prox_max_iter=1000,\n",
    "             verbose=True):\n",
    "    \"\"\"\n",
    "    Paper-faithful FoGLasso for least squares:\n",
    "\n",
    "        min_x  (1/2)||Xx - y||^2 + λ1||x||_1 + λ2 Σ_i w_i ||x_{G_i}||_2\n",
    "\n",
    "    Uses backtracking line search and dual solver for proximal operator.\n",
    "    Terminates when adjacent objective change <= obj_tol.\n",
    "    \"\"\"\n",
    "    X = np.asarray(X)\n",
    "    y = np.asarray(y)\n",
    "    n, p = X.shape\n",
    "\n",
    "    groups = [np.asarray(g, dtype=np.int64) for g in groups]\n",
    "    if weights is None:\n",
    "        weights = np.array([np.sqrt(len(g)) for g in groups], dtype=float)\n",
    "    else:\n",
    "        weights = np.asarray(weights, dtype=float)\n",
    "\n",
    "    # Initialize\n",
    "    x_prev = np.zeros(p, dtype=float)\n",
    "    x = x_prev.copy()\n",
    "\n",
    "    alpha_old = 0.0\n",
    "    alpha = 1.0\n",
    "\n",
    "    L_prev = float(L0)\n",
    "    if L_prev <= 0:\n",
    "        raise ValueError(\"L0 must be > 0.\")\n",
    "\n",
    "    hist = {\n",
    "        \"objective\": [],\n",
    "        \"loss\": [],\n",
    "        \"penalty\": [],\n",
    "        \"L\": [],\n",
    "    }\n",
    "\n",
    "    def penalty_val(z):\n",
    "        val = lambda1 * np.sum(np.abs(z))\n",
    "        for i, Gi in enumerate(groups):\n",
    "            val += lambda2 * weights[i] * np.linalg.norm(z[Gi])\n",
    "        return float(val)\n",
    "\n",
    "    def loss_and_grad(z):\n",
    "        r = X @ z - y\n",
    "        loss = 0.5 * float(np.dot(r, r))\n",
    "        grad = X.T @ r\n",
    "        return loss, grad\n",
    "\n",
    "    # Record initial objective\n",
    "    loss0, _ = loss_and_grad(x)\n",
    "    obj0 = loss0 + penalty_val(x)\n",
    "    hist[\"objective\"].append(obj0)\n",
    "    hist[\"loss\"].append(loss0)\n",
    "    hist[\"penalty\"].append(obj0 - loss0)\n",
    "    hist[\"L\"].append(L_prev)\n",
    "\n",
    "    for it in range(1, max_iter + 1):\n",
    "        # Step 3: s_i = x_i + β_i (x_i - x_{i-1})\n",
    "        beta = (alpha_old - 1.0) / alpha\n",
    "        s = x + beta * (x - x_prev)\n",
    "\n",
    "        # Gradient at s_i\n",
    "        loss_s, grad_s = loss_and_grad(s)\n",
    "\n",
    "        # Step 4: backtracking line search\n",
    "        L_trial = L_prev\n",
    "\n",
    "        while True:\n",
    "            v = s - (1.0 / L_trial) * grad_s\n",
    "\n",
    "            x_new, prox_hist = compute_proximal_operator_foglasso(\n",
    "                v,\n",
    "                groups,\n",
    "                lambda1=lambda1 / L_trial,\n",
    "                lambda2=lambda2 / L_trial,\n",
    "                weights=weights,\n",
    "                prox_max_iter=prox_max_iter,\n",
    "                prox_gap_tol=prox_gap_tol,\n",
    "                verbose=False\n",
    "            )\n",
    "\n",
    "            # Check majorization\n",
    "            loss_new, _ = loss_and_grad(x_new)\n",
    "            dx = x_new - s\n",
    "            rhs = loss_s + float(np.dot(grad_s, dx)) + 0.5 * L_trial * float(np.dot(dx, dx))\n",
    "\n",
    "            if loss_new <= rhs:\n",
    "                break\n",
    "\n",
    "            L_trial *= 2.0\n",
    "\n",
    "        L_prev = L_trial\n",
    "\n",
    "        # Step 5: α_{i+1} = (1 + sqrt(1 + 4 α_i^2))/2\n",
    "        alpha_new = (1.0 + np.sqrt(1.0 + 4.0 * alpha * alpha)) / 2.0\n",
    "\n",
    "        # Objective bookkeeping\n",
    "        pen = penalty_val(x_new)\n",
    "        obj = loss_new + pen\n",
    "\n",
    "        hist[\"objective\"].append(obj)\n",
    "        hist[\"loss\"].append(loss_new)\n",
    "        hist[\"penalty\"].append(pen)\n",
    "        hist[\"L\"].append(L_prev)\n",
    "\n",
    "        if verbose and (it % 10 == 0 or it == 1):\n",
    "            obj_change = abs(hist[\"objective\"][-1] - hist[\"objective\"][-2])\n",
    "            print(f\"Iter {it:4d}: Obj={obj:.6f}, Loss={loss_new:.6e}, \"\n",
    "                  f\"Pen={pen:.6f}, |ΔObj|={obj_change:.2e}, L={L_prev:.2e}\")\n",
    "\n",
    "        # Stopping criterion\n",
    "        if abs(hist[\"objective\"][-1] - hist[\"objective\"][-2]) <= obj_tol:\n",
    "            if verbose:\n",
    "                print(f\"Converged at iter {it} (|ΔObj| <= {obj_tol:g}).\")\n",
    "            x = x_new\n",
    "            break\n",
    "\n",
    "        x_prev = x\n",
    "        x = x_new\n",
    "        alpha_old = alpha\n",
    "        alpha = alpha_new\n",
    "\n",
    "    return x, hist\n",
    "\n",
    "\n",
    "print(\"✓ All algorithms and helper functions loaded successfully\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ============================================================================\n",
    "# CHUNK 2: DATA LOADING\n",
    "# ============================================================================\n",
    "\n",
    "def get_pathway_groups(gene_names):\n",
    "    \"\"\"Get KEGG pathway groups for genes.\"\"\"\n",
    "    pathways = gp.get_library(name='KEGG_2021_Human')\n",
    "    gene_to_idx = {gene: i for i, gene in enumerate(gene_names)}\n",
    "\n",
    "    groups = []\n",
    "    pathway_names = []\n",
    "\n",
    "    for pathway_name, pathway_genes in pathways.items():\n",
    "        idx = [gene_to_idx[g] for g in pathway_genes if g in gene_to_idx]\n",
    "        idx = np.unique(idx)\n",
    "        idx = np.sort(idx).astype(int)\n",
    "\n",
    "        if 5 <= idx.size <= 200:\n",
    "            groups.append(idx)\n",
    "            pathway_names.append(pathway_name)\n",
    "\n",
    "    return groups, pathway_names\n",
    "\n",
    "\n",
    "def load_real_gse2034_data_FIXED():\n",
    "    \"\"\"Load GSE2034 with correct relapse labels.\"\"\"\n",
    "    print(\"Downloading GSE2034 (Breast Cancer)...\")\n",
    "    gse = GEOparse.get_GEO(geo=\"GSE2034\", destdir=\"./data\")\n",
    "    \n",
    "    print(\"Extracting clinical labels from metadata...\")\n",
    "    y_labels = []\n",
    "    sample_names = []\n",
    "    \n",
    "    for gsm_name, gsm in gse.gsms.items():\n",
    "        chars = gsm.metadata.get('characteristics_ch1', [])\n",
    "        \n",
    "        relapse_status = None\n",
    "        for char in chars:\n",
    "            char_lower = str(char).lower()\n",
    "            if 'relapse' in char_lower and ':' in char_lower:\n",
    "                value = char_lower.split(':')[-1].strip()\n",
    "                if value in ['0', '1']:\n",
    "                    relapse_status = int(value)\n",
    "                    break\n",
    "        \n",
    "        if relapse_status is not None:\n",
    "            y_labels.append(relapse_status)\n",
    "            sample_names.append(gsm_name)\n",
    "    \n",
    "    y = np.array(y_labels)\n",
    "    print(f\"Extracted labels for {len(y)} samples\")\n",
    "    print(f\"Class distribution: {np.bincount(y)} (0=No Relapse, 1=Relapse)\")\n",
    "    \n",
    "    if len(y) == 0 or len(np.unique(y)) < 2:\n",
    "        raise ValueError(f\"Label extraction failed! Got {len(y)} samples with {len(np.unique(y))} classes\")\n",
    "    \n",
    "    # Build expression matrix\n",
    "    print(\"Building expression matrix...\")\n",
    "    pivoted_data = gse.pivot_samples('VALUE')\n",
    "    pivoted_data = pivoted_data[sample_names]\n",
    "    \n",
    "    # Map probes to genes\n",
    "    print(\"Mapping probes to gene symbols...\")\n",
    "    platform_name = list(gse.gpls.keys())[0]\n",
    "    gpl = gse.gpls[platform_name]\n",
    "    \n",
    "    probe_to_gene = {}\n",
    "    for _, row in gpl.table.iterrows():\n",
    "        sym = None\n",
    "        for col in ['Gene Symbol', 'GENE_SYMBOL', 'Symbol']:\n",
    "            if col in row and pd.notna(row[col]) and row[col] != '':\n",
    "                sym = str(row[col]).split('///')[0].strip()\n",
    "                break\n",
    "        if sym and sym != '---':\n",
    "            probe_to_gene[row['ID']] = sym\n",
    "    \n",
    "    pivoted_data['Gene'] = pivoted_data.index.map(probe_to_gene)\n",
    "    pivoted_data = pivoted_data.dropna(subset=['Gene'])\n",
    "    \n",
    "    # Average duplicates\n",
    "    gene_expression_df = pivoted_data.groupby('Gene').mean()\n",
    "    \n",
    "    X = gene_expression_df.T.values\n",
    "    gene_names = gene_expression_df.index.tolist()\n",
    "    \n",
    "    # Log transform if needed\n",
    "    if np.max(X) > 100:\n",
    "        print(\"Applying Log2(x+1) transform...\")\n",
    "        X = np.log2(X + 1)\n",
    "    \n",
    "    X = StandardScaler().fit_transform(X)\n",
    "    \n",
    "    print(f\"Final: {X.shape[0]} samples x {X.shape[1]} genes\")\n",
    "    print(f\"Verification: y has {np.sum(y==0)} non-relapse, {np.sum(y==1)} relapse\")\n",
    "    \n",
    "    return X, y, gene_names\n",
    "\n",
    "\n",
    "def load_gse2034_filtered(top_k=1000, method='ttest'):\n",
    "    \"\"\"Load GSE2034 with feature selection.\"\"\"\n",
    "    X, y, gene_names = load_real_gse2034_data_FIXED()\n",
    "    \n",
    "    if method == 'variance':\n",
    "        variances = np.var(X, axis=0)\n",
    "        top_indices = np.argsort(variances)[-top_k:]\n",
    "    elif method == 'ttest':\n",
    "        from scipy.stats import ttest_ind\n",
    "        t_stats = []\n",
    "        for i in range(X.shape[1]):\n",
    "            t_stat, _ = ttest_ind(X[y==0, i], X[y==1, i])\n",
    "            t_stats.append(abs(t_stat))\n",
    "        top_indices = np.argsort(t_stats)[-top_k:]\n",
    "    else:\n",
    "        raise ValueError(f\"Unknown method: {method}\")\n",
    "    \n",
    "    X_filtered = X[:, top_indices]\n",
    "    gene_names_filtered = [gene_names[i] for i in top_indices]\n",
    "    \n",
    "    print(f\"Reduced using {method}: {X.shape[1]} → {X_filtered.shape[1]} genes\")\n",
    "    return X_filtered, y, gene_names_filtered\n",
    "\n",
    "\n",
    "print(\"\\n\" + \"=\"*80)\n",
    "print(\"Loading GSE2034 Breast Cancer Data...\")\n",
    "print(\"=\"*80)\n",
    "\n",
    "X, y, gene_names = load_gse2034_filtered(13237)\n",
    "groups, pathway_names = get_pathway_groups(gene_names)\n",
    "\n",
    "# Transform Y for numerical stability\n",
    "y = y * 2 - 1\n",
    "\n",
    "print(f\"\\n✓ Data loaded: {X.shape[0]} samples, {X.shape[1]} genes\")\n",
    "print(f\"✓ Pathways: {len(groups)} overlapping groups\")\n",
    "print(f\"✓ Response transformed: y ∈ {{-1, +1}}\")\n",
    "\n",
    "# Set parameters\n",
    "lam_l1 = 0.005\n",
    "lam_g = 0.0001\n",
    "\n",
    "print(f\"✓ Penalty parameters: λ1={lam_l1}, λ2={lam_g}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ============================================================================\n",
    "# CHUNK 3: RUN ALGORITHM 1 - FoGLasso\n",
    "# ============================================================================\n",
    "\n",
    "print(\"\\n\" + \"=\"*80)\n",
    "print(\"Running Algorithm 1: FoGLasso (Dual Solver)...\")\n",
    "print(\"=\"*80)\n",
    "\n",
    "start_time = time.time()\n",
    "\n",
    "beta_FoGLasso, hist_FoGLasso = foglasso(\n",
    "    X, y, groups, lambda1=lam_l1, lambda2=lam_g, L0=1.0,\n",
    "    max_iter=10000, verbose=True, obj_tol=1e-5\n",
    ")\n",
    "\n",
    "elapsed_time = time.time() - start_time\n",
    "\n",
    "print(f\"\\n✓ FoGLasso completed\")\n",
    "print(f\"  - Runtime: {elapsed_time:.2f} seconds\")\n",
    "print(f\"  - Converged in {len(hist_FoGLasso['objective'])-1} iterations\")\n",
    "print(f\"  - Final objective: {hist_FoGLasso['objective'][-1]:.6f}\")\n",
    "print(f\"  - Final penalty: {hist_FoGLasso['penalty'][-1]:.6f}\")\n",
    "print(f\"  - Final loss: {hist_FoGLasso['loss'][-1]:.6f}\")\n",
    "print(f\"  - ||β||: {np.linalg.norm(beta_FoGLasso):.4f}\")\n",
    "print(f\"  - Non-zero coefficients: {np.count_nonzero(beta_FoGLasso)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ============================================================================\n",
    "# CHUNK 4: RUN ALGORITHM 2 - Davis-Yin with HJ-Prox\n",
    "# ============================================================================\n",
    "\n",
    "print(\"\\n\" + \"=\"*80)\n",
    "print(\"Running Algorithm 2: Davis-Yin with HJ-Prox...\")\n",
    "print(\"=\"*80)\n",
    "\n",
    "start_time = time.time()\n",
    "\n",
    "beta_HJ, hist_HJ = davis_yin_adaptive_overlapping_group_lasso(\n",
    "    X, y, groups,\n",
    "    lambda1=lam_l1,\n",
    "    lambda2=lam_g,\n",
    "    delta=10,\n",
    "    int_samples=1000,\n",
    "    max_iter=100000,\n",
    "    verbose=True,\n",
    "    tol=1e-1000\n",
    ")\n",
    "\n",
    "elapsed_time = time.time() - start_time\n",
    "\n",
    "print(f\"\\n✓ Davis-Yin HJ completed\")\n",
    "print(f\"  - Runtime: {elapsed_time:.2f} seconds\")\n",
    "print(f\"  - Converged in {len(hist_HJ['objective_z'])} iterations\")\n",
    "print(f\"  - Final objective: {hist_HJ['objective_z'][-1]:.6f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ============================================================================\n",
    "# CHUNK 5: ANALYSIS - Active Pathway Identification\n",
    "# ============================================================================\n",
    "\n",
    "def analyze_active_pathways(beta_fog, groups, pathway_names, gene_names, top_n=20, active_tol=1e-6):\n",
    "    \"\"\"\n",
    "    Identify and analyze the most active pathways in results.\n",
    "    \n",
    "    Parameters:\n",
    "    -----------\n",
    "    beta_fog : array - Solution coefficients\n",
    "    groups : list of arrays - group indices for each pathway\n",
    "    pathway_names : list of str - pathway names corresponding to groups\n",
    "    gene_names : list of str - gene symbols\n",
    "    top_n : int - number of top pathways to display\n",
    "    active_tol : float - tolerance for considering a group active\n",
    "    \n",
    "    Returns:\n",
    "    --------\n",
    "    results_df : DataFrame with pathway analysis\n",
    "    \"\"\"\n",
    "    print(\"=\"*80)\n",
    "    print(\"ANALYZING ACTIVE PATHWAYS\")\n",
    "    print(\"=\"*80)\n",
    "    \n",
    "    # Compute group norms\n",
    "    group_norms = np.array([np.linalg.norm(beta_fog[np.asarray(Gi, dtype=int)]) for Gi in groups])\n",
    "    \n",
    "    # Compute pathway sizes\n",
    "    pathway_sizes = np.array([len(Gi) for Gi in groups])\n",
    "    \n",
    "    # Compute normalized group norms (SIZE-ADJUSTED)\n",
    "    normalized_group_norms = group_norms / np.sqrt(pathway_sizes)\n",
    "    \n",
    "    # Overall statistics\n",
    "    n_active = np.sum(group_norms > active_tol)\n",
    "    print(f\"\\nOverall Statistics:\")\n",
    "    print(f\"  Total pathways: {len(groups)}\")\n",
    "    print(f\"  Active pathways: {n_active} ({100*n_active/len(groups):.1f}%)\")\n",
    "    print(f\"  Inactive pathways: {len(groups) - n_active}\")\n",
    "    \n",
    "    print(f\"\\nGroup norm distribution:\")\n",
    "    print(f\"  Min: {group_norms.min():.3e}\")\n",
    "    print(f\"  Q1: {np.percentile(group_norms, 25):.3e}\")\n",
    "    print(f\"  Median: {np.median(group_norms):.3e}\")\n",
    "    print(f\"  Q3: {np.percentile(group_norms, 75):.3e}\")\n",
    "    print(f\"  Max: {group_norms.max():.3e}\")\n",
    "    \n",
    "    # Build results table\n",
    "    results = []\n",
    "    for i, Gi in enumerate(groups):\n",
    "        pathway_name = pathway_names[i]\n",
    "        group_norm = group_norms[i]\n",
    "        normalized_norm = normalized_group_norms[i]\n",
    "        pathway_size = pathway_sizes[i]\n",
    "        \n",
    "        genes_in_pathway = [gene_names[idx] for idx in Gi]\n",
    "        coeffs_in_pathway = beta_fog[Gi]\n",
    "        \n",
    "        n_nonzero = np.sum(np.abs(coeffs_in_pathway) > 1e-10)\n",
    "        avg_abs_coeff = np.mean(np.abs(coeffs_in_pathway))\n",
    "        max_abs_coeff = np.max(np.abs(coeffs_in_pathway))\n",
    "        \n",
    "        results.append({\n",
    "            'pathway_index': i,\n",
    "            'pathway_name': pathway_name,\n",
    "            'group_norm': group_norm,\n",
    "            'normalized_group_norm': normalized_norm,\n",
    "            'pathway_size': pathway_size,\n",
    "            'n_nonzero_genes': n_nonzero,\n",
    "            'pct_genes_active': 100 * n_nonzero / pathway_size,\n",
    "            'avg_abs_coeff': avg_abs_coeff,\n",
    "            'max_abs_coeff': max_abs_coeff,\n",
    "            'genes': genes_in_pathway,\n",
    "            'coefficients': coeffs_in_pathway\n",
    "        })\n",
    "    \n",
    "    results_df = pd.DataFrame(results)\n",
    "    \n",
    "    # Display rankings\n",
    "    print(f\"\\n{'='*80}\")\n",
    "    print(f\"TOP {top_n} PATHWAYS BY RAW GROUP NORM (||β_G||)\")\n",
    "    print(f\"{'='*80}\")\n",
    "    print(\"⚠️  WARNING: This ranking is BIASED by pathway size!\")\n",
    "    print(\"-\"*80)\n",
    "    \n",
    "    results_by_raw = results_df.sort_values('group_norm', ascending=False)\n",
    "    print(f\"\\n{'Rank':<6} {'Pathway':<60} {'||β_G||':<12}\")\n",
    "    print(\"-\"*80)\n",
    "    for rank, (idx, row) in enumerate(results_by_raw.head(top_n).iterrows(), 1):\n",
    "        pathway_short = row['pathway_name'][:58] + \"..\" if len(row['pathway_name']) > 60 else row['pathway_name']\n",
    "        print(f\"{rank:<6} {pathway_short:<60} {row['group_norm']:<12.4e}\")\n",
    "    \n",
    "    print(f\"\\n{'='*80}\")\n",
    "    print(f\"TOP {top_n} PATHWAYS BY NORMALIZED GROUP NORM (||β_G|| / √|G|)\")\n",
    "    print(f\"{'='*80}\")\n",
    "    print(\"✅ RECOMMENDED: This ranking is SIZE-ADJUSTED and biologically meaningful!\")\n",
    "    print(\"-\"*80)\n",
    "    \n",
    "    results_by_normalized = results_df.sort_values('normalized_group_norm', ascending=False)\n",
    "    print(f\"\\n{'Rank':<6} {'Pathway':<60} {'Norm/√Size':<12}\")\n",
    "    print(\"-\"*80)\n",
    "    for rank, (idx, row) in enumerate(results_by_normalized.head(top_n).iterrows(), 1):\n",
    "        pathway_short = row['pathway_name'][:58] + \"..\" if len(row['pathway_name']) > 60 else row['pathway_name']\n",
    "        print(f\"{rank:<6} {pathway_short:<60} {row['normalized_group_norm']:<12.4e}\")\n",
    "    \n",
    "    return results_by_normalized\n",
    "\n",
    "\n",
    "print(\"\\n\" + \"=\"*80)\n",
    "print(\"Analyzing FoGLasso Results...\")\n",
    "print(\"=\"*80)\n",
    "\n",
    "results_FoGLasso = analyze_active_pathways(\n",
    "    beta_fog=beta_FoGLasso,\n",
    "    groups=groups,\n",
    "    pathway_names=pathway_names,\n",
    "    gene_names=gene_names,\n",
    "    top_n=3\n",
    ")\n",
    "\n",
    "print(\"\\n\" + \"=\"*80)\n",
    "print(\"Analyzing Davis-Yin HJ Results...\")\n",
    "print(\"=\"*80)\n",
    "\n",
    "results_HJ = analyze_active_pathways(\n",
    "    beta_fog=beta_HJ,\n",
    "    groups=groups,\n",
    "    pathway_names=pathway_names,\n",
    "    gene_names=gene_names,\n",
    "    top_n=3\n",
    ")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ============================================================================\n",
    "# CHUNK 6: GENERATE FIGURES\n",
    "# ============================================================================\n",
    "\n",
    "print(\"\\n\" + \"=\"*80)\n",
    "print(\"Generating figures...\")\n",
    "print(\"=\"*80)\n",
    "\n",
    "\n",
    "def correlation_scatter(\n",
    "    beta1,\n",
    "    beta2,\n",
    "    labels=('HJ-Prox', 'FoGLasso'),\n",
    "    figsize=(8, 8),\n",
    "    diagonal_color='#B71C1C',\n",
    "    diagonal_style='--',\n",
    "    diagonal_width=2.5,\n",
    "    point_size=20,\n",
    "    point_alpha=0.5,\n",
    "    show_legend=True,\n",
    "    fontsize_axes=14,\n",
    "    fontsize_title=16,\n",
    "    fontsize_corr=12\n",
    "):\n",
    "    \"\"\"\n",
    "    Draw coefficient correlation scatter plot.\n",
    "\n",
    "    Returns:\n",
    "        fig: matplotlib.figure.Figure\n",
    "    \"\"\"\n",
    "    b1 = np.asarray(beta1).ravel()\n",
    "    b2 = np.asarray(beta2).ravel()\n",
    "    assert b1.shape == b2.shape, \"beta1 and beta2 must have same shape\"\n",
    "\n",
    "    pearson_corr = np.corrcoef(b1, b2)[0, 1]\n",
    "\n",
    "    fig, ax = plt.subplots(figsize=figsize)\n",
    "\n",
    "    ax.scatter(b1, b2, s=point_size, alpha=point_alpha, edgecolors='none')\n",
    "\n",
    "    # Diagonal (perfect agreement)\n",
    "    lim_min = min(b1.min(), b2.min())\n",
    "    lim_max = max(b1.max(), b2.max())\n",
    "    ax.plot(\n",
    "        [lim_min, lim_max],\n",
    "        [lim_min, lim_max],\n",
    "        linestyle=diagonal_style,\n",
    "        linewidth=diagonal_width,\n",
    "        color=diagonal_color,\n",
    "        alpha=0.9,\n",
    "        label='Perfect agreement' if show_legend else None\n",
    "    )\n",
    "\n",
    "    ax.set_xlabel(f'{labels[0]} coefficients', fontsize=fontsize_axes, fontweight='bold')\n",
    "    ax.set_ylabel(f'{labels[1]} coefficients', fontsize=fontsize_axes, fontweight='bold')\n",
    "    ax.set_title('Coefficient Comparison', fontsize=fontsize_title, fontweight='bold')\n",
    "\n",
    "    # Pearson correlation text\n",
    "    ax.text(\n",
    "        0.02, 0.98,\n",
    "        f'Pearson r = {pearson_corr:.5f}',\n",
    "        transform=ax.transAxes,\n",
    "        ha='left', va='top',\n",
    "        fontsize=fontsize_corr, fontweight='bold',\n",
    "        bbox=dict(boxstyle='round,pad=0.3', alpha=0.15)\n",
    "    )\n",
    "\n",
    "    ax.axhline(0, color='k', lw=0.8, alpha=0.3)\n",
    "    ax.axvline(0, color='k', lw=0.8, alpha=0.3)\n",
    "    ax.grid(True, alpha=0.3)\n",
    "    ax.set_aspect('equal', adjustable='box')\n",
    "\n",
    "    if show_legend:\n",
    "        ax.legend()\n",
    "\n",
    "    fig.tight_layout()\n",
    "    return fig\n",
    "\n",
    "\n",
    "# --- Figure 1: Coefficient Correlation ---\n",
    "fig = correlation_scatter(beta_HJ, beta_FoGLasso,\n",
    "                          labels=('Davis-Yin (HJ-Prox)', 'FoGLasso (Dual)'),\n",
    "                          figsize=(8, 8))\n",
    "fig.savefig('coeff_correlation.pdf', bbox_inches='tight')\n",
    "plt.show()\n",
    "plt.close()\n",
    "\n",
    "# --- Figure 2: Objective Convergence ---\n",
    "limit = 1000\n",
    "obj_FoGLasso = hist_FoGLasso[\"objective\"][:limit]\n",
    "obj_HJ = hist_HJ[\"objective_z\"][:limit]\n",
    "\n",
    "plt.figure(figsize=(10, 8))\n",
    "\n",
    "plt.semilogy(obj_FoGLasso, '-', linewidth=3,\n",
    "             label=f'FOGLASSO: {obj_FoGLasso[-1]:.3f}')\n",
    "plt.semilogy(obj_HJ, '--', linewidth=3,\n",
    "             label=f'DYS-HJ: {obj_HJ[-1]:.3f}')\n",
    "\n",
    "plt.ylabel('Objective Value (log scale)', fontsize=30)\n",
    "plt.xlabel('Iteration', fontsize=40)\n",
    "plt.title('Overlapping Group LASSO', fontsize=40)\n",
    "plt.legend(fontsize=35, loc='upper right')\n",
    "plt.grid(True, alpha=0.3, which='both')\n",
    "plt.tick_params(axis='both', which='major', labelsize=40)\n",
    "plt.tight_layout()\n",
    "plt.savefig('OGLASSO_objective_convergence.pdf', format='pdf', dpi=600, bbox_inches='tight')\n",
    "plt.show()\n",
    "plt.close()\n",
    "\n",
    "print(\"✓ All figures saved: coeff_correlation.pdf, OGLASSO_objective_convergence.pdf\")\n",
    "print(\"\\n\" + \"=\"*80)\n",
    "print(\"✓ ALL EXPERIMENTS COMPLETED SUCCESSFULLY\")\n",
    "print(\"=\"*80)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.10"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
