{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "ed16e2ff",
   "metadata": {},
   "source": [
    "## Part 0: imported packages"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "fafc94bf",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from itertools import product\n",
    "from scipy.stats import norm\n",
    "from scipy.optimize import brentq\n",
    "import pandas as pd\n",
    "from scipy.stats import truncnorm\n",
    "from tqdm.notebook  import tqdm\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "442a983b",
   "metadata": {},
   "outputs": [],
   "source": [
    "class DataDistribution(object):\n",
    "    def __init__(self, alpha_min, x_max, theta_star, x, rng):\n",
    "        self.alpha_min = alpha_min\n",
    "        self.x_theta = x @ theta_star\n",
    "        self.x_max_theta = x_max @ theta_star\n",
    "        self.b = self.x_max_theta / self.x_theta\n",
    "        self.rng = rng\n",
    "\n",
    "        self.b_tilde = alpha_min / (2 * (1 - alpha_min))\n",
    "        self.a = ((self.b + 3 * self.b_tilde + 1) - np.sqrt((self.b + 3 * self.b_tilde + 1) ** 2 - 4 * (self.b + self.b_tilde + 2 * self.b_tilde * self.b))) / 2\n",
    "        self.c = (2 * self.b_tilde) / (self.a - 1) - 1\n",
    "        self.h = (1 - alpha_min) / (2 * self.x_theta)\n",
    "\n",
    "        self.y_min = - self.x_max_theta\n",
    "        self.y_max = self.x_max_theta\n",
    "\n",
    "        self.pdf_max = max(self.h, self.h * self.c)\n",
    "        self.pdf_min = min(self.h, self.h * self.c)\n",
    "\n",
    "    def get_pdf(self, y):\n",
    "        abs_y = np.abs(y)\n",
    "        if abs_y <= self.x_theta:\n",
    "            return self.h\n",
    "        elif abs_y < self.a * self.x_theta:\n",
    "            frac = (abs_y - self.x_theta) / (self.a * self.x_theta - self.x_theta)\n",
    "            return (1-frac) * self.h + frac * self.h * self.c\n",
    "        elif abs_y <= self.x_max_theta:\n",
    "            return self.h * self.c\n",
    "        else:\n",
    "            return 0\n",
    "\n",
    "    def _get_cdf_pos(self, y):\n",
    "        if y < self.x_theta:\n",
    "            return self.h * y + 1 / 2\n",
    "        elif y < self.a * self.x_theta:\n",
    "            s = 1 / 2 + self.h * self.x_theta\n",
    "            frac = (y - self.x_theta) / (self.a * self.x_theta - self.x_theta)\n",
    "\n",
    "            h1 = self.h\n",
    "            h2 = (1-frac) * self.h + frac * self.h * self.c\n",
    "            area  = (h1 + h2) / 2 * (y - self.x_theta)\n",
    "            return s + area\n",
    "        elif y < self.x_max_theta:\n",
    "            return 1 - (self.x_max_theta - y) * self.c * self.h\n",
    "        else:\n",
    "            return 1\n",
    "    \n",
    "    def get_cdf(self, y):\n",
    "        \"\"\"\n",
    "        Cumulative Distribution Function for the symmetric piecewise linear PDF.\n",
    "        The PDF is symmetric around 0, so we compute CDF for y >= 0 and handle negative y by symmetry.\n",
    "        \"\"\"\n",
    "        if y >= 0:\n",
    "            return self._get_cdf_pos(y)\n",
    "        else:\n",
    "            return 1 - self._get_cdf_pos(-y)\n",
    "\n",
    "    def sample(self):\n",
    "        f_max = self.pdf_max\n",
    "        g_max = 1 / (self.y_max - self.y_min) # uniform distribution\n",
    "        M = f_max / g_max\n",
    "        while True:\n",
    "            u = self.rng.uniform(low=0, high=1, size=1)\n",
    "            x = self.rng.uniform(low=self.y_min, high=self.y_max, size=1)\n",
    "            if u <= self.get_pdf(x) / (g_max * (self.y_max - self.y_min)):\n",
    "                return x\n",
    "\n",
    "    def quantile_interval(self, gamma_lo, gamma_hi):\n",
    "        q1 = brentq(lambda x: self.get_cdf(x) - gamma_lo, -100, 100, xtol=1e-12, rtol=1e-10, maxiter=200)\n",
    "        q2 = brentq(lambda x: self.get_cdf(x) - gamma_hi, -100, 100, xtol=1e-12, rtol=1e-10, maxiter=200)\n",
    "        return q2 - q1\n",
    "\n",
    "\n",
    "x = np.linspace(-20, 20, 500)\n",
    "data = DataDistribution(0.05, np.array([20, 20]), np.array([2, 2]), np.array([1, 1]), rng=np.random.default_rng(42))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "912a3141",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Dataset(object):\n",
    "    def __init__(self, d, alpha_min, rng):\n",
    "        self.d = d\n",
    "        self.rng = rng\n",
    "        self.alpha_min = alpha_min \n",
    "        self._X_min = 1\n",
    "        self._X_max = 20\n",
    "        self._theta_min = 1\n",
    "        self._theta_max = 2\n",
    "\n",
    "        self.theta = self.rng.uniform(low=self._theta_min, high=self._theta_max, size=(d,))\n",
    "        self.x_max = np.array([self._X_max, self._X_max])\n",
    "\n",
    "    def get_data_distribution(self, x):\n",
    "        return DataDistribution(self.alpha_min, self.x_max, self.theta, x=x, rng=self.rng)\n",
    "\n",
    "    def generate_X_y(self, n):\n",
    "        X = self.rng.uniform(low=self._X_min, high=self._X_max, size=(n, self.d))\n",
    "        distributions = [self.get_data_distribution(x_i) for x_i in X]\n",
    "        y = np.array([distribution.sample() for distribution in distributions]).flatten()\n",
    "        return X, y, distributions\n",
    "\n",
    "dataset = Dataset(2, 0.05, np.random.default_rng(42))\n",
    "X, y, _ = dataset.generate_X_y(100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "84dc8df0",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((100, 2), (100,))"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X.shape, y.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "37c76dff",
   "metadata": {},
   "outputs": [],
   "source": [
    "# -------------------------- Quantile SGD (symmetric) --------------------------\n",
    "def pinball_loss(x, gamma):\n",
    "    return np.where(x >= 0, gamma * x, (gamma - 1) * x)\n",
    "\n",
    "def pinball_grad(x, gamma):\n",
    "    return np.where(x < 0, gamma-1, gamma)\n",
    "\n",
    "def sgd_quantile(X, Y, gamma, epochs=400, batch_size=64, lr0=0.00001, lr_decay=0.99, rng=None):\n",
    "    n, d = X.shape\n",
    "    theta = np.zeros(d) \n",
    "    theta_list = [theta.copy()]\n",
    "    loss_list = [np.mean(pinball_loss(Y - X @ theta, gamma))]\n",
    "    for ep in range(epochs):\n",
    "        idx = rng.permutation(n)\n",
    "        X_ep, Y_ep = X[idx], Y[idx]\n",
    "        lr = lr0 * (lr_decay**ep)\n",
    "        for start in range(0, n, batch_size):\n",
    "            stop = min(start+batch_size, n)\n",
    "            xb, yb = X_ep[start:stop], Y_ep[start:stop]\n",
    "            ghat = -xb.T @ pinball_grad(yb - xb @ theta, gamma) / (stop-start)\n",
    "            theta -= lr * ghat\n",
    "        theta_list.append(theta.copy())\n",
    "        loss_list.append(np.mean(pinball_loss(Y - X @ theta, gamma)))\n",
    "    return theta, theta_list, np.array(loss_list)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "2610a741",
   "metadata": {},
   "outputs": [],
   "source": [
    "# -------------------------- CQR utilities --------------------------\n",
    "def cqr_nc_score(X, Y, theta_lo, theta_hi):\n",
    "    t_low  = X @ theta_lo\n",
    "    t_high = X @ theta_hi\n",
    "    return np.maximum(t_low - Y, Y - t_high)\n",
    "\n",
    "def calibration_threshold(X_cal, y_cal, theta_lo, theta_hi, alpha):\n",
    "    m = len(y_cal)\n",
    "    nc_scores = cqr_nc_score(X_cal, y_cal, theta_lo, theta_hi)\n",
    "    k = int(np.ceil((1 - alpha) * (m + 1)))\n",
    "    cqr_threshold = np.partition(nc_scores, k-1)[k-1]\n",
    "    return cqr_threshold\n",
    "\n",
    "def excess_length(X, dataset, theta_n_lo, theta_n_hi, q_hat):\n",
    "    C_lo = X @ theta_n_lo - q_hat\n",
    "    C_up = X @ theta_n_hi + q_hat\n",
    "    true_interval = X @ dataset.theta_hi - X @ dataset.theta_lo\n",
    "    excess_len = (C_up - C_lo) - true_interval\n",
    "    return  true_interval, excess_len\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "677a7d2e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def successive_halving_lr0_tuning(X_train, y_train, gamma, rng, epochs=1, \n",
    "                                 lr0_candidates=None, n_iterations=3, \n",
    "                                 eta=2, min_budget=1, debug=False):\n",
    "    \"\"\"\n",
    "    Successive Halving algorithm for tuning lr0 hyperparameter.\n",
    "    \n",
    "    Parameters:\n",
    "    -----------\n",
    "    X_train, y_train : training data\n",
    "    gamma : quantile level\n",
    "    rng : random number generator\n",
    "    epochs : number of training epochs\n",
    "    lr0_candidates : list of learning rates to try (if None, will generate automatically)\n",
    "    n_iterations : number of successive halving iterations\n",
    "    eta : elimination factor (how many candidates to eliminate each round)\n",
    "    min_budget : minimum number of candidates to keep\n",
    "    \n",
    "    Returns:\n",
    "    --------\n",
    "    best_lr0 : best learning rate found\n",
    "    best_loss : corresponding loss value\n",
    "    tuning_history : list of (iteration, candidates, losses) for each round\n",
    "    \"\"\"\n",
    "    def pprint(text):\n",
    "        if debug:\n",
    "            print(text)\n",
    "    \n",
    "    # Generate learning rate candidates if not provided\n",
    "    if lr0_candidates is None:\n",
    "        # Use log-uniform distribution for learning rates\n",
    "        lr0_candidates = np.logspace(-5, 0, 20)  # From 1e-5 to 1e-1\n",
    "    \n",
    "    tuning_history = []\n",
    "    current_candidates = lr0_candidates.copy()\n",
    "    \n",
    "    pprint(f\"Starting Successive Halving with {len(current_candidates)} candidates\")\n",
    "    pprint(f\"Learning rate range: [{min(current_candidates):.2e}, {max(current_candidates):.2e}]\")\n",
    "    \n",
    "    for iteration in range(n_iterations):\n",
    "        pprint(f\"\\n--- Iteration {iteration + 1} ---\")\n",
    "        pprint(f\"Evaluating {len(current_candidates)} candidates\")\n",
    "        \n",
    "        # Evaluate all current candidates\n",
    "        candidate_losses = []\n",
    "        \n",
    "        for i, lr0 in enumerate(current_candidates):\n",
    "            # Train model with current lr0\n",
    "            theta, _, loss_list = sgd_quantile(X_train, y_train, gamma=gamma, \n",
    "                                          rng=rng, epochs=epochs, lr0=lr0)\n",
    "            \n",
    "            # Use the last (final) loss value as evaluation metric\n",
    "            final_loss = loss_list[-1]\n",
    "            candidate_losses.append(final_loss)\n",
    "            \n",
    "            pprint(f\"  lr0={lr0:.2e}: final_loss={final_loss:.6f}\")\n",
    "        \n",
    "        # Store history\n",
    "        tuning_history.append({\n",
    "            'iteration': iteration + 1,\n",
    "            'candidates': current_candidates.copy(),\n",
    "            'losses': candidate_losses.copy()\n",
    "        })\n",
    "        \n",
    "        # Select top candidates for next round\n",
    "        if iteration < n_iterations - 1:  # Don't eliminate on last iteration\n",
    "            # Sort by loss (lower is better for quantile regression)\n",
    "            sorted_indices = np.argsort(candidate_losses)\n",
    "            \n",
    "            # Keep top candidates (eliminate worst ones)\n",
    "            n_keep = max(min_budget, len(current_candidates) // eta)\n",
    "            top_indices = sorted_indices[:n_keep]\n",
    "            \n",
    "            current_candidates = current_candidates[top_indices]\n",
    "            pprint(f\"Kept top {len(current_candidates)} candidates for next round\")\n",
    "        else:\n",
    "            best_idx = np.argmin(candidate_losses)\n",
    "            best_lr0 = current_candidates[best_idx]\n",
    "            best_loss = candidate_losses[best_idx]\n",
    "            \n",
    "            pprint(f\"\\nFinal best: lr0={best_lr0:.2e}, loss={best_loss:.6f}\")\n",
    "    \n",
    "    return theta, best_lr0, best_loss, tuning_history\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0099e3d0",
   "metadata": {},
   "source": [
    "## Part 2: Fix m = 5000, vary n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1bdec2d9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "n: 1000, m: 5000, alpha: 0.075 seed: 3 excess_len_avg: 2.854841222514091\n",
      "n: 200, m: 5000, alpha: 0.075 seed: 6 excess_len_avg: 48.546809736430546\n",
      "n: 1000, m: 5000, alpha: 0.075 seed: 6 excess_len_avg: 4.150124548586544\n",
      "n: 200, m: 5000, alpha: 0.075 seed: 1 excess_len_avg: 19.613843725598212\n",
      "n: 200, m: 5000, alpha: 0.075 seed: 4 excess_len_avg: 34.35296620633268\n",
      "n: 200, m: 5000, alpha: 0.075 seed: 3 excess_len_avg: 14.636042016978921\n",
      "n: 200, m: 5000, alpha: 0.075 seed: 5 excess_len_avg: 25.75499325626206\n",
      "n: 4000, m: 5000, alpha: 0.075 seed: 3 excess_len_avg: 0.6117903333956047\n",
      "n: 200, m: 5000, alpha: 0.075 seed: 2 excess_len_avg: 20.030606184119495\n",
      "n: 1000, m: 5000, alpha: 0.075 seed: 5 excess_len_avg: 4.816972939162153\n",
      "n: 4000, m: 5000, alpha: 0.075 seed: 2 excess_len_avg: 2.165711859941424\n",
      "n: 1000, m: 5000, alpha: 0.075 seed: 2 excess_len_avg: 0.6873941323744596\n",
      "n: 200, m: 5000, alpha: 0.075 seed: 0 excess_len_avg: 41.65827089286385\n",
      "n: 4000, m: 5000, alpha: 0.075 seed: 0 excess_len_avg: 4.652898729599024\n",
      "n: 1000, m: 5000, alpha: 0.075 seed: 0 excess_len_avg: 0.7919961447478708\n",
      "n: 4000, m: 5000, alpha: 0.075 seed: 5 excess_len_avg: 5.360305752908871\n",
      "n: 2000, m: 5000, alpha: 0.075 seed: 3 excess_len_avg: 2.7253724980757554\n",
      "n: 1000, m: 5000, alpha: 0.075 seed: 4 excess_len_avg: 7.766193395145059\n",
      "n: 1000, m: 5000, alpha: 0.075 seed: 1 excess_len_avg: 6.8424610383173405\n",
      "n: 4000, m: 5000, alpha: 0.075 seed: 1 excess_len_avg: 2.6463669518603705\n",
      "n: 500, m: 5000, alpha: 0.075 seed: 3 excess_len_avg: 45.29628579507173\n",
      "n: 500, m: 5000, alpha: 0.075 seed: 6 excess_len_avg: 12.849551138678876\n",
      "n: 8000, m: 5000, alpha: 0.075 seed: 0 excess_len_avg: 0.48173743727387147\n",
      "n: 500, m: 5000, alpha: 0.075 seed: 0 excess_len_avg: 20.816823760318503\n",
      "n: 500, m: 5000, alpha: 0.075 seed: 4 excess_len_avg: 17.970246042145156\n",
      "n: 2000, m: 5000, alpha: 0.075 seed: 6 excess_len_avg: 1.497934868764338\n",
      "n: 500, m: 5000, alpha: 0.075 seed: 5 excess_len_avg: 17.513328411071296\n",
      "n: 500, m: 5000, alpha: 0.075 seed: 2 excess_len_avg: 33.69414354660706\n",
      "n: 8000, m: 5000, alpha: 0.075 seed: 3 excess_len_avg: 1.1433915990797645\n",
      "n: 8000, m: 5000, alpha: 0.075 seed: 5 excess_len_avg: 0.5644701995844453\n",
      "n: 4000, m: 5000, alpha: 0.075 seed: 4 excess_len_avg: 2.3839823811471597\n",
      "n: 8000, m: 5000, alpha: 0.075 seed: 2 excess_len_avg: 1.4611007444614001\n",
      "n: 2000, m: 5000, alpha: 0.075 seed: 5 excess_len_avg: 6.970367505862316\n",
      "n: 2000, m: 5000, alpha: 0.075 seed: 0 excess_len_avg: 3.912342138916387\n",
      "n: 500, m: 5000, alpha: 0.075 seed: 1 excess_len_avg: 30.07721945876236\n",
      "n: 6000, m: 5000, alpha: 0.075 seed: 2 excess_len_avg: 0.8877037473803461\n",
      "n: 2000, m: 5000, alpha: 0.075 seed: 2 excess_len_avg: 2.419808094846222\n",
      "n: 6000, m: 5000, alpha: 0.075 seed: 0 excess_len_avg: 1.4055466733891686\n",
      "n: 6000, m: 5000, alpha: 0.075 seed: 3 excess_len_avg: 0.2159317115574578\n",
      "n: 4000, m: 5000, alpha: 0.075 seed: 6 excess_len_avg: 2.4406112410986514\n",
      "n: 200, m: 5000, alpha: 0.075 seed: 7 excess_len_avg: 84.2993257714666\n",
      "n: 15000, m: 5000, alpha: 0.075 seed: 4 excess_len_avg: 1.3433883315402324\n",
      "n: 8000, m: 5000, alpha: 0.075 seed: 1 excess_len_avg: 4.41279450013958\n",
      "n: 15000, m: 5000, alpha: 0.075 seed: 3 excess_len_avg: 0.3475674606425873\n",
      "n: 8000, m: 5000, alpha: 0.075 seed: 4 excess_len_avg: 0.895588138337209\n",
      "n: 15000, m: 5000, alpha: 0.075 seed: 5 excess_len_avg: 0.43208688823199465\n",
      "n: 1000, m: 5000, alpha: 0.075 seed: 7 excess_len_avg: 6.133566381264021\n",
      "n: 1000, m: 5000, alpha: 0.075 seed: 8 excess_len_avg: 14.605641951370671\n",
      "n: 15000, m: 5000, alpha: 0.075 seed: 2 excess_len_avg: 2.4086772049776344\n",
      "n: 200, m: 5000, alpha: 0.075 seed: 8 excess_len_avg: 45.63543019285262\n",
      "n: 6000, m: 5000, alpha: 0.075 seed: 5 excess_len_avg: 1.1672720530354441\n",
      "n: 2000, m: 5000, alpha: 0.075 seed: 4 excess_len_avg: 13.421214956065983\n",
      "n: 2000, m: 5000, alpha: 0.075 seed: 1 excess_len_avg: 3.4220915605887465\n",
      "n: 4000, m: 5000, alpha: 0.075 seed: 7 excess_len_avg: 2.557665838372578\n",
      "n: 200, m: 5000, alpha: 0.075 seed: 9 excess_len_avg: 7.339494824365353\n",
      "n: 10000, m: 5000, alpha: 0.075 seed: 0 excess_len_avg: 1.3188227365701277\n",
      "n: 15000, m: 5000, alpha: 0.075 seed: 0 excess_len_avg: 0.9566640385835244\n",
      "n: 1000, m: 5000, alpha: 0.075 seed: 9 excess_len_avg: 4.198256757977445\n",
      "n: 500, m: 5000, alpha: 0.075 seed: 7 excess_len_avg: 13.672903268266715\n",
      "n: 8000, m: 5000, alpha: 0.075 seed: 6 excess_len_avg: 1.1828241340433847\n",
      "n: 8000, m: 5000, alpha: 0.075 seed: 7 excess_len_avg: 4.209908355937132\n",
      "n: 6000, m: 5000, alpha: 0.075 seed: 1 excess_len_avg: 0.5328749419657757\n",
      "n: 500, m: 5000, alpha: 0.075 seed: 8 excess_len_avg: 9.211149950916148\n",
      "n: 15000, m: 5000, alpha: 0.075 seed: 1 excess_len_avg: 0.6330411234284207\n",
      "n: 8000, m: 5000, alpha: 0.075 seed: 8 excess_len_avg: 0.9180253538440493\n",
      "n: 2000, m: 5000, alpha: 0.075 seed: 7 excess_len_avg: 2.8286378789667053\n",
      "n: 2000, m: 5000, alpha: 0.075 seed: 8 excess_len_avg: 2.78203525881359\n",
      "n: 10000, m: 5000, alpha: 0.075 seed: 5 excess_len_avg: 1.9622115612429212\n",
      "n: 10000, m: 5000, alpha: 0.075 seed: 3 excess_len_avg: 2.4934512482649134\n",
      "n: 4000, m: 5000, alpha: 0.075 seed: 8 excess_len_avg: 3.4814780712232634\n",
      "n: 500, m: 5000, alpha: 0.075 seed: 9 excess_len_avg: 46.33141349032414\n",
      "n: 15000, m: 5000, alpha: 0.075 seed: 6 excess_len_avg: 0.9186819474702645\n",
      "n: 6000, m: 5000, alpha: 0.075 seed: 4 excess_len_avg: 0.5444040641243654\n",
      "n: 4000, m: 5000, alpha: 0.075 seed: 9 excess_len_avg: 2.5750277062200633\n",
      "n: 1000, m: 5000, alpha: 0.075 seed: 10 excess_len_avg: 5.588788178219989\n",
      "n: 6000, m: 5000, alpha: 0.075 seed: 7 excess_len_avg: 0.6050430890606826\n",
      "n: 10000, m: 5000, alpha: 0.075 seed: 2 excess_len_avg: 0.5619024565809007\n",
      "n: 2000, m: 5000, alpha: 0.075 seed: 9 excess_len_avg: 2.156842133127234\n",
      "n: 6000, m: 5000, alpha: 0.075 seed: 6 excess_len_avg: 1.4131970986484321\n",
      "n: 200, m: 5000, alpha: 0.075 seed: 10 excess_len_avg: 112.15115265014978\n",
      "n: 200, m: 5000, alpha: 0.075 seed: 11 excess_len_avg: 24.81912137069238\n",
      "n: 4000, m: 5000, alpha: 0.075 seed: 11 excess_len_avg: 6.163892279623778\n",
      "n: 200, m: 5000, alpha: 0.075 seed: 12 excess_len_avg: 30.43067541539736\n",
      "n: 1000, m: 5000, alpha: 0.075 seed: 11 excess_len_avg: 4.5303603054739225\n",
      "n: 15000, m: 5000, alpha: 0.075 seed: 8 excess_len_avg: 0.20988710509149208\n",
      "n: 1000, m: 5000, alpha: 0.075 seed: 12 excess_len_avg: 4.328218577484508\n",
      "n: 2000, m: 5000, alpha: 0.075 seed: 10 excess_len_avg: 6.970994437383122\n",
      "n: 500, m: 5000, alpha: 0.075 seed: 11 excess_len_avg: 37.28068129450876\n",
      "n: 15000, m: 5000, alpha: 0.075 seed: 9 excess_len_avg: 1.208130916231974\n",
      "n: 8000, m: 5000, alpha: 0.075 seed: 10 excess_len_avg: 2.8301136859268254\n",
      "n: 8000, m: 5000, alpha: 0.075 seed: 11 excess_len_avg: 2.0793637270173764n: 6000, m: 5000, alpha: 0.075 seed: 9 excess_len_avg: 2.614364941744644\n",
      "\n",
      "n: 8000, m: 5000, alpha: 0.075 seed: 9 excess_len_avg: 1.0267321797179083\n",
      "n: 10000, m: 5000, alpha: 0.075 seed: 6 excess_len_avg: 0.3145689991056404\n",
      "n: 4000, m: 5000, alpha: 0.075 seed: 10 excess_len_avg: 0.6359817672502666\n",
      "n: 4000, m: 5000, alpha: 0.075 seed: 12 excess_len_avg: 2.63034392782755\n",
      "n: 500, m: 5000, alpha: 0.075 seed: 10 excess_len_avg: 1.9049094869530376\n",
      "n: 10000, m: 5000, alpha: 0.075 seed: 7 excess_len_avg: 1.7118446152298903\n",
      "n: 10000, m: 5000, alpha: 0.075 seed: 1 excess_len_avg: 0.7020629411265563\n",
      "n: 10000, m: 5000, alpha: 0.075 seed: 4 excess_len_avg: 2.0626500107379333\n",
      "n: 15000, m: 5000, alpha: 0.075 seed: 7 excess_len_avg: 0.24562551383910244\n",
      "n: 10000, m: 5000, alpha: 0.075 seed: 8 excess_len_avg: 1.527898846369781\n",
      "n: 500, m: 5000, alpha: 0.075 seed: 12 excess_len_avg: 4.304112284202994\n",
      "n: 20000, m: 5000, alpha: 0.075 seed: 3 excess_len_avg: 0.21942935992170207\n",
      "n: 20000, m: 5000, alpha: 0.075 seed: 5 excess_len_avg: 1.6612095774009326\n",
      "n: 2000, m: 5000, alpha: 0.075 seed: 11 excess_len_avg: 2.9580503500933597\n",
      "n: 6000, m: 5000, alpha: 0.075 seed: 11 excess_len_avg: 1.3140203962537353\n",
      "n: 2000, m: 5000, alpha: 0.075 seed: 12 excess_len_avg: 2.4743808798815903\n",
      "n: 6000, m: 5000, alpha: 0.075 seed: 8 excess_len_avg: 5.08593299852671\n",
      "n: 15000, m: 5000, alpha: 0.075 seed: 10 excess_len_avg: 1.6008718023742121\n",
      "n: 200, m: 5000, alpha: 0.075 seed: 13 excess_len_avg: 54.26417062616996\n",
      "n: 1000, m: 5000, alpha: 0.075 seed: 13 excess_len_avg: 8.454840917590111\n",
      "n: 20000, m: 5000, alpha: 0.075 seed: 2 excess_len_avg: 2.069008492783988\n",
      "n: 20000, m: 5000, alpha: 0.075 seed: 4 excess_len_avg: 1.7947145957995934\n",
      "n: 1000, m: 5000, alpha: 0.075 seed: 14 excess_len_avg: 1.9105069185837715\n",
      "n: 15000, m: 5000, alpha: 0.075 seed: 11 excess_len_avg: 0.3360118562151282\n",
      "n: 8000, m: 5000, alpha: 0.075 seed: 12 excess_len_avg: 1.572891525332168\n",
      "n: 200, m: 5000, alpha: 0.075 seed: 15 excess_len_avg: 42.32679319905759\n",
      "n: 200, m: 5000, alpha: 0.075 seed: 14 excess_len_avg: 50.12122615381032\n",
      "n: 4000, m: 5000, alpha: 0.075 seed: 13 excess_len_avg: 0.922404819771731\n",
      "n: 4000, m: 5000, alpha: 0.075 seed: 14 excess_len_avg: 3.4838932200908292\n",
      "n: 6000, m: 5000, alpha: 0.075 seed: 12 excess_len_avg: 1.3384823963021268\n",
      "n: 1000, m: 5000, alpha: 0.075 seed: 16 excess_len_avg: 17.805221399419096\n",
      "n: 6000, m: 5000, alpha: 0.075 seed: 10 excess_len_avg: 1.4300502737450984\n",
      "n: 200, m: 5000, alpha: 0.075 seed: 16 excess_len_avg: 27.955630376198364\n",
      "n: 10000, m: 5000, alpha: 0.075 seed: 10 excess_len_avg: 1.0280122958531124\n",
      "n: 500, m: 5000, alpha: 0.075 seed: 13 excess_len_avg: 5.750879327271544\n",
      "n: 20000, m: 5000, alpha: 0.075 seed: 0 excess_len_avg: 4.058128186132364\n",
      "n: 4000, m: 5000, alpha: 0.075 seed: 15 excess_len_avg: 9.39075122729355n: 8000, m: 5000, alpha: 0.075 seed: 14 excess_len_avg: 1.3289135186712906\n",
      "\n",
      "n: 1000, m: 5000, alpha: 0.075 seed: 15 excess_len_avg: 1.2665858110616395\n",
      "n: 15000, m: 5000, alpha: 0.075 seed: 12 excess_len_avg: 1.0798058051771575\n",
      "n: 10000, m: 5000, alpha: 0.075 seed: 11 excess_len_avg: 1.4050382338690917\n",
      "n: 2000, m: 5000, alpha: 0.075 seed: 14 excess_len_avg: 3.702064960167198\n",
      "n: 2000, m: 5000, alpha: 0.075 seed: 16 excess_len_avg: 4.615362360948502\n",
      "n: 8000, m: 5000, alpha: 0.075 seed: 15 excess_len_avg: 0.3769046893897947\n",
      "n: 500, m: 5000, alpha: 0.075 seed: 16 excess_len_avg: 12.923832349941422\n",
      "n: 15000, m: 5000, alpha: 0.075 seed: 14 excess_len_avg: 1.4688563345371235\n",
      "n: 20000, m: 5000, alpha: 0.075 seed: 6 excess_len_avg: 1.8062584954985756\n",
      "n: 500, m: 5000, alpha: 0.075 seed: 14 excess_len_avg: 6.482253640255607\n",
      "n: 2000, m: 5000, alpha: 0.075 seed: 13 excess_len_avg: 9.445266964957398\n",
      "n: 4000, m: 5000, alpha: 0.075 seed: 16 excess_len_avg: 1.0844762506244068\n",
      "n: 10000, m: 5000, alpha: 0.075 seed: 9 excess_len_avg: 2.445150282526446\n",
      "n: 500, m: 5000, alpha: 0.075 seed: 15 excess_len_avg: 63.24751246549652\n",
      "n: 20000, m: 5000, alpha: 0.075 seed: 9 excess_len_avg: 1.499527694128264\n",
      "n: 200, m: 5000, alpha: 0.075 seed: 17 excess_len_avg: 27.588746861667495\n",
      "n: 1000, m: 5000, alpha: 0.075 seed: 17 excess_len_avg: 2.8650441220942158\n",
      "n: 8000, m: 5000, alpha: 0.075 seed: 13 excess_len_avg: 0.261100591020827\n",
      "n: 6000, m: 5000, alpha: 0.075 seed: 14 excess_len_avg: 1.6872842027351818\n",
      "n: 20000, m: 5000, alpha: 0.075 seed: 1 excess_len_avg: 0.2899531365573761\n",
      "n: 6000, m: 5000, alpha: 0.075 seed: 13 excess_len_avg: 1.635677780847192\n",
      "n: 200, m: 5000, alpha: 0.075 seed: 18 excess_len_avg: 15.345580945146203\n",
      "n: 2000, m: 5000, alpha: 0.075 seed: 15 excess_len_avg: 1.8963162226255\n",
      "n: 6000, m: 5000, alpha: 0.075 seed: 15 excess_len_avg: 0.9784685543952937\n",
      "n: 1000, m: 5000, alpha: 0.075 seed: 19 excess_len_avg: 2.376463725184347\n",
      "n: 20000, m: 5000, alpha: 0.075 seed: 11 excess_len_avg: 1.1255398814023188\n",
      "n: 2000, m: 5000, alpha: 0.075 seed: 17 excess_len_avg: 4.527777008391681\n",
      "n: 4000, m: 5000, alpha: 0.075 seed: 19 excess_len_avg: 4.1311372183627695\n",
      "n: 15000, m: 5000, alpha: 0.075 seed: 16 excess_len_avg: 0.4206515281656193\n",
      "n: 20000, m: 5000, alpha: 0.075 seed: 8 excess_len_avg: 0.9234814568410814\n",
      "n: 500, m: 5000, alpha: 0.075 seed: 17 excess_len_avg: 17.1633266527459\n",
      "n: 6000, m: 5000, alpha: 0.075 seed: 16 excess_len_avg: 3.324540818089985\n",
      "n: 10000, m: 5000, alpha: 0.075 seed: 14 excess_len_avg: 0.7280329326339136\n",
      "n: 1000, m: 5000, alpha: 0.075 seed: 18 excess_len_avg: 10.087841760900616\n",
      "n: 200, m: 5000, alpha: 0.075 seed: 19 excess_len_avg: 72.34108003081826\n",
      "n: 8000, m: 5000, alpha: 0.075 seed: 16 excess_len_avg: 1.928499633727047n: 4000, m: 5000, alpha: 0.075 seed: 17 excess_len_avg: 0.6009664178005724\n",
      "\n",
      "n: 10000, m: 5000, alpha: 0.075 seed: 12 excess_len_avg: 2.922380895900454\n",
      "n: 500, m: 5000, alpha: 0.075 seed: 18 excess_len_avg: 6.9394323974362955\n",
      "n: 10000, m: 5000, alpha: 0.075 seed: 15 excess_len_avg: 1.9395599116926319\n",
      "n: 15000, m: 5000, alpha: 0.075 seed: 13 excess_len_avg: 1.0427762524905961\n",
      "n: 4000, m: 5000, alpha: 0.075 seed: 18 excess_len_avg: 2.54128078465286\n",
      "n: 2000, m: 5000, alpha: 0.075 seed: 19 excess_len_avg: 11.706891215780862\n",
      "n: 10000, m: 5000, alpha: 0.075 seed: 13 excess_len_avg: 1.3162792645123242\n",
      "n: 20000, m: 5000, alpha: 0.075 seed: 10 excess_len_avg: 0.22307748029405888\n",
      "n: 15000, m: 5000, alpha: 0.075 seed: 17 excess_len_avg: 9.663696259942219\n",
      "n: 8000, m: 5000, alpha: 0.075 seed: 17 excess_len_avg: 1.9010620003476986\n",
      "n: 6000, m: 5000, alpha: 0.075 seed: 19 excess_len_avg: 2.4598107709447534\n",
      "n: 15000, m: 5000, alpha: 0.075 seed: 19 excess_len_avg: 0.3980906321745869\n",
      "n: 15000, m: 5000, alpha: 0.075 seed: 15 excess_len_avg: 0.6001450240166489\n",
      "n: 2000, m: 5000, alpha: 0.075 seed: 18 excess_len_avg: 2.0974757599059695\n",
      "n: 500, m: 5000, alpha: 0.075 seed: 19 excess_len_avg: 17.290532969722133\n",
      "n: 20000, m: 5000, alpha: 0.075 seed: 12 excess_len_avg: 1.576792933751031\n",
      "n: 8000, m: 5000, alpha: 0.075 seed: 18 excess_len_avg: 0.6319640887405342\n",
      "n: 8000, m: 5000, alpha: 0.075 seed: 19 excess_len_avg: 0.5162306147209422\n",
      "n: 6000, m: 5000, alpha: 0.075 seed: 18 excess_len_avg: 0.8772640430191252\n",
      "n: 6000, m: 5000, alpha: 0.075 seed: 17 excess_len_avg: 1.563269496835053\n",
      "n: 20000, m: 5000, alpha: 0.075 seed: 7 excess_len_avg: 8.431268933970168\n",
      "n: 20000, m: 5000, alpha: 0.075 seed: 14 excess_len_avg: 1.2765416445784623\n",
      "n: 20000, m: 5000, alpha: 0.075 seed: 16 excess_len_avg: 0.6148250129177273\n",
      "n: 15000, m: 5000, alpha: 0.075 seed: 18 excess_len_avg: 3.475787002482083\n",
      "n: 10000, m: 5000, alpha: 0.075 seed: 16 excess_len_avg: 0.39225160005211934\n",
      "n: 10000, m: 5000, alpha: 0.075 seed: 18 excess_len_avg: 0.8463398381753785\n",
      "n: 20000, m: 5000, alpha: 0.075 seed: 17 excess_len_avg: 3.18143950557126\n",
      "n: 10000, m: 5000, alpha: 0.075 seed: 17 excess_len_avg: 1.9826594350940032\n",
      "n: 10000, m: 5000, alpha: 0.075 seed: 19 excess_len_avg: 0.43497905813477483\n",
      "n: 20000, m: 5000, alpha: 0.075 seed: 13 excess_len_avg: 0.16004742065232624\n",
      "n: 20000, m: 5000, alpha: 0.075 seed: 19 excess_len_avg: 0.22500158985593868\n",
      "n: 20000, m: 5000, alpha: 0.075 seed: 15 excess_len_avg: 1.1294044793855555\n",
      "n: 20000, m: 5000, alpha: 0.075 seed: 18 excess_len_avg: 1.045672714923243\n"
     ]
    }
   ],
   "source": [
    "import json\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from multiprocessing import Pool\n",
    "from functools import partial\n",
    "import multiprocessing\n",
    "\n",
    "def process_single_run(seed, n, m, alpha):\n",
    "    rng = np.random.default_rng(seed)\n",
    "    alpha_min = 0.005\n",
    "    assert alpha > alpha_min\n",
    "    dataset = Dataset(2, alpha_min, rng)\n",
    "    \n",
    "\n",
    "    try:\n",
    "        X_train, y_train, _ = dataset.generate_X_y(n=n)\n",
    "        X_cal, y_cal, _ = dataset.generate_X_y(n=m)\n",
    "        X_test, y_test, test_distributions = dataset.generate_X_y(n=2000)\n",
    "    except Exception as e:\n",
    "        print(f\"Error: {e}\")\n",
    "        return None\n",
    "\n",
    "    # Train models\n",
    "    theta_n_hi, best_lr0_hi, best_loss_hi, history_hi = successive_halving_lr0_tuning(\n",
    "        X_train=X_train, \n",
    "        y_train=y_train, \n",
    "        gamma=1-alpha/2,  # Same gamma as in your original line\n",
    "        rng=rng, \n",
    "        epochs=1,  # Same epochs as in your original line\n",
    "        n_iterations=3,  # Number of elimination rounds\n",
    "        eta=2,  # Eliminate half the candidates each round\n",
    "        debug=False,\n",
    "    )\n",
    "\n",
    "    theta_n_lo, best_lr0_lo, best_loss_lo, history_lo = successive_halving_lr0_tuning(\n",
    "        X_train=X_train, \n",
    "        y_train=y_train, \n",
    "        gamma=alpha/2,  # Same gamma as in your original line\n",
    "        rng=rng, \n",
    "        epochs=1,  # Same epochs as in your original line\n",
    "        n_iterations=3,  # Number of elimination rounds\n",
    "        eta=2,  # Eliminate half the candidates each round\n",
    "        debug=False,\n",
    "    )\n",
    "    # Calibration\n",
    "    q_hat = calibration_threshold(X_cal, y_cal, theta_n_lo, theta_n_hi, alpha=alpha)\n",
    "\n",
    "    # Evaluate excess length\n",
    "    true_interval = [d.quantile_interval(alpha/2, 1-alpha/2) for d in test_distributions]\n",
    "    excess_len = X_test @ (theta_n_hi - theta_n_lo) - true_interval\n",
    "\n",
    "    res = {'n': n, 'alpha': alpha, 'seed': seed, \n",
    "           'q_hat': q_hat,\n",
    "           'excess_len_avg': np.mean(np.abs(excess_len)).item(),\n",
    "           'coverage': np.mean((y_test >= X_test @ theta_n_lo - q_hat) & (y_test <= X_test @ theta_n_hi + q_hat)).tolist()\n",
    "          }\n",
    "    \n",
    "    print(f'n: {n}, m: {m}, alpha: {alpha} seed: {seed} excess_len_avg: {res[\"excess_len_avg\"]}')\n",
    "    return res\n",
    "\n",
    "def generate_parameter_combinations(n_list, m_list, seeds, alphas):\n",
    "    parameters = []\n",
    "    for seed in seeds:\n",
    "        for n in n_list:\n",
    "            for m in m_list:\n",
    "                for alpha in alphas:\n",
    "                    parameters.append((seed, n, m, alpha))\n",
    "    return parameters\n",
    "\n",
    "if __name__ == '__main__':\n",
    "    n_list = [200, 500, 1000, 2000, 4000, 6000, 8000, 10000, 15000, 20000]\n",
    "    m_list = [5000]\n",
    "    seeds = list(range(0, 20))\n",
    "    alphas = [0.01, 0.025, 0.05, 0.075, 0.1, 0.125, 0.15, 0.175, 0.2]\n",
    "    param_combinations = generate_parameter_combinations(n_list, m_list, seeds, alphas)\n",
    "\n",
    "    with Pool(processes=multiprocessing.cpu_count()) as pool:\n",
    "        results = pool.starmap(process_single_run, param_combinations)\n",
    "    \n",
    "    df = pd.DataFrame(results)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e7593795",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.DataFrame(results)\n",
    "df.to_csv(\"./cqr_vary_n_uniform_1.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6e142f38",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Dataset(object):\n",
    "    def __init__(self, d, alpha_min, rng):\n",
    "        self.d = d\n",
    "        self.rng = rng\n",
    "        self.alpha_min = alpha_min \n",
    "        self._X_min = 1\n",
    "        self._X_max = 20\n",
    "        self._theta_min = 1\n",
    "        self._theta_max = 2\n",
    "\n",
    "        self.theta = self.rng.uniform(low=self._theta_min, high=self._theta_max, size=(d,))\n",
    "        self.x_max = np.array([self._X_max, self._X_max])\n",
    "\n",
    "    def get_data_distribution(self, x):\n",
    "        return DataDistribution(self.alpha_min, self.x_max, self.theta, x=x, rng=self.rng)\n",
    "\n",
    "    def generate_X_y(self, n):\n",
    "        X = self.rng.uniform(low=self._X_min, high=self._X_max, size=(n, self.d))\n",
    "        distributions = [self.get_data_distribution(x_i) for x_i in X]\n",
    "        y = np.array([distribution.sample() for distribution in distributions]).flatten()\n",
    "        return X, y, distributions"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "45782be8",
   "metadata": {},
   "source": [
    "We compute the large coefficients in Equation (41) and Equation (42) following the construction described in Appendix D.1\n",
    "- d: dimension of the domain  \n",
    "- B: \\sup_X ||X||_2\n",
    "- K: \\sup_\\theta ||\\theta||_2\n",
    "- R = B K+1/f_min"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "14bdd918",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "B: 28.28, K: 2.83, f_max: 0.11, f_min: 2.34e-05 R: 42905.68, lambda_max: 273.57, lambda_min: 35.70\n"
     ]
    }
   ],
   "source": [
    "rng = np.random.default_rng(0)\n",
    "alpha_min = 0.005\n",
    "n = 200\n",
    "\n",
    "dataset = Dataset(2, alpha_min, rng)\n",
    "X, _, distributions = dataset.generate_X_y(n=n)\n",
    "\n",
    "\n",
    "d = dataset.d\n",
    "B = dataset._X_max * np.sqrt(2)\n",
    "K = dataset._theta_max * np.sqrt(2)\n",
    "\n",
    "f_max = max([data.pdf_max for data in distributions])\n",
    "f_min = min([data.pdf_min for data in distributions])\n",
    "\n",
    "R = B * K + 1 / f_min\n",
    "\n",
    "a = 0\n",
    "for i in range(n):\n",
    "    x = X[i]\n",
    "    a += np.c_[x] @ np.c_[x].T\n",
    "a /= n \n",
    "\n",
    "lambda_max, lambda_min = np.linalg.eigvals(a)\n",
    "\n",
    "\n",
    "print(f\"B: {B:.2f}, K: {K:.2f}, f_max: {f_max:.2f}, f_min: {f_min:.2e} R: {R:.2f}, lambda_max: {lambda_max:.2f}, lambda_min: {lambda_min:.2f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fc6bd4fc",
   "metadata": {},
   "source": [
    "Equation (41)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3e42608c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "np.float64(102410.24594489107)"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Coefficients of 1/sqrt{n}, part 1\n",
    "4 * lambda_max * np.sqrt(f_max * d) / (lambda_min ** 1.5 * f_min)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d1e305e8",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "np.float64(342804.27746433445)"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Coefficients of 1/sqrt{n}, part 2\n",
    "2 * B * lambda_max * np.sqrt(2 * f_max * d) / (lambda_min ** 2 * f_min)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "356f50d6",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "np.float64(26837.012912011665)"
      ]
     },
     "execution_count": 33,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Coefficients of 1/sqrt{m}\n",
    "np.sqrt(np.pi) / (2 * np.sqrt(2) * f_min)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "f0bb4061",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "16.5895707122441"
      ]
     },
     "execution_count": 36,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(102410.24594489107 + 342804.27746433445) / 26837.012912011665"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "f64519ae",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 1 / f_min\n",
    "# 1056 * lambda_max ** 2 * f_max ** 3 * B**2 * R / (lambda_min ** 4 * f_min**2)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f9f18f5e",
   "metadata": {},
   "source": [
    "Equation (42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "1749990a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "np.float64(484798.45842957124)"
      ]
     },
     "execution_count": 39,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Coefficients of 1/sqrt{n}, part 1\n",
    "4 * B * lambda_max * np.sqrt(f_max * d) / (lambda_min ** 2 * f_min)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "35b5551d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "np.float64(53674.02582402333)"
      ]
     },
     "execution_count": 38,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Coefficients of 1/sqrt{m}\n",
    "np.sqrt(np.pi) / (np.sqrt(2) * f_min)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "646ca622",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "9.032273077839188"
      ]
     },
     "execution_count": 40,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "484798.45842957124 / 53674.02582402333"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
