{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "deadly-rolling",
   "metadata": {},
   "source": [
    "## Probability of existence of consistent attack with power law covariance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "together-ghost",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from scipy.special import erf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "helper-functions",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_p_star(p):\n",
    "    if p==1:\n",
    "        return np.inf\n",
    "    elif p==np.inf:\n",
    "        return 1\n",
    "    else: \n",
    "        return p / (p - 1)\n",
    "    \n",
    "def lp_norm(x, p):\n",
    "    if p == np.inf:\n",
    "        return np.max(np.abs(x))\n",
    "    else:\n",
    "        return np.sum(np.abs(x)**p)**(1/p)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "analytic-function",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_powerlaw_covariance(d, alpha, offset=0):\n",
    "    \"\"\"λ_i ~ (i + offset)^(-alpha), making early eigenvalues more similar\"\"\"\n",
    "    \n",
    "    indices = np.arange(1, d + 1) + offset\n",
    "    eigenvalues = indices ** (-alpha)\n",
    "    \n",
    "    # Normalize so trace = d\n",
    "    eigenvalues = eigenvalues / np.sum(eigenvalues) * d\n",
    "    \n",
    "    Sigma = np.diag(eigenvalues)\n",
    "    return Sigma, eigenvalues\n",
    "\n",
    "def get_analytic_powerlaw(theta_star, hat_theta, Sigma, eps, p=2):\n",
    "    \"\"\"Compute analytical probability with general covariance\n",
    "    \n",
    "    Formula: P = erf(eps * d_star / sigma_norm / sqrt(2))\n",
    "    where d_star = ||theta_perp||_{p*} and sigma_norm = sqrt(hat_theta^T Sigma hat_theta)\n",
    "    \n",
    "    With lambda_1 = 1, this simplifies to:\n",
    "    sigma_norm = sqrt(m^2 * d + sum_{i=2}^d c_i^2 * lambda_i)\n",
    "    \"\"\"\n",
    "    d = len(theta_star)\n",
    "    p_star = get_p_star(p)\n",
    "    m = theta_star.dot(hat_theta) / d\n",
    "    theta_perp = hat_theta - m * theta_star\n",
    "    \n",
    "    # Compute ||θ_perp||_{p*}\n",
    "    d_star = lp_norm(theta_perp, p_star)\n",
    "    \n",
    "    # Compute sqrt(hat_theta^T Sigma hat_theta)\n",
    "    # This works correctly with the new eigenvalue normalization\n",
    "    sigma_norm = np.sqrt(hat_theta @ Sigma @ hat_theta)\n",
    "    \n",
    "    # Probability formula (using erf instead of 2*Phi - 1)\n",
    "    return erf(eps * d_star / sigma_norm / np.sqrt(2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "create-predictors",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "def create_power_law_eigenvalues(d, beta):\n",
    "    \"\"\"Create power law eigenvalues with lambda_1 = 1, sum(lambda_2:d) = d-1\n",
    "    \n",
    "    Args:\n",
    "        d: dimension\n",
    "        beta: power law exponent\n",
    "    \n",
    "    Returns:\n",
    "        eigenvalues: array of shape (d,) with lambda_i for i=1,...,d\n",
    "    \"\"\"\n",
    "    eigenvalues = np.zeros(d)\n",
    "    eigenvalues[0] = 1.0  # First eigenvalue is always 1\n",
    "    \n",
    "    if d > 1:\n",
    "        # Create power law for indices 2 to d\n",
    "        indices = np.arange(2, d + 1)  # i = 2, 3, ..., d\n",
    "        raw_values = indices ** (-beta)\n",
    "        \n",
    "        # Normalize so sum(lambda_2:d) = d - 1\n",
    "        eigenvalues[1:] = (d - 1) * raw_values / np.sum(raw_values)\n",
    "    \n",
    "    return eigenvalues\n",
    "\n",
    "def create_low_rank_predictor(theta_star, eigenvalues, k, d, m=0.5):\n",
    "    \"\"\"Create predictor concentrated on top-k eigenvectors\n",
    "    Deterministic: c_i = 1/sqrt(k-1) for i in {2,...,k}, 0 otherwise\"\"\"\n",
    "    \n",
    "    hat_theta = np.zeros(d)\n",
    "    hat_theta[0] = m * np.sqrt(d)  # Correlation with theta_star\n",
    "    \n",
    "    # Orthogonal component norm\n",
    "    orth_norm = np.sqrt(d * (1 - m**2))\n",
    "    \n",
    "    # Distribute equally among top-k eigenvectors (indices 1 to k-1 in 0-indexed)\n",
    "    if k > 1:\n",
    "        for i in range(1, min(k, d)):\n",
    "            hat_theta[i] = orth_norm / np.sqrt(k - 1)\n",
    "    \n",
    "    # Normalize to have norm sqrt(d)\n",
    "    hat_theta = hat_theta / np.linalg.norm(hat_theta) * np.sqrt(d)\n",
    "    \n",
    "    return hat_theta\n",
    "\n",
    "def create_uniform_predictor(theta_star, eigenvalues, d, m=0.5):\n",
    "    \"\"\"Create predictor uniformly spread: c_i = 1/sqrt(d-1) for i in {2,...,d}\"\"\"\n",
    "    \n",
    "    hat_theta = np.zeros(d)\n",
    "    hat_theta[0] = m * np.sqrt(d)\n",
    "    \n",
    "    # Uniform distribution on orthogonal part\n",
    "    orth_norm = np.sqrt(d * (1 - m**2))\n",
    "    for i in range(1, d):\n",
    "        hat_theta[i] = orth_norm / np.sqrt(d - 1)\n",
    "    \n",
    "    # Normalize\n",
    "    hat_theta = hat_theta / np.linalg.norm(hat_theta) * np.sqrt(d)\n",
    "    \n",
    "    return hat_theta\n",
    "\n",
    "def create_tail_heavy_predictor(theta_star, eigenvalues, k, d, m=0.5):\n",
    "    \"\"\"Create predictor concentrated on tail eigenvectors (indices k:d)\n",
    "    Deterministic: c_i = 0 for i<k, 1/sqrt(d-k) for i>=k\"\"\"\n",
    "    \n",
    "    hat_theta = np.zeros(d)\n",
    "    hat_theta[0] = m * np.sqrt(d)\n",
    "    \n",
    "    # Concentrate on tail\n",
    "    orth_norm = np.sqrt(d * (1 - m**2))\n",
    "    for i in range(k, d):\n",
    "        hat_theta[i] = orth_norm / np.sqrt(d - k)\n",
    "    \n",
    "    # Normalize\n",
    "    hat_theta = hat_theta / np.linalg.norm(hat_theta) * np.sqrt(d)\n",
    "    \n",
    "    return hat_theta\n",
    "\n",
    "def compute_denominator(hat_theta, eigenvalues):\n",
    "    \"\"\"Compute sqrt(hat_w^T Sigma hat_w) for probability formula\n",
    "    \n",
    "    With lambda_1 = 1 and hat_theta = [m*sqrt(d), c_2, ..., c_d]:\n",
    "    hat_w^T Sigma hat_w = m^2 * d * 1 + sum_{i=2}^d c_i^2 * lambda_i\n",
    "                        = m^2 * d + sum_{i=2}^d c_i^2 * lambda_i\n",
    "    \"\"\"\n",
    "    d = len(hat_theta)\n",
    "    \n",
    "    # First term: m^2 * d * lambda_1 = m^2 * d * 1\n",
    "    term1 = (hat_theta[0]**2) * eigenvalues[0]\n",
    "    \n",
    "    # Second term: sum of c_i^2 * lambda_i for i >= 2\n",
    "    term2 = np.sum(hat_theta[1:]**2 * eigenvalues[1:])\n",
    "    \n",
    "    return np.sqrt(term1 + term2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "setup-params",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Parameters\n",
    "d = 75\n",
    "alpha = 1.5  # Power law exponent\n",
    "m = 0.5  # Correlation with ground truth\n",
    "k = 10  # Number of top/tail eigenvectors\n",
    "n_samples = 10_000\n",
    "p = 2  # Attack norm\n",
    "\n",
    "# Generate power law covariance (diagonal in standard basis)\n",
    "Sigma, eigenvalues = get_powerlaw_covariance(d, alpha)\n",
    "\n",
    "# Create ground truth (first standard basis vector)\n",
    "theta_star = np.zeros(d)\n",
    "theta_star[0] = np.sqrt(d)\n",
    "\n",
    "print(f\"Dimension: {d}\")\n",
    "print(f\"Power law exponent α: {alpha}\")\n",
    "print(f\"Top-k: {k}\")\n",
    "print(f\"Eigenvalue ratio λ_1/λ_d: {eigenvalues[0]/eigenvalues[-1]:.2f}\")\n",
    "print(f\"Eigenvalue ratio λ_k/λ_d: {eigenvalues[k-1]/eigenvalues[-1]:.2f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "plot-eigenvalues",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot eigenvalue spectrum\n",
    "plt.figure(figsize=(6, 4))\n",
    "plt.semilogy(range(1, d+1), eigenvalues, 'o-', markersize=3, label='Eigenvalues')\n",
    "# plt.semilogy(range(1, d+1), [(i)**(-alpha) * eigenvalues[0] for i in range(1, d+1)], '--', label=f'$i^{{-{alpha}}}$ (theoretical)')\n",
    "plt.axvline(k, color='red', linestyle=':', alpha=0.5, label=f'k={k}')\n",
    "plt.xlabel('Eigenvalue index $i$')\n",
    "plt.ylabel('Eigenvalue $\\\\lambda_i$')\n",
    "plt.title(f'Power law spectrum: $\\\\lambda_i \\\\sim i^{{-{alpha}}}$')\n",
    "plt.legend()\n",
    "plt.grid(True, alpha=0.3)\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "three-regimes",
   "metadata": {},
   "source": [
    "## Test three alignment regimes (deterministic)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "test-alignments",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create three types of predictors (DETERMINISTIC)\n",
    "hat_theta_lowrank = create_low_rank_predictor(theta_star, eigenvalues, k, d, m)\n",
    "hat_theta_uniform = create_uniform_predictor(theta_star, eigenvalues, d, m)\n",
    "hat_theta_tail = create_tail_heavy_predictor(theta_star, eigenvalues, k, d, m)\n",
    "\n",
    "predictors = {\n",
    "    'Low-rank (top-{})'.format(k): hat_theta_lowrank,\n",
    "    'Uniform spread': hat_theta_uniform,\n",
    "    'Tail-heavy ({}:d)'.format(k): hat_theta_tail\n",
    "}\n",
    "\n",
    "# Verify determinism\n",
    "print(\"Verifying predictors are deterministic:\")\n",
    "hat_theta_lowrank2 = create_low_rank_predictor(theta_star, eigenvalues, k, d, m)\n",
    "print(f\"Low-rank reproducible: {np.allclose(hat_theta_lowrank, hat_theta_lowrank2)}\")\n",
    "\n",
    "# Compute key quantities for each predictor\n",
    "print(\"\\nPredictor statistics:\")\n",
    "print(\"-\" * 80)\n",
    "for name, hat_theta in predictors.items():\n",
    "    p_star = get_p_star(p)\n",
    "    theta_perp = hat_theta - (hat_theta @ theta_star / d) * theta_star\n",
    "    d_star = lp_norm(theta_perp, p_star)\n",
    "    sigma_norm = np.sqrt(hat_theta @ Sigma @ hat_theta)\n",
    "    ratio = d_star / sigma_norm\n",
    "    \n",
    "    # Show coefficient distribution\n",
    "    c_squared = hat_theta**2\n",
    "    top_k_weight = np.sum(c_squared[:k]) / np.sum(c_squared)\n",
    "    \n",
    "    print(f\"{name:25s}: d*/σ = {ratio:8.3f}  (larger = easier attack)\")\n",
    "    print(f\"{'':25s}  d* = {d_star:8.3f}, σ = {sigma_norm:8.3f}\")\n",
    "    print(f\"{'':25s}  Top-{k} weight: {top_k_weight:.1%}\")\n",
    "    print()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "plot-comparison",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plotting setup\n",
    "width = 225.84377 * 1.25\n",
    "\n",
    "def set_size(width, fraction=1, subplots=(1, 1)):\n",
    "    if width == \"thesis\":\n",
    "        width_pt = 426.79135\n",
    "    elif width == \"beamer\":\n",
    "        width_pt = 307.28987\n",
    "    else:\n",
    "        width_pt = width\n",
    "\n",
    "    fig_width_pt = width_pt * fraction\n",
    "    inches_per_pt = 1 / 72.27\n",
    "    golden_ratio = (5**0.5 - 1) / 2\n",
    "    fig_width_in = fig_width_pt * inches_per_pt\n",
    "    fig_height_in = fig_width_in * (subplots[0] / subplots[1]) #  * (golden_ratio)\n",
    "\n",
    "    return (fig_width_in, fig_height_in)\n",
    "\n",
    "tuple_size = set_size(width, fraction=0.33)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "vary-alpha",
   "metadata": {},
   "source": [
    "## Vary power law exponent α"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5c1c6df0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Fixed parameters\n",
    "eps_fixed = 1.0\n",
    "alphas = np.linspace(0.01, 1.5, 40)\n",
    "d = 30\n",
    "k = 2\n",
    "m = 0.5\n",
    "p = 2\n",
    "n_samples = 100_000  # Number of samples for empirical estimation\n",
    "\n",
    "# Storage for results\n",
    "probs_lowrank = []\n",
    "probs_uniform = []\n",
    "probs_tail = []\n",
    "\n",
    "# Select subset of alphas for empirical verification\n",
    "alphas_empirical = np.linspace(0.01, 1.5, 15)\n",
    "probs_lowrank_emp = []\n",
    "probs_uniform_emp = []\n",
    "probs_tail_emp = []\n",
    "\n",
    "theta_star = np.zeros(d)\n",
    "theta_star[0] = np.sqrt(d)\n",
    "\n",
    "plt.style.use(\"./latex_ready.mplstyle\")\n",
    "\n",
    "# Compute probability for each alpha (analytical)\n",
    "for alpha in alphas:\n",
    "    Sigma_alpha, eigenvalues_alpha = get_powerlaw_covariance(d, alpha)\n",
    "    \n",
    "    # Create three predictors\n",
    "    hat_lr = create_low_rank_predictor(theta_star, eigenvalues_alpha, k, d, m)\n",
    "    hat_unif = create_uniform_predictor(theta_star, eigenvalues_alpha, d, m)\n",
    "    hat_tail = create_tail_heavy_predictor(theta_star, eigenvalues_alpha, k, d, m)\n",
    "    \n",
    "    # Compute probabilities\n",
    "    prob_lr = get_analytic_powerlaw(theta_star, hat_lr, Sigma_alpha, eps_fixed, p)\n",
    "    prob_unif = get_analytic_powerlaw(theta_star, hat_unif, Sigma_alpha, eps_fixed, p)\n",
    "    prob_tail = get_analytic_powerlaw(theta_star, hat_tail, Sigma_alpha, eps_fixed, p)\n",
    "    \n",
    "    probs_lowrank.append(prob_lr)\n",
    "    probs_uniform.append(prob_unif)\n",
    "    probs_tail.append(prob_tail)\n",
    "\n",
    "# Compute empirical probabilities for subset of alphas\n",
    "p_star = get_p_star(p)\n",
    "for alpha in alphas_empirical:\n",
    "    Sigma_alpha, eigenvalues_alpha = get_powerlaw_covariance(d, alpha)\n",
    "    \n",
    "    # Create three predictors\n",
    "    hat_lr = create_low_rank_predictor(theta_star, eigenvalues_alpha, k, d, m)\n",
    "    hat_unif = create_uniform_predictor(theta_star, eigenvalues_alpha, d, m)\n",
    "    hat_tail = create_tail_heavy_predictor(theta_star, eigenvalues_alpha, k, d, m)\n",
    "    \n",
    "    # Sample from N(0, Sigma)\n",
    "    x = np.random.randn(n_samples, d) * np.sqrt(eigenvalues_alpha[np.newaxis, :])\n",
    "    \n",
    "    # Compute d_star for each predictor\n",
    "    theta_perp_lr = hat_lr - (hat_lr @ theta_star / d) * theta_star\n",
    "    theta_perp_unif = hat_unif - (hat_unif @ theta_star / d) * theta_star\n",
    "    theta_perp_tail = hat_tail - (hat_tail @ theta_star / d) * theta_star\n",
    "    \n",
    "    d_star_lr = lp_norm(theta_perp_lr, p_star)\n",
    "    d_star_unif = lp_norm(theta_perp_unif, p_star)\n",
    "    d_star_tail = lp_norm(theta_perp_tail, p_star)\n",
    "    \n",
    "    # Check condition: eps * d_star >= |<hat_theta, x>|\n",
    "    margins_lr = np.abs(x @ hat_lr)\n",
    "    margins_unif = np.abs(x @ hat_unif)\n",
    "    margins_tail = np.abs(x @ hat_tail)\n",
    "    \n",
    "    prob_lr_emp = np.mean(eps_fixed * d_star_lr >= margins_lr)\n",
    "    prob_unif_emp = np.mean(eps_fixed * d_star_unif >= margins_unif)\n",
    "    prob_tail_emp = np.mean(eps_fixed * d_star_tail >= margins_tail)\n",
    "    \n",
    "    probs_lowrank_emp.append(prob_lr_emp)\n",
    "    probs_uniform_emp.append(prob_unif_emp)\n",
    "    probs_tail_emp.append(prob_tail_emp)\n",
    "\n",
    "# Plot\n",
    "fig, ax = plt.subplots(1, 1, figsize=(tuple_size[0], tuple_size[1]))\n",
    "fig.subplots_adjust(left=0.20, bottom=0.18, top=0.98, right=0.96)\n",
    "\n",
    "# Analytical curves\n",
    "ax.plot(alphas, probs_lowrank, '-', linewidth=1.5, label='Low-rank', color='C0')\n",
    "ax.plot(alphas, probs_tail, '-', linewidth=1.5, label='Heavy-tail', color='C1')\n",
    "ax.plot(alphas, probs_uniform, '-', linewidth=1.5, label='Unif.', color='C2')\n",
    "\n",
    "# Empirical points\n",
    "ax.plot(alphas_empirical, probs_lowrank_emp, '.', markersize=3, color='C0')\n",
    "ax.plot(alphas_empirical, probs_tail_emp, '.', markersize=3, color='C1')\n",
    "ax.plot(alphas_empirical, probs_uniform_emp, '.', markersize=3, color='C2')\n",
    "\n",
    "ax.set_xlabel(r'$\\beta$', labelpad=0.0)\n",
    "ax.set_ylabel(r'prob. of existence', labelpad=1.0)\n",
    "ax.grid(True)\n",
    "\n",
    "legend = ax.legend(handlelength=1.5, frameon=False,loc='lower right', bbox_to_anchor=(1.1, 0.95), labelspacing=0.3, columnspacing=0.75)\n",
    "\n",
    "plt.savefig('prob_vs_alpha_eps{}_q{}_d{}_k{}_tiny.pdf'.format(eps_fixed, p, d, k), bbox_extra_artists=[legend], bbox_inches='tight')\n",
    "\n",
    "# plt.savefig('prob_vs_alpha_eps{}_q{}_d{}_k{}.pdf'.format(eps_fixed, p, d, k), bbox_inches='tight', dpi=300)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5c938417",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# -----------------------\n",
    "# CONFIG\n",
    "# -----------------------\n",
    "eps_fixed = 1.0\n",
    "max_alpha = 3.5\n",
    "alphas = np.linspace(0.01, max_alpha, 100)\n",
    "\n",
    "alphas_empirical = np.linspace(0.01, max_alpha, 10)\n",
    "n_samples = 100_000\n",
    "\n",
    "k = 10\n",
    "m = 0.5\n",
    "p = 2\n",
    "\n",
    "# Dimensions you want to compare\n",
    "dims = [2 ** i for i in range(4, 12, 3)]   # <- change as you like\n",
    "\n",
    "# Styles per dimension (cycled if you provide fewer than dims)\n",
    "linestyles = ['-', '--', ':', '-.']\n",
    "markers = ['o', 's', 'd', '^']  # for empirical points (optional)\n",
    "\n",
    "# Fix colors for the two predictor types (constant across dimensions)\n",
    "COLOR_LR = 'C0'\n",
    "COLOR_TAIL = 'C1'\n",
    "\n",
    "plt.style.use(\"./latex_ready.mplstyle\")\n",
    "\n",
    "# -----------------------\n",
    "# HELPERS\n",
    "# -----------------------\n",
    "def theta_star_for_dim(d: int) -> np.ndarray:\n",
    "    th = np.zeros(d)\n",
    "    th[0] = np.sqrt(d)\n",
    "    return th\n",
    "\n",
    "def compute_curves_for_dim(d: int):\n",
    "    \"\"\"\n",
    "    Returns:\n",
    "      probs_lr, probs_tail (analytical over alphas),\n",
    "      probs_lr_emp, probs_tail_emp (empirical over alphas_empirical)\n",
    "    \"\"\"\n",
    "    theta_star = theta_star_for_dim(d)\n",
    "\n",
    "    probs_lr = []\n",
    "    probs_tail = []\n",
    "\n",
    "    # Analytical\n",
    "    for alpha in alphas:\n",
    "        Sigma_alpha, eigenvalues_alpha = get_powerlaw_covariance(d, alpha)\n",
    "\n",
    "        hat_lr = create_low_rank_predictor(theta_star, eigenvalues_alpha, k, d, m)\n",
    "        hat_tail = create_tail_heavy_predictor(theta_star, eigenvalues_alpha, k, d, m)\n",
    "\n",
    "        prob_lr = get_analytic_powerlaw(theta_star, hat_lr, Sigma_alpha, eps_fixed, p)\n",
    "        prob_tail = get_analytic_powerlaw(theta_star, hat_tail, Sigma_alpha, eps_fixed, p)\n",
    "\n",
    "        probs_lr.append(prob_lr)\n",
    "        probs_tail.append(prob_tail)\n",
    "\n",
    "    # Empirical\n",
    "    p_star = get_p_star(p)\n",
    "    probs_lr_emp = []\n",
    "    probs_tail_emp = []\n",
    "\n",
    "    for alpha in alphas_empirical:\n",
    "        Sigma_alpha, eigenvalues_alpha = get_powerlaw_covariance(d, alpha)\n",
    "\n",
    "        hat_lr = create_low_rank_predictor(theta_star, eigenvalues_alpha, k, d, m)\n",
    "        hat_tail = create_tail_heavy_predictor(theta_star, eigenvalues_alpha, k, d, m)\n",
    "\n",
    "        # Sample x ~ N(0, Sigma) via diagonal eigenvalues scaling (as in your code)\n",
    "        x = np.random.randn(n_samples, d) * np.sqrt(eigenvalues_alpha[np.newaxis, :])\n",
    "\n",
    "        # Compute d_star\n",
    "        theta_perp_lr = hat_lr - (hat_lr @ theta_star / d) * theta_star\n",
    "        theta_perp_tail = hat_tail - (hat_tail @ theta_star / d) * theta_star\n",
    "\n",
    "        d_star_lr = lp_norm(theta_perp_lr, p_star)\n",
    "        d_star_tail = lp_norm(theta_perp_tail, p_star)\n",
    "\n",
    "        margins_lr = np.abs(x @ hat_lr)\n",
    "        margins_tail = np.abs(x @ hat_tail)\n",
    "\n",
    "        probs_lr_emp.append(np.mean(eps_fixed * d_star_lr >= margins_lr))\n",
    "        probs_tail_emp.append(np.mean(eps_fixed * d_star_tail >= margins_tail))\n",
    "\n",
    "    return probs_lr, probs_tail, probs_lr_emp, probs_tail_emp\n",
    "\n",
    "# -----------------------\n",
    "# COMPUTE + PLOT\n",
    "# -----------------------\n",
    "fig, ax = plt.subplots(1, 1, figsize=(tuple_size[0], tuple_size[1]))\n",
    "fig.subplots_adjust(left=0.20, bottom=0.18, top=0.98, right=0.96)\n",
    "\n",
    "for i, d in enumerate(dims):\n",
    "    ls = linestyles[i % len(linestyles)]\n",
    "    mk = markers[i % len(markers)]\n",
    "\n",
    "    probs_lr, probs_tail, probs_lr_emp, probs_tail_emp = compute_curves_for_dim(d)\n",
    "\n",
    "    # Analytical curves: linestyle encodes dimension, color encodes method\n",
    "    ax.plot(alphas, probs_lr, linestyle=ls, linewidth=1.5, color=COLOR_LR,\n",
    "            label=f'Low-rank (d={d})')\n",
    "    ax.plot(alphas, probs_tail, linestyle=ls, linewidth=1.5, color=COLOR_TAIL,\n",
    "            label=f'Heavy-tail (d={d})')\n",
    "\n",
    "    # Empirical points (optional): same color, dimension-specific marker\n",
    "    ax.plot(alphas_empirical, probs_lr_emp, linestyle='None', marker=mk, markersize=2, color=COLOR_LR, alpha=0.9)\n",
    "    ax.plot(alphas_empirical, probs_tail_emp, linestyle='None', marker=mk, markersize=2, color=COLOR_TAIL, alpha=0.9)\n",
    "\n",
    "ax.set_xlabel(r'$\\beta$', labelpad=0.0)\n",
    "ax.set_ylabel(r'prob. of existence', labelpad=1.0)\n",
    "ax.grid(True)\n",
    "\n",
    "# legend = ax.legend(handlelength=1.5, frameon=False, loc='lower right',\n",
    "#                    bbox_to_anchor=(1.1, 0.95), labelspacing=0.3, columnspacing=0.75)\n",
    "\n",
    "from matplotlib.lines import Line2D\n",
    "\n",
    "# -----------------------\n",
    "# CUSTOM 2-COLUMN LEGEND\n",
    "# -----------------------\n",
    "\n",
    "# Column 1 → dimensions (black, linestyle only)\n",
    "dim_handles = [\n",
    "    Line2D([0], [0],\n",
    "           color='black',\n",
    "           linestyle=linestyles[i % len(linestyles)],\n",
    "           linewidth=1.5,\n",
    "           label=f'$d={d}$')\n",
    "    for i, d in enumerate(dims)\n",
    "]\n",
    "\n",
    "# Column 2 → methods (color only, solid line)\n",
    "method_handles = [\n",
    "    Line2D([0], [0], color=COLOR_LR, linestyle='-', linewidth=1.5, label='LR'),\n",
    "    Line2D([0], [0], color=COLOR_TAIL, linestyle='-', linewidth=1.5, label='Tail')\n",
    "]\n",
    "\n",
    "handles = dim_handles + method_handles\n",
    "\n",
    "legend = ax.legend(\n",
    "    handles=handles,\n",
    "    frameon=False,\n",
    "    ncol=2,\n",
    "    handlelength=1.5,\n",
    "    labelspacing=0.3, \n",
    "    columnspacing=0.25,\n",
    "    loc='lower right',\n",
    "    bbox_to_anchor=(1.1, 0.95)\n",
    ")\n",
    "\n",
    "plt.savefig(f'prob_vs_alpha_eps{eps_fixed}_q{p}_k{k}_dims{\"-\".join(map(str,dims))}.pdf', bbox_extra_artists=[legend], bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "46128d81",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
