{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "import mlflow\n",
    "import dotenv\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.ticker\n",
    "import numpy as np\n",
    "import data\n",
    "import util\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# Load environment variables\n",
    "dotenv.load_dotenv()\n",
    "\n",
    "# Enable loading of the project module\n",
    "MODULE_DIR = os.path.join(os.path.abspath(os.path.join(os.path.curdir, os.path.pardir)), 'src')\n",
    "sys.path.append(MODULE_DIR)\n",
    "\n",
    "try:\n",
    "    PLOT_BASE_DIR = os.path.abspath(os.environ[\"PLOT_BASE_DIR\"])\n",
    "except KeyError:\n",
    "    raise RuntimeError(\"Missing plot output dir. Set the PLOT_BASE_DIR variable accordingly\")\n",
    "os.makedirs(PLOT_BASE_DIR, exist_ok=True)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# Load data\n",
    "client = mlflow.tracking.MlflowClient()\n",
    "experiment = client.get_experiment_by_name(\"paper_rotations_eurosat\")\n",
    "\n",
    "token = None\n",
    "run_records = []\n",
    "while True:\n",
    "    runs = client.search_runs(experiment.experiment_id, page_token=token)\n",
    "    token = runs.token\n",
    "    for run in runs:\n",
    "        run_dict = {\n",
    "            \"id\": run.info.run_id,\n",
    "            \"status\": run.info.status,\n",
    "        }\n",
    "        run_dict.update(run.data.params)\n",
    "        run_dict.update(run.data.metrics)\n",
    "        run_dict.update({\n",
    "            f\"tags.{tag}\": value for tag, value in run.data.tags.items()\n",
    "        })\n",
    "        run_records.append(run_dict)\n",
    "\n",
    "    if token is None:\n",
    "        break\n",
    "\n",
    "df_runs = pd.DataFrame.from_records(run_records)\n",
    "\n",
    "assert (df_runs[\"seed\"].astype(int) == 1).all()\n",
    "assert (df_runs[\"network_builder_widen_factor\"] == \"6\").all()\n",
    "\n",
    "df_runs_noninterp = df_runs[df_runs.train_accuracy < 1]\n",
    "if len(df_runs_noninterp) > 0:\n",
    "    df_runs_noninterp = df_runs_noninterp[\n",
    "        [\"id\", \"num_train_rotations\", \"noise_fraction\", \"training_seed\", \"train_accuracy\", \"train_loss\"]\n",
    "    ]\n",
    "    print(f\"Found non-interpolating runs:\\n{df_runs_noninterp}\")\n",
    "df_runs = df_runs[df_runs.train_accuracy == 1.0]\n",
    "\n",
    "df_runs_fail = df_runs[(df_runs.status != \"FINISHED\")]\n",
    "if len(df_runs_fail) > 0:\n",
    "    df_runs_fail = df_runs_fail[\n",
    "        [\"id\", \"num_train_rotations\", \"noise_fraction\", \"training_seed\", \"train_accuracy\", \"train_loss\"]\n",
    "    ]\n",
    "    print(f\"Found failed runs:\\n{df_runs_fail}\")\n",
    "df_runs = df_runs[df_runs.status == \"FINISHED\"]\n",
    "\n",
    "assert (df_runs.status == \"FINISHED\").all()\n",
    "assert (df_runs.train_accuracy == 1.0).all()\n",
    "\n",
    "print(f\"Have a total of {len(df_runs)} runs\")\n",
    "\n",
    "COLUMNS_SETTING = [\n",
    "    \"noise_fraction\", \"num_train_rotations\"\n",
    "]\n",
    "COLUMNS_METRICS = [\n",
    "    \"best_final_test_accuracy_diff\", \"best_test_accuracy\", \"test_accuracy_random\",\n",
    "    \"train_accuracy\", \"train_loss\"\n",
    "]\n",
    "COLUMNS = COLUMNS_SETTING + COLUMNS_METRICS\n",
    "\n",
    "df_runs = df_runs.set_index(\"id\")\n",
    "df_runs_full = df_runs\n",
    "df_runs = df_runs[COLUMNS]\n",
    "\n",
    "df_runs.num_train_rotations = df_runs.num_train_rotations.astype(int)\n",
    "\n",
    "# Load early-stopped train accuracy for all runs\n",
    "early_stopped_accuracies = {}\n",
    "early_stopped_noise_accuracies = {}\n",
    "early_stopping_epochs = {}\n",
    "for idx, run_id in enumerate(df_runs.index):\n",
    "    best_test_accuracy = max(client.get_metric_history(run_id, key=\"test_accuracy_random\"), key=lambda metric: metric.value)\n",
    "    best_epoch = best_test_accuracy.step\n",
    "    early_stopping_epochs[run_id] = best_epoch\n",
    "    train_accuracies = client.get_metric_history(run_id, key=\"train_accuracy\")\n",
    "    early_stopped_train_accuracy, = filter(lambda metric: metric.step == best_epoch, train_accuracies)\n",
    "    early_stopped_accuracies[run_id] = early_stopped_train_accuracy.value\n",
    "\n",
    "    if df_runs.loc[run_id][\"noise_fraction\"] == \"0.2\":\n",
    "        train_noise_accuracies = client.get_metric_history(run_id, key=\"train_noise_accuracy\")\n",
    "        early_stopped_noise_accuracy, = filter(lambda metric: metric.step == best_epoch, train_noise_accuracies)\n",
    "        early_stopped_noise_accuracies[run_id] = early_stopped_noise_accuracy.value\n",
    "    else:\n",
    "        early_stopped_noise_accuracies[run_id] = 0.0\n",
    "\n",
    "    if (idx + 1) % 1000 == 0 or (idx + 1) == len(df_runs):\n",
    "        print(f\"Processed {idx + 1}/{len(df_runs)} runs\")\n",
    "\n",
    "df_runs[\"es_train_accuracy\"] = pd.Series(early_stopped_accuracies)\n",
    "df_runs[\"es_train_noise_accuracy\"] = pd.Series(early_stopped_noise_accuracies)\n",
    "df_runs[\"es_train_noiseless_accuracy\"] = (\n",
    "    df_runs[\"es_train_accuracy\"] - df_runs[\"noise_fraction\"].astype(float) * df_runs[\"es_train_noise_accuracy\"]\n",
    ") / (1.0 - df_runs[\"noise_fraction\"].astype(float))\n",
    "df_runs[\"es_train_epoch\"] = pd.Series(early_stopping_epochs)\n"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "import plot_util\n",
    "plot_util.setup_matplotlib()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "NOISE_COLOR_MAP = {\n",
    "    \"0.0\": \"C0\",\n",
    "    \"0.2\": \"C1\",\n",
    "}\n",
    "ES_LINESTYLE_MAP = {\n",
    "    \"int\": plot_util.LINESTYLE_MAP[0],\n",
    "    \"es\": plot_util.LINESTYLE_MAP[1],\n",
    "}\n",
    "MAX_ROT = 12"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Test error curve"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(\n",
    "    figsize=(plot_util.FIGURE_WIDTH_FOR_COLUMNS_IN[6], plot_util.FIGURE_WIDTH_FOR_COLUMNS_IN[6] / plot_util.GOLDEN_RATIO)\n",
    ")\n",
    "\n",
    "df_test_error = df_runs.groupby([\"noise_fraction\", \"num_train_rotations\"]).agg([\"mean\", \"sem\"])\n",
    "# ms = 4\n",
    "# mfc = \"none\"\n",
    "# mew = 0.4\n",
    "for noise_idx, noise_fraction in enumerate(df_test_error.index.get_level_values(0).unique()):\n",
    "    df_setting_noise = df_test_error.loc[noise_fraction]\n",
    "    num_train_rotations = df_setting_noise.index\n",
    "    xs = num_train_rotations\n",
    "    ys_mean = 1.0 - df_setting_noise[\"test_accuracy_random\"][\"mean\"]\n",
    "    ys_err = df_setting_noise[\"test_accuracy_random\"][\"sem\"]\n",
    "    current_color = NOISE_COLOR_MAP[noise_fraction]\n",
    "    ax.fill_between(\n",
    "        xs,\n",
    "        ys_mean - ys_err,\n",
    "        ys_mean + ys_err,\n",
    "        color=current_color,\n",
    "        alpha=0.2,\n",
    "    )\n",
    "    ax.plot(\n",
    "        xs,\n",
    "        ys_mean,\n",
    "        c=current_color,\n",
    "        label=fr\"{float(noise_fraction)*100:.0f}\\% noise\",\n",
    "        ls=ES_LINESTYLE_MAP[\"int\"],\n",
    "        # marker=plot_util.MARKER_MAP[0],\n",
    "        # mfc=mfc,\n",
    "        # mew=mew,\n",
    "        # ms=ms,\n",
    "    )\n",
    "\n",
    "    ys_mean = 1.0 - df_setting_noise[\"best_test_accuracy\"][\"mean\"]\n",
    "    ys_err = df_setting_noise[\"best_test_accuracy\"][\"sem\"]\n",
    "    ax.fill_between(\n",
    "        xs,\n",
    "        ys_mean - ys_err,\n",
    "        ys_mean + ys_err,\n",
    "        color=current_color,\n",
    "        alpha=0.2,\n",
    "    )\n",
    "    ax.plot(\n",
    "        xs, ys_mean,\n",
    "        c=current_color,\n",
    "        label=f\"{float(noise_fraction)*100:.0f}\\% noise\",\n",
    "        ls=ES_LINESTYLE_MAP[\"es\"],\n",
    "        # marker=plot_util.MARKER_MAP[2],\n",
    "        # mfc=mfc,\n",
    "        # mew=mew,\n",
    "        # ms=ms,\n",
    "    )\n",
    "\n",
    "ax.set_ylim(0.0, 0.175)\n",
    "ax.set_yticks((\n",
    "    0.0,\n",
    "    0.03,\n",
    "    0.08,\n",
    "    0.12,\n",
    "    0.17,\n",
    "))\n",
    "ax.set_ylabel(\"test error\")\n",
    "ax.yaxis.set_major_formatter(matplotlib.ticker.PercentFormatter(xmax=1.0, decimals=0))\n",
    "\n",
    "ax.set_xlim(1, MAX_ROT)\n",
    "ax.set_xticks((1, 4, MAX_ROT))\n",
    "ax.xaxis.set_minor_locator(matplotlib.ticker.FixedLocator(df_runs.num_train_rotations.unique()))\n",
    "ax.invert_xaxis()\n",
    "ax.set_xlabel(r\"\\# rotations\")\n",
    "\n",
    "plt.show()\n",
    "\n",
    "fig.savefig(os.path.join(PLOT_BASE_DIR, \"rotations_error.pdf\"))\n",
    "\n",
    "plt.close(fig)\n",
    "\n",
    "# Legend\n",
    "handles, labels = ax.get_legend_handles_labels()\n",
    "proxy_patch = plot_util.proxy_patch()\n",
    "handles = [proxy_patch, proxy_patch] + handles\n",
    "labels = [\"interpolating:\", \"opt.\\ early-stopped:\"] + labels\n",
    "\n",
    "legend_fig = plt.figure(figsize=(\n",
    "    plot_util.FIGURE_WIDTH_FOR_COLUMNS_IN[6],\n",
    "    plot_util.LEGEND_HEIGHT_FOR_ROWS_IN[2],\n",
    "))\n",
    "\n",
    "legend_fig.legend(\n",
    "    handles,\n",
    "    labels,\n",
    "    loc='center right',\n",
    "    ncol=3,\n",
    "    frameon=False,\n",
    "    borderpad=0.5,\n",
    "    # handletextpad=0.3,\n",
    "    columnspacing=1.0,\n",
    ")\n",
    "\n",
    "legend_fig.savefig(os.path.join(PLOT_BASE_DIR, \"rotations_error_legend.pdf\"))\n",
    "\n",
    "plt.close(legend_fig)\n"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Early-stopped train error curves"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(\n",
    "    figsize=(\n",
    "        plot_util.FIGURE_WIDTH_FOR_COLUMNS_IN[6], plot_util.FIGURE_WIDTH_FOR_COLUMNS_IN[6] / plot_util.GOLDEN_RATIO\n",
    "    )\n",
    ")\n",
    "\n",
    "df_es_train_error = df_runs.groupby([\"noise_fraction\", \"num_train_rotations\"]).agg([\"mean\", \"sem\"])\n",
    "for noise_idx, noise_fraction in enumerate(df_es_train_error.index.get_level_values(0).unique()):\n",
    "    if noise_fraction == \"0.0\":\n",
    "        continue\n",
    "\n",
    "    # df_setting_noise = df_es_train_error.loc[noise_fraction]\n",
    "    # num_train_rotations = df_setting_noise.index\n",
    "    # xs = num_train_rotations\n",
    "    # ys_mean = 1.0 - df_setting_noise[\"es_train_accuracy\"][\"mean\"]\n",
    "    # ys_err = df_setting_noise[\"es_train_accuracy\"][\"sem\"]\n",
    "    # current_color = NOISE_COLOR_MAP[noise_fraction]\n",
    "    # ax.fill_between(xs, ys_mean - ys_err, ys_mean + ys_err, color=current_color, alpha=0.2)\n",
    "    # ax.plot(\n",
    "    #     xs,\n",
    "    #     ys_mean,\n",
    "    #     c=current_color,\n",
    "    #     label=fr\"{float(noise_fraction) * 100:.0f}\\% noise\",\n",
    "    #     ls=ES_LINESTYLE_MAP[\"es\"],\n",
    "    # )\n",
    "\n",
    "    df_setting_noise = df_es_train_error.loc[noise_fraction]\n",
    "    num_train_rotations = df_setting_noise.index\n",
    "    xs = num_train_rotations\n",
    "    ys_mean = 1.0 - df_setting_noise[\"es_train_noise_accuracy\"][\"mean\"]\n",
    "    ys_err = df_setting_noise[\"es_train_noise_accuracy\"][\"sem\"]\n",
    "    current_color = \"C3\"\n",
    "    ax.fill_between(xs, ys_mean - ys_err, ys_mean + ys_err, color=current_color, alpha=0.2)\n",
    "    ax.plot(\n",
    "        xs,\n",
    "        ys_mean,\n",
    "        c=current_color,\n",
    "        label=fr\"noisy subset\",\n",
    "        ls=ES_LINESTYLE_MAP[\"es\"],\n",
    "    )\n",
    "\n",
    "    ys_mean = 1.0 - df_setting_noise[\"es_train_noiseless_accuracy\"][\"mean\"]\n",
    "    ys_err = df_setting_noise[\"es_train_noiseless_accuracy\"][\"sem\"]\n",
    "    current_color = \"C2\"\n",
    "    ax.fill_between(xs, ys_mean - ys_err, ys_mean + ys_err, color=current_color, alpha=0.2)\n",
    "    ax.plot(\n",
    "        xs,\n",
    "        ys_mean,\n",
    "        c=current_color,\n",
    "        label=fr\"clean subset\",\n",
    "        ls=ES_LINESTYLE_MAP[\"es\"],\n",
    "    )\n",
    "\n",
    "\n",
    "ax.set_ylim(0.0 - 1e-2, 1.0 + 1e-2)\n",
    "ax.set_yticks((\n",
    "    0.0,\n",
    "    0.38,\n",
    "    1.0\n",
    "))\n",
    "ax.yaxis.set_major_formatter(matplotlib.ticker.PercentFormatter(xmax=1.0, decimals=0))\n",
    "ax.set_ylabel(r\"subset train error \\\\(optimal early stopping)\")\n",
    "\n",
    "ax.set_xlim(1, MAX_ROT)\n",
    "ax.set_xticks((1, 4, MAX_ROT))\n",
    "ax.xaxis.set_minor_locator(matplotlib.ticker.FixedLocator(df_runs.num_train_rotations.unique()))\n",
    "ax.invert_xaxis()\n",
    "ax.set_xlabel(r\"\\# rotations\")\n",
    "\n",
    "plt.show()\n",
    "\n",
    "fig.savefig(os.path.join(PLOT_BASE_DIR, \"rotations_es.pdf\"))\n",
    "\n",
    "plt.close(fig)\n",
    "\n",
    "# Legend\n",
    "handles, labels = ax.get_legend_handles_labels()\n",
    "proxy_patch = plot_util.proxy_patch()\n",
    "handles = [proxy_patch, handles[0], proxy_patch, handles[1]]\n",
    "labels = [\"\", labels[0], \"\", labels[1]]\n",
    "\n",
    "legend_fig = plt.figure(figsize=(\n",
    "    plot_util.FIGURE_WIDTH_FOR_COLUMNS_IN[6],\n",
    "    plot_util.LEGEND_HEIGHT_FOR_ROWS_IN[2],\n",
    "))\n",
    "\n",
    "legend_fig.legend(\n",
    "    handles,\n",
    "    labels,\n",
    "    loc='center',\n",
    "    ncol=2,\n",
    "    frameon=False,\n",
    "    borderpad=0.5,\n",
    ")\n",
    "\n",
    "legend_fig.savefig(os.path.join(PLOT_BASE_DIR, \"rotations_es_legend.pdf\"))\n",
    "\n",
    "plt.close(legend_fig)\n"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Training loss"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(\n",
    "    figsize=(plot_util.FIGURE_WIDTH_FOR_COLUMNS_IN[6], plot_util.FIGURE_WIDTH_FOR_COLUMNS_IN[6] / plot_util.GOLDEN_RATIO)\n",
    ")\n",
    "\n",
    "df_loss = df_runs.groupby([\"noise_fraction\", \"num_train_rotations\"]).agg([\"mean\", \"sem\"])\n",
    "# ms = 4\n",
    "# mfc = \"none\"\n",
    "# mew = 0.4\n",
    "for noise_idx, noise_fraction in enumerate(df_test_error.index.get_level_values(0).unique()):\n",
    "    df_setting_noise = df_loss.loc[noise_fraction]\n",
    "    num_train_rotations = df_setting_noise.index\n",
    "    xs = num_train_rotations\n",
    "    ys_mean = df_setting_noise[\"train_loss\"][\"mean\"]\n",
    "    ys_err = df_setting_noise[\"train_loss\"][\"sem\"]\n",
    "    current_color = NOISE_COLOR_MAP[noise_fraction]\n",
    "    ax.fill_between(\n",
    "        xs,\n",
    "        ys_mean - ys_err,\n",
    "        ys_mean + ys_err,\n",
    "        color=current_color,\n",
    "        alpha=0.2,\n",
    "    )\n",
    "    ax.plot(\n",
    "        xs,\n",
    "        ys_mean,\n",
    "        c=current_color,\n",
    "        label=fr\"{float(noise_fraction)*100:.0f}\\% noise\",\n",
    "        ls=ES_LINESTYLE_MAP[\"int\"],\n",
    "        # marker=plot_util.MARKER_MAP[0],\n",
    "        # mfc=mfc,\n",
    "        # mew=mew,\n",
    "        # ms=ms,\n",
    "    )\n",
    "\n",
    "ax.set_ylim(0.0, 2.5e-5)\n",
    "ax.set_yticks((\n",
    "    0.0,\n",
    "    6.5e-6,\n",
    "    9.9e-6,\n",
    "))\n",
    "ax.set_ylabel(\"training loss\")\n",
    "\n",
    "ax.set_xlim(1, MAX_ROT)\n",
    "ax.set_xticks((1, 4, MAX_ROT))\n",
    "ax.xaxis.set_minor_locator(matplotlib.ticker.FixedLocator(df_runs.num_train_rotations.unique()))\n",
    "ax.invert_xaxis()\n",
    "ax.set_xlabel(r\"\\# rotations\")\n",
    "\n",
    "plt.show()\n",
    "\n",
    "fig.savefig(os.path.join(PLOT_BASE_DIR, \"loss_rotations.pdf\"))\n",
    "\n",
    "plt.close(fig)\n",
    "\n",
    "\n",
    "# Legend\n",
    "legend_fig = plt.figure(figsize=(\n",
    "    plot_util.FIGURE_WIDTH_FOR_COLUMNS_IN[6],\n",
    "    plot_util.LEGEND_HEIGHT_FOR_ROWS_IN[2],\n",
    "))\n",
    "handles, labels = ax.get_legend_handles_labels()\n",
    "proxy_patch = plot_util.proxy_patch()\n",
    "handles = [proxy_patch, handles[0], proxy_patch, handles[1]]\n",
    "labels = [\"\", labels[0], \"\", labels[1]]\n",
    "legend_fig.legend(\n",
    "    handles,\n",
    "    labels,\n",
    "    loc='center',\n",
    "    ncol=2,\n",
    "    frameon=False,\n",
    "    borderpad=0.5,  # TODO\n",
    ")\n",
    "\n",
    "legend_fig.savefig(os.path.join(PLOT_BASE_DIR, \"loss_rotations_legend.pdf\"))\n",
    "\n",
    "plt.close(legend_fig)\n"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}