{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ============================================================================\n",
    "# ICML Paper - Experiment: Fused LASSO with Douglas-Rachford\n",
    "# Figures: 1 & 3\n",
    "# Description: Comparison of DRS (analytical), DRS-HJ, and PPM-HJ for fused LASSO\n",
    "# ============================================================================\n",
    "\n",
    "# ============================================================================\n",
    "# CHUNK 1: SETUP - Algorithms, Helper Functions, and Definitions\n",
    "# ============================================================================\n",
    "\n",
    "import time\n",
    "import torch\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from hj_prox import compute_prox_HJ\n",
    "\n",
    "# Set global seed and dtype\n",
    "torch.set_default_dtype(torch.float64)\n",
    "torch.manual_seed(42)\n",
    "np.random.seed(42)\n",
    "device = 'cpu'\n",
    "eps = 1e-5\n",
    "\n",
    "# Plotting configuration\n",
    "plt.rcParams.update({'font.size': 20})\n",
    "\n",
    "\n",
    "# ============================================================================\n",
    "# Helper Functions\n",
    "# ============================================================================\n",
    "\n",
    "def kth_order_diff_matrix(n: int, k: int) -> torch.Tensor:\n",
    "    \"\"\"\n",
    "    Construct k-th order differencing matrix D of shape (n - k, n).\n",
    "    For k=3: (D x)_i = x[i] - 3*x[i+1] + 3*x[i+2] - x[i+3]\n",
    "    \"\"\"\n",
    "    # Start with identity matrix\n",
    "    D = torch.eye(n, dtype=torch.float64, device=device)\n",
    "    \n",
    "    # Apply k differences\n",
    "    for _ in range(k):\n",
    "        # Create difference operator\n",
    "        diff_op = torch.zeros(D.shape[0] - 1, D.shape[0], dtype=torch.float64, device=device)\n",
    "        for i in range(D.shape[0] - 1):\n",
    "            diff_op[i, i] = -1\n",
    "            diff_op[i, i + 1] = 1\n",
    "        D = diff_op @ D\n",
    "    \n",
    "    return D\n",
    "\n",
    "def fused_lasso_objective(x: torch.Tensor, b: torch.Tensor, D: torch.Tensor, lambd: float) -> torch.Tensor:\n",
    "    \"\"\"\n",
    "    Compute fused lasso objective: 0.5 * ||b - x||_2^2 + λ * ||D x||_1\n",
    "    \n",
    "    Args:\n",
    "        x: signal (n, 1) or batch (batch_size, n)\n",
    "        b: observations (n, 1)\n",
    "        D: difference matrix (n-k, n)\n",
    "        lambd: regularization parameter\n",
    "    \"\"\"\n",
    "    if x.dim() == 1:\n",
    "        x = x.unsqueeze(0)\n",
    "    elif x.dim() == 2 and x.shape[1] == 1:\n",
    "        x = x.t()  # Convert (n, 1) to (1, n)\n",
    "    \n",
    "    batch_size = x.shape[0]\n",
    "    b_flat = b.squeeze()\n",
    "    \n",
    "    # Data fidelity term\n",
    "    data_fid = 0.5 * torch.norm(x - b_flat.unsqueeze(0), p=2, dim=1) ** 2\n",
    "    \n",
    "    # Penalty term: ||Dx||_1\n",
    "    Dx = x @ D.t()  # (batch_size, n-k)\n",
    "    penalty = lambd * torch.sum(torch.abs(Dx), dim=1)\n",
    "    \n",
    "    return data_fid + penalty\n",
    "\n",
    "\n",
    "def compute_gradient(x: torch.Tensor, b: torch.Tensor) -> torch.Tensor:\n",
    "    \"\"\"Gradient of the smooth part: ∇f(x) = x - b\"\"\"\n",
    "    return x - b\n",
    "\n",
    "\n",
    "def _soft_threshold(z: torch.Tensor, tau: float) -> torch.Tensor:\n",
    "    \"\"\"Soft thresholding operator.\"\"\"\n",
    "    return torch.sign(z) * torch.clamp(z.abs() - tau, min=0.0)\n",
    "\n",
    "\n",
    "# ============================================================================\n",
    "# Algorithm 1: Douglas-Rachford with Analytical Proximal Operators\n",
    "# ============================================================================\n",
    "\n",
    "@torch.no_grad()\n",
    "def douglas_rachford_fused(\n",
    "    x0: torch.Tensor,\n",
    "    b: torch.Tensor,\n",
    "    D: torch.Tensor,\n",
    "    lambd: float,\n",
    "    *,\n",
    "    gamma: float = 1.0,\n",
    "    relax: float = 1.0,\n",
    "    max_iters: int = 2000,\n",
    "    tol: float = 1e-6,\n",
    "    verbose: bool = True,\n",
    "    verbose_every: int = 50,\n",
    "    precompute_cholesky: bool = True,\n",
    "):\n",
    "    \"\"\"\n",
    "    Douglas-Rachford Splitting for:\n",
    "        min_x 0.5*||x - b||^2 + λ ||D x||_1\n",
    "    Product-space split:\n",
    "        F(x,w) = 0.5*||x-b||^2 + λ||w||_1,   C = {(x,w): w = D x}\n",
    "\n",
    "    Args\n",
    "    ----\n",
    "    x0 : (n,1) initial point\n",
    "    b  : (n,1) observation\n",
    "    D  : ((n-k), n) k-th order difference (or any linear operator as matrix)\n",
    "    lambd : λ (float)\n",
    "    gamma : DRS stepsize (>0), default 1.0 (any >0 converges)\n",
    "    relax : relaxation (0,2], default 1.0\n",
    "    max_iters : iteration cap\n",
    "    tol : stopping on ||z_{k+1} - z_k||_2\n",
    "    verbose : print progress\n",
    "    verbose_every : print every this many iterations\n",
    "    precompute_cholesky : cache Cholesky of (I + D^T D) for fast projections\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    x_star : (n,1) estimated signal\n",
    "    f_hist : (T,) objective values along the run (evaluated at x_k = prox_F^x(z_k))\n",
    "    diff_hist : (T,) fixed-point residual norms ||z_{k+1}-z_k||\n",
    "    \"\"\"\n",
    "    # Ensure (n,1) shape\n",
    "    if x0.dim() == 1:\n",
    "        x0 = x0.unsqueeze(1)\n",
    "    if b.dim() == 1:\n",
    "        b = b.unsqueeze(1)\n",
    "\n",
    "    n = b.shape[0]\n",
    "    device, dtype = b.device, b.dtype\n",
    "    m = D.shape[0]\n",
    "\n",
    "    # z = (x, w); initialize w as Dx (feasible start helps but not required)\n",
    "    x = x0.clone()\n",
    "    w = (D @ x0).clone()\n",
    "\n",
    "    eye_n = torch.eye(n, dtype=dtype, device=device)\n",
    "    Dt = D.t()\n",
    "\n",
    "    # Projection onto C: solve (I + D^T D) x = rhs once per iter\n",
    "    A_proj = eye_n + Dt @ D\n",
    "    if precompute_cholesky:\n",
    "        # A_proj is SPD (I + D^T D), so Cholesky is safe\n",
    "        L = torch.linalg.cholesky(A_proj)\n",
    "        def solve_A(rhs):\n",
    "            # Solve A_proj x = rhs via two triangular solves\n",
    "            y = torch.cholesky_solve(rhs, L)\n",
    "            return y\n",
    "    else:\n",
    "        def solve_A(rhs):\n",
    "            return torch.linalg.solve(A_proj, rhs)\n",
    "\n",
    "    f_hist = []\n",
    "    diff_hist = []\n",
    "\n",
    "    # Helpful constants for prox_F on x: prox_{γ * 0.5||·-b||^2}(v) = (v + γ b)/(1 + γ)\n",
    "    inv_one_plus_gamma = 1.0 / (1.0 + gamma)\n",
    "\n",
    "    # Work buffers\n",
    "    xA = torch.empty_like(x)\n",
    "    wA = torch.empty_like(w)\n",
    "\n",
    "    for k in range(1, max_iters + 1):\n",
    "        t0 = time.time()\n",
    "\n",
    "        # ---- Prox of F at z=(x,w) ----\n",
    "        # x part (closed form)\n",
    "        xA.copy_( (x + gamma * b) * inv_one_plus_gamma )\n",
    "        # w part (soft threshold)\n",
    "        wA.copy_( _soft_threshold(w, gamma * lambd) )\n",
    "\n",
    "        # ---- Reflection r = 2*prox_F(z) - z ----\n",
    "        rx = 2.0 * xA - x\n",
    "        rw = 2.0 * wA - w\n",
    "\n",
    "        # ---- Projection onto C: minimize ||x - rx||^2 + ||Dx - rw||^2 ----\n",
    "        rhs = rx + Dt @ rw\n",
    "        xB = solve_A(rhs)\n",
    "        wB = D @ xB\n",
    "\n",
    "        # ---- DRS update with relaxation ----\n",
    "        new_x = x + relax * (xB - xA)\n",
    "        new_w = w + relax * (wB - wA)\n",
    "\n",
    "        # ---- Convergence stats ----\n",
    "        diff = torch.linalg.norm(new_x - x).pow(2) + torch.linalg.norm(new_w - w).pow(2)\n",
    "        diff = torch.sqrt(diff)\n",
    "        diff_hist.append(diff.item())\n",
    "\n",
    "        # Objective evaluated at the current primal candidate x_k = xA\n",
    "        fval = 0.5 * torch.linalg.norm(xA - b).pow(2) + lambd * torch.norm(D @ xA, p=1)\n",
    "        f_hist.append(fval.item())\n",
    "\n",
    "        if verbose and (k == 1 or k % verbose_every == 0):\n",
    "            print(f\"DRS iter {k:4d}: f={fval.item():.6f}, ||Δz||={diff.item():.3e}, time={time.time()-t0:.4f}s\")\n",
    "\n",
    "        x, w = new_x, new_w\n",
    "\n",
    "        if diff.item() < tol:\n",
    "            if verbose:\n",
    "                print(f\"DRS converged at iter {k} (||Δz||={diff.item():.3e})\")\n",
    "            break\n",
    "\n",
    "    # The primal solution associated with a DRS fixed point is prox_F(z*). Use last xA.\n",
    "    x_star = xA.clone()\n",
    "    return x_star, torch.tensor(f_hist), torch.tensor(diff_hist)\n",
    "\n",
    "\n",
    "# ============================================================================\n",
    "# Algorithm 2: Douglas-Rachford with HJ-Prox\n",
    "# ============================================================================\n",
    "\n",
    "@torch.no_grad()\n",
    "def douglas_rachford_fused_HJ_Prox(\n",
    "    x0: torch.Tensor,\n",
    "    b: torch.Tensor,\n",
    "    D: torch.Tensor,\n",
    "    lambd: float,\n",
    "    *,\n",
    "    gamma: float = 1.0,\n",
    "    relax: float = 1.0,\n",
    "    max_iters: int = 5000,\n",
    "    int_samples_penalty: int = 100,\n",
    "    tol: float = 1e-8,\n",
    "    verbose: bool = True,\n",
    "    verbose_every: int = 100,\n",
    "    device: str = \"cpu\",\n",
    ") -> tuple[torch.Tensor, torch.Tensor]:\n",
    "    \"\"\"\n",
    "    DRS for min_x 0.5||x - b||^2 + λ||Dx||_1 with inexact prox for g via HJ-Prox.\n",
    "    - Delta schedule: delta = (2_500_000 * 5) / (i+1)^(2 + eps)\n",
    "    - Gamma is FIXED.\n",
    "    Returns (x_hat, f_hist) where x_hat = last y_k (prox-f shadow).\n",
    "    \"\"\"\n",
    "    # shapes\n",
    "    if x0.dim() == 1: x0 = x0.unsqueeze(1)\n",
    "    if b.dim()  == 1: b  = b.unsqueeze(1)\n",
    "    x0 = x0.to(device); b = b.to(device); D = D.to(device)\n",
    "\n",
    "    z = x0.clone()\n",
    "    n = b.shape[0]\n",
    "    b_flat = b.squeeze()\n",
    "    inv_one_plus_gamma = 1.0 / (1.0 + gamma)\n",
    "\n",
    "    # g(x) = λ ||Dx||_1   (batch-aware)\n",
    "    def penalty_func(x_batch: torch.Tensor) -> torch.Tensor:\n",
    "        xb = x_batch.reshape(-1, n)       # (B, n)\n",
    "        Dx = xb @ D.t()                   # (B, n-k)\n",
    "        return lambd * torch.sum(torch.abs(Dx), dim=1)\n",
    "\n",
    "    f_hist: list[float] = []\n",
    "\n",
    "    for i in range(1, max_iters + 1):\n",
    "        # Delta schedule\n",
    "        delta = (2500000*5) / ((i + 1)**(2.0 + eps))\n",
    "\n",
    "        # Prox_g (inexact, HJ-Prox)\n",
    "        xk, ls_iters_g, _ = compute_prox_HJ(\n",
    "            z, gamma, f=penalty_func,\n",
    "            delta=delta,\n",
    "            int_samples=int_samples_penalty,\n",
    "            alpha=1.0\n",
    "        )\n",
    "\n",
    "        # Prox_f\n",
    "        v = 2.0 * xk - z\n",
    "        yk = (gamma * b_flat.unsqueeze(1) + v) * inv_one_plus_gamma\n",
    "\n",
    "        # DRS update\n",
    "        step = yk - xk\n",
    "        z = z + relax * step\n",
    "\n",
    "        # Metrics\n",
    "        diff = torch.linalg.norm(step)\n",
    "        fval = fused_lasso_objective(yk.t(), b, D, lambd).item()\n",
    "        f_hist.append(fval)\n",
    "\n",
    "        if verbose and (i == 1 or i % verbose_every == 0):\n",
    "            print(f\"DRS(HJ) it {i:4d}: f={fval:.6f}, ||y-x||={diff.item():.3e}, \"\n",
    "                  f\"delta={delta:.2e}, ls_g={ls_iters_g}\")\n",
    "\n",
    "        if diff.item() < tol:\n",
    "            if verbose:\n",
    "                print(f\"DRS(HJ) converged at it {i} (||y-x||={diff.item():.3e})\")\n",
    "            break\n",
    "\n",
    "    x_hat = yk.clone()\n",
    "    return x_hat, torch.tensor(f_hist)\n",
    "\n",
    "\n",
    "# ============================================================================\n",
    "# Algorithm 3: Proximal Point Method with HJ-Prox\n",
    "# ============================================================================\n",
    "\n",
    "def proximal_point_fused_HJ_Prox(\n",
    "    x0: torch.Tensor,\n",
    "    b: torch.Tensor,\n",
    "    D: torch.Tensor,\n",
    "    lambd: float,\n",
    "    *,\n",
    "    gamma: float = 1.0,\n",
    "    max_iters: int = 5000,\n",
    "    int_samples: int = 1000,\n",
    "    delta_init: float = 1.0,\n",
    "    tol: float = 1e-8,\n",
    "    verbose: bool = True,\n",
    "    verbose_every: int = 100,\n",
    "    device: str = \"cpu\",\n",
    ") -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:\n",
    "    \"\"\"\n",
    "    Proximal Point Method for fused lasso using HJ-Prox.\n",
    "    \n",
    "    Iteratively computes: x^(k+1) = prox_{γF}(x^k)\n",
    "    where F(x) = 0.5||x - b||^2 + λ||Dx||_1\n",
    "    \n",
    "    Parameters:\n",
    "    -----------\n",
    "    x0 : torch.Tensor\n",
    "        Initial signal (n, 1) or (n,)\n",
    "    b : torch.Tensor\n",
    "        Observations (n, 1) or (n,)\n",
    "    D : torch.Tensor\n",
    "        Difference matrix (n-k, n)\n",
    "    lambd : float\n",
    "        Regularization parameter\n",
    "    gamma : float\n",
    "        Fixed proximal parameter (step size)\n",
    "    max_iters : int\n",
    "        Maximum number of iterations\n",
    "    int_samples : int\n",
    "        Number of samples for HJ-Prox approximation\n",
    "    delta_init : float\n",
    "        Initial smoothing parameter for HJ-Prox\n",
    "    tol : float\n",
    "        Convergence tolerance\n",
    "    verbose : bool\n",
    "        Whether to print progress\n",
    "    verbose_every : int\n",
    "        Print frequency\n",
    "    device : str\n",
    "        Device to use ('cpu' or 'cuda')\n",
    "        \n",
    "    Returns:\n",
    "    --------\n",
    "    x_final : torch.Tensor\n",
    "        Optimized signal (n, 1)\n",
    "    f_hist : torch.Tensor\n",
    "        Objective values at each iteration\n",
    "    diff_hist : torch.Tensor\n",
    "        Iterate differences at each iteration\n",
    "    \"\"\"\n",
    "    # Convert inputs to proper format\n",
    "    if x0.dim() == 1:\n",
    "        x0 = x0.unsqueeze(1)\n",
    "    if b.dim() == 1:\n",
    "        b = b.unsqueeze(1)\n",
    "    \n",
    "    x0 = x0.to(device)\n",
    "    b = b.to(device)\n",
    "    D = D.to(device)\n",
    "    \n",
    "    n = b.shape[0]\n",
    "    b_flat = b.squeeze()\n",
    "    \n",
    "    # Initialize\n",
    "    x = x0.clone()\n",
    "    \n",
    "    f_hist: list[float] = []\n",
    "    diff_hist: list[float] = []\n",
    "    \n",
    "    # Define full objective function for batched input\n",
    "    def full_objective_batch(x_batch: torch.Tensor) -> torch.Tensor:\n",
    "        \"\"\"\n",
    "        Compute full fused lasso objective for a batch of signals.\n",
    "        Args:\n",
    "            x_batch: shape (batch_size, n)\n",
    "        Returns:\n",
    "            objectives: shape (batch_size,)\n",
    "        \"\"\"\n",
    "        # Ensure proper shape\n",
    "        if x_batch.dim() == 1:\n",
    "            x_batch = x_batch.unsqueeze(0)\n",
    "        elif x_batch.dim() == 2 and x_batch.shape[1] == 1:\n",
    "            x_batch = x_batch.t()  # Convert (n, 1) to (1, n)\n",
    "        \n",
    "        batch_size = x_batch.shape[0]\n",
    "        \n",
    "        # Data fidelity term: 0.5 * ||x - b||^2\n",
    "        data_fid = 0.5 * torch.sum((x_batch - b_flat.unsqueeze(0)) ** 2, dim=1)\n",
    "        \n",
    "        # Penalty term: λ * ||Dx||_1\n",
    "        Dx = x_batch @ D.t()  # (batch_size, n-k)\n",
    "        penalty = lambd * torch.sum(torch.abs(Dx), dim=1)\n",
    "        \n",
    "        return data_fid + penalty\n",
    "    \n",
    "    # Single objective for tracking (non-batched)\n",
    "    def single_objective(x_vec: torch.Tensor) -> float:\n",
    "        \"\"\"Compute objective for a single signal vector.\"\"\"\n",
    "        x_flat = x_vec.squeeze()\n",
    "        data_fid = 0.5 * torch.sum((x_flat - b_flat) ** 2)\n",
    "        Dx = x_flat @ D.t()\n",
    "        penalty = lambd * torch.sum(torch.abs(Dx))\n",
    "        return (data_fid + penalty).item()\n",
    "    \n",
    "    # Initialize objective value\n",
    "    if verbose:\n",
    "        initial_obj = single_objective(x)\n",
    "        print(\"Starting Proximal Point Method with HJ-Prox for Fused Lasso...\")\n",
    "        print(f\"λ = {lambd}\")\n",
    "        print(f\"Fixed γ = {gamma}\")\n",
    "        print(f\"HJ-Prox: int_samples = {int_samples}, delta_init = {delta_init}\")\n",
    "        print(f\"Initial objective: {initial_obj:.6f}\")\n",
    "        print(\"-\" * 70)\n",
    "    \n",
    "    for i in range(1, max_iters + 1):\n",
    "        x_old = x.clone()\n",
    "        \n",
    "        # Adaptive delta schedule (annealing)\n",
    "        delta = (2500000*5) / ((i + 1)**(2.0 + eps))\n",
    "        \n",
    "        # Proximal Point step: x^(k+1) = prox_{γF}(x^k)\n",
    "        x_new, ls_iters, _ = compute_prox_HJ(\n",
    "            x, gamma, f=full_objective_batch,\n",
    "            delta=delta,\n",
    "            int_samples=int_samples,\n",
    "            alpha=1.0\n",
    "        )\n",
    "        \n",
    "        x = x_new.clone()\n",
    "        \n",
    "        # Compute objective\n",
    "        current_obj = single_objective(x)\n",
    "        f_hist.append(current_obj)\n",
    "        \n",
    "        # Compute convergence metric\n",
    "        diff = torch.linalg.norm(x - x_old).item()\n",
    "        diff_hist.append(diff)\n",
    "        \n",
    "        if verbose and (i == 1 or i % verbose_every == 0):\n",
    "            print(f\"PPM(HJ) iter {i:4d}: f={current_obj:.6f}, \"\n",
    "                  f\"||Δx||={diff:.3e}, \"\n",
    "                  f\"δ={delta:.3e}, ls_iters={ls_iters}\")\n",
    "        \n",
    "        # Check convergence\n",
    "        if diff < tol:\n",
    "            if verbose:\n",
    "                print(f\"\\nPPM(HJ) converged at iteration {i}\")\n",
    "                print(f\"Final objective: {current_obj:.6f}\")\n",
    "                print(f\"Final ||Δx||: {diff:.3e}\")\n",
    "            break\n",
    "    \n",
    "    if i == max_iters and verbose:\n",
    "        print(f\"\\nPPM(HJ) reached maximum iterations ({max_iters})\")\n",
    "        print(f\"Final objective: {current_obj:.6f}\")\n",
    "    \n",
    "    x_final = x.clone()\n",
    "    return x_final, torch.tensor(f_hist), torch.tensor(diff_hist)\n",
    "\n",
    "\n",
    "print(\"✓ All algorithms and helper functions loaded successfully\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ============================================================================\n",
    "# CHUNK 2: DATA GENERATION\n",
    "# ============================================================================\n",
    "\n",
    "def doppler(t: np.ndarray) -> np.ndarray:\n",
    "    \"\"\"Doppler test function (Donoho & Johnstone). Domain t ∈ (0, 1).\"\"\"\n",
    "    return np.sqrt(t * (1 - t)) * np.sin((2.1 * np.pi) / (t + 0.05))\n",
    "\n",
    "\n",
    "print(\"\\n\" + \"=\"*70)\n",
    "print(\"Generating Doppler signal data...\")\n",
    "print(\"=\"*70)\n",
    "\n",
    "# Generate data\n",
    "n = 256\n",
    "t_vals = np.linspace(0.01, 0.99, n)\n",
    "y_true = doppler(t_vals)\n",
    "\n",
    "# Add noise\n",
    "noise_level = 0.1\n",
    "y_noisy = y_true + noise_level * np.random.randn(n)\n",
    "\n",
    "# Convert to torch\n",
    "b = torch.from_numpy(y_noisy).to(torch.float64).unsqueeze(1).to(device)  # shape (n, 1)\n",
    "\n",
    "# Create the difference matrix\n",
    "k_order = 3\n",
    "D = kth_order_diff_matrix(n, k_order)\n",
    "\n",
    "print(f\"✓ Data generated: n={n} time points\")\n",
    "print(f\"✓ Noise level: {noise_level}\")\n",
    "print(f\"✓ Difference matrix: {k_order}-th order, shape {D.shape}\")\n",
    "\n",
    "# Initial point\n",
    "x0 = torch.zeros((n, 1), dtype=torch.float64, device=device)\n",
    "\n",
    "# Parameters\n",
    "lambd = 0.9\n",
    "gamma = 1.0\n",
    "\n",
    "print(f\"✓ Regularization parameter λ = {lambd}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ============================================================================\n",
    "# CHUNK 3: RUN ALGORITHM 1 - Douglas-Rachford with HJ-Prox\n",
    "# ============================================================================\n",
    "\n",
    "print(\"\\n\" + \"=\"*70)\n",
    "print(\"Running Algorithm 1: Douglas-Rachford with HJ-Prox...\")\n",
    "print(\"=\"*70)\n",
    "\n",
    "start_time = time.time()\n",
    "\n",
    "x_HJ, f_hist_HJ = douglas_rachford_fused_HJ_Prox(\n",
    "    x0=x0,\n",
    "    b=b,\n",
    "    D=D,\n",
    "    lambd=lambd,\n",
    "    gamma=0.000015,\n",
    "    relax=1,\n",
    "    max_iters=150000,\n",
    "    int_samples_penalty=1000,\n",
    "    tol=1e-16,\n",
    "    verbose=True\n",
    ")\n",
    "\n",
    "elapsed_time = time.time() - start_time\n",
    "\n",
    "print(f\"\\n✓ DRS-HJ completed\")\n",
    "print(f\"  - Runtime: {elapsed_time:.2f} seconds\")\n",
    "print(f\"  - Converged in {len(f_hist_HJ)} iterations\")\n",
    "print(f\"  - Final objective: {f_hist_HJ[-1]:.6f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ============================================================================\n",
    "# CHUNK 4: RUN ALGORITHM 2 - Analytical Douglas-Rachford\n",
    "# ============================================================================\n",
    "\n",
    "print(\"\\n\" + \"=\"*70)\n",
    "print(\"Running Algorithm 2: Douglas-Rachford with Analytical Proximal Operators...\")\n",
    "print(\"=\"*70)\n",
    "\n",
    "start_time = time.time()\n",
    "\n",
    "x_Analytical, f_hist_Analytical, diff_hist_Analytical = douglas_rachford_fused(\n",
    "    x0=b.clone(),\n",
    "    b=b,\n",
    "    D=D,\n",
    "    lambd=lambd,\n",
    "    gamma=0.0001,\n",
    "    relax=1,\n",
    "    max_iters=150000,\n",
    "    tol=1e-16,\n",
    "    verbose=True,\n",
    "    verbose_every=100\n",
    ")\n",
    "\n",
    "elapsed_time = time.time() - start_time\n",
    "\n",
    "print(f\"\\n✓ DRS-Analytical completed\")\n",
    "print(f\"  - Runtime: {elapsed_time:.2f} seconds\")\n",
    "print(f\"  - Converged in {len(f_hist_Analytical)} iterations\")\n",
    "print(f\"  - Final objective: {f_hist_Analytical[-1]:.6f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ============================================================================\n",
    "# CHUNK 5: GENERATE FIGURES - DRS-HJ vs Analytical DRS\n",
    "# ============================================================================\n",
    "\n",
    "print(\"\\n\" + \"=\"*70)\n",
    "print(\"Generating figures: DRS-HJ vs Analytical DRS comparison...\")\n",
    "print(\"=\"*70)\n",
    "\n",
    "# --- Figure 1: Noisy Observations and Ground Truth ---\n",
    "plt.figure(figsize=(11, 10))\n",
    "plt.scatter(t_vals, y_noisy, s=30, alpha=0.5,\n",
    "            label='Noisy Observations', zorder=1)\n",
    "plt.plot(t_vals, y_true, 'k-', linewidth=3,\n",
    "         label='True Signal', zorder=2)\n",
    "plt.ylabel('Signal Value', fontsize=40)\n",
    "plt.xlabel('Signal Index', fontsize=40)\n",
    "plt.title('Fused LASSO Ground Truth', fontsize=40)\n",
    "plt.legend(fontsize=40)\n",
    "plt.grid(True, alpha=0.3)\n",
    "plt.xlim(0, 1)\n",
    "plt.tick_params(axis='both', which='major', labelsize=40)\n",
    "plt.tight_layout()\n",
    "plt.savefig('fused_lasso_ground_truth.pdf', format='pdf', dpi=600, bbox_inches='tight')\n",
    "plt.show()\n",
    "plt.close()\n",
    "\n",
    "# --- Figure 2: DRS-HJ Solution ---\n",
    "plt.figure(figsize=(11, 10))\n",
    "plt.scatter(t_vals, y_noisy, s=30, alpha=0.5, zorder=1)\n",
    "plt.plot(t_vals, y_true, 'k-', linewidth=3, zorder=2)\n",
    "plt.plot(t_vals, x_HJ.squeeze().cpu().numpy(), 'b-', linewidth=3,\n",
    "         label='DRS-HJ', alpha=0.8, zorder=4)\n",
    "plt.ylabel('Signal Value', fontsize=40)\n",
    "plt.xlabel('Signal Index', fontsize=40)\n",
    "plt.title('DRS-HJ', fontsize=40)\n",
    "plt.grid(True, alpha=0.3)\n",
    "plt.xlim(0, 1)\n",
    "plt.tick_params(axis='both', which='major', labelsize=40)\n",
    "plt.tight_layout()\n",
    "plt.savefig('fused_lasso_drs_hj.pdf', format='pdf', dpi=600, bbox_inches='tight')\n",
    "plt.show()\n",
    "plt.close()\n",
    "\n",
    "# --- Figure 3: Analytical DRS Solution ---\n",
    "plt.figure(figsize=(11, 10))\n",
    "plt.scatter(t_vals, y_noisy, s=30, alpha=0.5, zorder=1)\n",
    "plt.plot(t_vals, y_true, 'k-', linewidth=3, zorder=2)\n",
    "plt.plot(t_vals, x_Analytical.squeeze().cpu().numpy(), 'r--', linewidth=3,\n",
    "         label='DRS', alpha=0.8, zorder=3)\n",
    "plt.ylabel('Signal Value', fontsize=40)\n",
    "plt.xlabel('Signal Index', fontsize=40)\n",
    "plt.title('DRS', fontsize=40)\n",
    "plt.grid(True, alpha=0.3)\n",
    "plt.xlim(0, 1)\n",
    "plt.tick_params(axis='both', which='major', labelsize=40)\n",
    "plt.tight_layout()\n",
    "plt.savefig('fused_lasso_coordinate_descent.pdf', format='pdf', dpi=600, bbox_inches='tight')\n",
    "plt.show()\n",
    "plt.close()\n",
    "\n",
    "# --- Figure 4: Objective Function Convergence (DRS-HJ vs Analytical) ---\n",
    "plt.figure(figsize=(11, 10))\n",
    "plt.semilogy(f_hist_HJ.numpy(), '-', linewidth=3,\n",
    "             label=f'DRS-HJ: {f_hist_HJ[-1].item():.3f}')\n",
    "plt.semilogy(f_hist_Analytical.numpy(), '--', linewidth=3,\n",
    "             label=f'DRS: {f_hist_Analytical[-1].item():.3f}')\n",
    "plt.ylabel('Objective (log scale)', fontsize=40)\n",
    "plt.xlabel('Iteration', fontsize=40)\n",
    "plt.title('DRS Convergence', fontsize=40)\n",
    "plt.legend(fontsize=40, loc='upper left')\n",
    "plt.grid(True, alpha=0.3, which='both')\n",
    "plt.tick_params(axis='both', which='major', labelsize=40)\n",
    "plt.tight_layout()\n",
    "plt.savefig('fused_lasso_convergence.pdf', format='pdf', dpi=600, bbox_inches='tight')\n",
    "plt.show()\n",
    "plt.close()\n",
    "\n",
    "print(\"✓ All figures saved: fused_lasso_ground_truth.pdf, fused_lasso_drs_hj.pdf, fused_lasso_coordinate_descent.pdf, fused_lasso_convergence.pdf\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ============================================================================\n",
    "# CHUNK 6: RUN ALGORITHM 3 - Proximal Point Method with HJ-Prox\n",
    "# ============================================================================\n",
    "\n",
    "print(\"\\n\" + \"=\"*70)\n",
    "print(\"Running Algorithm 3: Proximal Point Method with HJ-Prox...\")\n",
    "print(\"=\"*70)\n",
    "\n",
    "start_time = time.time()\n",
    "\n",
    "x_PPM, f_PPM, diff_PPM = proximal_point_fused_HJ_Prox(\n",
    "    x0, b, D, lambd,\n",
    "    gamma=0.000005,\n",
    "    max_iters=150000,\n",
    "    int_samples=1000,\n",
    "    delta_init=1.0,\n",
    "    tol=1e-18,\n",
    "    verbose=True,\n",
    "    verbose_every=100,\n",
    "    device=device\n",
    ")\n",
    "\n",
    "elapsed_time = time.time() - start_time\n",
    "\n",
    "print(f\"\\n✓ PPM-HJ completed\")\n",
    "print(f\"  - Runtime: {elapsed_time:.2f} seconds\")\n",
    "print(f\"  - Converged in {len(f_PPM)} iterations\")\n",
    "print(f\"  - Final objective: {f_PPM[-1]:.6f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ============================================================================\n",
    "# CHUNK 7: GENERATE FIGURES - DRS-HJ vs PPM-HJ\n",
    "# ============================================================================\n",
    "\n",
    "print(\"\\n\" + \"=\"*70)\n",
    "print(\"Generating figures: DRS-HJ vs PPM-HJ comparison...\")\n",
    "print(\"=\"*70)\n",
    "\n",
    "# --- Figure 5: Ground Truth (repeated for this comparison set) ---\n",
    "plt.figure(figsize=(11, 10))\n",
    "plt.scatter(t_vals, y_noisy, s=30, alpha=0.5,\n",
    "            label='Noisy Observations', zorder=1)\n",
    "plt.plot(t_vals, y_true, 'k-', linewidth=3,\n",
    "         label='True Signal', zorder=2)\n",
    "plt.ylabel('Signal Value', fontsize=40)\n",
    "plt.xlabel('Signal Index', fontsize=40)\n",
    "plt.title('Fused LASSO Ground Truth', fontsize=40)\n",
    "plt.legend(fontsize=40)\n",
    "plt.grid(True, alpha=0.3)\n",
    "plt.xlim(0, 1)\n",
    "plt.tick_params(axis='both', which='major', labelsize=40)\n",
    "plt.tight_layout()\n",
    "plt.savefig('fused_lasso_ground_truth.pdf', format='pdf', dpi=600, bbox_inches='tight')\n",
    "plt.show()\n",
    "plt.close()\n",
    "\n",
    "# --- Figure 6: DRS-HJ Solution (repeated) ---\n",
    "plt.figure(figsize=(11, 10))\n",
    "plt.scatter(t_vals, y_noisy, s=30, alpha=0.5, zorder=1)\n",
    "plt.plot(t_vals, y_true, 'k-', linewidth=3, zorder=2)\n",
    "plt.plot(t_vals, x_HJ.squeeze().cpu().numpy(), 'b-', linewidth=3,\n",
    "         label='DRS-HJ', alpha=0.8, zorder=4)\n",
    "plt.ylabel('Signal Value', fontsize=40)\n",
    "plt.xlabel('Signal Index', fontsize=40)\n",
    "plt.title('DRS-HJ', fontsize=40)\n",
    "plt.grid(True, alpha=0.3)\n",
    "plt.xlim(0, 1)\n",
    "plt.tick_params(axis='both', which='major', labelsize=40)\n",
    "plt.tight_layout()\n",
    "plt.savefig('fused_lasso_drs_hj.pdf', format='pdf', dpi=600, bbox_inches='tight')\n",
    "plt.show()\n",
    "plt.close()\n",
    "\n",
    "# --- Figure 7: PPM-HJ Solution ---\n",
    "plt.figure(figsize=(11, 10))\n",
    "plt.scatter(t_vals, y_noisy, s=30, alpha=0.5, zorder=1)\n",
    "plt.plot(t_vals, y_true, 'k-', linewidth=3, zorder=2)\n",
    "plt.plot(t_vals, x_PPM.squeeze().cpu().numpy(), 'r--', linewidth=3,\n",
    "         label='PPM-HJ', alpha=0.8, zorder=3)\n",
    "plt.ylabel('Signal Value', fontsize=40)\n",
    "plt.xlabel('Signal Index', fontsize=40)\n",
    "plt.title('PPM-HJ', fontsize=40)\n",
    "plt.grid(True, alpha=0.3)\n",
    "plt.xlim(0, 1)\n",
    "plt.tick_params(axis='both', which='major', labelsize=40)\n",
    "plt.tight_layout()\n",
    "plt.savefig('fused_lasso_PPM.pdf', format='pdf', dpi=600, bbox_inches='tight')\n",
    "plt.show()\n",
    "plt.close()\n",
    "\n",
    "# --- Figure 8: Objective Function Convergence (DRS-HJ vs PPM-HJ) ---\n",
    "plt.figure(figsize=(11, 10))\n",
    "plt.semilogy(f_hist_HJ.numpy(), '-', linewidth=3,\n",
    "             label=f'DRS-HJ: {f_hist_HJ[-1].item():.3f}')\n",
    "plt.semilogy(f_PPM.numpy(), '--', linewidth=3,\n",
    "             label=f'PPM-HJ: {f_PPM[-1].item():.3f}')\n",
    "plt.ylabel('Objective (log scale)', fontsize=40)\n",
    "plt.xlabel('Iteration', fontsize=40)\n",
    "plt.title('Fused LASSO', fontsize=40)\n",
    "plt.legend(fontsize=40, loc='upper left')\n",
    "plt.grid(True, alpha=0.3, which='both')\n",
    "plt.tick_params(axis='both', which='major', labelsize=40)\n",
    "plt.tight_layout()\n",
    "plt.savefig('fused_lasso_convergence_VSPPM.pdf', format='pdf', dpi=600, bbox_inches='tight')\n",
    "plt.show()\n",
    "plt.close()\n",
    "\n",
    "print(\"✓ All figures saved: fused_lasso_PPM.pdf, fused_lasso_convergence_VSPPM.pdf\")\n",
    "print(\"\\n\" + \"=\"*70)\n",
    "print(\"✓ ALL EXPERIMENTS COMPLETED SUCCESSFULLY\")\n",
    "print(\"=\"*70)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.10"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
