{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "import dotenv\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.ticker\n",
    "import numpy as np\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# Load environment variables\n",
    "dotenv.load_dotenv()\n",
    "\n",
    "# Enable loading of the project module\n",
    "MODULE_DIR = os.path.join(os.path.abspath(os.path.join(os.path.curdir, os.path.pardir)), 'src')\n",
    "sys.path.append(MODULE_DIR)\n",
    "\n",
    "try:\n",
    "    PLOT_BASE_DIR = os.path.abspath(os.environ[\"PLOT_BASE_DIR\"])\n",
    "except KeyError:\n",
    "    raise RuntimeError(\"Missing plot output dir. Set the PLOT_BASE_DIR variable accordingly\")\n",
    "os.makedirs(PLOT_BASE_DIR, exist_ok=True)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "import plot_util\n",
    "plot_util.setup_matplotlib()"
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Interpolating vs regularized"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "BETA_CAPTION = r\"$\\beta$ s.t.\\ filter size $\\in \\Theta\\mathopen{}\\left(d^\\beta\\right)\\mathclose{}$\"\n"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "import theory\n",
    "\n",
    "beta_star, betas, rate_error_int, rate_error_reg, rate_var, rate_bias = theory.Plots()"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(\n",
    "    figsize=(plot_util.FIGURE_WIDTH_FOR_COLUMNS_IN[6], plot_util.FIGURE_WIDTH_FOR_COLUMNS_IN[6] / plot_util.GOLDEN_RATIO)\n",
    ")\n",
    "\n",
    "ax.plot(\n",
    "    betas,\n",
    "    rate_error_int,\n",
    "    c=\"C0\",\n",
    "    label=fr\"interpolator\",\n",
    "    ls=plot_util.LINESTYLE_MAP[0],\n",
    ")\n",
    "ax.plot(\n",
    "    betas,\n",
    "    rate_error_reg,\n",
    "    c=\"C1\",\n",
    "    label=fr\"optimal regularized\",\n",
    "    ls=plot_util.LINESTYLE_MAP[1],\n",
    ")\n",
    "\n",
    "beta_star_idx = int(np.argwhere(betas == beta_star))\n",
    "\n",
    "ax.set_ylim(-0.45, 0.0)\n",
    "ax.set_yticks((\n",
    "    rate_error_reg.min(),\n",
    "    rate_error_int[beta_star_idx],\n",
    "    rate_error_int[:beta_star_idx].max(),\n",
    "    0.0\n",
    "))\n",
    "ax.set_yticklabels((\n",
    "    round(rate_error_reg.min(), 2),\n",
    "    round(rate_error_int[beta_star_idx], 2),\n",
    "    r\"$-\\frac{\\ell_{\\sigma}}{\\ell}$\",\n",
    "    \"0\"\n",
    "))\n",
    "ax.yaxis.set_minor_locator(matplotlib.ticker.FixedLocator(\n",
    "    np.linspace(-0.45, 0.0, num=10)\n",
    "))\n",
    "ax.set_ylabel(r\"error rate exponent $\\alpha$ $\\downarrow$\")\n",
    "\n",
    "ax.set_xlim(0.1, 1.0)\n",
    "ax.set_xticks((0.1, beta_star, 1.0))\n",
    "ax.set_xticklabels((\"0.1\", r\"$\\beta^*$\", \"1\"))\n",
    "ax.xaxis.set_minor_locator(matplotlib.ticker.FixedLocator(\n",
    "    np.linspace(0.1, 1.0, num=10)\n",
    "))\n",
    "ax.set_xlabel(BETA_CAPTION)\n",
    "\n",
    "plt.show()\n",
    "\n",
    "fig.savefig(os.path.join(PLOT_BASE_DIR, \"rates.pdf\"))\n",
    "\n",
    "plt.close(fig)\n",
    "\n",
    "legend_fig = plt.figure(figsize=(\n",
    "    plot_util.FIGURE_WIDTH_FOR_COLUMNS_IN[6],\n",
    "    plot_util.LEGEND_HEIGHT_FOR_ROWS_IN[1],\n",
    "))\n",
    "handles, labels = ax.get_legend_handles_labels()\n",
    "legend_fig.legend(\n",
    "    handles,\n",
    "    labels,\n",
    "    loc='center',\n",
    "    ncol=len(handles),\n",
    "    frameon=False,\n",
    "    borderpad=0.5  # TODO\n",
    ")\n",
    "\n",
    "legend_fig.savefig(os.path.join(PLOT_BASE_DIR, \"rates_legend.pdf\"))\n",
    "\n",
    "plt.close(legend_fig)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Bias vs. variance"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(\n",
    "    figsize=(plot_util.FIGURE_WIDTH_FOR_COLUMNS_IN[6], plot_util.FIGURE_WIDTH_FOR_COLUMNS_IN[6] / plot_util.GOLDEN_RATIO)\n",
    ")\n",
    "\n",
    "ax.plot(\n",
    "    betas,\n",
    "    rate_var,\n",
    "    c=\"C3\",\n",
    "    label=fr\"variance\",\n",
    "    ls=plot_util.LINESTYLE_MAP[0],\n",
    ")\n",
    "ax.plot(\n",
    "    betas,\n",
    "    rate_bias,\n",
    "    c=\"C2\",\n",
    "    label=fr\"bias\",\n",
    "    ls=plot_util.LINESTYLE_MAP[1],\n",
    ")\n",
    "\n",
    "beta_star_idx = int(np.argwhere(betas == beta_star))\n",
    "\n",
    "ax.set_ylim(-0.9, 0.0)\n",
    "ax.set_yticks((\n",
    "    -0.9,\n",
    "    rate_var[beta_star_idx],\n",
    "    rate_var.max(),\n",
    "    0.0\n",
    "))\n",
    "ax.set_yticklabels((\n",
    "    -0.9,\n",
    "    round(rate_error_int[beta_star_idx], 2),\n",
    "    r\"$-\\frac{\\ell_{\\sigma}}{\\ell}$\",\n",
    "    \"0\"\n",
    "))\n",
    "ax.yaxis.set_minor_locator(matplotlib.ticker.FixedLocator(\n",
    "    np.linspace(-0.9, 0.0, num=10)\n",
    "))\n",
    "ax.set_ylabel(r\"rate exponent $\\alpha$ $\\downarrow$\")\n",
    "\n",
    "ax.set_xlim(0.1, 1.0)\n",
    "ax.set_xticks((0.1, beta_star, 1.0))\n",
    "ax.set_xticklabels((\"0.1\", r\"$\\beta^*$\", \"1\"))\n",
    "ax.xaxis.set_minor_locator(matplotlib.ticker.FixedLocator(\n",
    "    np.linspace(0.1, 1.0, num=10)\n",
    "))\n",
    "ax.set_xlabel(BETA_CAPTION)\n",
    "\n",
    "plt.show()\n",
    "\n",
    "fig.savefig(os.path.join(PLOT_BASE_DIR, \"tradeoff.pdf\"))\n",
    "\n",
    "plt.close(fig)\n",
    "\n",
    "legend_fig = plt.figure(figsize=(\n",
    "    plot_util.FIGURE_WIDTH_FOR_COLUMNS_IN[6],\n",
    "    plot_util.LEGEND_HEIGHT_FOR_ROWS_IN[1],\n",
    "))\n",
    "handles, labels = ax.get_legend_handles_labels()\n",
    "legend_fig.legend(\n",
    "    handles,\n",
    "    labels,\n",
    "    loc='center',\n",
    "    ncol=len(handles),\n",
    "    frameon=False,\n",
    "    borderpad=0.5  # TODO\n",
    ")\n",
    "\n",
    "legend_fig.savefig(os.path.join(PLOT_BASE_DIR, \"tradeoff_legend.pdf\"))\n",
    "\n",
    "plt.close(legend_fig)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}