{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from IPython.core.display import display, HTML\n",
    "display(HTML('<style>.container { width:100% !important; }</style><link href=\"https://fonts.googleapis.com/css2?family=Roboto:wght@300;400;700&display=swap\" rel=\"stylesheet\">'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "import json\n",
    "import copy\n",
    "import numpy as np\n",
    "import yaml\n",
    "import pandas as pd\n",
    "import altair as alt\n",
    "from functools import partial\n",
    "import ipywidgets as widgets\n",
    "from altair_saver import save\n",
    "alt.renderers.enable('default')\n",
    "alt.data_transformers.disable_max_rows()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def pd_reduce(dataframe, output_column, fn):\n",
    "    result = dataframe.copy()\n",
    "    result[output_column] = np.nan\n",
    "    for i, index in enumerate(result.index):\n",
    "        prev = None if i == 0 else result.iloc[i - 1]\n",
    "        curr = result.loc[index]\n",
    "        result.loc[index, output_column] = fn(prev, curr, output_column)\n",
    "    return result\n",
    "\n",
    "def auc_segment(prev, curr, output_column, epsilon=0):\n",
    "    last_samp = 0 if prev is None else prev.train_targets\n",
    "    last_tot = 0 if prev is None else prev[output_column]\n",
    "#     print(curr)\n",
    "#     print(last_samp, epsilon)\n",
    "    return (curr.train_targets - last_samp) * max(curr.mean_loss_test - epsilon, 0)\n",
    "\n",
    "# def auc_segment(prev, curr, output_column, epsilon=0):\n",
    "#     if prev is None: return 0\n",
    "#     return max(prev.loss_online_portion - epsilon * prev.num_targets_online_portion, 0)\n",
    "\n",
    "\n",
    "\n",
    "def sum_reduction(prev, curr, output_column, input_column='auc_segment'):\n",
    "    last_tot = 0 if prev is None else prev[output_column]\n",
    "    return last_tot + curr[input_column]\n",
    "\n",
    "def auc(dataframe, carry_column='auc_segment', output_column='auc_agg', epsilon=0):\n",
    "#     display(dataframe)\n",
    "    result = pd_reduce(dataframe, carry_column, partial(auc_segment, epsilon=epsilon))\n",
    "    result = pd_reduce(result, output_column, partial(sum_reduction, input_column=carry_column))\n",
    "    return result\n",
    "\n",
    "def auc_per_data(df, epsilons):\n",
    "    for epsilon in epsilons:\n",
    "        colname = f'auc_agg@{epsilon}'.replace('.', '_')\n",
    "        results = []\n",
    "        for representation in df.representation.unique():\n",
    "            subset = df[(df.representation == representation)]\n",
    "            result = auc(subset.sort_values('n'), output_column=colname, epsilon=epsilon)\n",
    "            results.append(result)\n",
    "        df = pd.concat(results)\n",
    "        df[f'str_{colname}'] = df[colname].round(2).astype(str)\n",
    "        df.loc[df['mean_loss_test'] > epsilon, f'str_{colname}'] = \"> \" + df.loc[df['mean_loss_test'] > epsilon, f'str_{colname}']\n",
    "    # note that this overwrites `df` many times! not having an outer concat is by design\n",
    "    return df\n",
    "\n",
    "\n",
    "\n",
    "def sc_segment(prev, curr, output_column, epsilon=0):\n",
    "    if prev is not None:\n",
    "        prev_sc = prev[output_column]\n",
    "    else:\n",
    "        prev_sc = 1e20\n",
    "\n",
    "    if curr.mean_loss_test <= epsilon:\n",
    "        curr_sc = curr.n\n",
    "    else:\n",
    "        curr_sc = 1e20\n",
    "    return min(prev_sc, curr_sc)\n",
    "\n",
    "def sc(dataframe, carry_column='sc_segment', output_column='sc', epsilon=0):\n",
    "    result = pd_reduce(dataframe, output_column, partial(sc_segment, epsilon=epsilon))\n",
    "    return result\n",
    "\n",
    "def sc_per_data(df, epsilons):\n",
    "    for epsilon in epsilons:\n",
    "        colname = f'sc@{epsilon}'.replace('.', '_')\n",
    "        results = []\n",
    "        for representation in df.representation.unique():\n",
    "            subset = df[(df.representation == representation)]\n",
    "            results.append(sc(subset, output_column=colname, epsilon=epsilon))\n",
    "        df = pd.concat(results)\n",
    "        df[f'str_{colname}'] = df[colname].astype(int).astype(str)\n",
    "        df.loc[df[colname] > 1e10, f'str_{colname}'] = \"> \" + df.loc[df[colname] > 1e10, 'n'].astype(str)\n",
    "    # note that this overwrites `df` many times! not having an outer concat is by design\n",
    "    return df\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dir_with_experiments = '../..'\n",
    "num_classes = 45  \n",
    "    \n",
    "def get_mdls(experiment_name):\n",
    "    test_report = json.load(open(f'{dir_with_experiments}/{experiment_name}/online_test_report.json'))\n",
    "    test_accuracy = test_report[-1]['label_acc_test']  * 100\n",
    "    online_report = pickle.load(open(f'{dir_with_experiments}/{experiment_name}/online_coding.pkl', 'rb'))\n",
    "    train_n_label = 'num_targets_train' if 'num_targets_train' in online_report[0] else 'train_targets'\n",
    "    for rep in online_report:\n",
    "        rep['train_targets'] = rep[train_n_label]\n",
    "    train_size = online_report[-1][train_n_label]\n",
    "    uniform_codelength = train_size * np.log2(num_classes)\n",
    "\n",
    "    online_costs = [online_report[0][train_n_label] * np.log2(num_classes)] + [elem['loss_online_portion'] for elem in online_report[:-1]]\n",
    "    online_codelengths = np.cumsum(online_costs)\n",
    "    online_ns = [rep[train_n_label] for rep in online_report]\n",
    "    \n",
    "    return online_codelengths, online_ns, online_report\n",
    "\n",
    "def load_params(experiment_name):\n",
    "    try:\n",
    "        try:\n",
    "            return yaml.load(open(f'{dir_with_experiments}/{experiment_name}/online_l0.yml'))\n",
    "        except FileNotFoundError:\n",
    "            return yaml.load(open(f'{dir_with_experiments}/{experiment_name}/online_l0_control.yml'))\n",
    "    except FileNotFoundError:\n",
    "        return yaml.load(open(f'{dir_with_experiments}/{experiment_name}/{experiment_name}.yml'))\n",
    "\n",
    "\n",
    "def load_exps(experiments):\n",
    "    experiments = [{'name': e} for e in experiments]\n",
    "    results = []\n",
    "    for experiment in experiments:\n",
    "        params = load_params(experiment['name'])\n",
    "        experiment['n'] = params['dataset']['dataset_size']\n",
    "        experiment['model_layer'] = params['model']['model_layer']\n",
    "        experiment['corrupted'] = params['probe']['misc']['corrupted_token_percent'] > 0.99\n",
    "#         experiment['seed'] = params['seed']\n",
    "        \n",
    "        partial_mdls, partial_dataset_sizes, online_report = get_mdls(experiment['name'])\n",
    "#         print(partial_mdls)\n",
    "        for i, (mdl, size) in enumerate(zip(partial_mdls, partial_dataset_sizes)):\n",
    "            partial_exp = copy.deepcopy(experiment)\n",
    "            partial_exp['n'] = size\n",
    "            partial_exp['mdl'] = mdl\n",
    "            for k, v in online_report[i].items():\n",
    "                partial_exp[k] = v\n",
    "            for split in ['dev', 'test', 'online_portion']:\n",
    "                if f'loss_{split}' in partial_exp:\n",
    "                    partial_exp[f'mean_loss_{split}'] = partial_exp[f'loss_{split}'] / partial_exp[f'num_targets_{split}']\n",
    "            results.append(partial_exp)\n",
    "    df = pd.DataFrame(results)\n",
    "    df['representation'] = \"ELMo layer \" + df.model_layer.astype(str)\n",
    "    return df\n",
    "\n",
    "def merge_seeds(df, on='representation'):\n",
    "    subset_dfs = []\n",
    "    for on_key in df[on].unique():\n",
    "        repr_df = df[df[on] == on_key]\n",
    "        seed_dfs = []\n",
    "        for name in repr_df.name.unique():\n",
    "            seed_dfs.append(repr_df[repr_df.name == name].copy())\n",
    "\n",
    "        base_df = seed_dfs[0].sort_values('n')\n",
    "        ns = base_df['n'].values\n",
    "        subset_dfs.append(base_df)\n",
    "        for i in range(1, len(seed_dfs)):\n",
    "            seed_df = seed_dfs[i].sort_values('n')\n",
    "            merged_df = pd.merge_asof(seed_df, base_df, on='n', direction='nearest', suffixes=('', '_new'))\n",
    "            merged_df['n'] = ns\n",
    "            drop_cols = [col for col in merged_df.columns if col.endswith('_new')]\n",
    "            merged_df = merged_df.drop(columns=drop_cols)\n",
    "            subset_dfs.append(merged_df)\n",
    "    rebuilt_df = pd.concat(subset_dfs, sort=False)\n",
    "    return rebuilt_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def loss_data_chart(df, title='', xdomain=[40, 1000000], ydomain=[0.05, 3], xrules=[], yrules=[], \n",
    "                    color_title='Representation', final=False):\n",
    "\n",
    "    if final:\n",
    "        line_width = 5\n",
    "        label_size = 24\n",
    "        title_size = 30\n",
    "    else:\n",
    "        line_width = 5\n",
    "        label_size = 14\n",
    "        title_size = 20\n",
    "        \n",
    "    rules_df = pd.concat([\n",
    "        pd.DataFrame({'x': xrules}),\n",
    "        pd.DataFrame({'y': yrules})\n",
    "    ], sort=False)\n",
    "\n",
    "    colorscheme = 'set1'\n",
    "    stroke_color = '333'\n",
    "    line = alt.Chart(df[df.n >= 10], title=title).mark_line(size=line_width, opacity=0.4).encode(\n",
    "        x=alt.X('n', scale=alt.Scale(type='log', domain=xdomain, nice=False),  title='Dataset size'),\n",
    "        y=alt.Y('mean(mean_loss_test)', scale=alt.Scale(type='log', domain=ydomain, nice=False), title='Test loss'),\n",
    "        color=alt.Color('representation:N', title=color_title, scale=alt.Scale(scheme=colorscheme,), legend=None),\n",
    "    )\n",
    "\n",
    "    point = alt.Chart(df[df.n >= 10], title=title).mark_point(size=80, opacity=1).encode(\n",
    "        x=alt.X('n', scale=alt.Scale(type='log', domain=xdomain, nice=False),  title='Dataset size'),\n",
    "        y=alt.Y('mean(mean_loss_test)', scale=alt.Scale(type='log', domain=ydomain, nice=False), title='Test loss'),\n",
    "        color=alt.Color('representation:N', title=color_title, scale=alt.Scale(scheme=colorscheme,)),\n",
    "        shape=alt.Shape('representation:N', title=color_title), \n",
    "        tooltip=['n', 'representation']\n",
    "    )\n",
    "    \n",
    "    rule_x = alt.Chart(rules_df).mark_rule(size=3, color='999', strokeDash=[4, 4]).encode(x='x')\n",
    "    rule_y = alt.Chart(rules_df).mark_rule(size=3, color='999', strokeDash=[4, 4]).encode(y='y')\n",
    "\n",
    "    chart = alt.layer(rule_x, rule_y, line, point).resolve_scale(\n",
    "        color='independent',\n",
    "        shape='independent'\n",
    "    )\n",
    "    chart = chart.properties(width=600, height=500, background='white')\n",
    "    chart = chart.configure_legend(labelLimit=0)\n",
    "    chart = chart.configure(\n",
    "        title=alt.TitleConfig(fontSize=title_size, fontWeight='normal'),\n",
    "        axis=alt.AxisConfig(titleFontSize=title_size, labelFontSize=label_size, grid=(not final), \n",
    "                            domainWidth=5, domainColor=stroke_color, \n",
    "                            tickWidth=3, tickSize=9, tickCount=4, tickColor=stroke_color, tickOffset=0),\n",
    "        legend=alt.LegendConfig(titleFontSize=title_size, labelFontSize=label_size, labelLimit=0, titleLimit=0,\n",
    "                                orient='top-right', padding=10, \n",
    "                                titlePadding=10, rowPadding=5,\n",
    "                                fillColor='white', strokeColor='black', cornerRadius=0),\n",
    "        view=alt.ViewConfig(strokeWidth=0, stroke=stroke_color),\n",
    "        font='Roboto',\n",
    "    )    \n",
    "    return chart"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def make_latex(df, ns, stack=False, group_n=True, epsilons=[0.5, 0.1]):\n",
    "    df = auc_per_data(df, epsilons).reset_index(drop=True)\n",
    "    df = sc_per_data(df, epsilons)\n",
    "    df.reset_index(drop=True, inplace=True)\n",
    "    auc_cols = {f'str_auc_agg@{eps}'.replace('.', '_'):  f'SDL, $\\\\varepsilon$={eps}' for eps in epsilons}\n",
    "    sc_cols = {f'str_sc@{eps}'.replace('.', '_'):  f'$\\\\varepsilon$SC, $\\\\varepsilon$={eps}' for eps in epsilons}\n",
    "    output_df = df[df.n.isin(ns)].groupby(['model_layer', 'n', *auc_cols.keys(), *sc_cols.keys()]).mean().reset_index()\n",
    "    output_df = output_df[['n', 'model_layer', 'mean_loss_test', 'mdl', *auc_cols.keys(), *sc_cols.keys()]]\n",
    "    output_df = output_df.sort_values('n')\n",
    "    output_df['model_layer'] = output_df['model_layer'].astype(int)\n",
    "    output_df = output_df.rename(columns={'model_layer': 'ELMo layer', 'mean_loss_test': 'Val loss', 'mdl': 'MDL', **auc_cols, **sc_cols})\n",
    "    if stack:\n",
    "        if not group_n:\n",
    "            output_df['n'] = '$n=' + output_df['n'].astype(str) + '$'\n",
    "        output_df = output_df.set_index(['ELMo layer', 'n'])\n",
    "        if group_n:\n",
    "            output_df = output_df.transpose()\n",
    "            display(output_df)\n",
    "            output_df = output_df.stack()\n",
    "            output_df = output_df.swaplevel().sort_values('n', ascending=True)\n",
    "    else:\n",
    "        output_df = output_df.set_index(['n', 'ELMo layer'])\n",
    "        output_df = output_df.transpose()\n",
    "    out = widgets.Output(layout={'border': '1px solid black'})\n",
    "    latex_str = output_df.to_latex(multicolumn_format='c', float_format=\"{:0.2f}\".format, escape=False, column_format='llrrr')\n",
    "    out.append_stdout(latex_str)\n",
    "    return out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "df = load_exps([\n",
    "    'online_l0_n40000_corruptedFalse',\n",
    "    'online_l1_n40000_corruptedFalse',\n",
    "    'online_l2_n40000_corruptedFalse',\n",
    "    'online_l0_n40000_corruptedFalse_seed1',\n",
    "    'online_l1_n40000_corruptedFalse_seed1',\n",
    "    'online_l2_n40000_corruptedFalse_seed1',\n",
    "    'online_l0_n40000_corruptedFalse_seed0',\n",
    "    'online_l1_n40000_corruptedFalse_seed0',\n",
    "    'online_l2_n40000_corruptedFalse_seed0',\n",
    "    'online_l0_n40000_corruptedFalse_seed2',\n",
    "    'online_l1_n40000_corruptedFalse_seed2',\n",
    "    'online_l2_n40000_corruptedFalse_seed2',\n",
    "])\n",
    "df = merge_seeds(df)\n",
    "\n",
    "# df_add = load_exps([\n",
    "#     '1k_online_l0_n40000_corruptedFalse_seed0',\n",
    "#     '1k_online_l1_n40000_corruptedFalse_seed0',\n",
    "#     '1k_online_l2_n40000_corruptedFalse_seed0',\n",
    "# ])\n",
    "# df_add = merge_seeds(df_add)\n",
    "# df = pd.concat([df, df_add], sort=False)\n",
    "# df_add = load_exps([\n",
    "#     '100k_online_l0_n40000_corruptedFalse_seed0',\n",
    "#     '100k_online_l1_n40000_corruptedFalse_seed0',\n",
    "#     '100k_online_l2_n40000_corruptedFalse_seed0',\n",
    "\n",
    "# ])\n",
    "# df_add = merge_seeds(df_add)\n",
    "# df = pd.concat([df, df_add], sort=False)\n",
    "\n",
    "df = df.groupby(['representation', 'n']).mean().reset_index()\n",
    "epsilons = [0.1, 0.5]\n",
    "ns = [461, 474838]\n",
    "chart = loss_data_chart(df, final=True, xrules=ns, yrules=epsilons)\n",
    "display(chart)\n",
    "save(chart, 'elmo_layers.pdf')\n",
    "make_latex(df, ns=ns, stack=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
