{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from IPython.core.display import display, HTML\n",
    "display(HTML('<style>.container { width:100% !important; }</style><link href=\"https://fonts.googleapis.com/css2?family=Roboto:wght@300;400;700&display=swap\" rel=\"stylesheet\">'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import altair as alt\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from altair_saver import save\n",
    "import ipywidgets as widgets\n",
    "alt.data_transformers.disable_max_rows()\n",
    "alt.renderers.enable('default')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from functools import partial\n",
    "\n",
    "def pd_reduce(dataframe, output_column, fn):\n",
    "    result = dataframe.copy()\n",
    "    result[output_column] = np.nan\n",
    "    for i, index in enumerate(result.index):\n",
    "        prev = None if i == 0 else result.iloc[i - 1]\n",
    "        curr = result.loc[index]\n",
    "        result.loc[index, output_column] = fn(prev, curr, output_column)\n",
    "    return result\n",
    "\n",
    "def auc_segment(prev, curr, output_column, epsilon=0):\n",
    "    last_samp = 0 if prev is None else prev.samples\n",
    "    last_tot = 0 if prev is None else prev[output_column]\n",
    "    return (curr.samples - last_samp) * max(curr.test_loss - epsilon, 0)\n",
    "\n",
    "def sum_reduction(prev, curr, output_column, input_column='auc_segment'):\n",
    "    last_tot = 0 if prev is None else prev[output_column]\n",
    "    return last_tot + curr[input_column]\n",
    "\n",
    "def auc(dataframe, carry_column='auc_segment', output_column='auc_agg', epsilon=0):\n",
    "    result = pd_reduce(dataframe, carry_column, partial(auc_segment, epsilon=epsilon))\n",
    "    result = pd_reduce(result, output_column, partial(sum_reduction, input_column=carry_column))\n",
    "    return result\n",
    "\n",
    "def auc_per_data(df, epsilons):\n",
    "    for epsilon in epsilons:\n",
    "        colname = f'auc_agg@{epsilon}'.replace('.', '_')\n",
    "        results = []\n",
    "\n",
    "        for label in df.label.unique():\n",
    "            subset = df[(df.label == label)]\n",
    "            results.append(auc(subset, output_column=f'auc_agg@{epsilon}'.replace('.', '_'), epsilon=epsilon))\n",
    "\n",
    "        df = pd.concat(results)\n",
    "        df[f'str_{colname}'] = df[colname].round(2).astype(str)\n",
    "        if epsilon > 0:\n",
    "            df.loc[df['test_loss'] > epsilon, f'str_{colname}'] = \"> \" + df.loc[df['test_loss'] > epsilon, f'str_{colname}']\n",
    "    return df\n",
    "\n",
    "def sc_segment(prev, curr, output_column, epsilon=0):\n",
    "    if prev is not None:\n",
    "        prev_sc = prev[output_column]\n",
    "    else:\n",
    "        prev_sc = 1e20\n",
    "\n",
    "    if curr.test_loss <= epsilon:\n",
    "        curr_sc = curr.samples\n",
    "    else:\n",
    "        curr_sc = 1e20\n",
    "    return min(prev_sc, curr_sc)\n",
    "\n",
    "def sc(dataframe, carry_column='sc_segment', output_column='sc', epsilon=0):\n",
    "    result = pd_reduce(dataframe, output_column, partial(sc_segment, epsilon=epsilon))\n",
    "    return result\n",
    "\n",
    "def sc_per_data(df, epsilons):\n",
    "    for epsilon in epsilons:\n",
    "        colname = f'sc@{epsilon}'.replace('.', '_')\n",
    "        results = []\n",
    "        for label in df.label.unique():\n",
    "            subset = df[(df.label == label)]\n",
    "            results.append(sc(subset, output_column=colname, epsilon=epsilon))\n",
    "\n",
    "\n",
    "        df = pd.concat(results)\n",
    "        df[f'str_{colname}'] = df[colname].astype(int).astype(str)\n",
    "        df.loc[df[colname] > 1e10, f'str_{colname}'] = \"> \" + df.loc[df[colname] > 1e10, 'samples'].astype(str)\n",
    "    # note that this overwrites `df` many times! not having an outer concat is by design\n",
    "    return df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def loss_data_chart(df, title='', xdomain=[8,60000], ydomain=[0.008,2], xrules=[], yrules=[], \n",
    "                    color_title='Representation', final=False):\n",
    "    if final:\n",
    "        line_width = 5\n",
    "        label_size = 24\n",
    "        title_size = 30\n",
    "    else:\n",
    "        line_width = 5\n",
    "        label_size = 14\n",
    "        title_size = 20\n",
    "        \n",
    "    rules_df = pd.concat([\n",
    "        pd.DataFrame({'x': xrules}),\n",
    "        pd.DataFrame({'y': yrules})\n",
    "    ], sort=False)\n",
    "\n",
    "    colorscheme = 'set1'\n",
    "    stroke_color = '333'\n",
    "    line = alt.Chart(df[df.samples >= 10], title=title).mark_line(size=line_width, opacity=0.4).encode(\n",
    "        x=alt.X('samples', scale=alt.Scale(type='log', domain=xdomain, nice=False),  title='Dataset size'),\n",
    "        y=alt.Y('mean(test_loss)', scale=alt.Scale(type='log', domain=ydomain, nice=False), title='Test loss'),\n",
    "        color=alt.Color('label:N', title=color_title, scale=alt.Scale(scheme=colorscheme,), legend=None),\n",
    "    )\n",
    "\n",
    "    point = alt.Chart(df[df.samples >= 10], title=title).mark_point(size=80, opacity=1).encode(\n",
    "        x=alt.X('samples', scale=alt.Scale(type='log', domain=xdomain, nice=False),  title='Dataset size'),\n",
    "        y=alt.Y('mean(test_loss)', scale=alt.Scale(type='log', domain=ydomain, nice=False), title='Test loss'),\n",
    "        color=alt.Color('label:N', title=color_title, scale=alt.Scale(scheme=colorscheme,)),\n",
    "        shape=alt.Shape('label:N', title=color_title), \n",
    "        tooltip=['samples', 'label']\n",
    "    )\n",
    "    \n",
    "    rule_x = alt.Chart(rules_df).mark_rule(size=3, color='999', strokeDash=[4, 4]).encode(x='x')\n",
    "    rule_y = alt.Chart(rules_df).mark_rule(size=3, color='999', strokeDash=[4, 4]).encode(y='y')\n",
    "\n",
    "    chart = alt.layer(rule_x, rule_y, line, point).resolve_scale(\n",
    "        color='independent',\n",
    "        shape='independent'\n",
    "    )\n",
    "    chart = chart.properties(width=600, height=500, background='white')\n",
    "    chart = chart.configure_legend(labelLimit=0)\n",
    "    chart = chart.configure(\n",
    "        title=alt.TitleConfig(fontSize=title_size, fontWeight='normal'),\n",
    "        axis=alt.AxisConfig(titleFontSize=title_size, labelFontSize=label_size, grid=(not final), \n",
    "                            domainWidth=5, domainColor=stroke_color, \n",
    "                            tickWidth=3, tickSize=9, tickCount=4, tickColor=stroke_color, tickOffset=0),\n",
    "#         axisX=alt.AxisConfig(grid=True),\n",
    "        legend=alt.LegendConfig(titleFontSize=title_size, labelFontSize=label_size, labelLimit=0, titleLimit=0,\n",
    "                                orient='top-right', padding=10, \n",
    "                                titlePadding=10, rowPadding=5,\n",
    "                                fillColor='white', strokeColor='black', cornerRadius=0),\n",
    "        view=alt.ViewConfig(strokeWidth=0, stroke=stroke_color),\n",
    "        font='Roboto',\n",
    "    )    \n",
    "    return chart"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def make_latex(df, ns, stack=False, group_n=True, epsilons=[0.5, 0.1]):\n",
    "    df = auc_per_data(df, epsilons + [0]).reset_index(drop=True)\n",
    "    df = sc_per_data(df, epsilons)\n",
    "    df.reset_index(drop=True, inplace=True)\n",
    "    auc_cols = {f'str_auc_agg@{eps}'.replace('.', '_'):  f'SDL, $\\\\varepsilon$={eps}' for eps in epsilons}\n",
    "    auc_cols['str_auc_agg@0'] = 'MDL'\n",
    "    sc_cols = {f'str_sc@{eps}'.replace('.', '_'):  f'$\\\\varepsilon$SC, $\\\\varepsilon$={eps}' for eps in epsilons}\n",
    "    output_df = df[df.samples.isin(ns)].groupby(['label', 'data', 'samples', *auc_cols.keys(), *sc_cols.keys()]).mean().reset_index()\n",
    "    output_df = output_df[['samples', 'label', 'test_loss', *auc_cols.keys(), *sc_cols.keys()]]\n",
    "    output_df = output_df.sort_values('samples')\n",
    "\n",
    "    output_df = output_df.rename(columns={'label': 'Representation', 'samples': 'n', 'test_loss': 'Val loss', **auc_cols, **sc_cols})\n",
    "    auc_cols.pop('str_auc_agg@0')\n",
    "    output_df = output_df.reindex(['Representation', 'n', 'Val loss', 'MDL', *auc_cols.values(), *sc_cols.values()], axis=1)\n",
    "    output_df['n'] = output_df['n'].astype(int)\n",
    "    if stack:\n",
    "        if not group_n:\n",
    "            output_df['n'] = '$n=' + output_df['n'].astype(str) + '$'\n",
    "        output_df = output_df.set_index(['Representation', 'n'])\n",
    "        if group_n:\n",
    "            output_df = output_df.transpose()\n",
    "            output_df.reindex(['Val loss', 'MDL', *auc_cols.values(), *sc_cols.values()])\n",
    "            display(output_df)\n",
    "            output_df = output_df.stack()\n",
    "            output_df = output_df.swaplevel().sort_values('n', ascending=True)\n",
    "    else:\n",
    "        output_df = output_df.set_index(['n', 'label'])\n",
    "        output_df = output_df.transpose()\n",
    "    out = widgets.Output(layout={'border': '1px solid black'})\n",
    "    latex_str = output_df.to_latex(multicolumn_format='c', float_format=\"{:0.2f}\".format, escape=False, column_format='llrrr')\n",
    "    out.append_stdout(latex_str)\n",
    "    return out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "df = pd.concat([\n",
    "    *[pd.read_pickle(f'results/realprobe_paper_mnist-repr_raw_dim784_level3-seed{seed}.pkl') \n",
    "      for seed in range(8)],\n",
    "    *[pd.read_pickle(f'results/realprobe_paper_mnist-repr_cifar_supervised_dim784_level3-seed{seed}.pkl') \n",
    "      for seed in range(8)],\n",
    "    *[pd.read_pickle(f'results/realprobe_paper_mnist-repr_mnist_vae_dim8_level3-seed{seed}.pkl') \n",
    "      for seed in range(8)],\n",
    "], sort=False).reset_index(drop=True)\n",
    "\n",
    "if 'name' not in df: df['name'] = ''\n",
    "df['name'].fillna('',  inplace=True)\n",
    "df['name'] = df.name.str.replace('_?seeds?[0-9]*', '')\n",
    "df['test_error'] = 1 - df.test_accuracy\n",
    "df['zero'] = 0\n",
    "df['label'] = df.data + ' ' + df.repr + '-' + df.repr_dim.astype(str) + ' ' + df.name\n",
    "df.loc[df.repr == 'cifar_supervised', 'label'] = \"CIFAR\"\n",
    "df.loc[df.repr == 'raw', 'label'] = \"Pixels\"\n",
    "df.loc[df.repr == 'mnist_vae', 'label'] = \"VAE\"\n",
    "\n",
    "ns = [60, 20398]\n",
    "epsilons = [ 0.1, 0.02]\n",
    "chart = loss_data_chart(df, title=\"\", xrules=ns, yrules=epsilons, final=True, ydomain=[0.005, 1])\n",
    "display(chart)\n",
    "save(chart, 'mnist_reprs.pdf')\n",
    "df = df.groupby(['label', 'samples', 'data']).mean().reset_index()\n",
    "\n",
    "make_latex(df, ns=ns, epsilons=epsilons, stack=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "df = pd.concat([\n",
    "    *[pd.read_pickle(f'results/realprobe_mnist-repr_raw_dim784-ntrain_50000-seed{seed}.pkl') \n",
    "      for seed in [0, 2, 4, 6]],\n",
    "    *[pd.read_pickle(f'results/realprobe_mnist_noisygt-repr_raw_dim784_level3-seed{seed}.pkl') \n",
    "      for seed in [0, 2, 4, 6]],\n",
    "], sort=False).reset_index(drop=True)\n",
    "\n",
    "if 'name' not in df: df['name'] = ''\n",
    "df['name'].fillna('',  inplace=True)\n",
    "df['name'] = df.name.str.replace('_?seeds?[0-9]*', '')\n",
    "df['test_error'] = 1 - df.test_accuracy\n",
    "df['zero'] = 0\n",
    "df['label'] = df.data + ' ' + df.repr + '-' + df.repr_dim.astype(str) + ' ' + df.name\n",
    "df.loc[df.data == 'mnist', 'label'] = \"Raw pixels\"\n",
    "df.loc[df.data == 'mnist_noisygt', 'label'] = \"Noisy labels\"\n",
    "\n",
    "final = True\n",
    "ydomain = [0.01, 1]\n",
    "title=''\n",
    "xdomain=[8,60000]\n",
    "xrules=[]\n",
    "yrules=[]\n",
    "color_title='Representation'\n",
    "\n",
    "if final:\n",
    "    line_width = 8\n",
    "    label_size = 30\n",
    "    title_size = 54\n",
    "else:\n",
    "    line_width = 5\n",
    "    label_size = 14\n",
    "    title_size = 20\n",
    "\n",
    "rules_df = pd.concat([\n",
    "    pd.DataFrame({'x': xrules}),\n",
    "    pd.DataFrame({'y': yrules})\n",
    "], sort=False)\n",
    "\n",
    "colorscheme = 'set1'\n",
    "stroke_color = '333'\n",
    "line = alt.Chart(df[df.samples >= 10], title=title).mark_line(size=line_width, opacity=0.5).encode(\n",
    "    x=alt.X('samples', scale=alt.Scale(type='log', domain=xdomain, nice=False),  title='Dataset size'),\n",
    "    y=alt.Y('mean(test_loss)', scale=alt.Scale(type='log', domain=ydomain, nice=False), title=''),\n",
    "    color=alt.Color('label:N', title=color_title, scale=alt.Scale(scheme=colorscheme,), legend=None),\n",
    ")\n",
    "\n",
    "point = alt.Chart(df[df.samples >= 10], title=title).mark_point(size=120, opacity=1).encode(\n",
    "    x=alt.X('samples', scale=alt.Scale(type='log', domain=xdomain, nice=False),  title='Dataset size'),\n",
    "    y=alt.Y('mean(test_loss)', scale=alt.Scale(type='log', domain=ydomain, nice=False), title=''),\n",
    "    color=alt.Color('label:N', title=color_title, scale=alt.Scale(scheme=colorscheme,)),\n",
    "    shape=alt.Shape('label:N', title=color_title), \n",
    ")\n",
    "\n",
    "rule_x = alt.Chart(rules_df).mark_rule(size=3, color='999', strokeDash=[4, 4]).encode(x='x')\n",
    "rule_y = alt.Chart(rules_df).mark_rule(size=3, color='999', strokeDash=[4, 4]).encode(y='y')\n",
    "\n",
    "chart = alt.layer(rule_x, rule_y, line, point).resolve_scale(\n",
    "    color='independent',\n",
    "    shape='independent'\n",
    ")\n",
    "chart = chart.properties(width=600, height=500, background='white')\n",
    "chart = chart.configure_legend(labelLimit=0)\n",
    "chart = chart.configure(\n",
    "    title=alt.TitleConfig(fontSize=title_size, fontWeight='normal'),\n",
    "    axis=alt.AxisConfig(titleFontSize=title_size, labelFontSize=label_size, grid=(not final), \n",
    "                        domainWidth=5, domainColor=stroke_color, \n",
    "                        tickWidth=3, tickSize=9, tickCount=4, tickColor=stroke_color, tickOffset=0),\n",
    "    axisX=alt.AxisConfig(titlePadding=50),\n",
    "    legend=alt.LegendConfig(titleFontSize=36, labelFontSize=label_size, labelLimit=0, titleLimit=0,\n",
    "                            orient='top-right', padding=10, \n",
    "                            titlePadding=10, rowPadding=5,\n",
    "                            fillColor='white', strokeColor='black', cornerRadius=0),\n",
    "    view=alt.ViewConfig(strokeWidth=0, stroke=stroke_color),\n",
    "    font='Roboto',\n",
    ")    \n",
    "\n",
    "save(chart, \"noisygt_bold.pdf\")\n",
    "chart"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
