{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "02b860a2",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os, sys\n",
    "path_to_this_notebook = os.path.abspath('.')\n",
    "path_to_project = path_to_this_notebook[:path_to_this_notebook.find('note')]\n",
    "sys.path.append(path_to_project)\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1b169e37",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import pickle\n",
    "from collections import defaultdict\n",
    "import re\n",
    "import numpy as np\n",
    "%matplotlib inline\n",
    "\n",
    "\n",
    "def find_seed(name):\n",
    "    return re.findall('_\\([0-9]+, [0-9]+\\)', name)[0]\n",
    "\n",
    "def name_no_seed(name):\n",
    "    seed = find_seed(name)\n",
    "    return name.replace(seed, '')\n",
    "\n",
    "import matplotlib as mpl\n",
    "\n",
    "\n",
    "def filter_name(name):\n",
    "    epochs = re.findall('epochs=[0-9]+_', name)[0]\n",
    "    return name.replace(epochs, '')\n",
    "    name_short = name[:name.find('lr') - 1]\n",
    "    name_short =  name_short.replace('P=0-', '').replace('D=0-', '',).replace('R=0_', '').replace('MSE=0', '')\n",
    "    if name_short[-1] in ['-', '_']:\n",
    "        name_short = name_short[:-1]\n",
    "\n",
    "    return name_short\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f96d5882",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "path_to_problems = path_to_project + 'experiments/vopf/'\n",
    "problem_names = sorted(os.listdir(path_to_problems))\n",
    "print('Problems:')\n",
    "print('\\n'.join(problem_names))\n",
    "problem_name = problem_names[0]\n",
    "\n",
    "print()\n",
    "print('Chosen:', problem_name)\n",
    "print()\n",
    "experiment_names = sorted(os.listdir(path_to_problems + problem_name), key=lambda x: '0' if 'van' in x else x)\n",
    "print('Experiments for this problem:')\n",
    "print('\\n'.join(experiment_names))\n",
    "print()\n",
    "experiment_name = experiment_names[1]\n",
    "print('Chosen:', experiment_name)\n",
    "base_path = path_to_problems + '/%s/%s/' % (problem_name, experiment_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "332afe41",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "train_results_per_run = defaultdict(list)\n",
    "val_results_per_run = defaultdict(list)\n",
    "test_results_per_run = defaultdict(list)\n",
    "for exp_name in experiment_names:\n",
    "    #if 'D=0.01' in exp_name or 'P=1e-06' in exp_name or 'P=0.0001' in exp_name or 'P=1e-05' in exp_name or 'P=1e-08' in exp_name :\n",
    "    #    continue\n",
    "    try:\n",
    "        base_path = path_to_problems + '/%s/%s/' % (problem_name, exp_name)\n",
    "        with open(base_path + '/validate_history.pickle', 'rb') as f:\n",
    "            validate_history = pickle.load(f)\n",
    "        with open(base_path + '/training_history.pickle', 'rb') as f:\n",
    "            train_history = pickle.load(f)\n",
    "        with open(base_path + '/training_history.pickle', 'rb') as f:\n",
    "            test_history = pickle.load(f)\n",
    "        exp_name_no_seed = name_no_seed(exp_name)\n",
    "        train_results_per_run[exp_name_no_seed].append(train_history)\n",
    "        val_results_per_run[exp_name_no_seed].append(validate_history)\n",
    "        test_results_per_run[exp_name_no_seed].append(test_history)\n",
    "        print('Loaded %s' % exp_name)\n",
    "    except:\n",
    "        print('Could not load %s' % exp_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "acf51d92",
   "metadata": {},
   "outputs": [],
   "source": [
    "keys_to_plot = list(train_results_per_run.keys())\n",
    "keys_to_plot = [k for k in keys_to_plot\n",
    "                #if 'frozen' in k\n",
    "               ]\n",
    "print(keys_to_plot)\n",
    "renaming = {}\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "68118d8a",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.rc('font', family='serif')\n",
    "colors = ['#A5A5A5', '#9BA711', '#0058A0', '#9E0202']\n",
    "\n",
    "\n",
    "color_dict = {'vanilla_lr=5e-05_bs=1_epochs=250_pred-x_u_frozen-Wsq': colors[0],\n",
    "              'PD=0.0001_lr=5e-05_bs=1_epochs=250_pred-x_u_frozen-Wsq': colors[1],\n",
    "              'LOC=1_lr=5e-05_bs=1_epochs=250_pred-x_u_frozen-Wsq': colors[2],\n",
    "              'R=0_PD=0.0001_LOC=1_lr=5e-05_bs=1_epochs=250_pred-x_u_frozen-Wsq': colors[3],\n",
    "             }\n",
    "name_dict = {'vanilla_lr=5e-05_bs=1_epochs=250_pred-x_u_frozen-Wsq': 'Standard',\n",
    "              'PD=0.0001_lr=5e-05_bs=1_epochs=250_pred-x_u_frozen-Wsq': 'Proj. distance',\n",
    "              'LOC=1_lr=5e-05_bs=1_epochs=250_pred-x_u_frozen-Wsq': '$r-$smoothing',\n",
    "              'R=0_PD=0.0001_LOC=1_lr=5e-05_bs=1_epochs=250_pred-x_u_frozen-Wsq':  '$r-$smoothing + \\nproj.distance',\n",
    "             }\n",
    "\n",
    "\n",
    "#name = 'qp_vs_standard'\n",
    "name = 'opf'\n",
    "fig, ax = plt.subplots(figsize=(10, 8))\n",
    "#fig, ax = plt.figure()\n",
    "for key in color_dict:\n",
    "    res_test = test_results_per_run[key]\n",
    "    mean_val = np.mean([d['regret'] for d in res_test], axis=0)\n",
    "    std = np.std([d['regret'] for d in res_test], axis=0)\n",
    "    x = np.arange(0, len(mean_val))\n",
    "    c = color_dict[key]\n",
    "    #m = choose_marker(key)\n",
    "    m = '-' if 'LOC' in key else '--'\n",
    "    m = '-'\n",
    "    #m = '-' if 'froz' in key else '--'\n",
    "    p = ax.plot(x, mean_val, m, color=c, label=name_dict[key], linewidth=6)\n",
    "    #plt.scatter(x, mean_val,  marker='X', s=256, color=p[0].get_color())\n",
    "    ax.fill_between(x, mean_val - std, mean_val + std, color=p[0].get_color(), alpha=0.3)\n",
    "    \n",
    "_ = ax.legend(loc='upper right', prop={'size': 32})\n",
    "#_ = plt.title('Regret on the test set')\n",
    "_ = ax.set_xlabel('Training epoch', fontsize=32)\n",
    "_ = ax.set_ylabel('Regret', rotation=0, fontsize=32)\n",
    "\n",
    "ax.set_xticklabels(map(int, ax.get_xticks()), fontsize=25)\n",
    "ax.set_yticks(ax.get_yticks()[::2])\n",
    "ax.set_yticks([0.0, 0.02, 0.04, .06, ])\n",
    "ax.set_yticklabels(map(lambda x: np.round(x, 4), ax.get_yticks()), fontsize=25)\n",
    "ax.yaxis.set_label_coords(-.08, 0.5)\n",
    "plt.gca().set_position([0, 0, 1, 1])\n",
    "plt.savefig(name + '.svg')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fc1966a7",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "09c8195e",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "84a28a8a",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e87f60f0",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f228f772",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "99f08d86",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5bc2d62b",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6b06dc0e",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
