{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import csv\n",
    "from datetime import datetime, timedelta\n",
    "from collections import OrderedDict\n",
    "\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import plotly.graph_objects as go\n",
    "\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "def parse_folder(folder):\n",
    "    files = [file for file in os.listdir(folder) if os.path.splitext(file)[1] == '.csv']\n",
    "    expts = dict()\n",
    "    for fname in files:\n",
    "        with open(os.path.join(folder, fname), 'r') as f:\n",
    "            # print(fname)\n",
    "            dialect = csv.Sniffer().sniff(f.read(), delimiters=['\\t', ','])\n",
    "            f.seek(0)\n",
    "            reader = csv.DictReader(f, delimiter=dialect.delimiter)\n",
    "            accs, retrain_accs, timings, val_accs = [], [], [], []\n",
    "            for row in reader:\n",
    "                if row['budget'] != '': # nonempty row\n",
    "                    if float(row['val_acc_reinit']) > 0.5: # discard failed runs\n",
    "                        accs.append(float(row['target_acc_reinit']))\n",
    "                        # retrain_accs.append(float(row['target_acc_rerun']))\n",
    "                        val_accs.append(float(row['val_acc_reinit']))\n",
    "                        try:\n",
    "                            raw_time_string = datetime.strptime(row['test_time'], \"%H:%M:%S.%f\")\n",
    "                        except ValueError:\n",
    "                            raw_time_string = datetime.strptime(row['test_time'], \"%d day %H:%M:%S.%f\")\n",
    "                        timings.append(timedelta(hours=raw_time_string.hour, \n",
    "                                                 minutes=raw_time_string.minute, \n",
    "                                                 seconds=raw_time_string.second))  # cutoff microseconds\n",
    "                    else:\n",
    "                        print(f'Run discarded in file {fname} due to val.acc = {float(row[\"val_acc_reinit\"]):2.2f}')\n",
    "            if len(accs) > 0:\n",
    "                mean_acc = np.mean(accs)\n",
    "                val_acc = np.mean(val_accs)\n",
    "                std_err = np.std(accs) / np.sqrt(len(accs))\n",
    "                time_average = str(timedelta(seconds=np.mean([t.total_seconds() for t in timings])))\n",
    "                # print(f'Samples : {len(accs)}:')\n",
    "                #latex_str = f'{mean_acc:2.2%} ($\\pm {std_err*100:2.2f}$) & {time_average.split(\".\")[0]} \\\\\\\\'\n",
    "                #latex_str = f'{mean_acc:2.2%} ($\\pm {std_err*100:2.2f}$) & {val_acc:2.2%} \\\\\\\\'\n",
    "                latex_str = f'{mean_acc:2.2%} ($\\pm {std_err*100:2.2f}$)  & {val_acc:2.2%} & {time_average.split(\".\")[0]} \\\\\\\\'\n",
    "                # print(latex_str.replace('%', '\\%'))\n",
    "                try:\n",
    "                    name = fname.split('table_ResNet18_single-class_')[1].split('.')[0]\n",
    "                except IndexError:\n",
    "                    name = fname.split('table_ResNet18_single-class')[1].split('.')[0]\n",
    "                expts[name] = (mean_acc, val_acc)\n",
    "    return expts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "baseline = parse_folder('icml/sorted_data/from-scratch/_plotting')\n",
    "adversarial_poisoning = parse_folder('icml/sorted_data/from-scratch/_plotting/proposed')\n",
    "adversarial_training = parse_folder('icml/sorted_data/from-scratch/_plotting/adversarial_training')\n",
    "diff_private = parse_folder('icml/sorted_data/from-scratch/_plotting/DPSGD')\n",
    "filters = parse_folder('icml/sorted_data/from-scratch/_plotting/Filters')\n",
    "data_augmentations = parse_folder('icml/sorted_data/from-scratch/_plotting/data_aug')\n",
    "adversarial_poisoning"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "diff_private = dict(sorted(diff_private.items()))\n",
    "adversarial_poisoning = dict(sorted(adversarial_poisoning.items()))\n",
    "adversarial_training = dict(sorted(adversarial_training.items(), reverse=False))\n",
    "adversarial_poisoning"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Plot Poison Success vs Validation Accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "fig = go.Figure()\n",
    "\n",
    "# Baseline\n",
    "mean_accs, val_accs = list(zip(*baseline.values()))\n",
    "fig.add_trace(go.Scatter(y=mean_accs,\n",
    "                         x=[1 - v for v in val_accs],\n",
    "                         name='Undefended (Baseline)',\n",
    "                         mode='lines+markers+text',\n",
    "                         line=dict(width=5, dash='solid'),\n",
    "                         marker=dict(size=20, symbol='circle', color='royalblue'),\n",
    "                         showlegend=True,\n",
    "                         textposition= \"top center\",\n",
    "                         text='Undefended'\n",
    "                        ))\n",
    "\n",
    "# adversarial-poisoning\n",
    "mean_accs, val_accs = list(zip(*adversarial_poisoning.values()))\n",
    "fig.add_trace(go.Scatter(y=mean_accs,\n",
    "                         x=[1 - v for v in val_accs],\n",
    "                         name='Poison Immunity (proposed)',\n",
    "                         mode='lines+markers+text',\n",
    "                         line=dict(width=5, dash='solid'),\n",
    "                         marker=dict(size=20, symbol='diamond', color='firebrick'),\n",
    "                         showlegend=True,\n",
    "                         text=[\"p=0.25\", \"p=0.5\", \"p=0.625\", \"p=0.75\", \"p=0.875\"],\n",
    "                         textposition= [\"bottom left\",\"bottom left\",\"bottom left\",\"bottom center\",\"bottom center\"]\n",
    "                        ))\n",
    "\n",
    "# differential privacy\n",
    "mean_accs, val_accs = list(zip(*diff_private.values()))\n",
    "fig.add_trace(go.Scatter(y=mean_accs,\n",
    "                         x=[1 - v for v in val_accs],\n",
    "                         name='Differentially private SGD',\n",
    "                         mode='lines+markers',\n",
    "                         line=dict(width=5, dash='solid'),\n",
    "                         marker=dict(size=20, symbol='square', color='rgba(34,139,34,1.0)'),\n",
    "                         showlegend=True,\n",
    "                         text=[\"n=0.0001\", \"n=0.0005\", \"n=0.001\", \"n=0.005\", \"n=0.01\"],\n",
    "                         textposition= \"top right\",\n",
    "                        ))\n",
    "\n",
    "\n",
    "# adversarial_training\n",
    "mean_accs, val_accs = list(zip(*adversarial_training.values()))\n",
    "fig.add_trace(go.Scatter(y=mean_accs,\n",
    "                         x=[1 - v for v in val_accs],\n",
    "                         name='Adversarial Training',\n",
    "                         mode='lines+markers+text',\n",
    "                         line=dict(width=5, dash='solid'),\n",
    "                         marker=dict(size=20, symbol='pentagon', color='darkseagreen'),\n",
    "                         showlegend=True,\n",
    "                         text=[\"eps=8\", \"eps=16\"],  #\"eps=2\", \"eps=4\", \"eps=6\",???\n",
    "                         textposition= \"bottom center\",\n",
    "                        ))\n",
    "\n",
    "# Filter Defenses\n",
    "mean_accs, val_accs = list(zip(*filters.values()))\n",
    "fig.add_trace(go.Scatter(y=mean_accs,\n",
    "                         x=[1 - v for v in val_accs],\n",
    "                         name='Filter Defenses',\n",
    "                         mode='markers+text',\n",
    "                         line=dict(width=5, dash='solid'),\n",
    "                         marker=dict(size=20, symbol='x', color='darkslategray'),\n",
    "                         showlegend=True,\n",
    "                         text=[\"Spectral Signatures\", \"deep K-NN\", \"Activation Clustering\"],\n",
    "                         textposition= \"middle right\",                     \n",
    "                        ))\n",
    "# Data Augmentations\n",
    "mean_accs, val_accs = list(zip(*data_augmentations.values()))\n",
    "fig.add_trace(go.Scatter(y=mean_accs,\n",
    "                         x=[1 - v for v in val_accs],\n",
    "                         name='Data Augmentation',\n",
    "                         mode='markers+text',\n",
    "                         line=dict(width=5, dash='solid'),\n",
    "                         marker=dict(size=20, symbol='cross', color='olive'),\n",
    "                         showlegend=True,\n",
    "                         text=[\"Input Noise\", \"CutMix\", \"Maxup\"],\n",
    "                         textposition= \"bottom center\",\n",
    "                        ))    \n",
    "    \n",
    "    \n",
    "fig.update_traces(cliponaxis=False, textfont=dict(color='black'))\n",
    "# fig.update_layout(title=f'Angle between average data gradient and target gradient')\n",
    "fig.update_layout(xaxis_type=\"linear\", yaxis_type=\"linear\")\n",
    "fig.update_layout(legend=dict(x=1, y=0,         \n",
    "                              font=dict(\n",
    "                                    size=18,\n",
    "                                    color=\"black\"\n",
    "                                )), )\n",
    "fig.update_layout(\n",
    "    xaxis_title=\"<b>Validation Error</b>\",\n",
    "    yaxis_title=\"<b>Avg. Poison Success</b>\",\n",
    "    font=dict(\n",
    "        family=\"Computer Modern Bold\",\n",
    "        size=16,      \n",
    "        ),\n",
    "    legend=dict(\n",
    "        yanchor=\"top\",\n",
    "        y=0.99,\n",
    "        xanchor=\"right\",\n",
    "        x=1,\n",
    "        bgcolor=\"white\",\n",
    "        bordercolor=\"Black\",\n",
    "        borderwidth=1\n",
    "    )\n",
    ")\n",
    "fig.update_layout(uniformtext_minsize=24, uniformtext_mode='hide')\n",
    "\n",
    "\n",
    "fig.update_xaxes(showline=True, showgrid=True, gridwidth=0.1, \n",
    "                 gridcolor='rgba(1,1,1,0.25)', \n",
    "                 linecolor='black',\n",
    "                zerolinecolor='rgba(1,1,1,0.25)')\n",
    "fig.update_yaxes(showline=True, showgrid=True, gridwidth=0.1, \n",
    "                 gridcolor='rgba(1,1,1,0.25)', \n",
    "                 linecolor='black',\n",
    "                 zerolinecolor='rgba(1,1,1,0.25)')\n",
    "\n",
    "fig.update_xaxes(range=[0.05, 0.43])\n",
    "fig.update_yaxes(range=[-0.1, 1.1])\n",
    "fig.update_layout(\n",
    "    margin=dict(l=20, r=20, t=20, b=0),\n",
    "    paper_bgcolor='rgba(0,0,0,0)',\n",
    "    plot_bgcolor='rgba(0,0,0,0)',\n",
    "    width=1000,\n",
    "    height=500,\n",
    ")\n",
    "# fig.update_yaxes(automargin=True)\n",
    "\n",
    "    \n",
    "fig.show()\n",
    "fig.write_image(\"from_Scratch_success_vs_error.pdf\", scale=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Directions:\n",
    "['top left', 'top center', 'top right', 'middle left',\n",
    "            'middle center', 'middle right', 'bottom left', 'bottom\n",
    "            center', 'bottom right']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "symbols = ['circle', 'square', 'diamond', 'cross', 'x', 'triangle-down', 'pentagon', 'hexagram', 'star', 'diamond',\n",
    "           'hourglass']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# CSS colors:\n",
    "  aliceblue, antiquewhite, aqua, aquamarine, azure,\n",
    "                beige, bisque, black, blanchedalmond, blue,\n",
    "                blueviolet, brown, burlywood, cadetblue,\n",
    "                chartreuse, chocolate, coral, cornflowerblue,\n",
    "                cornsilk, crimson, cyan, darkblue, darkcyan,\n",
    "                darkgoldenrod, darkgray, darkgrey, darkgreen,\n",
    "                darkkhaki, darkmagenta, darkolivegreen, darkorange,\n",
    "                darkorchid, darkred, darksalmon, darkseagreen,\n",
    "                darkslateblue, darkslategray, darkslategrey,\n",
    "                darkturquoise, darkviolet, deeppink, deepskyblue,\n",
    "                dimgray, dimgrey, dodgerblue, firebrick,\n",
    "                floralwhite, forestgreen, fuchsia, gainsboro,\n",
    "                ghostwhite, gold, goldenrod, gray, grey, green,\n",
    "                greenyellow, honeydew, hotpink, indianred, indigo,\n",
    "                ivory, khaki, lavender, lavenderblush, lawngreen,\n",
    "                lemonchiffon, lightblue, lightcoral, lightcyan,\n",
    "                lightgoldenrodyellow, lightgray, lightgrey,\n",
    "                lightgreen, lightpink, lightsalmon, lightseagreen,\n",
    "                lightskyblue, lightslategray, lightslategrey,\n",
    "                lightsteelblue, lightyellow, lime, limegreen,\n",
    "                linen, magenta, maroon, mediumaquamarine,\n",
    "                mediumblue, mediumorchid, mediumpurple,\n",
    "                mediumseagreen, mediumslateblue, mediumspringgreen,\n",
    "                mediumturquoise, mediumvioletred, midnightblue,\n",
    "                mintcream, mistyrose, moccasin, navajowhite, navy,\n",
    "                oldlace, olive, olivedrab, orange, orangered,\n",
    "                orchid, palegoldenrod, palegreen, paleturquoise,\n",
    "                palevioletred, papayawhip, peachpuff, peru, pink,\n",
    "                plum, powderblue, purple, red, rosybrown,\n",
    "                royalblue, saddlebrown, salmon, sandybrown,\n",
    "                seagreen, seashell, sienna, silver, skyblue,\n",
    "                slateblue, slategray, slategrey, snow, springgreen,\n",
    "                steelblue, tan, teal, thistle, tomato, turquoise,\n",
    "                violet, wheat, white, whitesmoke, yellow,\n",
    "                yellowgreen"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# The same but for transfer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "baseline = parse_folder('icml/sorted_data/transfer/_plotting')\n",
    "adversarial_poisoning = parse_folder('icml/sorted_data/transfer/_plotting/proposed')\n",
    "adversarial_training = parse_folder('icml/sorted_data/transfer/_plotting/adversarial_training')\n",
    "diff_private = parse_folder('icml/sorted_data/transfer/_plotting/differential_privacy')\n",
    "filters = parse_folder('icml/sorted_data/transfer/_plotting/filters')\n",
    "data_augmentations = parse_folder('icml/sorted_data/transfer/_plotting/data_augmentations')\n",
    "diff_private = dict(sorted(diff_private.items()))\n",
    "adversarial_poisoning = dict(sorted(adversarial_poisoning.items()))\n",
    "adversarial_training = dict(sorted(adversarial_training.items(), reverse=True))\n",
    "adversarial_poisoning"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_augmentations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = go.Figure()\n",
    "\n",
    "# Baseline\n",
    "mean_accs, val_accs = list(zip(*baseline.values()))\n",
    "fig.add_trace(go.Scatter(y=mean_accs,\n",
    "                         x=[1 - v for v in val_accs],\n",
    "                         name='Baseline Attack',\n",
    "                         mode='lines+markers+text',\n",
    "                         line=dict(width=5, dash='solid'),\n",
    "                         marker=dict(size=20, symbol='circle', color='royalblue'),\n",
    "                         showlegend=True,\n",
    "                         textposition= \"top center\",\n",
    "                         text='Baseline'\n",
    "                        ))\n",
    "\n",
    "# adversarial-poisoning\n",
    "mean_accs, val_accs = list(zip(*adversarial_poisoning.values()))\n",
    "fig.add_trace(go.Scatter(y=mean_accs,\n",
    "                         x=[1 - v for v in val_accs],\n",
    "                         name='Adversarial Poisoning',\n",
    "                         mode='lines+markers+text',\n",
    "                         line=dict(width=5, dash='solid'),\n",
    "                         marker=dict(size=20, symbol='diamond', color='firebrick'),\n",
    "                         showlegend=True,\n",
    "                         text=[\"p=0.25\", \"p=0.5\", \"p=0.625\", \"p=0.75\", \"p=0.875\"],\n",
    "                         textposition= [\"bottom left\",\"bottom left\",\"bottom left\",\"bottom center\",\"bottom center\"]\n",
    "                        ))\n",
    "\n",
    "# differential privacy\n",
    "mean_accs, val_accs = list(zip(*diff_private.values()))\n",
    "fig.add_trace(go.Scatter(y=mean_accs,\n",
    "                         x=[1 - v for v in val_accs],\n",
    "                         name='Differentially private SGD',\n",
    "                         mode='lines+markers',\n",
    "                         line=dict(width=5, dash='solid'),\n",
    "                         marker=dict(size=20, symbol='square', color='rgba(34,139,34,1.0)'),\n",
    "                         showlegend=True,\n",
    "                         text=[\"n=0.0001\", \"n=0.0005\", \"n=0.001\", \"n=0.005\", \"n=0.01\"],\n",
    "                         textposition= \"top right\",\n",
    "                        ))\n",
    "\n",
    "\n",
    "# adversarial_training\n",
    "mean_accs, val_accs = list(zip(*adversarial_training.values()))\n",
    "fig.add_trace(go.Scatter(y=mean_accs,\n",
    "                         x=[1 - v for v in val_accs],\n",
    "                         name='Adversarial Training',\n",
    "                         mode='lines+markers+text',\n",
    "                         line=dict(width=5, dash='solid'),\n",
    "                         marker=dict(size=20, symbol='pentagon', color='darkseagreen'),\n",
    "                         showlegend=True,\n",
    "                         text=[\"eps=8\", \"eps=16\"],\n",
    "                         textposition= \"bottom center\",\n",
    "                        ))\n",
    "\n",
    "# Filter Defenses\n",
    "mean_accs, val_accs = list(zip(*filters.values()))\n",
    "fig.add_trace(go.Scatter(y=mean_accs,\n",
    "                         x=[1 - v for v in val_accs],\n",
    "                         name='Filter Defenses',\n",
    "                         mode='markers+text',\n",
    "                         line=dict(width=5, dash='solid'),\n",
    "                         marker=dict(size=20, symbol='x', color='darkslategray'),\n",
    "                         showlegend=True,\n",
    "                         text=[\"Madry\", \"deep-KNN\", \"Activation Clustering\"],\n",
    "                         textposition= \"middle right\",                     \n",
    "                        ))\n",
    "# Data Augmentations\n",
    "mean_accs, val_accs = list(zip(*data_augmentations.values()))\n",
    "fig.add_trace(go.Scatter(y=mean_accs,\n",
    "                         x=[1 - v for v in val_accs],\n",
    "                         name='Data Augmentation',\n",
    "                         mode='markers+text',\n",
    "                         line=dict(width=5, dash='solid'),\n",
    "                         marker=dict(size=20, symbol='cross', color='olive'),\n",
    "                         showlegend=True,\n",
    "                         text=[\"Input Noise\", \"CutMix\", \"Maxup\"],\n",
    "                         textposition= \"bottom center\",\n",
    "                        ))    \n",
    "    \n",
    "    \n",
    "fig.update_traces(cliponaxis=False, textfont=dict(color='black'))\n",
    "# fig.update_layout(title=f'Angle between average data gradient and target gradient')\n",
    "fig.update_layout(xaxis_type=\"linear\", yaxis_type=\"linear\")\n",
    "fig.update_layout(legend=dict(x=1, y=0,         \n",
    "                              font=dict(\n",
    "                                    size=18,\n",
    "                                    color=\"black\"\n",
    "                                )), )\n",
    "fig.update_layout(\n",
    "    xaxis_title=\"Validation Error\",\n",
    "    yaxis_title=\"Avg. Poison Success\",\n",
    "    font=dict(\n",
    "        # family=\"Computer Modern\",\n",
    "        size=14,      \n",
    "        ),\n",
    "    legend=dict(\n",
    "        yanchor=\"top\",\n",
    "        y=0.99,\n",
    "        xanchor=\"right\",\n",
    "        x=1,\n",
    "        bgcolor=\"white\",\n",
    "        bordercolor=\"Black\",\n",
    "        borderwidth=1\n",
    "    )\n",
    ")\n",
    "fig.update_layout(uniformtext_minsize=24, uniformtext_mode='hide')\n",
    "\n",
    "\n",
    "fig.update_xaxes(showline=True, showgrid=True, gridwidth=0.1, \n",
    "                 gridcolor='rgba(1,1,1,0.25)', \n",
    "                 linecolor='black',\n",
    "                zerolinecolor='rgba(1,1,1,0.25)')\n",
    "fig.update_yaxes(showline=True, showgrid=True, gridwidth=0.1, \n",
    "                 gridcolor='rgba(1,1,1,0.25)', \n",
    "                 linecolor='black',\n",
    "                 zerolinecolor='rgba(1,1,1,0.25)')\n",
    "\n",
    "fig.update_xaxes(range=[0.05, 0.45])\n",
    "fig.update_yaxes(range=[-0.1, 1.1])\n",
    "fig.update_layout(\n",
    "    margin=dict(l=20, r=20, t=20, b=0),\n",
    "    paper_bgcolor='rgba(0,0,0,0)',\n",
    "    plot_bgcolor='rgba(0,0,0,0)',\n",
    "    width=1000,\n",
    "    height=500,\n",
    ")\n",
    "# fig.update_yaxes(automargin=True)\n",
    "\n",
    "    \n",
    "fig.show()\n",
    "# fig.write_image(\"transfer_success_vs_error.pdf\", scale=1)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
