{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "# ============================================================================\n",
        "# CHUNK 1: SETUP - Algorithms, Helper Functions, and Definitions\n",
        "# ============================================================================\n",
        "\n",
        "import time\n",
        "import torch\n",
        "import numpy as np\n",
        "from hj_prox import compute_prox_HJ\n",
        "import matplotlib.pyplot as plt\n",
        "from matplotlib.ticker import MaxNLocator\n",
        "\n",
        "device = 'cpu'\n",
        "eps = 1e-5\n",
        "\n",
        "\n",
        "# ============================================================================\n",
        "# Helper Functions\n",
        "# ============================================================================\n",
        "\n",
        "def lasso_objective(x, A, b, lambd):\n",
        "    \"\"\"\n",
        "    Compute LASSO objective: 0.5 * ||Ax - b||_2^2 + λ * ||x||_1\n",
        "    \n",
        "    Args:\n",
        "        x: signal (n, 1) or batch (batch_size, n)\n",
        "        A: measurement matrix (m, n)\n",
        "        b: observations (m, 1)\n",
        "        lambd: regularization parameter\n",
        "    \"\"\"\n",
        "    if x.dim() == 1:\n",
        "        x = x.unsqueeze(0)\n",
        "    elif x.dim() == 2 and x.shape[0] != 1:\n",
        "        x = x.t()\n",
        "    \n",
        "    residual = A @ x.t() - b\n",
        "    data_fid = 0.5 * torch.norm(residual, p=2) ** 2\n",
        "    penalty = lambd * torch.norm(x, p=1)\n",
        "    \n",
        "    return data_fid + penalty\n",
        "\n",
        "\n",
        "def compute_gradient(x, A, b):\n",
        "    \"\"\"Gradient of the smooth part: ∇f(x) = A^T(Ax - b)\"\"\"\n",
        "    return A.t() @ (A @ x - b)\n",
        "\n",
        "\n",
        "def soft_thresholding(x, threshold):\n",
        "    \"\"\"\n",
        "    Soft-thresholding operator (proximal operator for L1 norm)\n",
        "    \n",
        "    prox_{λ||·||_1}(x) = sign(x) * max(|x| - λ, 0)\n",
        "    \n",
        "    Args:\n",
        "        x: input tensor\n",
        "        threshold: threshold parameter λ\n",
        "    \n",
        "    Returns:\n",
        "        Soft-thresholded tensor\n",
        "    \"\"\"\n",
        "    return torch.sign(x) * torch.maximum(torch.abs(x) - threshold, torch.zeros_like(x))\n",
        "\n",
        "\n",
        "# ============================================================================\n",
        "# Algorithm 1: Proximal Gradient Descent (Analytical Soft-Thresholding)\n",
        "# ============================================================================\n",
        "\n",
        "def proximal_gradient_descent_analytical(\n",
        "    x0, A, b, lambd, step_size, max_iters=1000, tol=1e-6, verbose=True\n",
        "):\n",
        "    \"\"\"\n",
        "    Proximal gradient descent for solving the LASSO problem using analytical soft-thresholding:\n",
        "      minimize 0.5 * ||Ax - b||^2 + lambda * ||x||_1\n",
        "    \n",
        "    This implementation uses the closed-form soft-thresholding operator.\n",
        "    \"\"\"\n",
        "    xk = x0.clone()\n",
        "    f_hist = []\n",
        "    diff_hist = []\n",
        "    \n",
        "    for i in range(max_iters):\n",
        "        t0 = time.time()\n",
        "        \n",
        "        # Gradient step\n",
        "        grad = compute_gradient(xk, A, b)\n",
        "        x_grad = xk - step_size * grad\n",
        "        \n",
        "        # Proximal step using soft-thresholding\n",
        "        x_prox = soft_thresholding(x_grad, step_size * lambd)\n",
        "        \n",
        "        # Compute metrics\n",
        "        fk = lasso_objective(x_prox, A, b, lambd)\n",
        "        diff = torch.norm(x_prox - xk)\n",
        "        \n",
        "        f_hist.append(fk.item())\n",
        "        diff_hist.append(diff.item())\n",
        "        xk = x_prox.clone()\n",
        "        \n",
        "        if verbose and i % 10 == 0:\n",
        "            print(f\"Analytical PGD iter {i+1:4d}: f={fk.item():.6f}, \"\n",
        "                  f\"||Δx||={diff.item():.6e}, time={time.time() - t0:.4f}s\")\n",
        "    \n",
        "    return xk, torch.tensor(f_hist), torch.tensor(diff_hist)\n",
        "\n",
        "\n",
        "# ============================================================================\n",
        "# Algorithm 2: Proximal Gradient Descent with HJ-Prox\n",
        "# ============================================================================\n",
        "\n",
        "def proximal_gradient_descent_lasso(\n",
        "    x0, A, b, lambd, step_size, max_iters=1000, \n",
        "    int_samples=100, delta=None, tol=1e-6, verbose=True\n",
        "):\n",
        "    \"\"\"\n",
        "    Proximal gradient descent for solving the LASSO problem:\n",
        "      minimize 0.5 * ||Ax - b||^2 + lambda * ||x||_1\n",
        "    \n",
        "    Uses HJ-Prox with delta = 250000/(k+1)^(2+eps) annealing schedule.\n",
        "    \"\"\"\n",
        "    xk = x0.clone()\n",
        "    f_hist = []\n",
        "    diff_hist = []\n",
        "    \n",
        "    # Define L1 penalty function for HJ-Prox\n",
        "    def l1_penalty(x_batch):\n",
        "        if x_batch.dim() == 1:\n",
        "            x_batch = x_batch.unsqueeze(0)\n",
        "        return lambd * torch.sum(torch.abs(x_batch), dim=1)\n",
        "    \n",
        "    for i in range(max_iters):\n",
        "        t0 = time.time()\n",
        "        \n",
        "        # Compute delta with annealing schedule\n",
        "        k = i + 1\n",
        "        delta_k = 125000/(i+1)**(2+eps)\n",
        "        t_k = 1/(i+1)**(1+eps)\n",
        "        \n",
        "        # Gradient step\n",
        "        grad = compute_gradient(xk, A, b)\n",
        "        x_grad = xk - step_size * grad\n",
        "        \n",
        "        # Proximal step using HJ-Prox for L1 norm\n",
        "        x_prox, ls_iters, _ = compute_prox_HJ(\n",
        "            x_grad,\n",
        "            t=step_size,\n",
        "            f=l1_penalty,\n",
        "            delta=delta_k,\n",
        "            int_samples=int_samples,\n",
        "            alpha=1\n",
        "        )\n",
        "        \n",
        "        # Compute metrics\n",
        "        fk = lasso_objective(x_prox, A, b, lambd)\n",
        "        diff = torch.norm(x_prox - xk)\n",
        "        f_hist.append(fk.item())\n",
        "        diff_hist.append(diff.item())\n",
        "        xk = x_prox.clone()\n",
        "        \n",
        "        if verbose and i % 10 == 0:\n",
        "            print(f\"PGD-HJ iter {i+1:4d}: f={fk.item():.6f}, ||Δx||={diff.item():.6e}, \"\n",
        "                  f\"delta={delta_k:.6e}, time={time.time() - t0:.4f}s\")\n",
        "        \n",
        "        if diff < tol:\n",
        "            break\n",
        "    \n",
        "    return xk, torch.tensor(f_hist), torch.tensor(diff_hist), delta_k\n",
        "\n",
        "\n",
        "# ============================================================================\n",
        "# Algorithm 3: Proximal Point Method with HJ-Prox\n",
        "# ============================================================================\n",
        "\n",
        "def proximal_point_lasso(\n",
        "    x0, A, b, lambd, gamma, max_iters=1000, \n",
        "    int_samples=100, tol=1e-6, verbose=True\n",
        "):\n",
        "    \"\"\"\n",
        "    Proximal Point Method for solving the LASSO problem:\n",
        "      minimize 0.5 * ||Ax - b||^2 + lambda * ||x||_1\n",
        "    \n",
        "    Iteratively computes: x^(k+1) = prox_{γF}(x^k)\n",
        "    where F(x) = 0.5||Ax - b||^2 + λ||x||_1 is the full objective.\n",
        "    \n",
        "    Uses delta = 150000/(k+1)^(2+eps) annealing schedule.\n",
        "    \n",
        "    Parameters:\n",
        "    -----------\n",
        "    x0 : torch.Tensor\n",
        "        Initial point\n",
        "    A : torch.Tensor\n",
        "        Design matrix\n",
        "    b : torch.Tensor\n",
        "        Observation vector\n",
        "    lambd : float\n",
        "        L1 regularization parameter\n",
        "    gamma : float\n",
        "        Fixed proximal parameter (step size)\n",
        "    max_iters : int\n",
        "        Maximum number of iterations\n",
        "    int_samples : int\n",
        "        Number of samples for HJ-Prox\n",
        "    tol : float\n",
        "        Convergence tolerance\n",
        "    verbose : bool\n",
        "        Whether to print progress\n",
        "        \n",
        "    Returns:\n",
        "    --------\n",
        "    xk : torch.Tensor\n",
        "        Final iterate\n",
        "    f_hist : torch.Tensor\n",
        "        Objective values\n",
        "    diff_hist : torch.Tensor\n",
        "        Iterate differences\n",
        "    delta_k : float\n",
        "        Final delta value\n",
        "    \"\"\"\n",
        "    xk = x0.clone()\n",
        "    f_hist = []\n",
        "    diff_hist = []\n",
        "    delta_k = None\n",
        "    \n",
        "    # Define full LASSO objective function for HJ-Prox (batch-aware)\n",
        "    def lasso_objective_batch(x_batch):\n",
        "        \"\"\"\n",
        "        Compute full LASSO objective for a batch of vectors.\n",
        "        Args:\n",
        "            x_batch: shape (batch_size, n)\n",
        "        Returns:\n",
        "            objectives: shape (batch_size,)\n",
        "        \"\"\"\n",
        "        if x_batch.dim() == 1:\n",
        "            x_batch = x_batch.unsqueeze(0)\n",
        "        \n",
        "        batch_size = x_batch.shape[0]\n",
        "        \n",
        "        # Smooth term: 0.5 * ||Ax - b||^2\n",
        "        Ax = torch.matmul(x_batch, A.t())\n",
        "        b_expanded = b.squeeze().unsqueeze(0).expand(batch_size, -1)\n",
        "        residuals = Ax - b_expanded\n",
        "        smooth_term = 0.5 * torch.sum(residuals ** 2, dim=1)\n",
        "        \n",
        "        # Non-smooth term: λ * ||x||_1\n",
        "        l1_term = lambd * torch.sum(torch.abs(x_batch), dim=1)\n",
        "        \n",
        "        return smooth_term + l1_term\n",
        "    \n",
        "    # Single objective for tracking\n",
        "    def lasso_objective_single(x_vec):\n",
        "        \"\"\"Compute objective for a single vector.\"\"\"\n",
        "        x_flat = x_vec.squeeze()\n",
        "        residual = A @ x_flat - b.squeeze()\n",
        "        smooth = 0.5 * torch.sum(residual ** 2)\n",
        "        l1 = lambd * torch.sum(torch.abs(x_flat))\n",
        "        return (smooth + l1).item()\n",
        "    \n",
        "    if verbose:\n",
        "        initial_obj = lasso_objective_single(xk)\n",
        "        print(\"Starting Proximal Point Method with HJ-Prox for LASSO...\")\n",
        "        print(f\"λ = {lambd}, γ = {gamma}\")\n",
        "        print(f\"Initial objective: {initial_obj:.6f}\")\n",
        "        print(\"-\" * 70)\n",
        "    \n",
        "    for i in range(max_iters):\n",
        "        t0 = time.time()\n",
        "        \n",
        "        # Compute delta with annealing schedule\n",
        "        k = i + 1\n",
        "        delta_k = 125000 / (k ** (2 + eps))\n",
        "        \n",
        "        # Proximal Point step: x^(k+1) = prox_{γF}(x^k)\n",
        "        x_prox, ls_iters, _ = compute_prox_HJ(\n",
        "            xk,\n",
        "            t=gamma,\n",
        "            f=lasso_objective_batch,\n",
        "            delta=delta_k,\n",
        "            int_samples=int_samples,\n",
        "            alpha=1\n",
        "        )\n",
        "        \n",
        "        # Compute metrics\n",
        "        fk = lasso_objective_single(x_prox)\n",
        "        diff = torch.norm(x_prox - xk)\n",
        "        f_hist.append(fk)\n",
        "        diff_hist.append(diff.item())\n",
        "        \n",
        "        if verbose and i % 10 == 0:\n",
        "            print(f\"PPM-HJ iter {i+1:4d}: f={fk:.6f}, ||Δx||={diff.item():.6e}, \"\n",
        "                  f\"delta={delta_k:.6e}, ls_iters={ls_iters}, \"\n",
        "                  f\"time={time.time() - t0:.4f}s\")\n",
        "        \n",
        "        # Update iterate\n",
        "        xk = x_prox.clone()\n",
        "        \n",
        "        # Check convergence\n",
        "        if diff < tol:\n",
        "            if verbose:\n",
        "                print(f\"\\nPPM-HJ converged at iteration {i+1}\")\n",
        "                print(f\"Final objective: {fk:.6f}\")\n",
        "            break\n",
        "    \n",
        "    return xk, torch.tensor(f_hist), torch.tensor(diff_hist), delta_k\n",
        "\n",
        "\n",
        "print(\"✓ All algorithms and helper functions loaded successfully\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "# ============================================================================\n",
        "# CHUNK 2: DATA GENERATION\n",
        "# ============================================================================\n",
        "\n",
        "seed = 100\n",
        "np.random.seed(seed)\n",
        "torch.manual_seed(seed)\n",
        "\n",
        "# Problem dimensions\n",
        "dim = 500\n",
        "A = torch.randn(int(dim/2), dim, device=device)\n",
        "\n",
        "# Create a sparse x_true\n",
        "x_true = torch.zeros(dim, 1, device=device)\n",
        "x_true[400:410] = 1\n",
        "\n",
        "# Generate b = A x_true + noise\n",
        "noise_level = 0.1\n",
        "noise = noise_level * torch.randn(int(dim/2), 1, device=device)\n",
        "b = A @ x_true + noise\n",
        "\n",
        "# Initial point\n",
        "x0 = torch.zeros((dim, 1), dtype=torch.float32, device=device)\n",
        "\n",
        "# Parameters\n",
        "lambd = 1\n",
        "\n",
        "# Compute Lipschitz constant for step size\n",
        "sigma_max = torch.linalg.norm(A, ord=2)\n",
        "L = sigma_max**2\n",
        "step_size = 1.0 / L\n",
        "\n",
        "print(\"=\"*60)\n",
        "print(\"Running LASSO Optimization Comparison\")\n",
        "print(\"=\"*60)\n",
        "print(f\"Problem size: A is {A.shape}, x is {x_true.shape}\")\n",
        "print(f\"Regularization parameter λ = {lambd}\")\n",
        "print(f\"Lipschitz constant L = {L:.4f}\")\n",
        "print(f\"Step size for PGD = {step_size:.6f}\")\n",
        "print(\"=\"*60)\n",
        "\n",
        "# Store results\n",
        "results = {}"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "# ============================================================================\n",
        "# CHUNK 3: RUN ALGORITHM 1 - Analytical PGD\n",
        "# ============================================================================\n",
        "\n",
        "print(\"\\n\" + \"=\"*60)\n",
        "print(\"Running Algorithm 1: Analytical PGD...\")\n",
        "print(\"=\"*60)\n",
        "\n",
        "x_analytical, f_hist_analytical, diff_hist_analytical = proximal_gradient_descent_analytical(\n",
        "    x0=x0, A=A, b=b, lambd=1,\n",
        "    step_size=step_size*0.085,\n",
        "    max_iters=10000,\n",
        "    tol=1e-30,\n",
        "    verbose=True\n",
        ")\n",
        "results['Analytical'] = (x_analytical, f_hist_analytical, diff_hist_analytical)\n",
        "\n",
        "print(f\"\\n✓ Analytical PGD completed\")\n",
        "print(f\"  - Converged in {len(f_hist_analytical)} iterations\")\n",
        "print(f\"  - Final objective: {f_hist_analytical[-1]:.6f}\")\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "# ============================================================================\n",
        "# CHUNK 4: RUN ALGORITHM 2 - PGD with HJ-Prox\n",
        "# ============================================================================\n",
        "\n",
        "print(\"\\n\" + \"=\"*60)\n",
        "print(\"Running Algorithm 2: PGD with HJ-Prox...\")\n",
        "print(\"=\"*60)\n",
        "\n",
        "x_hjprox, f_hist_hjprox, diff_hist_hjprox, final_delta = proximal_gradient_descent_lasso(\n",
        "    x0=x0, A=A, b=b, lambd=1,\n",
        "    step_size=step_size*0.085, #As mentioned in paper, we need small step sizes to help control errors\n",
        "    max_iters=10000,\n",
        "    int_samples=1000,\n",
        "    delta=0.01,\n",
        "    tol=1e-16,\n",
        "    verbose=True\n",
        ")\n",
        "results['HJ-Prox'] = (x_hjprox, f_hist_hjprox, diff_hist_hjprox)\n",
        "\n",
        "print(f\"\\n✓ PGD-HJ completed\")\n",
        "print(f\"  - Converged in {len(f_hist_hjprox)} iterations\")\n",
        "print(f\"  - Final objective: {f_hist_hjprox[-1]:.6f}\")\n",
        "print(f\"  - Final delta: {final_delta:.6e}\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "# ============================================================================\n",
        "# CHUNK 5: GENERATE FIGURES - PGD vs PGD-HJ Comparison\n",
        "# ============================================================================\n",
        "\n",
        "print(\"\\n\" + \"=\"*60)\n",
        "print(\"Generating figures: PGD vs PGD-HJ comparison...\")\n",
        "print(\"=\"*60)\n",
        "\n",
        "# Prepare data\n",
        "true_signal = x_true.detach().cpu().numpy().flatten()\n",
        "analytical_signal = results['Analytical'][0].detach().cpu().numpy().flatten()\n",
        "hjprox_signal = results['HJ-Prox'][0].detach().cpu().numpy().flatten()\n",
        "analytical_hist = results['Analytical'][1]\n",
        "hjprox_hist = results['HJ-Prox'][1]\n",
        "\n",
        "start, end = 395, 415\n",
        "idx = np.arange(start, end)\n",
        "\n",
        "# --- Figure 1: Ground Truth ---\n",
        "plt.figure(figsize=(11, 10))\n",
        "plt.plot(idx, true_signal[idx], 'o', markersize=10,\n",
        "         markerfacecolor='none', markeredgecolor='black',\n",
        "         markeredgewidth=3)\n",
        "plt.ylabel('Coefficient Value', fontsize=40)\n",
        "plt.xlabel('Coefficients', fontsize=40)\n",
        "plt.title('Ground Truth (395-415)', fontsize=40)\n",
        "plt.grid(True, alpha=0.3)\n",
        "plt.gca().xaxis.set_major_locator(MaxNLocator(integer=True))\n",
        "plt.gca().tick_params(axis='both', which='major', labelsize=40)\n",
        "plt.tight_layout()\n",
        "plt.savefig('lasso_ground_truth.pdf', format='pdf', dpi=600, bbox_inches='tight')\n",
        "plt.show()\n",
        "plt.close()\n",
        "\n",
        "# --- Figure 2: HJ-Prox Solution ---\n",
        "plt.figure(figsize=(11, 10))\n",
        "plt.plot(idx, hjprox_signal[idx], 's', markersize=8, color='blue')\n",
        "plt.ylabel('Coefficient Value', fontsize=40)\n",
        "plt.xlabel('Coefficient Index', fontsize=40)\n",
        "plt.title('PGD-HJ', fontsize=40)\n",
        "plt.grid(True, alpha=0.3)\n",
        "plt.gca().xaxis.set_major_locator(MaxNLocator(integer=True))\n",
        "plt.gca().tick_params(axis='both', which='major', labelsize=40)\n",
        "plt.tight_layout()\n",
        "plt.savefig('lasso_hjprox.pdf', format='pdf', dpi=600, bbox_inches='tight')\n",
        "plt.show()\n",
        "plt.close()\n",
        "\n",
        "# --- Figure 3: Analytical PGD Solution ---\n",
        "plt.figure(figsize=(11, 10))\n",
        "plt.plot(idx, analytical_signal[idx], '^', markersize=8, color='red')\n",
        "plt.ylabel('Coefficient Value', fontsize=40)\n",
        "plt.xlabel('Coefficients', fontsize=40)\n",
        "plt.title('PGD', fontsize=40)\n",
        "plt.grid(True, alpha=0.3)\n",
        "plt.gca().xaxis.set_major_locator(MaxNLocator(integer=True))\n",
        "plt.gca().tick_params(axis='both', which='major', labelsize=40)\n",
        "plt.tight_layout()\n",
        "plt.savefig('lasso_pgd.pdf', format='pdf', dpi=600, bbox_inches='tight')\n",
        "plt.show()\n",
        "plt.close()\n",
        "\n",
        "# --- Figure 4: Objective Function Convergence (PGD vs PGD-HJ) ---\n",
        "plt.figure(figsize=(11, 10))\n",
        "plt.semilogy(hjprox_hist, '-', linewidth=3,\n",
        "             label=f'PGD-HJ: {hjprox_hist[-1].item():.3f}')\n",
        "plt.semilogy(analytical_hist, '--', linewidth=3,\n",
        "             label=f'PGD: {analytical_hist[-1].item():.3f}')\n",
        "plt.ylabel('Objective (log scale)', fontsize=40)\n",
        "plt.xlabel('Iteration', fontsize=40)\n",
        "plt.title('PGD Convergence', fontsize=40)\n",
        "plt.legend(fontsize=40, loc='upper left')\n",
        "plt.grid(True, which='both', alpha=0.3)\n",
        "plt.gca().xaxis.set_major_locator(MaxNLocator(integer=True))\n",
        "max_iter = len(hjprox_hist)\n",
        "tick_positions = np.arange(0, max_iter + 1, 2500)\n",
        "plt.gca().set_xticks(tick_positions)\n",
        "plt.gca().tick_params(axis='both', which='major', labelsize=40)\n",
        "plt.tight_layout()\n",
        "plt.savefig('lasso_objectives.pdf', format='pdf', dpi=600, bbox_inches='tight')\n",
        "plt.show()\n",
        "plt.close()\n",
        "\n",
        "print(\"✓ All figures saved: lasso_ground_truth.pdf, lasso_hjprox.pdf, lasso_pgd.pdf, lasso_objectives.pdf\")\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "# ============================================================================\n",
        "# CHUNK 6: RUN ALGORITHM 3 - PPM with HJ-Prox (Lower Sample size to show difference)\n",
        "# ============================================================================\n",
        "\n",
        "print(\"\\n\" + \"=\"*60)\n",
        "print(\"Running Algorithm 3: Proximal Point Method with HJ-Prox...\")\n",
        "print(\"=\"*60)\n",
        "\n",
        "x_ppm, f_ppm, diff_ppm, _ = proximal_point_lasso(\n",
        "    x0, A, b, lambd=1, gamma= step_size*0.085,\n",
        "    max_iters=10000,\n",
        "    int_samples=500,\n",
        "    tol=1e-6,\n",
        "    verbose=True\n",
        ")\n",
        "\n",
        "print(f\"\\n✓ PPM-HJ completed\")\n",
        "print(f\"  - Converged in {len(f_ppm)} iterations\")\n",
        "print(f\"  - Final objective: {f_ppm[-1]:.6f}\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "# ============================================================================\n",
        "# CHUNK 7: RUN ALGORITHM 2 - PGD with HJ-Prox (Lower Sample size to show difference)\n",
        "# ============================================================================\n",
        "\n",
        "print(\"\\n\" + \"=\"*60)\n",
        "print(\"Running Algorithm 2: PGD with HJ-Prox...\")\n",
        "print(\"=\"*60)\n",
        "\n",
        "x_hjprox, f_hist_hjprox, diff_hist_hjprox, final_delta = proximal_gradient_descent_lasso(\n",
        "    x0=x0, A=A, b=b, lambd=1,\n",
        "    step_size=step_size*0.085, #As mentioned in paper, we need small step sizes to help control errors\n",
        "    max_iters=10000,\n",
        "    int_samples=500,\n",
        "    delta=0.01,\n",
        "    tol=1e-16,\n",
        "    verbose=True\n",
        ")\n",
        "results['HJ-Prox'] = (x_hjprox, f_hist_hjprox, diff_hist_hjprox)\n",
        "\n",
        "print(f\"\\n✓ PGD-HJ completed\")\n",
        "print(f\"  - Converged in {len(f_hist_hjprox)} iterations\")\n",
        "print(f\"  - Final objective: {f_hist_hjprox[-1]:.6f}\")\n",
        "print(f\"  - Final delta: {final_delta:.6e}\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "# ============================================================================\n",
        "# CHUNK 7: GENERATE FIGURES - PGD-HJ vs PPM-HJ Comparison\n",
        "# ============================================================================\n",
        "\n",
        "print(\"\\n\" + \"=\"*60)\n",
        "print(\"Generating figures: PGD-HJ vs PPM-HJ comparison...\")\n",
        "print(\"=\"*60)\n",
        "\n",
        "# Prepare data\n",
        "ppm_signal = x_ppm.detach().cpu().numpy().flatten()\n",
        "ppm_hist = f_ppm\n",
        "\n",
        "# --- Figure 5: PPM-HJ Solution ---\n",
        "plt.figure(figsize=(11, 10))\n",
        "plt.plot(idx, ppm_signal[idx], '^', markersize=8, color='red')\n",
        "plt.ylabel('Coefficient Value', fontsize=40)\n",
        "plt.xlabel('Coefficients', fontsize=40)\n",
        "plt.title('PPM-HJ', fontsize=40)\n",
        "plt.grid(True, alpha=0.3)\n",
        "plt.gca().xaxis.set_major_locator(MaxNLocator(integer=True))\n",
        "plt.gca().tick_params(axis='both', which='major', labelsize=40)\n",
        "plt.tight_layout()\n",
        "plt.savefig('lasso_ppm.pdf', format='pdf', dpi=600, bbox_inches='tight')\n",
        "plt.show()\n",
        "plt.close()\n",
        "\n",
        "# --- Figure 6: Objective Function Convergence (PGD-HJ vs PPM-HJ) ---\n",
        "plt.figure(figsize=(11, 10))\n",
        "plt.semilogy(hjprox_hist, '-', linewidth=3,\n",
        "             label=f'PGD-HJ: {hjprox_hist[-1].item():.3f}')\n",
        "plt.semilogy(ppm_hist, '--', linewidth=3,\n",
        "             label=f'PPM-HJ: {ppm_hist[-1].item():.3f}')\n",
        "plt.ylabel('Objective (log scale)', fontsize=40)\n",
        "plt.xlabel('Iteration', fontsize=40)\n",
        "plt.title('LASSO', fontsize=40)\n",
        "plt.legend(fontsize=40, loc='upper left')\n",
        "plt.grid(True, which='both', alpha=0.3)\n",
        "plt.gca().xaxis.set_major_locator(MaxNLocator(integer=True))\n",
        "max_iter = len(hjprox_hist)\n",
        "tick_positions = np.arange(0, max_iter + 1, 2500)\n",
        "plt.gca().set_xticks(tick_positions)\n",
        "plt.gca().tick_params(axis='both', which='major', labelsize=40)\n",
        "plt.tight_layout()\n",
        "plt.savefig('lasso_objectives_VS_PPM.pdf', format='pdf', dpi=600, bbox_inches='tight')\n",
        "plt.show()\n",
        "plt.close()\n",
        "\n",
        "print(\"✓ All figures saved: lasso_ppm.pdf, lasso_objectives_VS_PPM.pdf\")\n",
        "print(\"\\n\" + \"=\"*60)\n",
        "print(\"✓ ALL EXPERIMENTS COMPLETED SUCCESSFULLY\")\n",
        "print(\"=\"*60)"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "provenance": []
    },
    "gpuClass": "premium",
    "kernelspec": {
      "display_name": "base",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.11.10"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
