{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import gzip, json, glob\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import pandas as pd\n",
    "\n",
    "plt.rcParams.update({\n",
    "    \"text.usetex\": False,\n",
    "    \"font.family\": \"serif\",\n",
    "    #\"font.serif\": [\"Palatino\"],\n",
    "    \n",
    "    \"legend.frameon\": False,\n",
    "    \"legend.fancybox\": False,\n",
    "    \n",
    "    'font.size': 8,\n",
    "    'axes.linewidth': 0.6,\n",
    "\n",
    "    'xtick.major.width': 0.6,\n",
    "    'ytick.major.width': 0.6,\n",
    "    'xtick.minor.width': 0.6,\n",
    "    'ytick.minor.width': 0.6,\n",
    "    \n",
    "    \"lines.linewidth\": 0.9,\n",
    "    \n",
    "    \"axes.grid\": True,\n",
    "    \"grid.color\": \"#EEE\"\n",
    "    })\n",
    "\n",
    "plt.rc(\"text.latex\", preamble=r\"\\usepackage{amsmath}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_files(pattern): # data map (example_i[, target_label]) => data\n",
    "    data = {}\n",
    "    max_memory = 0.0\n",
    "    for filename in glob.glob(pattern):\n",
    "        print(\"Reading file:\", filename)\n",
    "        with gzip.open(filename, \"r\") as f:\n",
    "            lines = f.readlines()\n",
    "            for line in lines:\n",
    "                j = json.loads(line)\n",
    "                #print(filename, j.keys())\n",
    "                example_i = j[\"example_i\"]\n",
    "                if \"target_label\" in j:\n",
    "                    key = (example_i, j[\"target_label\"])\n",
    "                else:\n",
    "                    key = example_i\n",
    "\n",
    "                d = data.get(key, {})\n",
    "                data[key] = d\n",
    "                \n",
    "                if \"veritas_deltas\" in j:\n",
    "                    max_memory = max(max_memory, max(x[\"memory\"][-1] for x in j[\"veritas_log\"]))\n",
    "                    try: column_prefix = f\"veritas{j['max_time']:02d}\"\n",
    "                    except:\n",
    "                        s0 = filename.find(\"time\")+4\n",
    "                        s1 = filename.find(\"-\", s0)\n",
    "                        max_time = int(filename[s0:s1])\n",
    "                        #print(\"no max time in\", filename, f\"extracted '{max_time}' from filename\")\n",
    "                        column_prefix = f\"veritas{max_time:02d}\"\n",
    "                    d[f\"{column_prefix}_time\"] = j[\"veritas_time\"]\n",
    "                    d[f\"{column_prefix}_delta\"] = j[\"veritas_deltas\"][-1][0]\n",
    "                    #print(\"deltas\", j[\"veritas_deltas\"])\n",
    "                \n",
    "                if \"merge_ext\" in j and \"max_clique\" in j[\"merge_ext\"]:\n",
    "                    column_prefix = f\"mext_T{j['merge_ext']['max_clique']}_L{j['merge_ext']['max_level']}\"\n",
    "                    d[f\"{column_prefix}_time\"] = j[\"merge_ext\"][\"times\"][-1]\n",
    "                    d[f\"{column_prefix}_delta\"] = j[\"merge_ext\"][\"deltas\"][-1]\n",
    "                    \n",
    "                if \"kantchelian\" in j:\n",
    "                    column_prefix = \"kan\"\n",
    "                    d[f\"{column_prefix}_time\"] = j[\"kantchelian\"][\"time_p\"]\n",
    "                    d[f\"{column_prefix}_delta\"] = j[\"kantchelian_delta\"]\n",
    "                    #print(j[\"kantchelian\"][\"bounds\"])\n",
    "    print(f\"max_memory for {pattern} is: {max_memory/(1024*1024)}\")\n",
    "    return data\n",
    "\n",
    "def get_column_names(data):\n",
    "    columns = set()\n",
    "    for value in data.values():\n",
    "        columns |= value.keys()\n",
    "    return sorted(columns)\n",
    "        \n",
    "\n",
    "def to_df(data):\n",
    "    colnames = get_column_names(data)\n",
    "    columns = {}\n",
    "    index = pd.Series(list(data.keys()))\n",
    "    for c in colnames:\n",
    "        values = {}\n",
    "        for key, value in data.items():\n",
    "            if c in value:\n",
    "                values[key] = value[c]\n",
    "        columns[c] = values\n",
    "    df = pd.DataFrame(columns)\n",
    "    df.set_index(index)\n",
    "    return df\n",
    "\n",
    "def load_to_df(pattern, dropna=True):\n",
    "    data = load_files(pattern)\n",
    "    df = to_df(data)\n",
    "    df.sort_index(inplace=True, axis=0)\n",
    "    if dropna: df = df.dropna()\n",
    "    return df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dfs={}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dfs[\"covtype\"] = load_to_df(\"/home/laurens/repos/veritas/tests/experiments/results/r1-covtype-*\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dfs[\"f-mnist\"] = load_to_df(\"/home/laurens/repos/veritas/tests/experiments/results/r1-f-mnist-*\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dfs[\"higgs\"] = load_to_df(\"/home/laurens/repos/veritas/tests/experiments/results/r1-higgs-*\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dfs[\"ijcnn1\"] = load_to_df(\"/home/laurens/repos/veritas/tests/experiments/results/r1-ijcnn1-*\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dfs[\"mnist\"] = load_to_df(\"/home/laurens/repos/veritas/tests/experiments/results/r1-mnist-*\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dfs[\"webspam\"] = load_to_df(\"/home/laurens/repos/veritas/tests/experiments/results/r1-webspam-*\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dfs[\"mnist2v6\"] = load_to_df(\"/home/laurens/repos/veritas/tests/experiments/results/r1-mnist2v6-*\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gridcolor=\"#EEEEEE\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Bound difference"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#datasets = [\"covtype\", \"f-mnist\", \"higgs\"]\n",
    "#datasets = [\"ijcnn1\", \"mnist\", \"webspam\", \"mnist2v6\"]\n",
    "datasets = [\"covtype\", \"f-mnist\", \"higgs\", \"ijcnn1\", \"mnist\", \"webspam\", \"mnist2v6\"]\n",
    "datasets = [\"webspam\"]\n",
    "\n",
    "fig, axs = plt.subplots(1, len(datasets), figsize=(len(datasets)*4.0, 1.8))\n",
    "fig.subplots_adjust(left=0.15, right=0.9, top=0.85, bottom=0.25)\n",
    "\n",
    "axs=[axs]\n",
    "\n",
    "for d, ax in zip(datasets, axs):\n",
    "    df = dfs[d]\n",
    "    #display(d, df)\n",
    "    time_columns = [c for c in df.columns if c.endswith(\"time\")]\n",
    "    delta_columns = [c for c in df.columns if c.endswith(\"delta\")]\n",
    "    time_mean = df[time_columns].mean()\n",
    "    time_std = df[time_columns].std()\n",
    "    #div_from_opt = df[delta_columns].subtract(df[\"kan_delta\"], axis=0).abs().mean()\n",
    "    div_from_opt = df[delta_columns].mean()\n",
    "    speedup = (1.0/df[time_columns].divide(df[\"kan_time\"], axis=0)).mean()\n",
    "    \n",
    "    scale_ = np.log10(div_from_opt.max().max()).round()\n",
    "    scale = 10**-scale_ * 10\n",
    "    \n",
    "    print(\"scale\", d, scale)\n",
    "    #div_from_opt *= scale\n",
    "    \n",
    "    veritas_time_columns = [c for c in df.columns if c.endswith(\"time\") and c.startswith(\"veritas\")]\n",
    "    veritas_delta_columns = [c for c in df.columns if c.endswith(\"delta\") and c.startswith(\"veritas\")]\n",
    "    mer_time_columns = [c for c in df.columns if c.endswith(\"time\") and c.startswith(\"mext\")]\n",
    "    mer_delta_columns = [c for c in df.columns if c.endswith(\"delta\") and c.startswith(\"mext\")]\n",
    "    \n",
    "    #ax.set_title(f\"{d} (n={len(df)})\")\n",
    "    ax.set_title(f\"{d}\")\n",
    "    ax.set_xlabel(\"Time\")\n",
    "    #ax.set_ylabel(\"Robustness delta value\")\n",
    "    #if scale != 1.0:\n",
    "    #    ax.text(-0.2, 1.1, f'$\\\\delta \\\\times 10^{{{scale_:.0f}}}$', horizontalalignment='left', verticalalignment='center', transform=ax.transAxes)\n",
    "    #else:\n",
    "    ax.text(-0.1, 1.1, f'$\\\\delta$', horizontalalignment='left', verticalalignment='center', transform=ax.transAxes)\n",
    "    \n",
    "    #ax.plot(time_mean[veritas_time_columns], div_from_opt[veritas_delta_columns], marker=\".\", linestyle=\"-\", markersize=4, label=\"Veritas\")\n",
    "    ax.errorbar(time_mean[veritas_time_columns], div_from_opt[veritas_delta_columns], xerr=time_std[veritas_time_columns],\n",
    "               capthick=1.0, elinewidth=None, capsize=2.0, marker=\".\", linestyle=\":\", markersize=4, errorevery=4, label=\"Veritas\")\n",
    "    #for i, (x, y, m) in enumerate(zip(time_mean[veritas_time_columns], div_from_opt[veritas_delta_columns], speedup[veritas_time_columns])):\n",
    "    #    ax.text(x, y-0.1, f\"{m:.0f}×\", horizontalalignment='right', verticalalignment='top', c=\"gray\")\n",
    "    #l, = ax.plot(time_mean[mer_time_columns], div_from_opt[mer_delta_columns], marker=\"8\", markersize=5, linestyle=\":\", label=\"Merge\")\n",
    "    #for i, (x, y, m) in enumerate(zip(time_mean[mer_time_columns], div_from_opt[mer_delta_columns], speedup[mer_time_columns])):\n",
    "    #    ax.text(x, y-0.1, f\"{m:.0f}×\", horizontalalignment='right', verticalalignment='top', c=\"gray\")\n",
    "    #ax.axhline(y=div_from_opt[mer_delta_columns][0], c=l.get_color(), ls=l.get_linestyle())\n",
    "    ax.errorbar(time_mean[mer_time_columns], div_from_opt[mer_delta_columns], xerr=time_std[mer_time_columns],\n",
    "               capthick=1.0, elinewidth=None, capsize=2.0, marker=\"8\", markersize=5, linestyle=\"\", label=\"Merge\")\n",
    "    #l, = ax.plot(time_mean[[\"kan_time\"]], div_from_opt[[\"kan_delta\"]], marker=\"*\", linestyle=\":\", markersize=4, label=\"MILP\")\n",
    "    ax.errorbar(time_mean[[\"kan_time\"]], div_from_opt[[\"kan_delta\"]], marker=\"*\", linestyle=\"\", markersize=4,\n",
    "                xerr=time_std[[\"kan_time\"]],\n",
    "               capthick=1.0, elinewidth=0.0, capsize=2.0, barsabove=True, label=\"MILP\")\n",
    "    ax.axhline(y=div_from_opt[\"kan_delta\"], c=\"gray\", ls=\":\", label=\"Exact\")\n",
    "\n",
    "    ax.set_xscale(\"log\")\n",
    "    \n",
    "    #xlim = (0.0, 1.1*time_mean[\"kan_time\"])\n",
    "    #ax.set_xlim(xlim)\n",
    "    #ax.set_xticks(list(np.arange(0.0,xlim[1], 10.0)))\n",
    "    ax.legend(fontsize=\"large\", bbox_to_anchor=(1.0, 0.8))\n",
    "\n",
    "plt.savefig(f\"/tmp/bound_err_{datasets[0]}.pdf\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Counting stats"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "fig, axs = plt.subplots(1, len(dfs), figsize=(len(dfs)*5, 4))\n",
    "\n",
    "better_stats = {}\n",
    "worse_stats = {}\n",
    "same_stats = {}\n",
    "\n",
    "for (d, df), ax in zip(dfs.items(), axs):\n",
    "    time_columns = [c for c in df.columns if c.endswith(\"time\")]\n",
    "    delta_columns = [c for c in df.columns if c.endswith(\"delta\")]\n",
    "    time_mean = df[time_columns].mean()\n",
    "    veritas_time_columns = [c for c in df.columns if c.endswith(\"time\") and c.startswith(\"veritas\")]\n",
    "    veritas_delta_columns = [c for c in df.columns if c.endswith(\"delta\") and c.startswith(\"veritas\")]\n",
    "    mer_delta_columns = [c for c in df.columns if c.endswith(\"delta\") and c.startswith(\"mext\")]\n",
    "    mer_time_columns = [c for c in df.columns if c.endswith(\"time\") and c.startswith(\"mext\")]\n",
    "    \n",
    "    div_from_opt = df[delta_columns].subtract(df[\"kan_delta\"], axis=0).abs()\n",
    "    \n",
    "    same_threshold = (df[\"kan_delta\"].quantile(0.8) - df[\"kan_delta\"].quantile(0.2)) / 100\n",
    "    \n",
    "    print(f\"same threshold {d}: {same_threshold}\")\n",
    "    \n",
    "    ver_better = div_from_opt[delta_columns].subtract(div_from_opt[mer_delta_columns[0]], axis=0) < -same_threshold\n",
    "    ver_worse = div_from_opt[delta_columns].subtract(div_from_opt[mer_delta_columns[0]], axis=0) > same_threshold\n",
    "    ver_same = ~ver_better & ~ver_worse\n",
    "    #display(ver_better.sum(), ver_worse.sum(), ver_same.sum())\n",
    "    \n",
    "    n = len(df)\n",
    "    ax.set_title(f\"{d} (n={n})\")\n",
    "    ax.set_xlabel(\"Time [s]\")\n",
    "    #ax.set_ylabel(\"Robustness delta value\")\n",
    "    ax.text(-0.1, 1.04, '%', horizontalalignment='right', verticalalignment='center', transform=ax.transAxes)\n",
    "    \n",
    "    ax.plot(time_mean[veritas_time_columns], ver_better.sum()[veritas_delta_columns]/n, marker=\"^\", linestyle=\":\", label=\"Better\")\n",
    "    ax.plot(time_mean[veritas_time_columns], ver_worse.sum()[veritas_delta_columns]/n, marker=\"v\", linestyle=\":\", label=\"Worse\")\n",
    "    ax.plot(time_mean[veritas_time_columns], ver_same.sum()[veritas_delta_columns]/n, marker=\".\", linestyle=\":\", label=\"Same\")\n",
    "    ax.axvline(x=time_mean[mer_time_columns[0]], ls=\"--\", color=\"gray\", label=\"Merge time\")\n",
    "    #ax.set_xscale(\"log\")\n",
    "    ax.legend()\n",
    "    \n",
    "    better_stats[d] = ver_better.sum()[veritas_delta_columns]/n\n",
    "    worse_stats[d] = ver_worse.sum()[veritas_delta_columns]/n\n",
    "    same_stats[d] = ver_same.sum()[veritas_delta_columns]/n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "better_df = (pd.DataFrame(better_stats).transpose()*100).round(1)\n",
    "worse_df = (pd.DataFrame(worse_stats).transpose()*100).round(1)\n",
    "same_df = (pd.DataFrame(same_stats).transpose()*100).round(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "display(better_df, worse_df, same_df)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# How many problems are solved in 1s, 2s, ..."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "datasets = [\"covtype\", \"f-mnist\", \"higgs\", \"ijcnn1\", \"mnist\", \"mnist2v6\"]\n",
    "fig, axs = plt.subplots(1, len(datasets), figsize=(len(datasets)*1.4, 1.8), sharey=True, sharex=True)\n",
    "fig.subplots_adjust(left=0.04, bottom=0.22, right=0.99\n",
    "                    , top=0.7, wspace=0.1, hspace=0.4)\n",
    "axs = axs.flatten()\n",
    "\n",
    "better_stats = {}\n",
    "worse_stats = {}\n",
    "same_stats = {}\n",
    "\n",
    "for d, ax in zip(datasets, axs):\n",
    "    df = dfs[d]\n",
    "    time_columns = [c for c in df.columns if c.endswith(\"time\")]\n",
    "    delta_columns = [c for c in df.columns if c.endswith(\"delta\")]\n",
    "    time_mean = df[time_columns].mean()\n",
    "    veritas_time_columns = [c for c in df.columns if c.endswith(\"time\") and c.startswith(\"veritas\")]\n",
    "    veritas_delta_columns = [c for c in df.columns if c.endswith(\"delta\") and c.startswith(\"veritas\")]\n",
    "    mer_delta_columns = [c for c in df.columns if c.endswith(\"delta\") and c.startswith(\"mext\")]\n",
    "    mer_time_columns = [c for c in df.columns if c.endswith(\"time\") and c.startswith(\"mext\")]\n",
    "    \n",
    "    times = np.linspace(0, 12, 50)\n",
    "    div_from_opt = df[veritas_delta_columns+mer_delta_columns].subtract(df[\"kan_delta\"], axis=0).abs()\n",
    "    #q10, q50, q90 = pd.Series(div_from_opt.to_numpy().flatten()).quantile([0.25, 0.5, 1.0])\n",
    "    #q50 = pd.Series(div_from_opt.to_numpy().flatten()).median()\n",
    "    q50 = pd.Series(df[mer_delta_columns].subtract(df[\"kan_delta\"], axis=0).abs().to_numpy().flatten()).mean()\n",
    "    #print(d, \"quantiles:\", q10, q50, q90)\n",
    "    #print(\"   how often do we see median?  \", (div_from_opt==q50).sum().sum()/len(df) * 100)\n",
    "    #print(\"   how many unique error values?\", len(pd.Series(div_from_opt.to_numpy().flatten()).unique()), np.prod(div_from_opt.shape))\n",
    "    \n",
    "    #in_time_ver = pd.concat([df[veritas_time_columns[0]]]*len(times), axis=1).le(times, axis=1)\n",
    "    #in_time_ver.columns = [f\"in_time{t:.2f}\" for t in times]\n",
    "    #in_time_mer = pd.concat([df[mer_time_columns[0]]]*len(times), axis=1).le(times, axis=1)\n",
    "    #in_time_mer.columns = [f\"in_time{t:.2f}\" for t in times]\n",
    "    in_time_kan = pd.concat([df[\"kan_time\"]]*len(times), axis=1).le(times, axis=1)\n",
    "    in_time_kan.columns = [f\"in_time{t:.2f}\" for t in times]\n",
    "    \n",
    "    in_time_ver10 = None\n",
    "    in_time_ver50 = None\n",
    "    in_time_ver90 = None\n",
    "    for tcol, dcol in zip(veritas_time_columns, veritas_delta_columns):\n",
    "        #x10 = pd.concat([(df[tcol]<=t) & ((df[dcol]-df[\"kan_delta\"]).abs()<=q10) for t in times], axis=1)\n",
    "        x50 = pd.concat([(df[tcol]<=t) & ((df[dcol]-df[\"kan_delta\"]).abs()<=q50) for t in times], axis=1)\n",
    "        #x90 = pd.concat([(df[tcol]<=t) & ((df[dcol]-df[\"kan_delta\"]).abs()<=q90) for t in times], axis=1)\n",
    "        #in_time_ver10 = (in_time_ver10 | x10) if in_time_ver10 is not None else x10\n",
    "        in_time_ver50 = (in_time_ver50 | x50) if in_time_ver50 is not None else x50\n",
    "        #in_time_ver90 = (in_time_ver90 | x90) if in_time_ver90 is not None else x90\n",
    "    \n",
    "    #in_time_mer10 = pd.concat([(df[mer_time_columns[0]]<=t) & ((df[mer_delta_columns[0]]-df[\"kan_delta\"]).abs()<=q10) for t in times], axis=1)\n",
    "    in_time_mer50 = pd.concat([(df[mer_time_columns[0]]<=t) & ((df[mer_delta_columns[0]]-df[\"kan_delta\"]).abs()<=q50) for t in times], axis=1)\n",
    "    #in_time_mer90 = pd.concat([(df[mer_time_columns[0]]<=t) & ((df[mer_delta_columns[0]]-df[\"kan_delta\"]).abs()<=q90) for t in times], axis=1)\n",
    "    \n",
    "    n = len(df)\n",
    "    #ax.set_title(f\"{d} (n={n}, m={q50:.2g})\")\n",
    "    #ax.set_title(f\"{d} (n={n})\")\n",
    "    ax.set_title(f\"{d}\")\n",
    "    ax.set_xlabel(\"Time\")\n",
    "    #ax.set_ylabel(\"Robustness delta value\")\n",
    "    if d==\"covtype\":# or d==\"ijcnn1\":\n",
    "        ax.text(-0.1, 1.09, '%', horizontalalignment='right', verticalalignment='center', transform=ax.transAxes)\n",
    "    #lv, = ax.plot(times, in_time_ver.mean()*100, ls=(0, (2, 4)))\n",
    "    #lm, = ax.plot(times, in_time_mer.mean()*100, ls=(0, (1, 4)))\n",
    "    lv, = ax.plot(times, in_time_ver50.mean()*100, ls=\"-\", label=\"Veritas\")\n",
    "    #ax.fill_between(times, in_time_ver10.mean()*100, in_time_ver90.mean()*100, alpha=0.1, color=lv.get_color())\n",
    "    #ax.plot(times, in_time_ver10.mean()*100, ls=(0, (1, 4)), c=lv.get_color())\n",
    "    #ax.plot(times, in_time_ver90.mean()*100, ls=(0, (1, 4)), c=lv.get_color())\n",
    "    lm, = ax.plot(times, in_time_mer50.mean()*100, ls=\"--\", label=\"Merge\")\n",
    "    #ax.fill_between(times, in_time_mer10.mean()*100, in_time_mer90.mean()*100, alpha=0.1, color=lm.get_color())\n",
    "    #ax.plot(times, in_time_mer10.mean()*100, ls=(0, (1, 4)), c=lm.get_color())\n",
    "    #ax.plot(times, in_time_mer90.mean()*100, ls=(0, (1, 4)), c=lm.get_color())\n",
    "    ax.plot(times, in_time_kan.mean()*100, ls=\"-.\", label=\"MILP\")\n",
    "\n",
    "axs[1].legend(ncol=3, bbox_to_anchor=(3.4, 1.6), fontsize=\"large\")\n",
    "#for ax in axs[3:]: ax.set_xlabel(\"Time\")\n",
    "plt.savefig(\"/tmp/solved_per_time.pdf\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Tables"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rows = {}\n",
    "\n",
    "def map_name(n):\n",
    "    if \"kan\" in n:\n",
    "        return \"MIPS\"\n",
    "    if \"veritas\" in n:\n",
    "        #return f\"$\\\\ouralg{{}}_{{{int(n[7:9])}}}$\"\n",
    "        return \"\\\\ouralg{}\"\n",
    "    if \"mext\" in n:\n",
    "        return \"\\\\merge{}\"\n",
    "\n",
    "def which_column(d):\n",
    "    if d == \"f-mnist\":\n",
    "        return \"veritas06\"\n",
    "    else:\n",
    "        return \"veritas02\"\n",
    "\n",
    "for i, (d, df) in enumerate(dfs.items()):\n",
    "    time_columns = [c for c in df.columns if c.endswith(\"time\") and (not c.startswith(\"veritas\") or c.startswith(which_column(d)))]\n",
    "    delta_columns = [c for c in df.columns if c.endswith(\"delta\") and (not c.startswith(\"veritas\") or c.startswith(which_column(d)))]\n",
    "    time_mean = df[time_columns].mean()\n",
    "    \n",
    "    r1 = df[delta_columns].mean()\n",
    "    r1[r1.index[1:]] /= r1[r1.index[0]]\n",
    "    r1[r1.index[1:]] *= 100.0\n",
    "    r1[r1.index[1:]] = [f\"\\\\SI{{{x:.3g}}}{{\\percent}}\" for x in r1[r1.index[1:]]]\n",
    "    r1.index = [map_name(n) for n in r1.index]\n",
    "    r2 = df[time_columns].mean()\n",
    "    r2.index = [map_name(n) for n in r2.index]\n",
    "    r3 = df[time_columns].std()\n",
    "    r3.index = [map_name(n) for n in r3.index]\n",
    "    r4 = df[time_columns].mean()\n",
    "    r4[r4.index[1:]] = r4[r4.index[0]] / r4[r4.index[1:]]\n",
    "    r4[r4.index[0]] = \"\"\n",
    "    r4[r4.index[1:]] = [f\"\\\\SI{{{x:.0f}}}{{\\times}}\" for x in r4[r4.index[1:]]]\n",
    "    r4.index = [map_name(n) for n in r4.index]\n",
    "    \n",
    "    \n",
    "    rows[(d, \"$\\\\delta$\")] = r1\n",
    "    rows[(d, \"$t$\")] = r2\n",
    "    rows[(d, \"$\\\\times$\")] = r4\n",
    "    rows[(d, \"$\\\\sigma_t$\")] = r3\n",
    "\n",
    "means_df = pd.DataFrame(rows)\n",
    "means_df = means_df.transpose()\n",
    "means_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(means_df.to_latex(escape=False))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Counts table\n",
    "rows = {}\n",
    "\n",
    "def map_name(n):\n",
    "    if \"veritas\" in n:\n",
    "        #return f\"$\\\\ouralg{{}}_{{{int(n[7:9])}}}$\"\n",
    "        return f\"Budget {int(n[7:9])}\"\n",
    "\n",
    "for i, (d, df) in enumerate(dfs.items()):\n",
    "    time_columns = [c for c in df.columns if c.endswith(\"time\") and c.startswith(\"ver\")]\n",
    "    delta_columns = [c for c in df.columns if c.endswith(\"delta\") and c.startswith(\"ver\")]\n",
    "    mer_time_column = [c for c in df.columns if c.endswith(\"time\") and c.startswith(\"mext\")][0]\n",
    "    mer_delta_column = [c for c in df.columns if c.endswith(\"delta\") and c.startswith(\"mext\")][0]\n",
    "    \n",
    "    kan_delta = df[\"kan_delta\"]\n",
    "    same_threshold = (kan_delta.quantile(0.6) - kan_delta.quantile(0.4)) / 1000\n",
    "    \n",
    "    mer_abs_diff = (df[mer_delta_column] - kan_delta).abs()\n",
    "    r1 = df[delta_columns].subtract(kan_delta, axis=0).abs().lt(mer_abs_diff, axis=0).mean()\n",
    "    r1 *= 100.0\n",
    "    r1.index = [map_name(n) for n in r1.index]\n",
    "    \n",
    "    r2 = df[delta_columns].subtract(kan_delta, axis=0).subtract(mer_abs_diff, axis=0).le(-same_threshold).mean()\n",
    "    r2 *= 100.0\n",
    "    r2.index = [map_name(n) for n in r2.index]\n",
    "    \n",
    "    r3 = df[delta_columns].subtract(kan_delta, axis=0).subtract(mer_abs_diff, axis=0).ge(same_threshold).mean()\n",
    "    r3 *= 100.0\n",
    "    r3.index = [map_name(n) for n in r3.index]\n",
    "    \n",
    "    r4 = df[delta_columns].subtract(kan_delta, axis=0).subtract(mer_abs_diff, axis=0).abs().lt(same_threshold).mean()\n",
    "    r4 *= 100.0\n",
    "    r4.index = [map_name(n) for n in r4.index]\n",
    "    \n",
    "    r5 = df[time_columns].lt(df[mer_time_column], axis=0).mean()\n",
    "    r5 *= 100.0\n",
    "    r5.index = [map_name(n) for n in r5.index]\n",
    "    \n",
    "    r6_a = df[time_columns].gt(df[mer_time_column], axis=0)\n",
    "    r6_b = df[delta_columns].subtract(kan_delta, axis=0).subtract(mer_abs_diff, axis=0).ge(same_threshold)\n",
    "    r6_a.columns = [map_name(n) for n in r6_a.columns]\n",
    "    r6_b.columns = [map_name(n) for n in r6_b.columns]\n",
    "    r6 = (r6_a & r6_b).mean()\n",
    "    r6 *= 100.0\n",
    "\n",
    "    \n",
    "    #rows[(d, \"r1\")] = r1\n",
    "    rows[(d, \"better\")] = r2\n",
    "    rows[(d, \"worse\")] = r3\n",
    "    rows[(d, \"same\")] = r4\n",
    "    rows[(d, \"faster\")] = r5\n",
    "    rows[(d, \"slower and worse\")] = r6\n",
    "    \n",
    "\n",
    "counts_df = pd.DataFrame(rows)\n",
    "counts_df = counts_df.transpose()\n",
    "counts_df.round(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "formatter = lambda x: f\"\\\\SI{{{x:.1f}}}{{\\percent}}\"\n",
    "print(counts_df.to_latex(escape=False, formatters=[formatter] * 5))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "time_columns = [c for c in df.columns if c.endswith(\"time\")]\n",
    "delta_columns = [c for c in df.columns if c.endswith(\"delta\")]\n",
    "veritas_time_columns = [c for c in df.columns if c.endswith(\"time\") and c.startswith(\"veritas\")]\n",
    "veritas_delta_columns = [c for c in df.columns if c.endswith(\"delta\") and c.startswith(\"veritas\")]\n",
    "mer_time_columns = [c for c in df.columns if c.endswith(\"time\") and c.startswith(\"mext\")]\n",
    "mer_delta_columns = [c for c in df.columns if c.endswith(\"delta\") and c.startswith(\"mext\")]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "delta_mean = df[delta_columns].mean()\n",
    "time_mean = df[time_columns].mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df[delta_columns]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.errorbar(time_mean, delta_mean, marker=\"x\", ls=\"\", xerr=df[time_columns].std())#, yerr=df[delta_columns].std())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "time_mean[veritas_time_columns]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "div_from_opt = df[delta_columns].subtract(df[\"kan_delta\"], axis=0).abs().mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.title(\"Mean absolute difference of delta value\")\n",
    "plt.plot(time_mean[veritas_time_columns], div_from_opt[veritas_delta_columns], marker=\".\", linestyle=\":\", label=\"Veritas\")\n",
    "l, = plt.plot(time_mean[mer_time_columns], div_from_opt[mer_delta_columns], marker=\"o\", linestyle=\":\", label=\"Merge\")\n",
    "plt.axhline(y=div_from_opt[mer_delta_columns][0], c=l.get_color(), ls=l.get_linestyle())\n",
    "l, = plt.plot(time_mean[[\"kan_time\"]], div_from_opt[[\"kan_delta\"]], marker=\"x\", linestyle=\":\", label=\"MILP\")\n",
    "plt.axhline(y=div_from_opt[\"kan_delta\"], c=l.get_color(), ls=l.get_linestyle())\n",
    "plt.legend()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.title(\"Mean delta value\")\n",
    "plt.plot(time_mean[veritas_time_columns], delta_mean[veritas_delta_columns], marker=\".\", linestyle=\":\", label=\"Veritas\")\n",
    "l, = plt.plot(time_mean[mer_time_columns], delta_mean[mer_delta_columns], marker=\"o\", linestyle=\":\", label=\"Merge\")\n",
    "plt.axhline(y=delta_mean[mer_delta_columns][0], c=l.get_color(), ls=l.get_linestyle())\n",
    "l, = plt.plot(time_mean[[\"kan_time\"]], delta_mean[[\"kan_delta\"]], marker=\"x\", linestyle=\":\", label=\"MILP\")\n",
    "plt.axhline(y=delta_mean[\"kan_delta\"], c=l.get_color(), ls=l.get_linestyle())\n",
    "plt.legend()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "time_mean"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df[delta_columns].subtract(df[\"kan_delta\"], axis=0).describe()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df[df[\"kan_delta\"]>20]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dfs[\"f-mnist\"] = load_to_df(\"/home/laurens/repos/veritas/tests/experiments/results/r1-f-mnist-time2*\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "df = dfs[\"mnist\"]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
