{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "32a904f9",
   "metadata": {},
   "source": [
    "## Visualize benchmark results in a table"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "999822a8",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import functools\n",
    "import pandas as pd\n",
    "from matplotlib import colors\n",
    "import matplotlib.pyplot as plt\n",
    "from collections import defaultdict\n",
    "from lm_polygraph.utils.manager import UEManager\n",
    "\n",
    "def b_g(s, A, cmap='PuBu', low=0.8, high=0):\n",
    "    # Pass the columns from Dataframe A\n",
    "    i = A.columns.tolist().index(s.name)\n",
    "    a = A.values[:,i].copy()\n",
    "    if s.name[-1] in ['rcc-auc']:\n",
    "        a = -a\n",
    "    if s.name[0] == 'BARTScoreSeq-rh':\n",
    "        a = -a\n",
    "    rng = a.max() - a.min()\n",
    "    norm = colors.Normalize(a.min() - (rng * low),\n",
    "                        a.max() + (rng * high))\n",
    "    normed = norm(a)\n",
    "    c = [colors.rgb2hex(x) for x in plt.colormaps[cmap](normed)]\n",
    "    return ['background-color: %s' % color for color in c]\n",
    "\n",
    "def get_array(dfs, row, col):\n",
    "    vals = []\n",
    "    for df in dfs:\n",
    "        if row in df.index and col in df.columns:\n",
    "            vals.append(df.loc[row, col])\n",
    "    return vals\n",
    "\n",
    "def pretty_plot(dataset_name, man_files, except_metrics=[], except_gen=['BARTScoreSeq-rh'], level='sequence'):\n",
    "    dfs = []\n",
    "    if isinstance(dataset_name, str):\n",
    "        dataset_name = [dataset_name]\n",
    "        man_files = [man_files]\n",
    "    columns = []\n",
    "    for group_name, group_files in zip(dataset_name, man_files):\n",
    "        gen_metrics = None\n",
    "        for f in group_files:\n",
    "            man = UEManager.load(f)\n",
    "            estimators = [e for (l, e) in man.estimations.keys() if l == level]\n",
    "            gen_metrics = list(set([(group_name, gen_name, m_name)\n",
    "               for (l, e_name, gen_name, m_name) in man.metrics\n",
    "               if l == level and (m_name not in except_metrics) and (gen_name not in except_gen)]))\n",
    "            gen_metrics.sort()\n",
    "            df = {k: {} for k in gen_metrics}\n",
    "            for (l, e_name, gen_name, m_name), value in man.metrics.items():\n",
    "                if l == level and (m_name not in except_metrics) and (gen_name not in except_gen):\n",
    "                    df[group_name, gen_name, m_name][e_name] = value\n",
    "            for k in gen_metrics:\n",
    "                df[k] = [df[k][e] for e in estimators]\n",
    "            df = pd.DataFrame(data=df, index=[e for e in estimators])\n",
    "            df = df.reindex(columns=gen_metrics)\n",
    "            dfs.append(df)\n",
    "        print('Will measure variance using', len(group_files), 'seeds')\n",
    "        columns += gen_metrics\n",
    "    assert(len(dfs) > 0)\n",
    "    index = dfs[0].index\n",
    "    mean, total = defaultdict(lambda: defaultdict(int)), defaultdict(lambda: defaultdict(int))\n",
    "    for col in columns:\n",
    "        for row in index:\n",
    "            vals = get_array(dfs, row, col)\n",
    "            mean[row][col] = -np.mean(vals)\n",
    "            total[row][col] = '{:.2f} ± {:.2f}'.format(np.mean(vals).item() * 100, np.std(vals).item() * 100)\n",
    "    \n",
    "    total_df = pd.DataFrame([[total[row][col] for col in columns] for row in index],\n",
    "                            index=index, columns=pd.MultiIndex.from_tuples(columns))\n",
    "    mean_df = pd.DataFrame([[mean[row][col] for col in columns] for row in index],\n",
    "                           index=index, columns=pd.MultiIndex.from_tuples(columns))\n",
    "    \n",
    "    s = total_df.style.apply(functools.partial(b_g, A=mean_df, cmap='Reds'), axis=0)\n",
    "    s.set_table_styles([{  # for row hover use <tr> instead of <td>\n",
    "        'selector': 'td:hover',\n",
    "        'props': [('background-color', '#ffffb3')]\n",
    "    }, {\n",
    "        'selector': '.index_name',\n",
    "        'props': 'font-style: italic; color: darkgrey; font-weight:normal;'\n",
    "    }])\n",
    "    s.set_table_styles({\n",
    "        columns[i]: [{'selector': 'th', 'props': 'border-left: {}px solid black'.format(1 if columns[i][0] == columns[i - 1][0] else 2)},\n",
    "                     {'selector': 'td', 'props': 'border-left: {}px solid black'.format(1 if columns[i][0] == columns[i - 1][0] else 2)}]\n",
    "        for i in range(1, len(columns)) if i == 0 or columns[i][1] != columns[i - 1][1]\n",
    "    }, overwrite=False, axis=0)\n",
    "    return s"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "31c03154",
   "metadata": {},
   "outputs": [],
   "source": [
    "# visualize results in a table\n",
    "pretty_plot(\n",
    "    'TriviaQA, Dolly3b',\n",
    "    # outputs generated by scripts/polygraph_eval benchmark\n",
    "    # provide several seeds to calculate variance\n",
    "    ['./workdir/output_seed' + str(x)\n",
    "     for x in range(1, 10)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2a1bfefa",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
