{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ============================================================================\n",
    "# CHUNK 1: SETUP - Algorithms, Helper Functions, and Definitions\n",
    "# ============================================================================\n",
    "\n",
    "import time\n",
    "import torch\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib.ticker import MaxNLocator\n",
    "\n",
    "# Set global dtype\n",
    "torch.set_default_dtype(torch.float64)\n",
    "device = 'cpu'\n",
    "eps = 1e-30\n",
    "\n",
    "# Plotting configuration\n",
    "plt.rcParams.update({'font.size': 16})\n",
    "\n",
    "\n",
    "# ============================================================================\n",
    "# Helper Functions\n",
    "# ============================================================================\n",
    "\n",
    "def compute_penalties(B: torch.Tensor, lambda1: float, lambda2: float, lambda3: float) -> torch.Tensor:\n",
    "    \"\"\"\n",
    "    Compute the three penalties for a single matrix B.\n",
    "    Returns: λ1||B||_* + λ2∑||B_i,:||_2 + λ3∑||B_:,j||_2\n",
    "    \"\"\"\n",
    "    # Nuclear norm (sum of singular values)\n",
    "    pen1 = lambda1 * torch.linalg.svdvals(B).sum()\n",
    "    \n",
    "    # Row-wise ℓ2 norms\n",
    "    pen2 = lambda2 * torch.sqrt((B**2).sum(dim=1)).sum()\n",
    "    \n",
    "    # Column-wise ℓ2 norms  \n",
    "    pen3 = lambda3 * torch.sqrt((B**2).sum(dim=0)).sum()\n",
    "    \n",
    "    return pen1 + pen2 + pen3\n",
    "\n",
    "\n",
    "def nuclear_penalty_batch_efficient(B_batch: torch.Tensor, p: int, q: int, \n",
    "                                  lambda1: float) -> torch.Tensor:\n",
    "    \"\"\"\n",
    "    Efficient nuclear norm only (when used separately in Douglas-Rachford)\n",
    "    \"\"\"\n",
    "    if lambda1 == 0:\n",
    "        return torch.zeros(B_batch.shape[0], device=B_batch.device)\n",
    "    \n",
    "    B_reshaped = B_batch.reshape(-1, p, q)\n",
    "    singular_values = torch.linalg.svdvals(B_reshaped)\n",
    "    return lambda1 * singular_values.sum(dim=1)\n",
    "\n",
    "\n",
    "def row_col_penalty_batch_efficient(B_batch: torch.Tensor, p: int, q: int,\n",
    "                                  lambda2: float, lambda3: float) -> torch.Tensor:\n",
    "    \"\"\"\n",
    "    Efficient row and column norms (when used together)\n",
    "    \"\"\"\n",
    "    B_reshaped = B_batch.reshape(-1, p, q)\n",
    "    \n",
    "    penalties = torch.zeros(B_batch.shape[0], device=B_batch.device)\n",
    "    \n",
    "    if lambda2 > 0:\n",
    "        row_norms = torch.norm(B_reshaped, p=2, dim=2)\n",
    "        penalties += lambda2 * row_norms.sum(dim=1)\n",
    "    \n",
    "    if lambda3 > 0:\n",
    "        col_norms = torch.norm(B_reshaped, p=2, dim=1)\n",
    "        penalties += lambda3 * col_norms.sum(dim=1)\n",
    "    \n",
    "    return penalties\n",
    "\n",
    "\n",
    "def total_objective(B: torch.Tensor, X: torch.Tensor, Y: torch.Tensor,\n",
    "                   lambda1: float, lambda2: float, lambda3: float) -> torch.Tensor:\n",
    "    \"\"\"\n",
    "    Compute total objective: 0.5||Y - XB||_F^2 + penalties\n",
    "    \"\"\"\n",
    "    residual = Y - X @ B\n",
    "    smooth_part = 0.5 * (residual**2).sum()\n",
    "    penalty_part = compute_penalties(B, lambda1, lambda2, lambda3)\n",
    "    return smooth_part + penalty_part\n",
    "\n",
    "\n",
    "def total_objective_efficient(B: torch.Tensor, X: torch.Tensor, Y: torch.Tensor,\n",
    "                            lambda1: float, lambda2: float, lambda3: float) -> torch.Tensor:\n",
    "    \"\"\"Efficient computation of total objective\"\"\"\n",
    "    # Data fidelity\n",
    "    residual = Y - X @ B\n",
    "    smooth_part = 0.5 * (residual**2).sum()\n",
    "    \n",
    "    # Penalties (computed efficiently)\n",
    "    if lambda1 > 0:\n",
    "        pen1 = lambda1 * torch.linalg.svdvals(B).sum()\n",
    "    else:\n",
    "        pen1 = 0\n",
    "        \n",
    "    if lambda2 > 0:\n",
    "        pen2 = lambda2 * torch.norm(B, p=2, dim=1).sum()\n",
    "    else:\n",
    "        pen2 = 0\n",
    "        \n",
    "    if lambda3 > 0:\n",
    "        pen3 = lambda3 * torch.norm(B, p=2, dim=0).sum()\n",
    "    else:\n",
    "        pen3 = 0\n",
    "    \n",
    "    return smooth_part + pen1 + pen2 + pen3\n",
    "\n",
    "\n",
    "def compute_prox(x, t, f, delta=1e-1, int_samples=100, alpha=1.0, linesearch_iters=0, device='cpu'):\n",
    "    \"\"\" Estimate proximals from function value sampling via HJ-Prox Algorithm.\n",
    "\n",
    "        The output estimates the proximal:\n",
    "        \n",
    "        $$\n",
    "            \\mathsf{prox_{tf}(x) = argmin_y \\ f(y) + \\dfrac{1}{2t} \\| y - x \\|^2,}\n",
    "        $$\n",
    "            \n",
    "        where $\\mathsf{x}$ = `x` is the input, $\\mathsf{t}$=`t` is the time parameter, \n",
    "        and $\\mathsf{f}$=`f` is the function of interest.\n",
    "\n",
    "        Args:\n",
    "            x (tensor): Input vector\n",
    "            t (tensor): Time > 0\n",
    "            f (Callable): Function to minimize\n",
    "            delta (float, optional): Smoothing parameter\n",
    "            int_samples (int, optional): Number of samples in Monte Carlo sampling for integral\n",
    "            alpha (float, optional): Scaling parameter for sampling variance\n",
    "            linesearch_iters (int, optional): Number of steps used in recursion (used for numerical stability)\n",
    "            device (string, optional): Device on which to store variables\n",
    "\n",
    "        Shape:\n",
    "            - Input `x` is of size `(n, 1)` where `n` is the dimension of the space of interest\n",
    "            - The output `prox_term` also has size `(n, 1)`\n",
    "\n",
    "        Returns:\n",
    "            prox_term (tensor): Estimate of the proximal of f at x\n",
    "            linesearch_iters (int): Number of steps used in recursion (used for numerical stability)\n",
    "            envelope (tensor): Value of envelope function (i.e. infimal convolution) at proximal\n",
    "    \"\"\"\n",
    "    assert x.shape[1] == 1\n",
    "    assert x.shape[0] >= 1\n",
    "    linesearch_iters += 1\n",
    "    standard_dev = torch.sqrt(torch.tensor(delta * t / alpha, device=device))\n",
    "    dim = x.shape[0]\n",
    "    \n",
    "    y = standard_dev * torch.randn(int_samples, dim, device=device, dtype=torch.float32) + x.permute(1, 0)\n",
    "    z = -f(y) * (alpha / delta)\n",
    "    w = torch.softmax(z, dim=0)\n",
    "    \n",
    "    softmax_overflow = 1.0 - (w < np.inf).prod()\n",
    "    if softmax_overflow:\n",
    "        alpha *= 0.5\n",
    "        return compute_prox(x, t, f, delta=delta, int_samples=int_samples, alpha=alpha,\n",
    "                            linesearch_iters=linesearch_iters, device=device)\n",
    "    else:\n",
    "        prox_term = torch.matmul(w.t(), y)\n",
    "        prox_term = prox_term.view(-1, 1)\n",
    "    \n",
    "    prox_overflow = 1.0 - (prox_term < np.inf).prod()\n",
    "    assert not prox_overflow, \"Prox Overflowed\"\n",
    "\n",
    "    envelope = f(prox_term.view(1, -1)) + (1 / (2 * t)) * torch.norm(prox_term - x.permute(1, 0), p=2)**2    \n",
    "    return prox_term, linesearch_iters, envelope\n",
    "\n",
    "\n",
    "# ============================================================================\n",
    "# Algorithm 1: Douglas-Rachford with Analytical Proximal Operators\n",
    "# ============================================================================\n",
    "\n",
    "def douglas_rachford_analytical(\n",
    "    B0: torch.Tensor,\n",
    "    X: torch.Tensor,\n",
    "    Y: torch.Tensor,\n",
    "    lambda1: float,\n",
    "    lambda2: float,\n",
    "    lambda3: float,\n",
    "    gamma: float = 1.0,\n",
    "    max_iters: int = 1000,\n",
    "    prox_iters: int = 100,\n",
    "    tol: float = 1e-10,\n",
    "    verbose: bool = True,\n",
    ") -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:\n",
    "    \"\"\"\n",
    "    Douglas-Rachford algorithm with analytical proximal operators.\n",
    "    \n",
    "    Splitting:\n",
    "    - f(B) = 0.5||Y - XB||_F^2 + λ1||B||_* (data fidelity + nuclear norm)\n",
    "    - g(B) = λ2∑||B_i,:||_2 + λ3∑||B_:,j||_2 (row and column penalties)\n",
    "    \n",
    "    Both proximal operators are computed analytically/iteratively.\n",
    "    \"\"\"\n",
    "    import math\n",
    "    \n",
    "    xk = B0.clone()\n",
    "    p, q = xk.shape\n",
    "    f_hist = torch.zeros(max_iters)\n",
    "    diff_hist = torch.zeros(max_iters)\n",
    "    \n",
    "    # Precompute for prox_f efficiency\n",
    "    XtX = X.T @ X\n",
    "    XtY = X.T @ Y\n",
    "    n = X.shape[0]\n",
    "    \n",
    "    def prox_f(V: torch.Tensor, gamma: float, max_iters: int = 50) -> torch.Tensor:\n",
    "        \"\"\"\n",
    "        Proximal operator of f(B) = 0.5||Y - XB||_F^2 + λ1||B||_*\n",
    "        \n",
    "        This requires solving:\n",
    "        min_B { 0.5||Y - XB||_F^2 + λ1||B||_* + (1/2γ)||B - V||_F^2 }\n",
    "        \n",
    "        Using FISTA (Fast Iterative Shrinkage-Thresholding Algorithm)\n",
    "        \"\"\"\n",
    "        # Lipschitz constant for the smooth part\n",
    "        L_smooth = torch.linalg.eigvalsh(XtX).max() + 1.0 / gamma\n",
    "        \n",
    "        # Initialize\n",
    "        B = V.clone()\n",
    "        B_prev = B.clone()\n",
    "        t = 1.0\n",
    "        \n",
    "        for k in range(max_iters):\n",
    "            # Gradient of smooth part: X'(XB - Y) + (1/γ)(B - V)\n",
    "            grad = XtX @ B - XtY + (1.0 / gamma) * (B - V)\n",
    "            \n",
    "            # Gradient step\n",
    "            B_temp = B - (1.0 / L_smooth) * grad\n",
    "            \n",
    "            # Nuclear norm prox (singular value shrinkage)\n",
    "            U, s, Vt = torch.linalg.svd(B_temp, full_matrices=False)\n",
    "            s_thresh = torch.clamp(s - lambda1 / L_smooth, min=0.0)\n",
    "            B_new = U @ torch.diag(s_thresh) @ Vt\n",
    "            \n",
    "            # FISTA momentum\n",
    "            t_new = (1 + math.sqrt(1 + 4 * t**2)) / 2\n",
    "            B = B_new + ((t - 1) / t_new) * (B_new - B_prev)\n",
    "            \n",
    "            # Check convergence\n",
    "            if (B_new - B_prev).norm() < 1e-8:\n",
    "                break\n",
    "                \n",
    "            B_prev = B_new\n",
    "            t = t_new\n",
    "            \n",
    "        return B_new\n",
    "    \n",
    "    def prox_g_dykstra(V: torch.Tensor, gamma: float, max_iters: int = 50) -> torch.Tensor:\n",
    "        \"\"\"\n",
    "        Proximal operator using Dykstra's algorithm for the sum of row and column penalties.\n",
    "        \"\"\"\n",
    "        # Initialize\n",
    "        B = V.clone()\n",
    "        p_rows = torch.zeros_like(V)\n",
    "        p_cols = torch.zeros_like(V)\n",
    "        \n",
    "        for k in range(max_iters):\n",
    "            B_old = B.clone()\n",
    "            \n",
    "            # Prox for row penalties\n",
    "            Y = B + p_rows\n",
    "            B_rows = torch.zeros_like(Y)\n",
    "            for i in range(p):\n",
    "                row_norm = Y[i, :].norm()\n",
    "                if row_norm > gamma * lambda2:\n",
    "                    B_rows[i, :] = Y[i, :] * (1 - gamma * lambda2 / row_norm)\n",
    "            p_rows = Y - B_rows\n",
    "            \n",
    "            # Prox for column penalties\n",
    "            Y = B_rows + p_cols\n",
    "            B = torch.zeros_like(Y)\n",
    "            for j in range(q):\n",
    "                col_norm = Y[:, j].norm()\n",
    "                if col_norm > gamma * lambda3:\n",
    "                    B[:, j] = Y[:, j] * (1 - gamma * lambda3 / col_norm)\n",
    "            p_cols = Y - B\n",
    "            \n",
    "            # Check convergence\n",
    "            if (B - B_old).norm() < 1e-8:\n",
    "                break\n",
    "                \n",
    "        return B\n",
    "    \n",
    "    # Main Douglas-Rachford loop\n",
    "    for i in range(max_iters):\n",
    "        t0 = time.time()\n",
    "        \n",
    "        # Step 1: y^k = prox_{γg}(x^k)\n",
    "        yk = prox_g_dykstra(xk, gamma, prox_iters)\n",
    "        \n",
    "        # Step 2: z^k = prox_{γf}(2y^k - x^k)\n",
    "        v = 2 * yk - xk\n",
    "        zk = prox_f(v, gamma, prox_iters)\n",
    "        \n",
    "        # Step 3: x^{k+1} = x^k + (z^k - y^k)\n",
    "        xk_new = xk + (zk - yk)\n",
    "        \n",
    "        # Compute metrics\n",
    "        fk = total_objective(yk, X, Y, lambda1, lambda2, lambda3)\n",
    "        diff = (xk_new - xk).norm(p=2)\n",
    "        primal_res = (zk - yk).norm(p=2)\n",
    "        \n",
    "        f_hist[i] = fk.cpu()\n",
    "        diff_hist[i] = diff.item()\n",
    "        xk = xk_new.clone()\n",
    "        \n",
    "        if verbose and i % 10 == 0:\n",
    "            print(f\"DRS-Analytical iter {i+1:4d}: f={fk:.6f}, ||Δx||={diff:.6f}, \"\n",
    "                  f\"||z-y||={primal_res:.6f}, time={time.time() - t0:.4f}s\")\n",
    "        \n",
    "        if diff < tol:\n",
    "            f_hist = f_hist[:i+1]\n",
    "            diff_hist = diff_hist[:i+1]\n",
    "            break\n",
    "    \n",
    "    return yk, f_hist, diff_hist\n",
    "\n",
    "\n",
    "# ============================================================================\n",
    "# Algorithm 2: Douglas-Rachford with HJ-Prox\n",
    "# ============================================================================\n",
    "\n",
    "def douglas_rachford_efficient(\n",
    "    B0: torch.Tensor,\n",
    "    X: torch.Tensor,\n",
    "    Y: torch.Tensor,\n",
    "    lambda1: float,\n",
    "    lambda2: float,\n",
    "    lambda3: float,\n",
    "    gamma: float = 1.0,\n",
    "    max_iters: int = 1000,\n",
    "    int_samples: int = 1000,\n",
    "    use_adaptive_delta: bool = True,\n",
    "    tol: float = 1e-6,\n",
    "    verbose: bool = True,\n",
    ") -> tuple:\n",
    "    \"\"\"\n",
    "    Douglas-Rachford with efficient penalty computations using HJ-Prox\n",
    "    \"\"\"\n",
    "    xk = B0.clone()\n",
    "    p, q = xk.shape\n",
    "    device = xk.device\n",
    "    \n",
    "    f_hist = []\n",
    "    time_hist = []\n",
    "    \n",
    "    def objective_f(b_batch):\n",
    "        \"\"\"f(B) = 0.5||Y - XB||_F^2 + λ1||B||_*\"\"\"\n",
    "        B_reshaped = b_batch.reshape(-1, p, q)\n",
    "        n_samples = B_reshaped.shape[0]\n",
    "        \n",
    "        # Vectorized data fidelity\n",
    "        residuals = Y.unsqueeze(0) - torch.bmm(\n",
    "            X.unsqueeze(0).expand(n_samples, -1, -1),\n",
    "            B_reshaped\n",
    "        )\n",
    "        data_fid = 0.5 * (residuals**2).sum(dim=(1, 2))\n",
    "        \n",
    "        # Exact nuclear norm (efficient batch computation)\n",
    "        nuclear = nuclear_penalty_batch_efficient(b_batch, p, q, lambda1)\n",
    "        \n",
    "        return data_fid + nuclear\n",
    "    \n",
    "    def objective_g(b_batch):\n",
    "        \"\"\"g(B) = λ2∑||B_i,:||_2 + λ3∑||B_:,j||_2 (efficient)\"\"\"\n",
    "        return row_col_penalty_batch_efficient(b_batch, p, q, lambda2, lambda3)\n",
    "    \n",
    "    if verbose:\n",
    "        print(f\"Starting Douglas-Rachford with HJ-Prox...\")\n",
    "        print(f\"λ₁={lambda1}, λ₂={lambda2}, λ₃={lambda3}, γ={gamma}\")\n",
    "        print(\"-\" * 70)\n",
    "    \n",
    "    for i in range(max_iters):\n",
    "        t_start = time.time()\n",
    "        \n",
    "        # Compute delta with annealing schedule\n",
    "        delta = 200000 / (i + 1)**(2 + eps)\n",
    "        \n",
    "        # Step 1: y^k = prox_{γg}(x^k)\n",
    "        x_flat = xk.view(-1, 1)\n",
    "        y_flat, _, _ = compute_prox(\n",
    "            x_flat, gamma, objective_g,\n",
    "            delta=delta, int_samples=int_samples,\n",
    "            device=device\n",
    "        )\n",
    "        yk = y_flat.view(p, q)\n",
    "        \n",
    "        # Step 2: z^k = prox_{γf}(2y^k - x^k)\n",
    "        v = 2 * yk - xk\n",
    "        v_flat = v.view(-1, 1)\n",
    "        z_flat, _, _ = compute_prox(\n",
    "            v_flat, gamma, objective_f,\n",
    "            delta=delta, int_samples=int_samples,\n",
    "            device=device\n",
    "        )\n",
    "        zk = z_flat.view(p, q)\n",
    "        \n",
    "        # Step 3: x^{k+1} = x^k + (z^k - y^k)\n",
    "        xk_new = xk + (zk - yk)\n",
    "        \n",
    "        # Compute metrics efficiently\n",
    "        fk = total_objective_efficient(yk, X, Y, lambda1, lambda2, lambda3)\n",
    "        diff = (xk_new - xk).norm(p='fro')\n",
    "        rel_diff = diff / (xk.norm(p='fro') + 1e-10)\n",
    "        \n",
    "        t_iter = time.time() - t_start\n",
    "        f_hist.append(fk.item())\n",
    "        time_hist.append(t_iter)\n",
    "        \n",
    "        xk = xk_new\n",
    "        \n",
    "        if verbose and i % 10 == 0:\n",
    "            print(f\"Iter {i+1:4d}: f={fk:.6f}, ||Δx||={rel_diff:.2e}, \"\n",
    "                  f\"δ={delta:.2e}, time={t_iter:.3f}s\")\n",
    "        \n",
    "        if rel_diff < tol:\n",
    "            if verbose:\n",
    "                print(f\"\\nConverged at iteration {i+1}\")\n",
    "            break\n",
    "    \n",
    "    return yk, torch.tensor(f_hist), torch.tensor(time_hist)\n",
    "\n",
    "\n",
    "print(\"✓ All algorithms and helper functions loaded successfully\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ============================================================================\n",
    "# CHUNK 2: DATA GENERATION\n",
    "# ============================================================================\n",
    "\n",
    "print(\"\\n\" + \"=\"*60)\n",
    "print(\"Generating multitask learning data...\")\n",
    "print(\"=\"*60)\n",
    "\n",
    "# Define true coefficient patterns\n",
    "beta_1a = np.array([1, 0, 0, 1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0])\n",
    "beta_1b = np.array([0, 0, 0, 1, 1, 0, 0, 1, 0])\n",
    "beta_2a = np.array([0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0])\n",
    "beta_2b = np.array([0, 1, 1, 0, 0, 0, 0, 1, 0])\n",
    "\n",
    "# Replicate vectors\n",
    "beta_1a = np.tile(beta_1a, 2)\n",
    "beta_2a = np.tile(beta_2a, 2)\n",
    "\n",
    "# True coefficient matrix\n",
    "B1 = np.outer(beta_1a, beta_1b)\n",
    "B2 = np.outer(beta_2a, beta_2b)\n",
    "B_true = B1 + B2\n",
    "\n",
    "# Generate data\n",
    "n, sigma = 50, 1.1\n",
    "X_np = np.random.normal(0, sigma, size=(n, len(beta_1a)))\n",
    "E = np.random.normal(0, sigma, size=(n, len(beta_1b)))\n",
    "Y_np = X_np @ B_true + E\n",
    "\n",
    "# Convert to torch\n",
    "X = torch.tensor(X_np, dtype=torch.float64)\n",
    "Y = torch.tensor(Y_np, dtype=torch.float64)\n",
    "\n",
    "# Compute Lipschitz constant\n",
    "U, s, Vh = np.linalg.svd(X_np, full_matrices=False)\n",
    "L = np.max(s) ** 2\n",
    "\n",
    "print(f\"✓ Data generated: X shape {X.shape}, Y shape {Y.shape}\")\n",
    "print(f\"✓ True coefficient matrix B shape: {B_true.shape}\")\n",
    "print(f\"✓ Lipschitz constant L = {L:.4f}\")\n",
    "\n",
    "# Initial point\n",
    "B0 = torch.zeros((len(beta_1a), len(beta_1b)), dtype=torch.float64)\n",
    "\n",
    "# Penalty parameters (same for all algorithms)\n",
    "lambda1 = 11.0  # Nuclear norm\n",
    "lambda2 = 11.0  # Row norms\n",
    "lambda3 = 11.0  # Column norms\n",
    "\n",
    "print(f\"✓ Penalty parameters: λ1={lambda1}, λ2={lambda2}, λ3={lambda3}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ============================================================================\n",
    "# CHUNK 3: RUN ALGORITHM 1 - Analytical Douglas-Rachford\n",
    "# ============================================================================\n",
    "\n",
    "print(\"\\n\" + \"=\"*60)\n",
    "print(\"Running Algorithm 1: Douglas-Rachford with Analytical Proximal Operators...\")\n",
    "print(\"=\"*60)\n",
    "\n",
    "start_time = time.time()\n",
    "\n",
    "B_Analytical, f_hist_Analytical, diff_hist_Analytical = douglas_rachford_analytical(\n",
    "    B0=B0,\n",
    "    X=X,\n",
    "    Y=Y,\n",
    "    lambda1=lambda1,\n",
    "    lambda2=lambda2,\n",
    "    lambda3=lambda3,\n",
    "    gamma=1.0 / L * 0.005,\n",
    "    max_iters=10000,\n",
    "    prox_iters=100,\n",
    "    tol=1e-10,\n",
    "    verbose=True\n",
    ")\n",
    "\n",
    "elapsed_time = time.time() - start_time\n",
    "\n",
    "print(f\"\\n✓ DRS-Analytical completed\")\n",
    "print(f\"  - Runtime: {elapsed_time:.2f} seconds\")\n",
    "print(f\"  - Converged in {len(f_hist_Analytical)} iterations\")\n",
    "print(f\"  - Final objective: {f_hist_Analytical[-1]:.6f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ============================================================================\n",
    "# CHUNK 4: RUN ALGORITHM 2 - Douglas-Rachford with HJ-Prox\n",
    "# ============================================================================\n",
    "\n",
    "print(\"\\n\" + \"=\"*60)\n",
    "print(\"Running Algorithm 2: Douglas-Rachford with HJ-Prox...\")\n",
    "print(\"=\"*60)\n",
    "\n",
    "start_time = time.time()\n",
    "\n",
    "B_HJ, f_hist_HJ, time_hist_HJ = douglas_rachford_efficient(\n",
    "    B0=B0,\n",
    "    X=X,\n",
    "    Y=Y,\n",
    "    lambda1=lambda1,\n",
    "    lambda2=lambda2,\n",
    "    lambda3=lambda3,\n",
    "    gamma=1.0 / L * 0.005,\n",
    "    max_iters=10000,\n",
    "    int_samples=1000,\n",
    "    verbose=True\n",
    ")\n",
    "\n",
    "elapsed_time = time.time() - start_time\n",
    "\n",
    "print(f\"\\n✓ DRS-HJ completed\")\n",
    "print(f\"  - Runtime: {elapsed_time:.2f} seconds\")\n",
    "print(f\"  - Converged in {len(f_hist_HJ)} iterations\")\n",
    "print(f\"  - Final objective: {f_hist_HJ[-1]:.6f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ============================================================================\n",
    "# CHUNK 5: GENERATE FIGURES\n",
    "# ============================================================================\n",
    "\n",
    "print(\"\\n\" + \"=\"*60)\n",
    "print(\"Generating figures...\")\n",
    "print(\"=\"*60)\n",
    "\n",
    "# Compute shared vmin/vmax for consistent coloring across plots\n",
    "all_vals = np.concatenate([\n",
    "    B_true.flatten(),\n",
    "    B_Analytical.numpy().flatten(),\n",
    "    B_HJ.numpy().flatten()\n",
    "])\n",
    "vmin, vmax = all_vals.min(), all_vals.max()\n",
    "if vmin < 0 < vmax:\n",
    "    m = max(abs(vmin), abs(vmax))\n",
    "    vmin, vmax = -m, m\n",
    "\n",
    "# --- Figure 1: True Coefficient Matrix ---\n",
    "plt.figure(figsize=(11, 10))\n",
    "im = plt.imshow(B_true, aspect='auto', interpolation='nearest',\n",
    "                cmap='RdBu_r', vmin=vmin, vmax=vmax)\n",
    "plt.title('True Coefficient Matrix', fontsize=40)\n",
    "plt.xlabel('Tasks (9)', fontsize=40)\n",
    "plt.ylabel('Features (30)', fontsize=40)\n",
    "cbar = plt.colorbar(im, fraction=0.046, pad=0.04)\n",
    "cbar.ax.tick_params(labelsize=40)\n",
    "plt.gca().tick_params(axis='both', which='major', labelsize=40)\n",
    "plt.gca().xaxis.set_major_locator(MaxNLocator(integer=True))\n",
    "plt.tight_layout()\n",
    "plt.savefig('multitask_ground_truth.pdf', format='pdf', dpi=600, bbox_inches='tight')\n",
    "plt.show()\n",
    "plt.close()\n",
    "\n",
    "# --- Figure 2: DRS Solution ---\n",
    "plt.figure(figsize=(11, 10))\n",
    "im = plt.imshow(B_Analytical.numpy(), aspect='auto', interpolation='nearest',\n",
    "                cmap='RdBu_r', vmin=vmin, vmax=vmax)\n",
    "plt.title('DRS', fontsize=40)\n",
    "plt.xlabel('Tasks (9)', fontsize=40)\n",
    "plt.ylabel('Features (30)', fontsize=40)\n",
    "cbar = plt.colorbar(im, fraction=0.046, pad=0.04)\n",
    "cbar.ax.tick_params(labelsize=40)\n",
    "plt.gca().tick_params(axis='both', which='major', labelsize=40)\n",
    "plt.gca().xaxis.set_major_locator(MaxNLocator(integer=True))\n",
    "plt.tight_layout()\n",
    "plt.savefig('multitask_drs.pdf', format='pdf', dpi=600, bbox_inches='tight')\n",
    "plt.show()\n",
    "plt.close()\n",
    "\n",
    "# --- Figure 3: DRS-HJ Solution ---\n",
    "plt.figure(figsize=(11, 10))\n",
    "im = plt.imshow(B_HJ.numpy(), aspect='auto', interpolation='nearest',\n",
    "                cmap='RdBu_r', vmin=vmin, vmax=vmax)\n",
    "plt.title('DRS-HJ', fontsize=40)\n",
    "plt.xlabel('Tasks (9)', fontsize=40)\n",
    "plt.ylabel('Features (30)', fontsize=40)\n",
    "cbar = plt.colorbar(im, fraction=0.046, pad=0.04)\n",
    "cbar.ax.tick_params(labelsize=40)\n",
    "plt.gca().tick_params(axis='both', which='major', labelsize=40)\n",
    "plt.gca().xaxis.set_major_locator(MaxNLocator(integer=True))\n",
    "plt.tight_layout()\n",
    "plt.savefig('multitask_drs_hj.pdf', format='pdf', dpi=600, bbox_inches='tight')\n",
    "plt.show()\n",
    "plt.close()\n",
    "\n",
    "# --- Figure 4: Objective Convergence ---\n",
    "plt.figure(figsize=(11, 10))\n",
    "plt.semilogy(f_hist_HJ.numpy(), '-', linewidth=2,\n",
    "             label=f'DRS-HJ: {f_hist_HJ[-1].item():.3f}')\n",
    "plt.semilogy(f_hist_Analytical.numpy(), '--', linewidth=2,\n",
    "             label=f'DRS: {f_hist_Analytical[-1].item():.3f}')\n",
    "plt.title('DRS Convergence', fontsize=40)\n",
    "plt.xlabel('Iteration', fontsize=40)\n",
    "plt.ylabel('Objective (log scale)', fontsize=40)\n",
    "plt.legend(fontsize=40, loc=\"upper left\")\n",
    "plt.grid(True, alpha=0.3, which='both')\n",
    "plt.gca().tick_params(axis='both', which='major', labelsize=40)\n",
    "max_iter = len(f_hist_HJ)\n",
    "tick_positions = np.arange(0, max_iter + 1, 2500)\n",
    "plt.gca().set_xticks(tick_positions)\n",
    "plt.tight_layout()\n",
    "plt.savefig('multitask_objectives.pdf', format='pdf', dpi=600, bbox_inches='tight')\n",
    "plt.show()\n",
    "plt.close()\n",
    "\n",
    "print(\"✓ All figures saved: multitask_ground_truth.pdf, multitask_drs.pdf, multitask_drs_hj.pdf, multitask_objectives.pdf\")\n",
    "print(\"\\n\" + \"=\"*60)\n",
    "print(\"✓ ALL EXPERIMENTS COMPLETED SUCCESSFULLY\")\n",
    "print(\"=\"*60)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.10"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
