{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 118,
   "id": "10baab73",
   "metadata": {},
   "outputs": [],
   "source": [
    "from haven import haven_jupyter as hj\n",
    "from haven import haven_results as hr\n",
    "from haven import haven_utils as hu\n",
    "import pprint\n",
    "import numpy as np\n",
    "from experiments_config.syn_interp_exp import *\n",
    "from experiments_config.kernelize_exp import *\n",
    "from experiments_config.syn_vary_n_exp import *\n",
    "from experiments_config.syn_cyclic_exp import *\n",
    "from experiments_config.syn_check_alpha_beta import *\n",
    "from experiments_config.syn_non_interp_exp import *\n",
    "from exp_configs import get_exp_group"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "896794f0",
   "metadata": {},
   "outputs": [],
   "source": [
    "def form_exp_list(exp_groups, exp_config):\n",
    "    exp_list = []\n",
    "    for e in exp_groups:\n",
    "        exp_list += get_exp_group(**exp_config)[e]\n",
    "    return exp_list\n",
    "import sys"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 688,
   "id": "63c205d8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# path to where the experiments are saved\n",
    "base = './outputs/'\n",
    "savedir_ijcnn = base + 'output_ijcnn'\n",
    "savedir_synthetic_main_plot = base + 'output_synthetic_main_plot'\n",
    "\n",
    "exp_config_fname = './exp_configs.py'\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 683,
   "id": "f2e17046",
   "metadata": {},
   "outputs": [],
   "source": [
    "markersize = 15\n",
    "markevery = 20\n",
    "linewidth = 3\n",
    "shb = 'solid'\n",
    "sgd = 'dotted'\n",
    "styles = {}\n",
    "\n",
    "styles['SGD_CNST'] = {'label':'SGD',  'color':'black',\n",
    "                         'marker':'d', 'markersize':markersize,\n",
    "                           'linewidth': linewidth, 'linestyle':shb,\n",
    "                     'markevery':markevery }\n",
    "\n",
    "styles['SHB_CNST'] = {'label':'SHB',  'color':'blue',\n",
    "                         'marker':'o', 'markersize':markersize,\n",
    "                           'linewidth': linewidth, 'linestyle':shb,\n",
    "                     'markevery':markevery }\n",
    "\n",
    "styles['SHB_MUL_0.4_False'] = {'label':'Multi-SHB',  'color':'green',\n",
    "                         'marker':'*', 'markersize':markersize,\n",
    "                           'linewidth': linewidth, 'linestyle':shb,\n",
    "                     'markevery':markevery }\n",
    "\n",
    "styles['SHB_MUL_auto_False'] = {'label':'Multi-SHB-adaptive-increase',  'color':'darkcyan',\n",
    "                         'marker':'v', 'markersize':markersize,\n",
    "                           'linewidth': linewidth, 'linestyle':shb,\n",
    "                     'markevery':markevery }\n",
    "\n",
    "styles['SHB_MUL_0.4_True'] = {'label':'Multi-SHB-CNST',  'color':'orange',\n",
    "                         'marker':'h', 'markersize':markersize,\n",
    "                           'linewidth': linewidth, 'linestyle':shb,\n",
    "                     'markevery':markevery }\n",
    "\n",
    "styles['SHB_MUL_auto_True'] = {'label':'Multi-SHB-adaptive-constant',  'color':'brown',\n",
    "                         'marker':'p', 'markersize':markersize,\n",
    "                           'linewidth': linewidth, 'linestyle':shb,\n",
    "                     'markevery':markevery }\n",
    "\n",
    "styles['SGD_ACC_EXP'] = {'label':'Nesterov-EXP',  'color':'purple',\n",
    "                         'marker':'s', 'markersize':markersize,\n",
    "                           'linewidth': linewidth, 'linestyle':shb,\n",
    "                     'markevery':markevery }\n",
    "\n",
    "styles['SHB_MIX_0.5'] = {'label':'2P-SHB',  'color':'red',\n",
    "                         'marker':'P', 'markersize':markersize,\n",
    "                           'linewidth': linewidth, 'linestyle':shb,\n",
    "                     'markevery':markevery }\n",
    "\n",
    "# -------------------------------------------------------------------------------------------\n",
    "\n",
    "styles['SHB_EXP'] = {'label':'SHB-EXP',  'color':'darkcyan',\n",
    "                         'marker':'v', 'markersize':markersize,\n",
    "                           'linewidth': linewidth, 'linestyle':shb,\n",
    "                     'markevery':markevery }\n",
    "\n",
    "styles['SGD_ACC'] = {'label':'Nesterov',  'color':'cyan',\n",
    "                         'marker':'h', 'markersize':markersize,\n",
    "                           'linewidth': linewidth, 'linestyle':shb,\n",
    "                     'markevery':markevery }\n",
    "\n",
    "styles['SGD_EXP'] = {'label':'SGD-EXP',  'color':'black',\n",
    "                         'marker':'d', 'markersize':markersize,\n",
    "                           'linewidth': linewidth, 'linestyle':shb,\n",
    "                     'markevery':markevery }\n",
    "\n",
    "styles['SHB_CNST_1.0'] = {'label':'SHB-1.0',  'color':'darkcyan',\n",
    "                         'marker':'d', 'markersize':markersize,\n",
    "                           'linewidth': linewidth, 'linestyle':shb,\n",
    "                     'markevery':markevery }\n",
    "styles['SHB_CNST_1'] = {'label':'SHB-1.0',  'color':'darkcyan',\n",
    "                         'marker':'d', 'markersize':markersize,\n",
    "                           'linewidth': linewidth, 'linestyle':shb,\n",
    "                     'markevery':markevery }\n",
    "styles['SHB_CNST_0.9'] = {'label':'SHB-0.9',  'color':'m',\n",
    "                         'marker':'v', 'markersize':markersize,\n",
    "                           'linewidth': linewidth, 'linestyle':shb,\n",
    "                     'markevery':markevery }\n",
    "styles['SHB_CNST_0.8'] = {'label':'SHB-0.8',  'color':'red',\n",
    "                         'marker':'s', 'markersize':markersize,\n",
    "                           'linewidth': linewidth, 'linestyle':shb,\n",
    "                     'markevery':markevery }\n",
    "styles['SHB_CNST_0.7'] = {'label':'SHB-0.7',  'color':'brown',\n",
    "                         'marker':'p', 'markersize':markersize,\n",
    "                           'linewidth': linewidth, 'linestyle':shb,\n",
    "                     'markevery':markevery }\n",
    "styles['SHB_CNST_0.6'] = {'label':'SHB-0.6',  'color':'orange',\n",
    "                         'marker':'8', 'markersize':markersize,\n",
    "                           'linewidth': linewidth, 'linestyle':shb,\n",
    "                     'markevery':markevery }\n",
    "styles['SHB_CNST_0.5'] = {'label':'SHB-0.5',  'color':'mediumpurple',\n",
    "                         'marker':'H', 'markersize':markersize,\n",
    "                           'linewidth': linewidth, 'linestyle':shb,\n",
    "                     'markevery':markevery }\n",
    "styles['SHB_CNST_0.4'] = {'label':'SHB-0.4',  'color':'skyblue',\n",
    "                         'marker':'X', 'markersize':markersize,\n",
    "                           'linewidth': linewidth, 'linestyle':shb,\n",
    "                     'markevery':markevery }\n",
    "styles['SHB_CNST_0.3'] = {'label':'SHB-0.3',  'color':'green',\n",
    "                         'marker':'x', 'markersize':markersize,\n",
    "                           'linewidth': linewidth, 'linestyle':shb,\n",
    "                     'markevery':markevery }\n",
    "styles['SHB_CNST_0.25'] = {'label':'SHB-0.25',  'color':'cyan',\n",
    "                         'marker':'>', 'markersize':markersize,\n",
    "                           'linewidth': linewidth, 'linestyle':shb,\n",
    "                     'markevery':markevery }\n",
    "styles['SHB_CNST_0.12'] = {'label':'SHB-0.125',  'color':'pink',\n",
    "                         'marker':'h', 'markersize':markersize,\n",
    "                           'linewidth': linewidth, 'linestyle':shb,\n",
    "                     'markevery':markevery }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 704,
   "id": "b2818f4d",
   "metadata": {},
   "outputs": [],
   "source": [
    "styles['SHB_MUL_0.5'] = {'label':'SHB-MUL-0.5',  'color':'pink',\n",
    "                         'marker':'h', 'markersize':markersize,\n",
    "                           'linewidth': linewidth, 'linestyle':shb,\n",
    "                     'markevery':markevery }\n",
    "styles['SHB_MUL_auto'] = {'label':'SHB-MUL',  'color':'slategray',\n",
    "                         'marker':'^', 'markersize':markersize,\n",
    "                           'linewidth': linewidth, 'linestyle':shb,\n",
    "                     'markevery':markevery }\n",
    "styles['SHB_MUL_PAN_2'] = {'label':'Multi-SHB-PAN-2',  'color':'mediumpurple',\n",
    "                         'marker':'8', 'markersize':markersize,\n",
    "                           'linewidth': linewidth, 'linestyle':shb,\n",
    "                     'markevery':markevery }\n",
    "styles['SHB_MUL_PAN_max'] = {'label':'Multi-SHB-PAN-T-KAP',  'color':'navy',\n",
    "                         'marker':'3', 'markersize':markersize,\n",
    "                           'linewidth': linewidth, 'linestyle':shb,\n",
    "                     'markevery':markevery }"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f6853fca",
   "metadata": {},
   "source": [
    "# ExpList for Losses"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 808,
   "id": "f5d9dd85",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "def get_res_dict(config, datasets, batch_size, kappa, variance, runs=[0,1,2]):\n",
    "    exp_groups = []\n",
    "\n",
    "    for dataset in datasets:\n",
    "        exp_groups.append(\"exp_\" + dataset)\n",
    "    exp_list = form_exp_list(exp_groups, config)\n",
    "    loss_exp_dict={}\n",
    "    combined_filter_dict={}\n",
    "    for loss_func,loss in [(\"squared_loss\",\"squared\"),(\"logistic_loss\",\"logistic\")]:\n",
    "\n",
    "        rfilteres=[({\"loss_func\":loss_func})] \n",
    "        rest_exp=hr.filter_exp_list(exp_list, filterby_list=rfilteres, verbose=0)\n",
    "        loss_exp_dict[loss_func]=rest_exp \n",
    "\n",
    "\n",
    "    for loss_func,loss in [(\"squared_loss\",\"squared\"),(\"logistic_loss\",\"logistic\")]:\n",
    "        filtered_best_exp_list=[]\n",
    "        for run in runs:\n",
    "            filtered_best_exp_list +=[        \n",
    "                    ({'opt':{'name':'EXP_SHB', 'alpha_t':'CNST', 'method':'WANG21', \"is_sls\":False},  \n",
    "                        \"loss_func\":loss_func, 'kappa':kappa, 'variance': variance, 'batch_size':batch_size, 'runs':run},\n",
    "                    {'style':styles[\"SHB_CNST\"]})]\n",
    "\n",
    "\n",
    "            filtered_best_exp_list +=[        \n",
    "                ({'opt':{'name':'EXP_SGD', 'alpha_t':'CNST'},  \n",
    "                    \"loss_func\":loss_func, 'kappa':kappa, 'variance': variance, 'batch_size':batch_size, 'runs':run},\n",
    "                {'style':styles[\"SGD_CNST\"]})]\n",
    "\n",
    "            filtered_best_exp_list +=[        \n",
    "                ({'opt':{'name':'EXP_ACC_SGD', 'alpha_t':'DECR', \"is_sls\":False},  \n",
    "                    \"loss_func\":loss_func, 'kappa':kappa, 'variance': variance, 'batch_size':batch_size, 'runs':run},\n",
    "                {'style':styles[\"SGD_ACC_EXP\"]})]\n",
    "        \n",
    "            for c in [0.5]:\n",
    "                filtered_best_exp_list +=[        \n",
    "                ({'opt':{'name':'Mix_SHB', 'c':c},  \n",
    "                    \"loss_func\":loss_func, 'batch_size':batch_size, 'kappa':kappa, 'variance': variance, 'runs':run},\n",
    "                   {'style':styles[\"SHB_MIX_0.5\"]})]\n",
    "\n",
    "            for c in [0.4]:\n",
    "                for beta in [False, True]:\n",
    "                    filtered_best_exp_list +=[        \n",
    "                        ({'opt':{'name':'M_ASHB', 'c':c, 'beta_const':beta},  \n",
    "                            \"loss_func\":loss_func, 'batch_size':batch_size, 'kappa':kappa, 'variance': variance, 'runs':run},\n",
    "                           {'style':styles[f\"SHB_MUL_{c}_{beta}\"]})]\n",
    "\n",
    "        combined_filter_dict[loss_func]=filtered_best_exp_list \n",
    "\n",
    "    print(len(hr.filter_exp_list(loss_exp_dict[\"squared_loss\"], filterby_list=combined_filter_dict[\"squared_loss\"], verbose=0)))\n",
    "    return loss_exp_dict, combined_filter_dict"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7f49c13a",
   "metadata": {},
   "source": [
    "# Squared Loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a825945c",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "datasets = [\"synthetic_kappa\"]\n",
    "batch_size = -10/9\n",
    "kappa = 512\n",
    "variance = 1e-2\n",
    "runs = [0,1,2]\n",
    "config = EXP_SYN_NON_INTERP_CONFIGS\n",
    "loss_exp_dict, combined_filter_dict = get_res_dict(config, datasets, batch_size, kappa, variance, runs)\n",
    "show_legend_all=None\n",
    "rm = hr.ResultManager(\n",
    "                    exp_list= loss_exp_dict[\"squared_loss\"],\n",
    "                    savedir_base=savedir_synthetic_main_plot,\n",
    "                    filterby_list=combined_filter_dict[\"squared_loss\"],\n",
    "                    verbose= 0       \n",
    "                     )\n",
    "rm.get_plot_all(\n",
    "    avg_across='runs',\n",
    "    plot_median=True,\n",
    "    order='metrics_by_groups',\n",
    "    show_legend_all=show_legend_all,\n",
    "    legend_last_row_only=True,\n",
    "    y_metric_list =  ['grad_norm'], \n",
    "    x_metric='itr',\n",
    "    legend_list=['opt.name'], \n",
    "    title_list = ['kappa','variance'], \n",
    "    groupby_list = ['dataset', 'model', 'n_samples', 'variance', 'kappa'],\n",
    "    log_metric_list = [\"grad_norm\", 'alpha_k', 'train_loss'],\n",
    "    legend_fontsize=14,\n",
    "    x_fontsize=27,\n",
    "    y_fontsize=27,\n",
    "    xtick_fontsize=27,\n",
    "    ytick_fontsize=27,\n",
    "    title_fontsize=18,\n",
    "    result_step=50,\n",
    "    ylim_list=[[(10**(-4.5), 10**(0))]], \n",
    "    \n",
    "    map_ylabel_list=[{'train_loss':'Train loss'}, \n",
    "                     {'val_acc':'Validation accuracy'},\n",
    "                     {'grad_norm': 'Gradient Norm'},\n",
    "                     {'alpha_k': 'alpha_k'},\n",
    "                     {'beta_k': 'beta_k'},\n",
    "                     {'lambda_k': 'lambda_k'},\n",
    "                     {'n_grad_evals': 'Number of stochastic gradient evaluations'}],\n",
    "    map_xlabel_list=[{'itr': 'Iteration'}],\n",
    "                figsize=(12,8),\n",
    "    plot_confidence=False,\n",
    "     legend_kwargs={'bbox_to_anchor':[0.5, -0.22], \n",
    "                           'borderaxespad':0., \n",
    "                           'ncol':6,'loc':'center'},\n",
    "    savedir_plots=f'./plots/squared_synthetic_label_{batch_size}_{kappa}_{variance}' \n",
    "            )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d71d625b",
   "metadata": {},
   "source": [
    "# Logistic Loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e9490090",
   "metadata": {},
   "outputs": [],
   "source": [
    "show_legend_all=None\n",
    "rm = hr.ResultManager(\n",
    "                    exp_list= loss_exp_dict[\"logistic_loss\"],\n",
    "                    savedir_base=savedir_mushroom_kernel_2,\n",
    "                    filterby_list=combined_filter_dict[\"logistic_loss\"],\n",
    "                    verbose= 0       \n",
    "                     )\n",
    "\n",
    "\n",
    "rm.get_plot_all(\n",
    "    avg_across='runs',\n",
    "    plot_median=True,\n",
    "    order='metrics_by_groups',\n",
    "    show_legend_all=show_legend_all,\n",
    "    legend_last_row_only=True,\n",
    "    y_metric_list =  ['grad_norm'], \n",
    "    x_metric='itr',\n",
    "    legend_list=['opt.name'], \n",
    "    title_list = ['dataset'], \n",
    "    groupby_list = ['dataset', 'model'],\n",
    "    log_metric_list = [\"grad_norm\"],\n",
    "    legend_fontsize=16,\n",
    "    x_fontsize=16,\n",
    "    y_fontsize=16,\n",
    "    xtick_fontsize=16,\n",
    "    ytick_fontsize=16,\n",
    "    title_fontsize=18,\n",
    "    result_step=1,\n",
    "    ylim_list=[[(10**(-6.5), 10**(-1))]],\n",
    "    map_ylabel_list=[{'train_loss':'Train loss'}, \n",
    "                     {'val_acc':'Validation accuracy'},\n",
    "                     {'grad_norm': 'Gradient Norm'},\n",
    "                     {'alpha_k': 'alpha_k'},\n",
    "                     {'beta_k': 'beta_k'},\n",
    "                     {'n_grad_evals': 'Number of stochastic gradient evaluations'}],\n",
    "    map_xlabel_list=[{'itr': '(Gradient evaluations) / 1000'}],\n",
    "                figsize=(12,5),\n",
    "    plot_confidence=True,\n",
    "     legend_kwargs={'bbox_to_anchor':[0.5, -0.25], 'ncol':7, 'loc':'center'},\n",
    "    savedir_plots='./plots/logistic_mushroom_kernel_2'\n",
    "            )"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
