{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import os"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Main Exp Table"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 100,
   "metadata": {},
   "outputs": [],
   "source": [
    "aug_types = [(0, 0.0), (2, 2.0), (4, 2.0), (5, 2.0), (6, 2.0), (7, 2.0), (10, 2.0), (1, 0.1), (1, 0.2)]\n",
    "index_names = ['Single', 'Single+Concat (2x)', 'Single+ConcatRel (2x)', 'w/ Anaphora', 'w/ Ellipsis', 'w/ Mixed', 'w/ ChatGPT', '90%Single+10%Multi', '80%Single+20%Multi']\n",
    "column_names = ['Multi JGA', 'Multi TA', 'Multi CDTA', 'Cross JGA', 'Single JGA']\n",
    "datasets = ['multiwoz21', 'sgd/group0', 'sgd/group1', 'sgd/group2']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 101,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead tr th {\n",
       "        text-align: left;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr>\n",
       "      <th></th>\n",
       "      <th colspan=\"5\" halign=\"left\">multiwoz21</th>\n",
       "      <th colspan=\"5\" halign=\"left\">sgd/group0</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th></th>\n",
       "      <th>Multi JGA</th>\n",
       "      <th>Multi TA</th>\n",
       "      <th>Multi CDTA</th>\n",
       "      <th>Cross JGA</th>\n",
       "      <th>Single JGA</th>\n",
       "      <th>Multi JGA</th>\n",
       "      <th>Multi TA</th>\n",
       "      <th>Multi CDTA</th>\n",
       "      <th>Cross JGA</th>\n",
       "      <th>Single JGA</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>Single</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Single+Concat (2x)</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Single+ConcatRel (2x)</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>w/ Anaphora</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>w/ Ellipsis</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>w/ Mixed</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>w/ ChatGPT</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>90%Single+10%Multi</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>80%Single+20%Multi</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                      multiwoz21                                           \\\n",
       "                       Multi JGA Multi TA Multi CDTA Cross JGA Single JGA   \n",
       "Single                       NaN      NaN        NaN       NaN        NaN   \n",
       "Single+Concat (2x)           NaN      NaN        NaN       NaN        NaN   \n",
       "Single+ConcatRel (2x)        NaN      NaN        NaN       NaN        NaN   \n",
       "w/ Anaphora                  NaN      NaN        NaN       NaN        NaN   \n",
       "w/ Ellipsis                  NaN      NaN        NaN       NaN        NaN   \n",
       "w/ Mixed                     NaN      NaN        NaN       NaN        NaN   \n",
       "w/ ChatGPT                   NaN      NaN        NaN       NaN        NaN   \n",
       "90%Single+10%Multi           NaN      NaN        NaN       NaN        NaN   \n",
       "80%Single+20%Multi           NaN      NaN        NaN       NaN        NaN   \n",
       "\n",
       "                      sgd/group0                                           \n",
       "                       Multi JGA Multi TA Multi CDTA Cross JGA Single JGA  \n",
       "Single                       NaN      NaN        NaN       NaN        NaN  \n",
       "Single+Concat (2x)           NaN      NaN        NaN       NaN        NaN  \n",
       "Single+ConcatRel (2x)        NaN      NaN        NaN       NaN        NaN  \n",
       "w/ Anaphora                  NaN      NaN        NaN       NaN        NaN  \n",
       "w/ Ellipsis                  NaN      NaN        NaN       NaN        NaN  \n",
       "w/ Mixed                     NaN      NaN        NaN       NaN        NaN  \n",
       "w/ ChatGPT                   NaN      NaN        NaN       NaN        NaN  \n",
       "90%Single+10%Multi           NaN      NaN        NaN       NaN        NaN  \n",
       "80%Single+20%Multi           NaN      NaN        NaN       NaN        NaN  "
      ]
     },
     "execution_count": 101,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "main_exp_table = pd.DataFrame(index=index_names, columns=pd.MultiIndex.from_product([datasets[:2], column_names]))\n",
    "main_exp_table"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 102,
   "metadata": {},
   "outputs": [],
   "source": [
    "table = main_exp_table\n",
    "for i, (aug_type, aug_times) in enumerate(aug_types):\n",
    "    index_name = index_names[i]\n",
    "    for dataset in datasets[:2]:\n",
    "        res_dir = f'output/{dataset}/t5-large_model0_context100_aug{aug_type}_x{aug_times}'\n",
    "        \n",
    "        res_file = f'{res_dir}/test_multi_domain_result.md'\n",
    "        df = pd.read_table(res_file, sep='|', index_col=1).dropna(axis=1, how='all').iloc[1:]\n",
    "        df = df.apply(lambda x: x.str.strip())\n",
    "        df = df.rename(columns=str.strip).rename(index=str.strip)\n",
    "        table.loc[index_name, (dataset, 'Multi JGA')] = df.loc['all', 'JGA']\n",
    "        table.loc[index_name, (dataset, 'Multi TA')] = df.loc['all', 'TA']\n",
    "        table.loc[index_name, (dataset, 'Multi CDTA')] = df.loc['all', 'CDTA']\n",
    "        table.loc[index_name, (dataset, 'Cross JGA')] = df.loc['cross-domain', 'JGA']\n",
    "\n",
    "        res_file = f'{res_dir}/test_single_domain_result.md'\n",
    "        df = pd.read_table(res_file, sep='|', index_col=1).dropna(axis=1, how='all').iloc[1:]\n",
    "        df = df.apply(lambda x: x.str.strip())\n",
    "        df = df.rename(columns=str.strip).rename(index=str.strip)\n",
    "        table.loc[index_name, (dataset, 'Single JGA')] = df.loc['all', 'JGA']\n",
    "main_exp_table = main_exp_table.applymap(lambda x: round(pd.to_numeric(x, errors='coerce') * 100, 1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 103,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead tr th {\n",
       "        text-align: left;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr>\n",
       "      <th></th>\n",
       "      <th colspan=\"5\" halign=\"left\">multiwoz21</th>\n",
       "      <th colspan=\"5\" halign=\"left\">sgd/group0</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th></th>\n",
       "      <th>Multi JGA</th>\n",
       "      <th>Multi TA</th>\n",
       "      <th>Multi CDTA</th>\n",
       "      <th>Cross JGA</th>\n",
       "      <th>Single JGA</th>\n",
       "      <th>Multi JGA</th>\n",
       "      <th>Multi TA</th>\n",
       "      <th>Multi CDTA</th>\n",
       "      <th>Cross JGA</th>\n",
       "      <th>Single JGA</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>Single</th>\n",
       "      <td>28.4</td>\n",
       "      <td>49.9</td>\n",
       "      <td>0.9</td>\n",
       "      <td>24.7</td>\n",
       "      <td>65.7</td>\n",
       "      <td>34.3</td>\n",
       "      <td>62.4</td>\n",
       "      <td>1.9</td>\n",
       "      <td>34.7</td>\n",
       "      <td>88.9</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Single+Concat (2x)</th>\n",
       "      <td>44.2</td>\n",
       "      <td>66.6</td>\n",
       "      <td>9.1</td>\n",
       "      <td>32.7</td>\n",
       "      <td>64.2</td>\n",
       "      <td>46.3</td>\n",
       "      <td>69.8</td>\n",
       "      <td>1.8</td>\n",
       "      <td>45.0</td>\n",
       "      <td>89.2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Single+ConcatRel (2x)</th>\n",
       "      <td>43.2</td>\n",
       "      <td>65.4</td>\n",
       "      <td>21.2</td>\n",
       "      <td>33.6</td>\n",
       "      <td>66.2</td>\n",
       "      <td>47.3</td>\n",
       "      <td>71.3</td>\n",
       "      <td>2.1</td>\n",
       "      <td>46.8</td>\n",
       "      <td>87.9</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>w/ Anaphora</th>\n",
       "      <td>42.0</td>\n",
       "      <td>63.9</td>\n",
       "      <td>24.0</td>\n",
       "      <td>32.9</td>\n",
       "      <td>66.7</td>\n",
       "      <td>51.7</td>\n",
       "      <td>74.1</td>\n",
       "      <td>23.5</td>\n",
       "      <td>52.5</td>\n",
       "      <td>90.1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>w/ Ellipsis</th>\n",
       "      <td>42.9</td>\n",
       "      <td>64.5</td>\n",
       "      <td>20.3</td>\n",
       "      <td>33.2</td>\n",
       "      <td>66.2</td>\n",
       "      <td>50.9</td>\n",
       "      <td>73.4</td>\n",
       "      <td>12.9</td>\n",
       "      <td>51.8</td>\n",
       "      <td>88.1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>w/ Mixed</th>\n",
       "      <td>42.6</td>\n",
       "      <td>64.1</td>\n",
       "      <td>25.3</td>\n",
       "      <td>34.1</td>\n",
       "      <td>65.5</td>\n",
       "      <td>50.2</td>\n",
       "      <td>73.7</td>\n",
       "      <td>16.1</td>\n",
       "      <td>50.5</td>\n",
       "      <td>88.9</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>w/ ChatGPT</th>\n",
       "      <td>43.1</td>\n",
       "      <td>64.9</td>\n",
       "      <td>25.0</td>\n",
       "      <td>33.4</td>\n",
       "      <td>67.7</td>\n",
       "      <td>51.3</td>\n",
       "      <td>74.3</td>\n",
       "      <td>26.7</td>\n",
       "      <td>52.4</td>\n",
       "      <td>88.1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>90%Single+10%Multi</th>\n",
       "      <td>41.9</td>\n",
       "      <td>63.6</td>\n",
       "      <td>20.8</td>\n",
       "      <td>31.9</td>\n",
       "      <td>64.9</td>\n",
       "      <td>59.2</td>\n",
       "      <td>76.8</td>\n",
       "      <td>53.1</td>\n",
       "      <td>61.3</td>\n",
       "      <td>92.1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>80%Single+20%Multi</th>\n",
       "      <td>44.2</td>\n",
       "      <td>66.8</td>\n",
       "      <td>35.2</td>\n",
       "      <td>35.8</td>\n",
       "      <td>67.1</td>\n",
       "      <td>65.9</td>\n",
       "      <td>82.9</td>\n",
       "      <td>64.5</td>\n",
       "      <td>67.2</td>\n",
       "      <td>89.8</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                      multiwoz21                                           \\\n",
       "                       Multi JGA Multi TA Multi CDTA Cross JGA Single JGA   \n",
       "Single                      28.4     49.9        0.9      24.7       65.7   \n",
       "Single+Concat (2x)          44.2     66.6        9.1      32.7       64.2   \n",
       "Single+ConcatRel (2x)       43.2     65.4       21.2      33.6       66.2   \n",
       "w/ Anaphora                 42.0     63.9       24.0      32.9       66.7   \n",
       "w/ Ellipsis                 42.9     64.5       20.3      33.2       66.2   \n",
       "w/ Mixed                    42.6     64.1       25.3      34.1       65.5   \n",
       "w/ ChatGPT                  43.1     64.9       25.0      33.4       67.7   \n",
       "90%Single+10%Multi          41.9     63.6       20.8      31.9       64.9   \n",
       "80%Single+20%Multi          44.2     66.8       35.2      35.8       67.1   \n",
       "\n",
       "                      sgd/group0                                           \n",
       "                       Multi JGA Multi TA Multi CDTA Cross JGA Single JGA  \n",
       "Single                      34.3     62.4        1.9      34.7       88.9  \n",
       "Single+Concat (2x)          46.3     69.8        1.8      45.0       89.2  \n",
       "Single+ConcatRel (2x)       47.3     71.3        2.1      46.8       87.9  \n",
       "w/ Anaphora                 51.7     74.1       23.5      52.5       90.1  \n",
       "w/ Ellipsis                 50.9     73.4       12.9      51.8       88.1  \n",
       "w/ Mixed                    50.2     73.7       16.1      50.5       88.9  \n",
       "w/ ChatGPT                  51.3     74.3       26.7      52.4       88.1  \n",
       "90%Single+10%Multi          59.2     76.8       53.1      61.3       92.1  \n",
       "80%Single+20%Multi          65.9     82.9       64.5      67.2       89.8  "
      ]
     },
     "execution_count": 103,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "main_exp_table"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 104,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\\begin{tabular}{lrrrrrrrrrr}\n",
      "\\toprule\n",
      "{} & \\multicolumn{5}{l}{multiwoz21} & \\multicolumn{5}{l}{sgd/group0} \\\\\n",
      "{} &  Multi JGA & Multi TA & Multi CDTA & Cross JGA & Single JGA &  Multi JGA & Multi TA & Multi CDTA & Cross JGA & Single JGA \\\\\n",
      "\\midrule\n",
      "Single                &       28.4 &     49.9 &        0.9 &      24.7 &       65.7 &       34.3 &     62.4 &        1.9 &      34.7 &       88.9 \\\\\n",
      "Single+Concat (2x)    &       44.2 &     66.6 &        9.1 &      32.7 &       64.2 &       46.3 &     69.8 &        1.8 &      45.0 &       89.2 \\\\\n",
      "Single+ConcatRel (2x) &       43.2 &     65.4 &       21.2 &      33.6 &       66.2 &       47.3 &     71.3 &        2.1 &      46.8 &       87.9 \\\\\n",
      "w/ Anaphora           &       42.0 &     63.9 &       24.0 &      32.9 &       66.7 &       51.7 &     74.1 &       23.5 &      52.5 &       90.1 \\\\\n",
      "w/ Ellipsis           &       42.9 &     64.5 &       20.3 &      33.2 &       66.2 &       50.9 &     73.4 &       12.9 &      51.8 &       88.1 \\\\\n",
      "w/ Mixed              &       42.6 &     64.1 &       25.3 &      34.1 &       65.5 &       50.2 &     73.7 &       16.1 &      50.5 &       88.9 \\\\\n",
      "w/ ChatGPT            &       43.1 &     64.9 &       25.0 &      33.4 &       67.7 &       51.3 &     74.3 &       26.7 &      52.4 &       88.1 \\\\\n",
      "90\\%Single+10\\%Multi    &       41.9 &     63.6 &       20.8 &      31.9 &       64.9 &       59.2 &     76.8 &       53.1 &      61.3 &       92.1 \\\\\n",
      "80\\%Single+20\\%Multi    &       44.2 &     66.8 &       35.2 &      35.8 &       67.1 &       65.9 &     82.9 &       64.5 &      67.2 &       89.8 \\\\\n",
      "\\bottomrule\n",
      "\\end{tabular}\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_88698/1754652341.py:1: FutureWarning: In future versions `DataFrame.to_latex` is expected to utilise the base implementation of `Styler.to_latex` for formatting and rendering. The arguments signature may therefore change. It is recommended instead to use `DataFrame.style.to_latex` which also contains additional functionality.\n",
      "  print(main_exp_table.to_latex())\n"
     ]
    }
   ],
   "source": [
    "print(main_exp_table.to_latex())"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Turn-level DST"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 105,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead tr th {\n",
       "        text-align: left;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr>\n",
       "      <th></th>\n",
       "      <th colspan=\"5\" halign=\"left\">multiwoz21</th>\n",
       "      <th colspan=\"5\" halign=\"left\">sgd/group0</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th></th>\n",
       "      <th>Multi JGA</th>\n",
       "      <th>Multi TA</th>\n",
       "      <th>Multi CDTA</th>\n",
       "      <th>Cross JGA</th>\n",
       "      <th>Single JGA</th>\n",
       "      <th>Multi JGA</th>\n",
       "      <th>Multi TA</th>\n",
       "      <th>Multi CDTA</th>\n",
       "      <th>Cross JGA</th>\n",
       "      <th>Single JGA</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>Single</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Single+Concat (2x)</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Single+ConcatRel (2x)</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>w/ Anaphora</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>w/ Ellipsis</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>w/ Mixed</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>w/ ChatGPT</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>90%Single+10%Multi</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>80%Single+20%Multi</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                      multiwoz21                                           \\\n",
       "                       Multi JGA Multi TA Multi CDTA Cross JGA Single JGA   \n",
       "Single                       NaN      NaN        NaN       NaN        NaN   \n",
       "Single+Concat (2x)           NaN      NaN        NaN       NaN        NaN   \n",
       "Single+ConcatRel (2x)        NaN      NaN        NaN       NaN        NaN   \n",
       "w/ Anaphora                  NaN      NaN        NaN       NaN        NaN   \n",
       "w/ Ellipsis                  NaN      NaN        NaN       NaN        NaN   \n",
       "w/ Mixed                     NaN      NaN        NaN       NaN        NaN   \n",
       "w/ ChatGPT                   NaN      NaN        NaN       NaN        NaN   \n",
       "90%Single+10%Multi           NaN      NaN        NaN       NaN        NaN   \n",
       "80%Single+20%Multi           NaN      NaN        NaN       NaN        NaN   \n",
       "\n",
       "                      sgd/group0                                           \n",
       "                       Multi JGA Multi TA Multi CDTA Cross JGA Single JGA  \n",
       "Single                       NaN      NaN        NaN       NaN        NaN  \n",
       "Single+Concat (2x)           NaN      NaN        NaN       NaN        NaN  \n",
       "Single+ConcatRel (2x)        NaN      NaN        NaN       NaN        NaN  \n",
       "w/ Anaphora                  NaN      NaN        NaN       NaN        NaN  \n",
       "w/ Ellipsis                  NaN      NaN        NaN       NaN        NaN  \n",
       "w/ Mixed                     NaN      NaN        NaN       NaN        NaN  \n",
       "w/ ChatGPT                   NaN      NaN        NaN       NaN        NaN  \n",
       "90%Single+10%Multi           NaN      NaN        NaN       NaN        NaN  \n",
       "80%Single+20%Multi           NaN      NaN        NaN       NaN        NaN  "
      ]
     },
     "execution_count": 105,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "turn_exp_table = pd.DataFrame(index=index_names, columns=pd.MultiIndex.from_product([datasets[:2], column_names]))\n",
    "turn_exp_table"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 106,
   "metadata": {},
   "outputs": [],
   "source": [
    "table = turn_exp_table\n",
    "for i, (aug_type, aug_times) in enumerate(aug_types):\n",
    "    index_name = index_names[i]\n",
    "    for dataset in datasets[:2]:\n",
    "        res_dir = f'output/{dataset}/t5-large_model1_context4_aug{aug_type}_x{aug_times}'\n",
    "\n",
    "        res_file = f'{res_dir}/test_multi_domain_result.md'\n",
    "        df = pd.read_table(res_file, sep='|', index_col=1).dropna(axis=1, how='all').iloc[1:]\n",
    "        df = df.apply(lambda x: x.str.strip())\n",
    "        df = df.rename(columns=str.strip).rename(index=str.strip)\n",
    "        table.loc[index_name, (dataset, 'Multi JGA')] = df.loc['all', 'JGA']\n",
    "        table.loc[index_name, (dataset, 'Multi TA')] = df.loc['all', 'TA']\n",
    "        table.loc[index_name, (dataset, 'Multi CDTA')] = df.loc['all', 'CDTA']\n",
    "        table.loc[index_name, (dataset, 'Cross JGA')] = df.loc['cross-domain', 'JGA']\n",
    "\n",
    "        res_file = f'{res_dir}/test_single_domain_result.md'\n",
    "        df = pd.read_table(res_file, sep='|', index_col=1).dropna(axis=1, how='all').iloc[1:]\n",
    "        df = df.apply(lambda x: x.str.strip())\n",
    "        df = df.rename(columns=str.strip).rename(index=str.strip)\n",
    "        table.loc[index_name, (dataset, 'Single JGA')] = df.loc['all', 'JGA']\n",
    "turn_exp_table = turn_exp_table.applymap(lambda x: round(pd.to_numeric(x, errors='coerce') * 100, 1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 107,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead tr th {\n",
       "        text-align: left;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr>\n",
       "      <th></th>\n",
       "      <th colspan=\"5\" halign=\"left\">multiwoz21</th>\n",
       "      <th colspan=\"5\" halign=\"left\">sgd/group0</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th></th>\n",
       "      <th>Multi JGA</th>\n",
       "      <th>Multi TA</th>\n",
       "      <th>Multi CDTA</th>\n",
       "      <th>Cross JGA</th>\n",
       "      <th>Single JGA</th>\n",
       "      <th>Multi JGA</th>\n",
       "      <th>Multi TA</th>\n",
       "      <th>Multi CDTA</th>\n",
       "      <th>Cross JGA</th>\n",
       "      <th>Single JGA</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>Single</th>\n",
       "      <td>27.7</td>\n",
       "      <td>59.7</td>\n",
       "      <td>1.1</td>\n",
       "      <td>23.8</td>\n",
       "      <td>63.8</td>\n",
       "      <td>33.0</td>\n",
       "      <td>67.6</td>\n",
       "      <td>0.8</td>\n",
       "      <td>33.4</td>\n",
       "      <td>83.1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Single+Concat (2x)</th>\n",
       "      <td>39.5</td>\n",
       "      <td>74.1</td>\n",
       "      <td>5.6</td>\n",
       "      <td>28.9</td>\n",
       "      <td>63.6</td>\n",
       "      <td>36.3</td>\n",
       "      <td>78.3</td>\n",
       "      <td>1.3</td>\n",
       "      <td>35.2</td>\n",
       "      <td>83.3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Single+ConcatRel (2x)</th>\n",
       "      <td>39.1</td>\n",
       "      <td>73.9</td>\n",
       "      <td>19.9</td>\n",
       "      <td>29.5</td>\n",
       "      <td>63.0</td>\n",
       "      <td>35.2</td>\n",
       "      <td>77.9</td>\n",
       "      <td>1.2</td>\n",
       "      <td>34.0</td>\n",
       "      <td>79.5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>w/ Anaphora</th>\n",
       "      <td>37.5</td>\n",
       "      <td>71.9</td>\n",
       "      <td>24.9</td>\n",
       "      <td>29.3</td>\n",
       "      <td>63.3</td>\n",
       "      <td>38.3</td>\n",
       "      <td>78.6</td>\n",
       "      <td>15.5</td>\n",
       "      <td>38.2</td>\n",
       "      <td>81.6</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>w/ Ellipsis</th>\n",
       "      <td>37.5</td>\n",
       "      <td>72.1</td>\n",
       "      <td>15.1</td>\n",
       "      <td>28.6</td>\n",
       "      <td>64.1</td>\n",
       "      <td>35.7</td>\n",
       "      <td>76.5</td>\n",
       "      <td>3.5</td>\n",
       "      <td>35.4</td>\n",
       "      <td>82.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>w/ Mixed</th>\n",
       "      <td>38.5</td>\n",
       "      <td>73.1</td>\n",
       "      <td>25.5</td>\n",
       "      <td>30.8</td>\n",
       "      <td>61.0</td>\n",
       "      <td>37.9</td>\n",
       "      <td>78.8</td>\n",
       "      <td>13.5</td>\n",
       "      <td>38.1</td>\n",
       "      <td>82.1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>w/ ChatGPT</th>\n",
       "      <td>39.1</td>\n",
       "      <td>73.9</td>\n",
       "      <td>29.1</td>\n",
       "      <td>30.9</td>\n",
       "      <td>62.6</td>\n",
       "      <td>39.8</td>\n",
       "      <td>78.8</td>\n",
       "      <td>19.5</td>\n",
       "      <td>40.7</td>\n",
       "      <td>81.5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>90%Single+10%Multi</th>\n",
       "      <td>41.1</td>\n",
       "      <td>74.1</td>\n",
       "      <td>25.0</td>\n",
       "      <td>31.8</td>\n",
       "      <td>62.8</td>\n",
       "      <td>47.9</td>\n",
       "      <td>84.8</td>\n",
       "      <td>37.8</td>\n",
       "      <td>48.6</td>\n",
       "      <td>85.6</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>80%Single+20%Multi</th>\n",
       "      <td>42.0</td>\n",
       "      <td>75.1</td>\n",
       "      <td>35.0</td>\n",
       "      <td>33.0</td>\n",
       "      <td>62.2</td>\n",
       "      <td>54.7</td>\n",
       "      <td>87.9</td>\n",
       "      <td>56.3</td>\n",
       "      <td>56.0</td>\n",
       "      <td>81.6</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                      multiwoz21                                           \\\n",
       "                       Multi JGA Multi TA Multi CDTA Cross JGA Single JGA   \n",
       "Single                      27.7     59.7        1.1      23.8       63.8   \n",
       "Single+Concat (2x)          39.5     74.1        5.6      28.9       63.6   \n",
       "Single+ConcatRel (2x)       39.1     73.9       19.9      29.5       63.0   \n",
       "w/ Anaphora                 37.5     71.9       24.9      29.3       63.3   \n",
       "w/ Ellipsis                 37.5     72.1       15.1      28.6       64.1   \n",
       "w/ Mixed                    38.5     73.1       25.5      30.8       61.0   \n",
       "w/ ChatGPT                  39.1     73.9       29.1      30.9       62.6   \n",
       "90%Single+10%Multi          41.1     74.1       25.0      31.8       62.8   \n",
       "80%Single+20%Multi          42.0     75.1       35.0      33.0       62.2   \n",
       "\n",
       "                      sgd/group0                                           \n",
       "                       Multi JGA Multi TA Multi CDTA Cross JGA Single JGA  \n",
       "Single                      33.0     67.6        0.8      33.4       83.1  \n",
       "Single+Concat (2x)          36.3     78.3        1.3      35.2       83.3  \n",
       "Single+ConcatRel (2x)       35.2     77.9        1.2      34.0       79.5  \n",
       "w/ Anaphora                 38.3     78.6       15.5      38.2       81.6  \n",
       "w/ Ellipsis                 35.7     76.5        3.5      35.4       82.0  \n",
       "w/ Mixed                    37.9     78.8       13.5      38.1       82.1  \n",
       "w/ ChatGPT                  39.8     78.8       19.5      40.7       81.5  \n",
       "90%Single+10%Multi          47.9     84.8       37.8      48.6       85.6  \n",
       "80%Single+20%Multi          54.7     87.9       56.3      56.0       81.6  "
      ]
     },
     "execution_count": 107,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "turn_exp_table"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 108,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\\begin{tabular}{lrrrrrrrrrr}\n",
      "\\toprule\n",
      "{} & \\multicolumn{5}{l}{multiwoz21} & \\multicolumn{5}{l}{sgd/group0} \\\\\n",
      "{} &  Multi JGA & Multi TA & Multi CDTA & Cross JGA & Single JGA &  Multi JGA & Multi TA & Multi CDTA & Cross JGA & Single JGA \\\\\n",
      "\\midrule\n",
      "Single                &       27.7 &     59.7 &        1.1 &      23.8 &       63.8 &       33.0 &     67.6 &        0.8 &      33.4 &       83.1 \\\\\n",
      "Single+Concat (2x)    &       39.5 &     74.1 &        5.6 &      28.9 &       63.6 &       36.3 &     78.3 &        1.3 &      35.2 &       83.3 \\\\\n",
      "Single+ConcatRel (2x) &       39.1 &     73.9 &       19.9 &      29.5 &       63.0 &       35.2 &     77.9 &        1.2 &      34.0 &       79.5 \\\\\n",
      "w/ Anaphora           &       37.5 &     71.9 &       24.9 &      29.3 &       63.3 &       38.3 &     78.6 &       15.5 &      38.2 &       81.6 \\\\\n",
      "w/ Ellipsis           &       37.5 &     72.1 &       15.1 &      28.6 &       64.1 &       35.7 &     76.5 &        3.5 &      35.4 &       82.0 \\\\\n",
      "w/ Mixed              &       38.5 &     73.1 &       25.5 &      30.8 &       61.0 &       37.9 &     78.8 &       13.5 &      38.1 &       82.1 \\\\\n",
      "w/ ChatGPT            &       39.1 &     73.9 &       29.1 &      30.9 &       62.6 &       39.8 &     78.8 &       19.5 &      40.7 &       81.5 \\\\\n",
      "90\\%Single+10\\%Multi    &       41.1 &     74.1 &       25.0 &      31.8 &       62.8 &       47.9 &     84.8 &       37.8 &      48.6 &       85.6 \\\\\n",
      "80\\%Single+20\\%Multi    &       42.0 &     75.1 &       35.0 &      33.0 &       62.2 &       54.7 &     87.9 &       56.3 &      56.0 &       81.6 \\\\\n",
      "\\bottomrule\n",
      "\\end{tabular}\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_88698/4272196636.py:1: FutureWarning: In future versions `DataFrame.to_latex` is expected to utilise the base implementation of `Styler.to_latex` for formatting and rendering. The arguments signature may therefore change. It is recommended instead to use `DataFrame.style.to_latex` which also contains additional functionality.\n",
      "  print(turn_exp_table.to_latex())\n"
     ]
    }
   ],
   "source": [
    "print(turn_exp_table.to_latex())"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### CoQR"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 109,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead tr th {\n",
       "        text-align: left;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th colspan=\"5\" halign=\"left\">multiwoz21</th>\n",
       "      <th colspan=\"5\" halign=\"left\">sgd/group0</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th>Multi JGA</th>\n",
       "      <th>Multi TA</th>\n",
       "      <th>Multi CDTA</th>\n",
       "      <th>Cross JGA</th>\n",
       "      <th>Single JGA</th>\n",
       "      <th>Multi JGA</th>\n",
       "      <th>Multi TA</th>\n",
       "      <th>Multi CDTA</th>\n",
       "      <th>Cross JGA</th>\n",
       "      <th>Single JGA</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th rowspan=\"4\" valign=\"top\">dialog</th>\n",
       "      <th>Single+Concat (2x)</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>w/ CANARD</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>w/ CANARD+MIX</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>w/ ChatGPT</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"4\" valign=\"top\">turn</th>\n",
       "      <th>Single+Concat (2x)</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>w/ CANARD</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>w/ CANARD+MIX</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>w/ ChatGPT</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                          multiwoz21                                           \\\n",
       "                           Multi JGA Multi TA Multi CDTA Cross JGA Single JGA   \n",
       "dialog Single+Concat (2x)        NaN      NaN        NaN       NaN        NaN   \n",
       "       w/ CANARD                 NaN      NaN        NaN       NaN        NaN   \n",
       "       w/ CANARD+MIX             NaN      NaN        NaN       NaN        NaN   \n",
       "       w/ ChatGPT                NaN      NaN        NaN       NaN        NaN   \n",
       "turn   Single+Concat (2x)        NaN      NaN        NaN       NaN        NaN   \n",
       "       w/ CANARD                 NaN      NaN        NaN       NaN        NaN   \n",
       "       w/ CANARD+MIX             NaN      NaN        NaN       NaN        NaN   \n",
       "       w/ ChatGPT                NaN      NaN        NaN       NaN        NaN   \n",
       "\n",
       "                          sgd/group0                                           \n",
       "                           Multi JGA Multi TA Multi CDTA Cross JGA Single JGA  \n",
       "dialog Single+Concat (2x)        NaN      NaN        NaN       NaN        NaN  \n",
       "       w/ CANARD                 NaN      NaN        NaN       NaN        NaN  \n",
       "       w/ CANARD+MIX             NaN      NaN        NaN       NaN        NaN  \n",
       "       w/ ChatGPT                NaN      NaN        NaN       NaN        NaN  \n",
       "turn   Single+Concat (2x)        NaN      NaN        NaN       NaN        NaN  \n",
       "       w/ CANARD                 NaN      NaN        NaN       NaN        NaN  \n",
       "       w/ CANARD+MIX             NaN      NaN        NaN       NaN        NaN  \n",
       "       w/ ChatGPT                NaN      NaN        NaN       NaN        NaN  "
      ]
     },
     "execution_count": 109,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "coqr_test_table = pd.DataFrame(index=pd.MultiIndex.from_product([['dialog', 'turn'], ['Single+Concat (2x)', 'w/ CANARD', 'w/ CANARD+MIX', 'w/ ChatGPT']]), columns=pd.MultiIndex.from_product([datasets[:2], column_names]))\n",
    "coqr_test_table"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 110,
   "metadata": {},
   "outputs": [],
   "source": [
    "table = coqr_test_table\n",
    "for dataset in datasets[:2]:\n",
    "    for j in range(2):\n",
    "        if j == 0:\n",
    "            res_dir = f'output/{dataset}/t5-large_model0_context100_aug2_x2.0'\n",
    "        else:\n",
    "            res_dir = f'output/{dataset}/t5-large_model1_context4_aug2_x2.0'\n",
    "\n",
    "        for i, suffix in enumerate(['', '_coqr_canard', '_coqr_canard_mix_x2.0', '_coqr_chatgpt']):\n",
    "            index_name = coqr_test_table.index[j*4+i]\n",
    "            res_file = f'{res_dir}/test_multi_domain{suffix}_result.md'\n",
    "            df = pd.read_table(res_file, sep='|', index_col=1).dropna(axis=1, how='all').iloc[1:]\n",
    "            df = df.apply(lambda x: x.str.strip())\n",
    "            df = df.rename(columns=str.strip).rename(index=str.strip)\n",
    "            table.loc[index_name, (dataset, 'Multi JGA')] = df.loc['all', 'JGA']\n",
    "            table.loc[index_name, (dataset, 'Multi TA')] = df.loc['all', 'TA']\n",
    "            table.loc[index_name, (dataset, 'Multi CDTA')] = df.loc['all', 'CDTA']\n",
    "            table.loc[index_name, (dataset, 'Cross JGA')] = df.loc['cross-domain', 'JGA']\n",
    "\n",
    "            res_file = f'{res_dir}/test_single_domain{suffix}_result.md'\n",
    "            df = pd.read_table(res_file, sep='|', index_col=1).dropna(axis=1, how='all').iloc[1:]\n",
    "            df = df.apply(lambda x: x.str.strip())\n",
    "            df = df.rename(columns=str.strip).rename(index=str.strip)\n",
    "            table.loc[index_name, (dataset, 'Single JGA')] = df.loc['all', 'JGA']\n",
    "coqr_test_table = coqr_test_table.applymap(lambda x: round(pd.to_numeric(x, errors='coerce') * 100, 1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 111,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead tr th {\n",
       "        text-align: left;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th colspan=\"5\" halign=\"left\">multiwoz21</th>\n",
       "      <th colspan=\"5\" halign=\"left\">sgd/group0</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th>Multi JGA</th>\n",
       "      <th>Multi TA</th>\n",
       "      <th>Multi CDTA</th>\n",
       "      <th>Cross JGA</th>\n",
       "      <th>Single JGA</th>\n",
       "      <th>Multi JGA</th>\n",
       "      <th>Multi TA</th>\n",
       "      <th>Multi CDTA</th>\n",
       "      <th>Cross JGA</th>\n",
       "      <th>Single JGA</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th rowspan=\"4\" valign=\"top\">dialog</th>\n",
       "      <th>Single+Concat (2x)</th>\n",
       "      <td>44.2</td>\n",
       "      <td>66.6</td>\n",
       "      <td>9.1</td>\n",
       "      <td>32.7</td>\n",
       "      <td>64.2</td>\n",
       "      <td>46.3</td>\n",
       "      <td>69.8</td>\n",
       "      <td>1.8</td>\n",
       "      <td>45.0</td>\n",
       "      <td>89.2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>w/ CANARD</th>\n",
       "      <td>38.4</td>\n",
       "      <td>61.4</td>\n",
       "      <td>8.4</td>\n",
       "      <td>29.8</td>\n",
       "      <td>61.3</td>\n",
       "      <td>36.2</td>\n",
       "      <td>60.9</td>\n",
       "      <td>11.3</td>\n",
       "      <td>35.9</td>\n",
       "      <td>74.8</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>w/ CANARD+MIX</th>\n",
       "      <td>44.0</td>\n",
       "      <td>66.4</td>\n",
       "      <td>16.5</td>\n",
       "      <td>33.8</td>\n",
       "      <td>64.1</td>\n",
       "      <td>48.9</td>\n",
       "      <td>71.6</td>\n",
       "      <td>18.1</td>\n",
       "      <td>48.5</td>\n",
       "      <td>88.8</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>w/ ChatGPT</th>\n",
       "      <td>40.3</td>\n",
       "      <td>62.8</td>\n",
       "      <td>16.5</td>\n",
       "      <td>30.6</td>\n",
       "      <td>62.1</td>\n",
       "      <td>40.3</td>\n",
       "      <td>65.9</td>\n",
       "      <td>21.9</td>\n",
       "      <td>39.9</td>\n",
       "      <td>73.5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"4\" valign=\"top\">turn</th>\n",
       "      <th>Single+Concat (2x)</th>\n",
       "      <td>39.5</td>\n",
       "      <td>74.1</td>\n",
       "      <td>5.6</td>\n",
       "      <td>28.9</td>\n",
       "      <td>63.6</td>\n",
       "      <td>36.3</td>\n",
       "      <td>78.3</td>\n",
       "      <td>1.3</td>\n",
       "      <td>35.2</td>\n",
       "      <td>83.3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>w/ CANARD</th>\n",
       "      <td>34.5</td>\n",
       "      <td>70.3</td>\n",
       "      <td>6.2</td>\n",
       "      <td>26.6</td>\n",
       "      <td>59.2</td>\n",
       "      <td>30.0</td>\n",
       "      <td>71.0</td>\n",
       "      <td>13.7</td>\n",
       "      <td>30.4</td>\n",
       "      <td>64.6</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>w/ CANARD+MIX</th>\n",
       "      <td>39.4</td>\n",
       "      <td>73.7</td>\n",
       "      <td>16.6</td>\n",
       "      <td>30.0</td>\n",
       "      <td>63.4</td>\n",
       "      <td>39.5</td>\n",
       "      <td>78.7</td>\n",
       "      <td>18.1</td>\n",
       "      <td>39.5</td>\n",
       "      <td>82.2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>w/ ChatGPT</th>\n",
       "      <td>36.3</td>\n",
       "      <td>70.6</td>\n",
       "      <td>17.0</td>\n",
       "      <td>27.4</td>\n",
       "      <td>61.1</td>\n",
       "      <td>32.3</td>\n",
       "      <td>72.7</td>\n",
       "      <td>22.7</td>\n",
       "      <td>32.1</td>\n",
       "      <td>66.2</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                          multiwoz21                                           \\\n",
       "                           Multi JGA Multi TA Multi CDTA Cross JGA Single JGA   \n",
       "dialog Single+Concat (2x)       44.2     66.6        9.1      32.7       64.2   \n",
       "       w/ CANARD                38.4     61.4        8.4      29.8       61.3   \n",
       "       w/ CANARD+MIX            44.0     66.4       16.5      33.8       64.1   \n",
       "       w/ ChatGPT               40.3     62.8       16.5      30.6       62.1   \n",
       "turn   Single+Concat (2x)       39.5     74.1        5.6      28.9       63.6   \n",
       "       w/ CANARD                34.5     70.3        6.2      26.6       59.2   \n",
       "       w/ CANARD+MIX            39.4     73.7       16.6      30.0       63.4   \n",
       "       w/ ChatGPT               36.3     70.6       17.0      27.4       61.1   \n",
       "\n",
       "                          sgd/group0                                           \n",
       "                           Multi JGA Multi TA Multi CDTA Cross JGA Single JGA  \n",
       "dialog Single+Concat (2x)       46.3     69.8        1.8      45.0       89.2  \n",
       "       w/ CANARD                36.2     60.9       11.3      35.9       74.8  \n",
       "       w/ CANARD+MIX            48.9     71.6       18.1      48.5       88.8  \n",
       "       w/ ChatGPT               40.3     65.9       21.9      39.9       73.5  \n",
       "turn   Single+Concat (2x)       36.3     78.3        1.3      35.2       83.3  \n",
       "       w/ CANARD                30.0     71.0       13.7      30.4       64.6  \n",
       "       w/ CANARD+MIX            39.5     78.7       18.1      39.5       82.2  \n",
       "       w/ ChatGPT               32.3     72.7       22.7      32.1       66.2  "
      ]
     },
     "execution_count": 111,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "coqr_test_table"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 112,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\\begin{tabular}{lrrrrrrrrrr}\n",
      "\\toprule\n",
      "{} & \\multicolumn{5}{l}{multiwoz21} & \\multicolumn{5}{l}{sgd/group0} \\\\\n",
      "{} &  Multi JGA & Multi TA & Multi CDTA & Cross JGA & Single JGA &  Multi JGA & Multi TA & Multi CDTA & Cross JGA & Single JGA \\\\\n",
      "\\midrule\n",
      "Single+Concat (2x) &       44.2 &     66.6 &        9.1 &      32.7 &       64.2 &       46.3 &     69.8 &        1.8 &      45.0 &       89.2 \\\\\n",
      "w/ CANARD          &       38.4 &     61.4 &        8.4 &      29.8 &       61.3 &       36.2 &     60.9 &       11.3 &      35.9 &       74.8 \\\\\n",
      "w/ CANARD+MIX      &       44.0 &     66.4 &       16.5 &      33.8 &       64.1 &       48.9 &     71.6 &       18.1 &      48.5 &       88.8 \\\\\n",
      "w/ ChatGPT         &       40.3 &     62.8 &       16.5 &      30.6 &       62.1 &       40.3 &     65.9 &       21.9 &      39.9 &       73.5 \\\\\n",
      "\\bottomrule\n",
      "\\end{tabular}\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_88698/4289893829.py:1: FutureWarning: In future versions `DataFrame.to_latex` is expected to utilise the base implementation of `Styler.to_latex` for formatting and rendering. The arguments signature may therefore change. It is recommended instead to use `DataFrame.style.to_latex` which also contains additional functionality.\n",
      "  print(coqr_test_table.loc['dialog'].to_latex())\n"
     ]
    }
   ],
   "source": [
    "print(coqr_test_table.loc['dialog'].to_latex())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 113,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\\begin{tabular}{lrrrrrrrrrr}\n",
      "\\toprule\n",
      "{} & \\multicolumn{5}{l}{multiwoz21} & \\multicolumn{5}{l}{sgd/group0} \\\\\n",
      "{} &  Multi JGA & Multi TA & Multi CDTA & Cross JGA & Single JGA &  Multi JGA & Multi TA & Multi CDTA & Cross JGA & Single JGA \\\\\n",
      "\\midrule\n",
      "Single+Concat (2x) &       39.5 &     74.1 &        5.6 &      28.9 &       63.6 &       36.3 &     78.3 &        1.3 &      35.2 &       83.3 \\\\\n",
      "w/ CANARD          &       34.5 &     70.3 &        6.2 &      26.6 &       59.2 &       30.0 &     71.0 &       13.7 &      30.4 &       64.6 \\\\\n",
      "w/ CANARD+MIX      &       39.4 &     73.7 &       16.6 &      30.0 &       63.4 &       39.5 &     78.7 &       18.1 &      39.5 &       82.2 \\\\\n",
      "w/ ChatGPT         &       36.3 &     70.6 &       17.0 &      27.4 &       61.1 &       32.3 &     72.7 &       22.7 &      32.1 &       66.2 \\\\\n",
      "\\bottomrule\n",
      "\\end{tabular}\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_88698/2481759033.py:1: FutureWarning: In future versions `DataFrame.to_latex` is expected to utilise the base implementation of `Styler.to_latex` for formatting and rendering. The arguments signature may therefore change. It is recommended instead to use `DataFrame.style.to_latex` which also contains additional functionality.\n",
      "  print(coqr_test_table.loc['turn'].to_latex())\n"
     ]
    }
   ],
   "source": [
    "print(coqr_test_table.loc['turn'].to_latex())"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Performance vs Aug Data Size & Model Size"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th>Multi JGA</th>\n",
       "      <th>Multi TA</th>\n",
       "      <th>Multi CDTA</th>\n",
       "      <th>Cross JGA</th>\n",
       "      <th>Single JGA</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th rowspan=\"5\" valign=\"top\">T5-small</th>\n",
       "      <th>1x</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2x</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4x</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8x</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10%Multi</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"5\" valign=\"top\">T5-base</th>\n",
       "      <th>1x</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2x</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4x</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8x</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10%Multi</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"5\" valign=\"top\">T5-large</th>\n",
       "      <th>1x</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2x</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4x</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8x</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10%Multi</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"5\" valign=\"top\">aggregate</th>\n",
       "      <th>1x</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2x</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4x</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8x</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10%Multi</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                   Multi JGA Multi TA Multi CDTA Cross JGA Single JGA\n",
       "T5-small  1x             NaN      NaN        NaN       NaN        NaN\n",
       "          2x             NaN      NaN        NaN       NaN        NaN\n",
       "          4x             NaN      NaN        NaN       NaN        NaN\n",
       "          8x             NaN      NaN        NaN       NaN        NaN\n",
       "          10%Multi       NaN      NaN        NaN       NaN        NaN\n",
       "T5-base   1x             NaN      NaN        NaN       NaN        NaN\n",
       "          2x             NaN      NaN        NaN       NaN        NaN\n",
       "          4x             NaN      NaN        NaN       NaN        NaN\n",
       "          8x             NaN      NaN        NaN       NaN        NaN\n",
       "          10%Multi       NaN      NaN        NaN       NaN        NaN\n",
       "T5-large  1x             NaN      NaN        NaN       NaN        NaN\n",
       "          2x             NaN      NaN        NaN       NaN        NaN\n",
       "          4x             NaN      NaN        NaN       NaN        NaN\n",
       "          8x             NaN      NaN        NaN       NaN        NaN\n",
       "          10%Multi       NaN      NaN        NaN       NaN        NaN\n",
       "aggregate 1x             NaN      NaN        NaN       NaN        NaN\n",
       "          2x             NaN      NaN        NaN       NaN        NaN\n",
       "          4x             NaN      NaN        NaN       NaN        NaN\n",
       "          8x             NaN      NaN        NaN       NaN        NaN\n",
       "          10%Multi       NaN      NaN        NaN       NaN        NaN"
      ]
     },
     "execution_count": 61,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "aug_data_size_table = pd.DataFrame(index=pd.MultiIndex.from_product([['T5-small', 'T5-base', 'T5-large', 'aggregate'], ['1x', '2x', '4x', '8x', '10%Multi']]), columns=column_names)\n",
    "aug_data_size_table"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 66,
   "metadata": {},
   "outputs": [],
   "source": [
    "table = aug_data_size_table\n",
    "for i, model_size in enumerate(['T5-small', 'T5-base', 'T5-large']):\n",
    "    for j, model_type in enumerate([1.0,2.0,4.0,8.0,'10%Multi']):\n",
    "        index_name = table.index[i*5+j]\n",
    "\n",
    "        if model_type == '10%Multi':\n",
    "            res_dir = f'output/sgd/group0/{model_size.lower()}_model0_context100_aug1_x0.1'\n",
    "        else:\n",
    "            res_dir = f'output/sgd/group0/{model_size.lower()}_model0_context100_aug7_x{model_type}'\n",
    "\n",
    "        res_file = f'{res_dir}/test_multi_domain_result.md'\n",
    "        if not os.path.exists(res_file):\n",
    "            continue\n",
    "        df = pd.read_table(res_file, sep='|', index_col=1).dropna(axis=1, how='all').iloc[1:]\n",
    "        df = df.apply(lambda x: x.str.strip())\n",
    "        df = df.rename(columns=str.strip).rename(index=str.strip)\n",
    "        table.loc[index_name, 'Multi JGA'] = df.loc['all', 'JGA']\n",
    "        table.loc[index_name, 'Multi TA'] = df.loc['all', 'TA']\n",
    "        table.loc[index_name, 'Multi CDTA'] = df.loc['all', 'CDTA']\n",
    "        table.loc[index_name, 'Cross JGA'] = df.loc['cross-domain', 'JGA']\n",
    "\n",
    "        res_file = f'{res_dir}/test_single_domain_result.md'\n",
    "        df = pd.read_table(res_file, sep='|', index_col=1).dropna(axis=1, how='all').iloc[1:]\n",
    "        df = df.apply(lambda x: x.str.strip())\n",
    "        df = df.rename(columns=str.strip).rename(index=str.strip)\n",
    "        table.loc[index_name, 'Single JGA'] = df.loc['all', 'JGA']\n",
    "aug_data_size_table.iloc[:15] = aug_data_size_table.iloc[:15].applymap(lambda x: int(round(pd.to_numeric(x, errors='coerce') * 100, 0)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 67,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th>Multi JGA</th>\n",
       "      <th>Multi TA</th>\n",
       "      <th>Multi CDTA</th>\n",
       "      <th>Cross JGA</th>\n",
       "      <th>Single JGA</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th rowspan=\"5\" valign=\"top\">T5-small</th>\n",
       "      <th>1x</th>\n",
       "      <td>42</td>\n",
       "      <td>66</td>\n",
       "      <td>15</td>\n",
       "      <td>43</td>\n",
       "      <td>80</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2x</th>\n",
       "      <td>44</td>\n",
       "      <td>68</td>\n",
       "      <td>15</td>\n",
       "      <td>45</td>\n",
       "      <td>81</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4x</th>\n",
       "      <td>46</td>\n",
       "      <td>70</td>\n",
       "      <td>21</td>\n",
       "      <td>47</td>\n",
       "      <td>84</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8x</th>\n",
       "      <td>45</td>\n",
       "      <td>71</td>\n",
       "      <td>24</td>\n",
       "      <td>46</td>\n",
       "      <td>83</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10%Multi</th>\n",
       "      <td>41</td>\n",
       "      <td>62</td>\n",
       "      <td>27</td>\n",
       "      <td>43</td>\n",
       "      <td>78</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"5\" valign=\"top\">T5-base</th>\n",
       "      <th>1x</th>\n",
       "      <td>49</td>\n",
       "      <td>71</td>\n",
       "      <td>15</td>\n",
       "      <td>49</td>\n",
       "      <td>88</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2x</th>\n",
       "      <td>49</td>\n",
       "      <td>73</td>\n",
       "      <td>17</td>\n",
       "      <td>49</td>\n",
       "      <td>88</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4x</th>\n",
       "      <td>49</td>\n",
       "      <td>74</td>\n",
       "      <td>20</td>\n",
       "      <td>50</td>\n",
       "      <td>87</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8x</th>\n",
       "      <td>48</td>\n",
       "      <td>73</td>\n",
       "      <td>22</td>\n",
       "      <td>49</td>\n",
       "      <td>86</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10%Multi</th>\n",
       "      <td>56</td>\n",
       "      <td>75</td>\n",
       "      <td>50</td>\n",
       "      <td>58</td>\n",
       "      <td>89</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"5\" valign=\"top\">T5-large</th>\n",
       "      <th>1x</th>\n",
       "      <td>50</td>\n",
       "      <td>73</td>\n",
       "      <td>16</td>\n",
       "      <td>50</td>\n",
       "      <td>90</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2x</th>\n",
       "      <td>50</td>\n",
       "      <td>74</td>\n",
       "      <td>16</td>\n",
       "      <td>50</td>\n",
       "      <td>89</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4x</th>\n",
       "      <td>50</td>\n",
       "      <td>74</td>\n",
       "      <td>17</td>\n",
       "      <td>50</td>\n",
       "      <td>87</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8x</th>\n",
       "      <td>50</td>\n",
       "      <td>73</td>\n",
       "      <td>22</td>\n",
       "      <td>51</td>\n",
       "      <td>85</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10%Multi</th>\n",
       "      <td>59</td>\n",
       "      <td>77</td>\n",
       "      <td>53</td>\n",
       "      <td>61</td>\n",
       "      <td>92</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"5\" valign=\"top\">aggregate</th>\n",
       "      <th>1x</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2x</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4x</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8x</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10%Multi</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                   Multi JGA Multi TA Multi CDTA Cross JGA Single JGA\n",
       "T5-small  1x              42       66         15        43         80\n",
       "          2x              44       68         15        45         81\n",
       "          4x              46       70         21        47         84\n",
       "          8x              45       71         24        46         83\n",
       "          10%Multi        41       62         27        43         78\n",
       "T5-base   1x              49       71         15        49         88\n",
       "          2x              49       73         17        49         88\n",
       "          4x              49       74         20        50         87\n",
       "          8x              48       73         22        49         86\n",
       "          10%Multi        56       75         50        58         89\n",
       "T5-large  1x              50       73         16        50         90\n",
       "          2x              50       74         16        50         89\n",
       "          4x              50       74         17        50         87\n",
       "          8x              50       73         22        51         85\n",
       "          10%Multi        59       77         53        61         92\n",
       "aggregate 1x             NaN      NaN        NaN       NaN        NaN\n",
       "          2x             NaN      NaN        NaN       NaN        NaN\n",
       "          4x             NaN      NaN        NaN       NaN        NaN\n",
       "          8x             NaN      NaN        NaN       NaN        NaN\n",
       "          10%Multi       NaN      NaN        NaN       NaN        NaN"
      ]
     },
     "execution_count": 67,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "aug_data_size_table"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 68,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Multi JGA</th>\n",
       "      <th>Multi TA</th>\n",
       "      <th>Multi CDTA</th>\n",
       "      <th>Cross JGA</th>\n",
       "      <th>Single JGA</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>1x</th>\n",
       "      <td>42$\\rightarrow$49$\\rightarrow$50</td>\n",
       "      <td>66$\\rightarrow$71$\\rightarrow$73</td>\n",
       "      <td>15$\\rightarrow$15$\\rightarrow$16</td>\n",
       "      <td>43$\\rightarrow$49$\\rightarrow$50</td>\n",
       "      <td>80$\\rightarrow$88$\\rightarrow$90</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2x</th>\n",
       "      <td>44$\\rightarrow$49$\\rightarrow$50</td>\n",
       "      <td>68$\\rightarrow$73$\\rightarrow$74</td>\n",
       "      <td>15$\\rightarrow$17$\\rightarrow$16</td>\n",
       "      <td>45$\\rightarrow$49$\\rightarrow$50</td>\n",
       "      <td>81$\\rightarrow$88$\\rightarrow$89</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4x</th>\n",
       "      <td>46$\\rightarrow$49$\\rightarrow$50</td>\n",
       "      <td>70$\\rightarrow$74$\\rightarrow$74</td>\n",
       "      <td>21$\\rightarrow$20$\\rightarrow$17</td>\n",
       "      <td>47$\\rightarrow$50$\\rightarrow$50</td>\n",
       "      <td>84$\\rightarrow$87$\\rightarrow$87</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8x</th>\n",
       "      <td>45$\\rightarrow$48$\\rightarrow$50</td>\n",
       "      <td>71$\\rightarrow$73$\\rightarrow$73</td>\n",
       "      <td>24$\\rightarrow$22$\\rightarrow$22</td>\n",
       "      <td>46$\\rightarrow$49$\\rightarrow$51</td>\n",
       "      <td>83$\\rightarrow$86$\\rightarrow$85</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10%Multi</th>\n",
       "      <td>41$\\rightarrow$56$\\rightarrow$59</td>\n",
       "      <td>62$\\rightarrow$75$\\rightarrow$77</td>\n",
       "      <td>27$\\rightarrow$50$\\rightarrow$53</td>\n",
       "      <td>43$\\rightarrow$58$\\rightarrow$61</td>\n",
       "      <td>78$\\rightarrow$89$\\rightarrow$92</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                 Multi JGA                          Multi TA  \\\n",
       "1x        42$\\rightarrow$49$\\rightarrow$50  66$\\rightarrow$71$\\rightarrow$73   \n",
       "2x        44$\\rightarrow$49$\\rightarrow$50  68$\\rightarrow$73$\\rightarrow$74   \n",
       "4x        46$\\rightarrow$49$\\rightarrow$50  70$\\rightarrow$74$\\rightarrow$74   \n",
       "8x        45$\\rightarrow$48$\\rightarrow$50  71$\\rightarrow$73$\\rightarrow$73   \n",
       "10%Multi  41$\\rightarrow$56$\\rightarrow$59  62$\\rightarrow$75$\\rightarrow$77   \n",
       "\n",
       "                                Multi CDTA                         Cross JGA  \\\n",
       "1x        15$\\rightarrow$15$\\rightarrow$16  43$\\rightarrow$49$\\rightarrow$50   \n",
       "2x        15$\\rightarrow$17$\\rightarrow$16  45$\\rightarrow$49$\\rightarrow$50   \n",
       "4x        21$\\rightarrow$20$\\rightarrow$17  47$\\rightarrow$50$\\rightarrow$50   \n",
       "8x        24$\\rightarrow$22$\\rightarrow$22  46$\\rightarrow$49$\\rightarrow$51   \n",
       "10%Multi  27$\\rightarrow$50$\\rightarrow$53  43$\\rightarrow$58$\\rightarrow$61   \n",
       "\n",
       "                                Single JGA  \n",
       "1x        80$\\rightarrow$88$\\rightarrow$90  \n",
       "2x        81$\\rightarrow$88$\\rightarrow$89  \n",
       "4x        84$\\rightarrow$87$\\rightarrow$87  \n",
       "8x        83$\\rightarrow$86$\\rightarrow$85  \n",
       "10%Multi  78$\\rightarrow$89$\\rightarrow$92  "
      ]
     },
     "execution_count": 68,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "table = aug_data_size_table\n",
    "for j, model_type in enumerate([1.0,2.0,4.0,8.0,'10%Multi']):\n",
    "    for metric in aug_data_size_table.columns:\n",
    "        index_name = table.index[15+j]\n",
    "        table.loc[index_name, metric] = '$\\rightarrow$'.join([str(table.loc[table.index[i*5+j], metric]) for i in range(3)])\n",
    "table.loc['aggregate']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 69,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\\begin{tabular}{llllll}\n",
      "\\toprule\n",
      "{} &                         Multi JGA &                          Multi TA &                        Multi CDTA &                         Cross JGA &                        Single JGA \\\\\n",
      "\\midrule\n",
      "1x       &  42\\$\\textbackslash rightarrow\\$49\\$\\textbackslash rightarrow\\$50 &  66\\$\\textbackslash rightarrow\\$71\\$\\textbackslash rightarrow\\$73 &  15\\$\\textbackslash rightarrow\\$15\\$\\textbackslash rightarrow\\$16 &  43\\$\\textbackslash rightarrow\\$49\\$\\textbackslash rightarrow\\$50 &  80\\$\\textbackslash rightarrow\\$88\\$\\textbackslash rightarrow\\$90 \\\\\n",
      "2x       &  44\\$\\textbackslash rightarrow\\$49\\$\\textbackslash rightarrow\\$50 &  68\\$\\textbackslash rightarrow\\$73\\$\\textbackslash rightarrow\\$74 &  15\\$\\textbackslash rightarrow\\$17\\$\\textbackslash rightarrow\\$16 &  45\\$\\textbackslash rightarrow\\$49\\$\\textbackslash rightarrow\\$50 &  81\\$\\textbackslash rightarrow\\$88\\$\\textbackslash rightarrow\\$89 \\\\\n",
      "4x       &  46\\$\\textbackslash rightarrow\\$49\\$\\textbackslash rightarrow\\$50 &  70\\$\\textbackslash rightarrow\\$74\\$\\textbackslash rightarrow\\$74 &  21\\$\\textbackslash rightarrow\\$20\\$\\textbackslash rightarrow\\$17 &  47\\$\\textbackslash rightarrow\\$50\\$\\textbackslash rightarrow\\$50 &  84\\$\\textbackslash rightarrow\\$87\\$\\textbackslash rightarrow\\$87 \\\\\n",
      "8x       &  45\\$\\textbackslash rightarrow\\$48\\$\\textbackslash rightarrow\\$50 &  71\\$\\textbackslash rightarrow\\$73\\$\\textbackslash rightarrow\\$73 &  24\\$\\textbackslash rightarrow\\$22\\$\\textbackslash rightarrow\\$22 &  46\\$\\textbackslash rightarrow\\$49\\$\\textbackslash rightarrow\\$51 &  83\\$\\textbackslash rightarrow\\$86\\$\\textbackslash rightarrow\\$85 \\\\\n",
      "10\\%Multi &  41\\$\\textbackslash rightarrow\\$56\\$\\textbackslash rightarrow\\$59 &  62\\$\\textbackslash rightarrow\\$75\\$\\textbackslash rightarrow\\$77 &  27\\$\\textbackslash rightarrow\\$50\\$\\textbackslash rightarrow\\$53 &  43\\$\\textbackslash rightarrow\\$58\\$\\textbackslash rightarrow\\$61 &  78\\$\\textbackslash rightarrow\\$89\\$\\textbackslash rightarrow\\$92 \\\\\n",
      "\\bottomrule\n",
      "\\end{tabular}\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_67659/3137222915.py:1: FutureWarning: In future versions `DataFrame.to_latex` is expected to utilise the base implementation of `Styler.to_latex` for formatting and rendering. The arguments signature may therefore change. It is recommended instead to use `DataFrame.style.to_latex` which also contains additional functionality.\n",
      "  print(aug_data_size_table.loc['aggregate'].to_latex())\n"
     ]
    }
   ],
   "source": [
    "print(aug_data_size_table.loc['aggregate'].to_latex())"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Domain expand"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 67,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead tr th {\n",
       "        text-align: left;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr>\n",
       "      <th></th>\n",
       "      <th colspan=\"3\" halign=\"left\">group2</th>\n",
       "      <th colspan=\"3\" halign=\"left\">group1</th>\n",
       "      <th colspan=\"3\" halign=\"left\">group0</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th></th>\n",
       "      <th>Multi JGA</th>\n",
       "      <th>Multi TA</th>\n",
       "      <th>Multi CDTA</th>\n",
       "      <th>Multi JGA</th>\n",
       "      <th>Multi TA</th>\n",
       "      <th>Multi CDTA</th>\n",
       "      <th>Multi JGA</th>\n",
       "      <th>Multi TA</th>\n",
       "      <th>Multi CDTA</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>10Multi</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10Multi+syn8x</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15Multi</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15Multi+syn8x</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19Multi</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Mix (2x)</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Mix (8x)</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                 group2                        group1                      \\\n",
       "              Multi JGA Multi TA Multi CDTA Multi JGA Multi TA Multi CDTA   \n",
       "10Multi             NaN      NaN        NaN       NaN      NaN        NaN   \n",
       "10Multi+syn8x       NaN      NaN        NaN       NaN      NaN        NaN   \n",
       "15Multi             NaN      NaN        NaN       NaN      NaN        NaN   \n",
       "15Multi+syn8x       NaN      NaN        NaN       NaN      NaN        NaN   \n",
       "19Multi             NaN      NaN        NaN       NaN      NaN        NaN   \n",
       "Mix (2x)            NaN      NaN        NaN       NaN      NaN        NaN   \n",
       "Mix (8x)            NaN      NaN        NaN       NaN      NaN        NaN   \n",
       "\n",
       "                 group0                      \n",
       "              Multi JGA Multi TA Multi CDTA  \n",
       "10Multi             NaN      NaN        NaN  \n",
       "10Multi+syn8x       NaN      NaN        NaN  \n",
       "15Multi             NaN      NaN        NaN  \n",
       "15Multi+syn8x       NaN      NaN        NaN  \n",
       "19Multi             NaN      NaN        NaN  \n",
       "Mix (2x)            NaN      NaN        NaN  \n",
       "Mix (8x)            NaN      NaN        NaN  "
      ]
     },
     "execution_count": 67,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "domain_expand_table = pd.DataFrame(index=['10Multi', '10Multi+syn8x', '15Multi', '15Multi+syn8x', '19Multi', 'Mix (2x)', 'Mix (8x)'], columns=pd.MultiIndex.from_product([['group2', 'group1', 'group0'], ['Multi JGA', 'Multi TA', 'Multi CDTA']]))\n",
    "domain_expand_table"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 70,
   "metadata": {},
   "outputs": [],
   "source": [
    "table = domain_expand_table\n",
    "for i, (group, aug_type, aug_times, prefix) in enumerate([(2, 1, 0.223, 'all_single_'), (2, 1, 0.223, 'all_single_syn8x_'), \n",
    "                                                          (1, 1, 0.133, 'all_single_'), (1, 1, 0.133, 'all_single_syn8x_'),\n",
    "                                                          (0, 1, 0.1, ''), (0, 7, 2.0, ''), (0, 7, 8.0, '')]):\n",
    "    index_name = table.index[i]\n",
    "    res_dir = f'output/sgd/group{group}/{prefix}t5-large_model0_context100_aug{aug_type}_x{aug_times}'\n",
    "    \n",
    "    res_file = f'{res_dir}/test_multi_domain_result.md'\n",
    "    df = pd.read_table(res_file, sep='|', index_col=1).dropna(axis=1, how='all').iloc[1:]\n",
    "    df = df.apply(lambda x: x.str.strip())\n",
    "    df = df.rename(columns=str.strip).rename(index=str.strip)\n",
    "    for dataset in ['group2', 'group1', 'group0']:\n",
    "        table.loc[index_name, (dataset, 'Multi JGA')] = df.loc[dataset, 'JGA']\n",
    "        table.loc[index_name, (dataset, 'Multi TA')] = df.loc[dataset, 'TA']\n",
    "        table.loc[index_name, (dataset, 'Multi CDTA')] = df.loc[dataset, 'CDTA']\n",
    "\n",
    "    # res_file = f'{res_dir}/test_single_domain_result.md'\n",
    "    # df = pd.read_table(res_file, sep='|', index_col=1).dropna(axis=1, how='all').iloc[1:]\n",
    "    # df = df.apply(lambda x: x.str.strip())\n",
    "    # df = df.rename(columns=str.strip).rename(index=str.strip)\n",
    "    # for dataset in ['group2', 'group1', 'group0']:\n",
    "    #     table.loc[index_name, (dataset, 'Single JGA')] = df.loc[dataset, 'JGA']\n",
    "domain_expand_table = domain_expand_table.applymap(lambda x: round(pd.to_numeric(x, errors='coerce') * 100, 1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 71,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead tr th {\n",
       "        text-align: left;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr>\n",
       "      <th></th>\n",
       "      <th colspan=\"3\" halign=\"left\">group2</th>\n",
       "      <th colspan=\"3\" halign=\"left\">group1</th>\n",
       "      <th colspan=\"3\" halign=\"left\">group0</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th></th>\n",
       "      <th>Multi JGA</th>\n",
       "      <th>Multi TA</th>\n",
       "      <th>Multi CDTA</th>\n",
       "      <th>Multi JGA</th>\n",
       "      <th>Multi TA</th>\n",
       "      <th>Multi CDTA</th>\n",
       "      <th>Multi JGA</th>\n",
       "      <th>Multi TA</th>\n",
       "      <th>Multi CDTA</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>10Multi</th>\n",
       "      <td>74.0</td>\n",
       "      <td>87.5</td>\n",
       "      <td>75.2</td>\n",
       "      <td>40.6</td>\n",
       "      <td>67.9</td>\n",
       "      <td>6.4</td>\n",
       "      <td>34.9</td>\n",
       "      <td>58.1</td>\n",
       "      <td>3.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10Multi+syn8x</th>\n",
       "      <td>74.2</td>\n",
       "      <td>89.2</td>\n",
       "      <td>68.6</td>\n",
       "      <td>54.3</td>\n",
       "      <td>79.0</td>\n",
       "      <td>25.0</td>\n",
       "      <td>48.5</td>\n",
       "      <td>73.7</td>\n",
       "      <td>26.2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15Multi</th>\n",
       "      <td>65.5</td>\n",
       "      <td>82.7</td>\n",
       "      <td>65.3</td>\n",
       "      <td>64.3</td>\n",
       "      <td>82.4</td>\n",
       "      <td>58.8</td>\n",
       "      <td>38.8</td>\n",
       "      <td>61.2</td>\n",
       "      <td>12.3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15Multi+syn8x</th>\n",
       "      <td>70.5</td>\n",
       "      <td>86.6</td>\n",
       "      <td>57.9</td>\n",
       "      <td>61.6</td>\n",
       "      <td>82.5</td>\n",
       "      <td>42.7</td>\n",
       "      <td>47.0</td>\n",
       "      <td>71.5</td>\n",
       "      <td>21.1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19Multi</th>\n",
       "      <td>56.2</td>\n",
       "      <td>75.3</td>\n",
       "      <td>61.6</td>\n",
       "      <td>65.4</td>\n",
       "      <td>82.4</td>\n",
       "      <td>63.1</td>\n",
       "      <td>56.8</td>\n",
       "      <td>74.4</td>\n",
       "      <td>47.5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Mix (2x)</th>\n",
       "      <td>49.6</td>\n",
       "      <td>74.0</td>\n",
       "      <td>28.9</td>\n",
       "      <td>56.0</td>\n",
       "      <td>78.0</td>\n",
       "      <td>13.3</td>\n",
       "      <td>47.4</td>\n",
       "      <td>71.4</td>\n",
       "      <td>14.6</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Mix (8x)</th>\n",
       "      <td>50.8</td>\n",
       "      <td>75.3</td>\n",
       "      <td>36.8</td>\n",
       "      <td>55.6</td>\n",
       "      <td>77.5</td>\n",
       "      <td>18.4</td>\n",
       "      <td>46.8</td>\n",
       "      <td>70.9</td>\n",
       "      <td>20.3</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                 group2                        group1                      \\\n",
       "              Multi JGA Multi TA Multi CDTA Multi JGA Multi TA Multi CDTA   \n",
       "10Multi            74.0     87.5       75.2      40.6     67.9        6.4   \n",
       "10Multi+syn8x      74.2     89.2       68.6      54.3     79.0       25.0   \n",
       "15Multi            65.5     82.7       65.3      64.3     82.4       58.8   \n",
       "15Multi+syn8x      70.5     86.6       57.9      61.6     82.5       42.7   \n",
       "19Multi            56.2     75.3       61.6      65.4     82.4       63.1   \n",
       "Mix (2x)           49.6     74.0       28.9      56.0     78.0       13.3   \n",
       "Mix (8x)           50.8     75.3       36.8      55.6     77.5       18.4   \n",
       "\n",
       "                 group0                      \n",
       "              Multi JGA Multi TA Multi CDTA  \n",
       "10Multi            34.9     58.1        3.0  \n",
       "10Multi+syn8x      48.5     73.7       26.2  \n",
       "15Multi            38.8     61.2       12.3  \n",
       "15Multi+syn8x      47.0     71.5       21.1  \n",
       "19Multi            56.8     74.4       47.5  \n",
       "Mix (2x)           47.4     71.4       14.6  \n",
       "Mix (8x)           46.8     70.9       20.3  "
      ]
     },
     "execution_count": 71,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "domain_expand_table"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 72,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\\begin{tabular}{lrrrrrrrrr}\n",
      "\\toprule\n",
      "{} & \\multicolumn{3}{l}{group2} & \\multicolumn{3}{l}{group1} & \\multicolumn{3}{l}{group0} \\\\\n",
      "{} & Multi JGA & Multi TA & Multi CDTA & Multi JGA & Multi TA & Multi CDTA & Multi JGA & Multi TA & Multi CDTA \\\\\n",
      "\\midrule\n",
      "10Multi       &      74.0 &     87.5 &       75.2 &      40.6 &     67.9 &        6.4 &      34.9 &     58.1 &        3.0 \\\\\n",
      "10Multi+syn8x &      74.2 &     89.2 &       68.6 &      54.3 &     79.0 &       25.0 &      48.5 &     73.7 &       26.2 \\\\\n",
      "15Multi       &      65.5 &     82.7 &       65.3 &      64.3 &     82.4 &       58.8 &      38.8 &     61.2 &       12.3 \\\\\n",
      "15Multi+syn8x &      70.5 &     86.6 &       57.9 &      61.6 &     82.5 &       42.7 &      47.0 &     71.5 &       21.1 \\\\\n",
      "19Multi       &      56.2 &     75.3 &       61.6 &      65.4 &     82.4 &       63.1 &      56.8 &     74.4 &       47.5 \\\\\n",
      "Mix (2x)      &      49.6 &     74.0 &       28.9 &      56.0 &     78.0 &       13.3 &      47.4 &     71.4 &       14.6 \\\\\n",
      "Mix (8x)      &      50.8 &     75.3 &       36.8 &      55.6 &     77.5 &       18.4 &      46.8 &     70.9 &       20.3 \\\\\n",
      "\\bottomrule\n",
      "\\end{tabular}\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_76944/705736087.py:1: FutureWarning: In future versions `DataFrame.to_latex` is expected to utilise the base implementation of `Styler.to_latex` for formatting and rendering. The arguments signature may therefore change. It is recommended instead to use `DataFrame.style.to_latex` which also contains additional functionality.\n",
      "  print(domain_expand_table.to_latex())\n"
     ]
    }
   ],
   "source": [
    "print(domain_expand_table.to_latex())"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### domain number in a dialog"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 88,
   "metadata": {},
   "outputs": [],
   "source": [
    "aug_types = [(0, 0.0), (2, 2.0), (3, 2.0), (2, 2.0), (3, 2.0), (4, 2.0), (7, 2.0), (9, 2.0), (1, 0.1), (1, 0.2)]\n",
    "index_names = ['Single', 'Single+Concat', 'Single+ConcatN', 'Single+Concat+CoQR-M', 'Single+ConcatN+CoQR-MixN', 'Single+ConcatRel (2x)', 'w/ Mix', 'w/ MixN', '90%Single+10%Multi', '80%Single+20%Multi']\n",
    "column_names = ['Multi JGA', 'Multi TA', 'Multi CDTA']\n",
    "datasets = ['multiwoz21', 'sgd/group0']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 89,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead tr th {\n",
       "        text-align: left;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr>\n",
       "      <th></th>\n",
       "      <th colspan=\"12\" halign=\"left\">sgd/group0</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th></th>\n",
       "      <th colspan=\"4\" halign=\"left\">Multi JGA</th>\n",
       "      <th colspan=\"4\" halign=\"left\">Multi TA</th>\n",
       "      <th colspan=\"4\" halign=\"left\">Multi CDTA</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th></th>\n",
       "      <th>domain_1</th>\n",
       "      <th>domain_2</th>\n",
       "      <th>domain_3</th>\n",
       "      <th>domain_4</th>\n",
       "      <th>domain_1</th>\n",
       "      <th>domain_2</th>\n",
       "      <th>domain_3</th>\n",
       "      <th>domain_4</th>\n",
       "      <th>domain_1</th>\n",
       "      <th>domain_2</th>\n",
       "      <th>domain_3</th>\n",
       "      <th>domain_4</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>Single</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Single+Concat</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Single+ConcatN</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Single+Concat+CoQR-M</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Single+ConcatN+CoQR-MixN</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Single+ConcatRel (2x)</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>w/ Mix</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>w/ MixN</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>90%Single+10%Multi</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>80%Single+20%Multi</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                         sgd/group0                                      \\\n",
       "                          Multi JGA                            Multi TA   \n",
       "                           domain_1 domain_2 domain_3 domain_4 domain_1   \n",
       "Single                          NaN      NaN      NaN      NaN      NaN   \n",
       "Single+Concat                   NaN      NaN      NaN      NaN      NaN   \n",
       "Single+ConcatN                  NaN      NaN      NaN      NaN      NaN   \n",
       "Single+Concat+CoQR-M            NaN      NaN      NaN      NaN      NaN   \n",
       "Single+ConcatN+CoQR-MixN        NaN      NaN      NaN      NaN      NaN   \n",
       "Single+ConcatRel (2x)           NaN      NaN      NaN      NaN      NaN   \n",
       "w/ Mix                          NaN      NaN      NaN      NaN      NaN   \n",
       "w/ MixN                         NaN      NaN      NaN      NaN      NaN   \n",
       "90%Single+10%Multi              NaN      NaN      NaN      NaN      NaN   \n",
       "80%Single+20%Multi              NaN      NaN      NaN      NaN      NaN   \n",
       "\n",
       "                                                                         \\\n",
       "                                                    Multi CDTA            \n",
       "                         domain_2 domain_3 domain_4   domain_1 domain_2   \n",
       "Single                        NaN      NaN      NaN        NaN      NaN   \n",
       "Single+Concat                 NaN      NaN      NaN        NaN      NaN   \n",
       "Single+ConcatN                NaN      NaN      NaN        NaN      NaN   \n",
       "Single+Concat+CoQR-M          NaN      NaN      NaN        NaN      NaN   \n",
       "Single+ConcatN+CoQR-MixN      NaN      NaN      NaN        NaN      NaN   \n",
       "Single+ConcatRel (2x)         NaN      NaN      NaN        NaN      NaN   \n",
       "w/ Mix                        NaN      NaN      NaN        NaN      NaN   \n",
       "w/ MixN                       NaN      NaN      NaN        NaN      NaN   \n",
       "90%Single+10%Multi            NaN      NaN      NaN        NaN      NaN   \n",
       "80%Single+20%Multi            NaN      NaN      NaN        NaN      NaN   \n",
       "\n",
       "                                            \n",
       "                                            \n",
       "                         domain_3 domain_4  \n",
       "Single                        NaN      NaN  \n",
       "Single+Concat                 NaN      NaN  \n",
       "Single+ConcatN                NaN      NaN  \n",
       "Single+Concat+CoQR-M          NaN      NaN  \n",
       "Single+ConcatN+CoQR-MixN      NaN      NaN  \n",
       "Single+ConcatRel (2x)         NaN      NaN  \n",
       "w/ Mix                        NaN      NaN  \n",
       "w/ MixN                       NaN      NaN  \n",
       "90%Single+10%Multi            NaN      NaN  \n",
       "80%Single+20%Multi            NaN      NaN  "
      ]
     },
     "execution_count": 89,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "domain_cnt_table = pd.DataFrame(index=index_names, columns=pd.MultiIndex.from_product([datasets[1:2], column_names, [f'domain_{i}' for i in range(1,5)]]))\n",
    "domain_cnt_table"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 91,
   "metadata": {},
   "outputs": [],
   "source": [
    "table = domain_cnt_table\n",
    "for i, (aug_type, aug_times) in enumerate(aug_types):\n",
    "    index_name = index_names[i]\n",
    "    for dataset in datasets[1:2]:\n",
    "        res_dir = f'output/{dataset}/t5-large_model0_context100_aug{aug_type}_x{aug_times}'\n",
    "        \n",
    "        res_file = f'{res_dir}/test_multi_domain_result.md'\n",
    "        if i == 3:\n",
    "            res_file = f'{res_dir}/test_multi_domain_coqr_canard_mix_x2.0_result.md'\n",
    "        elif i == 4:\n",
    "            res_file = f'{res_dir}/test_multi_domain_coqr_canard_aug9_x2.0_result.md'\n",
    "        df = pd.read_table(res_file, sep='|', index_col=1).dropna(axis=1, how='all').iloc[1:]\n",
    "        df = df.apply(lambda x: x.str.strip())\n",
    "        df = df.rename(columns=str.strip).rename(index=str.strip)\n",
    "        for dom_cnt in range(1, 5):\n",
    "            table.loc[index_name, (dataset, 'Multi JGA', f'domain_{dom_cnt}')] = df.loc[f'domain_{dom_cnt}', 'JGA']\n",
    "            table.loc[index_name, (dataset, 'Multi TA', f'domain_{dom_cnt}')] = df.loc[f'domain_{dom_cnt}', 'TA']\n",
    "            table.loc[index_name, (dataset, 'Multi CDTA', f'domain_{dom_cnt}')] = df.loc[f'domain_{dom_cnt}', 'CDTA']\n",
    "\n",
    "domain_cnt_table = domain_cnt_table.applymap(lambda x: round(pd.to_numeric(x, errors='coerce') * 100, 1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 92,
   "metadata": {},
   "outputs": [],
   "source": [
    "domain_cnt_table = domain_cnt_table.drop(columns=[('sgd/group0', 'Multi CDTA', 'domain_1')])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 93,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead tr th {\n",
       "        text-align: left;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr>\n",
       "      <th></th>\n",
       "      <th colspan=\"4\" halign=\"left\">Multi JGA</th>\n",
       "      <th colspan=\"4\" halign=\"left\">Multi TA</th>\n",
       "      <th colspan=\"3\" halign=\"left\">Multi CDTA</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th></th>\n",
       "      <th>domain_1</th>\n",
       "      <th>domain_2</th>\n",
       "      <th>domain_3</th>\n",
       "      <th>domain_4</th>\n",
       "      <th>domain_1</th>\n",
       "      <th>domain_2</th>\n",
       "      <th>domain_3</th>\n",
       "      <th>domain_4</th>\n",
       "      <th>domain_2</th>\n",
       "      <th>domain_3</th>\n",
       "      <th>domain_4</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>Single</th>\n",
       "      <td>87.2</td>\n",
       "      <td>0.1</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>92.3</td>\n",
       "      <td>43.6</td>\n",
       "      <td>41.3</td>\n",
       "      <td>39.6</td>\n",
       "      <td>2.3</td>\n",
       "      <td>0.3</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Single+Concat</th>\n",
       "      <td>87.6</td>\n",
       "      <td>24.6</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>92.7</td>\n",
       "      <td>57.6</td>\n",
       "      <td>48.3</td>\n",
       "      <td>18.5</td>\n",
       "      <td>1.9</td>\n",
       "      <td>1.4</td>\n",
       "      <td>1.9</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Single+ConcatN</th>\n",
       "      <td>87.9</td>\n",
       "      <td>22.0</td>\n",
       "      <td>10.3</td>\n",
       "      <td>0.0</td>\n",
       "      <td>93.3</td>\n",
       "      <td>60.2</td>\n",
       "      <td>53.8</td>\n",
       "      <td>33.8</td>\n",
       "      <td>2.0</td>\n",
       "      <td>1.4</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Single+Concat+CoQR-M</th>\n",
       "      <td>86.2</td>\n",
       "      <td>31.1</td>\n",
       "      <td>0.1</td>\n",
       "      <td>0.0</td>\n",
       "      <td>91.3</td>\n",
       "      <td>62.1</td>\n",
       "      <td>48.8</td>\n",
       "      <td>19.6</td>\n",
       "      <td>21.9</td>\n",
       "      <td>2.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Single+ConcatN+CoQR-MixN</th>\n",
       "      <td>86.8</td>\n",
       "      <td>27.4</td>\n",
       "      <td>11.7</td>\n",
       "      <td>3.1</td>\n",
       "      <td>92.4</td>\n",
       "      <td>64.7</td>\n",
       "      <td>54.3</td>\n",
       "      <td>42.7</td>\n",
       "      <td>16.3</td>\n",
       "      <td>3.1</td>\n",
       "      <td>18.5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Single+ConcatRel (2x)</th>\n",
       "      <td>87.2</td>\n",
       "      <td>27.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>92.8</td>\n",
       "      <td>59.8</td>\n",
       "      <td>50.0</td>\n",
       "      <td>28.8</td>\n",
       "      <td>2.3</td>\n",
       "      <td>0.3</td>\n",
       "      <td>3.7</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>w/ Mix</th>\n",
       "      <td>86.4</td>\n",
       "      <td>33.6</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>92.0</td>\n",
       "      <td>64.9</td>\n",
       "      <td>51.6</td>\n",
       "      <td>33.5</td>\n",
       "      <td>19.3</td>\n",
       "      <td>2.7</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>w/ MixN</th>\n",
       "      <td>86.9</td>\n",
       "      <td>34.2</td>\n",
       "      <td>18.4</td>\n",
       "      <td>3.8</td>\n",
       "      <td>92.5</td>\n",
       "      <td>67.3</td>\n",
       "      <td>59.3</td>\n",
       "      <td>52.3</td>\n",
       "      <td>21.7</td>\n",
       "      <td>6.5</td>\n",
       "      <td>25.9</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>90%Single+10%Multi</th>\n",
       "      <td>88.1</td>\n",
       "      <td>45.6</td>\n",
       "      <td>22.4</td>\n",
       "      <td>2.3</td>\n",
       "      <td>92.3</td>\n",
       "      <td>69.7</td>\n",
       "      <td>57.2</td>\n",
       "      <td>39.6</td>\n",
       "      <td>58.2</td>\n",
       "      <td>34.4</td>\n",
       "      <td>13.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>80%Single+20%Multi</th>\n",
       "      <td>88.5</td>\n",
       "      <td>56.4</td>\n",
       "      <td>34.2</td>\n",
       "      <td>8.8</td>\n",
       "      <td>93.0</td>\n",
       "      <td>79.3</td>\n",
       "      <td>66.8</td>\n",
       "      <td>50.4</td>\n",
       "      <td>68.2</td>\n",
       "      <td>49.3</td>\n",
       "      <td>44.4</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                         Multi JGA                            Multi TA  \\\n",
       "                          domain_1 domain_2 domain_3 domain_4 domain_1   \n",
       "Single                        87.2      0.1      0.0      0.0     92.3   \n",
       "Single+Concat                 87.6     24.6      0.0      0.0     92.7   \n",
       "Single+ConcatN                87.9     22.0     10.3      0.0     93.3   \n",
       "Single+Concat+CoQR-M          86.2     31.1      0.1      0.0     91.3   \n",
       "Single+ConcatN+CoQR-MixN      86.8     27.4     11.7      3.1     92.4   \n",
       "Single+ConcatRel (2x)         87.2     27.0      0.0      0.0     92.8   \n",
       "w/ Mix                        86.4     33.6      0.0      0.0     92.0   \n",
       "w/ MixN                       86.9     34.2     18.4      3.8     92.5   \n",
       "90%Single+10%Multi            88.1     45.6     22.4      2.3     92.3   \n",
       "80%Single+20%Multi            88.5     56.4     34.2      8.8     93.0   \n",
       "\n",
       "                                                    Multi CDTA           \\\n",
       "                         domain_2 domain_3 domain_4   domain_2 domain_3   \n",
       "Single                       43.6     41.3     39.6        2.3      0.3   \n",
       "Single+Concat                57.6     48.3     18.5        1.9      1.4   \n",
       "Single+ConcatN               60.2     53.8     33.8        2.0      1.4   \n",
       "Single+Concat+CoQR-M         62.1     48.8     19.6       21.9      2.0   \n",
       "Single+ConcatN+CoQR-MixN     64.7     54.3     42.7       16.3      3.1   \n",
       "Single+ConcatRel (2x)        59.8     50.0     28.8        2.3      0.3   \n",
       "w/ Mix                       64.9     51.6     33.5       19.3      2.7   \n",
       "w/ MixN                      67.3     59.3     52.3       21.7      6.5   \n",
       "90%Single+10%Multi           69.7     57.2     39.6       58.2     34.4   \n",
       "80%Single+20%Multi           79.3     66.8     50.4       68.2     49.3   \n",
       "\n",
       "                                   \n",
       "                         domain_4  \n",
       "Single                        0.0  \n",
       "Single+Concat                 1.9  \n",
       "Single+ConcatN                0.0  \n",
       "Single+Concat+CoQR-M          0.0  \n",
       "Single+ConcatN+CoQR-MixN     18.5  \n",
       "Single+ConcatRel (2x)         3.7  \n",
       "w/ Mix                        0.0  \n",
       "w/ MixN                      25.9  \n",
       "90%Single+10%Multi           13.0  \n",
       "80%Single+20%Multi           44.4  "
      ]
     },
     "execution_count": 93,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "domain_cnt_table.loc[:,'sgd/group0']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 94,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\\begin{tabular}{lrrrrrrrrrrr}\n",
      "\\toprule\n",
      "{} & \\multicolumn{4}{l}{Multi JGA} & \\multicolumn{4}{l}{Multi TA} & \\multicolumn{3}{l}{Multi CDTA} \\\\\n",
      "{} &  domain\\_1 & domain\\_2 & domain\\_3 & domain\\_4 & domain\\_1 & domain\\_2 & domain\\_3 & domain\\_4 &   domain\\_2 & domain\\_3 & domain\\_4 \\\\\n",
      "\\midrule\n",
      "Single                   &      87.2 &      0.1 &      0.0 &      0.0 &     92.3 &     43.6 &     41.3 &     39.6 &        2.3 &      0.3 &      0.0 \\\\\n",
      "Single+Concat            &      87.6 &     24.6 &      0.0 &      0.0 &     92.7 &     57.6 &     48.3 &     18.5 &        1.9 &      1.4 &      1.9 \\\\\n",
      "Single+ConcatN           &      87.9 &     22.0 &     10.3 &      0.0 &     93.3 &     60.2 &     53.8 &     33.8 &        2.0 &      1.4 &      0.0 \\\\\n",
      "Single+Concat+CoQR-M     &      86.2 &     31.1 &      0.1 &      0.0 &     91.3 &     62.1 &     48.8 &     19.6 &       21.9 &      2.0 &      0.0 \\\\\n",
      "Single+ConcatN+CoQR-MixN &      86.8 &     27.4 &     11.7 &      3.1 &     92.4 &     64.7 &     54.3 &     42.7 &       16.3 &      3.1 &     18.5 \\\\\n",
      "Single+ConcatRel (2x)    &      87.2 &     27.0 &      0.0 &      0.0 &     92.8 &     59.8 &     50.0 &     28.8 &        2.3 &      0.3 &      3.7 \\\\\n",
      "w/ Mix                   &      86.4 &     33.6 &      0.0 &      0.0 &     92.0 &     64.9 &     51.6 &     33.5 &       19.3 &      2.7 &      0.0 \\\\\n",
      "w/ MixN                  &      86.9 &     34.2 &     18.4 &      3.8 &     92.5 &     67.3 &     59.3 &     52.3 &       21.7 &      6.5 &     25.9 \\\\\n",
      "90\\%Single+10\\%Multi       &      88.1 &     45.6 &     22.4 &      2.3 &     92.3 &     69.7 &     57.2 &     39.6 &       58.2 &     34.4 &     13.0 \\\\\n",
      "80\\%Single+20\\%Multi       &      88.5 &     56.4 &     34.2 &      8.8 &     93.0 &     79.3 &     66.8 &     50.4 &       68.2 &     49.3 &     44.4 \\\\\n",
      "\\bottomrule\n",
      "\\end{tabular}\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_88698/3465050992.py:1: FutureWarning: In future versions `DataFrame.to_latex` is expected to utilise the base implementation of `Styler.to_latex` for formatting and rendering. The arguments signature may therefore change. It is recommended instead to use `DataFrame.style.to_latex` which also contains additional functionality.\n",
      "  print(domain_cnt_table.loc[:,'sgd/group0'].to_latex())\n"
     ]
    }
   ],
   "source": [
    "print(domain_cnt_table.loc[:,'sgd/group0'].to_latex())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
