{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# R1SMG (Ji–Li) calibration → spherical-privacy API parameters\n",
    "\n",
    "This notebook does the following for the **rank‑1 singular multivariate Gaussian (R1SMG)** calibration:\n",
    "\n",
    "1. Compute \\(\\psi(T,\\delta)\\) and the paper-claimed \\(\\sigma_\\star\\).\n",
    "2. Convert \\(\\sigma_\\star\\) to spherical generalized‑gamma parameters \\((\\alpha,\\beta,p)\\).\n",
    "3. Compute the **true** \\(\\delta(\\varepsilon)\\) by calling the submission's internal routine:\n",
    "\n",
    "```python\n",
    "float(compute_spherical_generalized_gamma_privacy(eps, alpha, beta, p, T, mu_0, mu_1))\n",
    "```\n",
    "\n",
    "**Notes**\n",
    "- The R1SMG noise radius is half‑normal, which corresponds to \\(\\alpha=0\\), \\(p=2\\), and \\(\\beta=1/(2\\sigma_\\star)\\).\n",
    "- The Ji–Li formula for \\(\\psi\\) requires \\(T>2\\).\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a74d2b61",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "\n",
    "# Navigate to the parent directory of the project structure\n",
    "project_dir = os.path.abspath(os.path.join(os.getcwd(), '../../'))\n",
    "src_dir = os.path.join(project_dir, 'src')\n",
    "fig_dir = os.path.join(project_dir, 'fig')\n",
    "os.makedirs(fig_dir, exist_ok=True)\n",
    "\n",
    "# Add the src directory to sys.path\n",
    "sys.path.append(src_dir)\n",
    "\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from math import log, exp\n",
    "from scipy.special import gammaln\n",
    "\n",
    "from OAEGN_analysis.privacy_analysis import compute_spherical_generalized_gamma_privacy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def ji_li_psi(T: int, delta: float) -> float:\n",
    "    \"\"\"Compute psi(T, delta) in a numerically stable (log-space) way.\n",
    "\n",
    "    psi = ( delta * Gamma((T-1)/2) / (sqrt(pi) * Gamma(T/2)) )^(2/(T-2))\n",
    "    Requires: T > 2, delta in (0,1).\n",
    "    \"\"\"\n",
    "    if T <= 2:\n",
    "        raise ValueError('Ji–Li psi(T,delta) requires T > 2.')\n",
    "    if not (0.0 < delta < 1.0):\n",
    "        raise ValueError('delta must be in (0,1).')\n",
    "\n",
    "    # log( Gamma((T-1)/2) / (sqrt(pi) * Gamma(T/2)) )\n",
    "    log_ratio = gammaln((T - 1) / 2.0) - 0.5 * log(np.pi) - gammaln(T / 2.0)\n",
    "    log_base = log(delta) + log_ratio\n",
    "    log_psi = (2.0 / (T - 2.0)) * log_base\n",
    "    return float(exp(log_psi))\n",
    "\n",
    "\n",
    "def ji_li_sigma_star(eps: float, delta: float, T: int, s: float) -> float:\n",
    "    \"\"\"Paper-claimed sigma_star: sigma_star = 2 s^2 / (eps * psi(T,delta)).\"\"\"\n",
    "    if eps <= 0:\n",
    "        raise ValueError('eps must be > 0.')\n",
    "    if s <= 0:\n",
    "        raise ValueError('s must be > 0.')\n",
    "    psi = ji_li_psi(T, delta)\n",
    "    return float((2.0 * s * s) / (eps * psi))\n",
    "\n",
    "\n",
    "def r1smg_to_sgg_params(sigma_star: float):\n",
    "    \"\"\"Convert R1SMG radius (half-normal) to (alpha,beta,p).\n",
    "\n",
    "    R = sqrt(sigma_star) * |Z|, Z ~ N(0,1)\n",
    "    => f_R(r) ∝ exp(- r^2/(2 sigma_star)) on r>=0\n",
    "    => alpha=0, p=2, beta=1/(2 sigma_star)\n",
    "    \"\"\"\n",
    "    if sigma_star <= 0:\n",
    "        raise ValueError('sigma_star must be > 0.')\n",
    "    alpha = 0.0\n",
    "    p = 2.0\n",
    "    beta = 1.0 / (2.0 * sigma_star)\n",
    "    return alpha, beta, p\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def true_delta_r1smg_via_sgg_api(\n",
    "    eps: float,\n",
    "    delta_claimed: float,\n",
    "    T: int,\n",
    "    s: float,\n",
    ") -> dict:\n",
    "    \"\"\"Compute (psi, sigma_star_claim, alpha,beta,p) and call the SGG privacy API.\n",
    "\n",
    "    By default we pass mu_0=0 and mu_1=s (neighbor shift of norm s).\n",
    "    \"\"\"\n",
    "    mu_0 = np.array([0.0])\n",
    "    mu_1 = np.array([s])\n",
    "\n",
    "    psi = ji_li_psi(T, delta_claimed)\n",
    "    sigma_star = ji_li_sigma_star(eps, delta_claimed, T, s)\n",
    "    alpha, beta, p = r1smg_to_sgg_params(sigma_star)\n",
    "\n",
    "    if compute_spherical_generalized_gamma_privacy is None:\n",
    "        raise RuntimeError(\n",
    "            'compute_spherical_generalized_gamma_privacy is not imported. '\n",
    "            'Fix the import path in the first code cell.'\n",
    "        )\n",
    "\n",
    "    delta_true = float(\n",
    "        compute_spherical_generalized_gamma_privacy(eps, alpha, beta, p, T, mu_0, mu_1)\n",
    "    )\n",
    "    return {\n",
    "        'eps': float(eps),\n",
    "        'delta_claimed': float(delta_claimed),\n",
    "        'T': int(T),\n",
    "        's': float(s),\n",
    "        'psi': float(psi),\n",
    "        'sigma_star_claim': float(sigma_star),\n",
    "        'alpha': float(alpha),\n",
    "        'beta': float(beta),\n",
    "        'p': float(p),\n",
    "        'delta_true_at_eps': float(delta_true),\n",
    "    }\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Example: single configuration\n",
    "Edit the parameters below and run the cell."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "               eps: 1.0\n",
      "     delta_claimed: 1e-06\n",
      "                 T: 128\n",
      "                 s: 1.0\n",
      "               psi: 0.7700556518693528\n",
      "  sigma_star_claim: 2.59721488329433\n",
      "             alpha: 0.0\n",
      "              beta: 0.1925139129673382\n",
      "                 p: 2.0\n",
      " delta_true_at_eps: 0.9823692057633403\n"
     ]
    }
   ],
   "source": [
    "# ---- user parameters ----\n",
    "eps = 1.0\n",
    "delta_claimed = 1e-6\n",
    "T = 128\n",
    "s = 1.0\n",
    "\n",
    "out = true_delta_r1smg_via_sgg_api(eps=eps, delta_claimed=delta_claimed, T=T, s=s)\n",
    "for k, v in out.items():\n",
    "    print(f'{k:>18s}: {v}')\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Batch: evaluate multiple eps values\n",
    "Useful for generating a figure comparing the claimed \\(\\delta\\) vs the computed true \\(\\delta(\\varepsilon)\\)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>eps</th>\n",
       "      <th>delta_claimed</th>\n",
       "      <th>psi</th>\n",
       "      <th>sigma_star_claim</th>\n",
       "      <th>alpha</th>\n",
       "      <th>beta</th>\n",
       "      <th>p</th>\n",
       "      <th>delta_true_at_eps</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0.1</td>\n",
       "      <td>0.00001</td>\n",
       "      <td>0.798721</td>\n",
       "      <td>25.040031</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.019968</td>\n",
       "      <td>2.0</td>\n",
       "      <td>0.813284</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1.0</td>\n",
       "      <td>0.00001</td>\n",
       "      <td>0.798721</td>\n",
       "      <td>2.504003</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.199680</td>\n",
       "      <td>2.0</td>\n",
       "      <td>0.983594</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>2.0</td>\n",
       "      <td>0.00001</td>\n",
       "      <td>0.798721</td>\n",
       "      <td>1.252002</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.399361</td>\n",
       "      <td>2.0</td>\n",
       "      <td>0.995020</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>4.0</td>\n",
       "      <td>0.00001</td>\n",
       "      <td>0.798721</td>\n",
       "      <td>0.626001</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.798721</td>\n",
       "      <td>2.0</td>\n",
       "      <td>0.998804</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>8.0</td>\n",
       "      <td>0.00001</td>\n",
       "      <td>0.798721</td>\n",
       "      <td>0.313000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.597442</td>\n",
       "      <td>2.0</td>\n",
       "      <td>0.999755</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   eps  delta_claimed       psi  sigma_star_claim  alpha      beta    p  \\\n",
       "0  0.1        0.00001  0.798721         25.040031    0.0  0.019968  2.0   \n",
       "1  1.0        0.00001  0.798721          2.504003    0.0  0.199680  2.0   \n",
       "2  2.0        0.00001  0.798721          1.252002    0.0  0.399361  2.0   \n",
       "3  4.0        0.00001  0.798721          0.626001    0.0  0.798721  2.0   \n",
       "4  8.0        0.00001  0.798721          0.313000    0.0  1.597442  2.0   \n",
       "\n",
       "   delta_true_at_eps  \n",
       "0           0.813284  \n",
       "1           0.983594  \n",
       "2           0.995020  \n",
       "3           0.998804  \n",
       "4           0.999755  "
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "eps_list = [0.1, 1, 2, 4, 8]\n",
    "delta_claimed = 1e-5\n",
    "T = 128\n",
    "s = 1.0\n",
    "\n",
    "rows = []\n",
    "for eps in eps_list:\n",
    "    rows.append(true_delta_r1smg_via_sgg_api(eps=eps, delta_claimed=delta_claimed, T=T, s=s))\n",
    "\n",
    "import pandas as pd\n",
    "df = pd.DataFrame(rows)\n",
    "df[['eps','delta_claimed','psi','sigma_star_claim','alpha','beta','p','delta_true_at_eps']]\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "dp_privl_venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
