{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# !scp mygpu:TXL/transformer-xl/pytorch/test_res.csv ./test_res_retrieval.csv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# !scp gpu6:bulatov/TXL/_git/test_res.csv ./test_res_synthetic_results.csv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# df = pd.read_csv('test_res_synthetic.csv')\n",
    "# df = pd.concat((df, pd.read_csv('test_res_mygpu.csv')))\n",
    "# df = pd.concat((df, pd.read_csv('test_res_synthetic_results.csv')))\n",
    "# df = pd.concat((df, pd.read_csv('test_res_synthetic_results_all.csv')))\n",
    "df_retrieval = pd.read_csv('test_res_retrieval.csv')\n",
    "df = pd.read_csv('test_res_synthetic_results.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = df.dropna(subset=['last_val_acc'])\n",
    "df = df.drop('work_dir', axis=1)\n",
    "df = df.drop_duplicates()\n",
    "\n",
    "df = df[(df.max_step > 450000) | (df.tgt_len >= 24)]\n",
    "# df = df[~((df.tgt_len == 12) & (df.max_step == 700000))]\n",
    "df = df[~((df.tgt_len == 24) & (df.max_step == 300000))]\n",
    "df = df[~((df.tgt_len == 8) & (df.num_mem_tokens == 12))]\n",
    "df = df[~((df.dataset == 'reverse') & (df.max_step == 900000))]\n",
    "\n",
    "\n",
    "df_retrieval = df_retrieval.dropna(subset=['last_val_acc'])\n",
    "df_retrieval = df_retrieval.drop('work_dir', axis=1)\n",
    "df_retrieval = df_retrieval.drop_duplicates()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "reverse = df[df.dataset == 'reverse']\n",
    "copy = df[df.dataset == 'copy']\n",
    "retrieval = df_retrieval[df_retrieval.dataset.apply(lambda x: 'retrieval' in x)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th>test acc</th>\n",
       "      <th>mean acc</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>dataset</th>\n",
       "      <th>tgt_len</th>\n",
       "      <th>lr</th>\n",
       "      <th>max_step</th>\n",
       "      <th>num_mem_tokens</th>\n",
       "      <th>mem_len</th>\n",
       "      <th>mem_backprop_depth</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th rowspan=\"13\" valign=\"top\">copy</th>\n",
       "      <th rowspan=\"3\" valign=\"top\">8</th>\n",
       "      <th rowspan=\"3\" valign=\"top\">0.0001</th>\n",
       "      <th rowspan=\"3\" valign=\"top\">900000</th>\n",
       "      <th rowspan=\"2\" valign=\"top\">0</th>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <td>(0.102, 0.102, 0.102, 0.101, 0.102)</td>\n",
       "      <td>0.101800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <th>0</th>\n",
       "      <td>(0.451, 0.189, 0.259, 0.186)</td>\n",
       "      <td>0.271250</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <td>(1.0, 1.0, 1.0, 1.0, 1.0)</td>\n",
       "      <td>1.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"5\" valign=\"top\">12</th>\n",
       "      <th rowspan=\"5\" valign=\"top\">0.0001</th>\n",
       "      <th rowspan=\"2\" valign=\"top\">700000</th>\n",
       "      <th>0</th>\n",
       "      <th>6</th>\n",
       "      <th>0</th>\n",
       "      <td>(0.657, 0.298, 0.355, 0.856)</td>\n",
       "      <td>0.541500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <td>(1.0, 1.0, 1.0, 1.0)</td>\n",
       "      <td>1.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"3\" valign=\"top\">900000</th>\n",
       "      <th rowspan=\"2\" valign=\"top\">0</th>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <td>(0.102, 0.102, 0.102, 0.102)</td>\n",
       "      <td>0.102000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <th>0</th>\n",
       "      <td>(1.0, 0.904, 0.998, 0.998)</td>\n",
       "      <td>0.975000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <td>(1.0, 1.0, 1.0)</td>\n",
       "      <td>1.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"4\" valign=\"top\">24</th>\n",
       "      <th rowspan=\"4\" valign=\"top\">0.0001</th>\n",
       "      <th rowspan=\"3\" valign=\"top\">400000</th>\n",
       "      <th rowspan=\"2\" valign=\"top\">0</th>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <td>(0.12, 0.12, 0.115)</td>\n",
       "      <td>0.118333</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>24</th>\n",
       "      <th>0</th>\n",
       "      <td>(1.0, 1.0, 0.999)</td>\n",
       "      <td>0.999667</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>24</th>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <td>(1.0, 1.0, 1.0)</td>\n",
       "      <td>1.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>500000</th>\n",
       "      <th>12</th>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <td>(1.0, 1.0)</td>\n",
       "      <td>1.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>72</th>\n",
       "      <th>0.0001</th>\n",
       "      <th>400000</th>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <td>(1.0, 1.0, 1.0)</td>\n",
       "      <td>1.000000</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                                                                                      test acc  \\\n",
       "dataset tgt_len lr     max_step num_mem_tokens mem_len mem_backprop_depth                                        \n",
       "copy    8       0.0001 900000   0              0       0                   (0.102, 0.102, 0.102, 0.101, 0.102)   \n",
       "                                               8       0                          (0.451, 0.189, 0.259, 0.186)   \n",
       "                                8              0       0                             (1.0, 1.0, 1.0, 1.0, 1.0)   \n",
       "        12      0.0001 700000   0              6       0                          (0.657, 0.298, 0.355, 0.856)   \n",
       "                                6              0       0                                  (1.0, 1.0, 1.0, 1.0)   \n",
       "                       900000   0              0       0                          (0.102, 0.102, 0.102, 0.102)   \n",
       "                                               12      0                            (1.0, 0.904, 0.998, 0.998)   \n",
       "                                12             0       0                                       (1.0, 1.0, 1.0)   \n",
       "        24      0.0001 400000   0              0       0                                   (0.12, 0.12, 0.115)   \n",
       "                                               24      0                                     (1.0, 1.0, 0.999)   \n",
       "                                24             0       0                                       (1.0, 1.0, 1.0)   \n",
       "                       500000   12             0       0                                            (1.0, 1.0)   \n",
       "        72      0.0001 400000   0              0       0                                       (1.0, 1.0, 1.0)   \n",
       "\n",
       "                                                                           mean acc  \n",
       "dataset tgt_len lr     max_step num_mem_tokens mem_len mem_backprop_depth            \n",
       "copy    8       0.0001 900000   0              0       0                   0.101800  \n",
       "                                               8       0                   0.271250  \n",
       "                                8              0       0                   1.000000  \n",
       "        12      0.0001 700000   0              6       0                   0.541500  \n",
       "                                6              0       0                   1.000000  \n",
       "                       900000   0              0       0                   0.102000  \n",
       "                                               12      0                   0.975000  \n",
       "                                12             0       0                   1.000000  \n",
       "        24      0.0001 400000   0              0       0                   0.118333  \n",
       "                                               24      0                   0.999667  \n",
       "                                24             0       0                   1.000000  \n",
       "                       500000   12             0       0                   1.000000  \n",
       "        72      0.0001 400000   0              0       0                   1.000000  "
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "gb_cols = ['dataset', 'tgt_len', 'lr', 'max_step',  'num_mem_tokens', 'mem_len', 'mem_backprop_depth']#, 'seed']\n",
    "df_ = copy\n",
    "gb = df_.dropna(subset=['test acc']).groupby(gb_cols).agg({'test acc': tuple})\n",
    "gb['mean acc'] = df_.groupby(gb_cols).mean()['test acc']\n",
    "gb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th>test acc</th>\n",
       "      <th>mean acc</th>\n",
       "      <th>med acc</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>dataset</th>\n",
       "      <th>tgt_len</th>\n",
       "      <th>max_step</th>\n",
       "      <th>num_mem_tokens</th>\n",
       "      <th>mem_len</th>\n",
       "      <th>mem_backprop_depth</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th rowspan=\"14\" valign=\"top\">reverse</th>\n",
       "      <th rowspan=\"3\" valign=\"top\">8</th>\n",
       "      <th rowspan=\"3\" valign=\"top\">700000</th>\n",
       "      <th rowspan=\"2\" valign=\"top\">0</th>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <td>(0.101, 0.102, 0.102, 0.102)</td>\n",
       "      <td>0.101750</td>\n",
       "      <td>0.102</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <th>0</th>\n",
       "      <td>(0.627, 0.696, 0.68, 0.66)</td>\n",
       "      <td>0.665750</td>\n",
       "      <td>0.670</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <td>(1.0, 0.701, 0.702, 1.0, 1.0, 1.0)</td>\n",
       "      <td>0.900500</td>\n",
       "      <td>1.000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"5\" valign=\"top\">12</th>\n",
       "      <th rowspan=\"5\" valign=\"top\">850000</th>\n",
       "      <th rowspan=\"3\" valign=\"top\">0</th>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <td>(0.1, 0.1, 0.102, 0.1)</td>\n",
       "      <td>0.100500</td>\n",
       "      <td>0.100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <th>0</th>\n",
       "      <td>(0.832, 0.8, 0.782, 0.784)</td>\n",
       "      <td>0.799500</td>\n",
       "      <td>0.792</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <th>0</th>\n",
       "      <td>(0.877, 0.872, 0.833)</td>\n",
       "      <td>0.860667</td>\n",
       "      <td>0.872</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <td>(1.0, 1.0, 1.0)</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <td>(1.0, 1.0, 1.0)</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"5\" valign=\"top\">24</th>\n",
       "      <th rowspan=\"3\" valign=\"top\">400000</th>\n",
       "      <th rowspan=\"2\" valign=\"top\">0</th>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <td>(0.102, 0.102, 0.102)</td>\n",
       "      <td>0.102000</td>\n",
       "      <td>0.102</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>24</th>\n",
       "      <th>0</th>\n",
       "      <td>(0.959, 0.899, 0.394)</td>\n",
       "      <td>0.750667</td>\n",
       "      <td>0.899</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>24</th>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <td>(1.0, 1.0, 1.0, 1.0)</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"2\" valign=\"top\">500000</th>\n",
       "      <th>0</th>\n",
       "      <th>12</th>\n",
       "      <th>0</th>\n",
       "      <td>(0.974, 0.996)</td>\n",
       "      <td>0.985000</td>\n",
       "      <td>0.985</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <td>(1.0, 1.0)</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>48</th>\n",
       "      <th>250000</th>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <td>(1.0, 1.0, 1.0)</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                                                                              test acc  \\\n",
       "dataset tgt_len max_step num_mem_tokens mem_len mem_backprop_depth                                       \n",
       "reverse 8       700000   0              0       0                         (0.101, 0.102, 0.102, 0.102)   \n",
       "                                        8       0                           (0.627, 0.696, 0.68, 0.66)   \n",
       "                         8              0       0                   (1.0, 0.701, 0.702, 1.0, 1.0, 1.0)   \n",
       "        12      850000   0              0       0                               (0.1, 0.1, 0.102, 0.1)   \n",
       "                                        6       0                           (0.832, 0.8, 0.782, 0.784)   \n",
       "                                        12      0                                (0.877, 0.872, 0.833)   \n",
       "                         6              0       0                                      (1.0, 1.0, 1.0)   \n",
       "                         12             0       0                                      (1.0, 1.0, 1.0)   \n",
       "        24      400000   0              0       0                                (0.102, 0.102, 0.102)   \n",
       "                                        24      0                                (0.959, 0.899, 0.394)   \n",
       "                         24             0       0                                 (1.0, 1.0, 1.0, 1.0)   \n",
       "                500000   0              12      0                                       (0.974, 0.996)   \n",
       "                         12             0       0                                           (1.0, 1.0)   \n",
       "        48      250000   0              0       0                                      (1.0, 1.0, 1.0)   \n",
       "\n",
       "                                                                    mean acc  \\\n",
       "dataset tgt_len max_step num_mem_tokens mem_len mem_backprop_depth             \n",
       "reverse 8       700000   0              0       0                   0.101750   \n",
       "                                        8       0                   0.665750   \n",
       "                         8              0       0                   0.900500   \n",
       "        12      850000   0              0       0                   0.100500   \n",
       "                                        6       0                   0.799500   \n",
       "                                        12      0                   0.860667   \n",
       "                         6              0       0                   1.000000   \n",
       "                         12             0       0                   1.000000   \n",
       "        24      400000   0              0       0                   0.102000   \n",
       "                                        24      0                   0.750667   \n",
       "                         24             0       0                   1.000000   \n",
       "                500000   0              12      0                   0.985000   \n",
       "                         12             0       0                   1.000000   \n",
       "        48      250000   0              0       0                   1.000000   \n",
       "\n",
       "                                                                    med acc  \n",
       "dataset tgt_len max_step num_mem_tokens mem_len mem_backprop_depth           \n",
       "reverse 8       700000   0              0       0                     0.102  \n",
       "                                        8       0                     0.670  \n",
       "                         8              0       0                     1.000  \n",
       "        12      850000   0              0       0                     0.100  \n",
       "                                        6       0                     0.792  \n",
       "                                        12      0                     0.872  \n",
       "                         6              0       0                     1.000  \n",
       "                         12             0       0                     1.000  \n",
       "        24      400000   0              0       0                     0.102  \n",
       "                                        24      0                     0.899  \n",
       "                         24             0       0                     1.000  \n",
       "                500000   0              12      0                     0.985  \n",
       "                         12             0       0                     1.000  \n",
       "        48      250000   0              0       0                     1.000  "
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_ = reverse\n",
    "gb_cols = ['dataset', 'tgt_len', 'max_step',  'num_mem_tokens', 'mem_len', 'mem_backprop_depth']#, 'seed']\n",
    "gb = df_.dropna(subset=['test acc']).groupby(gb_cols).agg({'test acc': tuple})\n",
    "gb['mean acc'] = df_.groupby(gb_cols).mean()['test acc']\n",
    "gb['med acc'] = df_.groupby(gb_cols).median()['test acc']\n",
    "gb['med acc'] = df_.groupby(gb_cols).median()['test acc']\n",
    "gb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th>last_val_acc</th>\n",
       "      <th>test acc</th>\n",
       "      <th>mean acc</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>dataset</th>\n",
       "      <th>tgt_len</th>\n",
       "      <th>max_step</th>\n",
       "      <th>num_mem_tokens</th>\n",
       "      <th>mem_len</th>\n",
       "      <th>mem_backprop_depth</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th rowspan=\"14\" valign=\"top\">copy</th>\n",
       "      <th rowspan=\"4\" valign=\"top\">8</th>\n",
       "      <th>700000</th>\n",
       "      <th>0</th>\n",
       "      <th>8</th>\n",
       "      <th>0</th>\n",
       "      <td>(0.1,)</td>\n",
       "      <td>(nan,)</td>\n",
       "      <td>0.100000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"3\" valign=\"top\">900000</th>\n",
       "      <th rowspan=\"2\" valign=\"top\">0</th>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <td>(0.101, 0.1, 0.099, 0.098, 0.098)</td>\n",
       "      <td>(0.102, 0.102, 0.102, 0.101, 0.102)</td>\n",
       "      <td>0.099200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <th>0</th>\n",
       "      <td>(0.447, 0.258, 0.188, 0.185, 0.1)</td>\n",
       "      <td>(0.451, 0.259, 0.189, 0.186, nan)</td>\n",
       "      <td>0.235600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <td>(1.0, 1.0, 1.0, 1.0, 1.0, 1.0)</td>\n",
       "      <td>(nan, 1.0, 1.0, 1.0, 1.0, 1.0)</td>\n",
       "      <td>1.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"5\" valign=\"top\">12</th>\n",
       "      <th rowspan=\"2\" valign=\"top\">700000</th>\n",
       "      <th>0</th>\n",
       "      <th>6</th>\n",
       "      <th>0</th>\n",
       "      <td>(0.808, 0.658, 0.334, 0.285, 0.103)</td>\n",
       "      <td>(0.856, 0.657, 0.355, 0.298, nan)</td>\n",
       "      <td>0.437600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <td>(1.0, 1.0, 1.0, 1.0, 0.1)</td>\n",
       "      <td>(1.0, 1.0, 1.0, 1.0, nan)</td>\n",
       "      <td>0.820000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"3\" valign=\"top\">900000</th>\n",
       "      <th rowspan=\"2\" valign=\"top\">0</th>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <td>(0.102, 0.101, 0.1, 0.099)</td>\n",
       "      <td>(0.102, 0.102, 0.102, 0.102)</td>\n",
       "      <td>0.100500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <th>0</th>\n",
       "      <td>(1.0, 0.998, 0.981, 0.866)</td>\n",
       "      <td>(1.0, 0.998, 0.998, 0.904)</td>\n",
       "      <td>0.961250</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <td>(1.0, 1.0, 1.0, 0.101)</td>\n",
       "      <td>(1.0, 1.0, 1.0, nan)</td>\n",
       "      <td>0.775250</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"4\" valign=\"top\">24</th>\n",
       "      <th rowspan=\"3\" valign=\"top\">400000</th>\n",
       "      <th rowspan=\"2\" valign=\"top\">0</th>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <td>(0.118, 0.116, 0.114)</td>\n",
       "      <td>(0.12, 0.12, 0.115)</td>\n",
       "      <td>0.116000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>24</th>\n",
       "      <th>0</th>\n",
       "      <td>(1.0, 1.0, 0.998)</td>\n",
       "      <td>(1.0, 1.0, 0.999)</td>\n",
       "      <td>0.999333</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>24</th>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <td>(1.0, 1.0, 1.0)</td>\n",
       "      <td>(1.0, 1.0, 1.0)</td>\n",
       "      <td>1.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>500000</th>\n",
       "      <th>12</th>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <td>(1.0, 1.0)</td>\n",
       "      <td>(1.0, 1.0)</td>\n",
       "      <td>1.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>72</th>\n",
       "      <th>400000</th>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <td>(1.0, 1.0, 1.0)</td>\n",
       "      <td>(1.0, 1.0, 1.0)</td>\n",
       "      <td>1.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"15\" valign=\"top\">reverse</th>\n",
       "      <th rowspan=\"3\" valign=\"top\">8</th>\n",
       "      <th rowspan=\"3\" valign=\"top\">700000</th>\n",
       "      <th rowspan=\"2\" valign=\"top\">0</th>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <td>(0.103, 0.101, 0.1, 0.099)</td>\n",
       "      <td>(0.102, 0.102, 0.102, 0.101)</td>\n",
       "      <td>0.100750</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <th>0</th>\n",
       "      <td>(0.694, 0.678, 0.658, 0.622)</td>\n",
       "      <td>(0.696, 0.68, 0.66, 0.627)</td>\n",
       "      <td>0.663000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <td>(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.701, 0.699, 0...</td>\n",
       "      <td>(1.0, 1.0, nan, nan, 1.0, 1.0, 0.702, 0.701, nan)</td>\n",
       "      <td>0.899889</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"5\" valign=\"top\">12</th>\n",
       "      <th rowspan=\"5\" valign=\"top\">850000</th>\n",
       "      <th rowspan=\"3\" valign=\"top\">0</th>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <td>(0.104, 0.103, 0.097, 0.094)</td>\n",
       "      <td>(0.1, 0.102, 0.1, 0.1)</td>\n",
       "      <td>0.099500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <th>0</th>\n",
       "      <td>(0.868, 0.844, 0.804, 0.787, 0.779, 0.66)</td>\n",
       "      <td>(0.832, 0.782, 0.8, 0.784, nan, nan)</td>\n",
       "      <td>0.790333</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <th>0</th>\n",
       "      <td>(0.873, 0.837, 0.786)</td>\n",
       "      <td>(0.877, 0.833, 0.872)</td>\n",
       "      <td>0.832000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <td>(1.0, 1.0, 1.0)</td>\n",
       "      <td>(1.0, 1.0, 1.0)</td>\n",
       "      <td>1.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <td>(1.0, 1.0, 1.0)</td>\n",
       "      <td>(1.0, 1.0, 1.0)</td>\n",
       "      <td>1.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"6\" valign=\"top\">24</th>\n",
       "      <th>250000</th>\n",
       "      <th>0</th>\n",
       "      <th>24</th>\n",
       "      <th>0</th>\n",
       "      <td>(0.991,)</td>\n",
       "      <td>(nan,)</td>\n",
       "      <td>0.991000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"3\" valign=\"top\">400000</th>\n",
       "      <th rowspan=\"2\" valign=\"top\">0</th>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <td>(0.102, 0.101, 0.1)</td>\n",
       "      <td>(0.102, 0.102, 0.102)</td>\n",
       "      <td>0.101000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>24</th>\n",
       "      <th>0</th>\n",
       "      <td>(0.943, 0.883, 0.81, 0.553)</td>\n",
       "      <td>(0.959, nan, 0.899, 0.394)</td>\n",
       "      <td>0.797250</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>24</th>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <td>(1.0, 1.0, 1.0, 1.0)</td>\n",
       "      <td>(1.0, 1.0, 1.0, 1.0)</td>\n",
       "      <td>1.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"2\" valign=\"top\">500000</th>\n",
       "      <th>0</th>\n",
       "      <th>12</th>\n",
       "      <th>0</th>\n",
       "      <td>(0.993, 0.98)</td>\n",
       "      <td>(0.996, 0.974)</td>\n",
       "      <td>0.986500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <td>(1.0, 1.0)</td>\n",
       "      <td>(1.0, 1.0)</td>\n",
       "      <td>1.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>48</th>\n",
       "      <th>250000</th>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <td>(1.0, 1.0, 1.0)</td>\n",
       "      <td>(1.0, 1.0, 1.0)</td>\n",
       "      <td>1.000000</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                                                                                         last_val_acc  \\\n",
       "dataset tgt_len max_step num_mem_tokens mem_len mem_backprop_depth                                                      \n",
       "copy    8       700000   0              8       0                                                              (0.1,)   \n",
       "                900000   0              0       0                                   (0.101, 0.1, 0.099, 0.098, 0.098)   \n",
       "                                        8       0                                   (0.447, 0.258, 0.188, 0.185, 0.1)   \n",
       "                         8              0       0                                      (1.0, 1.0, 1.0, 1.0, 1.0, 1.0)   \n",
       "        12      700000   0              6       0                                 (0.808, 0.658, 0.334, 0.285, 0.103)   \n",
       "                         6              0       0                                           (1.0, 1.0, 1.0, 1.0, 0.1)   \n",
       "                900000   0              0       0                                          (0.102, 0.101, 0.1, 0.099)   \n",
       "                                        12      0                                          (1.0, 0.998, 0.981, 0.866)   \n",
       "                         12             0       0                                              (1.0, 1.0, 1.0, 0.101)   \n",
       "        24      400000   0              0       0                                               (0.118, 0.116, 0.114)   \n",
       "                                        24      0                                                   (1.0, 1.0, 0.998)   \n",
       "                         24             0       0                                                     (1.0, 1.0, 1.0)   \n",
       "                500000   12             0       0                                                          (1.0, 1.0)   \n",
       "        72      400000   0              0       0                                                     (1.0, 1.0, 1.0)   \n",
       "reverse 8       700000   0              0       0                                          (0.103, 0.101, 0.1, 0.099)   \n",
       "                                        8       0                                        (0.694, 0.678, 0.658, 0.622)   \n",
       "                         8              0       0                   (1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.701, 0.699, 0...   \n",
       "        12      850000   0              0       0                                        (0.104, 0.103, 0.097, 0.094)   \n",
       "                                        6       0                           (0.868, 0.844, 0.804, 0.787, 0.779, 0.66)   \n",
       "                                        12      0                                               (0.873, 0.837, 0.786)   \n",
       "                         6              0       0                                                     (1.0, 1.0, 1.0)   \n",
       "                         12             0       0                                                     (1.0, 1.0, 1.0)   \n",
       "        24      250000   0              24      0                                                            (0.991,)   \n",
       "                400000   0              0       0                                                 (0.102, 0.101, 0.1)   \n",
       "                                        24      0                                         (0.943, 0.883, 0.81, 0.553)   \n",
       "                         24             0       0                                                (1.0, 1.0, 1.0, 1.0)   \n",
       "                500000   0              12      0                                                       (0.993, 0.98)   \n",
       "                         12             0       0                                                          (1.0, 1.0)   \n",
       "        48      250000   0              0       0                                                     (1.0, 1.0, 1.0)   \n",
       "\n",
       "                                                                                                             test acc  \\\n",
       "dataset tgt_len max_step num_mem_tokens mem_len mem_backprop_depth                                                      \n",
       "copy    8       700000   0              8       0                                                              (nan,)   \n",
       "                900000   0              0       0                                 (0.102, 0.102, 0.102, 0.101, 0.102)   \n",
       "                                        8       0                                   (0.451, 0.259, 0.189, 0.186, nan)   \n",
       "                         8              0       0                                      (nan, 1.0, 1.0, 1.0, 1.0, 1.0)   \n",
       "        12      700000   0              6       0                                   (0.856, 0.657, 0.355, 0.298, nan)   \n",
       "                         6              0       0                                           (1.0, 1.0, 1.0, 1.0, nan)   \n",
       "                900000   0              0       0                                        (0.102, 0.102, 0.102, 0.102)   \n",
       "                                        12      0                                          (1.0, 0.998, 0.998, 0.904)   \n",
       "                         12             0       0                                                (1.0, 1.0, 1.0, nan)   \n",
       "        24      400000   0              0       0                                                 (0.12, 0.12, 0.115)   \n",
       "                                        24      0                                                   (1.0, 1.0, 0.999)   \n",
       "                         24             0       0                                                     (1.0, 1.0, 1.0)   \n",
       "                500000   12             0       0                                                          (1.0, 1.0)   \n",
       "        72      400000   0              0       0                                                     (1.0, 1.0, 1.0)   \n",
       "reverse 8       700000   0              0       0                                        (0.102, 0.102, 0.102, 0.101)   \n",
       "                                        8       0                                          (0.696, 0.68, 0.66, 0.627)   \n",
       "                         8              0       0                   (1.0, 1.0, nan, nan, 1.0, 1.0, 0.702, 0.701, nan)   \n",
       "        12      850000   0              0       0                                              (0.1, 0.102, 0.1, 0.1)   \n",
       "                                        6       0                                (0.832, 0.782, 0.8, 0.784, nan, nan)   \n",
       "                                        12      0                                               (0.877, 0.833, 0.872)   \n",
       "                         6              0       0                                                     (1.0, 1.0, 1.0)   \n",
       "                         12             0       0                                                     (1.0, 1.0, 1.0)   \n",
       "        24      250000   0              24      0                                                              (nan,)   \n",
       "                400000   0              0       0                                               (0.102, 0.102, 0.102)   \n",
       "                                        24      0                                          (0.959, nan, 0.899, 0.394)   \n",
       "                         24             0       0                                                (1.0, 1.0, 1.0, 1.0)   \n",
       "                500000   0              12      0                                                      (0.996, 0.974)   \n",
       "                         12             0       0                                                          (1.0, 1.0)   \n",
       "        48      250000   0              0       0                                                     (1.0, 1.0, 1.0)   \n",
       "\n",
       "                                                                    mean acc  \n",
       "dataset tgt_len max_step num_mem_tokens mem_len mem_backprop_depth            \n",
       "copy    8       700000   0              8       0                   0.100000  \n",
       "                900000   0              0       0                   0.099200  \n",
       "                                        8       0                   0.235600  \n",
       "                         8              0       0                   1.000000  \n",
       "        12      700000   0              6       0                   0.437600  \n",
       "                         6              0       0                   0.820000  \n",
       "                900000   0              0       0                   0.100500  \n",
       "                                        12      0                   0.961250  \n",
       "                         12             0       0                   0.775250  \n",
       "        24      400000   0              0       0                   0.116000  \n",
       "                                        24      0                   0.999333  \n",
       "                         24             0       0                   1.000000  \n",
       "                500000   12             0       0                   1.000000  \n",
       "        72      400000   0              0       0                   1.000000  \n",
       "reverse 8       700000   0              0       0                   0.100750  \n",
       "                                        8       0                   0.663000  \n",
       "                         8              0       0                   0.899889  \n",
       "        12      850000   0              0       0                   0.099500  \n",
       "                                        6       0                   0.790333  \n",
       "                                        12      0                   0.832000  \n",
       "                         6              0       0                   1.000000  \n",
       "                         12             0       0                   1.000000  \n",
       "        24      250000   0              24      0                   0.991000  \n",
       "                400000   0              0       0                   0.101000  \n",
       "                                        24      0                   0.797250  \n",
       "                         24             0       0                   1.000000  \n",
       "                500000   0              12      0                   0.986500  \n",
       "                         12             0       0                   1.000000  \n",
       "        48      250000   0              0       0                   1.000000  "
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "gb_cols = ['dataset', 'tgt_len', 'max_step',  'num_mem_tokens', 'mem_len', 'mem_backprop_depth']\n",
    "gb = df.dropna(subset=['last_val_acc']).sort_values('last_val_acc', ascending=False).groupby(gb_cols).agg({'last_val_acc': tuple, 'test acc': tuple})\n",
    "\n",
    "gb['mean acc'] = df.groupby(gb_cols).mean()['last_val_acc']\n",
    "gb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th>test acc</th>\n",
       "      <th>mean acc</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>dataset</th>\n",
       "      <th>tgt_len</th>\n",
       "      <th>max_step</th>\n",
       "      <th>num_mem_tokens</th>\n",
       "      <th>mem_len</th>\n",
       "      <th>mem_backprop_depth</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th rowspan=\"20\" valign=\"top\">retrieval</th>\n",
       "      <th>1</th>\n",
       "      <th>1000000</th>\n",
       "      <th>3</th>\n",
       "      <th>0</th>\n",
       "      <th>1</th>\n",
       "      <td>(0.099,)</td>\n",
       "      <td>0.099000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"4\" valign=\"top\">2</th>\n",
       "      <th rowspan=\"4\" valign=\"top\">1000000</th>\n",
       "      <th rowspan=\"2\" valign=\"top\">0</th>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <td>(0.099,)</td>\n",
       "      <td>0.099000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <th>0</th>\n",
       "      <td>(0.965, 0.999, 0.999)</td>\n",
       "      <td>0.987667</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"2\" valign=\"top\">2</th>\n",
       "      <th rowspan=\"2\" valign=\"top\">0</th>\n",
       "      <th>0</th>\n",
       "      <td>(0.989, 0.922, 0.984)</td>\n",
       "      <td>0.965000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>(0.866,)</td>\n",
       "      <td>0.866000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"4\" valign=\"top\">3</th>\n",
       "      <th rowspan=\"4\" valign=\"top\">1000000</th>\n",
       "      <th rowspan=\"2\" valign=\"top\">0</th>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <td>(0.099,)</td>\n",
       "      <td>0.099000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <th>0</th>\n",
       "      <td>(0.669, 0.905, 0.973)</td>\n",
       "      <td>0.849000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"2\" valign=\"top\">3</th>\n",
       "      <th rowspan=\"2\" valign=\"top\">0</th>\n",
       "      <th>0</th>\n",
       "      <td>(0.986, 0.994, 0.996)</td>\n",
       "      <td>0.992000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>(0.995, 0.96, 0.6)</td>\n",
       "      <td>0.851667</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"4\" valign=\"top\">4</th>\n",
       "      <th rowspan=\"4\" valign=\"top\">1000000</th>\n",
       "      <th rowspan=\"2\" valign=\"top\">0</th>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <td>(0.496,)</td>\n",
       "      <td>0.496000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <th>0</th>\n",
       "      <td>(1.0, 1.0, 1.0)</td>\n",
       "      <td>1.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"2\" valign=\"top\">4</th>\n",
       "      <th rowspan=\"2\" valign=\"top\">0</th>\n",
       "      <th>0</th>\n",
       "      <td>(0.999, 1.0, 0.999)</td>\n",
       "      <td>0.999333</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>(1.0, 0.999, 0.999)</td>\n",
       "      <td>0.999333</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"5\" valign=\"top\">5</th>\n",
       "      <th>700000</th>\n",
       "      <th>5</th>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <td>(0.963,)</td>\n",
       "      <td>0.963000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"4\" valign=\"top\">1000000</th>\n",
       "      <th rowspan=\"2\" valign=\"top\">0</th>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <td>(0.496,)</td>\n",
       "      <td>0.496000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <th>0</th>\n",
       "      <td>(1.0, 1.0, 1.0)</td>\n",
       "      <td>1.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"2\" valign=\"top\">5</th>\n",
       "      <th rowspan=\"2\" valign=\"top\">0</th>\n",
       "      <th>0</th>\n",
       "      <td>(0.999, 0.998, 0.998)</td>\n",
       "      <td>0.998333</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>(0.999, 1.0, 1.0)</td>\n",
       "      <td>0.999667</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"2\" valign=\"top\">10</th>\n",
       "      <th>250000</th>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <td>(0.228, 0.247, 0.254, 0.237, 0.256, 0.254, 0.22)</td>\n",
       "      <td>0.242286</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>400000</th>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <td>(1.0, 1.0)</td>\n",
       "      <td>1.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>retrieval29_ext</th>\n",
       "      <th>60</th>\n",
       "      <th>850000</th>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <td>(0.189, 0.192, 0.189)</td>\n",
       "      <td>0.190000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"3\" valign=\"top\">retrieval59</th>\n",
       "      <th rowspan=\"3\" valign=\"top\">120</th>\n",
       "      <th>500000</th>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <td>(0.159,)</td>\n",
       "      <td>0.159000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>600000</th>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <td>(0.171,)</td>\n",
       "      <td>0.171000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1000000</th>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <td>(0.166,)</td>\n",
       "      <td>0.166000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>retrieval59_ext</th>\n",
       "      <th>120</th>\n",
       "      <th>600000</th>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <td>(0.152,)</td>\n",
       "      <td>0.152000</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                                                                                                    test acc  \\\n",
       "dataset         tgt_len max_step num_mem_tokens mem_len mem_backprop_depth                                                     \n",
       "retrieval       1       1000000  3              0       1                                                           (0.099,)   \n",
       "                2       1000000  0              0       0                                                           (0.099,)   \n",
       "                                                2       0                                              (0.965, 0.999, 0.999)   \n",
       "                                 2              0       0                                              (0.989, 0.922, 0.984)   \n",
       "                                                        1                                                           (0.866,)   \n",
       "                3       1000000  0              0       0                                                           (0.099,)   \n",
       "                                                3       0                                              (0.669, 0.905, 0.973)   \n",
       "                                 3              0       0                                              (0.986, 0.994, 0.996)   \n",
       "                                                        1                                                 (0.995, 0.96, 0.6)   \n",
       "                4       1000000  0              0       0                                                           (0.496,)   \n",
       "                                                4       0                                                    (1.0, 1.0, 1.0)   \n",
       "                                 4              0       0                                                (0.999, 1.0, 0.999)   \n",
       "                                                        1                                                (1.0, 0.999, 0.999)   \n",
       "                5       700000   5              0       0                                                           (0.963,)   \n",
       "                        1000000  0              0       0                                                           (0.496,)   \n",
       "                                                5       0                                                    (1.0, 1.0, 1.0)   \n",
       "                                 5              0       0                                              (0.999, 0.998, 0.998)   \n",
       "                                                        1                                                  (0.999, 1.0, 1.0)   \n",
       "                10      250000   0              0       0                   (0.228, 0.247, 0.254, 0.237, 0.256, 0.254, 0.22)   \n",
       "                        400000   0              0       0                                                         (1.0, 1.0)   \n",
       "retrieval29_ext 60      850000   0              0       0                                              (0.189, 0.192, 0.189)   \n",
       "retrieval59     120     500000   0              0       0                                                           (0.159,)   \n",
       "                        600000   0              0       0                                                           (0.171,)   \n",
       "                        1000000  0              0       0                                                           (0.166,)   \n",
       "retrieval59_ext 120     600000   0              0       0                                                           (0.152,)   \n",
       "\n",
       "                                                                            mean acc  \n",
       "dataset         tgt_len max_step num_mem_tokens mem_len mem_backprop_depth            \n",
       "retrieval       1       1000000  3              0       1                   0.099000  \n",
       "                2       1000000  0              0       0                   0.099000  \n",
       "                                                2       0                   0.987667  \n",
       "                                 2              0       0                   0.965000  \n",
       "                                                        1                   0.866000  \n",
       "                3       1000000  0              0       0                   0.099000  \n",
       "                                                3       0                   0.849000  \n",
       "                                 3              0       0                   0.992000  \n",
       "                                                        1                   0.851667  \n",
       "                4       1000000  0              0       0                   0.496000  \n",
       "                                                4       0                   1.000000  \n",
       "                                 4              0       0                   0.999333  \n",
       "                                                        1                   0.999333  \n",
       "                5       700000   5              0       0                   0.963000  \n",
       "                        1000000  0              0       0                   0.496000  \n",
       "                                                5       0                   1.000000  \n",
       "                                 5              0       0                   0.998333  \n",
       "                                                        1                   0.999667  \n",
       "                10      250000   0              0       0                   0.242286  \n",
       "                        400000   0              0       0                   1.000000  \n",
       "retrieval29_ext 60      850000   0              0       0                   0.190000  \n",
       "retrieval59     120     500000   0              0       0                   0.159000  \n",
       "                        600000   0              0       0                   0.171000  \n",
       "                        1000000  0              0       0                   0.166000  \n",
       "retrieval59_ext 120     600000   0              0       0                   0.152000  "
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_ = retrieval[retrieval.lr == 0.0001]\n",
    "gb_cols = ['dataset', 'tgt_len', 'max_step',  'num_mem_tokens', 'mem_len', 'mem_backprop_depth']#, 'seed']\n",
    "gb = df_.dropna(subset=['test acc']).groupby(gb_cols).agg({'test acc': tuple})\n",
    "gb['mean acc'] = df_.groupby(gb_cols).mean()['test acc']\n",
    "gb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 157,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>seed</th>\n",
       "      <th>test loss</th>\n",
       "      <th>test ppl</th>\n",
       "      <th>test acc</th>\n",
       "      <th>last_val_acc</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>1234</td>\n",
       "      <td>0.00000</td>\n",
       "      <td>1.0000</td>\n",
       "      <td>1.000</td>\n",
       "      <td>1.000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>2022</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>0.302</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>34</th>\n",
       "      <td>2022</td>\n",
       "      <td>1.27346</td>\n",
       "      <td>3.5732</td>\n",
       "      <td>0.302</td>\n",
       "      <td>0.300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>76</th>\n",
       "      <td>2049</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>0.098</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>118</th>\n",
       "      <td>2049</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>1.000</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "     seed  test loss  test ppl  test acc  last_val_acc\n",
       "3    1234    0.00000    1.0000     1.000         1.000\n",
       "4    2022        NaN       NaN       NaN         0.302\n",
       "34   2022    1.27346    3.5732     0.302         0.300\n",
       "76   2049        NaN       NaN       NaN         0.098\n",
       "118  2049        NaN       NaN       NaN         1.000"
      ]
     },
     "execution_count": 157,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_ = copy[(copy.tgt_len == 12) & (copy.num_mem_tokens == 12)]\n",
    "df_[[col for col in df_.columns if df_[col].unique().shape[0] > 1]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 158,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>seed</th>\n",
       "      <th>test loss</th>\n",
       "      <th>test ppl</th>\n",
       "      <th>test acc</th>\n",
       "      <th>last_val_acc</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>2049</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>0.103</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>1212</td>\n",
       "      <td>0.00000</td>\n",
       "      <td>1.00000</td>\n",
       "      <td>1.000</td>\n",
       "      <td>1.000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>2023</td>\n",
       "      <td>1.35207</td>\n",
       "      <td>3.86542</td>\n",
       "      <td>0.273</td>\n",
       "      <td>0.268</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>48</th>\n",
       "      <td>2049</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>1.000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>132</th>\n",
       "      <td>2049</td>\n",
       "      <td>0.00001</td>\n",
       "      <td>1.00001</td>\n",
       "      <td>1.000</td>\n",
       "      <td>1.000</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "     seed  test loss  test ppl  test acc  last_val_acc\n",
       "12   2049        NaN       NaN       NaN         0.103\n",
       "13   1212    0.00000   1.00000     1.000         1.000\n",
       "14   2023    1.35207   3.86542     0.273         0.268\n",
       "48   2049        NaN       NaN       NaN         1.000\n",
       "132  2049    0.00001   1.00001     1.000         1.000"
      ]
     },
     "execution_count": 158,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# df_ = reverse[(reverse.num_mem_tokens == 24) & (reverse.tgt_len == 24)]\n",
    "df_ = reverse[(reverse.num_mem_tokens == 12) & (reverse.tgt_len == 12)]\n",
    "df_[[col for col in df_.columns if df_[col].unique().shape[0] > 1]]#.sort_values('last_val_acc')#.dropna(subset=['test acc'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 98,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "Empty DataFrame\n",
       "Columns: []\n",
       "Index: []"
      ]
     },
     "execution_count": 98,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_ = reverse[(reverse.mem_len == ) & (reverse.tgt_len == 8)]\n",
    "df_[[col for col in df_.columns if df_[col].unique().shape[0] > 1]]#.sort_values('last_val_acc')#.dropna(subset=['test acc'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 102,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>lr</th>\n",
       "      <th>lr_min</th>\n",
       "      <th>test loss</th>\n",
       "      <th>test ppl</th>\n",
       "      <th>test acc</th>\n",
       "      <th>last_val_acc</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>115</th>\n",
       "      <td>0.00010</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>2.30265</td>\n",
       "      <td>10.00066</td>\n",
       "      <td>0.102</td>\n",
       "      <td>0.1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>116</th>\n",
       "      <td>0.00025</td>\n",
       "      <td>0.000001</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "          lr    lr_min  test loss  test ppl  test acc  last_val_acc\n",
       "115  0.00010  0.000000    2.30265  10.00066     0.102           0.1\n",
       "116  0.00025  0.000001        NaN       NaN       NaN           1.0"
      ]
     },
     "execution_count": 102,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_ = copy[(copy.num_mem_tokens == 12) & (copy.tgt_len == 12)]\n",
    "df_[[col for col in df_.columns if df_[col].unique().shape[0] > 1]]#.sort_values('last_val_acc')#.dropna(subset=['test acc'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 83,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "data              /home/bulatov/bulatov/TXL/data24/data24\n",
       "dataset                                           reverse\n",
       "n_layer                                                 4\n",
       "n_head                                                  4\n",
       "d_head                                                 64\n",
       "                                   ...                   \n",
       "n_nonemb_param                                     921088\n",
       "test loss                                             0.0\n",
       "test ppl                                              1.0\n",
       "test acc                                              1.0\n",
       "last_val_acc                                          1.0\n",
       "Name: 107, Length: 73, dtype: object"
      ]
     },
     "execution_count": 83,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_.loc[107]"
   ]
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "3ad594eba56fa4d30e478d0eb2c02077805d7e655fd2c7b71fc86bcad8bf7b09"
  },
  "kernelspec": {
   "display_name": "Python 3.9.0 64-bit ('cudaenv': conda)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.0"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
