{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ============================================================================\n",
    "# CHUNK 1: SETUP - Algorithms, Helper Functions, and Definitions\n",
    "# ============================================================================\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "from hj_prox import compute_prox_HJ\n",
    "\n",
    "# Plotting configuration\n",
    "plt.rcParams.update({'font.size': 20})\n",
    "\n",
    "\n",
    "# ============================================================================\n",
    "# Helper Functions\n",
    "# ============================================================================\n",
    "\n",
    "def soft_threshold(v, tau):\n",
    "    \"\"\"Soft-thresholding S_tau(v) applied elementwise.\"\"\"\n",
    "    return np.sign(v) * np.maximum(np.abs(v) - tau, 0.0)\n",
    "\n",
    "\n",
    "def objective_nonneg_lasso(X, y, beta, lam):\n",
    "    \"\"\"Compute non-negative LASSO objective.\"\"\"\n",
    "    r = y - X @ beta\n",
    "    return 0.5 * float(r @ r) + lam * float(np.sum(np.abs(beta)))\n",
    "\n",
    "\n",
    "def l1_penalty_batch(beta_batch, lam):\n",
    "    \"\"\"Compute L1 penalty for batch of coefficient vectors.\n",
    "    Args:\n",
    "        beta_batch: shape (n_samples, n_features)\n",
    "        lam: L1 penalty parameter\n",
    "    Returns:\n",
    "        penalties: shape (n_samples,)\n",
    "    \"\"\"\n",
    "    return lam * torch.abs(beta_batch).sum(dim=1)\n",
    "\n",
    "\n",
    "def estimate_L_xtx(X, iters=25, seed=0):\n",
    "    \"\"\"\n",
    "    Power iteration estimate of ||X^T X||_2 = sigma_max(X)^2 (no SVD needed).\n",
    "    \"\"\"\n",
    "    rng = np.random.default_rng(seed)\n",
    "    p = X.shape[1]\n",
    "    v = rng.normal(size=p)\n",
    "    v /= np.linalg.norm(v) + 1e-12\n",
    "    for _ in range(iters):\n",
    "        v = X.T @ (X @ v)\n",
    "        v /= (np.linalg.norm(v) + 1e-12)\n",
    "    XtXv = X.T @ (X @ v)\n",
    "    return float(v @ XtXv)\n",
    "\n",
    "\n",
    "def compute_residual_gap(z, X, y, lam, gamma, return_norm=False):\n",
    "    \"\"\"\n",
    "    Compute the ANALYTICAL residual gap (y^k - x^k) in Davis-Yin splitting from z^k.\n",
    "    \n",
    "    For the nonnegative LASSO problem:\n",
    "        min_beta 0.5||y - X beta||^2 + lam ||beta||_1 + I_{beta>=0}(beta)\n",
    "    \n",
    "    Parameters:\n",
    "    -----------\n",
    "    z : array_like, shape (p,)\n",
    "        Current auxiliary variable z^k\n",
    "    X : array_like, shape (n, p)\n",
    "        Design matrix\n",
    "    y : array_like, shape (n,)\n",
    "        Response vector\n",
    "    lam : float\n",
    "        L1 regularization parameter (lambda)\n",
    "    gamma : float\n",
    "        Step size parameter\n",
    "    return_norm : bool, optional (default=False)\n",
    "        If True, return the norm ||y^k - x^k|| instead of the vector\n",
    "    \n",
    "    Returns:\n",
    "    --------\n",
    "    residual_gap : array_like, shape (p,) or float\n",
    "        The residual gap (y^k - x^k), or its norm if return_norm=True\n",
    "    \"\"\"\n",
    "    # Step 1: Compute x^k = prox_h(z^k) = projection onto non-negative orthant\n",
    "    x = np.maximum(z, 0.0)\n",
    "    \n",
    "    # Step 2: Compute gradient of f at x^k\n",
    "    grad_f = X.T @ (X @ x - y)\n",
    "    \n",
    "    # Step 3: Compute intermediate point u\n",
    "    u = 2.0 * x - z - gamma * grad_f\n",
    "    \n",
    "    # Step 4: Compute y^k = prox_g(u) where g(beta) = lam||beta||_1\n",
    "    y_k = soft_threshold(u, gamma * lam)\n",
    "    \n",
    "    # Step 5: Compute residual gap\n",
    "    residual_gap = y_k - x\n",
    "    \n",
    "    if return_norm:\n",
    "        return np.linalg.norm(residual_gap)\n",
    "    else:\n",
    "        return residual_gap\n",
    "\n",
    "\n",
    "def compute_residual_gap_swapped(z, X, y, lam, gamma, return_norm=False):\n",
    "    \"\"\"\n",
    "    Compute the ANALYTICAL residual gap (y^k - x^k) for SWAPPED Davis-Yin.\n",
    "    \n",
    "    SWAPPED ORDER (used by DYS-HJ-1):\n",
    "        x^k = soft_threshold(z^k, gamma*lam)  [L1 penalty FIRST - analytical]\n",
    "        u^k = 2x^k - z^k - gamma * grad f(x^k)\n",
    "        y^k = max(u^k, 0)                      [Projection LAST - analytical]\n",
    "        residual_gap = y^k - x^k\n",
    "    \"\"\"\n",
    "    # Step 1: Compute x^k = prox_g(z^k) = soft_threshold (L1 penalty - FIRST)\n",
    "    x = soft_threshold(z, gamma * lam)\n",
    "    \n",
    "    # Step 2: Compute gradient of f at x^k\n",
    "    grad_f = X.T @ (X @ x - y)\n",
    "    \n",
    "    # Step 3: Compute intermediate point u\n",
    "    u = 2.0 * x - z - gamma * grad_f\n",
    "    \n",
    "    # Step 4: Compute y^k = prox_h(u) = projection (non-negativity - LAST)\n",
    "    y_k = np.maximum(u, 0.0)\n",
    "    \n",
    "    # Step 5: Compute residual gap\n",
    "    residual_gap = y_k - x\n",
    "    \n",
    "    if return_norm:\n",
    "        return np.linalg.norm(residual_gap)\n",
    "    else:\n",
    "        return residual_gap\n",
    "\n",
    "\n",
    "def compute_prox_projection(z, gamma, delta, int_samples=1000, device='cpu'):\n",
    "    \"\"\"\n",
    "    Computes HJ-Prox for non-negativity using Rejection/Importance Sampling.\n",
    "    \n",
    "    Formula:\n",
    "    prox(z) = E[y * I(y>=0)]\n",
    "            = (1/N) * sum of (samples * indicator)\n",
    "    \"\"\"\n",
    "    sigma = np.sqrt(gamma * delta)\n",
    "    \n",
    "    noise = torch.randn(int_samples, z.shape[0], device=device)\n",
    "    y_samples = z.view(1, -1) + sigma * noise\n",
    "    \n",
    "    # Indicator: 1 if y >= 0, else 0\n",
    "    is_feasible = (y_samples >= 0).float()\n",
    "    \n",
    "    # Sum of valid samples per dimension\n",
    "    numerator = torch.sum(y_samples * is_feasible, dim=0)\n",
    "    \n",
    "    # Divide by total samples (same for all dimensions)\n",
    "    prox_est = numerator / int_samples\n",
    "    \n",
    "    return prox_est\n",
    "\n",
    "\n",
    "def compute_fused_prox_ppm(z, gamma, objective_func, delta=1e-1,\n",
    "                          int_samples=2000, alpha=1.0,\n",
    "                          linesearch_iters=0, device='cpu'):\n",
    "    \"\"\"\n",
    "    Estimate prox_{gamma*(f + I_C)}(z) using HJ-Prox with feasibility constraints.\n",
    "    \n",
    "    Method:\n",
    "    1. Sample y ~ N(z, sqrt(gamma*delta/alpha))\n",
    "    2. PROJECT samples to feasible set: y_proj = max(y, 0)\n",
    "    3. Evaluate f(y_proj) at the PROJECTED samples\n",
    "    4. Compute softmax weights: w = softmax(-f(y_proj)*alpha/delta)\n",
    "    5. Return: sum(w_i * y_proj_i) where sum(w_i) = 1.0\n",
    "    \"\"\"\n",
    "    linesearch_iters += 1\n",
    "    \n",
    "    # 1. Sample from Gaussian\n",
    "    sigma = np.sqrt(gamma * delta / alpha)\n",
    "    noise = torch.randn(int_samples, z.shape[0], device=device)\n",
    "    y_samples = z.view(1, -1) + sigma * noise\n",
    "    \n",
    "    # 2. PROJECT samples to feasible set (set negatives to zero)\n",
    "    y_projected = torch.clamp(y_samples, min=0.0)\n",
    "    \n",
    "    # 3. Evaluate objective at PROJECTED samples\n",
    "    f_vals = objective_func(y_projected)\n",
    "    \n",
    "    # 4. Compute softmax weights\n",
    "    scaling = -f_vals * (alpha / delta)\n",
    "    weights = torch.softmax(scaling, dim=0)\n",
    "    \n",
    "    # Check for overflow\n",
    "    softmax_overflow = 1.0 - (weights < np.inf).prod()\n",
    "    if softmax_overflow:\n",
    "        alpha *= 0.5\n",
    "        return compute_fused_prox_ppm(z, gamma, objective_func, delta,\n",
    "                                     int_samples, alpha, linesearch_iters, device)\n",
    "    \n",
    "    # 5. Compute weighted average of PROJECTED samples\n",
    "    weights_expanded = weights.unsqueeze(1)\n",
    "    prox_est = (weights_expanded * y_projected).sum(dim=0)\n",
    "    \n",
    "    # Sanity check\n",
    "    prox_overflow = 1.0 - (prox_est < np.inf).prod()\n",
    "    if prox_overflow:\n",
    "        print(\"Warning: Prox overflow detected!\")\n",
    "        return z, linesearch_iters\n",
    "    \n",
    "    return prox_est, linesearch_iters\n",
    "\n",
    "\n",
    "# ============================================================================\n",
    "# Algorithm 1: Davis-Yin with Analytical Proximal Operators\n",
    "# ============================================================================\n",
    "\n",
    "def davis_yin_nonneg_lasso(X, y, lam, gamma=None, max_iter=1500, tol=1e-7, z0=None, track_every=10):\n",
    "    \"\"\"\n",
    "    Davis-Yin three-operator splitting for:\n",
    "        min_beta 0.5||y - X beta||^2 + lam ||beta||_1 + I_{beta>=0}(beta)\n",
    "\n",
    "    Split:\n",
    "        f(beta) = 0.5||y - X beta||^2   (smooth)\n",
    "        g(beta) = lam ||beta||_1\n",
    "        h(beta) = I_{beta>=0}(beta)\n",
    "\n",
    "    Iteration:\n",
    "        x^k = prox_{gamma h}(z^k) = (z^k)_+\n",
    "        u^k = 2x^k - z^k - gamma * grad f(x^k)\n",
    "        y^k = prox_{gamma g}(u^k) = soft_threshold(u^k, gamma*lam)\n",
    "        z^{k+1} = z^k + (y^k - x^k)\n",
    "    \"\"\"\n",
    "    n, p = X.shape\n",
    "    z = np.zeros(p) if z0 is None else z0.astype(float).copy()\n",
    "\n",
    "    if gamma is None:\n",
    "        L = estimate_L_xtx(X)\n",
    "        gamma = 1.0 / L\n",
    "    gamma = float(gamma)\n",
    "\n",
    "    obj_hist, rel_hist, nnz_hist, it_hist, fpr_hist = [], [], [], [], []\n",
    "    x_prev = None\n",
    "\n",
    "    for k in range(max_iter):\n",
    "        x = np.maximum(z, 0.0)\n",
    "        grad = X.T @ (X @ x - y)\n",
    "        u = 2.0 * x - z - gamma * grad\n",
    "        yk = soft_threshold(u, gamma * lam)\n",
    "        \n",
    "        # Compute fixed-point residual\n",
    "        residual_gap = yk - x\n",
    "        fpr_norm = np.linalg.norm(residual_gap)\n",
    "        \n",
    "        z_new = z + residual_gap\n",
    "\n",
    "        if k % track_every == 0 or k == max_iter - 1:\n",
    "            obj_hist.append(objective_nonneg_lasso(X, y, x, lam))\n",
    "            nnz_hist.append(int(np.sum(x > 1e-10)))\n",
    "            it_hist.append(k)\n",
    "            fpr_hist.append(float(fpr_norm))\n",
    "            if x_prev is None:\n",
    "                rel_hist.append(np.nan)\n",
    "            else:\n",
    "                denom = max(1.0, np.linalg.norm(x_prev))\n",
    "                rel_hist.append(float(np.linalg.norm(x - x_prev) / denom))\n",
    "\n",
    "        if x_prev is not None:\n",
    "            denom = max(1.0, np.linalg.norm(x_prev))\n",
    "            rel = np.linalg.norm(x - x_prev) / denom\n",
    "            if rel < tol:\n",
    "                z = z_new\n",
    "                break\n",
    "\n",
    "        x_prev = x\n",
    "        z = z_new\n",
    "        \n",
    "    history = {\n",
    "        \"iters\": k + 1,\n",
    "        \"gamma\": gamma,\n",
    "        \"track_every\": track_every,\n",
    "        \"it_track\": np.array(it_hist),\n",
    "        \"objective\": np.array(obj_hist),\n",
    "        \"rel_change\": np.array(rel_hist),\n",
    "        \"nnz\": np.array(nnz_hist),\n",
    "        \"fixed_point_residual\": np.array(fpr_hist),\n",
    "    }\n",
    "    return z, history\n",
    "\n",
    "\n",
    "# ============================================================================\n",
    "# Algorithm 2: Davis-Yin with HJ-Prox (Swapped Order - One Analytical)\n",
    "# ============================================================================\n",
    "\n",
    "def davis_yin_nonneg_lasso_hjprox_swapped(X, y, lam, gamma=None, max_iter=1500, tol=1e-7, z0=None, \n",
    "                                          track_every=10, int_samples=100, delta=1e-1, \n",
    "                                          adaptive_delta=True, device='cpu'):\n",
    "    \"\"\"\n",
    "    Davis-Yin with HJ-Prox for L1 penalty (SWAPPED order).\n",
    "    \n",
    "    Split (CONSISTENT ORDER - same as DYS baseline):\n",
    "        f(beta) = 0.5||y - X beta||^2   (smooth)\n",
    "        h(beta) = I_{beta>=0}(beta)     (Projection - FIRST - analytical)\n",
    "        g(beta) = lam ||beta||_1        (HJ-Prox - LAST)\n",
    "\n",
    "    Iteration:\n",
    "        x^k = prox_{gamma h}(z^k) = (z^k)_+      [Projection FIRST - analytical]\n",
    "        u^k = 2x^k - z^k - gamma * grad f(x^k)\n",
    "        y^k = HJ-prox_{gamma g}(u^k)             [L1 penalty LAST - HJ-Prox]\n",
    "        z^{k+1} = z^k + (y^k - x^k)\n",
    "    \"\"\"\n",
    "    n, p = X.shape\n",
    "    z = np.zeros(p) if z0 is None else z0.astype(float).copy()\n",
    "\n",
    "    if gamma is None:\n",
    "        L = estimate_L_xtx(X)\n",
    "        gamma = 1.0 / L\n",
    "    gamma = float(gamma)\n",
    "\n",
    "    X_torch = torch.tensor(X, dtype=torch.float32, device=device)\n",
    "    y_torch = torch.tensor(y, dtype=torch.float32, device=device)\n",
    "    \n",
    "    def l1_func(beta_batch):\n",
    "        return l1_penalty_batch(beta_batch, lam)\n",
    "    \n",
    "    obj_hist, rel_hist, nnz_hist, it_hist, ls_hist = [], [], [], [], []\n",
    "    fpr_hist, fpr_analytical_hist = [], []\n",
    "    x_prev = None\n",
    "\n",
    "    for k in range(max_iter):\n",
    "        # Step 1: prox_h (non-negativity - analytical - FIRST)\n",
    "        x = np.maximum(z, 0.0)\n",
    "        \n",
    "        # Step 2: Compute gradient\n",
    "        x_torch = torch.tensor(x, dtype=torch.float32, device=device)\n",
    "        grad = X_torch.T @ (X_torch @ x_torch - y_torch)\n",
    "        grad_np = grad.cpu().numpy()\n",
    "        \n",
    "        # Step 3: Reflection step\n",
    "        u = 2.0 * x - z - gamma * grad_np\n",
    "        \n",
    "        # Step 4: prox_g using HJ-Prox (L1 penalty - LAST)\n",
    "        if adaptive_delta:\n",
    "            current_delta = delta / (1.0 + 0.01 * k)\n",
    "        else:\n",
    "            current_delta = delta\n",
    "        \n",
    "        u_torch = torch.tensor(u, dtype=torch.float32, device=device).view(-1, 1)\n",
    "        yk_torch, ls_iters, envelope = compute_prox_HJ(\n",
    "            u_torch, \n",
    "            gamma, \n",
    "            l1_func,\n",
    "            delta=current_delta,\n",
    "            int_samples=int_samples,\n",
    "            alpha=1.0\n",
    "        )\n",
    "        yk = yk_torch.view(-1).cpu().numpy()\n",
    "        \n",
    "        # Compute actual fixed-point residual\n",
    "        residual_gap = yk - x\n",
    "        fpr_norm = np.linalg.norm(residual_gap)\n",
    "        \n",
    "        # Compute analytical fixed-point residual\n",
    "        fpr_analytical_norm = compute_residual_gap(z, X, y, lam, gamma, return_norm=True)\n",
    "        \n",
    "        # Step 5: Update dual variable\n",
    "        z_new = z + residual_gap\n",
    "\n",
    "        if k % track_every == 0 or k == max_iter - 1:\n",
    "            obj_hist.append(objective_nonneg_lasso(X, y, yk, lam))\n",
    "            nnz_hist.append(int(np.sum(yk > 1e-10)))\n",
    "            it_hist.append(k)\n",
    "            ls_hist.append(ls_iters)\n",
    "            fpr_hist.append(float(fpr_norm))\n",
    "            fpr_analytical_hist.append(float(fpr_analytical_norm))\n",
    "            \n",
    "            if x_prev is None:\n",
    "                rel_hist.append(np.nan)\n",
    "            else:\n",
    "                denom = max(1.0, np.linalg.norm(x_prev))\n",
    "                rel_hist.append(float(np.linalg.norm(x - x_prev) / denom))\n",
    "            \n",
    "            if k % (track_every * 10) == 0:\n",
    "                print(f\"Iter {k:4d}: obj={obj_hist[-1]:.6f}, rel_change={rel_hist[-1]:.2e}, \"\n",
    "                      f\"fpr={fpr_norm:.2e}, fpr_analytical={fpr_analytical_norm:.2e}, \"\n",
    "                      f\"nnz={nnz_hist[-1]}, ls_iters={ls_iters}, delta={current_delta:.2e}\")\n",
    "\n",
    "        if x_prev is not None:\n",
    "            denom = max(1.0, np.linalg.norm(x_prev))\n",
    "            rel = np.linalg.norm(x - x_prev) / denom\n",
    "            if rel < tol:\n",
    "                z = z_new\n",
    "                break\n",
    "\n",
    "        x_prev = x\n",
    "        z = z_new\n",
    "\n",
    "    beta_hat = z\n",
    "    \n",
    "    history = {\n",
    "        \"iters\": k + 1,\n",
    "        \"gamma\": gamma,\n",
    "        \"track_every\": track_every,\n",
    "        \"it_track\": np.array(it_hist),\n",
    "        \"objective\": np.array(obj_hist),\n",
    "        \"rel_change\": np.array(rel_hist),\n",
    "        \"nnz\": np.array(nnz_hist),\n",
    "        \"linesearch_iters\": np.array(ls_hist),\n",
    "        \"fixed_point_residual\": np.array(fpr_hist),\n",
    "        \"fixed_point_residual_analytical\": np.array(fpr_analytical_hist),\n",
    "    }\n",
    "    \n",
    "    return beta_hat, history\n",
    "\n",
    "\n",
    "# ============================================================================\n",
    "# Algorithm 3: Davis-Yin with HJ-Prox (Both Operators via HJ-Prox)\n",
    "# ============================================================================\n",
    "\n",
    "def davis_yin_nonneg_lasso_2(X, y, lam, gamma=None, max_iter=1500, tol=1e-7, z0=None, \n",
    "                                  track_every=100, \n",
    "                                  int_samples_l1=100, delta_l1=1e-1,\n",
    "                                  int_samples_proj=100, delta_proj=1e-1,\n",
    "                                  adaptive_delta=True, device='cpu'):\n",
    "    \"\"\"\n",
    "    Davis-Yin with BOTH operators approximated via HJ-Prox.\n",
    "    Uses rejection sampling for projection and standard HJ-Prox for L1.\n",
    "    \"\"\"\n",
    "    n, p = X.shape\n",
    "    z = np.zeros(p) if z0 is None else z0.astype(float).copy()\n",
    "\n",
    "    if gamma is None:\n",
    "        L = estimate_L_xtx(X)\n",
    "        gamma = 0.9 / L \n",
    "    gamma = float(gamma)\n",
    "\n",
    "    X_torch = torch.tensor(X, dtype=torch.float32, device=device)\n",
    "    y_torch = torch.tensor(y, dtype=torch.float32, device=device)\n",
    "    \n",
    "    def l1_func(beta_batch):\n",
    "        return l1_penalty_batch(beta_batch, lam)\n",
    "    \n",
    "    obj_hist, it_hist = [], []\n",
    "    fpr_hist, fpr_analytical_hist = [], []\n",
    "    rel_hist, nnz_hist = [], []\n",
    "    x_prev = None\n",
    "\n",
    "    print(f\"Starting Davis-Yin with Honest Rejection Sampling\")\n",
    "    print(f\"Gamma: {gamma:.4e}, Delta_Proj: {delta_proj}, Delta_L1: {delta_l1}\")\n",
    "\n",
    "    for k in range(max_iter):\n",
    "        if adaptive_delta:\n",
    "            curr_d_proj = delta_proj / (1.0 + 0.01 * k)\n",
    "            curr_d_l1 = delta_l1 / (1.0 + 0.01 * k)\n",
    "        else:\n",
    "            curr_d_proj = delta_proj\n",
    "            curr_d_l1 = delta_l1\n",
    "        \n",
    "        # Step 1: prox_h (Rejection Sampling for projection)\n",
    "        z_torch = torch.tensor(z, dtype=torch.float32, device=device)\n",
    "        x_torch = compute_prox_projection(\n",
    "            z_torch, \n",
    "            gamma, \n",
    "            delta=curr_d_proj, \n",
    "            int_samples=int_samples_proj\n",
    "        )\n",
    "        x = x_torch.cpu().numpy()\n",
    "        \n",
    "        # Step 2: Gradient Step \n",
    "        grad = X_torch.T @ (X_torch @ x_torch - y_torch)\n",
    "        grad_np = grad.cpu().numpy()\n",
    "        u = 2.0 * x - z - gamma * grad_np\n",
    "        \n",
    "        # Step 3: prox_g (L1 Norm via HJ-Prox)\n",
    "        u_torch = torch.tensor(u, dtype=torch.float32, device=device).view(-1, 1)\n",
    "        yk_torch, _, _ = compute_prox_HJ(\n",
    "            u_torch, \n",
    "            gamma, \n",
    "            l1_func,\n",
    "            delta=curr_d_l1,\n",
    "            int_samples=int_samples_l1,\n",
    "            alpha=1.0\n",
    "        )\n",
    "        yk = yk_torch.view(-1).cpu().numpy()\n",
    "        \n",
    "        # Compute actual fixed-point residual\n",
    "        residual_gap = yk - x\n",
    "        fpr_norm = np.linalg.norm(residual_gap)\n",
    "        \n",
    "        # Compute analytical fixed-point residual\n",
    "        fpr_analytical_norm = compute_residual_gap(z, X, y, lam, gamma, return_norm=True)\n",
    "        \n",
    "        # Step 4: Update Dual Variable\n",
    "        z_new = z + residual_gap\n",
    "\n",
    "        if k % track_every == 0 or k == max_iter - 1:\n",
    "            obj = objective_nonneg_lasso(X, y, x, lam)\n",
    "            obj_hist.append(obj)\n",
    "            it_hist.append(k)\n",
    "            fpr_hist.append(float(fpr_norm))\n",
    "            fpr_analytical_hist.append(float(fpr_analytical_norm))\n",
    "            nnz_hist.append(int(np.sum(x > 1e-10)))\n",
    "            \n",
    "            if x_prev is None:\n",
    "                rel_hist.append(np.nan)\n",
    "            else:\n",
    "                denom = max(1.0, np.linalg.norm(x_prev))\n",
    "                rel_hist.append(float(np.linalg.norm(x - x_prev) / denom))\n",
    "            \n",
    "            n_zeros = np.sum(x == 0.0)\n",
    "            \n",
    "            print(f\"Iter {k}: Obj={obj:.5f}, fpr={fpr_norm:.2e}, fpr_analytical={fpr_analytical_norm:.2e}, \"\n",
    "                  f\"rel_change={rel_hist[-1]:.2e}, Min(x)={np.min(x):.2e}, Exact Zeros={n_zeros}\")\n",
    "\n",
    "        if x_prev is not None:\n",
    "            denom = max(1.0, np.linalg.norm(x_prev))\n",
    "            rel = np.linalg.norm(x - x_prev) / denom\n",
    "            if rel < tol:\n",
    "                z = z_new\n",
    "                print(\"Converged.\")\n",
    "                break\n",
    "            \n",
    "        x_prev = x\n",
    "        z = z_new\n",
    "\n",
    "    history = {\n",
    "        \"iters\": k + 1,\n",
    "        \"gamma\": gamma,\n",
    "        \"track_every\": track_every,\n",
    "        \"it_track\": np.array(it_hist),\n",
    "        \"objective\": np.array(obj_hist),\n",
    "        \"rel_change\": np.array(rel_hist),\n",
    "        \"nnz\": np.array(nnz_hist),\n",
    "        \"fixed_point_residual\": np.array(fpr_hist),\n",
    "        \"fixed_point_residual_analytical\": np.array(fpr_analytical_hist),\n",
    "    }\n",
    "    \n",
    "    return z, history\n",
    "\n",
    "\n",
    "# ============================================================================\n",
    "# Algorithm 4: Proximal Point Method with HJ-Prox\n",
    "# ============================================================================\n",
    "\n",
    "def proximal_point_nonneg_lasso_hjprox(X, y, lam, t=None, max_iter=1500, tol=1e-7, \n",
    "                                       beta0=None, track_every=10, int_samples=2000, \n",
    "                                       delta=1e-1, adaptive_delta=True, device='cpu'):\n",
    "    \"\"\"\n",
    "    Proximal Point Method with HJ-Prox for non-negative LASSO:\n",
    "        min_β  0.5||y - Xβ||² + λ||β||₁  subject to β ≥ 0\n",
    "    \n",
    "    Iteration:\n",
    "        β^{k+1} = prox_{γ F}(β^k)\n",
    "        \n",
    "    where F(β) = 0.5||y - Xβ||² + λ||β||₁ + I_{β≥0}(β) is the FULL objective.\n",
    "    \"\"\"\n",
    "    n, p = X.shape\n",
    "    beta = np.zeros(p) if beta0 is None else beta0.astype(float).copy()\n",
    "\n",
    "    X_torch = torch.tensor(X, dtype=torch.float32, device=device)\n",
    "    y_torch = torch.tensor(y, dtype=torch.float32, device=device)\n",
    "    \n",
    "    def lasso_objective(beta_batch):\n",
    "        \"\"\"Evaluate 0.5||y - Xβ||² + λ||β||₁ for batch of β values.\"\"\"\n",
    "        preds = beta_batch @ X_torch.T\n",
    "        residuals = y_torch - preds\n",
    "        ls_term = 0.5 * torch.sum(residuals ** 2, dim=1)\n",
    "        l1_term = lam * torch.sum(torch.abs(beta_batch), dim=1)\n",
    "        return ls_term + l1_term\n",
    "    \n",
    "    obj_hist, rel_hist, nnz_hist, it_hist, ls_hist, fpr_hist = [], [], [], [], [], []\n",
    "    beta_prev = None\n",
    "    \n",
    "    print(\"=\" * 70)\n",
    "    print(\"Proximal Point Method with HJ-Prox\")\n",
    "    print(\"=\" * 70)\n",
    "    print(f\"Problem: n={n}, p={p}, λ={lam}\")\n",
    "    print(f\"HJ-Prox: {int_samples} samples, initial δ={delta:.2e}\")\n",
    "    print(f\"Proximal step size γ={t:.4e}\")\n",
    "    print(\"-\" * 70)\n",
    "    \n",
    "    for k in range(max_iter):\n",
    "        if adaptive_delta:\n",
    "            current_delta = delta / (1.0 + 0.01 * k)\n",
    "        else:\n",
    "            current_delta = delta\n",
    "        \n",
    "        beta_torch = torch.tensor(beta, dtype=torch.float32, device=device)\n",
    "        \n",
    "        beta_new_torch, ls_iters = compute_fused_prox_ppm(\n",
    "            beta_torch,\n",
    "            t,\n",
    "            lasso_objective,\n",
    "            delta=current_delta,\n",
    "            int_samples=int_samples,\n",
    "            device=device\n",
    "        )\n",
    "        \n",
    "        beta_new = beta_new_torch.cpu().numpy()\n",
    "        \n",
    "        # Compute Davis-Yin fixed-point residual at current iterate\n",
    "        dy_fpr_norm = compute_residual_gap(beta, X, y, lam, 0.003, return_norm=True)\n",
    "        \n",
    "        if k % track_every == 0 or k == max_iter - 1:\n",
    "            obj_val = objective_nonneg_lasso(X, y, beta_new, lam)\n",
    "            obj_hist.append(obj_val)\n",
    "            nnz_hist.append(int(np.sum(beta_new > 1e-10)))\n",
    "            it_hist.append(k)\n",
    "            ls_hist.append(ls_iters)\n",
    "            fpr_hist.append(float(dy_fpr_norm))\n",
    "            \n",
    "            n_negative = int(np.sum(beta_new < -1e-6))\n",
    "            min_val = float(np.min(beta_new))\n",
    "            \n",
    "            if beta_prev is None:\n",
    "                rel_hist.append(np.nan)\n",
    "            else:\n",
    "                denom = max(1.0, np.linalg.norm(beta_prev))\n",
    "                rel_hist.append(float(np.linalg.norm(beta_new - beta_prev) / denom))\n",
    "            \n",
    "            if k % (track_every * 10) == 0:\n",
    "                print(f\"Iter {k:4d}: obj={obj_hist[-1]:.6f}, rel_change={rel_hist[-1]:.2e}, \"\n",
    "                      f\"DY_fpr={dy_fpr_norm:.2e}, nnz={nnz_hist[-1]}, min_β={min_val:.2e}, \"\n",
    "                      f\"neg={n_negative}, δ={current_delta:.2e}, ls={ls_iters}\")\n",
    "        \n",
    "        if beta_prev is not None:\n",
    "            denom = max(1.0, np.linalg.norm(beta_prev))\n",
    "            rel = np.linalg.norm(beta_new - beta_prev) / denom\n",
    "            if rel < tol:\n",
    "                beta = beta_new\n",
    "                break\n",
    "        \n",
    "        beta_prev = beta\n",
    "        beta = beta_new\n",
    "    \n",
    "    beta_hat = np.maximum(beta, 0.0)\n",
    "    \n",
    "    print(\"-\" * 70)\n",
    "    print(f\"Converged in {k+1} iterations\")\n",
    "    print(f\"Final objective: {obj_hist[-1]:.6e}\")\n",
    "    print(f\"Final DY FPR: {fpr_hist[-1]:.2e}\")\n",
    "    print(f\"Non-zero coefficients: {nnz_hist[-1]}\")\n",
    "    print(\"=\" * 70)\n",
    "    \n",
    "    history = {\n",
    "        \"iters\": k + 1,\n",
    "        \"gamma\": t,\n",
    "        \"track_every\": track_every,\n",
    "        \"it_track\": np.array(it_hist),\n",
    "        \"objective\": np.array(obj_hist),\n",
    "        \"rel_change\": np.array(rel_hist),\n",
    "        \"nnz\": np.array(nnz_hist),\n",
    "        \"linesearch_iters\": np.array(ls_hist),\n",
    "        \"fixed_point_residual\": np.array(fpr_hist),\n",
    "    }\n",
    "    \n",
    "    return beta_hat, history\n",
    "\n",
    "\n",
    "print(\"✓ All algorithms and helper functions loaded successfully\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ============================================================================\n",
    "# CHUNK 2: DATA GENERATION\n",
    "# ============================================================================\n",
    "\n",
    "print(\"\\n\" + \"=\"*70)\n",
    "print(\"Generating constrained LASSO data...\")\n",
    "print(\"=\"*70)\n",
    "\n",
    "rng = np.random.default_rng(0)\n",
    "n, p = 250, 500\n",
    "k_true = 50\n",
    "\n",
    "# Make X with roughly normalized columns\n",
    "X = rng.normal(size=(n, p))\n",
    "X = X - X.mean(axis=0, keepdims=True)\n",
    "X = X / (np.linalg.norm(X, axis=0, keepdims=True) + 1e-12)\n",
    "\n",
    "# Sparse nonnegative ground truth\n",
    "beta_true = np.zeros(p)\n",
    "support = rng.choice(p, size=k_true, replace=False)\n",
    "beta_true[support] = rng.uniform(1.0, 2.0, size=k_true) * 2.5\n",
    "\n",
    "# Response with noise\n",
    "sigma = 0.5\n",
    "y = X @ beta_true + rng.normal(scale=sigma, size=n)\n",
    "y = y - y.mean()\n",
    "\n",
    "# Parameters\n",
    "lam = 0.5\n",
    "gamma = 0.0025\n",
    "\n",
    "print(f\"✓ Data generated: n={n}, p={p}\")\n",
    "print(f\"✓ True non-zeros: {k_true}\")\n",
    "print(f\"✓ Regularization parameter λ = {lam}\")\n",
    "print(f\"✓ Step size γ = {gamma}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ============================================================================\n",
    "# CHUNK 3: RUN ALGORITHM 1 - Analytical Davis-Yin\n",
    "# ============================================================================\n",
    "\n",
    "print(\"\\n\" + \"=\"*70)\n",
    "print(\"Running Algorithm 1: Davis-Yin with Analytical Proximal Operators...\")\n",
    "print(\"=\"*70)\n",
    "\n",
    "beta_Analytical, hist_Analytical = davis_yin_nonneg_lasso(\n",
    "    X, y, lam, gamma=gamma, max_iter=10000, tol=1e-100, track_every=1\n",
    ")\n",
    "\n",
    "supp_hat = set(np.flatnonzero(beta_Analytical > 1e-6))\n",
    "supp_true = set(support)\n",
    "\n",
    "print(\"\\n✓ Analytical DYS completed\")\n",
    "print(f\"  - Converged in {hist_Analytical['iters']} iterations\")\n",
    "print(f\"  - Final objective: {objective_nonneg_lasso(X, y, beta_Analytical, lam):.6f}\")\n",
    "print(f\"  - Estimated non-zeros: {len(supp_hat)}\")\n",
    "print(f\"  - Support overlap: {len(supp_hat & supp_true)}/{k_true}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ============================================================================\n",
    "# CHUNK 4: RUN ALGORITHM 2 - DYS-HJ-1 (Swapped Order)\n",
    "# ============================================================================\n",
    "\n",
    "print(\"\\n\" + \"=\"*70)\n",
    "print(\"Running Algorithm 2: DYS-HJ-1 (One Analytical Projection)...\")\n",
    "print(\"=\"*70)\n",
    "\n",
    "beta_HJ_1, hist_HJ_1 = davis_yin_nonneg_lasso_hjprox_swapped(\n",
    "    X, y, lam=lam, \n",
    "    max_iter=10000,\n",
    "    int_samples=1000,\n",
    "    delta=0.1,\n",
    "    gamma=gamma,\n",
    "    adaptive_delta=False,\n",
    "    device='cpu',\n",
    "    track_every=1\n",
    ")\n",
    "\n",
    "print(f\"\\n✓ DYS-HJ-1 completed\")\n",
    "print(f\"  - Converged in {hist_HJ_1['iters']} iterations\")\n",
    "print(f\"  - Final objective: {hist_HJ_1['objective'][-1]:.6f}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ============================================================================\n",
    "# CHUNK 5: RUN ALGORITHM 3 - DYS-HJ-2 (Both via HJ-Prox)\n",
    "# ============================================================================\n",
    "\n",
    "print(\"\\n\" + \"=\"*70)\n",
    "print(\"Running Algorithm 3: DYS-HJ-2 (Both Operators via HJ-Prox)...\")\n",
    "print(\"=\"*70)\n",
    "\n",
    "beta_HJ_2, hist_HJ_2 = davis_yin_nonneg_lasso_2(\n",
    "    X, y, lam, \n",
    "    max_iter=10000,\n",
    "    gamma=gamma,\n",
    "    int_samples_l1=1000,\n",
    "    delta_l1=0.1,\n",
    "    int_samples_proj=1000,\n",
    "    delta_proj=0.1,\n",
    "    adaptive_delta=False,\n",
    "    device='cpu',\n",
    "    track_every=1\n",
    ")\n",
    "\n",
    "print(f\"\\n✓ DYS-HJ-2 completed\")\n",
    "print(f\"  - Converged in {hist_HJ_2['iters']} iterations\")\n",
    "print(f\"  - Final objective: {hist_HJ_2['objective'][-1]:.6f}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ============================================================================\n",
    "# CHUNK 6: RUN ALGORITHM 4 - PPM with HJ-Prox\n",
    "# ============================================================================\n",
    "\n",
    "print(\"\\n\" + \"=\"*70)\n",
    "print(\"Running Algorithm 4: Proximal Point Method with HJ-Prox...\")\n",
    "print(\"=\"*70)\n",
    "\n",
    "beta_PPM, hist_PPM = proximal_point_nonneg_lasso_hjprox(\n",
    "    X, y, lam,\n",
    "    max_iter=10000,\n",
    "    track_every=1,\n",
    "    int_samples=1000,\n",
    "    delta=1e-1,\n",
    "    t=0.01,\n",
    "    device='cpu',\n",
    "    adaptive_delta=False\n",
    ")\n",
    "\n",
    "print(f\"\\n✓ PPM-HJ completed\")\n",
    "print(f\"  - Converged in {hist_PPM['iters']} iterations\")\n",
    "print(f\"  - Final objective: {hist_PPM['objective'][-1]:.6f}\")\n",
    "print(f\"  - Recovered non-zeros: {np.sum(beta_PPM > 1e-6)}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ============================================================================\n",
    "# CHUNK 7: GENERATE FIGURES\n",
    "# ============================================================================\n",
    "\n",
    "print(\"\\n\" + \"=\"*70)\n",
    "print(\"Generating figures...\")\n",
    "print(\"=\"*70)\n",
    "\n",
    "# --- Figure 1: Fixed Point Residual Convergence ---\n",
    "plt.figure(figsize=(14, 8))\n",
    "\n",
    "plt.plot(hist_Analytical['fixed_point_residual'], '-', linewidth=3,\n",
    "         label=f'DYS: {hist_Analytical[\"fixed_point_residual\"][-1]:.3e}')\n",
    "plt.plot(hist_HJ_1['fixed_point_residual_analytical'], '--', linewidth=3,\n",
    "         label=f'DYS-HJ-1: {hist_HJ_1[\"fixed_point_residual_analytical\"][-1]:.3f}')\n",
    "plt.plot(hist_HJ_2['fixed_point_residual_analytical'], '-.', linewidth=3,\n",
    "         label=f'DYS-HJ-2: {hist_HJ_2[\"fixed_point_residual_analytical\"][-1]:.3f}')\n",
    "plt.plot(hist_PPM['fixed_point_residual'], ':', linewidth=3,\n",
    "         label=f'HJ-PPM: {hist_PPM[\"fixed_point_residual\"][-1]:.3f}')\n",
    "\n",
    "plt.ylabel('Fixed Point Residual', fontsize=40)\n",
    "plt.xlabel('Iteration', fontsize=40)\n",
    "plt.title('Fixed Point Residuals', fontsize=40)\n",
    "plt.legend(fontsize=35, loc='upper right')\n",
    "plt.grid(True, alpha=0.3)\n",
    "plt.tick_params(axis='both', which='major', labelsize=40)\n",
    "plt.tight_layout()\n",
    "plt.savefig('fixed_point_residual_convergence.pdf', format='pdf', dpi=600, bbox_inches='tight')\n",
    "plt.show()\n",
    "plt.close()\n",
    "\n",
    "# --- Figure 2: Objective Value Convergence ---\n",
    "plt.figure(figsize=(14, 8))\n",
    "\n",
    "plt.semilogy(hist_Analytical['objective'], '-', linewidth=3,\n",
    "             label=f'DYS: {hist_Analytical[\"objective\"][-1]:.3f}')\n",
    "plt.semilogy(hist_HJ_1['objective'], '--', linewidth=3,\n",
    "             label=f'DYS-HJ-1: {hist_HJ_1[\"objective\"][-1]:.3f}')\n",
    "plt.semilogy(hist_HJ_2['objective'], '-.', linewidth=3,\n",
    "             label=f'DYS-HJ-2: {hist_HJ_2[\"objective\"][-1]:.3f}')\n",
    "plt.semilogy(hist_PPM['objective'], ':', linewidth=3,\n",
    "             label=f'HJ-PPM: {hist_PPM[\"objective\"][-1]:.3f}')\n",
    "\n",
    "plt.ylabel('Objective Value (log scale)', fontsize=30)\n",
    "plt.xlabel('Iteration', fontsize=40)\n",
    "plt.title('Objective Values', fontsize=40)\n",
    "plt.legend(fontsize=35, loc='upper right')\n",
    "plt.grid(True, alpha=0.3, which='both')\n",
    "plt.tick_params(axis='both', which='major', labelsize=40)\n",
    "plt.tight_layout()\n",
    "plt.savefig('objective_convergence.pdf', format='pdf', dpi=600, bbox_inches='tight')\n",
    "plt.show()\n",
    "plt.close()\n",
    "\n",
    "print(\"✓ All figures saved: fixed_point_residual_convergence.pdf, objective_convergence.pdf\")\n",
    "print(\"\\n\" + \"=\"*70)\n",
    "print(\"✓ ALL EXPERIMENTS COMPLETED SUCCESSFULLY\")\n",
    "print(\"=\"*70)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.10"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
