{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "\n",
    "## MNIST------------------\n",
    "data = np.array([ \n",
    "    [107, 207, 457, 727, 1050],\n",
    "    [1.128907265, 1.129903694, 0.95000524, 1.065457432, 1.07842692],  # US-cost-cd\n",
    "    [0.110218761, 0.052139277, 0.038944393, 0.041994556, 0.026442336],\n",
    "    [1.321370877, 1.302779898, 1.133641583, 1.225134875, 1.106242767],  # US-cost-nd\n",
    "    [0.056115027, 0.069383208, 0.051824403, 0.055883342, 0.035187563],\n",
    "    [1.052127953, 1.010772865, 0.932377781, 0.997475716, 1.020153382],  # CS-cost-cd\n",
    "    [0.123135508, 0.047050721, 0.049005097, 0.020344666, 0.030966611],\n",
    "    [1.285456198, 1.095992328, 1.111457288, 1.062367158, 1.078174653],  # CS-cost-nd\n",
    "    [0.163859898, 0.062611723, 0.065212466, 0.027073222, 0.041208143],\n",
    "    [0.016784394, 0.032314512, 0.058098828, 0.090117951, 0.124908273],\n",
    "    [0.00061105\t, 0.001779524, 0.00332018 , 0.009177903, 0.011535964],\n",
    "    [0.020852594, 0.035248069, 0.064630771, 0.097514323, 0.130831342],\n",
    "    [0.001144923, 0.001170404, 0.003033192, 0.006771496, 0.009581683]\n",
    "])\n",
    "\n",
    "\n",
    "data = np.array([ \n",
    "[98\t,            229,\t        518\t   ,     800\t,        1022],\n",
    "[1.385657769,\t1.439847652,\t1.307507055,\t1.332869042,\t1.247963972],\n",
    "[0.182460984,\t0.196996492,\t0.124174898,\t0.145391186,\t0.104980348],\n",
    "[1.329337019,\t1.356253487,\t1.23696181,\t    1.246061124,\t1.21640765],\n",
    "[0.132378931,\t0.122671998,\t0.11955425,\t    0.140853685,\t0.075317043],\n",
    "[1.117119621,\t1.053939527,\t1.081315637,\t0.95361068,\t    1.072596685],\n",
    "[0.042874569,\t0.07600911,     0.075840169,\t0.074246636,\t0.105601742],\n",
    "[1.089705833,\t1.015831175,\t1.034047334,\t0.946062464,\t1.088840971],\n",
    "[0.061325524,\t0.070880353,\t0.10320549,\t    0.087807818,\t0.134093922],\n",
    "[0.063943539,\t0.119506072,\t0.261097259,\t0.40215236,\t    0.434957432],\n",
    "[0.001331551,\t0.004756456,\t0.022773619,\t0.046057504,\t0.001771119],\n",
    "[0.08123109\t,    0.151810029,\t0.273526873,\t0.435773293,\t0.490406661],\n",
    "[0.004508728,\t0.002975685,\t0.004542501,\t0.043295801,\t0.013428183]\n",
    "])\n",
    "\n",
    "\n",
    "## ModelNet------------------\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "# Extract x-coordinates\n",
    "x = data[0]\n",
    "\n",
    "# Extract lines and variances\n",
    "line1, var1 = data[1], data[2]\n",
    "line2, var2 = data[3], data[4]\n",
    "line3, var3 = data[5], data[6]\n",
    "line4, var4 = data[7], data[8]\n",
    "runtime_US, runtime_US_var = data[9], data[10]\n",
    "runtime_CS, runtime_CS_var = data[11], data[12]\n",
    "\n",
    "# Create subplots\n",
    "fig, axes = plt.subplots(1, 3, figsize=(24, 8))\n",
    "\n",
    "# Subplot 1: Cost comparisons (cd)\n",
    "axes[1].plot(x, line1, label=\"US-cost-cd\", linestyle=\"--\", marker=\"o\", markersize=10, linewidth=3)\n",
    "axes[1].fill_between(x, line1 - var1, line1 + var1, alpha=0.2)\n",
    "axes[1].plot(x, line3, label=\"CS-cost-cd\", linestyle=\"-\", marker=\"^\", markersize=10, linewidth=3)\n",
    "axes[1].fill_between(x, line3 - var3, line3 + var3, alpha=0.2)\n",
    "axes[1].set_xlabel(\"Sample size\", fontsize=20)\n",
    "axes[1].set_ylabel(\"cost-cd over baseline\", fontsize=20)\n",
    "# axes[1].set_title(\"Cost Comparison (cd)\", fontsize=22)\n",
    "axes[1].annotate(\"(b) Comparison of cost-cd\", xy=(0.5, -0.15), xycoords=\"axes fraction\", \n",
    "                 ha=\"center\", va=\"center\", fontsize=25)\n",
    "axes[1].legend(fontsize=15)\n",
    "axes[1].grid(True)\n",
    "\n",
    "# Subplot 2: Cost comparisons (nd)\n",
    "axes[2].plot(x, line2, label=\"US-cost-nd\", linestyle=\"--\", marker=\"o\", markersize=10, linewidth=3)\n",
    "axes[2].fill_between(x, line2 - var2, line2 + var2, alpha=0.2)\n",
    "axes[2].plot(x, line4, label=\"CS-cost-nd\", linestyle=\"-\", marker=\"^\", markersize=10, linewidth=3)\n",
    "axes[2].fill_between(x, line4 - var4, line4 + var4, alpha=0.2)\n",
    "axes[2].set_xlabel(\"Sample size\", fontsize=20)\n",
    "axes[2].set_ylabel(\"cost-nd over baseline\", fontsize=20)\n",
    "# axes[2].set_title(\"Cost Comparison (nd)\", fontsize=22)\n",
    "axes[2].annotate(\"(c) Comparison of cost-nd\", xy=(0.5, -0.15), xycoords=\"axes fraction\", \n",
    "                 ha=\"center\", va=\"center\", fontsize=25)\n",
    "axes[2].legend(fontsize=15)\n",
    "axes[2].grid(True)\n",
    "\n",
    "# Subplot 3: Runtime comparisons\n",
    "axes[0].plot(x, runtime_US, label=\"US-runtime\", linestyle=\"--\", marker=\"o\", markersize=10, linewidth=3)\n",
    "axes[0].fill_between(x, runtime_US - runtime_US_var, runtime_US + runtime_US_var, alpha=0.2)\n",
    "axes[0].plot(x, runtime_CS, label=\"CS-runtime\", linestyle=\"-\", marker=\"^\", markersize=10, linewidth=3)\n",
    "axes[0].fill_between(x, runtime_CS - runtime_CS_var, runtime_CS + runtime_CS_var, alpha=0.2)\n",
    "axes[0].set_xlabel(\"Sample size\", fontsize=20)\n",
    "axes[0].set_ylabel(\"runtime over baseline\", fontsize=20)\n",
    "# axes[0].set_title(\"(a) Runtime Comparison\", fontsize=22)\n",
    "axes[0].annotate(\"(a) Comparison of runtime\", xy=(0.5, -0.15), xycoords=\"axes fraction\", \n",
    "                 ha=\"center\", va=\"center\", fontsize=25)\n",
    "axes[0].legend(fontsize=15)\n",
    "axes[0].grid(True)\n",
    "\n",
    "# Adjust layout and show\n",
    "plt.tight_layout()\n",
    "plt.show()\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
