{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n",
      "2019-10-02 23:46:15 WARNING  From /home/arbiter/.pyenv/versions/3.7.3/envs/ap-submission/lib/python3.7/site-packages/cleverhans/utils_tf.py:341: The name tf.GraphKeys is deprecated. Please use tf.compat.v1.GraphKeys instead.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "import json\n",
    "import os\n",
    "import math\n",
    "import re\n",
    "import logging\n",
    "from functools import reduce, partial\n",
    "from collections import OrderedDict\n",
    "from typing import Dict, List, Tuple, Union, Callable\n",
    "import pprint\n",
    "from mkdir_p import mkdir_p\n",
    "    \n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "import plotly\n",
    "import chart_studio.plotly as py\n",
    "import plotly.graph_objs as go\n",
    "from sklearn.model_selection import ParameterGrid\n",
    "from flatten_dict import flatten\n",
    "\n",
    "from nnattack.variables import auto_var\n",
    "from params import (\n",
    "    compare_attacks,\n",
    "    compare_defense,\n",
    "    #parametric_defense,\n",
    "    \n",
    "    #compare_nns,\n",
    "    \n",
    "    nn_k1_robustness,\n",
    "    nn_k3_robustness,\n",
    "    rf_robustness,\n",
    "    dt_robustness,\n",
    "    lr_ap_robustness,\n",
    "    lr_at_robustness,\n",
    "    mlp_ap_robustness,\n",
    "    mlp_at_robustness,\n",
    "    \n",
    "    tst_scores,\n",
    "    \n",
    "    dt_robustness_figs,\n",
    "    nn_k1_robustness_figs,\n",
    "    nn_k3_robustness_figs,\n",
    "    rf_robustness_figs,\n",
    "    \n",
    "    nn1_def,\n",
    "    nn3_def,\n",
    "    dt_def,\n",
    "    rf_def,\n",
    "    lr_def,\n",
    "    mlp_def,\n",
    ")\n",
    "import params\n",
    "#import params_l2\n",
    "from utils import set_plot, get_result, write_to_tex, union_param_key, params_to_dataframe, table_wrapper\n",
    "\n",
    "auto_var.set_variable_value('random_seed', 0)\n",
    "auto_var.set_variable_value('ord', 'inf')\n",
    "auto_var.set_logging_level(0)\n",
    "\n",
    "compare_attacks = compare_attacks()\n",
    "compare_defense = compare_defense()\n",
    "tst_scores = tst_scores()\n",
    "\n",
    "#compare_nns = compare_nns()\n",
    "mlp_ap_robustness = mlp_ap_robustness()\n",
    "mlp_at_robustness = mlp_at_robustness()\n",
    "lr_ap_robustness = lr_ap_robustness()\n",
    "lr_at_robustness = lr_at_robustness()\n",
    "nn_k1_robustness = nn_k1_robustness()\n",
    "nn_k3_robustness = nn_k3_robustness()\n",
    "rf_robustness = rf_robustness()\n",
    "dt_robustness = dt_robustness()\n",
    "dt_robustness_figs = dt_robustness_figs()\n",
    "nn_k1_robustness_figs = nn_k1_robustness_figs()\n",
    "nn_k3_robustness_figs = nn_k3_robustness_figs()\n",
    "rf_robustness_figs = rf_robustness_figs()\n",
    "\n",
    "nn1_def = nn1_def()\n",
    "nn3_def = nn3_def()\n",
    "dt_def = dt_def()\n",
    "rf_def = rf_def()\n",
    "lr_def = lr_def()\n",
    "mlp_def = mlp_def()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "def result_latex_figs(exp_name, control_var, caption):\n",
    "    control = ParameterGrid(control_var)\n",
    "    ret = \"\"\"\n",
    "\\\\begin{figure}[ht!]\n",
    "\\\\centering\"\"\"\n",
    "    img_paths = []\n",
    "    for i, g in enumerate(control):\n",
    "        dataset, ord = g['dataset'], g['ord']\n",
    "        img_path = f'./figs/{exp_name}_{dataset}_{ord}.eps'\n",
    "        dataset = dataset.replace(\"_\", \" \")\n",
    "        ret += \"\"\"\n",
    "\\\\subfloat[%s]{\n",
    "    \\\\includegraphics[width=.45\\\\textwidth]{%s}}\"\"\" % (dataset, img_path)\n",
    "        if i % 2 == 1:\n",
    "            ret += \"\\n\"\n",
    "    ret += \"\"\"\n",
    "\\\\caption{%s}\n",
    "\\\\label{fig:%s}\n",
    "\\\\end{figure} \n",
    "\"\"\" % (caption, exp_name)\n",
    "    return ret\n",
    "                      \n",
    "def plot_result(df, exp_name, control_var, variables,\n",
    "                get_title_fn: Union[Callable[[Dict], str], None]=None,\n",
    "                get_label_name_fn: Union[Callable[[Dict], str], None]=None,\n",
    "                get_label_color_fn: Union[Callable[[Dict], str], None]=None, show_plot=True):\n",
    "    ret = []\n",
    "    for g in ParameterGrid(control_var):\n",
    "        temp_df = df\n",
    "                      \n",
    "        if get_title_fn is None:\n",
    "            title = exp_name\n",
    "            for k, v in g.items():\n",
    "                title = title + f\"_{get_var_name(k, v)}\"\n",
    "        else:\n",
    "            title = get_title_fn(g)\n",
    "            \n",
    "        for k, v in g.items():\n",
    "            temp_df = temp_df.loc[df[k] == v]\n",
    "                      \n",
    "        fig, ax = plt.subplots()\n",
    "        ax.set_title(title, fontsize=20)\n",
    "        for name, group in temp_df.groupby(variables):\n",
    "            #print(name, len(group))\n",
    "            eps_list = [re.findall(r'[-+]?\\d*\\.\\d+|\\d+', t)[0] for t in group.mean().index.tolist()[:-1]]\n",
    "            s = [r for r in group.mean().tolist()[:-1] if not np.isnan(r)]\n",
    "            x = [float(eps_list[i]) for i, r in enumerate(group.mean().tolist()[:-1]) if not np.isnan(r)]\n",
    "                      \n",
    "            if get_label_name_fn is not None:\n",
    "                label = get_label_name(name)\n",
    "            elif isinstance(name, str):\n",
    "                label = get_var_name(variables[0], name)\n",
    "            else:\n",
    "                mod_names = []\n",
    "                for i, n in enumerate(name):\n",
    "                    mod_names.append(get_var_name(variables[i], n))\n",
    "                label = mod_name.join(\"_\")\n",
    "\n",
    "            if get_label_color_fn is not None:\n",
    "                ax.plot(x, s, label=label, linewidth=3.5, color=get_label_color_fn(name))\n",
    "            else:\n",
    "                ax.plot(x, s, label=label, linewidth=3.5)\n",
    "\n",
    "        dataset = g['dataset']\n",
    "        ord = g['ord']\n",
    "        set_plot(fig, ax)\n",
    "        plt.savefig(f'./figs/{exp_name}_{dataset}_{ord}.eps', format='eps')\n",
    "        #plt.savefig(f'./figs/{exp_name}_{dataset}_{ord}.png', format='png')\n",
    "        ret.append((g, f'./figs/{exp_name}_{dataset}_{ord}.eps'))\n",
    "        if show_plot:\n",
    "            plt.show()\n",
    "        else:\n",
    "            plt.close()\n",
    "    return ret"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def parbox(width, content):\n",
    "    return \"\\\\parbox{%dmm}{\\\\centering %s}\" % (width, content)\n",
    "\n",
    "def knn_attack_plots(exp_name, grid_param, caption='', show_plot=True):\n",
    "    df = params_to_dataframe(grid_param)\n",
    "    datasets = set.union(*[set(g['dataset']) for g in grid_param]) if isinstance(grid_param, list) else grid_param['dataset']\n",
    "\n",
    "    control = {\n",
    "        'dataset': datasets,\n",
    "        'ord': grid_param[0]['ord'],\n",
    "    }\n",
    "    variables = ['attack']\n",
    "    plot_result(df, exp_name, control, variables, show_plot)\n",
    "    return result_latex_figs(exp_name, control, caption)\n",
    "\n",
    "def get_var_name(var, arg):\n",
    "    if var == 'dataset':\n",
    "        return auto_var.get_var_shown_name(var, arg)\n",
    "    return arg.replace('_', '-')\n",
    "\n",
    "def avg_pert_table(exp_name, grid_param, columns, rows, objs:list=None, obj_formats:list=None):\n",
    "    if objs is None:\n",
    "        objs = ['avg_pert']\n",
    "    columns = list(filter(lambda a: a not in ['n_features', 'n_samples', 'n_classes'], columns))\n",
    "    if len(columns) == 0 or len(rows) == 0:\n",
    "        return pd.DataFrame({})\n",
    "    df = params_to_dataframe(grid_param, objs)\n",
    "    \n",
    "    d = OrderedDict()\n",
    "    col_grid = OrderedDict([(c, union_param_key(grid_param, c)) for c in columns])\n",
    "    row_grid = OrderedDict([(r, union_param_key(grid_param, r)) for r in rows])\n",
    "    for i, obj in enumerate(objs):\n",
    "        temp_df = df.groupby(columns + rows)[obj].mean()\n",
    "        temp_df_sem = df.groupby(columns + rows)[obj].sem()\n",
    "        \n",
    "        if obj == 'tst_score':\n",
    "            assert columns[0] == 'model'\n",
    "        for col in ParameterGrid(col_grid):\n",
    "            col_k = tuple(col[c] for c in columns)\n",
    "            col_name = tuple([get_var_name(c, col[c]) for c in columns[:-1]] \\\n",
    "                             + [\"%s-%s\" % (get_var_name(columns[-1], col[columns[-1]]), obj.replace(\"_\", \"-\"))])\n",
    "            d[col_name] = {}\n",
    "            for row in ParameterGrid(row_grid):\n",
    "                row_k = tuple(row[r] for r in rows)\n",
    "                row_name = tuple(get_var_name(r, row[r]) for r in rows)\n",
    "                if (col_k + row_k) in temp_df:\n",
    "                    #d[col_name][row_name] = \"$%.3f \\pm %.3f$\" % (temp_df[col_k + row_k], temp_df_sem[col_k + row_k])\n",
    "                    if obj_formats is None:\n",
    "                        str_format = \"$%.3f$\"\n",
    "                    else:\n",
    "                        str_format = obj_formats[i]\n",
    "                    d[col_name][row_name] = str_format % (temp_df[col_k + row_k])\n",
    "                    if temp_df[col_k + row_k] < 1:\n",
    "                        d[col_name][row_name] = d[col_name][row_name].replace(\"0.\", \".\")\n",
    "                else:\n",
    "                    d[col_name][row_name] = -1\n",
    "\n",
    "    #d = OrderedDict([(k, d[k]) for k in d.keys()])\n",
    "    return pd.DataFrame(d)\n",
    "\n",
    "def dataset_stat_column(df, grid_param, columns, rows):\n",
    "    if (\"n_features\" not in columns) and (\"n_samples\" not in columns) and (\"n_classes\" not in columns) \\\n",
    "        and (\"n_train\" not in columns) and (\"n_test\" not in columns):\n",
    "        return df\n",
    "    \n",
    "    column_names = {\n",
    "        'n_train': '\\# training',\n",
    "        'n_test': '\\# testing',\n",
    "        'n_features': '\\# features',\n",
    "        'n_samples': '\\# examples',\n",
    "        'n_classes': '\\# classes',\n",
    "    }\n",
    "    \n",
    "    d = df.to_dict(into=OrderedDict)\n",
    "    datasets = union_param_key(grid_param, \"dataset\")\n",
    "    if len(d.keys()) > 0:\n",
    "        first_key = list(d.keys())[0]\n",
    "        row_len = 1 if isinstance(d[first_key], str) else len(first_key)\n",
    "        col_len = 1 if isinstance(first_key, str) else len(first_key)\n",
    "        ori_cols = list(d.keys())\n",
    "    else:\n",
    "        row_len = 1\n",
    "        col_len = 1\n",
    "        ori_cols = []\n",
    "    \n",
    "    for dataset in datasets:\n",
    "        X, y, _ = auto_var.get_var_with_argument(\"dataset\", dataset)\n",
    "        row_name = (get_var_name(\"dataset\", dataset), )\n",
    "        for col in columns:\n",
    "            if col not in column_names:\n",
    "                continue\n",
    "            column_name = tuple(['-' for _ in range(col_len-1)] + [column_names[col]])\n",
    "            if col == \"n_features\":\n",
    "                d.setdefault(column_name, {})[row_name] = X.shape[1]\n",
    "            elif col == \"n_samples\":\n",
    "                d.setdefault(column_name, {})[row_name] = X.shape[0]\n",
    "            elif col == \"n_train\":\n",
    "                d.setdefault(column_name, {})[row_name] = X.shape[0] - 200\n",
    "            elif col == \"n_test\":\n",
    "                d.setdefault(column_name, {})[row_name] = 100\n",
    "            elif col == \"n_classes\":\n",
    "                d.setdefault(column_name, {})[row_name] = len(np.unique(y))\n",
    "                \n",
    "    for col in ori_cols:\n",
    "        d.move_to_end(col)\n",
    "        \n",
    "    return pd.DataFrame(d)\n",
    "    \n",
    "def cmp_ratio(df):\n",
    "    ret = OrderedDict()\n",
    "    d = df.to_dict(into=OrderedDict)\n",
    "    cmp_base = []\n",
    "    \n",
    "    i = 0\n",
    "    for col, col_dict in d.items():\n",
    "        ret[col] = col_dict\n",
    "        if 'avg-pert' not in col[1]:\n",
    "            continue\n",
    "        if i == 0 or i == 1:\n",
    "            cmp_base.append(col_dict)\n",
    "            i += 1\n",
    "            continue\n",
    "        temp = {}\n",
    "        for k, v in col_dict.items():\n",
    "            if v == -1 or cmp_base[i % 2][k] == -1:\n",
    "                temp[k] = int(-1)\n",
    "            else:\n",
    "                v = v.replace(\"$\", \"\")\n",
    "                t = cmp_base[i % 2][k].replace(\"$\", \"\")\n",
    "                temp[k] = \"$%.2f$\" % (float(v) / float(t))\n",
    "        \n",
    "        ret[tuple([c for c in col[:-1]] + [\"%s imp.\" % col[-1]])] = temp\n",
    "        i += 1\n",
    "        \n",
    "    return pd.DataFrame(ret)\n",
    "\n",
    "def max_imp(df):\n",
    "    ret = OrderedDict()\n",
    "    d = df.to_dict(into=OrderedDict)\n",
    "    \n",
    "    def add_new_col(col_list, ret):\n",
    "        new_col = {}\n",
    "        \n",
    "        for attack_name in [col_list[0][0][1], col_list[1][0][1]]:\n",
    "            temp = list(filter(lambda t: t[0][1] == attack_name, col_list))\n",
    "            imps = []\n",
    "            for c in temp:\n",
    "                imps.append([float(v.replace(\"$\", \"\")) if v != -1 else -1 for _, v in c[1].items()])\n",
    "            imps = (np.array(imps).T).argmax(axis=1)\n",
    "\n",
    "            new_col = {}\n",
    "            new_col_imp = {}\n",
    "            new_col_eps = {}\n",
    "            pcol = temp[0][0]\n",
    "            \n",
    "            if 'd' in pcol[0].split(\"-\")[-1]:\n",
    "                tt = pcol[0].split(\"-\")\n",
    "                tt.pop(-2)\n",
    "            else:\n",
    "                tt = pcol[0].split(\"-\")[:-1]\n",
    "            new_col_name = (\"-\".join(tt), pcol[1])\n",
    "            new_col_imp_name = (\"-\".join(tt), (\"%s imp.\" % pcol[1]))\n",
    "            new_col_eps_name = (\"-\".join(tt), (\"%s $\\\\epsilon$\" % pcol[1]))\n",
    "            for i, idx in enumerate(imps):\n",
    "                k, v = list(temp[idx][1].items())[i]\n",
    "                new_col[k] = v\n",
    "                k, v = list(temp[idx][2].items())[i]\n",
    "                new_col_imp[k] = v \n",
    "\n",
    "                if 'd' in temp[idx][0][0].split(\"-\")[-1]:\n",
    "                    new_col_eps[k] = \"$\" + (\"%.1f$\" % (float(temp[idx][0][0].split(\"-\")[-2]) * 0.01))[1:]\n",
    "                else:\n",
    "                    new_col_eps[k] = \"$\" + (\"%.1f$\" % (float(temp[idx][0][0].split(\"-\")[-1]) * 0.01))[1:]\n",
    "\n",
    "            ret[new_col_eps_name] = new_col_eps\n",
    "            ret[new_col_name] = new_col\n",
    "            ret[new_col_imp_name] = new_col_imp\n",
    "    \n",
    "    prev_col = None\n",
    "    temp = []\n",
    "    for i, (col, col_dict) in enumerate(d.items()):\n",
    "        if 'd' in col[0].split(\"-\")[-1]:\n",
    "            check_idx = -2\n",
    "        else:\n",
    "            check_idx = -1\n",
    "            \n",
    "        if i == 0 or i == 1:\n",
    "            ret[col] = col_dict\n",
    "            continue\n",
    "            \n",
    "        if len(temp) == 0:\n",
    "            temp.append(col_dict)\n",
    "        elif i % 2 == 1:\n",
    "            temp[-1] = (prev_col, temp[-1], col_dict)\n",
    "            if i == (len(d.items())-1):\n",
    "                add_new_col(temp, ret)\n",
    "        else:\n",
    "            if col[0].split(\"-\")[:check_idx] != prev_col[0].split(\"-\")[:check_idx]:\n",
    "                add_new_col(temp, ret)\n",
    "                temp = [col_dict]\n",
    "            else:\n",
    "                temp.append(col_dict)\n",
    "                \n",
    "        prev_col = col\n",
    "        \n",
    "    return pd.DataFrame(ret)\n",
    "\n",
    "def bold_best(df, reverse=False):\n",
    "    d = df.to_dict(into=OrderedDict)\n",
    "    \n",
    "    temp = []\n",
    "    for i, (col, col_dict) in enumerate(d.items()):\n",
    "        temp.append([])\n",
    "        for row, row_value in col_dict.items():\n",
    "            if isinstance(row_value, str):\n",
    "                temp[-1].append(float(row_value.replace(\"$\", '')))\n",
    "            else:\n",
    "                temp[-1].append(np.inf if reverse else -np.inf)\n",
    "            \n",
    "    temp = np.array(temp).T\n",
    "    if reverse:\n",
    "        best_idx = temp.argmin(axis=1)\n",
    "    else:\n",
    "        best_idx = temp.argmax(axis=1)\n",
    "        \n",
    "    ret = OrderedDict()\n",
    "    for i, (col, col_dict) in enumerate(d.items()):\n",
    "        ret[col] = {}\n",
    "        for j, (row, row_value) in enumerate(col_dict.items()):\n",
    "            if not isinstance(row_value, str):\n",
    "                ret[col][row] = row_value\n",
    "            else:\n",
    "                if float(row_value[1:-1]) == temp[j][best_idx[j]]:\n",
    "                    ret[col][row] = \"$\\\\mathbf{\" + row_value[1:-1] + \"}$\"\n",
    "                else:\n",
    "                    ret[col][row] = row_value\n",
    "\n",
    "    return pd.DataFrame(ret)\n",
    "\n",
    "def gen_table(exp_name, grid_params, columns, rows, objs=None,\n",
    "              combine_method=None, additionals=None, obj_formats=None):\n",
    "    if objs is None:\n",
    "        objs = ['avg_pert']\n",
    "    df = pd.DataFrame({})\n",
    "    if combine_method is None:\n",
    "        df = avg_pert_table(exp_name, grid_params, columns, rows, objs, obj_formats)\n",
    "        if additionals:\n",
    "            for fn in additionals:\n",
    "                df = fn(df)\n",
    "    else:\n",
    "        dfs = []\n",
    "        for g in grid_params:\n",
    "            df = avg_pert_table(exp_name, g, columns, rows, objs, obj_formats)\n",
    "            if additionals:\n",
    "                for fn in additionals:\n",
    "                    df = fn(df)\n",
    "            dfs.append(df)\n",
    "        df = pd.concat(dfs, axis=combine_method)\n",
    "    \n",
    "    if 'dataset' in rows:\n",
    "        df = dataset_stat_column(df, grid_param, columns, rows)\n",
    "    return df\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def model_acc(df, grid_param):\n",
    "    # col = ['model', 'attack']\n",
    "    ret = OrderedDict()\n",
    "    tst_df = params_to_dataframe(grid_param, ['tst_score'])\n",
    "\n",
    "    d = df.to_dict(into=OrderedDict)\n",
    "    models = set([c[0] for c, _ in d.items()])\n",
    "    \n",
    "    prev_col =None\n",
    "    for i, (col, col_dict) in enumerate(d.items()):\n",
    "        new_col_dict = OrderedDict({})\n",
    "        if i == 0:\n",
    "            for row, _ in col_dict.items():\n",
    "                temp_df = tst_df[(tst_df['model'] == col[0].replace(\"-\", \"_\"))\n",
    "                                 & (tst_df['attack'] == 'blackbox') \n",
    "                                 & (tst_df['dataset'] == row[0].replace(\"-\", \"_\"))]\n",
    "                new_col_dict[row] = \"$%.2f$\" % temp_df['tst_score'].mean()\n",
    "            ret[(col[0], col[1].replace('-avg-pert', ' tst acc.'))] = new_col_dict\n",
    "            \n",
    "        elif '\\\\epsilon' in col[1]:\n",
    "            m = re.match(r\"(?P<attack>[a-zA-Z_0-9'-]+) \\$\\\\epsilon\\$\", col[1])\n",
    "            attack_name = m.group(\"attack\")[:-9].replace(\"-\", \"_\") # remove '$epsilon$'\n",
    "            for row, row_val in col_dict.items():\n",
    "                if 'd' in col[0].split('-')[-1]:\n",
    "                    model_name = '%s-%d-%s' % ('-'.join(col[0].split('-')[:-1]),\n",
    "                                               int(float(row_val.replace(\"$\", \"\"))*100),\n",
    "                                               col[0].split('-')[-1],)\n",
    "                else:\n",
    "                    model_name = \"%s-%d\" % (col[0], int(float(row_val.replace(\"$\", \"\"))*100))\n",
    "                model_name = model_name.replace(\"-\", \"_\")\n",
    "                temp_df = tst_df[(tst_df['model'] == model_name)\n",
    "                                 & (tst_df['attack'] == attack_name) \n",
    "                                 & (tst_df['dataset'] == row[0].replace(\"-\", \"_\"))]\n",
    "                new_col_dict[row] = \"$%.2f$\" % temp_df['tst_score'].mean()\n",
    "            ret[(col[0], col[1].replace('-avg-pert $\\\\epsilon$', ' tst acc.'))] = new_col_dict\n",
    "            \n",
    "        prev_col = col\n",
    "        ret[col] = col_dict\n",
    "    return pd.DataFrame(ret)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "./results/fashion-mnist35f-pca25-adv-nn-k1-30-RBA-Exact-KNN-k1-rs0-linf.json doesn't exist\n",
      "./results/fashion-mnist06f-pca25-adv-nn-k1-30-RBA-Exact-KNN-k1-rs0-linf.json doesn't exist\n",
      "./results/mnist17f-pca25-adv-nn-k1-30-RBA-Exact-KNN-k1-rs0-linf.json doesn't exist\n",
      "./results/covtypebin-2200-adv-nn-k1-30-RBA-Approx-KNN-k1-50-rs0-linf.json doesn't exist\n"
     ]
    }
   ],
   "source": [
    "def improvement(df):\n",
    "    ret = OrderedDict()\n",
    "    d = df.to_dict(into=OrderedDict)\n",
    "    \n",
    "    for i, (col, col_dict) in enumerate(d.items()):\n",
    "        if i != 0:\n",
    "            ret[col] = {}\n",
    "        for row, row_value in col_dict.items():\n",
    "            if i == 0:\n",
    "                ref = col_dict\n",
    "                value = 1.0\n",
    "            elif ref[row] == -1 or row_value == -1:\n",
    "                value = -1.\n",
    "            else:\n",
    "                value = (float(row_value.replace(\"$\", '')) / float(ref[row].replace(\"$\", '')))\n",
    "                \n",
    "            if i != 0:\n",
    "                ret[col][row] = \"$%.2f$\" % value\n",
    "        \n",
    "    return pd.DataFrame(ret)\n",
    "\n",
    "_, exp_name, grid_param, _ = compare_defense()\n",
    "avg_caption = \"\"\"\n",
    "The \\defenderscore across four nonparametric classifiers and corresponding competitors.\n",
    "A number greater than one indicates that the defense yields a more robust model, \n",
    "while less than one indicates less robustness (higher is better; best is in bold).\n",
    "The \\defenderscore for undefended classifiers are always one.\n",
    "\"\"\"\n",
    "#del grid_param[1]\n",
    "\n",
    "table_str = gen_table(\n",
    "                exp_name, grid_param, ['model', 'attack'], ['dataset'], combine_method=1,\n",
    "                objs=['avg_pert'], additionals=[improvement, bold_best]\n",
    "            ).to_latex(escape=False)\n",
    "table_str = re.sub(r\"[\\s]*adv-nn-k1-30 & [\\s]*robustv2-nn-k1-30 & [\\s]*advPruning-nn-k1-30\",\n",
    "                   r\"\\\\multicolumn{3}{c}{1-NN}\", table_str)\n",
    "table_str = re.sub(r\"[\\s]*adv-nn-k3-30 & [\\s]*advPruning-nn-k3-30\",\n",
    "                   r\"\\\\multicolumn{2}{c}{3-NN}\", table_str)\n",
    "table_str = re.sub(r\"[\\s]*adv-decision-tree-d5-30 & [\\s]*robust-decision-tree-d5-30 & [\\s]*advPruning-decision-tree-d5-30\",\n",
    "                   r\"\\\\multicolumn{3}{c}{DT}\", table_str)\n",
    "table_str = re.sub(r\"[\\s]*adv-rf-100-30-d5 & [\\s]*robust-rf-100-30-d5 & [\\s]*advPruning-rf-100-30-d5\",\n",
    "                   r\"\\\\multicolumn{3}{c}{RF}\", table_str)\n",
    "#table_str = re.sub(r\"[\\s]*adv-mlp-30 & [\\s]*advPruning-mlp-30\",\n",
    "#                   r\" \\\\multicolumn{2}{c}{MLP}\", table_str)\n",
    "#table_str = re.sub(r\"[\\s]*adv-logistic-regression-30 & [\\s]*advPruning-logistic-regression-30\",\n",
    "#                   r\" \\\\multicolumn{2}{c}{LR}\", table_str)\n",
    "table_str = table_str.replace(\"RBA-Exact-KNN-k1-avg-pert & RBA-Exact-KNN-k1-avg-pert & RBA-Exact-KNN-k1-avg-pert\",\n",
    "                              \"AT & Wang's & AP\")\n",
    "table_str = table_str.replace(\"RBA-Approx-KNN-k1-50-avg-pert & RBA-Approx-KNN-k1-50-avg-pert & RBA-Approx-KNN-k1-50-avg-pert\",\n",
    "                              \"AT & Wang's & AP\")\n",
    "table_str = table_str.replace(\"RBA-Approx-KNN-k3-50-avg-pert & RBA-Approx-KNN-k3-50-avg-pert\",\n",
    "                              \"AT & AP\")\n",
    "table_str = re.sub(r\"[\\s]*RBA-Exact-DT-avg-pert & [\\s]*RBA-Exact-DT-avg-pert & [\\s]*RBA-Exact-DT-avg-pert\",\n",
    "                    \" AT & RS & AP\", table_str)\n",
    "table_str = table_str.replace(\"RBA-Approx-RF-100-avg-pert & RBA-Approx-RF-100-avg-pert & RBA-Approx-RF-100-avg-pert\",\n",
    "                              \"AT & RS & AP\")\n",
    "table_str = re.sub(r\"[\\s]*pgd-avg-pert & [\\s]*pgd-avg-pert\",\n",
    "                    \" AT & AP\", table_str)\n",
    "table_str = table_str.replace(\"adv-nnopt-k1-all-avg-pert\", \"AT\")\n",
    "table_str = table_str.replace(\"advPruning-decision-tree-d5-30\", \"AP\")\n",
    "#table_str = table_str.replace(\"llllllllllllllll\", \"l|ccc|cc|ccc|ccc||cc|cc\")\n",
    "table_str = table_str.replace(\"llllllllllll\", \"l|ccc|cc|ccc|ccc\")\n",
    "write_to_tex(table_str, exp_name + '_table.tex')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def attack_table(exp_name, grid_param):\n",
    "    table_str = gen_table(\n",
    "                    exp_name, grid_param, ['model', 'attack'], ['dataset'], combine_method=1,\n",
    "                    objs=['avg_pert'], additionals=[partial(bold_best, reverse=True)]\n",
    "                ).to_latex(escape=False)\n",
    "    table_str = re.sub(r\"([a-zA-Z_0-9'-]+) imp\\.\", \"imp.\", table_str)\n",
    "    table_str = re.sub(r\"([a-zA-Z_0-9'-]+) tst acc\\.\", \"tst acc.\", table_str)\n",
    "    table_str = re.sub(r\"([a-zA-Z_0-9'-]+)-avg-pert\", r\"\\1\", table_str)\n",
    "    table_str = re.sub(r\"([a-zA-Z_0-9'-]+) \\$\\\\epsilon\\$\", r\"$\\\\epsilon$\", table_str)\n",
    "    table_str = table_str.replace(\"lllllllllllllll\", \"l|ccccc|cccc|ccc|cc\")\n",
    "    table_str = table_str.replace(\"-avg-pert\", \"\")\n",
    "    table_str = table_str.replace(\"direct-k1\", \"Direct\")\n",
    "    table_str = table_str.replace(\"direct-k3\", \"Direct\")\n",
    "    table_str = table_str.replace(\"kernelsub-c1000-pgd\", \"Kernel\")\n",
    "    table_str = table_str.replace(\"RBA-Exact-KNN-k1\", \"RBA-Exact\")\n",
    "    table_str = table_str.replace(\"RBA-Exact-DT\", \"RBA-Exact\")\n",
    "    table_str = table_str.replace(\"RBA-Approx-KNN-k1-50\", \"RBA-Approx\")\n",
    "    table_str = table_str.replace(\"RBA-Approx-KNN-k3-50\", \"RBA-Approx\")\n",
    "    table_str = table_str.replace(\"RBA-Approx-RF-100\", \"RBA-Approx\")\n",
    "    table_str = table_str.replace(\"dt-papernots\", \"Papernot's\")\n",
    "    table_str = table_str.replace(\"rev-nnopt-k3-50-region\", \"RBA-Approx\")\n",
    "    table_str = table_str.replace(\"rf-attack-rev-100\", \"RBA-Approx\")\n",
    "    table_str = table_str.replace(\"dt-attack-opt\", \"RBA-Exact\")\n",
    "    table_str = table_str.replace(\"decision-tree-d5\", \"DT\")\n",
    "    table_str = table_str.replace(\"random-forest-100-d5\", \"RF\")\n",
    "    table_str = table_str.replace(\"knn1\", \"1-NN\")\n",
    "    table_str = table_str.replace(\"knn3\", \"3-NN\")\n",
    "    table_str = table_str.replace(\"\\multicolumn{5}{l}\", \"\\multicolumn{5}{c}\")\n",
    "    table_str = table_str.replace(\"\\multicolumn{4}{l}\", \"\\multicolumn{4}{c}\")\n",
    "    table_str = table_str.replace(\"\\multicolumn{3}{l}\", \"\\multicolumn{3}{c}\")\n",
    "    table_str = table_str.replace(\"\\multicolumn{2}{l}\", \"\\multicolumn{2}{c}\")\n",
    "    table_str = table_str.replace(\"blackbox\", \"BBox\")\n",
    "    table_str = table_str.replace(\"RBA-Exact\", \"\\\\makecell{RBA\\\\\\\\Exact}\")\n",
    "    table_str = table_str.replace(\"RBA-Approx\", \"\\\\makecell{RBA\\\\\\\\Approx}\")\n",
    "    return table_str\n",
    "\n",
    "_, exp_name, grid_param, _ = compare_attacks()\n",
    "write_to_tex(attack_table(exp_name, grid_param), exp_name + '_table.tex')\n",
    "#_, exp_name, grid_param, _ = params_l2.compare_attacks()()\n",
    "#write_to_tex(attack_table(exp_name, grid_param), exp_name + '_table.tex')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "exp_fns = [nn_k1_robustness_figs, nn_k3_robustness_figs, dt_robustness_figs, rf_robustness_figs]\n",
    "model_names = [\"1-NN\", \"3-NN\", \"Decision tree\", \"Random forest\"]\n",
    "def get_label_name(name):\n",
    "    if 'advPruning' in name:\n",
    "        return \"AP\"\n",
    "    elif 'robust' in name:\n",
    "        return \"RS\"\n",
    "    elif 'decision_tree' in name:\n",
    "        return \"Reg.\"\n",
    "    elif 'knn1' in name:\n",
    "        return \"Reg.\"\n",
    "    elif 'knn3' in name:\n",
    "        return \"Reg.\"\n",
    "    elif 'random_forest' in name:\n",
    "        return \"Reg.\"\n",
    "        \n",
    "    return name\n",
    "\n",
    "def get_label_color(name):\n",
    "    if 'advPruning' in name:\n",
    "        return \"#ff7f0e\"\n",
    "    elif 'robust' in name:\n",
    "        return \"#1f77b4\"\n",
    "    elif 'decision_tree' in name:\n",
    "        return \"#7f7f7f\"\n",
    "    elif 'knn1' in name:\n",
    "        return \"#7f7f7f\"\n",
    "    elif 'knn3' in name:\n",
    "        return \"#7f7f7f\"\n",
    "    elif 'random_forest' in name:\n",
    "        return \"#7f7f7f\"\n",
    "        \n",
    "    return name\n",
    "\n",
    "def compare_nn_plots(exp_name, grid_param, caption='', show_plot=False):\n",
    "    df = params_to_dataframe(grid_param)\n",
    "    datasets = set.union(*[set(g['dataset']) for g in grid_param]) if isinstance(grid_param, list) else grid_param['dataset']\n",
    "\n",
    "    control = {\n",
    "        'dataset': datasets,\n",
    "        'ord': grid_param['ord'],\n",
    "    }\n",
    "    variables = ['model']\n",
    "    \n",
    "    fig_paths = plot_result(df, exp_name, control, variables,\n",
    "                            get_title_fn=lambda g: get_var_name(\"dataset\", g['dataset']),\n",
    "                            get_label_name_fn=get_label_name,\n",
    "                            get_label_color_fn=get_label_color,\n",
    "                            show_plot=show_plot)\n",
    "    return fig_paths\n",
    "\n",
    "def fig_paths_latex(fig_paths: List[List[Tuple[Dict, str]]], fig_label, caption):\n",
    "    ret = \"\"\"\n",
    "\\\\begin{figure}[ht!]\n",
    "\\\\centering\"\"\"\n",
    "    img_paths = []\n",
    "    for row in fig_paths:\n",
    "        for entry in row:\n",
    "            g, img_path = entry\n",
    "            ret += \"\"\"\n",
    "\\\\subfloat[%s]{\n",
    "    \\\\includegraphics[width=%.2f\\\\textwidth]{%s}}\"\"\" % (g['subfig_label'], 1/len(fig_paths[0]), img_path)\n",
    "        ret += \"\\n\"\n",
    "    ret += \"\"\"\n",
    "\\\\caption{%s}\n",
    "\\\\label{fig:%s}\n",
    "\\\\end{figure} \n",
    "\"\"\" % (caption, fig_label)\n",
    "    return ret\n",
    "\n",
    "fig_paths = []\n",
    "for i, fn in enumerate(exp_fns):\n",
    "    _, exp_name, grid_param, _ = fn()\n",
    "    fig_path = compare_nn_plots(exp_name, grid_param, show_plot=False)\n",
    "    for g, _ in fig_path:\n",
    "        g['subfig_label'] = model_names[i]\n",
    "        g['subfig_label'] = get_var_name(\"dataset\", g['subfig_label'])\n",
    "    fig_paths.append(fig_path)\n",
    "transpose = [list() for c in fig_paths[0]]\n",
    "for i, col in enumerate(fig_paths):\n",
    "    for j, r in enumerate(col):\n",
    "        transpose[j].append(r)\n",
    "        \n",
    "caption = \"The maximum perturbation distance allowed versus accuracy.\"\n",
    "fig_str = fig_paths_latex(transpose[:5], \"defense-cmp\", caption)\n",
    "write_to_tex(fig_str, 'defense_cmp_fig.tex')\n",
    "\n",
    "#fig_str = fig_paths_latex(transpose[5:], \"defense-cmp2\", caption)\n",
    "#write_to_tex(fig_str, 'defense_cmp2_fig.tex')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "#_, exp_name, grid_param, _ = compare_nns()\n",
    "#\n",
    "#def get_title_fn(g):\n",
    "#    ret = get_var_name(\"dataset\", g['dataset'])\n",
    "#    return ret\n",
    "#\n",
    "#def compare_nn_plots(exp_name, grid_param, caption='', show_plot=True):\n",
    "#    df = params_to_dataframe(grid_param)\n",
    "#    datasets = set.union(*[set(g['dataset']) for g in grid_param]) if isinstance(grid_param, list) else grid_param['dataset']\n",
    "#\n",
    "#    control = {\n",
    "#        'dataset': datasets,\n",
    "#        'ord': grid_param[0]['ord'],\n",
    "#    }\n",
    "#    variables = ['model']\n",
    "#    figs = plot_result(df, exp_name, control, variables, get_title_fn=get_title_fn, show_plot=show_plot)\n",
    "#    fig_paths = []\n",
    "#    for i, f in enumerate(figs):\n",
    "#        if i % 3 == 0:\n",
    "#            fig_paths.append([])\n",
    "#        f[0]['subfig_label'] = f[0]['dataset']\n",
    "#        f[0]['subfig_label'] = get_var_name(\"dataset\", f[0]['subfig_label'])\n",
    "#        fig_paths[-1].append(f)\n",
    "#    \n",
    "#    return fig_paths_latex(fig_paths, exp_name, caption=caption)\n",
    "#caption = \"\"\"\n",
    "#The maximum perturbation distance allowed versus accuracy with different $k$ of $k$-NN classifier\n",
    "#using RBA-Approx attack searching 50 regions.\"\n",
    "#\"\"\"\n",
    "#fig_str = compare_nn_plots(exp_name, grid_param, caption=caption, show_plot=False)\n",
    "#write_to_tex(fig_str, exp_name + '_fig.tex')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "ename": "ValueError",
     "evalue": "too many values to unpack (expected 3)",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mValueError\u001b[0m                                Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-9-468a222865c9>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     12\u001b[0m \u001b[0mret\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mOrderedDict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     13\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mds\u001b[0m \u001b[0;32min\u001b[0m \u001b[0menumerate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdatasets\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 14\u001b[0;31m     \u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mauto_var\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_var_with_argument\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"dataset\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mds\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     15\u001b[0m     \u001b[0mtX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mauto_var\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_var_with_argument\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"dataset\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtree_datasets\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     16\u001b[0m     ret[auto_var.get_var_shown_name(\"dataset\", ds)] = OrderedDict([\n",
      "\u001b[0;31mValueError\u001b[0m: too many values to unpack (expected 3)"
     ]
    }
   ],
   "source": [
    "from params import datasets, tree_datasets\n",
    "_, _, grid_param, _ = tst_scores()\n",
    "\n",
    "col_names = [\n",
    "    \"\\\\parbox{15mm}{\\\\centering \\# training \\\\\\\\ (1-NN, 3-NN)}\",\n",
    "    \"\\\\parbox{15mm}{\\\\centering \\# training \\\\\\\\ (DT, RF, MLP)}\",\n",
    "    \"\\\\parbox{15mm}{\\\\centering \\# testing \\\\\\\\ (perturbation)}\",\n",
    "    \"\\\\parbox{15mm}{\\\\centering \\# testing \\\\\\\\ (test accuracy)}\",\n",
    "    \"\\# features\",\n",
    "    \"\\# classes\",\n",
    "]\n",
    "ret = OrderedDict()\n",
    "for i, ds in enumerate(datasets):\n",
    "    X, y, _ = auto_var.get_var_with_argument(\"dataset\", ds)\n",
    "    tX, _, _ = auto_var.get_var_with_argument(\"dataset\", tree_datasets[i])\n",
    "    ret[auto_var.get_var_shown_name(\"dataset\", ds)] = OrderedDict([\n",
    "        (col_names[0], X.shape[0]-200),\n",
    "        (col_names[1], tX.shape[0]-200),\n",
    "        (col_names[2], 100),\n",
    "        (col_names[3], 200),\n",
    "        (col_names[4], X.shape[1]),\n",
    "        (col_names[5], 2),\n",
    "    ])\n",
    "df = pd.DataFrame(ret).T\n",
    "df = df[[c for c in col_names]]\n",
    "\n",
    "exp_name = \"dataset-stats\"\n",
    "caption = \"Dataset statistics.\"\n",
    "table_str = table_wrapper(df, table_name=exp_name, caption=caption)\n",
    "table_str = table_str.replace(\"lrrrrrr\", \"lcccccc\")\n",
    "write_to_tex(table_str, exp_name + '_table.tex')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "def structure_fonts(series):\n",
    "    if '\\\\# train' in series.name[1]:\n",
    "        return series.apply(lambda x: \"$%d$\" % x if not np.isnan(x) else \"-1\")\n",
    "    else:\n",
    "        return series.apply(lambda x: (\"$%.3f$\" % x) if x >= 1 else (\"$%.3f$\" % x).replace(\"0.\", \".\"))\n",
    "    \n",
    "def preprocess(grid_param):\n",
    "    models = [i.replace(\"_\", \"-\") for i in union_param_key(grid_param, 'model')]\n",
    "    attacks = []\n",
    "    for i in union_param_key(grid_param, 'attack'):\n",
    "        for s in ['-avg-pert', '-tst-score', '-aug-len', '-imp']:\n",
    "            attacks.append(i.replace(\"_\", \"-\") + s)\n",
    "    col_names = []\n",
    "    for model in models:\n",
    "        for attack in attacks:\n",
    "            if '-imp' in attack and model == models[0]:\n",
    "                continue\n",
    "            col_names.append((model, attack))\n",
    "    return col_names\n",
    "\n",
    "def process(task_fn):\n",
    "    _, exp_name, grid_param, _ = task_fn()\n",
    "    df = gen_table(exp_name, grid_param, ['model', 'attack'], ['dataset'],\n",
    "                   combine_method=1, objs=['tst_score', 'avg_pert', 'aug_len'],\n",
    "                   additionals=[])\n",
    "\n",
    "    models = [i.replace(\"_\", \"-\") for i in union_param_key(grid_param, 'model')]\n",
    "    attack = grid_param[0]['attack'][0].replace(\"_\", '-')\n",
    "    col_names = preprocess(grid_param)\n",
    "\n",
    "    df = df.apply(lambda a: a.apply(lambda b: float(str(b).replace(\"$\", \"\")) if b else b))\n",
    "    for model in models[1:]:\n",
    "        df[(model, attack + '-imp')] = df[(model, attack + '-avg-pert')] / df[(models[0], attack + '-avg-pert')]\n",
    "    df = df[col_names]\n",
    "    df = df.rename(index=str, columns={\n",
    "        attack + \"-aug-len\": \"\\# train\",\n",
    "        attack + \"-tst-score\": parbox(8, \"test \\\\\\\\ accuracy\"),\n",
    "        attack + \"-avg-pert\": parbox(9, \"ER\"),\n",
    "        attack + \"-imp\": \"\\\\defenderscore\",\n",
    "    })\n",
    "    return df\n",
    "\n",
    "def postprocess(task_fn, df, rename_columns=None, caption=None):\n",
    "    _, exp_name, grid_param, _ = task_fn()\n",
    "    if rename_columns:\n",
    "        df = df.rename(index=str, columns=rename_columns)\n",
    "    df = df.apply(structure_fonts)\n",
    "    table_str = table_wrapper(df, table_name=exp_name, caption=caption)\n",
    "    #table_str = df.to_latex(escape=False)\n",
    "    table_str = table_str.replace(\"{l}\", \"{c}\")\n",
    "    table_str = table_str.replace(\"llllllllllllllll\", \"lccc|cccc|cccc|cccc\")\n",
    "    table_str = table_str.replace(\"$.000$\", \"-\")\n",
    "    table_str = table_str.replace(\"$nan$\", \"-\")\n",
    "    table_str = table_str.replace(\"begin{table}\", \"begin{table*}\")\n",
    "    table_str = table_str.replace(\"end{table}\", \"end{table*}\")\n",
    "    return table_str"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "caption_template = \"\"\"\n",
    "The number of training data left after adversarial pruning (AP), testing accuracy, empirical robustness,\n",
    "and \\\\defenderscore with different separation parameter of AP for {}.\n",
    "\"\"\"\n",
    "\"\"\"\n",
    "Testing accuracy is a sanity check that we are not giving away all accuracy for robustness.\n",
    "The higher the empirical robustness is means the classifier is more robust to the given attack.\n",
    "When considering the strength of the attack, empirical robustness is lower the better.\n",
    "When considering the strength of the defense, \\\\defenderscore is higher the better.\n",
    "For \\\\defenderscore higher mean that after defense (AP), the classifier become more robust, thus higher the better.\n",
    "\"\"\"\n",
    "\n",
    "fn = nn_k1_robustness\n",
    "df = process(fn)\n",
    "caption = caption_template.format(\"1-NN\")\n",
    "rename_columns = {\n",
    "    \"knn1\": \"1-NN\",\n",
    "    \"advPruning-nn-k1-10\": \"AP (separation parameter $r$=.1)\",\n",
    "    \"advPruning-nn-k1-30\": \"AP (separation parameter $r$=.3)\",\n",
    "    \"advPruning-nn-k1-50\": \"AP (separation parameter $r$=.5)\",\n",
    "}\n",
    "_, exp_name, _, _ = fn()\n",
    "table_str = postprocess(fn, df, rename_columns, caption)\n",
    "write_to_tex(table_str, exp_name + '_table.tex')\n",
    "\n",
    "fn = nn_k3_robustness\n",
    "df = process(fn)\n",
    "caption = caption_template.format(\"3-NN\")\n",
    "rename_columns = {\n",
    "    \"knn3\": \"3-NN\",\n",
    "    \"advPruning-nn-k3-10\": \"AP (separation parameter $r$=.1)\",\n",
    "    \"advPruning-nn-k3-30\": \"AP (separation parameter $r$=.3)\",\n",
    "    \"advPruning-nn-k3-50\": \"AP (separation parameter $r$=.5)\",\n",
    "}\n",
    "_, exp_name, _, _ = fn()\n",
    "table_str = postprocess(fn, df, rename_columns, caption)\n",
    "write_to_tex(table_str, exp_name + '_table.tex')\n",
    "\n",
    "fn = dt_robustness\n",
    "df = process(fn)\n",
    "caption = caption_template.format(\"DT\")\n",
    "rename_columns = {\n",
    "    \"decision-tree-d5\": \"DT\",\n",
    "    \"advPruning-decision-tree-d5-10\": \"AP (separation parameter $r$=.1)\",\n",
    "    \"advPruning-decision-tree-d5-30\": \"AP (separation parameter $r$=.3)\",\n",
    "    \"advPruning-decision-tree-d5-50\": \"AP (separation parameter $r$=.5)\",\n",
    "}\n",
    "_, exp_name, _, _ = fn()\n",
    "table_str = postprocess(fn, df, rename_columns, caption)\n",
    "write_to_tex(table_str, exp_name + '_table.tex')\n",
    "\n",
    "fn = rf_robustness\n",
    "df = process(fn)\n",
    "caption = caption_template.format(\"RF\")\n",
    "rename_columns = {\n",
    "    \"random-forest-100-d5\": \"RF\",\n",
    "    \"advPruning-rf-100-10-d5\": \"AP (separation parameter $r$=.1)\",\n",
    "    \"advPruning-rf-100-30-d5\": \"AP (separation parameter $r$=.3)\",\n",
    "    \"advPruning-rf-100-50-d5\": \"AP (separation parameter $r$=.5)\",\n",
    "}\n",
    "_, exp_name, _, _ = fn()\n",
    "table_str = postprocess(fn, df, rename_columns, caption)\n",
    "write_to_tex(table_str, exp_name + '_table.tex')\n",
    "\n",
    "#fn = lr_ap_robustness\n",
    "#df = process(fn)\n",
    "#caption = caption_template.format(\"LR\")\n",
    "#rename_columns = {\n",
    "#    \"logistic-regression\": \"LR\",\n",
    "#    \"advPruning-logistic-regression-10\": \"AP (separation parameter $r$=.1)\",\n",
    "#    \"advPruning-logistic-regression-30\": \"AP (separation parameter $r$=.3)\",\n",
    "#    \"advPruning-logistic-regression-50\": \"AP (separation parameter $r$=.5)\",\n",
    "#}\n",
    "#_, exp_name, _, _ = fn()\n",
    "#table_str = postprocess(fn, df, rename_columns, caption)\n",
    "#write_to_tex(table_str, exp_name + '_table.tex')\n",
    "#\n",
    "#fn = mlp_ap_robustness\n",
    "#df = process(fn)\n",
    "#caption = caption_template.format(\"MLP\")\n",
    "#rename_columns = {\n",
    "#    \"mlp\": \"MLP\",\n",
    "#    \"advPruning-mlp-10\": \"AP (separation parameter $r$=.1)\",\n",
    "#    \"advPruning-mlp-30\": \"AP (separation parameter $r$=.3)\",\n",
    "#    \"advPruning-mlp-50\": \"AP (separation parameter $r$=.5)\",\n",
    "#}\n",
    "#_, exp_name, _, _ = fn()\n",
    "#table_str = postprocess(fn, df, rename_columns, caption)\n",
    "#write_to_tex(table_str, exp_name + '_table.tex')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "./results/fashion-mnist35f-pca25-adv-nn-k1-30-RBA-Exact-KNN-k1-rs0-linf.json doesn't exist\n",
      "./results/fashion-mnist06f-pca25-adv-nn-k1-30-RBA-Exact-KNN-k1-rs0-linf.json doesn't exist\n",
      "./results/mnist17f-pca25-adv-nn-k1-30-RBA-Exact-KNN-k1-rs0-linf.json doesn't exist\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead tr th {\n",
       "        text-align: left;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr>\n",
       "      <th></th>\n",
       "      <th colspan=\"4\" halign=\"left\">$1$-NN</th>\n",
       "      <th colspan=\"3\" halign=\"left\">$3$-NN</th>\n",
       "      <th colspan=\"4\" halign=\"left\">DT</th>\n",
       "      <th colspan=\"4\" halign=\"left\">RF</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th></th>\n",
       "      <th>AP</th>\n",
       "      <th>AT</th>\n",
       "      <th>nature</th>\n",
       "      <th>Wang's</th>\n",
       "      <th>AP</th>\n",
       "      <th>AT</th>\n",
       "      <th>nature</th>\n",
       "      <th>AP</th>\n",
       "      <th>AT</th>\n",
       "      <th>nature</th>\n",
       "      <th>RS</th>\n",
       "      <th>AP</th>\n",
       "      <th>AT</th>\n",
       "      <th>nature</th>\n",
       "      <th>RS</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>australian</th>\n",
       "      <td>0.36</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.06</td>\n",
       "      <td>0.36</td>\n",
       "      <td>0.75</td>\n",
       "      <td>0.01</td>\n",
       "      <td>0.48</td>\n",
       "      <td>0.08</td>\n",
       "      <td>0.10</td>\n",
       "      <td>0.02</td>\n",
       "      <td>0.81</td>\n",
       "      <td>0.89</td>\n",
       "      <td>0.93</td>\n",
       "      <td>0.87</td>\n",
       "      <td>1.00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>cancer</th>\n",
       "      <td>0.06</td>\n",
       "      <td>0.03</td>\n",
       "      <td>0.05</td>\n",
       "      <td>0.06</td>\n",
       "      <td>0.38</td>\n",
       "      <td>0.09</td>\n",
       "      <td>0.17</td>\n",
       "      <td>0.47</td>\n",
       "      <td>0.19</td>\n",
       "      <td>0.23</td>\n",
       "      <td>0.24</td>\n",
       "      <td>0.90</td>\n",
       "      <td>0.42</td>\n",
       "      <td>0.54</td>\n",
       "      <td>0.93</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>covtypebin_2200</th>\n",
       "      <td>0.59</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.03</td>\n",
       "      <td>0.59</td>\n",
       "      <td>0.77</td>\n",
       "      <td>0.09</td>\n",
       "      <td>0.10</td>\n",
       "      <td>0.30</td>\n",
       "      <td>0.02</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.11</td>\n",
       "      <td>0.90</td>\n",
       "      <td>0.26</td>\n",
       "      <td>0.23</td>\n",
       "      <td>0.18</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>diabetes</th>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.18</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.10</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.92</td>\n",
       "      <td>0.13</td>\n",
       "      <td>0.05</td>\n",
       "      <td>0.29</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>fashion_mnist06f_pca25</th>\n",
       "      <td>0.00</td>\n",
       "      <td>NaN</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.19</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.21</td>\n",
       "      <td>0.61</td>\n",
       "      <td>0.14</td>\n",
       "      <td>0.16</td>\n",
       "      <td>0.16</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>fashion_mnist35f_pca25</th>\n",
       "      <td>0.00</td>\n",
       "      <td>NaN</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.27</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.61</td>\n",
       "      <td>0.76</td>\n",
       "      <td>0.27</td>\n",
       "      <td>0.36</td>\n",
       "      <td>0.60</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>fourclass</th>\n",
       "      <td>0.37</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.37</td>\n",
       "      <td>0.38</td>\n",
       "      <td>0.01</td>\n",
       "      <td>0.02</td>\n",
       "      <td>0.67</td>\n",
       "      <td>0.10</td>\n",
       "      <td>0.07</td>\n",
       "      <td>0.12</td>\n",
       "      <td>0.77</td>\n",
       "      <td>0.04</td>\n",
       "      <td>0.05</td>\n",
       "      <td>0.32</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>halfmoon_2200</th>\n",
       "      <td>0.02</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.02</td>\n",
       "      <td>0.06</td>\n",
       "      <td>0.01</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.27</td>\n",
       "      <td>0.04</td>\n",
       "      <td>0.03</td>\n",
       "      <td>0.08</td>\n",
       "      <td>0.47</td>\n",
       "      <td>0.07</td>\n",
       "      <td>0.06</td>\n",
       "      <td>0.02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>mnist17f_pca25</th>\n",
       "      <td>0.00</td>\n",
       "      <td>NaN</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.06</td>\n",
       "      <td>0.07</td>\n",
       "      <td>0.01</td>\n",
       "      <td>0.85</td>\n",
       "      <td>0.52</td>\n",
       "      <td>0.12</td>\n",
       "      <td>0.27</td>\n",
       "      <td>0.23</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                       $1$-NN                     $3$-NN                 DT  \\\n",
       "                           AP    AT nature Wang's     AP    AT nature    AP   \n",
       "australian               0.36  0.00   0.06   0.36   0.75  0.01   0.48  0.08   \n",
       "cancer                   0.06  0.03   0.05   0.06   0.38  0.09   0.17  0.47   \n",
       "covtypebin_2200          0.59  0.00   0.03   0.59   0.77  0.09   0.10  0.30   \n",
       "diabetes                 0.00  0.00   0.00   0.00   0.18  0.00   0.00  0.10   \n",
       "fashion_mnist06f_pca25   0.00   NaN   0.00   0.00   0.00  0.00   0.00  0.19   \n",
       "fashion_mnist35f_pca25   0.00   NaN   0.00   0.00   0.00  0.00   0.00  0.27   \n",
       "fourclass                0.37  0.00   0.00   0.37   0.38  0.01   0.02  0.67   \n",
       "halfmoon_2200            0.02  0.00   0.00   0.02   0.06  0.01   0.00  0.27   \n",
       "mnist17f_pca25           0.00   NaN   0.00   0.00   0.00  0.00   0.00  0.06   \n",
       "\n",
       "                                             RF                     \n",
       "                          AT nature    RS    AP    AT nature    RS  \n",
       "australian              0.10   0.02  0.81  0.89  0.93   0.87  1.00  \n",
       "cancer                  0.19   0.23  0.24  0.90  0.42   0.54  0.93  \n",
       "covtypebin_2200         0.02   0.00  0.11  0.90  0.26   0.23  0.18  \n",
       "diabetes                0.00   0.00  0.00  0.92  0.13   0.05  0.29  \n",
       "fashion_mnist06f_pca25  0.00   0.00  0.21  0.61  0.14   0.16  0.16  \n",
       "fashion_mnist35f_pca25  0.00   0.00  0.61  0.76  0.27   0.36  0.60  \n",
       "fourclass               0.10   0.07  0.12  0.77  0.04   0.05  0.32  \n",
       "halfmoon_2200           0.04   0.03  0.08  0.47  0.07   0.06  0.02  \n",
       "mnist17f_pca25          0.07   0.01  0.85  0.52  0.12   0.27  0.23  "
      ]
     },
     "execution_count": 50,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "_, exp_name, grid_params, _ = compare_defense()\n",
    "del grid_params[1]\n",
    "\n",
    "rename_columns = {\n",
    "    \"knn1\": (\"$1$-NN\", \"nature\"),\n",
    "    \"adv_nn_k1_30\": (\"$1$-NN\", \"AT\"),\n",
    "    \"robustv2_nn_k1_30\": (\"$1$-NN\", \"Wang's\"),\n",
    "    \"advPruning_nn_k1_30\": (\"$1$-NN\", \"AP\"),\n",
    "    \n",
    "    \"knn3\": (\"$3$-NN\", \"nature\"),\n",
    "    \"adv_nn_k3_30\": (\"$3$-NN\", \"AT\"),\n",
    "    \"advPruning_nn_k3_30\": (\"$3$-NN\", \"AP\"),\n",
    "    \n",
    "    \"decision_tree_d5\": (\"DT\", \"nature\"),\n",
    "    \"adv_decision_tree_d5_30\": (\"DT\", \"AT\"),\n",
    "    \"robust_decision_tree_d5_30\": (\"DT\", \"RS\"),\n",
    "    \"advPruning_decision_tree_d5_30\": (\"DT\", \"AP\"),\n",
    "    \n",
    "    \"random_forest_100_d5\": (\"RF\", \"nature\"),\n",
    "    \"adv_rf_100_30_d5\": (\"RF\", \"AT\"),\n",
    "    \"robust_rf_100_30_d5\": (\"RF\", \"RS\"),\n",
    "    \"advPruning_rf_100_30_d5\": (\"RF\", \"AP\"),\n",
    "}\n",
    "\n",
    "def fun(elt):\n",
    "    if not elt:\n",
    "        return elt\n",
    "    for e in elt:\n",
    "        if e['eps'] == 0.3:\n",
    "            return e['tst_acc']\n",
    "    return elt\n",
    "\n",
    "table = []\n",
    "for grid_param in grid_params:\n",
    "    df = params_to_dataframe(grid_param, columns=[(\"tst_score\"), 'results'])\n",
    "    df['results'] = df['results'].map(fun)\n",
    "    ret = {}\n",
    "    for dataset, d in df.groupby(\"dataset\"):\n",
    "        ret.setdefault(dataset, {})\n",
    "        for model_name, md in d.groupby(\"model\"):\n",
    "            ret[dataset][rename_columns[model_name]] = md['results'].values[0]\n",
    "    table.append(pd.DataFrame.from_dict(ret, orient='index'))\n",
    "pd.concat(table, axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead tr th {\n",
       "        text-align: left;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr>\n",
       "      <th></th>\n",
       "      <th colspan=\"4\" halign=\"left\">RF</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th></th>\n",
       "      <th>AP</th>\n",
       "      <th>AT</th>\n",
       "      <th>nature</th>\n",
       "      <th>RS</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>australian</th>\n",
       "      <td>0.89</td>\n",
       "      <td>0.93</td>\n",
       "      <td>0.87</td>\n",
       "      <td>1.00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>cancer</th>\n",
       "      <td>0.90</td>\n",
       "      <td>0.42</td>\n",
       "      <td>0.54</td>\n",
       "      <td>0.93</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>covtypebin_2200</th>\n",
       "      <td>0.90</td>\n",
       "      <td>0.26</td>\n",
       "      <td>0.23</td>\n",
       "      <td>0.18</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>diabetes</th>\n",
       "      <td>0.92</td>\n",
       "      <td>0.13</td>\n",
       "      <td>0.05</td>\n",
       "      <td>0.29</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>fashion_mnist06f_pca25</th>\n",
       "      <td>0.61</td>\n",
       "      <td>0.14</td>\n",
       "      <td>0.16</td>\n",
       "      <td>0.16</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>fashion_mnist35f_pca25</th>\n",
       "      <td>0.76</td>\n",
       "      <td>0.27</td>\n",
       "      <td>0.36</td>\n",
       "      <td>0.60</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>fourclass</th>\n",
       "      <td>0.77</td>\n",
       "      <td>0.04</td>\n",
       "      <td>0.05</td>\n",
       "      <td>0.32</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>halfmoon_2200</th>\n",
       "      <td>0.47</td>\n",
       "      <td>0.07</td>\n",
       "      <td>0.06</td>\n",
       "      <td>0.02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>mnist17f_pca25</th>\n",
       "      <td>0.52</td>\n",
       "      <td>0.12</td>\n",
       "      <td>0.27</td>\n",
       "      <td>0.23</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                          RF                   \n",
       "                          AP    AT nature    RS\n",
       "australian              0.89  0.93   0.87  1.00\n",
       "cancer                  0.90  0.42   0.54  0.93\n",
       "covtypebin_2200         0.90  0.26   0.23  0.18\n",
       "diabetes                0.92  0.13   0.05  0.29\n",
       "fashion_mnist06f_pca25  0.61  0.14   0.16  0.16\n",
       "fashion_mnist35f_pca25  0.76  0.27   0.36  0.60\n",
       "fourclass               0.77  0.04   0.05  0.32\n",
       "halfmoon_2200           0.47  0.07   0.06  0.02\n",
       "mnist17f_pca25          0.52  0.12   0.27  0.23"
      ]
     },
     "execution_count": 49,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "table[4]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "def process2(task_fn):\n",
    "    _, exp_name, grid_param, _ = task_fn()\n",
    "    df = gen_table(exp_name, grid_param, ['model', 'attack'], ['dataset'],\n",
    "                   combine_method=1, objs=['tst_score', 'avg_pert'], additionals=[])\n",
    "\n",
    "    models = [i.replace(\"_\", \"-\") for i in union_param_key(grid_param, 'model')]\n",
    "    attack = grid_param[0]['attack'][0].replace(\"_\", '-')\n",
    "    col_names = [a for a in preprocess(grid_param) if '-aug-len' not in a[1]]\n",
    "\n",
    "    df = df.apply(lambda a: a.apply(lambda b: float(str(b).replace(\"$\", \"\")) if b else b))\n",
    "    for model in models[1:]:\n",
    "        df[(model, attack + '-imp')] = df[(model, attack + '-avg-pert')] / df[(models[0], attack + '-avg-pert')]\n",
    "    df = df[col_names]\n",
    "    df = df.rename(index=str, columns={\n",
    "        attack + \"-tst-score\": parbox(8, \"testing \\\\\\\\ accuracy\"),\n",
    "        attack + \"-avg-pert\": parbox(9, \"empirical \\\\\\\\ robustness\"),\n",
    "        attack + \"-imp\": \"\\\\defenderscore\",\n",
    "    })\n",
    "    return df\n",
    "\n",
    "def postprocess2(task_fn, df, rename_columns, caption):\n",
    "    _, exp_name, grid_param, _ = task_fn()\n",
    "    df = df.rename(index=str, columns=rename_columns)\n",
    "    df = df.apply(structure_fonts)\n",
    "    table_str = table_wrapper(df, table_name=exp_name, caption=caption,)\n",
    "    #table_str = df.to_latex(escape=False)\n",
    "    table_str = table_str.replace(\"{l}\", \"{c}\")\n",
    "    table_str = table_str.replace(\"llllllllllll\", \"lcc|ccc|ccc|ccc\")\n",
    "    table_str = table_str.replace(\"$.000$\", \"-\")\n",
    "    table_str = table_str.replace(\"$nan$\", \"-\")\n",
    "    return table_str"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "caption_template = \"\"\"\n",
    "The testing accuracy, empirical robustness,\n",
    "and \\\\defenderscore with differen attack distance of adversarial training (AT) for {}.\n",
    "\"\"\"\n",
    "\"\"\"\n",
    "Testing accuracy is a sanity check that we are not giving away all accuracy for robustness.\n",
    "The higher the empirical robustness is means the classifier is more robust to the given attack.\n",
    "For \\\\defenderscore higher mean that after defense (AP), the classifier become more robust, thus higher the better.\n",
    "\"\"\"\n",
    "\n",
    "#fn = mlp_at_robustness\n",
    "#df = process2(fn)\n",
    "#caption = caption_template.format(\"MLP\")\n",
    "#rename_columns = {\n",
    "#    \"mlp\": \"MLP\",\n",
    "#    \"adv-mlp-10\": \"AT (attack distance $r$=.1)\",\n",
    "#    \"adv-mlp-30\": \"AT (attack distance $r$=.3)\",\n",
    "#    \"adv-mlp-50\": \"AT (attack distance $r$=.5)\",\n",
    "#}\n",
    "#_, exp_name, _, _ = fn()\n",
    "#table_str = postprocess2(fn, df, rename_columns, caption)\n",
    "#write_to_tex(table_str, exp_name + '_table.tex')\n",
    "#\n",
    "#fn = lr_at_robustness\n",
    "#df = process2(fn)\n",
    "#caption = caption_template.format(\"LR\")\n",
    "#rename_columns = {\n",
    "#    \"logistic-regression\": \"LR\",\n",
    "#    \"adv-logistic-regression-10\": \"AT (attack distance $r$=.1)\",\n",
    "#    \"adv-logistic-regression-30\": \"AT (attack distance $r$=.3)\",\n",
    "#    \"adv-logistic-regression-50\": \"AT (attack distance $r$=.5)\",\n",
    "#}\n",
    "#_, exp_name, _, _ = fn()\n",
    "#table_str = postprocess2(fn, df, rename_columns, caption)\n",
    "#write_to_tex(table_str, exp_name + '_table.tex')\n",
    "#\n",
    "#\n",
    "#fn = params_l2.mlp_at_robustness()\n",
    "#df = process2(fn)\n",
    "#caption = caption_template.format(\"MLP\")\n",
    "#rename_columns = {\n",
    "#    \"mlp\": \"MLP\",\n",
    "#    \"adv-mlp-25\": \"AT (attack distance $r$=.25)\",\n",
    "#    \"adv-mlp-50\": \"AT (attack distance $r$=.50)\",\n",
    "#    \"adv-mlp-75\": \"AT (attack distance $r$=.75)\",\n",
    "#}\n",
    "#_, exp_name, _, _ = fn()\n",
    "#table_str = postprocess2(fn, df, rename_columns, caption)\n",
    "#write_to_tex(table_str, exp_name + '_table.tex')\n",
    "#\n",
    "#fn = params_l2.lr_at_robustness()\n",
    "#df = process2(fn)\n",
    "#caption = caption_template.format(\"LR\")\n",
    "#rename_columns = {\n",
    "#    \"logistic-regression\": \"LR\",\n",
    "#    \"adv-logistic-regression-25\": \"AT (attack distance $r$=.25)\",\n",
    "#    \"adv-logistic-regression-50\": \"AT (attack distance $r$=.50)\",\n",
    "#    \"adv-logistic-regression-75\": \"AT (attack distance $r$=.75)\",\n",
    "#}\n",
    "#_, exp_name, _, _ = fn()\n",
    "#table_str = postprocess2(fn, df, rename_columns, caption)\n",
    "#write_to_tex(table_str, exp_name + '_table.tex')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "./results/mnist17-2200-pca25-adv-nn-k1-75-RBA-Exact-KNN-k1-rs0-l2.json doesn't exist\n",
      "./results/australian-adv-nn-k3-50-RBA-Approx-KNN-k3-50-rs0-l2.json doesn't exist\n",
      "./results/diabetes-adv-nn-k3-25-RBA-Approx-KNN-k3-50-rs0-l2.json doesn't exist\n",
      "./results/diabetes-robustv2-nn-k3-25-RBA-Approx-KNN-k3-50-rs0-l2.json doesn't exist\n",
      "./results/fashion-mnist35-2200-pca25-adv-nn-k3-75-RBA-Approx-KNN-k3-50-rs0-l2.json doesn't exist\n",
      "./results/fashion-mnist06-2200-pca25-adv-nn-k3-75-RBA-Approx-KNN-k3-50-rs0-l2.json doesn't exist\n",
      "./results/fashion-mnist06-2200-pca25-robustv2-nn-k3-75-RBA-Approx-KNN-k3-50-rs0-l2.json doesn't exist\n",
      "./results/mnist17-2200-pca25-robustv2-nn-k3-75-RBA-Approx-KNN-k3-50-rs0-l2.json doesn't exist\n",
      "./results/fourclass-robust-rf-100-25-d5-RBA-Approx-RF-100-rs0-l2.json doesn't exist\n",
      "./results/diabetes-robust-rf-100-25-d5-RBA-Approx-RF-100-rs0-l2.json doesn't exist\n",
      "./results/covtypebin-10200-robust-rf-100-50-d5-RBA-Approx-RF-100-rs0-l2.json doesn't exist\n",
      "./results/diabetes-adv-decision-tree-d5-25-RBA-Exact-DT-rs0-l2.json doesn't exist\n",
      "./results/diabetes-robust-decision-tree-d5-25-RBA-Exact-DT-rs0-l2.json doesn't exist\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead tr th {\n",
       "        text-align: left;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr>\n",
       "      <th></th>\n",
       "      <th colspan=\"3\" halign=\"left\">1-NN</th>\n",
       "      <th colspan=\"2\" halign=\"left\">3-NN</th>\n",
       "      <th colspan=\"3\" halign=\"left\">RF</th>\n",
       "      <th colspan=\"3\" halign=\"left\">DT</th>\n",
       "      <th colspan=\"2\" halign=\"left\">LR</th>\n",
       "      <th colspan=\"2\" halign=\"left\">MLP</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th></th>\n",
       "      <th>AT</th>\n",
       "      <th>Wang's</th>\n",
       "      <th>AP</th>\n",
       "      <th>AT</th>\n",
       "      <th>AP</th>\n",
       "      <th>AT</th>\n",
       "      <th>RS</th>\n",
       "      <th>AP</th>\n",
       "      <th>AT</th>\n",
       "      <th>RS</th>\n",
       "      <th>AP</th>\n",
       "      <th>AT</th>\n",
       "      <th>AP</th>\n",
       "      <th>AT</th>\n",
       "      <th>AP</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>australian</th>\n",
       "      <td>1.190647</td>\n",
       "      <td>1.607914</td>\n",
       "      <td>1.607914</td>\n",
       "      <td>-2.252252</td>\n",
       "      <td>1.200450</td>\n",
       "      <td>1.041441</td>\n",
       "      <td>0.963964</td>\n",
       "      <td>1.050450</td>\n",
       "      <td>2.120000</td>\n",
       "      <td>5.840000</td>\n",
       "      <td>3.373333</td>\n",
       "      <td>3.063745</td>\n",
       "      <td>1.382470</td>\n",
       "      <td>9.397436</td>\n",
       "      <td>5.371795</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>cancer</th>\n",
       "      <td>1.067692</td>\n",
       "      <td>1.040000</td>\n",
       "      <td>1.403077</td>\n",
       "      <td>0.979339</td>\n",
       "      <td>1.351240</td>\n",
       "      <td>1.008837</td>\n",
       "      <td>1.023564</td>\n",
       "      <td>1.235641</td>\n",
       "      <td>0.912052</td>\n",
       "      <td>1.130293</td>\n",
       "      <td>1.100977</td>\n",
       "      <td>1.448517</td>\n",
       "      <td>1.024433</td>\n",
       "      <td>1.847495</td>\n",
       "      <td>1.418301</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>covtype</th>\n",
       "      <td>1.280899</td>\n",
       "      <td>2.808989</td>\n",
       "      <td>2.808989</td>\n",
       "      <td>1.056391</td>\n",
       "      <td>2.218045</td>\n",
       "      <td>1.173010</td>\n",
       "      <td>-3.460208</td>\n",
       "      <td>2.117647</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>5.028169</td>\n",
       "      <td>4.802817</td>\n",
       "      <td>4.207547</td>\n",
       "      <td>2.141509</td>\n",
       "      <td>3.728571</td>\n",
       "      <td>3.871429</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>diabetes</th>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.855263</td>\n",
       "      <td>1.855263</td>\n",
       "      <td>-6.451613</td>\n",
       "      <td>1.496774</td>\n",
       "      <td>1.077181</td>\n",
       "      <td>-3.355705</td>\n",
       "      <td>1.197987</td>\n",
       "      <td>-10.204082</td>\n",
       "      <td>-10.204082</td>\n",
       "      <td>1.336735</td>\n",
       "      <td>1.820388</td>\n",
       "      <td>2.873786</td>\n",
       "      <td>3.846774</td>\n",
       "      <td>1.548387</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>f-mnist06</th>\n",
       "      <td>1.067164</td>\n",
       "      <td>2.305970</td>\n",
       "      <td>2.305970</td>\n",
       "      <td>-4.149378</td>\n",
       "      <td>1.597510</td>\n",
       "      <td>0.804813</td>\n",
       "      <td>0.831551</td>\n",
       "      <td>1.566845</td>\n",
       "      <td>0.840909</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>1.795455</td>\n",
       "      <td>1.184783</td>\n",
       "      <td>1.713768</td>\n",
       "      <td>2.077922</td>\n",
       "      <td>3.168831</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>f-mnist35</th>\n",
       "      <td>0.776163</td>\n",
       "      <td>1.078488</td>\n",
       "      <td>1.078488</td>\n",
       "      <td>-2.469136</td>\n",
       "      <td>1.034568</td>\n",
       "      <td>0.634660</td>\n",
       "      <td>1.550351</td>\n",
       "      <td>1.063232</td>\n",
       "      <td>1.328000</td>\n",
       "      <td>2.624000</td>\n",
       "      <td>1.856000</td>\n",
       "      <td>1.266055</td>\n",
       "      <td>1.103976</td>\n",
       "      <td>0.862385</td>\n",
       "      <td>1.018349</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>fourclass</th>\n",
       "      <td>0.971963</td>\n",
       "      <td>3.261682</td>\n",
       "      <td>3.261682</td>\n",
       "      <td>0.823009</td>\n",
       "      <td>2.991150</td>\n",
       "      <td>0.866667</td>\n",
       "      <td>-6.666667</td>\n",
       "      <td>3.420000</td>\n",
       "      <td>1.437500</td>\n",
       "      <td>1.194444</td>\n",
       "      <td>2.909722</td>\n",
       "      <td>1.348921</td>\n",
       "      <td>1.366906</td>\n",
       "      <td>2.142857</td>\n",
       "      <td>2.590909</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>halfmoon</th>\n",
       "      <td>1.075758</td>\n",
       "      <td>3.106061</td>\n",
       "      <td>3.106061</td>\n",
       "      <td>0.916667</td>\n",
       "      <td>2.120370</td>\n",
       "      <td>0.950920</td>\n",
       "      <td>1.226994</td>\n",
       "      <td>1.834356</td>\n",
       "      <td>1.033708</td>\n",
       "      <td>1.022472</td>\n",
       "      <td>1.719101</td>\n",
       "      <td>0.823256</td>\n",
       "      <td>1.088372</td>\n",
       "      <td>1.447761</td>\n",
       "      <td>1.701493</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>mnist17</th>\n",
       "      <td>-3.225806</td>\n",
       "      <td>1.158065</td>\n",
       "      <td>1.158065</td>\n",
       "      <td>0.729947</td>\n",
       "      <td>1.082888</td>\n",
       "      <td>0.644295</td>\n",
       "      <td>0.288591</td>\n",
       "      <td>1.020134</td>\n",
       "      <td>1.204545</td>\n",
       "      <td>2.469697</td>\n",
       "      <td>1.356061</td>\n",
       "      <td>1.166667</td>\n",
       "      <td>1.056667</td>\n",
       "      <td>1.056022</td>\n",
       "      <td>0.913165</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                1-NN                          3-NN                  RF  \\\n",
       "                  AT    Wang's        AP        AT        AP        AT   \n",
       "australian  1.190647  1.607914  1.607914 -2.252252  1.200450  1.041441   \n",
       "cancer      1.067692  1.040000  1.403077  0.979339  1.351240  1.008837   \n",
       "covtype     1.280899  2.808989  2.808989  1.056391  2.218045  1.173010   \n",
       "diabetes    1.000000  1.855263  1.855263 -6.451613  1.496774  1.077181   \n",
       "f-mnist06   1.067164  2.305970  2.305970 -4.149378  1.597510  0.804813   \n",
       "f-mnist35   0.776163  1.078488  1.078488 -2.469136  1.034568  0.634660   \n",
       "fourclass   0.971963  3.261682  3.261682  0.823009  2.991150  0.866667   \n",
       "halfmoon    1.075758  3.106061  3.106061  0.916667  2.120370  0.950920   \n",
       "mnist17    -3.225806  1.158065  1.158065  0.729947  1.082888  0.644295   \n",
       "\n",
       "                                       DT                             LR  \\\n",
       "                  RS        AP         AT         RS        AP        AT   \n",
       "australian  0.963964  1.050450   2.120000   5.840000  3.373333  3.063745   \n",
       "cancer      1.023564  1.235641   0.912052   1.130293  1.100977  1.448517   \n",
       "covtype    -3.460208  2.117647   1.000000   5.028169  4.802817  4.207547   \n",
       "diabetes   -3.355705  1.197987 -10.204082 -10.204082  1.336735  1.820388   \n",
       "f-mnist06   0.831551  1.566845   0.840909   0.000000  1.795455  1.184783   \n",
       "f-mnist35   1.550351  1.063232   1.328000   2.624000  1.856000  1.266055   \n",
       "fourclass  -6.666667  3.420000   1.437500   1.194444  2.909722  1.348921   \n",
       "halfmoon    1.226994  1.834356   1.033708   1.022472  1.719101  0.823256   \n",
       "mnist17     0.288591  1.020134   1.204545   2.469697  1.356061  1.166667   \n",
       "\n",
       "                           MLP            \n",
       "                  AP        AT        AP  \n",
       "australian  1.382470  9.397436  5.371795  \n",
       "cancer      1.024433  1.847495  1.418301  \n",
       "covtype     2.141509  3.728571  3.871429  \n",
       "diabetes    2.873786  3.846774  1.548387  \n",
       "f-mnist06   1.713768  2.077922  3.168831  \n",
       "f-mnist35   1.103976  0.862385  1.018349  \n",
       "fourclass   1.366906  2.142857  2.590909  \n",
       "halfmoon    1.088372  1.447761  1.701493  \n",
       "mnist17     1.056667  1.056022  0.913165  "
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def def_process(task_fn):\n",
    "    _, exp_name, grid_param, _ = task_fn()\n",
    "    df = gen_table(exp_name, grid_param, ['model', 'attack'], ['dataset'],\n",
    "                   combine_method=0, objs=['tst_score', 'avg_pert', 'aug_len'],\n",
    "                   additionals=[])\n",
    "\n",
    "    models = [i.replace(\"_\", \"-\") for i in union_param_key(grid_param, 'model')]\n",
    "    attack = grid_param[0]['attack'][0].replace(\"_\", '-')\n",
    "    col_names = preprocess(grid_param)\n",
    "\n",
    "    df = df.apply(lambda a: a.apply(lambda b: float(str(b).replace(\"$\", \"\")) if b else b))\n",
    "    for model in models[1:]:\n",
    "        df[(model, attack + '-imp')] = df[(model, attack + '-avg-pert')] / df[(models[0], attack + '-avg-pert')]\n",
    "    df = df[col_names]\n",
    "    df = df.rename(index=str, columns={\n",
    "        attack + \"-aug-len\": \"\\# train\",\n",
    "        attack + \"-tst-score\": parbox(8, \"test \\\\\\\\ accuracy\"),\n",
    "        attack + \"-avg-pert\": parbox(9, \"ER\"),\n",
    "        attack + \"-imp\": \"\\\\defenderscore\",\n",
    "    })\n",
    "    return df\n",
    "\n",
    "\n",
    "#ds = ['australian', 'cancer', 'covtype', 'diabetes', 'f-mnist06', 'f-mnist35', 'fourclass', 'halfmoon', 'mnist17']\n",
    "#table = OrderedDict()\n",
    "#ds_eps = params_l2.ds_eps\n",
    "#\n",
    "#df = def_process(params_l2.nn1_def())\n",
    "#ds_def = {'AT': np.zeros(len(ds)), \"Wang's\": np.zeros(len(ds)), 'AP': np.zeros(len(ds))}\n",
    "#for idx in df['knn1'].index:\n",
    "#    ds_name = idx[0]\n",
    "#    ds_idx = ds.index(idx[0])\n",
    "#    ds_def['AT'][ds_idx] = df[f'adv-nn-k1-{ds_eps[ds_name]}']['\\defenderscore'][ds_name][0]\n",
    "#    ds_def[\"Wang's\"][ds_idx] = df[f'robustv2-nn-k1-{ds_eps[ds_name]}']['\\defenderscore'][ds_name][0]\n",
    "#    ds_def['AP'][ds_idx] = df[f'advPruning-nn-k1-{ds_eps[ds_name]}']['\\defenderscore'][ds_name][0]\n",
    "#table['1-NN'] = ds_def\n",
    "#\n",
    "#df = def_process(params_l2.nn3_def())\n",
    "#ds_def = {'AT': np.zeros(len(ds)), 'AP': np.zeros(len(ds))}\n",
    "#for idx in df['knn3'].index:\n",
    "#    ds_name = idx[0]\n",
    "#    ds_idx = ds.index(idx[0])\n",
    "#    ds_def['AT'][ds_idx] = df[f'adv-nn-k3-{ds_eps[ds_name]}']['\\defenderscore'][ds_name][0]\n",
    "#    ds_def['AP'][ds_idx] = df[f'advPruning-nn-k3-{ds_eps[ds_name]}']['\\defenderscore'][ds_name][0]\n",
    "#table['3-NN'] = ds_def\n",
    "#\n",
    "#df = def_process(params_l2.rf_def())\n",
    "#ds_def = OrderedDict([('AT', np.zeros(len(ds))), ('RS', np.zeros(len(ds))), ('AP', np.zeros(len(ds)))])\n",
    "#for idx in df['random-forest-100-d5'].index:\n",
    "#    ds_name = idx[0]\n",
    "#    ds_idx = ds.index(idx[0])\n",
    "#    ds_def['AT'][ds_idx] = df[f'adv-rf-100-{ds_eps[ds_name]}-d5']['\\defenderscore'][ds_name][0]\n",
    "#    ds_def['RS'][ds_idx] = df[f'robust-rf-100-{ds_eps[ds_name]}-d5']['\\defenderscore'][ds_name][0]\n",
    "#    ds_def['AP'][ds_idx] = df[f'advPruning-rf-100-{ds_eps[ds_name]}-d5']['\\defenderscore'][ds_name][0]\n",
    "#table['RF'] = ds_def\n",
    "#\n",
    "#df = def_process(params_l2.dt_def())\n",
    "#ds_def = OrderedDict([('AT', np.zeros(len(ds))), ('RS', np.zeros(len(ds))), ('AP', np.zeros(len(ds)))])\n",
    "#for idx in df['decision-tree-d5'].index:\n",
    "#    ds_name = idx[0]\n",
    "#    ds_idx = ds.index(idx[0])\n",
    "#    ds_def['AT'][ds_idx] = df[f'adv-decision-tree-d5-{ds_eps[ds_name]}']['\\defenderscore'][ds_name][0]\n",
    "#    ds_def['RS'][ds_idx] = df[f'robust-decision-tree-d5-{ds_eps[ds_name]}']['\\defenderscore'][ds_name][0]\n",
    "#    ds_def['AP'][ds_idx] = df[f'advPruning-decision-tree-d5-{ds_eps[ds_name]}']['\\defenderscore'][ds_name][0]\n",
    "#table['DT'] = ds_def\n",
    "#\n",
    "#df = def_process(params_l2.lr_def())\n",
    "#ds_def = OrderedDict([('AT', np.zeros(len(ds))), ('AP', np.zeros(len(ds)))])\n",
    "#for idx in df['logistic-regression'].index:\n",
    "#    ds_name = idx[0]\n",
    "#    ds_idx = ds.index(idx[0])\n",
    "#    ds_def['AT'][ds_idx] = df[f'adv-logistic-regression-{ds_eps[ds_name]}']['\\defenderscore'][ds_name][0]\n",
    "#    ds_def['AP'][ds_idx] = df[f'advPruning-logistic-regression-{ds_eps[ds_name]}']['\\defenderscore'][ds_name][0]\n",
    "#table['LR'] = ds_def\n",
    "#\n",
    "#df = def_process(params_l2.mlp_def())\n",
    "#ds_def = OrderedDict([('AT', np.zeros(len(ds))), ('AP', np.zeros(len(ds)))])\n",
    "#for idx in df['mlp'].index:\n",
    "#    ds_name = idx[0]\n",
    "#    ds_idx = ds.index(idx[0])\n",
    "#    ds_def['AT'][ds_idx] = df[f'adv-mlp-{ds_eps[ds_name]}']['\\defenderscore'][ds_name][0]\n",
    "#    ds_def['AP'][ds_idx] = df[f'advPruning-mlp-{ds_eps[ds_name]}']['\\defenderscore'][ds_name][0]\n",
    "#table['MLP'] = ds_def\n",
    "#\n",
    "#df = pd.DataFrame(flatten(table), index=ds)\n",
    "#df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "def to_str(s):\n",
    "    max_value = s.max()\n",
    "    ret = s.apply(lambda x: '$%.2f$' % x if x != max_value else '$\\\\mathbf{%.2f}$' % x)\n",
    "    return ret\n",
    "df = df.fillna(-1)\n",
    "for model_name in ['1-NN', '3-NN', 'DT', 'LR', 'MLP', 'RF']:\n",
    "    df[model_name] = df[model_name].apply(to_str, axis=1)\n",
    "table_str = df.to_latex(escape=False)\n",
    "table_str = table_str.replace(\"{l}\", \"{c}\")\n",
    "table_str = table_str.replace(\"llllllllllllllll\", \"lccc|cc|ccc|ccc|cc|cc\")\n",
    "table_str = table_str.replace(\"$.000$\", \"-\")\n",
    "table_str = table_str.replace(\"$nan$\", \"-\")\n",
    "\n",
    "write_to_tex(table_str, 'compare_defense_table_l2.tex')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "./results/fullmnist-pca100-xgb-xgb-models-prune-fullmnist-pca100-20-rf-linf.0200.model-blackbox-rs0-linf.json doesn't exist\n",
      "./results/fullmnist-pca100-xgb-xgb-models-fullmnist-pca100.0200.model-blackbox-rs0-linf.json doesn't exist\n",
      "./results/fullfashion-pca100-xgb-xgb-models-prune-fullfashion-pca100-20-rf-linf.0200.model-blackbox-rs0-linf.json doesn't exist\n",
      "./results/fullfashion-pca100-xgb-xgb-models-fullfashion-pca100.0200.model-blackbox-rs0-linf.json doesn't exist\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead tr th {\n",
       "        text-align: left;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr>\n",
       "      <th>model</th>\n",
       "      <th colspan=\"5\" halign=\"left\">RF</th>\n",
       "      <th colspan=\"5\" halign=\"left\">$1$-NN</th>\n",
       "      <th colspan=\"5\" halign=\"left\">$3$-NN</th>\n",
       "      <th colspan=\"5\" halign=\"left\">GBM</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>method</th>\n",
       "      <th colspan=\"2\" halign=\"left\">natural</th>\n",
       "      <th colspan=\"3\" halign=\"left\">AP</th>\n",
       "      <th colspan=\"2\" halign=\"left\">natural</th>\n",
       "      <th colspan=\"3\" halign=\"left\">AP</th>\n",
       "      <th colspan=\"2\" halign=\"left\">natural</th>\n",
       "      <th colspan=\"3\" halign=\"left\">AP</th>\n",
       "      <th colspan=\"2\" halign=\"left\">natural</th>\n",
       "      <th colspan=\"3\" halign=\"left\">AP</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th></th>\n",
       "      <th>acc.</th>\n",
       "      <th>ER</th>\n",
       "      <th>acc.</th>\n",
       "      <th>ER</th>\n",
       "      <th>\\defscore</th>\n",
       "      <th>acc.</th>\n",
       "      <th>ER</th>\n",
       "      <th>acc.</th>\n",
       "      <th>ER</th>\n",
       "      <th>\\defscore</th>\n",
       "      <th>acc.</th>\n",
       "      <th>ER</th>\n",
       "      <th>acc.</th>\n",
       "      <th>ER</th>\n",
       "      <th>\\defscore</th>\n",
       "      <th>acc.</th>\n",
       "      <th>ER</th>\n",
       "      <th>acc.</th>\n",
       "      <th>ER</th>\n",
       "      <th>\\defscore</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>fullfashion</th>\n",
       "      <td>0.825</td>\n",
       "      <td>0.182</td>\n",
       "      <td>0.790</td>\n",
       "      <td>0.171</td>\n",
       "      <td>0.936</td>\n",
       "      <td>0.870</td>\n",
       "      <td>0.196</td>\n",
       "      <td>0.815</td>\n",
       "      <td>0.225</td>\n",
       "      <td>1.148</td>\n",
       "      <td>0.865</td>\n",
       "      <td>0.217</td>\n",
       "      <td>0.835</td>\n",
       "      <td>0.242</td>\n",
       "      <td>1.119</td>\n",
       "      <td>0.885</td>\n",
       "      <td>0.161</td>\n",
       "      <td>0.860</td>\n",
       "      <td>0.164</td>\n",
       "      <td>1.013</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>fullmnist</th>\n",
       "      <td>0.950</td>\n",
       "      <td>0.141</td>\n",
       "      <td>0.935</td>\n",
       "      <td>0.152</td>\n",
       "      <td>1.080</td>\n",
       "      <td>0.935</td>\n",
       "      <td>0.221</td>\n",
       "      <td>0.940</td>\n",
       "      <td>0.222</td>\n",
       "      <td>1.007</td>\n",
       "      <td>0.950</td>\n",
       "      <td>0.233</td>\n",
       "      <td>0.940</td>\n",
       "      <td>0.269</td>\n",
       "      <td>1.155</td>\n",
       "      <td>0.980</td>\n",
       "      <td>0.185</td>\n",
       "      <td>0.965</td>\n",
       "      <td>0.182</td>\n",
       "      <td>0.981</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "model            RF                                 $1$-NN                \\\n",
       "method      natural            AP                  natural            AP   \n",
       "               acc.     ER   acc.     ER \\defscore    acc.     ER   acc.   \n",
       "fullfashion   0.825  0.182  0.790  0.171     0.936   0.870  0.196  0.815   \n",
       "fullmnist     0.950  0.141  0.935  0.152     1.080   0.935  0.221  0.940   \n",
       "\n",
       "model                         $3$-NN                                    GBM  \\\n",
       "method                       natural            AP                  natural   \n",
       "                ER \\defscore    acc.     ER   acc.     ER \\defscore    acc.   \n",
       "fullfashion  0.225     1.148   0.865  0.217  0.835  0.242     1.119   0.885   \n",
       "fullmnist    0.222     1.007   0.950  0.233  0.940  0.269     1.155   0.980   \n",
       "\n",
       "model                                       \n",
       "method                 AP                   \n",
       "                ER   acc.     ER \\defscore  \n",
       "fullfashion  0.161  0.860  0.164     1.013  \n",
       "fullmnist    0.185  0.965  0.182     0.981  "
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "_, exp_name, grid_params, _ = params.fullds()()\n",
    "df = params_to_dataframe(grid_params, columns=[(\"tst_score\"), (\"avg_pert\"), \"aug_len\"])\n",
    "table = {}\n",
    "mnist_model_names = ['random_forest_500_d10', 'approxAP_rf_500_20_d10',\n",
    "                     'knn1', 'approxAP_nn_k1_20',\n",
    "                     'knn3', 'approxAP_nn_k3_20',\n",
    "                     'xgb_xgb_models/fullmnist_pca100.unrob.0200.model',\n",
    "                     'xgb_xgb_models/prune_fullmnist_pca100_20_linf.unrob.0200.model',\n",
    "                    ]\n",
    "fashion_model_names = ['random_forest_500_d10', 'approxAP_rf_500_20_d10',\n",
    "                       'knn1', 'approxAP_nn_k1_20',\n",
    "                       'knn3', 'approxAP_nn_k3_20',\n",
    "                       'xgb_xgb_models/fullfashion_pca100.unrob.0200.model',\n",
    "                       'xgb_xgb_models/prune_fullfashion_pca100_20_linf.unrob.0200.model',\n",
    "                      ]\n",
    "shown_names = ['RF', 'AP RF', '$1$-NN', 'AP $1$-NN', '$3$-NN', 'AP $3$-NN']\n",
    "for ds_name_t, d in df.groupby(\"dataset\"):\n",
    "    ds_name = ds_name_t.split(\"_\")[0]\n",
    "    if 'mnist' in ds_name:\n",
    "        model_names = mnist_model_names\n",
    "    else:\n",
    "        model_names = fashion_model_names\n",
    " \n",
    "    table[ds_name] = []\n",
    "    for i, model_name in enumerate(model_names):\n",
    "        model_df = d[d['model'] == model_name]\n",
    "        if len(model_df['tst_score'].values) == 0:\n",
    "            table[ds_name].append(\"%.3f\" % -1)\n",
    "            table[ds_name].append(\"%.3f\" % -1)\n",
    "            if i % 2:\n",
    "                table[ds_name].append(\"%.3f\" % -1)\n",
    "        else:\n",
    "            table[ds_name].append(\"%.3f\" % model_df['tst_score'].values[0])\n",
    "            table[ds_name].append(\"%.3f\" % model_df['avg_pert'].values[0])\n",
    "            if i % 2:\n",
    "                try:\n",
    "                    table[ds_name].append(\n",
    "                        \"%.3f\" % (model_df['avg_pert'].values[0] / d[d['model'] == model_names[i-1]]['avg_pert'].values[0]))\n",
    "                except:\n",
    "                    table[ds_name].append(\"%.3f\" % -1)\n",
    "    #for i, model_name in enumerate(model_names):\n",
    "    #    model_df = d[d['model'] == model_name]\n",
    "    #    table[ds_name].append(\"%.3f\" % model_df['tst_score'].values[0])\n",
    "    #    table[ds_name].append(\"%.3f\" % model_df['avg_pert'].values[0])\n",
    "    #    if i % 2:\n",
    "    #        table[ds_name].append(\n",
    "    #            \"%.3f\" % (model_df['avg_pert'].values[0] / d[d['model'] == model_names[i-1]]['avg_pert'].values[0]))\n",
    "table = pd.DataFrame.from_dict(table, orient='index', columns=[\n",
    "    ('RF', 'natural', 'acc.'), ('RF', 'natural', 'ER'),\n",
    "    ('RF', 'AP', 'acc.'), ('RF', 'AP', 'ER'), ('RF', 'AP', '\\\\defscore'),\n",
    "    ('$1$-NN', 'natural', 'acc.'), ('$1$-NN', 'natural', 'ER'),\n",
    "    ('$1$-NN', 'AP', 'acc.'), ('$1$-NN', 'AP', 'ER'), ('$1$-NN', 'AP', '\\\\defscore'),\n",
    "    ('$3$-NN', 'natural', 'acc.'), ('$3$-NN', 'natural', 'ER'),\n",
    "    ('$3$-NN', 'AP', 'acc.'), ('$3$-NN', 'AP', 'ER'), ('$3$-NN', 'AP', '\\\\defscore'),\n",
    "    ('GBM', 'natural', 'acc.'), ('GBM', 'natural', 'ER'),\n",
    "    ('GBM', 'AP', 'acc.'), ('GBM', 'AP', 'ER'), ('GBM', 'AP', '\\\\defscore'),\n",
    "])\n",
    "table.columns = pd.MultiIndex.from_tuples(table.columns, names=['model', 'method', ''])\n",
    "table"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%bash\n",
    "bash ./sync_report.sh"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "ename": "AssertionError",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mAssertionError\u001b[0m                            Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-1-788ee363e0a2>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32massert\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m==\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;31mAssertionError\u001b[0m: "
     ]
    }
   ],
   "source": [
    "assert 1==0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "ename": "ModuleNotFoundError",
     "evalue": "No module named 'pvn'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mModuleNotFoundError\u001b[0m                       Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-8-7bbcb64bc98d>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      3\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mmatplotlib\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcolors\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mListedColormap\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      4\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0msklearn\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mneighbors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdatasets\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mpvn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdefender\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdefense\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0madversarial_pruning\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      6\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0msklearn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtree\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mDecisionTreeClassifier\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      7\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0msklearn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mensemble\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mRandomForestClassifier\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'pvn'"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib.colors import ListedColormap\n",
    "from sklearn import neighbors, datasets\n",
    "from sklearn.tree import DecisionTreeClassifier\n",
    "from sklearn.ensemble import RandomForestClassifier\n",
    "from scipy.spatial import Voronoi, voronoi_plot_2d\n",
    "\n",
    "\n",
    "def draw_boundary(clf, X, y, file_name=None):\n",
    "    h = .01  # step size in the mesh\n",
    "    \n",
    "    # Create color maps\n",
    "    cmap_light = ListedColormap(['#FFAAAA', '#AAFFAA', '#AAAAFF'])\n",
    "    cmap_bold = ListedColormap(['#FF0000', '#00FF00', '#0000FF'])\n",
    "    #cmap_light = ListedColormap(['#ffffff', '#43a2ca'])\n",
    "    #cmap_bold = ListedColormap(['#e0f3db', '#a8ddb5'])\n",
    "\n",
    "    # Plot the decision boundary. For that, we will assign a color to each\n",
    "    # point in the mesh [x_min, x_max]x[y_min, y_max].\n",
    "    x_min, x_max = X[:, 0].min() - 0.25, X[:, 0].max() + 0.25\n",
    "    y_min, y_max = X[:, 1].min() - 0.25, X[:, 1].max() + 0.25\n",
    "    xx, yy = np.meshgrid(np.arange(x_min, x_max, h),\n",
    "                         np.arange(y_min, y_max, h))\n",
    "    Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])\n",
    "\n",
    "    # Put the result into a color plot\n",
    "    Z = Z.reshape(xx.shape)\n",
    "    fig = plt.figure(figsize=(8, 8))\n",
    "    plt.pcolormesh(xx, yy, Z, cmap=cmap_light)\n",
    "\n",
    "    # Plot also the training points\n",
    "    plt.scatter(X[:, 0], X[:, 1], c=y, cmap=cmap_bold,\n",
    "                edgecolor='k', s=20)\n",
    "    plt.xlim(xx.min(), xx.max())\n",
    "    plt.ylim(yy.min(), yy.max())\n",
    "    plt.axis('off')\n",
    "    plt.title(\"\")\n",
    "    \n",
    "    if file_name is not None:\n",
    "        plt.savefig(file_name, transparent=True)\n",
    "    plt.show()\n",
    "    return fig"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'draw_boundary' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-9-970ec1048f5e>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdatasets\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmake_moons\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mn_samples\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1000\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnoise\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0.20\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrandom_state\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      2\u001b[0m \u001b[0mclf\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mneighbors\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mKNeighborsClassifier\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0m_\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdraw_boundary\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mclf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"figs/moon_1nn.png\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m: name 'draw_boundary' is not defined"
     ]
    }
   ],
   "source": [
    "X, y = datasets.make_moons(n_samples=1000, noise=0.20, random_state=0)\n",
    "clf = neighbors.KNeighborsClassifier(1).fit(X, y)\n",
    "_ = draw_boundary(clf, X, y, \"figs/moon_1nn.png\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.923543828315131 2.92527918559083 0.9994067720837564\n"
     ]
    }
   ],
   "source": [
    "a = json.load(open(\"./results/fullmnist-approxAP-faisslshknn-3-500-500-blackbox-rs0-l2.json\", \"r\"))\n",
    "b = json.load(open(\"./results/fullmnist-faisslshknn-3-500-blackbox-rs0-l2.json\", \"r\"))\n",
    "print(a['avg_pert']['avg'], b['avg_pert']['avg'], a['avg_pert']['avg'] / b['avg_pert']['avg'])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
