{
 "cells": [
  {
   "cell_type": "code",
   "id": "initial_id",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "from Plotting import *\n",
    "import pickle\n",
    "import numpy as np\n",
    "from matplotlib import pyplot as plt"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "# Load data\n",
    "\n",
    "with open('out/example_GD.pkl', 'rb') as f:\n",
    "    res_GD = pickle.load(f)\n",
    "    \n",
    "with open('out/example_OpGD.pkl', 'rb') as f:\n",
    "    res_OpGD = pickle.load(f)\n"
   ],
   "id": "a5ed654c22480001",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": "res_GD",
   "id": "a778703f58886681",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "def compareMinGenErrsN(expr_res_list,names, export_filename=None):\n",
    "    plt.figure()\n",
    "    plt.xlabel(\"n (logarithmic)\")\n",
    "    plt.ylabel(\"$\\\\log_{10}$ generalization error \")\n",
    "    p = expr_res_list[0].meta[\"f_decay\"]\n",
    "    ns = expr_res_list[0].ns\n",
    "    ns = np.array(ns)\n",
    "    plt.title(f\"Generalization Errors; optimal rate={(2 * p - 1) / (2 * p):.3f}\")\n",
    "    # rep = ress[0][0].shape[0]\n",
    "    lines = []\n",
    "    for expr_res,name in zip(expr_res_list,names):\n",
    "        rep = expr_res.meta[\"repeats\"]\n",
    "        means = []\n",
    "        log_stds = []\n",
    "        for res, T, n in zip(expr_res.gen_err_list, expr_res.T_ada_list, expr_res.ns):\n",
    "            min_gen_errs = np.min(res, axis=1)\n",
    "            means.append(np.mean(min_gen_errs))\n",
    "            if rep > 1:\n",
    "                min_gen_errs_log = np.log10(min_gen_errs)\n",
    "                log_stds.append(np.nanstd(min_gen_errs_log))\n",
    "        means = np.array(means)\n",
    "        log_means = np.log10(means)\n",
    "        if rep > 1:\n",
    "            l = plt.errorbar(np.log10(ns), log_means, yerr=log_stds, fmt=\"-o\", label=name)\n",
    "        else:\n",
    "            l = plt.plot(np.log10(ns), log_means, \"-o\", label=name)[0]\n",
    "        lines.append(l)\n",
    "        # make a linear fit\n",
    "        k, m = solveAB(np.log10(ns), log_means)\n",
    "        l = plt.plot(np.log10(ns), k * np.log10(ns) + m, color='gray', linestyle=\"--\",\n",
    "                 label=f\"$\\\\log Err\\\\approx {k:.3f} \\\\log n + {m:.2f}$\")[0]\n",
    "        lines.append(l)\n",
    "        # plt.xticks(np.log10(ns), [n // 100 for n in ns])\n",
    "    xlabels = [ns[0]] + [\"\"] * (len(ns) - 2) + [ns[-1]]\n",
    "    plt.xticks(np.log10(ns), xlabels)\n",
    "    plt.legend(lines, [l.get_label() for l in lines])\n",
    "    if export_filename is not None:\n",
    "        plt.savefig(export_filename)\n",
    "    plt.show()\n",
    "compareMinGenErrsN([res_GD, res_OpGD],[\"Vanilla GD\",\"OpGD\"], \"out/CompareMinGenN_p1q2L15.pdf\")"
   ],
   "id": "61a6960ed7483daa",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "def compareMinGenStepsN(expr_res_list,names, export_filename=None):\n",
    "    plt.figure()\n",
    "    plt.xlabel(\"n (logarithmic)\")\n",
    "    plt.ylabel(\"log10 min gen step\")\n",
    "    p = expr_res_list[0].meta[\"f_decay\"]\n",
    "    ns = expr_res_list[0].ns\n",
    "    ns = np.array(ns)\n",
    "    plt.title(f\"Oracle stopping time\")\n",
    "    # rep = ress[0][0].shape[0]\n",
    "    lines = []\n",
    "    for expr_res,name in zip(expr_res_list,names):\n",
    "        rep = expr_res.meta[\"repeats\"]\n",
    "        means = []\n",
    "        stds = []\n",
    "        ticks = expr_res.meta[\"ticks\"]\n",
    "        for res, T, n in zip(expr_res.gen_err_list, expr_res.T_ada_list, expr_res.ns):\n",
    "            min_gen_errs = np.log10(np.argmin(res, axis=1) * T / ticks)\n",
    "            means.append(np.mean(min_gen_errs))\n",
    "            if rep > 1:\n",
    "                stds.append(np.nanstd(min_gen_errs))\n",
    "        ns = expr_res.ns\n",
    "        if rep > 1:\n",
    "            l = plt.errorbar(np.log10(ns), means, yerr=stds, fmt=\"-o\", label=name)\n",
    "        else:\n",
    "            l = plt.plot(np.log10(ns), means, \"-o\", label=name)[0]\n",
    "        lines.append(l)\n",
    "        # make a linear fit\n",
    "        ns = np.array(ns)\n",
    "        means = np.array(means)\n",
    "        k, m = solveAB(np.log10(ns), means)\n",
    "        l = plt.plot(np.log10(ns), k * np.log10(ns) + m, color='gray', linestyle=\"--\",\n",
    "                 label=f\"$\\\\log t\\\\approx {k:.2f} \\\\log n + {m:.2f}$\")[0]\n",
    "        lines.append(l)\n",
    "    xlabels = [ns[0]] + [\"\"] * (len(ns) - 2) + [ns[-1]]\n",
    "    plt.xticks(np.log10(ns), xlabels)\n",
    "    plt.legend(lines, [l.get_label() for l in lines])\n",
    "    if export_filename is not None:\n",
    "        plt.savefig(export_filename)\n",
    "    plt.show()\n",
    "compareMinGenStepsN([res_GD, res_OpGD],[\"Vanilla GD\",\"OpGD\"], \"out/CompareMinGenStepN_p1q2L3.pdf\")"
   ],
   "id": "bc40fab1b57a4a7a",
   "outputs": [],
   "execution_count": null
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
