{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pkg.utils.analyze import summarize_outputs_in_df\n",
    "\n",
    "train_path = \"/home/knowledge-tracing/outputs/ktst_benchmark_minus_assist2009_sweep/train\"\n",
    "test_path = \"/home/knowledge-tracing/outputs/ktst_benchmark_minus_assist2009_sweep/eval\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# train\n",
    "df_sweep = summarize_outputs_in_df(outputs_dir=train_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "data.dataset  data.val_fold_idx  data.format          model.aggregation  model.attn_variant           \n",
       "algebra2005   0                  combinatorial_dense  q_mean_c           learnable_alibi_monotonic_q_k    100\n",
       "                                 set_dense            q_mean_c           learnable_alibi_monotonic_q_k    100\n",
       "                                                      self_attn_all      learnable_alibi_monotonic_q_k    100\n",
       "              1                  combinatorial_dense  q_mean_c           learnable_alibi_monotonic_q_k    100\n",
       "                                 set_dense            q_mean_c           learnable_alibi_monotonic_q_k    100\n",
       "                                                                                                         ... \n",
       "statics2011   0                  set_dense            q_mean_c           learnable_alibi_monotonic_q_k    100\n",
       "              1                  set_dense            q_mean_c           learnable_alibi_monotonic_q_k    100\n",
       "              2                  set_dense            q_mean_c           learnable_alibi_monotonic_q_k    100\n",
       "              3                  set_dense            q_mean_c           learnable_alibi_monotonic_q_k    100\n",
       "              4                  set_dense            q_mean_c           learnable_alibi_monotonic_q_k    100\n",
       "Name: val_auc, Length: 65, dtype: int64"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_sweep_groupby_count = df_sweep.groupby([\"data.dataset\", \"data.val_fold_idx\", \"data.format\", \"model.aggregation\", \"model.attn_variant\"])[\"val_auc\"].count()\n",
    "df_sweep_groupby_count"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "val_auc\n",
       "100    65\n",
       "Name: count, dtype: int64"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# we should have 65 runs with count of 100\n",
    "df_sweep_groupby_count.value_counts()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>index</th>\n",
       "      <th>val_time</th>\n",
       "      <th>val_metric</th>\n",
       "      <th>val_auc</th>\n",
       "      <th>val_accuracy</th>\n",
       "      <th>val_epoch_mean_loss</th>\n",
       "      <th>epoch</th>\n",
       "      <th>iteration</th>\n",
       "      <th>trainable_params</th>\n",
       "      <th>best_iteration</th>\n",
       "      <th>...</th>\n",
       "      <th>data.batch_size</th>\n",
       "      <th>data.batch_size_val</th>\n",
       "      <th>data.val_fold_idx</th>\n",
       "      <th>optimizer._target_</th>\n",
       "      <th>optimizer._recursive_</th>\n",
       "      <th>optimizer.lr</th>\n",
       "      <th>optimizer.weight_decay</th>\n",
       "      <th>optimizer.theta_lr</th>\n",
       "      <th>run_dir</th>\n",
       "      <th>outputs_dir</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>0.994183</td>\n",
       "      <td>0.844278</td>\n",
       "      <td>0.844278</td>\n",
       "      <td>0.823307</td>\n",
       "      <td>0.393108</td>\n",
       "      <td>18</td>\n",
       "      <td>450</td>\n",
       "      <td>50079585</td>\n",
       "      <td>450</td>\n",
       "      <td>...</td>\n",
       "      <td>128</td>\n",
       "      <td>64</td>\n",
       "      <td>2</td>\n",
       "      <td>torch.optim.Adam</td>\n",
       "      <td>False</td>\n",
       "      <td>0.00010</td>\n",
       "      <td>0.00001</td>\n",
       "      <td>0.0010</td>\n",
       "      <td>/home/knowledge-tracing/outputs/ktst_benchmark...</td>\n",
       "      <td>/home/knowledge-tracing/outputs/ktst_benchmark...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0</td>\n",
       "      <td>0.796662</td>\n",
       "      <td>0.848743</td>\n",
       "      <td>0.848743</td>\n",
       "      <td>0.827729</td>\n",
       "      <td>0.390337</td>\n",
       "      <td>48</td>\n",
       "      <td>1200</td>\n",
       "      <td>50079585</td>\n",
       "      <td>1200</td>\n",
       "      <td>...</td>\n",
       "      <td>128</td>\n",
       "      <td>64</td>\n",
       "      <td>2</td>\n",
       "      <td>torch.optim.Adam</td>\n",
       "      <td>False</td>\n",
       "      <td>0.00005</td>\n",
       "      <td>0.00005</td>\n",
       "      <td>0.0001</td>\n",
       "      <td>/home/knowledge-tracing/outputs/ktst_benchmark...</td>\n",
       "      <td>/home/knowledge-tracing/outputs/ktst_benchmark...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>0</td>\n",
       "      <td>0.586055</td>\n",
       "      <td>0.840709</td>\n",
       "      <td>0.840709</td>\n",
       "      <td>0.822931</td>\n",
       "      <td>0.400883</td>\n",
       "      <td>20</td>\n",
       "      <td>500</td>\n",
       "      <td>23297437</td>\n",
       "      <td>500</td>\n",
       "      <td>...</td>\n",
       "      <td>128</td>\n",
       "      <td>64</td>\n",
       "      <td>2</td>\n",
       "      <td>torch.optim.Adam</td>\n",
       "      <td>False</td>\n",
       "      <td>0.00010</td>\n",
       "      <td>0.00001</td>\n",
       "      <td>0.0001</td>\n",
       "      <td>/home/knowledge-tracing/outputs/ktst_benchmark...</td>\n",
       "      <td>/home/knowledge-tracing/outputs/ktst_benchmark...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>0</td>\n",
       "      <td>0.635283</td>\n",
       "      <td>0.846317</td>\n",
       "      <td>0.846317</td>\n",
       "      <td>0.826301</td>\n",
       "      <td>0.390812</td>\n",
       "      <td>47</td>\n",
       "      <td>1175</td>\n",
       "      <td>23826649</td>\n",
       "      <td>1175</td>\n",
       "      <td>...</td>\n",
       "      <td>128</td>\n",
       "      <td>64</td>\n",
       "      <td>2</td>\n",
       "      <td>torch.optim.Adam</td>\n",
       "      <td>False</td>\n",
       "      <td>0.00005</td>\n",
       "      <td>0.00005</td>\n",
       "      <td>0.0010</td>\n",
       "      <td>/home/knowledge-tracing/outputs/ktst_benchmark...</td>\n",
       "      <td>/home/knowledge-tracing/outputs/ktst_benchmark...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>0</td>\n",
       "      <td>0.717781</td>\n",
       "      <td>0.845477</td>\n",
       "      <td>0.845477</td>\n",
       "      <td>0.825349</td>\n",
       "      <td>0.391653</td>\n",
       "      <td>50</td>\n",
       "      <td>1250</td>\n",
       "      <td>23166785</td>\n",
       "      <td>1250</td>\n",
       "      <td>...</td>\n",
       "      <td>128</td>\n",
       "      <td>64</td>\n",
       "      <td>2</td>\n",
       "      <td>torch.optim.Adam</td>\n",
       "      <td>False</td>\n",
       "      <td>0.00005</td>\n",
       "      <td>0.00005</td>\n",
       "      <td>0.0005</td>\n",
       "      <td>/home/knowledge-tracing/outputs/ktst_benchmark...</td>\n",
       "      <td>/home/knowledge-tracing/outputs/ktst_benchmark...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6495</th>\n",
       "      <td>0</td>\n",
       "      <td>1.270265</td>\n",
       "      <td>0.736107</td>\n",
       "      <td>0.736107</td>\n",
       "      <td>0.709592</td>\n",
       "      <td>0.566432</td>\n",
       "      <td>64</td>\n",
       "      <td>3328</td>\n",
       "      <td>3730425</td>\n",
       "      <td>3328</td>\n",
       "      <td>...</td>\n",
       "      <td>128</td>\n",
       "      <td>64</td>\n",
       "      <td>4</td>\n",
       "      <td>torch.optim.Adam</td>\n",
       "      <td>False</td>\n",
       "      <td>0.00005</td>\n",
       "      <td>0.00005</td>\n",
       "      <td>0.0001</td>\n",
       "      <td>/home/knowledge-tracing/outputs/ktst_benchmark...</td>\n",
       "      <td>/home/knowledge-tracing/outputs/ktst_benchmark...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6496</th>\n",
       "      <td>0</td>\n",
       "      <td>1.646861</td>\n",
       "      <td>0.734807</td>\n",
       "      <td>0.734807</td>\n",
       "      <td>0.709019</td>\n",
       "      <td>0.569656</td>\n",
       "      <td>47</td>\n",
       "      <td>2444</td>\n",
       "      <td>4721825</td>\n",
       "      <td>2444</td>\n",
       "      <td>...</td>\n",
       "      <td>128</td>\n",
       "      <td>64</td>\n",
       "      <td>4</td>\n",
       "      <td>torch.optim.Adam</td>\n",
       "      <td>False</td>\n",
       "      <td>0.00005</td>\n",
       "      <td>0.00005</td>\n",
       "      <td>0.0005</td>\n",
       "      <td>/home/knowledge-tracing/outputs/ktst_benchmark...</td>\n",
       "      <td>/home/knowledge-tracing/outputs/ktst_benchmark...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6497</th>\n",
       "      <td>0</td>\n",
       "      <td>1.359858</td>\n",
       "      <td>0.735501</td>\n",
       "      <td>0.735501</td>\n",
       "      <td>0.706811</td>\n",
       "      <td>0.568942</td>\n",
       "      <td>74</td>\n",
       "      <td>3848</td>\n",
       "      <td>3928705</td>\n",
       "      <td>3848</td>\n",
       "      <td>...</td>\n",
       "      <td>128</td>\n",
       "      <td>64</td>\n",
       "      <td>4</td>\n",
       "      <td>torch.optim.Adam</td>\n",
       "      <td>False</td>\n",
       "      <td>0.00005</td>\n",
       "      <td>0.00005</td>\n",
       "      <td>0.0001</td>\n",
       "      <td>/home/knowledge-tracing/outputs/ktst_benchmark...</td>\n",
       "      <td>/home/knowledge-tracing/outputs/ktst_benchmark...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6498</th>\n",
       "      <td>0</td>\n",
       "      <td>0.740544</td>\n",
       "      <td>0.735514</td>\n",
       "      <td>0.735514</td>\n",
       "      <td>0.709722</td>\n",
       "      <td>0.567375</td>\n",
       "      <td>71</td>\n",
       "      <td>3692</td>\n",
       "      <td>2672057</td>\n",
       "      <td>3692</td>\n",
       "      <td>...</td>\n",
       "      <td>128</td>\n",
       "      <td>64</td>\n",
       "      <td>4</td>\n",
       "      <td>torch.optim.Adam</td>\n",
       "      <td>False</td>\n",
       "      <td>0.00005</td>\n",
       "      <td>0.00005</td>\n",
       "      <td>0.0001</td>\n",
       "      <td>/home/knowledge-tracing/outputs/ktst_benchmark...</td>\n",
       "      <td>/home/knowledge-tracing/outputs/ktst_benchmark...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6499</th>\n",
       "      <td>0</td>\n",
       "      <td>0.856531</td>\n",
       "      <td>0.734903</td>\n",
       "      <td>0.734903</td>\n",
       "      <td>0.708326</td>\n",
       "      <td>0.568338</td>\n",
       "      <td>49</td>\n",
       "      <td>2548</td>\n",
       "      <td>3071413</td>\n",
       "      <td>2548</td>\n",
       "      <td>...</td>\n",
       "      <td>128</td>\n",
       "      <td>64</td>\n",
       "      <td>4</td>\n",
       "      <td>torch.optim.Adam</td>\n",
       "      <td>False</td>\n",
       "      <td>0.00010</td>\n",
       "      <td>0.00005</td>\n",
       "      <td>0.0001</td>\n",
       "      <td>/home/knowledge-tracing/outputs/ktst_benchmark...</td>\n",
       "      <td>/home/knowledge-tracing/outputs/ktst_benchmark...</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>6500 rows × 53 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "      index  val_time  val_metric   val_auc  val_accuracy  \\\n",
       "0         0  0.994183    0.844278  0.844278      0.823307   \n",
       "1         0  0.796662    0.848743  0.848743      0.827729   \n",
       "2         0  0.586055    0.840709  0.840709      0.822931   \n",
       "3         0  0.635283    0.846317  0.846317      0.826301   \n",
       "4         0  0.717781    0.845477  0.845477      0.825349   \n",
       "...     ...       ...         ...       ...           ...   \n",
       "6495      0  1.270265    0.736107  0.736107      0.709592   \n",
       "6496      0  1.646861    0.734807  0.734807      0.709019   \n",
       "6497      0  1.359858    0.735501  0.735501      0.706811   \n",
       "6498      0  0.740544    0.735514  0.735514      0.709722   \n",
       "6499      0  0.856531    0.734903  0.734903      0.708326   \n",
       "\n",
       "      val_epoch_mean_loss  epoch  iteration  trainable_params  best_iteration  \\\n",
       "0                0.393108     18        450          50079585             450   \n",
       "1                0.390337     48       1200          50079585            1200   \n",
       "2                0.400883     20        500          23297437             500   \n",
       "3                0.390812     47       1175          23826649            1175   \n",
       "4                0.391653     50       1250          23166785            1250   \n",
       "...                   ...    ...        ...               ...             ...   \n",
       "6495             0.566432     64       3328           3730425            3328   \n",
       "6496             0.569656     47       2444           4721825            2444   \n",
       "6497             0.568942     74       3848           3928705            3848   \n",
       "6498             0.567375     71       3692           2672057            3692   \n",
       "6499             0.568338     49       2548           3071413            2548   \n",
       "\n",
       "      ...  data.batch_size data.batch_size_val  data.val_fold_idx  \\\n",
       "0     ...              128                  64                  2   \n",
       "1     ...              128                  64                  2   \n",
       "2     ...              128                  64                  2   \n",
       "3     ...              128                  64                  2   \n",
       "4     ...              128                  64                  2   \n",
       "...   ...              ...                 ...                ...   \n",
       "6495  ...              128                  64                  4   \n",
       "6496  ...              128                  64                  4   \n",
       "6497  ...              128                  64                  4   \n",
       "6498  ...              128                  64                  4   \n",
       "6499  ...              128                  64                  4   \n",
       "\n",
       "      optimizer._target_ optimizer._recursive_ optimizer.lr  \\\n",
       "0       torch.optim.Adam                 False      0.00010   \n",
       "1       torch.optim.Adam                 False      0.00005   \n",
       "2       torch.optim.Adam                 False      0.00010   \n",
       "3       torch.optim.Adam                 False      0.00005   \n",
       "4       torch.optim.Adam                 False      0.00005   \n",
       "...                  ...                   ...          ...   \n",
       "6495    torch.optim.Adam                 False      0.00005   \n",
       "6496    torch.optim.Adam                 False      0.00005   \n",
       "6497    torch.optim.Adam                 False      0.00005   \n",
       "6498    torch.optim.Adam                 False      0.00005   \n",
       "6499    torch.optim.Adam                 False      0.00010   \n",
       "\n",
       "      optimizer.weight_decay  optimizer.theta_lr  \\\n",
       "0                    0.00001              0.0010   \n",
       "1                    0.00005              0.0001   \n",
       "2                    0.00001              0.0001   \n",
       "3                    0.00005              0.0010   \n",
       "4                    0.00005              0.0005   \n",
       "...                      ...                 ...   \n",
       "6495                 0.00005              0.0001   \n",
       "6496                 0.00005              0.0005   \n",
       "6497                 0.00005              0.0001   \n",
       "6498                 0.00005              0.0001   \n",
       "6499                 0.00005              0.0001   \n",
       "\n",
       "                                                run_dir  \\\n",
       "0     /home/knowledge-tracing/outputs/ktst_benchmark...   \n",
       "1     /home/knowledge-tracing/outputs/ktst_benchmark...   \n",
       "2     /home/knowledge-tracing/outputs/ktst_benchmark...   \n",
       "3     /home/knowledge-tracing/outputs/ktst_benchmark...   \n",
       "4     /home/knowledge-tracing/outputs/ktst_benchmark...   \n",
       "...                                                 ...   \n",
       "6495  /home/knowledge-tracing/outputs/ktst_benchmark...   \n",
       "6496  /home/knowledge-tracing/outputs/ktst_benchmark...   \n",
       "6497  /home/knowledge-tracing/outputs/ktst_benchmark...   \n",
       "6498  /home/knowledge-tracing/outputs/ktst_benchmark...   \n",
       "6499  /home/knowledge-tracing/outputs/ktst_benchmark...   \n",
       "\n",
       "                                            outputs_dir  \n",
       "0     /home/knowledge-tracing/outputs/ktst_benchmark...  \n",
       "1     /home/knowledge-tracing/outputs/ktst_benchmark...  \n",
       "2     /home/knowledge-tracing/outputs/ktst_benchmark...  \n",
       "3     /home/knowledge-tracing/outputs/ktst_benchmark...  \n",
       "4     /home/knowledge-tracing/outputs/ktst_benchmark...  \n",
       "...                                                 ...  \n",
       "6495  /home/knowledge-tracing/outputs/ktst_benchmark...  \n",
       "6496  /home/knowledge-tracing/outputs/ktst_benchmark...  \n",
       "6497  /home/knowledge-tracing/outputs/ktst_benchmark...  \n",
       "6498  /home/knowledge-tracing/outputs/ktst_benchmark...  \n",
       "6499  /home/knowledge-tracing/outputs/ktst_benchmark...  \n",
       "\n",
       "[6500 rows x 53 columns]"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_sweep"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead tr th {\n",
       "        text-align: left;\n",
       "    }\n",
       "\n",
       "    .dataframe thead tr:last-of-type th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th colspan=\"4\" halign=\"left\">val_auc</th>\n",
       "      <th colspan=\"4\" halign=\"left\">val_accuracy</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th>max</th>\n",
       "      <th>mean</th>\n",
       "      <th>std</th>\n",
       "      <th>count</th>\n",
       "      <th>max</th>\n",
       "      <th>mean</th>\n",
       "      <th>std</th>\n",
       "      <th>count</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>data.dataset</th>\n",
       "      <th>data.val_fold_idx</th>\n",
       "      <th>data.format</th>\n",
       "      <th>model.aggregation</th>\n",
       "      <th>model.attn_variant</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th rowspan=\"5\" valign=\"top\">algebra2005</th>\n",
       "      <th rowspan=\"3\" valign=\"top\">0</th>\n",
       "      <th>combinatorial_dense</th>\n",
       "      <th>q_mean_c</th>\n",
       "      <th>learnable_alibi_monotonic_q_k</th>\n",
       "      <td>0.848902</td>\n",
       "      <td>0.846287</td>\n",
       "      <td>0.002080</td>\n",
       "      <td>100</td>\n",
       "      <td>0.822141</td>\n",
       "      <td>0.820101</td>\n",
       "      <td>0.001336</td>\n",
       "      <td>100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"2\" valign=\"top\">set_dense</th>\n",
       "      <th>q_mean_c</th>\n",
       "      <th>learnable_alibi_monotonic_q_k</th>\n",
       "      <td>0.847878</td>\n",
       "      <td>0.845973</td>\n",
       "      <td>0.001793</td>\n",
       "      <td>100</td>\n",
       "      <td>0.821990</td>\n",
       "      <td>0.819990</td>\n",
       "      <td>0.001213</td>\n",
       "      <td>100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>self_attn_all</th>\n",
       "      <th>learnable_alibi_monotonic_q_k</th>\n",
       "      <td>0.825472</td>\n",
       "      <td>0.817858</td>\n",
       "      <td>0.003874</td>\n",
       "      <td>100</td>\n",
       "      <td>0.812764</td>\n",
       "      <td>0.805193</td>\n",
       "      <td>0.003891</td>\n",
       "      <td>100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"2\" valign=\"top\">1</th>\n",
       "      <th>combinatorial_dense</th>\n",
       "      <th>q_mean_c</th>\n",
       "      <th>learnable_alibi_monotonic_q_k</th>\n",
       "      <td>0.853817</td>\n",
       "      <td>0.851092</td>\n",
       "      <td>0.002586</td>\n",
       "      <td>100</td>\n",
       "      <td>0.824443</td>\n",
       "      <td>0.821866</td>\n",
       "      <td>0.001836</td>\n",
       "      <td>100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>set_dense</th>\n",
       "      <th>q_mean_c</th>\n",
       "      <th>learnable_alibi_monotonic_q_k</th>\n",
       "      <td>0.854022</td>\n",
       "      <td>0.851175</td>\n",
       "      <td>0.002402</td>\n",
       "      <td>100</td>\n",
       "      <td>0.824389</td>\n",
       "      <td>0.822133</td>\n",
       "      <td>0.001688</td>\n",
       "      <td>100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <th>...</th>\n",
       "      <th>...</th>\n",
       "      <th>...</th>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"5\" valign=\"top\">statics2011</th>\n",
       "      <th>0</th>\n",
       "      <th>set_dense</th>\n",
       "      <th>q_mean_c</th>\n",
       "      <th>learnable_alibi_monotonic_q_k</th>\n",
       "      <td>0.831489</td>\n",
       "      <td>0.762583</td>\n",
       "      <td>0.075448</td>\n",
       "      <td>100</td>\n",
       "      <td>0.819969</td>\n",
       "      <td>0.802202</td>\n",
       "      <td>0.017756</td>\n",
       "      <td>100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <th>set_dense</th>\n",
       "      <th>q_mean_c</th>\n",
       "      <th>learnable_alibi_monotonic_q_k</th>\n",
       "      <td>0.830879</td>\n",
       "      <td>0.750988</td>\n",
       "      <td>0.066099</td>\n",
       "      <td>100</td>\n",
       "      <td>0.807674</td>\n",
       "      <td>0.779541</td>\n",
       "      <td>0.021760</td>\n",
       "      <td>100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <th>set_dense</th>\n",
       "      <th>q_mean_c</th>\n",
       "      <th>learnable_alibi_monotonic_q_k</th>\n",
       "      <td>0.828428</td>\n",
       "      <td>0.747583</td>\n",
       "      <td>0.069108</td>\n",
       "      <td>100</td>\n",
       "      <td>0.805793</td>\n",
       "      <td>0.779715</td>\n",
       "      <td>0.019553</td>\n",
       "      <td>100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <th>set_dense</th>\n",
       "      <th>q_mean_c</th>\n",
       "      <th>learnable_alibi_monotonic_q_k</th>\n",
       "      <td>0.838067</td>\n",
       "      <td>0.769400</td>\n",
       "      <td>0.072530</td>\n",
       "      <td>100</td>\n",
       "      <td>0.810541</td>\n",
       "      <td>0.785942</td>\n",
       "      <td>0.022858</td>\n",
       "      <td>100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <th>set_dense</th>\n",
       "      <th>q_mean_c</th>\n",
       "      <th>learnable_alibi_monotonic_q_k</th>\n",
       "      <td>0.826554</td>\n",
       "      <td>0.760816</td>\n",
       "      <td>0.065914</td>\n",
       "      <td>100</td>\n",
       "      <td>0.806354</td>\n",
       "      <td>0.786535</td>\n",
       "      <td>0.017515</td>\n",
       "      <td>100</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>65 rows × 8 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "                                                                                                     val_auc  \\\n",
       "                                                                                                         max   \n",
       "data.dataset data.val_fold_idx data.format         model.aggregation model.attn_variant                        \n",
       "algebra2005  0                 combinatorial_dense q_mean_c          learnable_alibi_monotonic_q_k  0.848902   \n",
       "                               set_dense           q_mean_c          learnable_alibi_monotonic_q_k  0.847878   \n",
       "                                                   self_attn_all     learnable_alibi_monotonic_q_k  0.825472   \n",
       "             1                 combinatorial_dense q_mean_c          learnable_alibi_monotonic_q_k  0.853817   \n",
       "                               set_dense           q_mean_c          learnable_alibi_monotonic_q_k  0.854022   \n",
       "...                                                                                                      ...   \n",
       "statics2011  0                 set_dense           q_mean_c          learnable_alibi_monotonic_q_k  0.831489   \n",
       "             1                 set_dense           q_mean_c          learnable_alibi_monotonic_q_k  0.830879   \n",
       "             2                 set_dense           q_mean_c          learnable_alibi_monotonic_q_k  0.828428   \n",
       "             3                 set_dense           q_mean_c          learnable_alibi_monotonic_q_k  0.838067   \n",
       "             4                 set_dense           q_mean_c          learnable_alibi_monotonic_q_k  0.826554   \n",
       "\n",
       "                                                                                                              \\\n",
       "                                                                                                        mean   \n",
       "data.dataset data.val_fold_idx data.format         model.aggregation model.attn_variant                        \n",
       "algebra2005  0                 combinatorial_dense q_mean_c          learnable_alibi_monotonic_q_k  0.846287   \n",
       "                               set_dense           q_mean_c          learnable_alibi_monotonic_q_k  0.845973   \n",
       "                                                   self_attn_all     learnable_alibi_monotonic_q_k  0.817858   \n",
       "             1                 combinatorial_dense q_mean_c          learnable_alibi_monotonic_q_k  0.851092   \n",
       "                               set_dense           q_mean_c          learnable_alibi_monotonic_q_k  0.851175   \n",
       "...                                                                                                      ...   \n",
       "statics2011  0                 set_dense           q_mean_c          learnable_alibi_monotonic_q_k  0.762583   \n",
       "             1                 set_dense           q_mean_c          learnable_alibi_monotonic_q_k  0.750988   \n",
       "             2                 set_dense           q_mean_c          learnable_alibi_monotonic_q_k  0.747583   \n",
       "             3                 set_dense           q_mean_c          learnable_alibi_monotonic_q_k  0.769400   \n",
       "             4                 set_dense           q_mean_c          learnable_alibi_monotonic_q_k  0.760816   \n",
       "\n",
       "                                                                                                              \\\n",
       "                                                                                                         std   \n",
       "data.dataset data.val_fold_idx data.format         model.aggregation model.attn_variant                        \n",
       "algebra2005  0                 combinatorial_dense q_mean_c          learnable_alibi_monotonic_q_k  0.002080   \n",
       "                               set_dense           q_mean_c          learnable_alibi_monotonic_q_k  0.001793   \n",
       "                                                   self_attn_all     learnable_alibi_monotonic_q_k  0.003874   \n",
       "             1                 combinatorial_dense q_mean_c          learnable_alibi_monotonic_q_k  0.002586   \n",
       "                               set_dense           q_mean_c          learnable_alibi_monotonic_q_k  0.002402   \n",
       "...                                                                                                      ...   \n",
       "statics2011  0                 set_dense           q_mean_c          learnable_alibi_monotonic_q_k  0.075448   \n",
       "             1                 set_dense           q_mean_c          learnable_alibi_monotonic_q_k  0.066099   \n",
       "             2                 set_dense           q_mean_c          learnable_alibi_monotonic_q_k  0.069108   \n",
       "             3                 set_dense           q_mean_c          learnable_alibi_monotonic_q_k  0.072530   \n",
       "             4                 set_dense           q_mean_c          learnable_alibi_monotonic_q_k  0.065914   \n",
       "\n",
       "                                                                                                          \\\n",
       "                                                                                                   count   \n",
       "data.dataset data.val_fold_idx data.format         model.aggregation model.attn_variant                    \n",
       "algebra2005  0                 combinatorial_dense q_mean_c          learnable_alibi_monotonic_q_k   100   \n",
       "                               set_dense           q_mean_c          learnable_alibi_monotonic_q_k   100   \n",
       "                                                   self_attn_all     learnable_alibi_monotonic_q_k   100   \n",
       "             1                 combinatorial_dense q_mean_c          learnable_alibi_monotonic_q_k   100   \n",
       "                               set_dense           q_mean_c          learnable_alibi_monotonic_q_k   100   \n",
       "...                                                                                                  ...   \n",
       "statics2011  0                 set_dense           q_mean_c          learnable_alibi_monotonic_q_k   100   \n",
       "             1                 set_dense           q_mean_c          learnable_alibi_monotonic_q_k   100   \n",
       "             2                 set_dense           q_mean_c          learnable_alibi_monotonic_q_k   100   \n",
       "             3                 set_dense           q_mean_c          learnable_alibi_monotonic_q_k   100   \n",
       "             4                 set_dense           q_mean_c          learnable_alibi_monotonic_q_k   100   \n",
       "\n",
       "                                                                                                   val_accuracy  \\\n",
       "                                                                                                            max   \n",
       "data.dataset data.val_fold_idx data.format         model.aggregation model.attn_variant                           \n",
       "algebra2005  0                 combinatorial_dense q_mean_c          learnable_alibi_monotonic_q_k     0.822141   \n",
       "                               set_dense           q_mean_c          learnable_alibi_monotonic_q_k     0.821990   \n",
       "                                                   self_attn_all     learnable_alibi_monotonic_q_k     0.812764   \n",
       "             1                 combinatorial_dense q_mean_c          learnable_alibi_monotonic_q_k     0.824443   \n",
       "                               set_dense           q_mean_c          learnable_alibi_monotonic_q_k     0.824389   \n",
       "...                                                                                                         ...   \n",
       "statics2011  0                 set_dense           q_mean_c          learnable_alibi_monotonic_q_k     0.819969   \n",
       "             1                 set_dense           q_mean_c          learnable_alibi_monotonic_q_k     0.807674   \n",
       "             2                 set_dense           q_mean_c          learnable_alibi_monotonic_q_k     0.805793   \n",
       "             3                 set_dense           q_mean_c          learnable_alibi_monotonic_q_k     0.810541   \n",
       "             4                 set_dense           q_mean_c          learnable_alibi_monotonic_q_k     0.806354   \n",
       "\n",
       "                                                                                                              \\\n",
       "                                                                                                        mean   \n",
       "data.dataset data.val_fold_idx data.format         model.aggregation model.attn_variant                        \n",
       "algebra2005  0                 combinatorial_dense q_mean_c          learnable_alibi_monotonic_q_k  0.820101   \n",
       "                               set_dense           q_mean_c          learnable_alibi_monotonic_q_k  0.819990   \n",
       "                                                   self_attn_all     learnable_alibi_monotonic_q_k  0.805193   \n",
       "             1                 combinatorial_dense q_mean_c          learnable_alibi_monotonic_q_k  0.821866   \n",
       "                               set_dense           q_mean_c          learnable_alibi_monotonic_q_k  0.822133   \n",
       "...                                                                                                      ...   \n",
       "statics2011  0                 set_dense           q_mean_c          learnable_alibi_monotonic_q_k  0.802202   \n",
       "             1                 set_dense           q_mean_c          learnable_alibi_monotonic_q_k  0.779541   \n",
       "             2                 set_dense           q_mean_c          learnable_alibi_monotonic_q_k  0.779715   \n",
       "             3                 set_dense           q_mean_c          learnable_alibi_monotonic_q_k  0.785942   \n",
       "             4                 set_dense           q_mean_c          learnable_alibi_monotonic_q_k  0.786535   \n",
       "\n",
       "                                                                                                              \\\n",
       "                                                                                                         std   \n",
       "data.dataset data.val_fold_idx data.format         model.aggregation model.attn_variant                        \n",
       "algebra2005  0                 combinatorial_dense q_mean_c          learnable_alibi_monotonic_q_k  0.001336   \n",
       "                               set_dense           q_mean_c          learnable_alibi_monotonic_q_k  0.001213   \n",
       "                                                   self_attn_all     learnable_alibi_monotonic_q_k  0.003891   \n",
       "             1                 combinatorial_dense q_mean_c          learnable_alibi_monotonic_q_k  0.001836   \n",
       "                               set_dense           q_mean_c          learnable_alibi_monotonic_q_k  0.001688   \n",
       "...                                                                                                      ...   \n",
       "statics2011  0                 set_dense           q_mean_c          learnable_alibi_monotonic_q_k  0.017756   \n",
       "             1                 set_dense           q_mean_c          learnable_alibi_monotonic_q_k  0.021760   \n",
       "             2                 set_dense           q_mean_c          learnable_alibi_monotonic_q_k  0.019553   \n",
       "             3                 set_dense           q_mean_c          learnable_alibi_monotonic_q_k  0.022858   \n",
       "             4                 set_dense           q_mean_c          learnable_alibi_monotonic_q_k  0.017515   \n",
       "\n",
       "                                                                                                          \n",
       "                                                                                                   count  \n",
       "data.dataset data.val_fold_idx data.format         model.aggregation model.attn_variant                   \n",
       "algebra2005  0                 combinatorial_dense q_mean_c          learnable_alibi_monotonic_q_k   100  \n",
       "                               set_dense           q_mean_c          learnable_alibi_monotonic_q_k   100  \n",
       "                                                   self_attn_all     learnable_alibi_monotonic_q_k   100  \n",
       "             1                 combinatorial_dense q_mean_c          learnable_alibi_monotonic_q_k   100  \n",
       "                               set_dense           q_mean_c          learnable_alibi_monotonic_q_k   100  \n",
       "...                                                                                                  ...  \n",
       "statics2011  0                 set_dense           q_mean_c          learnable_alibi_monotonic_q_k   100  \n",
       "             1                 set_dense           q_mean_c          learnable_alibi_monotonic_q_k   100  \n",
       "             2                 set_dense           q_mean_c          learnable_alibi_monotonic_q_k   100  \n",
       "             3                 set_dense           q_mean_c          learnable_alibi_monotonic_q_k   100  \n",
       "             4                 set_dense           q_mean_c          learnable_alibi_monotonic_q_k   100  \n",
       "\n",
       "[65 rows x 8 columns]"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# more statistics\n",
    "df_sweep_groupby = df_sweep.groupby([\"data.dataset\", \"data.val_fold_idx\", \"data.format\", \"model.aggregation\", \"model.attn_variant\"])[[\"val_auc\", \"val_accuracy\"]].agg([\"max\", \"mean\", \"std\", \"count\"])\n",
    "df_sweep_groupby"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/home/knowledge-tracing/outputs/ktst_benchmark_minus_assist2009_sweep/train/2024-09-06/07-25-30-590680/95,/home/knowledge-tracing/outputs/ktst_benchmark_minus_assist2009_sweep/train/2024-09-04/10-54-30-497305/93,/home/knowledge-tracing/outputs/ktst_benchmark_minus_assist2009_sweep/train/2024-09-09/22-11-24-257901/82,/home/knowledge-tracing/outputs/ktst_benchmark_minus_assist2009_sweep/train/2024-09-06/08-08-04-378054/83,/home/knowledge-tracing/outputs/ktst_benchmark_minus_assist2009_sweep/train/2024-09-04/10-54-30-496901/22,/home/knowledge-tracing/outputs/ktst_benchmark_minus_assist2009_sweep/train/2024-09-10/02-23-02-461252/31,/home/knowledge-tracing/outputs/ktst_benchmark_minus_assist2009_sweep/train/2024-09-06/10-11-46-543684/41,/home/knowledge-tracing/outputs/ktst_benchmark_minus_assist2009_sweep/train/2024-09-04/10-54-30-498205/67,/home/knowledge-tracing/outputs/ktst_benchmark_minus_assist2009_sweep/train/2024-09-10/02-32-40-954955/65,/home/knowledge-tracing/outputs/ktst_benchmark_minus_assist2009_sweep/train/2024-09-07/15-16-42-963481/95,/home/knowledge-tracing/outputs/ktst_benchmark_minus_assist2009_sweep/train/2024-09-04/10-54-30-496077/11,/home/knowledge-tracing/outputs/ktst_benchmark_minus_assist2009_sweep/train/2024-09-10/02-47-04-105536/93,/home/knowledge-tracing/outputs/ktst_benchmark_minus_assist2009_sweep/train/2024-09-07/21-51-10-777065/95,/home/knowledge-tracing/outputs/ktst_benchmark_minus_assist2009_sweep/train/2024-09-04/10-54-30-496663/22,/home/knowledge-tracing/outputs/ktst_benchmark_minus_assist2009_sweep/train/2024-09-10/04-04-32-223657/41,/home/knowledge-tracing/outputs/ktst_benchmark_minus_assist2009_sweep/train/2024-09-05/15-00-56-744959/17,/home/knowledge-tracing/outputs/ktst_benchmark_minus_assist2009_sweep/train/2024-09-09/10-44-02-381439/67,/home/knowledge-tracing/outputs/ktst_benchmark_minus_assist2009_sweep/train/2024-09-05/15-00-56-746376/76,/home/knowledge-tracing/outputs/ktst_benchmark_minus_assist2009_sweep/train/2024-09-09/10-44-02-381440/92,/home/knowledge-tracing/outputs/ktst_benchmark_minus_assist2009_sweep/train/2024-09-05/15-00-56-744782/92,/home/knowledge-tracing/outputs/ktst_benchmark_minus_assist2009_sweep/train/2024-09-09/10-44-02-381586/46,/home/knowledge-tracing/outputs/ktst_benchmark_minus_assist2009_sweep/train/2024-09-05/15-00-56-760623/72,/home/knowledge-tracing/outputs/ktst_benchmark_minus_assist2009_sweep/train/2024-09-09/10-44-02-381449/41,/home/knowledge-tracing/outputs/ktst_benchmark_minus_assist2009_sweep/train/2024-09-06/06-50-50-329688/74,/home/knowledge-tracing/outputs/ktst_benchmark_minus_assist2009_sweep/train/2024-09-09/10-44-02-382031/21,/home/knowledge-tracing/outputs/ktst_benchmark_minus_assist2009_sweep/train/2024-09-09/04-18-10-748000/78,/home/knowledge-tracing/outputs/ktst_benchmark_minus_assist2009_sweep/train/2024-09-09/05-00-35-741381/99,/home/knowledge-tracing/outputs/ktst_benchmark_minus_assist2009_sweep/train/2024-09-09/09-22-30-500838/62,/home/knowledge-tracing/outputs/ktst_benchmark_minus_assist2009_sweep/train/2024-09-09/10-02-12-331210/68,/home/knowledge-tracing/outputs/ktst_benchmark_minus_assist2009_sweep/train/2024-09-09/11-32-12-798752/64,/home/knowledge-tracing/outputs/ktst_benchmark_minus_assist2009_sweep/train/2024-09-04/10-54-30-496636/92,/home/knowledge-tracing/outputs/ktst_benchmark_minus_assist2009_sweep/train/2024-09-04/10-54-30-497083/58,/home/knowledge-tracing/outputs/ktst_benchmark_minus_assist2009_sweep/train/2024-09-04/10-54-30-497249/73,/home/knowledge-tracing/outputs/ktst_benchmark_minus_assist2009_sweep/train/2024-09-06/03-19-01-177666/49,/home/knowledge-tracing/outputs/ktst_benchmark_minus_assist2009_sweep/train/2024-09-06/03-39-00-145053/42,/home/knowledge-tracing/outputs/ktst_benchmark_minus_assist2009_sweep/train/2024-09-08/02-27-04-432905/52,/home/knowledge-tracing/outputs/ktst_benchmark_minus_assist2009_sweep/train/2024-09-10/09-22-09-819739/90,/home/knowledge-tracing/outputs/ktst_benchmark_minus_assist2009_sweep/train/2024-09-10/09-05-02-834138/47,/home/knowledge-tracing/outputs/ktst_benchmark_minus_assist2009_sweep/train/2024-09-08/03-29-24-678510/95,/home/knowledge-tracing/outputs/ktst_benchmark_minus_assist2009_sweep/train/2024-09-10/10-03-37-422435/56,/home/knowledge-tracing/outputs/ktst_benchmark_minus_assist2009_sweep/train/2024-09-10/13-58-47-957805/12,/home/knowledge-tracing/outputs/ktst_benchmark_minus_assist2009_sweep/train/2024-09-09/00-09-58-012842/12,/home/knowledge-tracing/outputs/ktst_benchmark_minus_assist2009_sweep/train/2024-09-10/10-16-52-068777/96,/home/knowledge-tracing/outputs/ktst_benchmark_minus_assist2009_sweep/train/2024-09-10/14-21-28-936922/67,/home/knowledge-tracing/outputs/ktst_benchmark_minus_assist2009_sweep/train/2024-09-09/03-13-40-504794/12,/home/knowledge-tracing/outputs/ktst_benchmark_minus_assist2009_sweep/train/2024-09-11/06-41-30-722432/50,/home/knowledge-tracing/outputs/ktst_benchmark_minus_assist2009_sweep/train/2024-09-10/14-22-31-235278/44,/home/knowledge-tracing/outputs/ktst_benchmark_minus_assist2009_sweep/train/2024-09-09/03-51-41-252679/89,/home/knowledge-tracing/outputs/ktst_benchmark_minus_assist2009_sweep/train/2024-09-11/06-43-31-808500/95,/home/knowledge-tracing/outputs/ktst_benchmark_minus_assist2009_sweep/train/2024-09-10/15-38-35-459693/32,/home/knowledge-tracing/outputs/ktst_benchmark_minus_assist2009_sweep/train/2024-09-06/06-13-54-892263/53,/home/knowledge-tracing/outputs/ktst_benchmark_minus_assist2009_sweep/train/2024-09-06/06-46-32-335558/96,/home/knowledge-tracing/outputs/ktst_benchmark_minus_assist2009_sweep/train/2024-09-06/07-33-55-670682/16,/home/knowledge-tracing/outputs/ktst_benchmark_minus_assist2009_sweep/train/2024-09-07/19-28-25-823709/87,/home/knowledge-tracing/outputs/ktst_benchmark_minus_assist2009_sweep/train/2024-09-08/01-24-27-858197/57,/home/knowledge-tracing/outputs/ktst_benchmark_minus_assist2009_sweep/train/2024-09-09/19-45-30-465817/48,/home/knowledge-tracing/outputs/ktst_benchmark_minus_assist2009_sweep/train/2024-09-10/03-30-32-222161/94,/home/knowledge-tracing/outputs/ktst_benchmark_minus_assist2009_sweep/train/2024-09-10/08-38-13-350599/61,/home/knowledge-tracing/outputs/ktst_benchmark_minus_assist2009_sweep/train/2024-09-10/08-40-34-736440/93,/home/knowledge-tracing/outputs/ktst_benchmark_minus_assist2009_sweep/train/2024-09-10/09-21-15-756289/56,/home/knowledge-tracing/outputs/ktst_benchmark_minus_assist2009_sweep/train/2024-09-08/04-50-26-105758/57,/home/knowledge-tracing/outputs/ktst_benchmark_minus_assist2009_sweep/train/2024-09-08/16-31-56-240467/0,/home/knowledge-tracing/outputs/ktst_benchmark_minus_assist2009_sweep/train/2024-09-08/20-49-32-025688/83,/home/knowledge-tracing/outputs/ktst_benchmark_minus_assist2009_sweep/train/2024-09-08/23-39-50-586482/57,/home/knowledge-tracing/outputs/ktst_benchmark_minus_assist2009_sweep/train/2024-09-09/03-46-54-771138/21\n"
     ]
    }
   ],
   "source": [
    "max_idx = df_sweep.groupby([\"data.dataset\", \"data.val_fold_idx\", \"data.format\", \"model.aggregation\", \"model.attn_variant\"])[\"val_auc\"].idxmax()\n",
    "best_runs = df_sweep.iloc[max_idx]\n",
    "print(\",\".join(best_runs[\"run_dir\"].values))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "# test\n",
    "df_test = summarize_outputs_in_df(outputs_dir=test_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead tr th {\n",
       "        text-align: left;\n",
       "    }\n",
       "\n",
       "    .dataframe thead tr:last-of-type th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th colspan=\"3\" halign=\"left\">best_val_metric</th>\n",
       "      <th colspan=\"3\" halign=\"left\">test_auc</th>\n",
       "      <th colspan=\"3\" halign=\"left\">test_accuracy</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th>mean</th>\n",
       "      <th>std</th>\n",
       "      <th>count</th>\n",
       "      <th>mean</th>\n",
       "      <th>std</th>\n",
       "      <th>count</th>\n",
       "      <th>mean</th>\n",
       "      <th>std</th>\n",
       "      <th>count</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>data.dataset</th>\n",
       "      <th>data.format</th>\n",
       "      <th>model.aggregation</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th rowspan=\"3\" valign=\"top\">algebra2005</th>\n",
       "      <th>combinatorial_dense</th>\n",
       "      <th>q_mean_c</th>\n",
       "      <td>0.849403</td>\n",
       "      <td>0.002773</td>\n",
       "      <td>5</td>\n",
       "      <td>0.852903</td>\n",
       "      <td>0.000867</td>\n",
       "      <td>5</td>\n",
       "      <td>0.829075</td>\n",
       "      <td>0.000878</td>\n",
       "      <td>5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"2\" valign=\"top\">set_dense</th>\n",
       "      <th>q_mean_c</th>\n",
       "      <td>0.849260</td>\n",
       "      <td>0.002980</td>\n",
       "      <td>5</td>\n",
       "      <td>0.852205</td>\n",
       "      <td>0.000503</td>\n",
       "      <td>5</td>\n",
       "      <td>0.828660</td>\n",
       "      <td>0.000610</td>\n",
       "      <td>5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>self_attn_all</th>\n",
       "      <td>0.825729</td>\n",
       "      <td>0.001907</td>\n",
       "      <td>5</td>\n",
       "      <td>0.828489</td>\n",
       "      <td>0.000891</td>\n",
       "      <td>5</td>\n",
       "      <td>0.816181</td>\n",
       "      <td>0.001270</td>\n",
       "      <td>5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"2\" valign=\"top\">assist2009</th>\n",
       "      <th>combinatorial_dense</th>\n",
       "      <th>q_mean_c</th>\n",
       "      <td>0.790494</td>\n",
       "      <td>0.005015</td>\n",
       "      <td>5</td>\n",
       "      <td>0.799305</td>\n",
       "      <td>0.001651</td>\n",
       "      <td>5</td>\n",
       "      <td>0.749643</td>\n",
       "      <td>0.001373</td>\n",
       "      <td>5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>set_dense</th>\n",
       "      <th>self_attn_all</th>\n",
       "      <td>0.779177</td>\n",
       "      <td>0.005468</td>\n",
       "      <td>5</td>\n",
       "      <td>0.787790</td>\n",
       "      <td>0.002364</td>\n",
       "      <td>5</td>\n",
       "      <td>0.742304</td>\n",
       "      <td>0.001438</td>\n",
       "      <td>5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>assist2015</th>\n",
       "      <th>set_dense</th>\n",
       "      <th>q_mean_c</th>\n",
       "      <td>0.736846</td>\n",
       "      <td>0.002746</td>\n",
       "      <td>5</td>\n",
       "      <td>0.731416</td>\n",
       "      <td>0.000346</td>\n",
       "      <td>5</td>\n",
       "      <td>0.752331</td>\n",
       "      <td>0.000194</td>\n",
       "      <td>5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>bridge2algebra2006</th>\n",
       "      <th>set_dense</th>\n",
       "      <th>q_mean_c</th>\n",
       "      <td>0.823760</td>\n",
       "      <td>0.002862</td>\n",
       "      <td>5</td>\n",
       "      <td>0.826365</td>\n",
       "      <td>0.000401</td>\n",
       "      <td>5</td>\n",
       "      <td>0.860820</td>\n",
       "      <td>0.000593</td>\n",
       "      <td>5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"3\" valign=\"top\">ednet</th>\n",
       "      <th>combinatorial_dense</th>\n",
       "      <th>q_mean_c</th>\n",
       "      <td>0.731776</td>\n",
       "      <td>0.003501</td>\n",
       "      <td>5</td>\n",
       "      <td>0.735167</td>\n",
       "      <td>0.001152</td>\n",
       "      <td>5</td>\n",
       "      <td>0.713917</td>\n",
       "      <td>0.001737</td>\n",
       "      <td>5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"2\" valign=\"top\">set_dense</th>\n",
       "      <th>q_mean_c</th>\n",
       "      <td>0.736491</td>\n",
       "      <td>0.002852</td>\n",
       "      <td>5</td>\n",
       "      <td>0.739416</td>\n",
       "      <td>0.000262</td>\n",
       "      <td>5</td>\n",
       "      <td>0.715404</td>\n",
       "      <td>0.001243</td>\n",
       "      <td>5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>self_attn_all</th>\n",
       "      <td>0.735763</td>\n",
       "      <td>0.002849</td>\n",
       "      <td>5</td>\n",
       "      <td>0.738962</td>\n",
       "      <td>0.000899</td>\n",
       "      <td>5</td>\n",
       "      <td>0.715431</td>\n",
       "      <td>0.000909</td>\n",
       "      <td>5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>nips_task34</th>\n",
       "      <th>set_dense</th>\n",
       "      <th>q_mean_c</th>\n",
       "      <td>0.796326</td>\n",
       "      <td>0.001292</td>\n",
       "      <td>5</td>\n",
       "      <td>0.807055</td>\n",
       "      <td>0.000044</td>\n",
       "      <td>5</td>\n",
       "      <td>0.735552</td>\n",
       "      <td>0.000309</td>\n",
       "      <td>5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>poj</th>\n",
       "      <th>set_dense</th>\n",
       "      <th>q_mean_c</th>\n",
       "      <td>0.680013</td>\n",
       "      <td>0.017118</td>\n",
       "      <td>5</td>\n",
       "      <td>0.634721</td>\n",
       "      <td>0.001272</td>\n",
       "      <td>5</td>\n",
       "      <td>0.656844</td>\n",
       "      <td>0.000860</td>\n",
       "      <td>5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>statics2011</th>\n",
       "      <th>set_dense</th>\n",
       "      <th>q_mean_c</th>\n",
       "      <td>0.831083</td>\n",
       "      <td>0.004375</td>\n",
       "      <td>5</td>\n",
       "      <td>0.829128</td>\n",
       "      <td>0.000999</td>\n",
       "      <td>5</td>\n",
       "      <td>0.800194</td>\n",
       "      <td>0.001614</td>\n",
       "      <td>5</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                                         best_val_metric  \\\n",
       "                                                                    mean   \n",
       "data.dataset       data.format         model.aggregation                   \n",
       "algebra2005        combinatorial_dense q_mean_c                 0.849403   \n",
       "                   set_dense           q_mean_c                 0.849260   \n",
       "                                       self_attn_all            0.825729   \n",
       "assist2009         combinatorial_dense q_mean_c                 0.790494   \n",
       "                   set_dense           self_attn_all            0.779177   \n",
       "assist2015         set_dense           q_mean_c                 0.736846   \n",
       "bridge2algebra2006 set_dense           q_mean_c                 0.823760   \n",
       "ednet              combinatorial_dense q_mean_c                 0.731776   \n",
       "                   set_dense           q_mean_c                 0.736491   \n",
       "                                       self_attn_all            0.735763   \n",
       "nips_task34        set_dense           q_mean_c                 0.796326   \n",
       "poj                set_dense           q_mean_c                 0.680013   \n",
       "statics2011        set_dense           q_mean_c                 0.831083   \n",
       "\n",
       "                                                                          \\\n",
       "                                                               std count   \n",
       "data.dataset       data.format         model.aggregation                   \n",
       "algebra2005        combinatorial_dense q_mean_c           0.002773     5   \n",
       "                   set_dense           q_mean_c           0.002980     5   \n",
       "                                       self_attn_all      0.001907     5   \n",
       "assist2009         combinatorial_dense q_mean_c           0.005015     5   \n",
       "                   set_dense           self_attn_all      0.005468     5   \n",
       "assist2015         set_dense           q_mean_c           0.002746     5   \n",
       "bridge2algebra2006 set_dense           q_mean_c           0.002862     5   \n",
       "ednet              combinatorial_dense q_mean_c           0.003501     5   \n",
       "                   set_dense           q_mean_c           0.002852     5   \n",
       "                                       self_attn_all      0.002849     5   \n",
       "nips_task34        set_dense           q_mean_c           0.001292     5   \n",
       "poj                set_dense           q_mean_c           0.017118     5   \n",
       "statics2011        set_dense           q_mean_c           0.004375     5   \n",
       "\n",
       "                                                          test_auc            \\\n",
       "                                                              mean       std   \n",
       "data.dataset       data.format         model.aggregation                       \n",
       "algebra2005        combinatorial_dense q_mean_c           0.852903  0.000867   \n",
       "                   set_dense           q_mean_c           0.852205  0.000503   \n",
       "                                       self_attn_all      0.828489  0.000891   \n",
       "assist2009         combinatorial_dense q_mean_c           0.799305  0.001651   \n",
       "                   set_dense           self_attn_all      0.787790  0.002364   \n",
       "assist2015         set_dense           q_mean_c           0.731416  0.000346   \n",
       "bridge2algebra2006 set_dense           q_mean_c           0.826365  0.000401   \n",
       "ednet              combinatorial_dense q_mean_c           0.735167  0.001152   \n",
       "                   set_dense           q_mean_c           0.739416  0.000262   \n",
       "                                       self_attn_all      0.738962  0.000899   \n",
       "nips_task34        set_dense           q_mean_c           0.807055  0.000044   \n",
       "poj                set_dense           q_mean_c           0.634721  0.001272   \n",
       "statics2011        set_dense           q_mean_c           0.829128  0.000999   \n",
       "\n",
       "                                                               test_accuracy  \\\n",
       "                                                         count          mean   \n",
       "data.dataset       data.format         model.aggregation                       \n",
       "algebra2005        combinatorial_dense q_mean_c              5      0.829075   \n",
       "                   set_dense           q_mean_c              5      0.828660   \n",
       "                                       self_attn_all         5      0.816181   \n",
       "assist2009         combinatorial_dense q_mean_c              5      0.749643   \n",
       "                   set_dense           self_attn_all         5      0.742304   \n",
       "assist2015         set_dense           q_mean_c              5      0.752331   \n",
       "bridge2algebra2006 set_dense           q_mean_c              5      0.860820   \n",
       "ednet              combinatorial_dense q_mean_c              5      0.713917   \n",
       "                   set_dense           q_mean_c              5      0.715404   \n",
       "                                       self_attn_all         5      0.715431   \n",
       "nips_task34        set_dense           q_mean_c              5      0.735552   \n",
       "poj                set_dense           q_mean_c              5      0.656844   \n",
       "statics2011        set_dense           q_mean_c              5      0.800194   \n",
       "\n",
       "                                                                          \n",
       "                                                               std count  \n",
       "data.dataset       data.format         model.aggregation                  \n",
       "algebra2005        combinatorial_dense q_mean_c           0.000878     5  \n",
       "                   set_dense           q_mean_c           0.000610     5  \n",
       "                                       self_attn_all      0.001270     5  \n",
       "assist2009         combinatorial_dense q_mean_c           0.001373     5  \n",
       "                   set_dense           self_attn_all      0.001438     5  \n",
       "assist2015         set_dense           q_mean_c           0.000194     5  \n",
       "bridge2algebra2006 set_dense           q_mean_c           0.000593     5  \n",
       "ednet              combinatorial_dense q_mean_c           0.001737     5  \n",
       "                   set_dense           q_mean_c           0.001243     5  \n",
       "                                       self_attn_all      0.000909     5  \n",
       "nips_task34        set_dense           q_mean_c           0.000309     5  \n",
       "poj                set_dense           q_mean_c           0.000860     5  \n",
       "statics2011        set_dense           q_mean_c           0.001614     5  "
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_test_results = df_test.groupby([\"data.dataset\", \"data.format\", \"model.aggregation\"])[[\"best_val_metric\", \"test_auc\", \"test_accuracy\"]].agg([\"mean\", \"std\", \"count\"])\n",
    "df_test_results"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "dev",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
