{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "import mlflow\n",
    "import dotenv\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.ticker\n",
    "import numpy as np\n",
    "import data\n",
    "import jax.random\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# Load environment variables\n",
    "dotenv.load_dotenv()\n",
    "\n",
    "# Enable loading of the project module\n",
    "MODULE_DIR = os.path.join(os.path.abspath(os.path.join(os.path.curdir, os.path.pardir)), 'src')\n",
    "sys.path.append(MODULE_DIR)\n",
    "\n",
    "try:\n",
    "    PLOT_BASE_DIR = os.path.abspath(os.environ[\"PLOT_BASE_DIR\"])\n",
    "except KeyError:\n",
    "    raise RuntimeError(\"Missing plot output dir. Set the PLOT_BASE_DIR variable accordingly\")\n",
    "os.makedirs(PLOT_BASE_DIR, exist_ok=True)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# Load data\n",
    "client = mlflow.tracking.MlflowClient()\n",
    "experiment = client.get_experiment_by_name(\"paper_filter_size_synthetic\")\n",
    "\n",
    "token = None\n",
    "run_records = []\n",
    "while True:\n",
    "    runs = client.search_runs(experiment.experiment_id, page_token=token)\n",
    "    token = runs.token\n",
    "    for run in runs:\n",
    "        run_dict = {\n",
    "            \"id\": run.info.run_id,\n",
    "            \"status\": run.info.status,\n",
    "        }\n",
    "        run_dict.update(run.data.params)\n",
    "        run_dict.update(run.data.metrics)\n",
    "        run_dict.update({\n",
    "            f\"tags.{tag}\": value for tag, value in run.data.tags.items()\n",
    "        })\n",
    "        run_records.append(run_dict)\n",
    "\n",
    "    if token is None:\n",
    "        break\n",
    "\n",
    "df_runs = pd.DataFrame.from_records(run_records)\n",
    "\n",
    "df_runs_noninterp = df_runs[df_runs.train_accuracy < 1]\n",
    "if len(df_runs_noninterp) > 0:\n",
    "    df_runs_noninterp = df_runs_noninterp[\n",
    "        [\"id\", \"network_filter_size\", \"network_width\", \"noise_fraction\", \"seed\", \"training_seed\",\n",
    "         \"train_accuracy\", \"train_loss\"]\n",
    "    ]\n",
    "    print(f\"Found non-interpolating runs:\\n{df_runs_noninterp}\")\n",
    "df_runs = df_runs[df_runs.train_accuracy == 1.0]\n",
    "\n",
    "df_runs_fail = df_runs[(df_runs.status != \"FINISHED\")]\n",
    "if len(df_runs_fail) > 0:\n",
    "    df_runs_fail = df_runs_fail[\n",
    "        [\"id\", \"network_filter_size\", \"network_width\", \"noise_fraction\", \"seed\", \"training_seed\",\n",
    "         \"train_accuracy\", \"train_loss\"]\n",
    "    ]\n",
    "    print(f\"Found failed runs:\\n{df_runs_fail}\")\n",
    "df_runs = df_runs[df_runs.status == \"FINISHED\"]\n",
    "\n",
    "assert (df_runs.status == \"FINISHED\").all()\n",
    "assert (df_runs.train_accuracy == 1.0).all()\n",
    "\n",
    "print(f\"Have a total of {len(df_runs)} runs\")\n",
    "\n",
    "COLUMNS_SETTING = [\n",
    "    \"noise_fraction\", \"network_width\", \"network_filter_size\"\n",
    "]\n",
    "COLUMNS_SETTING_FULL = [\"seed\"] + COLUMNS_SETTING\n",
    "COLUMNS_METRICS = [\n",
    "    \"best_final_test_accuracy_diff\", \"best_test_accuracy\", \"test_accuracy\",\n",
    "    \"train_accuracy\", \"train_loss\"\n",
    "]\n",
    "COLUMNS = COLUMNS_SETTING_FULL + COLUMNS_METRICS\n",
    "\n",
    "df_runs = df_runs.set_index(\"id\")\n",
    "df_runs_full = df_runs\n",
    "df_runs = df_runs[COLUMNS]\n",
    "\n",
    "df_runs.network_filter_size = df_runs.network_filter_size.astype(int)\n",
    "df_runs.network_width = df_runs.network_width.astype(int)\n",
    "\n",
    "# Aggregate over inner seeds; treat as one quantity\n",
    "df_agg = df_runs.groupby(COLUMNS_SETTING_FULL).agg(\"mean\").reset_index()\n",
    "\n",
    "# Load early-stopped train accuracy for all runs\n",
    "early_stopped_accuracies = {}\n",
    "early_stopped_noise_accuracies = {}\n",
    "early_stopping_epochs = {}\n",
    "for idx, run_id in enumerate(df_runs.index):\n",
    "    best_test_accuracy = max(client.get_metric_history(run_id, key=\"test_accuracy\"), key=lambda metric: metric.value)\n",
    "    best_epoch = best_test_accuracy.step\n",
    "    early_stopping_epochs[run_id] = best_epoch\n",
    "    train_accuracies = client.get_metric_history(run_id, key=\"train_accuracy\")\n",
    "    early_stopped_train_accuracy, = filter(lambda metric: metric.step == best_epoch, train_accuracies)\n",
    "    early_stopped_accuracies[run_id] = early_stopped_train_accuracy.value\n",
    "\n",
    "    if df_runs.loc[run_id][\"noise_fraction\"] == \"0.2\":\n",
    "        train_noise_accuracies = client.get_metric_history(run_id, key=\"train_noise_accuracy\")\n",
    "        early_stopped_noise_accuracy, = filter(lambda metric: metric.step == best_epoch, train_noise_accuracies)\n",
    "        early_stopped_noise_accuracies[run_id] = early_stopped_noise_accuracy.value\n",
    "    else:\n",
    "        early_stopped_noise_accuracies[run_id] = 0.0\n",
    "\n",
    "    if (idx + 1) % 1000 == 0 or (idx + 1) == len(df_runs):\n",
    "        print(f\"Processed {idx + 1}/{len(df_runs)} runs\")\n",
    "\n",
    "df_runs[\"es_train_accuracy\"] = pd.Series(early_stopped_accuracies)\n",
    "df_runs[\"es_train_noise_accuracy\"] = pd.Series(early_stopped_noise_accuracies)\n",
    "df_runs[\"es_train_noiseless_accuracy\"] = (\n",
    "    df_runs[\"es_train_accuracy\"] - df_runs[\"noise_fraction\"].astype(float) * df_runs[\"es_train_noise_accuracy\"]\n",
    ") / (1.0 - df_runs[\"noise_fraction\"].astype(float))\n",
    "df_runs[\"es_train_epoch\"] = pd.Series(early_stopping_epochs)\n",
    "df_agg = df_runs.groupby(COLUMNS_SETTING_FULL).agg(\"mean\").reset_index()"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "import plot_util\n",
    "plot_util.setup_matplotlib()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "NOISE_COLOR_MAP = {\n",
    "    \"0.0\": \"C0\",\n",
    "    \"0.2\": \"C1\",\n",
    "}\n",
    "\n",
    "ES_LINESTYLE_MAP = {\n",
    "    \"int\": plot_util.LINESTYLE_MAP[0],\n",
    "    \"es\": plot_util.LINESTYLE_MAP[1],\n",
    "}\n"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Test error curve"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(\n",
    "    figsize=(plot_util.FIGURE_WIDTH_FOR_COLUMNS_IN[6], plot_util.FIGURE_WIDTH_FOR_COLUMNS_IN[6] / plot_util.GOLDEN_RATIO)\n",
    ")\n",
    "\n",
    "df_test_error = df_agg.loc[df_agg.network_width == 128].drop(\n",
    "    [\"network_width\"], axis=1\n",
    ").groupby([\"noise_fraction\", \"network_filter_size\"]).agg([\"mean\", \"sem\"])\n",
    "for noise_idx, noise_fraction in enumerate(df_test_error.index.get_level_values(0).unique()):\n",
    "    df_setting_noise = df_test_error.loc[noise_fraction]\n",
    "    network_depth = df_setting_noise.index\n",
    "    xs = network_depth\n",
    "    ys_mean = 1.0 - df_setting_noise[\"test_accuracy\"][\"mean\"]\n",
    "    ys_err = df_setting_noise[\"test_accuracy\"][\"sem\"]\n",
    "    current_color = NOISE_COLOR_MAP[noise_fraction]\n",
    "    ax.fill_between(\n",
    "        xs,\n",
    "        ys_mean - ys_err,\n",
    "        ys_mean + ys_err,\n",
    "        color=current_color,\n",
    "        alpha=0.2,\n",
    "    )\n",
    "    ax.plot(\n",
    "        xs,\n",
    "        ys_mean,\n",
    "        c=current_color,\n",
    "        label=fr\"{float(noise_fraction)*100:.0f}\\% noise\",\n",
    "        ls=ES_LINESTYLE_MAP[\"int\"],\n",
    "    )\n",
    "\n",
    "    ys_mean = 1.0 - df_setting_noise[\"best_test_accuracy\"][\"mean\"]\n",
    "    ys_err = df_setting_noise[\"best_test_accuracy\"][\"sem\"]\n",
    "    ax.fill_between(\n",
    "        xs,\n",
    "        ys_mean - ys_err,\n",
    "        ys_mean + ys_err,\n",
    "        color=current_color,\n",
    "        alpha=0.2,\n",
    "    )\n",
    "    ax.plot(\n",
    "        xs, ys_mean,\n",
    "        c=current_color,\n",
    "        label=f\"{float(noise_fraction)*100:.0f}\\% noise\",\n",
    "        ls=ES_LINESTYLE_MAP[\"es\"],\n",
    "    )\n",
    "\n",
    "ax.set_ylim(0.0 - 1e-2, 0.45)\n",
    "ax.set_yticks((\n",
    "    0.0,\n",
    "    0.06,\n",
    "    0.1,\n",
    "    0.45\n",
    "))\n",
    "ax.yaxis.set_major_formatter(matplotlib.ticker.PercentFormatter(xmax=1.0, decimals=0))\n",
    "\n",
    "ax.set_xlim(5, 31)\n",
    "ax.set_xticks((5, 13, 22, 31))\n",
    "ax.set_xlabel(r\"filter size\")\n",
    "ax.set_ylabel(\"test error\")\n",
    "ax.xaxis.set_minor_locator(matplotlib.ticker.FixedLocator(tuple(range(5, 32))))\n",
    "\n",
    "plt.show()\n",
    "\n",
    "fig.savefig(os.path.join(PLOT_BASE_DIR, \"filters_error_synthetic.pdf\"))\n",
    "\n",
    "plt.close(fig)\n",
    "\n",
    "# Legend\n",
    "handles, labels = ax.get_legend_handles_labels()\n",
    "proxy_patch = plot_util.proxy_patch()\n",
    "handles = [proxy_patch, proxy_patch] + handles\n",
    "labels = [\"interpolating:\", \"opt.\\ early-stopped:\"] + labels\n",
    "\n",
    "legend_fig = plt.figure(figsize=(\n",
    "    plot_util.FIGURE_WIDTH_FOR_COLUMNS_IN[6],\n",
    "    plot_util.LEGEND_HEIGHT_FOR_ROWS_IN[2],\n",
    "))\n",
    "\n",
    "legend_fig.legend(\n",
    "    handles,\n",
    "    labels,\n",
    "    loc='center right',\n",
    "    ncol=3,\n",
    "    frameon=False,\n",
    "    borderpad=0.5,\n",
    "    # handletextpad=0.3,\n",
    "    columnspacing=1.0,\n",
    ")\n",
    "\n",
    "legend_fig.savefig(os.path.join(PLOT_BASE_DIR, \"filters_error_synthetic_legend.pdf\"))\n",
    "\n",
    "plt.close(legend_fig)\n"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Early-stopped train error curves"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(\n",
    "    figsize=(\n",
    "        plot_util.FIGURE_WIDTH_FOR_COLUMNS_IN[6], plot_util.FIGURE_WIDTH_FOR_COLUMNS_IN[6] / plot_util.GOLDEN_RATIO\n",
    "    )\n",
    ")\n",
    "\n",
    "df_es_train_error = df_agg.loc[df_agg.network_width == 128].drop(\n",
    "    [\"network_width\"], axis=1\n",
    ").groupby([\"noise_fraction\", \"network_filter_size\"]).agg([\"mean\", \"sem\"])\n",
    "for noise_idx, noise_fraction in enumerate(df_es_train_error.index.get_level_values(0).unique()):\n",
    "    if noise_fraction == \"0.0\":\n",
    "        continue\n",
    "\n",
    "    df_setting_noise = df_es_train_error.loc[noise_fraction]\n",
    "\n",
    "    filter_size = df_setting_noise.index\n",
    "    xs = filter_size\n",
    "\n",
    "    ys_mean = 1.0 - df_setting_noise[\"es_train_noise_accuracy\"][\"mean\"]\n",
    "    ys_err = df_setting_noise[\"es_train_noise_accuracy\"][\"sem\"]\n",
    "    current_color = \"C3\"\n",
    "    ax.fill_between(xs, ys_mean - ys_err, ys_mean + ys_err, color=current_color, alpha=0.2)\n",
    "    ax.plot(\n",
    "        xs,\n",
    "        ys_mean,\n",
    "        c=current_color,\n",
    "        label=fr\"noisy subset\",\n",
    "        ls=ES_LINESTYLE_MAP[\"es\"],\n",
    "    )\n",
    "\n",
    "    ys_mean = 1.0 - df_setting_noise[\"es_train_noiseless_accuracy\"][\"mean\"]\n",
    "    ys_err = df_setting_noise[\"es_train_noiseless_accuracy\"][\"sem\"]\n",
    "    current_color = \"C2\"\n",
    "    ax.fill_between(xs, ys_mean - ys_err, ys_mean + ys_err, color=current_color, alpha=0.2)\n",
    "    ax.plot(\n",
    "        xs,\n",
    "        ys_mean,\n",
    "        c=current_color,\n",
    "        label=fr\"clean subset\",\n",
    "        ls=ES_LINESTYLE_MAP[\"es\"],\n",
    "    )\n",
    "\n",
    "\n",
    "ax.set_ylim(0.0 - 1e-2, 1.0 + 1e-2)\n",
    "ax.set_yticks((\n",
    "    0.0,\n",
    "    1.0\n",
    "))\n",
    "ax.yaxis.set_major_formatter(matplotlib.ticker.PercentFormatter(xmax=1.0, decimals=0))\n",
    "\n",
    "ax.set_xlim(5, 31)\n",
    "ax.set_xticks((5, 13, 22, 31))\n",
    "ax.set_xlabel(r\"filter size\")\n",
    "ax.set_ylabel(r\"subset train error \\\\(optimal early stopping)\")\n",
    "ax.xaxis.set_minor_locator(matplotlib.ticker.FixedLocator(tuple(range(5, 32))))\n",
    "\n",
    "plt.show()\n",
    "\n",
    "fig.savefig(os.path.join(PLOT_BASE_DIR, \"filters_error_es.pdf\"))\n",
    "\n",
    "plt.close(fig)\n",
    "\n",
    "# Legend\n",
    "handles, labels = ax.get_legend_handles_labels()\n",
    "proxy_patch = plot_util.proxy_patch()\n",
    "handles = [proxy_patch, handles[0], proxy_patch, handles[1]]\n",
    "labels = [\"\", labels[0], \"\", labels[1]]\n",
    "\n",
    "legend_fig = plt.figure(figsize=(\n",
    "    plot_util.FIGURE_WIDTH_FOR_COLUMNS_IN[6],\n",
    "    plot_util.LEGEND_HEIGHT_FOR_ROWS_IN[2],\n",
    "))\n",
    "\n",
    "legend_fig.legend(\n",
    "    handles,\n",
    "    labels,\n",
    "    loc='center',\n",
    "    ncol=2,\n",
    "    frameon=False,\n",
    "    borderpad=0.5,\n",
    ")\n",
    "\n",
    "legend_fig.savefig(os.path.join(PLOT_BASE_DIR, \"filters_error_es_legend.pdf\"))\n",
    "\n",
    "plt.close(legend_fig)\n"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Example data samples"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "images_per_class = 2\n",
    "\n",
    "with jax.default_device(jax.devices(\"cpu\")[0]):\n",
    "    dataset = data.ShapeDataset(\n",
    "        image_size=32,\n",
    "        shape_size=5,\n",
    "        min_shape_size=3,\n",
    "        num_shapes_per_sample=10,\n",
    "        force_inside=True,\n",
    "        use_squares=False,\n",
    "        use_background=False,\n",
    "        cache_dir=None,\n",
    "    )\n",
    "    xs, ys = dataset.generate_samples(2 * images_per_class, key=jax.random.PRNGKey(42))\n"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "assert xs.shape == (2 * images_per_class, 32, 32, 1)\n",
    "assert ys.shape == (2 * images_per_class, 1)\n",
    "\n",
    "for idx, name in ((0, \"positive\"), (1, \"negative\")):\n",
    "    fig, axes = plt.subplots(\n",
    "        1, images_per_class,\n",
    "        figsize=(\n",
    "            plot_util.FIGURE_WIDTH_FOR_COLUMNS_IN[6], plot_util.FIGURE_WIDTH_FOR_COLUMNS_IN[6] / plot_util.GOLDEN_RATIO\n",
    "        )\n",
    "    )\n",
    "    for img_idx in range(images_per_class):\n",
    "        ax = axes[img_idx]\n",
    "        ax.imshow(\n",
    "            xs[idx * images_per_class + img_idx, :, :, 0], cmap=\"gray\"\n",
    "        )\n",
    "        ax.grid(visible=False)\n",
    "        ax.set_xticks((0, 32))\n",
    "        ax.set_yticks((0, 32))\n",
    "        ax.set_xlim(0, 32)\n",
    "        ax.set_ylim(0, 32)\n",
    "\n",
    "    plt.show()\n",
    "\n",
    "    fig.savefig(os.path.join(PLOT_BASE_DIR, f\"filters_data_{name}.pdf\"))\n",
    "\n",
    "    plt.close(fig)\n"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Differentiation from Deep Double Descent"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "diff_xrange = (5, 27)\n",
    "diff_xticks = (5, 13, 23, 27)\n",
    "diff_xticks_minor = (5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 27)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "for noise_fraction, noise_label in {\"0.0\": \"_noiseless\", \"0.2\": \"\"}.items():\n",
    "    fig, ax = plt.subplots(\n",
    "        figsize=(plot_util.FIGURE_WIDTH_FOR_COLUMNS_IN[4], plot_util.FIGURE_WIDTH_FOR_COLUMNS_IN[6] / plot_util.GOLDEN_RATIO)\n",
    "    )\n",
    "\n",
    "    df_ablation_error = df_agg[df_agg.noise_fraction == noise_fraction].drop(\"noise_fraction\", axis=1).groupby(\n",
    "        [\"network_width\", \"network_filter_size\"]\n",
    "    ).agg([\"mean\", \"sem\"])\n",
    "    for width_idx, width in enumerate(df_ablation_error.index.get_level_values(0).unique()):\n",
    "        df_setting_width = df_ablation_error.loc[width]\n",
    "        filter_size = df_setting_width.index\n",
    "        xs = filter_size\n",
    "        ys_mean = 1.0 - df_setting_width[\"test_accuracy\"][\"mean\"]\n",
    "        ys_err = df_setting_width[\"test_accuracy\"][\"sem\"]\n",
    "        current_color = f\"C{width_idx}\"  # TODO\n",
    "        ax.fill_between(\n",
    "            xs,\n",
    "            ys_mean - ys_err,\n",
    "            ys_mean + ys_err,\n",
    "            color=current_color,\n",
    "            alpha=0.2,\n",
    "        )\n",
    "        ax.plot(\n",
    "            xs,\n",
    "            ys_mean,\n",
    "            c=current_color,\n",
    "            label=fr\"{width}\",\n",
    "            ls=ES_LINESTYLE_MAP[\"int\"],\n",
    "        )\n",
    "\n",
    "        ys_mean = 1.0 - df_setting_width[\"best_test_accuracy\"][\"mean\"]\n",
    "        ys_err = df_setting_width[\"best_test_accuracy\"][\"sem\"]\n",
    "        ax.fill_between(\n",
    "            xs,\n",
    "            ys_mean - ys_err,\n",
    "            ys_mean + ys_err,\n",
    "            color=current_color,\n",
    "            alpha=0.2,\n",
    "        )\n",
    "        ax.plot(\n",
    "            xs, ys_mean,\n",
    "            c=current_color,\n",
    "            label=fr\"{width}\",\n",
    "            ls=ES_LINESTYLE_MAP[\"es\"],\n",
    "        )\n",
    "\n",
    "    ax.set_ylim(0.0 - 1e-2, 0.25)\n",
    "    ax.set_yticks((\n",
    "        0.0,\n",
    "        0.06,\n",
    "        0.1,\n",
    "        0.25,\n",
    "    ))\n",
    "    ax.yaxis.set_major_formatter(matplotlib.ticker.PercentFormatter(xmax=1.0, decimals=0))\n",
    "    ax.set_ylabel(\"test error\")\n",
    "\n",
    "    ax.set_xlim(*diff_xrange)\n",
    "    ax.set_xticks(diff_xticks)\n",
    "    ax.xaxis.set_minor_locator(matplotlib.ticker.FixedLocator(diff_xticks_minor))\n",
    "    ax.set_xlabel(r\"filter size\")\n",
    "\n",
    "    plt.show()\n",
    "\n",
    "    fig.savefig(os.path.join(PLOT_BASE_DIR, f\"ablation_filters{noise_label}_error.pdf\"))\n",
    "\n",
    "    plt.close(fig)\n",
    "\n",
    "# Legend\n",
    "handles, labels = ax.get_legend_handles_labels()\n",
    "proxy_patch = plot_util.proxy_patch()\n",
    "handles = [proxy_patch, proxy_patch] + handles\n",
    "labels = [\"interpolating, width:\", \"optimal early-stopped, width:\"] + labels\n",
    "\n",
    "legend_fig = plt.figure(figsize=(\n",
    "    2 * plot_util.FIGURE_WIDTH_FOR_COLUMNS_IN[4],\n",
    "    plot_util.LEGEND_HEIGHT_FOR_ROWS_IN[2],\n",
    "))\n",
    "\n",
    "legend_fig.legend(\n",
    "    handles,\n",
    "    labels,\n",
    "    loc='center left',\n",
    "    ncol=4,\n",
    "    frameon=False,\n",
    "    borderpad=0.5,\n",
    "    handletextpad=0.5,\n",
    "    columnspacing=0.8,\n",
    ")\n",
    "\n",
    "legend_fig.savefig(os.path.join(PLOT_BASE_DIR, f\"ablation_filters_error_legend.pdf\"))\n",
    "\n",
    "plt.close(legend_fig)\n"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(\n",
    "    figsize=(\n",
    "        plot_util.FIGURE_WIDTH_FOR_COLUMNS_IN[4], plot_util.FIGURE_WIDTH_FOR_COLUMNS_IN[6] / plot_util.GOLDEN_RATIO\n",
    "    )\n",
    ")\n",
    "\n",
    "df_ablation_es = df_agg[df_agg.noise_fraction == \"0.2\"].drop(\"noise_fraction\", axis=1).groupby(\n",
    "    [\"network_width\", \"network_filter_size\"]\n",
    ").agg([\"mean\", \"sem\"])\n",
    "for width_idx, width in enumerate(df_ablation_es.index.get_level_values(0).unique()):\n",
    "    df_setting_width = df_ablation_es.loc[width]\n",
    "    filter_size = df_setting_width.index\n",
    "    xs = filter_size\n",
    "\n",
    "    ys_mean = 1.0 - df_setting_width[\"es_train_noise_accuracy\"][\"mean\"]\n",
    "    ys_err = df_setting_width[\"es_train_noise_accuracy\"][\"sem\"]\n",
    "    current_color = f\"C{width_idx}\"\n",
    "    ax.fill_between(xs, ys_mean - ys_err, ys_mean + ys_err, color=current_color, alpha=0.2)\n",
    "    ax.plot(\n",
    "        xs,\n",
    "        ys_mean,\n",
    "        c=current_color,\n",
    "        label=fr\"{width}\",\n",
    "        ls=ES_LINESTYLE_MAP[\"int\"],\n",
    "    )\n",
    "\n",
    "    ys_mean = 1.0 - df_setting_width[\"es_train_noiseless_accuracy\"][\"mean\"]\n",
    "    ys_err = df_setting_width[\"es_train_noiseless_accuracy\"][\"sem\"]\n",
    "    current_color = f\"C{width_idx}\"\n",
    "    ax.fill_between(xs, ys_mean - ys_err, ys_mean + ys_err, color=current_color, alpha=0.2)\n",
    "    ax.plot(\n",
    "        xs,\n",
    "        ys_mean,\n",
    "        c=current_color,\n",
    "        label=fr\"{width}\",\n",
    "        ls=ES_LINESTYLE_MAP[\"es\"],\n",
    "    )\n",
    "\n",
    "ax.set_ylim(0.0 - 1e-2, 1.0 + 1e-2)\n",
    "ax.set_yticks((\n",
    "    0.0,\n",
    "    1.0\n",
    "))\n",
    "ax.yaxis.set_major_formatter(matplotlib.ticker.PercentFormatter(xmax=1.0, decimals=0))\n",
    "ax.set_ylabel(r\"subset train error \\\\(optimal early stopping)\")\n",
    "\n",
    "ax.set_xlim(*diff_xrange)\n",
    "ax.set_xticks(diff_xticks)\n",
    "ax.set_xlabel(r\"filter size\")\n",
    "ax.xaxis.set_minor_locator(matplotlib.ticker.FixedLocator(diff_xticks_minor))\n",
    "\n",
    "plt.show()\n",
    "\n",
    "fig.savefig(os.path.join(PLOT_BASE_DIR, f\"ablation_filters_es.pdf\"))\n",
    "\n",
    "plt.close(fig)\n",
    "\n",
    "# Legend\n",
    "handles, labels = ax.get_legend_handles_labels()\n",
    "proxy_patch = plot_util.proxy_patch()\n",
    "handles = [proxy_patch, proxy_patch] + handles\n",
    "labels = [\"noisy:\", \"clean:\"] + labels\n",
    "\n",
    "legend_fig = plt.figure(figsize=(\n",
    "    plot_util.FIGURE_WIDTH_FOR_COLUMNS_IN[4],\n",
    "    plot_util.LEGEND_HEIGHT_FOR_ROWS_IN[2],\n",
    "))\n",
    "\n",
    "legend_fig.legend(\n",
    "    handles,\n",
    "    labels,\n",
    "    loc='center right',\n",
    "    ncol=4,\n",
    "    frameon=False,\n",
    "    borderpad=0.5,\n",
    "    handletextpad=0.5,\n",
    "    columnspacing=0.8,\n",
    ")\n",
    "\n",
    "legend_fig.savefig(os.path.join(PLOT_BASE_DIR, f\"ablation_filters_es_legend.pdf\"))\n",
    "\n",
    "plt.close(legend_fig)\n"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Training losses"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(\n",
    "    figsize=(plot_util.FIGURE_WIDTH_FOR_COLUMNS_IN[6], plot_util.FIGURE_WIDTH_FOR_COLUMNS_IN[6] / plot_util.GOLDEN_RATIO)\n",
    ")\n",
    "\n",
    "for width_idx, network_width in enumerate((128, 256, 512)):\n",
    "    df_test_error = df_agg.loc[df_agg.network_width == network_width].drop(\n",
    "        [\"network_width\"], axis=1\n",
    "    ).groupby([\"noise_fraction\", \"network_filter_size\"]).agg([\"mean\", \"sem\"])\n",
    "    current_linestyle = plot_util.LINESTYLE_MAP[width_idx]\n",
    "    for noise_idx, noise_fraction in enumerate(df_test_error.index.get_level_values(0).unique()):\n",
    "        df_setting_noise = df_test_error.loc[noise_fraction]\n",
    "        filter_size = df_setting_noise.index\n",
    "        xs = filter_size\n",
    "        ys_mean = df_setting_noise[\"train_loss\"][\"mean\"]\n",
    "        ys_err = df_setting_noise[\"train_loss\"][\"sem\"]\n",
    "        current_color = NOISE_COLOR_MAP[noise_fraction]\n",
    "        ax.fill_between(\n",
    "            xs,\n",
    "            ys_mean - ys_err,\n",
    "            ys_mean + ys_err,\n",
    "            color=current_color,\n",
    "            alpha=0.2,\n",
    "        )\n",
    "        ax.plot(\n",
    "            xs,\n",
    "            ys_mean,\n",
    "            c=current_color,\n",
    "            label=fr\"{network_width}\",\n",
    "            ls=current_linestyle,\n",
    "        )\n",
    "\n",
    "ax.set_ylim(0.0, 3e-3)\n",
    "ax.set_yticks((\n",
    "    0.0,\n",
    "    1.5e-3,\n",
    "    2.4e-3,\n",
    "))\n",
    "ax.ticklabel_format(style=\"sci\", scilimits=(0, 0), axis=\"y\")\n",
    "ax.set_xlim(5, 31)\n",
    "ax.set_xticks((5, 13, 23, 31))\n",
    "ax.xaxis.set_minor_locator(matplotlib.ticker.FixedLocator((5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 27, 31)))\n",
    "\n",
    "ax.set_xlabel(r\"filter size\")\n",
    "ax.set_ylabel(\"training loss\")\n",
    "plt.show()\n",
    "\n",
    "fig.savefig(os.path.join(PLOT_BASE_DIR, \"loss_filters.pdf\"))\n",
    "\n",
    "plt.close(fig)\n",
    "\n",
    "# Legend\n",
    "legend_fig = plt.figure(figsize=(\n",
    "    plot_util.FIGURE_WIDTH_FOR_COLUMNS_IN[6],\n",
    "    plot_util.LEGEND_HEIGHT_FOR_ROWS_IN[2],\n",
    "))\n",
    "handles, labels = ax.get_legend_handles_labels()\n",
    "proxy_patch = plot_util.proxy_patch()\n",
    "handles = [proxy_patch, proxy_patch] + handles\n",
    "labels = [r\"$0\\%$ noise, width:\", r\"$20\\%$ noise, width:\"] + labels\n",
    "legend_fig.legend(\n",
    "    handles,\n",
    "    labels,\n",
    "    loc='center right',\n",
    "    ncol=4,\n",
    "    frameon=False,\n",
    "    borderpad=0.5,  # TODO\n",
    "    handletextpad=0.3,\n",
    "    columnspacing=0.7,\n",
    ")\n",
    "\n",
    "legend_fig.savefig(os.path.join(PLOT_BASE_DIR, \"loss_filters_legend.pdf\"))\n",
    "\n",
    "plt.close(legend_fig)\n"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}