{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [],
   "source": [
    "# !scp mygpu:TXL/transformer-xl/pytorch/test_res.csv ./test_res_retrieval.csv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "metadata": {},
   "outputs": [],
   "source": [
    "# !scp gpu6:bulatov/TXL/_git/test_res.csv ./test_res_synthetic_results.csv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "metadata": {},
   "outputs": [],
   "source": [
    "# df = pd.read_csv('test_res_synthetic.csv')\n",
    "# df = pd.concat((df, pd.read_csv('test_res_mygpu.csv')))\n",
    "# df = pd.concat((df, pd.read_csv('test_res_synthetic_results.csv')))\n",
    "# df = pd.concat((df, pd.read_csv('test_res_synthetic_results_all.csv')))\n",
    "df_retrieval = pd.read_csv('test_res_retrieval.csv')\n",
    "df = pd.read_csv('test_res_synthetic_results.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = df.dropna(subset=['last_val_acc'])\n",
    "df = df.drop('work_dir', axis=1)\n",
    "df = df.drop_duplicates()\n",
    "\n",
    "df = df[(df.max_step > 450000) | (df.tgt_len >= 24)]\n",
    "# df = df[~((df.tgt_len == 12) & (df.max_step == 700000))]\n",
    "df = df[~((df.tgt_len == 24) & (df.max_step == 300000))]\n",
    "df = df[~((df.tgt_len == 8) & (df.num_mem_tokens == 12))]\n",
    "df = df[~((df.dataset == 'reverse') & (df.max_step == 900000))]\n",
    "\n",
    "\n",
    "df_retrieval = df_retrieval.dropna(subset=['last_val_acc'])\n",
    "df_retrieval = df_retrieval.drop('work_dir', axis=1)\n",
    "df_retrieval = df_retrieval.drop_duplicates()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "metadata": {},
   "outputs": [],
   "source": [
    "reverse = df[df.dataset == 'reverse']\n",
    "copy = df[df.dataset == 'copy']\n",
    "retrieval = df_retrieval[df_retrieval.dataset.apply(lambda x: 'retrieval' in x)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th>test acc</th>\n",
       "      <th>mean acc</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>dataset</th>\n",
       "      <th>tgt_len</th>\n",
       "      <th>lr</th>\n",
       "      <th>max_step</th>\n",
       "      <th>num_mem_tokens</th>\n",
       "      <th>mem_len</th>\n",
       "      <th>mem_backprop_depth</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th rowspan=\"12\" valign=\"top\">copy</th>\n",
       "      <th rowspan=\"2\" valign=\"top\">8</th>\n",
       "      <th rowspan=\"2\" valign=\"top\">0.0001</th>\n",
       "      <th rowspan=\"2\" valign=\"top\">900000</th>\n",
       "      <th>0</th>\n",
       "      <th>8</th>\n",
       "      <th>0</th>\n",
       "      <td>(0.189, 0.451)</td>\n",
       "      <td>0.320000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <td>(1.0, 1.0, 1.0, 1.0)</td>\n",
       "      <td>1.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"5\" valign=\"top\">12</th>\n",
       "      <th rowspan=\"5\" valign=\"top\">0.0001</th>\n",
       "      <th rowspan=\"2\" valign=\"top\">700000</th>\n",
       "      <th>0</th>\n",
       "      <th>6</th>\n",
       "      <th>0</th>\n",
       "      <td>(0.657, 0.298, 0.355, 0.856)</td>\n",
       "      <td>0.541500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <td>(1.0, 1.0, 1.0, 1.0)</td>\n",
       "      <td>1.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"3\" valign=\"top\">900000</th>\n",
       "      <th rowspan=\"2\" valign=\"top\">0</th>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <td>(0.102,)</td>\n",
       "      <td>0.102000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <th>0</th>\n",
       "      <td>(0.998, 0.904, 1.0)</td>\n",
       "      <td>0.967333</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <td>(1.0, 1.0, 1.0)</td>\n",
       "      <td>1.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"4\" valign=\"top\">24</th>\n",
       "      <th rowspan=\"4\" valign=\"top\">0.0001</th>\n",
       "      <th rowspan=\"3\" valign=\"top\">400000</th>\n",
       "      <th rowspan=\"2\" valign=\"top\">0</th>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <td>(0.12, 0.12, 0.115)</td>\n",
       "      <td>0.118333</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>24</th>\n",
       "      <th>0</th>\n",
       "      <td>(1.0, 1.0, 0.999)</td>\n",
       "      <td>0.999667</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>24</th>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <td>(1.0, 1.0, 1.0)</td>\n",
       "      <td>1.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>500000</th>\n",
       "      <th>12</th>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <td>(1.0,)</td>\n",
       "      <td>1.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>72</th>\n",
       "      <th>0.0001</th>\n",
       "      <th>400000</th>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <td>(1.0, 1.0, 1.0)</td>\n",
       "      <td>1.000000</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                                                                               test acc  \\\n",
       "dataset tgt_len lr     max_step num_mem_tokens mem_len mem_backprop_depth                                 \n",
       "copy    8       0.0001 900000   0              8       0                                 (0.189, 0.451)   \n",
       "                                8              0       0                           (1.0, 1.0, 1.0, 1.0)   \n",
       "        12      0.0001 700000   0              6       0                   (0.657, 0.298, 0.355, 0.856)   \n",
       "                                6              0       0                           (1.0, 1.0, 1.0, 1.0)   \n",
       "                       900000   0              0       0                                       (0.102,)   \n",
       "                                               12      0                            (0.998, 0.904, 1.0)   \n",
       "                                12             0       0                                (1.0, 1.0, 1.0)   \n",
       "        24      0.0001 400000   0              0       0                            (0.12, 0.12, 0.115)   \n",
       "                                               24      0                              (1.0, 1.0, 0.999)   \n",
       "                                24             0       0                                (1.0, 1.0, 1.0)   \n",
       "                       500000   12             0       0                                         (1.0,)   \n",
       "        72      0.0001 400000   0              0       0                                (1.0, 1.0, 1.0)   \n",
       "\n",
       "                                                                           mean acc  \n",
       "dataset tgt_len lr     max_step num_mem_tokens mem_len mem_backprop_depth            \n",
       "copy    8       0.0001 900000   0              8       0                   0.320000  \n",
       "                                8              0       0                   1.000000  \n",
       "        12      0.0001 700000   0              6       0                   0.541500  \n",
       "                                6              0       0                   1.000000  \n",
       "                       900000   0              0       0                   0.102000  \n",
       "                                               12      0                   0.967333  \n",
       "                                12             0       0                   1.000000  \n",
       "        24      0.0001 400000   0              0       0                   0.118333  \n",
       "                                               24      0                   0.999667  \n",
       "                                24             0       0                   1.000000  \n",
       "                       500000   12             0       0                   1.000000  \n",
       "        72      0.0001 400000   0              0       0                   1.000000  "
      ]
     },
     "execution_count": 61,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "gb_cols = ['dataset', 'tgt_len', 'lr', 'max_step',  'num_mem_tokens', 'mem_len', 'mem_backprop_depth']#, 'seed']\n",
    "df_ = copy\n",
    "gb = df_.dropna(subset=['test acc']).groupby(gb_cols).agg({'test acc': tuple})\n",
    "gb['mean acc'] = df_.groupby(gb_cols).mean()['test acc']\n",
    "gb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th>test acc</th>\n",
       "      <th>mean acc</th>\n",
       "      <th>med acc</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>dataset</th>\n",
       "      <th>tgt_len</th>\n",
       "      <th>max_step</th>\n",
       "      <th>num_mem_tokens</th>\n",
       "      <th>mem_len</th>\n",
       "      <th>mem_backprop_depth</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th rowspan=\"14\" valign=\"top\">reverse</th>\n",
       "      <th rowspan=\"3\" valign=\"top\">8</th>\n",
       "      <th rowspan=\"3\" valign=\"top\">700000</th>\n",
       "      <th rowspan=\"2\" valign=\"top\">0</th>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <td>(0.102, 0.101)</td>\n",
       "      <td>0.101500</td>\n",
       "      <td>0.1015</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <th>0</th>\n",
       "      <td>(0.627, 0.696, 0.68, 0.66)</td>\n",
       "      <td>0.665750</td>\n",
       "      <td>0.6700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <td>(1.0, 0.701, 0.702, 1.0, 1.0, 1.0)</td>\n",
       "      <td>0.900500</td>\n",
       "      <td>1.0000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"5\" valign=\"top\">12</th>\n",
       "      <th rowspan=\"5\" valign=\"top\">850000</th>\n",
       "      <th rowspan=\"3\" valign=\"top\">0</th>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <td>(0.102, 0.1)</td>\n",
       "      <td>0.101000</td>\n",
       "      <td>0.1010</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <th>0</th>\n",
       "      <td>(0.832, 0.8, 0.782, 0.784)</td>\n",
       "      <td>0.799500</td>\n",
       "      <td>0.7920</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <th>0</th>\n",
       "      <td>(0.877, 0.872, 0.833)</td>\n",
       "      <td>0.860667</td>\n",
       "      <td>0.8720</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <td>(1.0, 1.0, 1.0)</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.0000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <td>(1.0, 1.0, 1.0)</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.0000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"5\" valign=\"top\">24</th>\n",
       "      <th rowspan=\"3\" valign=\"top\">400000</th>\n",
       "      <th rowspan=\"2\" valign=\"top\">0</th>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <td>(0.102, 0.102, 0.102)</td>\n",
       "      <td>0.102000</td>\n",
       "      <td>0.1020</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>24</th>\n",
       "      <th>0</th>\n",
       "      <td>(0.959, 0.899, 0.394, 0.881)</td>\n",
       "      <td>0.783250</td>\n",
       "      <td>0.8900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>24</th>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <td>(1.0, 1.0, 1.0, 1.0)</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.0000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"2\" valign=\"top\">500000</th>\n",
       "      <th>0</th>\n",
       "      <th>12</th>\n",
       "      <th>0</th>\n",
       "      <td>(0.996,)</td>\n",
       "      <td>0.996000</td>\n",
       "      <td>0.9960</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <td>(1.0,)</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.0000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>48</th>\n",
       "      <th>250000</th>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <td>(1.0, 1.0, 1.0)</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.0000</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                                                                              test acc  \\\n",
       "dataset tgt_len max_step num_mem_tokens mem_len mem_backprop_depth                                       \n",
       "reverse 8       700000   0              0       0                                       (0.102, 0.101)   \n",
       "                                        8       0                           (0.627, 0.696, 0.68, 0.66)   \n",
       "                         8              0       0                   (1.0, 0.701, 0.702, 1.0, 1.0, 1.0)   \n",
       "        12      850000   0              0       0                                         (0.102, 0.1)   \n",
       "                                        6       0                           (0.832, 0.8, 0.782, 0.784)   \n",
       "                                        12      0                                (0.877, 0.872, 0.833)   \n",
       "                         6              0       0                                      (1.0, 1.0, 1.0)   \n",
       "                         12             0       0                                      (1.0, 1.0, 1.0)   \n",
       "        24      400000   0              0       0                                (0.102, 0.102, 0.102)   \n",
       "                                        24      0                         (0.959, 0.899, 0.394, 0.881)   \n",
       "                         24             0       0                                 (1.0, 1.0, 1.0, 1.0)   \n",
       "                500000   0              12      0                                             (0.996,)   \n",
       "                         12             0       0                                               (1.0,)   \n",
       "        48      250000   0              0       0                                      (1.0, 1.0, 1.0)   \n",
       "\n",
       "                                                                    mean acc  \\\n",
       "dataset tgt_len max_step num_mem_tokens mem_len mem_backprop_depth             \n",
       "reverse 8       700000   0              0       0                   0.101500   \n",
       "                                        8       0                   0.665750   \n",
       "                         8              0       0                   0.900500   \n",
       "        12      850000   0              0       0                   0.101000   \n",
       "                                        6       0                   0.799500   \n",
       "                                        12      0                   0.860667   \n",
       "                         6              0       0                   1.000000   \n",
       "                         12             0       0                   1.000000   \n",
       "        24      400000   0              0       0                   0.102000   \n",
       "                                        24      0                   0.783250   \n",
       "                         24             0       0                   1.000000   \n",
       "                500000   0              12      0                   0.996000   \n",
       "                         12             0       0                   1.000000   \n",
       "        48      250000   0              0       0                   1.000000   \n",
       "\n",
       "                                                                    med acc  \n",
       "dataset tgt_len max_step num_mem_tokens mem_len mem_backprop_depth           \n",
       "reverse 8       700000   0              0       0                    0.1015  \n",
       "                                        8       0                    0.6700  \n",
       "                         8              0       0                    1.0000  \n",
       "        12      850000   0              0       0                    0.1010  \n",
       "                                        6       0                    0.7920  \n",
       "                                        12      0                    0.8720  \n",
       "                         6              0       0                    1.0000  \n",
       "                         12             0       0                    1.0000  \n",
       "        24      400000   0              0       0                    0.1020  \n",
       "                                        24      0                    0.8900  \n",
       "                         24             0       0                    1.0000  \n",
       "                500000   0              12      0                    0.9960  \n",
       "                         12             0       0                    1.0000  \n",
       "        48      250000   0              0       0                    1.0000  "
      ]
     },
     "execution_count": 62,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_ = reverse\n",
    "gb_cols = ['dataset', 'tgt_len', 'max_step',  'num_mem_tokens', 'mem_len', 'mem_backprop_depth']#, 'seed']\n",
    "gb = df_.dropna(subset=['test acc']).groupby(gb_cols).agg({'test acc': tuple})\n",
    "gb['mean acc'] = df_.groupby(gb_cols).mean()['test acc']\n",
    "gb['med acc'] = df_.groupby(gb_cols).median()['test acc']\n",
    "gb['med acc'] = df_.groupby(gb_cols).median()['test acc']\n",
    "gb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th>test acc</th>\n",
       "      <th>mean acc</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>dataset</th>\n",
       "      <th>tgt_len</th>\n",
       "      <th>max_step</th>\n",
       "      <th>num_mem_tokens</th>\n",
       "      <th>mem_len</th>\n",
       "      <th>mem_backprop_depth</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th rowspan=\"14\" valign=\"top\">retrieval</th>\n",
       "      <th rowspan=\"2\" valign=\"top\">2</th>\n",
       "      <th rowspan=\"2\" valign=\"top\">1000000</th>\n",
       "      <th>0</th>\n",
       "      <th>2</th>\n",
       "      <th>0</th>\n",
       "      <td>(0.965, 0.999, 0.999)</td>\n",
       "      <td>0.987667</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <td>(0.989, 0.922, 0.984)</td>\n",
       "      <td>0.965000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"3\" valign=\"top\">3</th>\n",
       "      <th rowspan=\"3\" valign=\"top\">1000000</th>\n",
       "      <th>0</th>\n",
       "      <th>3</th>\n",
       "      <th>0</th>\n",
       "      <td>(0.669, 0.905, 0.973)</td>\n",
       "      <td>0.849000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"2\" valign=\"top\">3</th>\n",
       "      <th rowspan=\"2\" valign=\"top\">0</th>\n",
       "      <th>0</th>\n",
       "      <td>(0.986, 0.994, 0.996)</td>\n",
       "      <td>0.992000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>(0.96,)</td>\n",
       "      <td>0.960000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"3\" valign=\"top\">4</th>\n",
       "      <th rowspan=\"3\" valign=\"top\">1000000</th>\n",
       "      <th>0</th>\n",
       "      <th>4</th>\n",
       "      <th>0</th>\n",
       "      <td>(1.0, 1.0, 1.0)</td>\n",
       "      <td>1.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"2\" valign=\"top\">4</th>\n",
       "      <th rowspan=\"2\" valign=\"top\">0</th>\n",
       "      <th>0</th>\n",
       "      <td>(0.999, 1.0, 0.999)</td>\n",
       "      <td>0.999333</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>(0.999, 0.999)</td>\n",
       "      <td>0.999000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"4\" valign=\"top\">5</th>\n",
       "      <th>700000</th>\n",
       "      <th>5</th>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <td>(0.963,)</td>\n",
       "      <td>0.963000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"3\" valign=\"top\">1000000</th>\n",
       "      <th>0</th>\n",
       "      <th>5</th>\n",
       "      <th>0</th>\n",
       "      <td>(1.0, 1.0, 1.0)</td>\n",
       "      <td>1.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"2\" valign=\"top\">5</th>\n",
       "      <th rowspan=\"2\" valign=\"top\">0</th>\n",
       "      <th>0</th>\n",
       "      <td>(0.999, 0.998, 0.998)</td>\n",
       "      <td>0.998333</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>(1.0, 1.0)</td>\n",
       "      <td>1.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"2\" valign=\"top\">10</th>\n",
       "      <th>250000</th>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <td>(0.228, 0.247, 0.254, 0.237, 0.256, 0.254, 0.22)</td>\n",
       "      <td>0.242286</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>400000</th>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <td>(1.0, 1.0)</td>\n",
       "      <td>1.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>retrieval29_ext</th>\n",
       "      <th>60</th>\n",
       "      <th>850000</th>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <td>(0.189, 0.192, 0.189)</td>\n",
       "      <td>0.190000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"3\" valign=\"top\">retrieval59</th>\n",
       "      <th rowspan=\"3\" valign=\"top\">120</th>\n",
       "      <th>500000</th>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <td>(0.159,)</td>\n",
       "      <td>0.159000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>600000</th>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <td>(0.171,)</td>\n",
       "      <td>0.171000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1000000</th>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <td>(0.166,)</td>\n",
       "      <td>0.166000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>retrieval59_ext</th>\n",
       "      <th>120</th>\n",
       "      <th>600000</th>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <td>(0.152,)</td>\n",
       "      <td>0.152000</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                                                                                                    test acc  \\\n",
       "dataset         tgt_len max_step num_mem_tokens mem_len mem_backprop_depth                                                     \n",
       "retrieval       2       1000000  0              2       0                                              (0.965, 0.999, 0.999)   \n",
       "                                 2              0       0                                              (0.989, 0.922, 0.984)   \n",
       "                3       1000000  0              3       0                                              (0.669, 0.905, 0.973)   \n",
       "                                 3              0       0                                              (0.986, 0.994, 0.996)   \n",
       "                                                        1                                                            (0.96,)   \n",
       "                4       1000000  0              4       0                                                    (1.0, 1.0, 1.0)   \n",
       "                                 4              0       0                                                (0.999, 1.0, 0.999)   \n",
       "                                                        1                                                     (0.999, 0.999)   \n",
       "                5       700000   5              0       0                                                           (0.963,)   \n",
       "                        1000000  0              5       0                                                    (1.0, 1.0, 1.0)   \n",
       "                                 5              0       0                                              (0.999, 0.998, 0.998)   \n",
       "                                                        1                                                         (1.0, 1.0)   \n",
       "                10      250000   0              0       0                   (0.228, 0.247, 0.254, 0.237, 0.256, 0.254, 0.22)   \n",
       "                        400000   0              0       0                                                         (1.0, 1.0)   \n",
       "retrieval29_ext 60      850000   0              0       0                                              (0.189, 0.192, 0.189)   \n",
       "retrieval59     120     500000   0              0       0                                                           (0.159,)   \n",
       "                        600000   0              0       0                                                           (0.171,)   \n",
       "                        1000000  0              0       0                                                           (0.166,)   \n",
       "retrieval59_ext 120     600000   0              0       0                                                           (0.152,)   \n",
       "\n",
       "                                                                            mean acc  \n",
       "dataset         tgt_len max_step num_mem_tokens mem_len mem_backprop_depth            \n",
       "retrieval       2       1000000  0              2       0                   0.987667  \n",
       "                                 2              0       0                   0.965000  \n",
       "                3       1000000  0              3       0                   0.849000  \n",
       "                                 3              0       0                   0.992000  \n",
       "                                                        1                   0.960000  \n",
       "                4       1000000  0              4       0                   1.000000  \n",
       "                                 4              0       0                   0.999333  \n",
       "                                                        1                   0.999000  \n",
       "                5       700000   5              0       0                   0.963000  \n",
       "                        1000000  0              5       0                   1.000000  \n",
       "                                 5              0       0                   0.998333  \n",
       "                                                        1                   1.000000  \n",
       "                10      250000   0              0       0                   0.242286  \n",
       "                        400000   0              0       0                   1.000000  \n",
       "retrieval29_ext 60      850000   0              0       0                   0.190000  \n",
       "retrieval59     120     500000   0              0       0                   0.159000  \n",
       "                        600000   0              0       0                   0.171000  \n",
       "                        1000000  0              0       0                   0.166000  \n",
       "retrieval59_ext 120     600000   0              0       0                   0.152000  "
      ]
     },
     "execution_count": 41,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_ = retrieval[retrieval.lr == 0.0001]\n",
    "gb_cols = ['dataset', 'tgt_len', 'max_step',  'num_mem_tokens', 'mem_len', 'mem_backprop_depth']#, 'seed']\n",
    "gb = df_.dropna(subset=['test acc']).groupby(gb_cols).agg({'test acc': tuple})\n",
    "gb['mean acc'] = df_.groupby(gb_cols).mean()['test acc']\n",
    "gb"
   ]
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "3ad594eba56fa4d30e478d0eb2c02077805d7e655fd2c7b71fc86bcad8bf7b09"
  },
  "kernelspec": {
   "display_name": "Python 3.9.0 64-bit ('cudaenv': conda)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.0"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
