{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "98e52941-bae5-4b55-b90c-7ee915957151",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "# Paths shown reflect the default Jupyter Docker Stacks user directory (/home/jovyan).\n",
    "code_path = '/home/jovyan/code/'\n",
    "\n",
    "# source utility functions \n",
    "file_path = os.path.join(code_path, 'utility_functions_implementing_tabpfn_generators_iclr.py')\n",
    "with open(os.path.expanduser(file_path)) as file:\n",
    "    exec(file.read())\n",
    "\n",
    "# source additional utility functions \n",
    "file_path = os.path.join(code_path, 'additional_utility_functions_for_tabpfn_generators_iclr.py')\n",
    "with open(os.path.expanduser(file_path)) as file:\n",
    "    exec(file.read())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bb330e3c-f036-4b28-bad0-9dd4b14475e6",
   "metadata": {},
   "outputs": [],
   "source": [
    "from time import perf_counter\n",
    "from typing import Iterable, Dict\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from tqdm import tqdm\n",
    "\n",
    "\n",
    "def generate_beta_parameter_list(\n",
    "    p: int,\n",
    "    rng: np.random.Generator | None = None\n",
    "):\n",
    "    if rng is None:\n",
    "        rng = np.random.default_rng()\n",
    "    ab = rng.uniform(0.0, 10.0, size=(p, 2))\n",
    "    return [tuple(row) for row in ab]\n",
    "\n",
    "\n",
    "def benchmark_generators(\n",
    "    n_repli: int,\n",
    "    num_feat_grid: Iterable[int],\n",
    "    n: int,\n",
    "    rho: float,\n",
    "    rng: np.random.Generator | None = None,\n",
    ") -> Dict[str, pd.DataFrame]:\n",
    "    if rng is None:\n",
    "        rng = np.random.default_rng()\n",
    "\n",
    "    grid = list(num_feat_grid)\n",
    "    out_miav = np.full((n_repli, len(grid)), np.nan)\n",
    "    out_jf   = np.full_like(out_miav, np.nan)\n",
    "    out_fc   = np.full_like(out_miav, np.nan)\n",
    "\n",
    "    total_iters = n_repli * len(grid)\n",
    "\n",
    "    with tqdm(total=total_iters, desc=\"Benchmarking\", unit=\"run\") as pbar:\n",
    "        for j, p in enumerate(grid):\n",
    "            for i in range(n_repli):\n",
    "                # --- simulate the data ---\n",
    "                beta_pars_list = generate_beta_parameter_list(p, rng=rng)\n",
    "                X = simulate_correlated_beta_data(n=n, rho=rho, beta_pars_list=beta_pars_list)\n",
    "\n",
    "                # --- benchmark the three generators ---\n",
    "                t0 = perf_counter()\n",
    "                _ = miav_tabpfn_generator(X = X, show_progress = False)\n",
    "                out_miav[i, j] = perf_counter() - t0\n",
    "\n",
    "                t0 = perf_counter()\n",
    "                _ = joint_factorization_tabpfn_generator(X = X, show_progress = False)\n",
    "                out_jf[i, j] = perf_counter() - t0\n",
    "\n",
    "                t0 = perf_counter()\n",
    "                _ = full_conditionals_tabpfn_generator(X = X, show_progress = False)\n",
    "                out_fc[i, j] = perf_counter() - t0\n",
    "\n",
    "                pbar.update(1)  # advance progress bar\n",
    "\n",
    "    cols = pd.Index(grid, name=\"num_features\")\n",
    "    return {\n",
    "        \"out_miav\": pd.DataFrame(out_miav, columns=cols),\n",
    "        \"out_jf\":   pd.DataFrame(out_jf,   columns=cols),\n",
    "        \"out_fc\":   pd.DataFrame(out_fc,   columns=cols),\n",
    "    }\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cf6d9983-8120-45b7-98a0-e8a53f501cbd",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_repli = 5\n",
    "num_feat_grid = [10, 20, 30, 40, 50, 60, 70, 80, 90, 100]\n",
    "n = 1000\n",
    "rho = 0.5\n",
    "\n",
    "results = benchmark_generators(\n",
    "     n_repli = n_repli,\n",
    "     num_feat_grid = num_feat_grid,\n",
    "     n = n,\n",
    "     rho = rho,\n",
    "     rng = np.random.default_rng(123)\n",
    " )\n",
    "\n",
    "out_miav = results['out_miav']\n",
    "out_jf = results['out_jf']\n",
    "out_fc = results['out_fc']\n",
    "\n",
    "out_miav.to_csv(\"time_bench_miav_n_1000.csv\", index = False)\n",
    "out_jf.to_csv(\"time_bench_jf_n_1000.csv\", index = False)\n",
    "out_fc.to_csv(\"time_bench_fc_n_1000.csv\", index = False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d5d3e6f-10ab-4cdd-a79e-765cf988f2b7",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_repli = 5\n",
    "num_feat_grid = [10, 20, 30, 40, 50, 60, 70, 80, 90, 100]\n",
    "n = 2000\n",
    "rho = 0.5\n",
    "\n",
    "results = benchmark_generators(\n",
    "     n_repli = n_repli,\n",
    "     num_feat_grid = num_feat_grid,\n",
    "     n = n,\n",
    "     rho = rho,\n",
    "     rng = np.random.default_rng(123)\n",
    " )\n",
    "\n",
    "out_miav = results['out_miav']\n",
    "out_jf = results['out_jf']\n",
    "out_fc = results['out_fc']\n",
    "\n",
    "out_miav.to_csv(\"time_bench_miav_n_2000.csv\", index = False)\n",
    "out_jf.to_csv(\"time_bench_jf_n_2000.csv\", index = False)\n",
    "out_fc.to_csv(\"time_bench_fc_n_2000.csv\", index = False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cb11b9fc-3e71-48ff-b830-96996adb84a2",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_repli = 5\n",
    "num_feat_grid = [10, 20, 30, 40, 50, 60, 70, 80, 90, 100]\n",
    "n = 3000\n",
    "rho = 0.5\n",
    "\n",
    "results = benchmark_generators(\n",
    "     n_repli = n_repli,\n",
    "     num_feat_grid = num_feat_grid,\n",
    "     n = n,\n",
    "     rho = rho,\n",
    "     rng = np.random.default_rng(123)\n",
    " )\n",
    "\n",
    "out_miav = results['out_miav']\n",
    "out_jf = results['out_jf']\n",
    "out_fc = results['out_fc']\n",
    "\n",
    "out_miav.to_csv(\"time_bench_miav_n_3000.csv\", index = False)\n",
    "out_jf.to_csv(\"time_bench_jf_n_3000.csv\", index = False)\n",
    "out_fc.to_csv(\"time_bench_fc_n_3000.csv\", index = False)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
