{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "50d76b8d-75c1-4ca9-94bd-47bad003d51f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import learn2learn as l2l\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import random\n",
    "import torch\n",
    "from torch import nn\n",
    "import pickle\n",
    "import cv2\n",
    "from xmeta.utils.sift import SiftFeature\n",
    "from importlib import reload \n",
    "from xmeta.utils.seed import set_seed\n",
    "from xmeta.utils.visualize import show_task, show_predictions\n",
    "from xmeta.utils.data import degrade_images, get_tasksets\n",
    "from xmeta.utils.experiment import draw_df_row\n",
    "from xmeta.utils.visualize import plot_eigenvalues\n",
    "from xmeta.maml.maml import xfast_adapt, explain_test_performance, explan_adaptation, setup_experiment\n",
    "from xmeta.maml.degraded_task import loop_alpha, loop_ratio\n",
    "import pickle\n",
    "from tqdm import tqdm\n",
    "import pandas as pd\n",
    "import dill\n",
    "import os\n",
    "import pandas as pd\n",
    "from torchsummary import summary\n",
    "import pandas as pd\n",
    "\n",
    "seed = 42\n",
    "\n",
    "loss = nn.CrossEntropyLoss(reduction='mean')\n",
    "device = torch.device('cuda')\n",
    "\n",
    "def get_preprocess(feature, cuda=False):\n",
    "    if cuda:\n",
    "        _device = torch.device('cuda')\n",
    "    else:\n",
    "        _device = torch.device('cpu')\n",
    "    def prep(x, cuda=False):\n",
    "        d, lbl = x\n",
    "        x = [feature(d).to(_device), lbl.to(_device)]\n",
    "        return x\n",
    "    return prep"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "73549378-a2d0-488c-9ba7-fb16590b52ec",
   "metadata": {},
   "outputs": [],
   "source": [
    "def _series_correlation(s0, s1, plot_data=False):\n",
    "    v_0 = pd.Series(s0['train_task_score'], index=s0['train_task_idx']).sort_index()\n",
    "    v_1 = pd.Series(s1['train_task_score'], index=s1['train_task_idx']).sort_index()\n",
    "    if plot_data:\n",
    "        plt.scatter(v_0.values, v_1.values)\n",
    "    return v_0.corr(v_1)\n",
    "\n",
    "def df_correlation(df0, df1):\n",
    "    assert len(df0) == len(df1)\n",
    "    df0 = df0.sort_values(['test_task_idx'])\n",
    "    df1 = df1.sort_values(['test_task_idx'])\n",
    "    \n",
    "    correlations = []\n",
    "    for ii in range(len(df0)):\n",
    "       correlations.append(\n",
    "           _series_correlation(df0.iloc[ii], df1.iloc[ii]))\n",
    "\n",
    "    df = pd.DataFrame({\n",
    "           'test_task_idx': df0['test_task_idx'],\n",
    "           'score_0': df0['train_task_score'],\n",
    "           'score_1': df1['train_task_score'],\n",
    "           'corr': correlations})\n",
    "    return df\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e582955a-1626-4a1c-b893-762ad6e5d95f",
   "metadata": {},
   "source": [
    "## 128 tasks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "593d79b9-6202-4509-86e5-ff4e3ddf3207",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "found ../examples/maml_l2l/cache/2024-0912-191331/\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "loaded ../examples/maml_l2l/cache/2024-0912-191331/expl_maml_k36_layer32_tasks128_mbs32_ways5_shots5_1000.pkl\n",
      "Computing pseudo-inverse of hessian\n",
      "done  (shape (1413, 1413))\n",
      "set src_param_matrix (shape (1413, 128))\n",
      "done\n",
      "src_param_matrix: (1413, 128)\n"
     ]
    }
   ],
   "source": [
    "k = 36\n",
    "fft_crop_size = 6\n",
    "ways = 5\n",
    "shots = 5\n",
    "num_tasks = 128\n",
    "num_test_tasks = 128\n",
    "experiment_dir = '../examples/maml_l2l/cache/2024-0912-191331/'\n",
    "explainer_path = experiment_dir + 'expl_maml_k36_layer32_tasks128_mbs32_ways5_shots5_1000.pkl'\n",
    "tasks_train, tasks_test, explainer, maml, feature, impurity_dict =\\\n",
    "    setup_experiment(ways=ways, shots=shots, num_tasks=num_tasks, \n",
    "    experiment_dir=experiment_dir, \n",
    "    explainer_path=explainer_path, \n",
    "    fft_crop_size=fft_crop_size,\n",
    "    dataset='omniglot'\n",
    "    )\n",
    "_preprocess = get_preprocess(feature, cuda=True)\n",
    "explainer.set_src_generalized_matrix(n_positive_ev=128)\n",
    "print('done')\n",
    "print('src_param_matrix:', explainer.src_param_matrix.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5a8f5e62-29ea-461c-bfd3-1628edf57b06",
   "metadata": {},
   "source": [
    "### test with train tasks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "dbfba28e-0534-4c6b-99b1-feac51508135",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 128/128 [00:43<00:00,  2.94it/s]\n"
     ]
    }
   ],
   "source": [
    "df_exact = explain_test_performance(explainer, tasks_train, tasks_train,\n",
    "                                  preprocess=_preprocess,\n",
    "                                  loss=loss,\n",
    "                                  shots=shots,\n",
    "                                  ways=ways,\n",
    "                                  num_train_task=num_tasks,\n",
    "                                  num_test_task=num_test_tasks\n",
    "                                  )\n",
    "\n",
    "pkl_path = os.path.join(experiment_dir, 'df_exact.pkl')\n",
    "with open(pkl_path, 'wb') as f:\n",
    "    pickle.dump(df_exact, f)\n",
    "\n",
    "index_dict={'train_noise_tasks':[], 'train_shuffle_tasks': []}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "e352c86f-198f-4dd9-b359-1137a985bb89",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>test_task_idx</th>\n",
       "      <th>test_accuracy</th>\n",
       "      <th>test_error</th>\n",
       "      <th>adaptation_accuracy</th>\n",
       "      <th>adaptation_error</th>\n",
       "      <th>train_accuracy</th>\n",
       "      <th>train_error</th>\n",
       "      <th>zeroshot_accuracy</th>\n",
       "      <th>zeroshot_error</th>\n",
       "      <th>train_task_idx</th>\n",
       "      <th>train_task_score</th>\n",
       "      <th>self_rank</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>74</th>\n",
       "      <td>74</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.00047049933</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0065798615</td>\n",
       "      <td>0.12</td>\n",
       "      <td>0.12</td>\n",
       "      <td>0.16</td>\n",
       "      <td>3.2498162</td>\n",
       "      <td>[100, 121, 48, 57, 9, 46, 77, 87, 66, 56, 122,...</td>\n",
       "      <td>[0.0001010524938465096, 8.810789586277679e-05,...</td>\n",
       "      <td>41</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>113</th>\n",
       "      <td>113</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0023007998</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0011258704</td>\n",
       "      <td>0.12</td>\n",
       "      <td>0.12</td>\n",
       "      <td>0.16</td>\n",
       "      <td>3.604068</td>\n",
       "      <td>[32, 45, 117, 6, 39, 86, 87, 122, 95, 83, 7, 9...</td>\n",
       "      <td>[0.00028579021454788744, 0.0002441526157781482...</td>\n",
       "      <td>38</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>109</th>\n",
       "      <td>109</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0037004738</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0028735157</td>\n",
       "      <td>0.39999998</td>\n",
       "      <td>0.39999998</td>\n",
       "      <td>0.39999998</td>\n",
       "      <td>2.2986786</td>\n",
       "      <td>[35, 84, 116, 57, 109, 32, 49, 27, 73, 17, 106...</td>\n",
       "      <td>[0.001666567288339138, 0.0012091314420104027, ...</td>\n",
       "      <td>4</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>65</th>\n",
       "      <td>65</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0041467366</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.013692132</td>\n",
       "      <td>0.19999999</td>\n",
       "      <td>0.19999999</td>\n",
       "      <td>0.28</td>\n",
       "      <td>3.281961</td>\n",
       "      <td>[86, 124, 117, 36, 95, 65, 25, 32, 49, 17, 11,...</td>\n",
       "      <td>[0.0008983498555608094, 0.000656173680908978, ...</td>\n",
       "      <td>5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>69</th>\n",
       "      <td>69</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0041564885</td>\n",
       "      <td>0.96</td>\n",
       "      <td>0.080072835</td>\n",
       "      <td>0.08</td>\n",
       "      <td>0.08</td>\n",
       "      <td>0.08</td>\n",
       "      <td>3.1572807</td>\n",
       "      <td>[87, 31, 43, 90, 123, 69, 51, 103, 48, 32, 46,...</td>\n",
       "      <td>[0.0007710354984737933, 0.0006775240763090551,...</td>\n",
       "      <td>5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>0.96</td>\n",
       "      <td>0.15310507</td>\n",
       "      <td>0.91999996</td>\n",
       "      <td>0.20023714</td>\n",
       "      <td>0.19999999</td>\n",
       "      <td>0.19999999</td>\n",
       "      <td>0.28</td>\n",
       "      <td>3.572054</td>\n",
       "      <td>[1, 73, 80, 14, 124, 115, 51, 86, 32, 29, 7, 1...</td>\n",
       "      <td>[0.07570473104715347, 0.011496552266180515, 0....</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>85</th>\n",
       "      <td>85</td>\n",
       "      <td>0.88</td>\n",
       "      <td>0.17959814</td>\n",
       "      <td>0.91999996</td>\n",
       "      <td>0.09479967</td>\n",
       "      <td>0.35999998</td>\n",
       "      <td>0.35999998</td>\n",
       "      <td>0.16</td>\n",
       "      <td>2.6793048</td>\n",
       "      <td>[85, 3, 115, 73, 20, 7, 12, 114, 84, 91, 97, 5...</td>\n",
       "      <td>[0.06142376735806465, 0.00852006021887064, 0.0...</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>25</th>\n",
       "      <td>25</td>\n",
       "      <td>0.96</td>\n",
       "      <td>0.1825858</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0013020827</td>\n",
       "      <td>0.08</td>\n",
       "      <td>0.08</td>\n",
       "      <td>0.19999999</td>\n",
       "      <td>2.3373556</td>\n",
       "      <td>[25, 83, 15, 12, 37, 32, 127, 35, 51, 49, 123,...</td>\n",
       "      <td>[0.04428214952349663, 0.013538037426769733, 0....</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19</th>\n",
       "      <td>19</td>\n",
       "      <td>0.96</td>\n",
       "      <td>0.19646795</td>\n",
       "      <td>0.96</td>\n",
       "      <td>0.10896608</td>\n",
       "      <td>0.08</td>\n",
       "      <td>0.08</td>\n",
       "      <td>0.12</td>\n",
       "      <td>3.1462817</td>\n",
       "      <td>[19, 127, 87, 28, 31, 62, 72, 40, 66, 4, 77, 7...</td>\n",
       "      <td>[0.09000102430582047, 0.015649262815713882, 0....</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>73</th>\n",
       "      <td>73</td>\n",
       "      <td>0.96</td>\n",
       "      <td>0.22557537</td>\n",
       "      <td>0.96</td>\n",
       "      <td>0.23264027</td>\n",
       "      <td>0.16</td>\n",
       "      <td>0.16</td>\n",
       "      <td>0.24</td>\n",
       "      <td>2.5233762</td>\n",
       "      <td>[73, 49, 1, 0, 87, 14, 120, 85, 37, 29, 95, 32...</td>\n",
       "      <td>[0.06637593358755112, 0.011804599314928055, 0....</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>128 rows × 12 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "     test_task_idx test_accuracy     test_error adaptation_accuracy  \\\n",
       "74              74           1.0  0.00047049933                 1.0   \n",
       "113            113           1.0   0.0023007998                 1.0   \n",
       "109            109           1.0   0.0037004738                 1.0   \n",
       "65              65           1.0   0.0041467366                 1.0   \n",
       "69              69           1.0   0.0041564885                0.96   \n",
       "..             ...           ...            ...                 ...   \n",
       "1                1          0.96     0.15310507          0.91999996   \n",
       "85              85          0.88     0.17959814          0.91999996   \n",
       "25              25          0.96      0.1825858                 1.0   \n",
       "19              19          0.96     0.19646795                0.96   \n",
       "73              73          0.96     0.22557537                0.96   \n",
       "\n",
       "    adaptation_error train_accuracy train_error zeroshot_accuracy  \\\n",
       "74      0.0065798615           0.12        0.12              0.16   \n",
       "113     0.0011258704           0.12        0.12              0.16   \n",
       "109     0.0028735157     0.39999998  0.39999998        0.39999998   \n",
       "65       0.013692132     0.19999999  0.19999999              0.28   \n",
       "69       0.080072835           0.08        0.08              0.08   \n",
       "..               ...            ...         ...               ...   \n",
       "1         0.20023714     0.19999999  0.19999999              0.28   \n",
       "85        0.09479967     0.35999998  0.35999998              0.16   \n",
       "25      0.0013020827           0.08        0.08        0.19999999   \n",
       "19        0.10896608           0.08        0.08              0.12   \n",
       "73        0.23264027           0.16        0.16              0.24   \n",
       "\n",
       "    zeroshot_error                                     train_task_idx  \\\n",
       "74       3.2498162  [100, 121, 48, 57, 9, 46, 77, 87, 66, 56, 122,...   \n",
       "113       3.604068  [32, 45, 117, 6, 39, 86, 87, 122, 95, 83, 7, 9...   \n",
       "109      2.2986786  [35, 84, 116, 57, 109, 32, 49, 27, 73, 17, 106...   \n",
       "65        3.281961  [86, 124, 117, 36, 95, 65, 25, 32, 49, 17, 11,...   \n",
       "69       3.1572807  [87, 31, 43, 90, 123, 69, 51, 103, 48, 32, 46,...   \n",
       "..             ...                                                ...   \n",
       "1         3.572054  [1, 73, 80, 14, 124, 115, 51, 86, 32, 29, 7, 1...   \n",
       "85       2.6793048  [85, 3, 115, 73, 20, 7, 12, 114, 84, 91, 97, 5...   \n",
       "25       2.3373556  [25, 83, 15, 12, 37, 32, 127, 35, 51, 49, 123,...   \n",
       "19       3.1462817  [19, 127, 87, 28, 31, 62, 72, 40, 66, 4, 77, 7...   \n",
       "73       2.5233762  [73, 49, 1, 0, 87, 14, 120, 85, 37, 29, 95, 32...   \n",
       "\n",
       "                                      train_task_score  self_rank  \n",
       "74   [0.0001010524938465096, 8.810789586277679e-05,...         41  \n",
       "113  [0.00028579021454788744, 0.0002441526157781482...         38  \n",
       "109  [0.001666567288339138, 0.0012091314420104027, ...          4  \n",
       "65   [0.0008983498555608094, 0.000656173680908978, ...          5  \n",
       "69   [0.0007710354984737933, 0.0006775240763090551,...          5  \n",
       "..                                                 ...        ...  \n",
       "1    [0.07570473104715347, 0.011496552266180515, 0....          0  \n",
       "85   [0.06142376735806465, 0.00852006021887064, 0.0...          0  \n",
       "25   [0.04428214952349663, 0.013538037426769733, 0....          0  \n",
       "19   [0.09000102430582047, 0.015649262815713882, 0....          0  \n",
       "73   [0.06637593358755112, 0.011804599314928055, 0....          0  \n",
       "\n",
       "[128 rows x 12 columns]"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_exact['self_rank'] = df_exact.apply(\n",
    "    lambda row: row['train_task_idx'].index(row['test_task_idx']), axis=1)\n",
    "df_exact.sort_values(['test_error'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "995f2ba5-7cb2-41b7-b11d-ea7290e6fd67",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "found ../examples/maml_l2l/cache/2024-0912-191331/\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "loaded ../examples/maml_l2l/cache/2024-0912-191331/opa_nh512_expl_maml_k36_layer32_tasks128_mbs32_ways5_shots5_1000.pkl\n",
      "done\n",
      "src_param_matrix: (1413, 128)\n"
     ]
    }
   ],
   "source": [
    "k = 36\n",
    "fft_crop_size = 6\n",
    "ways = 5\n",
    "shots = 5\n",
    "num_tasks = 128\n",
    "num_test_tasks = 128\n",
    "experiment_dir = '../examples/maml_l2l/cache/2024-0912-191331/'\n",
    "explainer_path = experiment_dir + 'opa_nh512_expl_maml_k36_layer32_tasks128_mbs32_ways5_shots5_1000.pkl'\n",
    "tasks_train, tasks_test, explainer_opa, maml, feature, impurity_dict =\\\n",
    "    setup_experiment(ways=ways, shots=shots, num_tasks=num_tasks, \n",
    "    experiment_dir=experiment_dir, \n",
    "    explainer_path=explainer_path, \n",
    "    fft_crop_size=fft_crop_size,\n",
    "    dataset='omniglot'\n",
    "    )\n",
    "_preprocess = get_preprocess(feature, cuda=True)\n",
    "print('done')\n",
    "print('src_param_matrix:', explainer.src_param_matrix.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "df58f498-f401-40bc-ad0a-8c3fe671cb55",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 128/128 [00:42<00:00,  2.99it/s]\n"
     ]
    }
   ],
   "source": [
    "df_opa = explain_test_performance(explainer_opa, tasks_train, tasks_train,\n",
    "                                  preprocess=_preprocess,\n",
    "                                  loss=loss,\n",
    "                                  shots=shots,\n",
    "                                  ways=ways,\n",
    "                                  num_train_task=num_tasks,\n",
    "                                  num_test_task=num_test_tasks\n",
    "                                  )\n",
    "\n",
    "pkl_path = os.path.join(experiment_dir, 'df_opa.pkl')\n",
    "with open(pkl_path, 'wb') as f:\n",
    "    pickle.dump(df_opa, f)\n",
    "\n",
    "index_dict={'train_noise_tasks':[], 'train_shuffle_tasks': []}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "7f61347c-e9b9-4e20-bd69-72fd88a96d08",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<class 'numpy.ndarray'>\n",
      "<class 'xmeta.utils.opa.CrossEntropyHessian'>\n"
     ]
    }
   ],
   "source": [
    "print(type(explainer.src_test_hessian))\n",
    "print(type(explainer_opa.src_test_hessian))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "38a1e311-d5d6-407d-8712-82f542d8a4ee",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>test_task_idx</th>\n",
       "      <th>test_accuracy</th>\n",
       "      <th>test_error</th>\n",
       "      <th>adaptation_accuracy</th>\n",
       "      <th>adaptation_error</th>\n",
       "      <th>train_accuracy</th>\n",
       "      <th>train_error</th>\n",
       "      <th>zeroshot_accuracy</th>\n",
       "      <th>zeroshot_error</th>\n",
       "      <th>train_task_idx</th>\n",
       "      <th>train_task_score</th>\n",
       "      <th>self_rank</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>74</th>\n",
       "      <td>74</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.00047049933</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0065798615</td>\n",
       "      <td>0.12</td>\n",
       "      <td>0.12</td>\n",
       "      <td>0.16</td>\n",
       "      <td>3.2498162</td>\n",
       "      <td>[29, 85, 48, 15, 40, 71, 121, 31, 19, 70, 5, 8...</td>\n",
       "      <td>[7.122623628674773e-07, 6.029400765328319e-07,...</td>\n",
       "      <td>25</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>113</th>\n",
       "      <td>113</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0023007998</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0011258704</td>\n",
       "      <td>0.12</td>\n",
       "      <td>0.12</td>\n",
       "      <td>0.16</td>\n",
       "      <td>3.604068</td>\n",
       "      <td>[117, 35, 56, 15, 3, 113, 87, 45, 31, 80, 83, ...</td>\n",
       "      <td>[4.0740278564044274e-06, 1.916644578159321e-06...</td>\n",
       "      <td>5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>109</th>\n",
       "      <td>109</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0037004738</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0028735157</td>\n",
       "      <td>0.39999998</td>\n",
       "      <td>0.39999998</td>\n",
       "      <td>0.39999998</td>\n",
       "      <td>2.2986786</td>\n",
       "      <td>[109, 35, 88, 49, 38, 99, 40, 79, 124, 32, 91,...</td>\n",
       "      <td>[1.413117024640087e-05, 5.671791768691037e-06,...</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>65</th>\n",
       "      <td>65</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0041467366</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.013692132</td>\n",
       "      <td>0.19999999</td>\n",
       "      <td>0.19999999</td>\n",
       "      <td>0.28</td>\n",
       "      <td>3.281961</td>\n",
       "      <td>[65, 87, 43, 120, 83, 84, 75, 1, 30, 117, 86, ...</td>\n",
       "      <td>[1.0569947335170582e-05, 6.356313406286063e-06...</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>69</th>\n",
       "      <td>69</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0041564885</td>\n",
       "      <td>0.96</td>\n",
       "      <td>0.080072835</td>\n",
       "      <td>0.08</td>\n",
       "      <td>0.08</td>\n",
       "      <td>0.08</td>\n",
       "      <td>3.1572807</td>\n",
       "      <td>[69, 46, 31, 97, 87, 48, 29, 43, 40, 123, 90, ...</td>\n",
       "      <td>[8.947985406848602e-06, 7.138767614378594e-06,...</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>0.96</td>\n",
       "      <td>0.15310507</td>\n",
       "      <td>0.91999996</td>\n",
       "      <td>0.20023714</td>\n",
       "      <td>0.19999999</td>\n",
       "      <td>0.19999999</td>\n",
       "      <td>0.28</td>\n",
       "      <td>3.572054</td>\n",
       "      <td>[1, 70, 78, 46, 36, 83, 88, 8, 12, 49, 121, 66...</td>\n",
       "      <td>[0.0012244611280038953, 4.2393963667564094e-05...</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>85</th>\n",
       "      <td>85</td>\n",
       "      <td>0.88</td>\n",
       "      <td>0.17959814</td>\n",
       "      <td>0.91999996</td>\n",
       "      <td>0.09479967</td>\n",
       "      <td>0.35999998</td>\n",
       "      <td>0.35999998</td>\n",
       "      <td>0.16</td>\n",
       "      <td>2.6793048</td>\n",
       "      <td>[85, 35, 12, 114, 56, 15, 29, 43, 112, 83, 86,...</td>\n",
       "      <td>[0.001092283520847559, 4.8956415412249044e-05,...</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>25</th>\n",
       "      <td>25</td>\n",
       "      <td>0.96</td>\n",
       "      <td>0.1825858</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0013020827</td>\n",
       "      <td>0.08</td>\n",
       "      <td>0.08</td>\n",
       "      <td>0.19999999</td>\n",
       "      <td>2.3373556</td>\n",
       "      <td>[25, 32, 80, 49, 86, 63, 124, 37, 71, 39, 53, ...</td>\n",
       "      <td>[0.0010653516510501504, 5.0167749577667564e-05...</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19</th>\n",
       "      <td>19</td>\n",
       "      <td>0.96</td>\n",
       "      <td>0.19646795</td>\n",
       "      <td>0.96</td>\n",
       "      <td>0.10896608</td>\n",
       "      <td>0.08</td>\n",
       "      <td>0.08</td>\n",
       "      <td>0.12</td>\n",
       "      <td>3.1462817</td>\n",
       "      <td>[19, 83, 75, 88, 99, 39, 56, 70, 77, 49, 105, ...</td>\n",
       "      <td>[0.0013941762736067176, 4.528206409304403e-05,...</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>73</th>\n",
       "      <td>73</td>\n",
       "      <td>0.96</td>\n",
       "      <td>0.22557537</td>\n",
       "      <td>0.96</td>\n",
       "      <td>0.23264027</td>\n",
       "      <td>0.16</td>\n",
       "      <td>0.16</td>\n",
       "      <td>0.24</td>\n",
       "      <td>2.5233762</td>\n",
       "      <td>[73, 87, 36, 0, 78, 120, 97, 35, 84, 105, 15, ...</td>\n",
       "      <td>[0.0009942222386598587, 3.1063063943292946e-05...</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>128 rows × 12 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "     test_task_idx test_accuracy     test_error adaptation_accuracy  \\\n",
       "74              74           1.0  0.00047049933                 1.0   \n",
       "113            113           1.0   0.0023007998                 1.0   \n",
       "109            109           1.0   0.0037004738                 1.0   \n",
       "65              65           1.0   0.0041467366                 1.0   \n",
       "69              69           1.0   0.0041564885                0.96   \n",
       "..             ...           ...            ...                 ...   \n",
       "1                1          0.96     0.15310507          0.91999996   \n",
       "85              85          0.88     0.17959814          0.91999996   \n",
       "25              25          0.96      0.1825858                 1.0   \n",
       "19              19          0.96     0.19646795                0.96   \n",
       "73              73          0.96     0.22557537                0.96   \n",
       "\n",
       "    adaptation_error train_accuracy train_error zeroshot_accuracy  \\\n",
       "74      0.0065798615           0.12        0.12              0.16   \n",
       "113     0.0011258704           0.12        0.12              0.16   \n",
       "109     0.0028735157     0.39999998  0.39999998        0.39999998   \n",
       "65       0.013692132     0.19999999  0.19999999              0.28   \n",
       "69       0.080072835           0.08        0.08              0.08   \n",
       "..               ...            ...         ...               ...   \n",
       "1         0.20023714     0.19999999  0.19999999              0.28   \n",
       "85        0.09479967     0.35999998  0.35999998              0.16   \n",
       "25      0.0013020827           0.08        0.08        0.19999999   \n",
       "19        0.10896608           0.08        0.08              0.12   \n",
       "73        0.23264027           0.16        0.16              0.24   \n",
       "\n",
       "    zeroshot_error                                     train_task_idx  \\\n",
       "74       3.2498162  [29, 85, 48, 15, 40, 71, 121, 31, 19, 70, 5, 8...   \n",
       "113       3.604068  [117, 35, 56, 15, 3, 113, 87, 45, 31, 80, 83, ...   \n",
       "109      2.2986786  [109, 35, 88, 49, 38, 99, 40, 79, 124, 32, 91,...   \n",
       "65        3.281961  [65, 87, 43, 120, 83, 84, 75, 1, 30, 117, 86, ...   \n",
       "69       3.1572807  [69, 46, 31, 97, 87, 48, 29, 43, 40, 123, 90, ...   \n",
       "..             ...                                                ...   \n",
       "1         3.572054  [1, 70, 78, 46, 36, 83, 88, 8, 12, 49, 121, 66...   \n",
       "85       2.6793048  [85, 35, 12, 114, 56, 15, 29, 43, 112, 83, 86,...   \n",
       "25       2.3373556  [25, 32, 80, 49, 86, 63, 124, 37, 71, 39, 53, ...   \n",
       "19       3.1462817  [19, 83, 75, 88, 99, 39, 56, 70, 77, 49, 105, ...   \n",
       "73       2.5233762  [73, 87, 36, 0, 78, 120, 97, 35, 84, 105, 15, ...   \n",
       "\n",
       "                                      train_task_score  self_rank  \n",
       "74   [7.122623628674773e-07, 6.029400765328319e-07,...         25  \n",
       "113  [4.0740278564044274e-06, 1.916644578159321e-06...          5  \n",
       "109  [1.413117024640087e-05, 5.671791768691037e-06,...          0  \n",
       "65   [1.0569947335170582e-05, 6.356313406286063e-06...          0  \n",
       "69   [8.947985406848602e-06, 7.138767614378594e-06,...          0  \n",
       "..                                                 ...        ...  \n",
       "1    [0.0012244611280038953, 4.2393963667564094e-05...          0  \n",
       "85   [0.001092283520847559, 4.8956415412249044e-05,...          0  \n",
       "25   [0.0010653516510501504, 5.0167749577667564e-05...          0  \n",
       "19   [0.0013941762736067176, 4.528206409304403e-05,...          0  \n",
       "73   [0.0009942222386598587, 3.1063063943292946e-05...          0  \n",
       "\n",
       "[128 rows x 12 columns]"
      ]
     },
     "execution_count": 34,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_opa['self_rank'] = df_opa.apply(\n",
    "    lambda row: row['train_task_idx'].index(row['test_task_idx']), axis=1)\n",
    "df_opa.sort_values(['test_error'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "a50fa5ae-a569-4e9e-a8a5-0861a1264c88",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "pandas.core.series.Series"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "type(df_exact.iloc[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "68d4e1d1-6872-4514-b136-6cec2536faee",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.6553646524467815"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "_series_correlation(df_exact.iloc[3], df_opa.iloc[3])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "26ed76c8-3490-4578-aac8-24af8bea09e6",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>test_task_idx</th>\n",
       "      <th>score_0</th>\n",
       "      <th>score_1</th>\n",
       "      <th>corr</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>[0.0584370382130146, 0.014456318691372871, 0.0...</td>\n",
       "      <td>[0.0005007253494113684, 2.7094183678855188e-05...</td>\n",
       "      <td>0.817668</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>[0.07570473104715347, 0.011496552266180515, 0....</td>\n",
       "      <td>[0.0012244611280038953, 4.2393963667564094e-05...</td>\n",
       "      <td>0.875442</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>2</td>\n",
       "      <td>[0.005113025661557913, 0.002789155812934041, 0...</td>\n",
       "      <td>[8.943345892475918e-05, 1.464942306483863e-05,...</td>\n",
       "      <td>0.414470</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>3</td>\n",
       "      <td>[0.013391907326877117, 0.008520047180354595, 0...</td>\n",
       "      <td>[0.00020589747873600572, 1.581244941917248e-05...</td>\n",
       "      <td>0.655365</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>4</td>\n",
       "      <td>[0.011177701875567436, 0.004326083231717348, 0...</td>\n",
       "      <td>[0.00013046336243860424, 1.6781919839559123e-0...</td>\n",
       "      <td>0.565447</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>123</th>\n",
       "      <td>123</td>\n",
       "      <td>[0.019583316519856453, 0.006185893435031176, 0...</td>\n",
       "      <td>[0.00034789618803188205, 2.4277813281514682e-0...</td>\n",
       "      <td>0.604136</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>124</th>\n",
       "      <td>124</td>\n",
       "      <td>[0.035794053226709366, 0.008923427201807499, 0...</td>\n",
       "      <td>[0.00044188444735482335, 2.444601705064997e-05...</td>\n",
       "      <td>0.801020</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>125</th>\n",
       "      <td>125</td>\n",
       "      <td>[0.00409677205607295, 0.0029801498167216778, 0...</td>\n",
       "      <td>[4.211665873299353e-05, 9.571598639013246e-06,...</td>\n",
       "      <td>0.421983</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>126</th>\n",
       "      <td>126</td>\n",
       "      <td>[0.01933872699737549, 0.006011705379933119, 0....</td>\n",
       "      <td>[0.00021357095101848245, 1.12523357529426e-05,...</td>\n",
       "      <td>0.706834</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>127</th>\n",
       "      <td>127</td>\n",
       "      <td>[0.06096941977739334, 0.01564934104681015, 0.0...</td>\n",
       "      <td>[0.0004481961368583143, 1.654908737691585e-05,...</td>\n",
       "      <td>0.911490</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>128 rows × 4 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "     test_task_idx                                            score_0  \\\n",
       "0                0  [0.0584370382130146, 0.014456318691372871, 0.0...   \n",
       "1                1  [0.07570473104715347, 0.011496552266180515, 0....   \n",
       "2                2  [0.005113025661557913, 0.002789155812934041, 0...   \n",
       "3                3  [0.013391907326877117, 0.008520047180354595, 0...   \n",
       "4                4  [0.011177701875567436, 0.004326083231717348, 0...   \n",
       "..             ...                                                ...   \n",
       "123            123  [0.019583316519856453, 0.006185893435031176, 0...   \n",
       "124            124  [0.035794053226709366, 0.008923427201807499, 0...   \n",
       "125            125  [0.00409677205607295, 0.0029801498167216778, 0...   \n",
       "126            126  [0.01933872699737549, 0.006011705379933119, 0....   \n",
       "127            127  [0.06096941977739334, 0.01564934104681015, 0.0...   \n",
       "\n",
       "                                               score_1      corr  \n",
       "0    [0.0005007253494113684, 2.7094183678855188e-05...  0.817668  \n",
       "1    [0.0012244611280038953, 4.2393963667564094e-05...  0.875442  \n",
       "2    [8.943345892475918e-05, 1.464942306483863e-05,...  0.414470  \n",
       "3    [0.00020589747873600572, 1.581244941917248e-05...  0.655365  \n",
       "4    [0.00013046336243860424, 1.6781919839559123e-0...  0.565447  \n",
       "..                                                 ...       ...  \n",
       "123  [0.00034789618803188205, 2.4277813281514682e-0...  0.604136  \n",
       "124  [0.00044188444735482335, 2.444601705064997e-05...  0.801020  \n",
       "125  [4.211665873299353e-05, 9.571598639013246e-06,...  0.421983  \n",
       "126  [0.00021357095101848245, 1.12523357529426e-05,...  0.706834  \n",
       "127  [0.0004481961368583143, 1.654908737691585e-05,...  0.911490  \n",
       "\n",
       "[128 rows x 4 columns]"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_correlation(df_exact, df_opa)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ee2797a9-662d-4a10-af2c-0bb05c7cd915",
   "metadata": {},
   "source": [
    "### test with test tasks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "af1b010d-7c5f-49d3-85a1-99c6828fb282",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 128/128 [00:42<00:00,  3.01it/s]\n"
     ]
    }
   ],
   "source": [
    "df_exact = explain_test_performance(explainer, tasks_train, tasks_test,\n",
    "                                  preprocess=_preprocess,\n",
    "                                  loss=loss,\n",
    "                                  shots=shots,\n",
    "                                  ways=ways,\n",
    "                                  num_train_task=num_tasks,\n",
    "                                  num_test_task=num_test_tasks\n",
    "                                  )\n",
    "\n",
    "pkl_path = os.path.join(experiment_dir, 'df_exact.pkl')\n",
    "with open(pkl_path, 'wb') as f:\n",
    "    pickle.dump(df_exact, f)\n",
    "\n",
    "index_dict={'train_noise_tasks':[], 'train_shuffle_tasks': []}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "58c6f934-08b4-4357-ac86-32a3639cdcaa",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 128/128 [00:42<00:00,  2.98it/s]\n"
     ]
    }
   ],
   "source": [
    "df_opa = explain_test_performance(explainer_opa, tasks_train, tasks_test,\n",
    "                                  preprocess=_preprocess,\n",
    "                                  loss=loss,\n",
    "                                  shots=shots,\n",
    "                                  ways=ways,\n",
    "                                  num_train_task=num_tasks,\n",
    "                                  num_test_task=num_test_tasks\n",
    "                                  )\n",
    "\n",
    "pkl_path = os.path.join(experiment_dir, 'df_opa.pkl')\n",
    "with open(pkl_path, 'wb') as f:\n",
    "    pickle.dump(df_opa, f)\n",
    "\n",
    "index_dict={'train_noise_tasks':[], 'train_shuffle_tasks': []}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "id": "1bd46b6b-c887-4d28-8a80-c04b5dd9d9d2",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>test_task_idx</th>\n",
       "      <th>test_accuracy</th>\n",
       "      <th>test_error</th>\n",
       "      <th>adaptation_accuracy</th>\n",
       "      <th>adaptation_error</th>\n",
       "      <th>train_accuracy</th>\n",
       "      <th>train_error</th>\n",
       "      <th>zeroshot_accuracy</th>\n",
       "      <th>zeroshot_error</th>\n",
       "      <th>train_task_idx</th>\n",
       "      <th>train_task_score</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>103</th>\n",
       "      <td>103</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.034549527</td>\n",
       "      <td>0.96</td>\n",
       "      <td>0.035123352</td>\n",
       "      <td>0.24</td>\n",
       "      <td>0.24</td>\n",
       "      <td>0.32</td>\n",
       "      <td>1.8267751</td>\n",
       "      <td>[57, 17, 49, 11, 0, 20, 37, 117, 39, 126, 66, ...</td>\n",
       "      <td>[0.025129446759819984, 0.02065139450132847, 0....</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>56</th>\n",
       "      <td>56</td>\n",
       "      <td>0.96</td>\n",
       "      <td>0.037129346</td>\n",
       "      <td>0.96</td>\n",
       "      <td>0.2604465</td>\n",
       "      <td>0.19999999</td>\n",
       "      <td>0.19999999</td>\n",
       "      <td>0.19999999</td>\n",
       "      <td>2.172749</td>\n",
       "      <td>[19, 59, 7, 68, 31, 87, 66, 14, 90, 100, 0, 92...</td>\n",
       "      <td>[0.007499896455556154, 0.004401837941259146, 0...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>112</th>\n",
       "      <td>112</td>\n",
       "      <td>0.96</td>\n",
       "      <td>0.042601608</td>\n",
       "      <td>0.96</td>\n",
       "      <td>0.18041918</td>\n",
       "      <td>0.16</td>\n",
       "      <td>0.16</td>\n",
       "      <td>0.08</td>\n",
       "      <td>4.658523</td>\n",
       "      <td>[122, 83, 40, 126, 87, 51, 114, 6, 48, 12, 112...</td>\n",
       "      <td>[0.010954011231660843, 0.0071158092468976974, ...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>127</th>\n",
       "      <td>127</td>\n",
       "      <td>0.96</td>\n",
       "      <td>0.049335122</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.022823265</td>\n",
       "      <td>0.12</td>\n",
       "      <td>0.12</td>\n",
       "      <td>0.08</td>\n",
       "      <td>3.2418723</td>\n",
       "      <td>[114, 122, 48, 35, 46, 115, 121, 102, 59, 90, ...</td>\n",
       "      <td>[0.009100834839046001, 0.008895132690668106, 0...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>2</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.05132146</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.01076157</td>\n",
       "      <td>0.35999998</td>\n",
       "      <td>0.35999998</td>\n",
       "      <td>0.28</td>\n",
       "      <td>2.4885516</td>\n",
       "      <td>[117, 122, 40, 35, 27, 6, 39, 17, 114, 19, 84,...</td>\n",
       "      <td>[0.02721152827143669, 0.023916112259030342, 0....</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>59</th>\n",
       "      <td>59</td>\n",
       "      <td>0.64</td>\n",
       "      <td>1.5289826</td>\n",
       "      <td>0.91999996</td>\n",
       "      <td>0.30159855</td>\n",
       "      <td>0.16</td>\n",
       "      <td>0.16</td>\n",
       "      <td>0.08</td>\n",
       "      <td>3.387437</td>\n",
       "      <td>[87, 1, 23, 29, 57, 122, 28, 115, 31, 56, 75, ...</td>\n",
       "      <td>[0.06670890748500824, 0.06600479781627655, 0.0...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>87</th>\n",
       "      <td>87</td>\n",
       "      <td>0.56</td>\n",
       "      <td>1.7605783</td>\n",
       "      <td>0.88</td>\n",
       "      <td>0.31418607</td>\n",
       "      <td>0.16</td>\n",
       "      <td>0.16</td>\n",
       "      <td>0.28</td>\n",
       "      <td>2.9345956</td>\n",
       "      <td>[80, 87, 73, 20, 102, 85, 70, 77, 100, 43, 121...</td>\n",
       "      <td>[0.08584997802972794, 0.08171568065881729, 0.0...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>23</th>\n",
       "      <td>23</td>\n",
       "      <td>0.64</td>\n",
       "      <td>2.012744</td>\n",
       "      <td>0.76</td>\n",
       "      <td>0.4485861</td>\n",
       "      <td>0.12</td>\n",
       "      <td>0.12</td>\n",
       "      <td>0.12</td>\n",
       "      <td>4.137039</td>\n",
       "      <td>[83, 124, 25, 15, 114, 8, 70, 122, 1, 88, 80, ...</td>\n",
       "      <td>[0.03902658447623253, 0.03546908497810364, 0.0...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>114</th>\n",
       "      <td>114</td>\n",
       "      <td>0.59999996</td>\n",
       "      <td>2.060565</td>\n",
       "      <td>0.88</td>\n",
       "      <td>0.53120136</td>\n",
       "      <td>0.16</td>\n",
       "      <td>0.16</td>\n",
       "      <td>0.16</td>\n",
       "      <td>3.3692026</td>\n",
       "      <td>[46, 31, 120, 73, 123, 19, 94, 43, 25, 117, 10...</td>\n",
       "      <td>[0.07754511386156082, 0.07292743772268295, 0.0...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>24</th>\n",
       "      <td>24</td>\n",
       "      <td>0.35999998</td>\n",
       "      <td>3.2999358</td>\n",
       "      <td>0.59999996</td>\n",
       "      <td>1.4049126</td>\n",
       "      <td>0.28</td>\n",
       "      <td>0.28</td>\n",
       "      <td>0.24</td>\n",
       "      <td>2.4452364</td>\n",
       "      <td>[83, 25, 31, 75, 20, 102, 48, 15, 59, 122, 46,...</td>\n",
       "      <td>[0.25515493750572205, 0.2135259211063385, 0.17...</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>128 rows × 11 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "     test_task_idx test_accuracy   test_error adaptation_accuracy  \\\n",
       "103            103           1.0  0.034549527                0.96   \n",
       "56              56          0.96  0.037129346                0.96   \n",
       "112            112          0.96  0.042601608                0.96   \n",
       "127            127          0.96  0.049335122                 1.0   \n",
       "2                2           1.0   0.05132146                 1.0   \n",
       "..             ...           ...          ...                 ...   \n",
       "59              59          0.64    1.5289826          0.91999996   \n",
       "87              87          0.56    1.7605783                0.88   \n",
       "23              23          0.64     2.012744                0.76   \n",
       "114            114    0.59999996     2.060565                0.88   \n",
       "24              24    0.35999998    3.2999358          0.59999996   \n",
       "\n",
       "    adaptation_error train_accuracy train_error zeroshot_accuracy  \\\n",
       "103      0.035123352           0.24        0.24              0.32   \n",
       "56         0.2604465     0.19999999  0.19999999        0.19999999   \n",
       "112       0.18041918           0.16        0.16              0.08   \n",
       "127      0.022823265           0.12        0.12              0.08   \n",
       "2         0.01076157     0.35999998  0.35999998              0.28   \n",
       "..               ...            ...         ...               ...   \n",
       "59        0.30159855           0.16        0.16              0.08   \n",
       "87        0.31418607           0.16        0.16              0.28   \n",
       "23         0.4485861           0.12        0.12              0.12   \n",
       "114       0.53120136           0.16        0.16              0.16   \n",
       "24         1.4049126           0.28        0.28              0.24   \n",
       "\n",
       "    zeroshot_error                                     train_task_idx  \\\n",
       "103      1.8267751  [57, 17, 49, 11, 0, 20, 37, 117, 39, 126, 66, ...   \n",
       "56        2.172749  [19, 59, 7, 68, 31, 87, 66, 14, 90, 100, 0, 92...   \n",
       "112       4.658523  [122, 83, 40, 126, 87, 51, 114, 6, 48, 12, 112...   \n",
       "127      3.2418723  [114, 122, 48, 35, 46, 115, 121, 102, 59, 90, ...   \n",
       "2        2.4885516  [117, 122, 40, 35, 27, 6, 39, 17, 114, 19, 84,...   \n",
       "..             ...                                                ...   \n",
       "59        3.387437  [87, 1, 23, 29, 57, 122, 28, 115, 31, 56, 75, ...   \n",
       "87       2.9345956  [80, 87, 73, 20, 102, 85, 70, 77, 100, 43, 121...   \n",
       "23        4.137039  [83, 124, 25, 15, 114, 8, 70, 122, 1, 88, 80, ...   \n",
       "114      3.3692026  [46, 31, 120, 73, 123, 19, 94, 43, 25, 117, 10...   \n",
       "24       2.4452364  [83, 25, 31, 75, 20, 102, 48, 15, 59, 122, 46,...   \n",
       "\n",
       "                                      train_task_score  \n",
       "103  [0.025129446759819984, 0.02065139450132847, 0....  \n",
       "56   [0.007499896455556154, 0.004401837941259146, 0...  \n",
       "112  [0.010954011231660843, 0.0071158092468976974, ...  \n",
       "127  [0.009100834839046001, 0.008895132690668106, 0...  \n",
       "2    [0.02721152827143669, 0.023916112259030342, 0....  \n",
       "..                                                 ...  \n",
       "59   [0.06670890748500824, 0.06600479781627655, 0.0...  \n",
       "87   [0.08584997802972794, 0.08171568065881729, 0.0...  \n",
       "23   [0.03902658447623253, 0.03546908497810364, 0.0...  \n",
       "114  [0.07754511386156082, 0.07292743772268295, 0.0...  \n",
       "24   [0.25515493750572205, 0.2135259211063385, 0.17...  \n",
       "\n",
       "[128 rows x 11 columns]"
      ]
     },
     "execution_count": 46,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_exact.sort_values(['test_error'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "id": "18b0085b-8d1d-4540-94e0-abaa464bb610",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>test_task_idx</th>\n",
       "      <th>test_accuracy</th>\n",
       "      <th>test_error</th>\n",
       "      <th>adaptation_accuracy</th>\n",
       "      <th>adaptation_error</th>\n",
       "      <th>train_accuracy</th>\n",
       "      <th>train_error</th>\n",
       "      <th>zeroshot_accuracy</th>\n",
       "      <th>zeroshot_error</th>\n",
       "      <th>train_task_idx</th>\n",
       "      <th>train_task_score</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>103</th>\n",
       "      <td>103</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.034549527</td>\n",
       "      <td>0.96</td>\n",
       "      <td>0.035123352</td>\n",
       "      <td>0.24</td>\n",
       "      <td>0.24</td>\n",
       "      <td>0.32</td>\n",
       "      <td>1.8267751</td>\n",
       "      <td>[120, 7, 49, 102, 59, 126, 66, 40, 75, 51, 83,...</td>\n",
       "      <td>[0.0001478378108004108, 0.00012239433999639004...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>56</th>\n",
       "      <td>56</td>\n",
       "      <td>0.96</td>\n",
       "      <td>0.037129346</td>\n",
       "      <td>0.96</td>\n",
       "      <td>0.2604465</td>\n",
       "      <td>0.19999999</td>\n",
       "      <td>0.19999999</td>\n",
       "      <td>0.19999999</td>\n",
       "      <td>2.172749</td>\n",
       "      <td>[85, 31, 80, 117, 25, 73, 68, 46, 88, 32, 99, ...</td>\n",
       "      <td>[5.819321813760325e-05, 4.074894968653098e-05,...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>112</th>\n",
       "      <td>112</td>\n",
       "      <td>0.96</td>\n",
       "      <td>0.042601608</td>\n",
       "      <td>0.96</td>\n",
       "      <td>0.18041918</td>\n",
       "      <td>0.16</td>\n",
       "      <td>0.16</td>\n",
       "      <td>0.08</td>\n",
       "      <td>4.658523</td>\n",
       "      <td>[87, 6, 46, 70, 56, 88, 114, 25, 126, 80, 2, 4...</td>\n",
       "      <td>[6.920070882188156e-05, 6.636149191763252e-05,...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>127</th>\n",
       "      <td>127</td>\n",
       "      <td>0.96</td>\n",
       "      <td>0.049335122</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.022823265</td>\n",
       "      <td>0.12</td>\n",
       "      <td>0.12</td>\n",
       "      <td>0.08</td>\n",
       "      <td>3.2418723</td>\n",
       "      <td>[25, 102, 78, 68, 59, 29, 11, 89, 0, 126, 28, ...</td>\n",
       "      <td>[7.086406549206004e-05, 6.332980410661548e-05,...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>2</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.05132146</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.01076157</td>\n",
       "      <td>0.35999998</td>\n",
       "      <td>0.35999998</td>\n",
       "      <td>0.28</td>\n",
       "      <td>2.4885516</td>\n",
       "      <td>[73, 117, 34, 20, 66, 119, 15, 40, 14, 87, 0, ...</td>\n",
       "      <td>[0.0001727769704302773, 0.00010949717398034409...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>59</th>\n",
       "      <td>59</td>\n",
       "      <td>0.64</td>\n",
       "      <td>1.5289826</td>\n",
       "      <td>0.91999996</td>\n",
       "      <td>0.30159855</td>\n",
       "      <td>0.16</td>\n",
       "      <td>0.16</td>\n",
       "      <td>0.08</td>\n",
       "      <td>3.387437</td>\n",
       "      <td>[73, 122, 87, 20, 28, 75, 115, 37, 88, 32, 80,...</td>\n",
       "      <td>[0.0006684588734060526, 0.00047748888027854264...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>87</th>\n",
       "      <td>87</td>\n",
       "      <td>0.56</td>\n",
       "      <td>1.7605783</td>\n",
       "      <td>0.88</td>\n",
       "      <td>0.31418607</td>\n",
       "      <td>0.16</td>\n",
       "      <td>0.16</td>\n",
       "      <td>0.28</td>\n",
       "      <td>2.9345956</td>\n",
       "      <td>[20, 8, 25, 120, 73, 1, 92, 31, 99, 15, 88, 96...</td>\n",
       "      <td>[0.00041064442484639585, 0.0003837125550489872...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>23</th>\n",
       "      <td>23</td>\n",
       "      <td>0.64</td>\n",
       "      <td>2.012744</td>\n",
       "      <td>0.76</td>\n",
       "      <td>0.4485861</td>\n",
       "      <td>0.12</td>\n",
       "      <td>0.12</td>\n",
       "      <td>0.12</td>\n",
       "      <td>4.137039</td>\n",
       "      <td>[73, 36, 124, 116, 46, 114, 87, 35, 43, 63, 62...</td>\n",
       "      <td>[0.00032196263782680035, 0.0003091092221438885...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>114</th>\n",
       "      <td>114</td>\n",
       "      <td>0.59999996</td>\n",
       "      <td>2.060565</td>\n",
       "      <td>0.88</td>\n",
       "      <td>0.53120136</td>\n",
       "      <td>0.16</td>\n",
       "      <td>0.16</td>\n",
       "      <td>0.16</td>\n",
       "      <td>3.3692026</td>\n",
       "      <td>[73, 25, 120, 92, 75, 38, 19, 7, 22, 3, 48, 12...</td>\n",
       "      <td>[0.0009972101543098688, 0.0004889949923381209,...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>24</th>\n",
       "      <td>24</td>\n",
       "      <td>0.35999998</td>\n",
       "      <td>3.2999358</td>\n",
       "      <td>0.59999996</td>\n",
       "      <td>1.4049126</td>\n",
       "      <td>0.28</td>\n",
       "      <td>0.28</td>\n",
       "      <td>0.24</td>\n",
       "      <td>2.4452364</td>\n",
       "      <td>[31, 117, 28, 35, 68, 102, 91, 97, 40, 20, 78,...</td>\n",
       "      <td>[0.001516898861154914, 0.00148588337469846, 0....</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>128 rows × 11 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "     test_task_idx test_accuracy   test_error adaptation_accuracy  \\\n",
       "103            103           1.0  0.034549527                0.96   \n",
       "56              56          0.96  0.037129346                0.96   \n",
       "112            112          0.96  0.042601608                0.96   \n",
       "127            127          0.96  0.049335122                 1.0   \n",
       "2                2           1.0   0.05132146                 1.0   \n",
       "..             ...           ...          ...                 ...   \n",
       "59              59          0.64    1.5289826          0.91999996   \n",
       "87              87          0.56    1.7605783                0.88   \n",
       "23              23          0.64     2.012744                0.76   \n",
       "114            114    0.59999996     2.060565                0.88   \n",
       "24              24    0.35999998    3.2999358          0.59999996   \n",
       "\n",
       "    adaptation_error train_accuracy train_error zeroshot_accuracy  \\\n",
       "103      0.035123352           0.24        0.24              0.32   \n",
       "56         0.2604465     0.19999999  0.19999999        0.19999999   \n",
       "112       0.18041918           0.16        0.16              0.08   \n",
       "127      0.022823265           0.12        0.12              0.08   \n",
       "2         0.01076157     0.35999998  0.35999998              0.28   \n",
       "..               ...            ...         ...               ...   \n",
       "59        0.30159855           0.16        0.16              0.08   \n",
       "87        0.31418607           0.16        0.16              0.28   \n",
       "23         0.4485861           0.12        0.12              0.12   \n",
       "114       0.53120136           0.16        0.16              0.16   \n",
       "24         1.4049126           0.28        0.28              0.24   \n",
       "\n",
       "    zeroshot_error                                     train_task_idx  \\\n",
       "103      1.8267751  [120, 7, 49, 102, 59, 126, 66, 40, 75, 51, 83,...   \n",
       "56        2.172749  [85, 31, 80, 117, 25, 73, 68, 46, 88, 32, 99, ...   \n",
       "112       4.658523  [87, 6, 46, 70, 56, 88, 114, 25, 126, 80, 2, 4...   \n",
       "127      3.2418723  [25, 102, 78, 68, 59, 29, 11, 89, 0, 126, 28, ...   \n",
       "2        2.4885516  [73, 117, 34, 20, 66, 119, 15, 40, 14, 87, 0, ...   \n",
       "..             ...                                                ...   \n",
       "59        3.387437  [73, 122, 87, 20, 28, 75, 115, 37, 88, 32, 80,...   \n",
       "87       2.9345956  [20, 8, 25, 120, 73, 1, 92, 31, 99, 15, 88, 96...   \n",
       "23        4.137039  [73, 36, 124, 116, 46, 114, 87, 35, 43, 63, 62...   \n",
       "114      3.3692026  [73, 25, 120, 92, 75, 38, 19, 7, 22, 3, 48, 12...   \n",
       "24       2.4452364  [31, 117, 28, 35, 68, 102, 91, 97, 40, 20, 78,...   \n",
       "\n",
       "                                      train_task_score  \n",
       "103  [0.0001478378108004108, 0.00012239433999639004...  \n",
       "56   [5.819321813760325e-05, 4.074894968653098e-05,...  \n",
       "112  [6.920070882188156e-05, 6.636149191763252e-05,...  \n",
       "127  [7.086406549206004e-05, 6.332980410661548e-05,...  \n",
       "2    [0.0001727769704302773, 0.00010949717398034409...  \n",
       "..                                                 ...  \n",
       "59   [0.0006684588734060526, 0.00047748888027854264...  \n",
       "87   [0.00041064442484639585, 0.0003837125550489872...  \n",
       "23   [0.00032196263782680035, 0.0003091092221438885...  \n",
       "114  [0.0009972101543098688, 0.0004889949923381209,...  \n",
       "24   [0.001516898861154914, 0.00148588337469846, 0....  \n",
       "\n",
       "[128 rows x 11 columns]"
      ]
     },
     "execution_count": 47,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_opa.sort_values(['test_error'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "788345d2-a0ce-4820-8309-bbf67274881c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.20134549307694974"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "_series_correlation(df_exact.iloc[113], df_opa.iloc[113])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "33f6320a-4141-47b4-803e-36c275b2b658",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>test_task_idx</th>\n",
       "      <th>score_0</th>\n",
       "      <th>score_1</th>\n",
       "      <th>corr</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>[0.007877628318965435, 0.00387960160151124, 0....</td>\n",
       "      <td>[0.00016143036191351712, 9.59113021963276e-05,...</td>\n",
       "      <td>0.056759</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>[0.008765799924731255, 0.007196517661213875, 0...</td>\n",
       "      <td>[0.0003266599087510258, 0.00027208126266486943...</td>\n",
       "      <td>0.062283</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>2</td>\n",
       "      <td>[0.006896039936691523, 0.005766767542809248, 0...</td>\n",
       "      <td>[0.0001727769704302773, 0.00010949717398034409...</td>\n",
       "      <td>0.048711</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>3</td>\n",
       "      <td>[0.04043019935488701, 0.04010601341724396, 0.0...</td>\n",
       "      <td>[0.0010342116001993418, 0.0009590634726919234,...</td>\n",
       "      <td>0.439581</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>4</td>\n",
       "      <td>[0.02433246374130249, 0.018689990043640137, 0....</td>\n",
       "      <td>[0.0004521046648733318, 0.000390937871998176, ...</td>\n",
       "      <td>0.178132</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>123</th>\n",
       "      <td>123</td>\n",
       "      <td>[0.01581629179418087, 0.012596304528415203, 0....</td>\n",
       "      <td>[0.000368670531315729, 0.00025150569854304194,...</td>\n",
       "      <td>0.034516</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>124</th>\n",
       "      <td>124</td>\n",
       "      <td>[0.007310028187930584, 0.006295029539614916, 0...</td>\n",
       "      <td>[0.00015129112580325454, 0.0001399657630827278...</td>\n",
       "      <td>-0.017100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>125</th>\n",
       "      <td>125</td>\n",
       "      <td>[0.011669972911477089, 0.007566728629171848, 0...</td>\n",
       "      <td>[0.0002187500795116648, 0.00021855198428966105...</td>\n",
       "      <td>0.224292</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>126</th>\n",
       "      <td>126</td>\n",
       "      <td>[0.004523418843746185, 0.004443782847374678, 0...</td>\n",
       "      <td>[0.0002517403045203537, 0.00022956932662054896...</td>\n",
       "      <td>0.285593</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>127</th>\n",
       "      <td>127</td>\n",
       "      <td>[0.006482816766947508, 0.005137724336236715, 0...</td>\n",
       "      <td>[7.086406549206004e-05, 6.332980410661548e-05,...</td>\n",
       "      <td>0.033457</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>128 rows × 4 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "     test_task_idx                                            score_0  \\\n",
       "0                0  [0.007877628318965435, 0.00387960160151124, 0....   \n",
       "1                1  [0.008765799924731255, 0.007196517661213875, 0...   \n",
       "2                2  [0.006896039936691523, 0.005766767542809248, 0...   \n",
       "3                3  [0.04043019935488701, 0.04010601341724396, 0.0...   \n",
       "4                4  [0.02433246374130249, 0.018689990043640137, 0....   \n",
       "..             ...                                                ...   \n",
       "123            123  [0.01581629179418087, 0.012596304528415203, 0....   \n",
       "124            124  [0.007310028187930584, 0.006295029539614916, 0...   \n",
       "125            125  [0.011669972911477089, 0.007566728629171848, 0...   \n",
       "126            126  [0.004523418843746185, 0.004443782847374678, 0...   \n",
       "127            127  [0.006482816766947508, 0.005137724336236715, 0...   \n",
       "\n",
       "                                               score_1      corr  \n",
       "0    [0.00016143036191351712, 9.59113021963276e-05,...  0.056759  \n",
       "1    [0.0003266599087510258, 0.00027208126266486943...  0.062283  \n",
       "2    [0.0001727769704302773, 0.00010949717398034409...  0.048711  \n",
       "3    [0.0010342116001993418, 0.0009590634726919234,...  0.439581  \n",
       "4    [0.0004521046648733318, 0.000390937871998176, ...  0.178132  \n",
       "..                                                 ...       ...  \n",
       "123  [0.000368670531315729, 0.00025150569854304194,...  0.034516  \n",
       "124  [0.00015129112580325454, 0.0001399657630827278... -0.017100  \n",
       "125  [0.0002187500795116648, 0.00021855198428966105...  0.224292  \n",
       "126  [0.0002517403045203537, 0.00022956932662054896...  0.285593  \n",
       "127  [7.086406549206004e-05, 6.332980410661548e-05,...  0.033457  \n",
       "\n",
       "[128 rows x 4 columns]"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_correlation(df_exact, df_opa)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "55f44a26-ff39-4c51-b823-a971598eb2f8",
   "metadata": {},
   "source": [
    "## 64 tasks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "fd147082-bba0-4704-af22-df8f91873f90",
   "metadata": {},
   "outputs": [],
   "source": [
    "k = 36\n",
    "fft_crop_size = 6\n",
    "ways = 5\n",
    "shots = 5\n",
    "num_tasks = 64\n",
    "num_test_tasks = 128\n",
    "experiment_dir = '../examples/maml_l2l/cache/2024-0913-003110/'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "52924a30-d18b-4951-901a-659f63849d6f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "found ../examples/maml_l2l/cache/2024-0913-003110/\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "loaded ../examples/maml_l2l/cache/2024-0913-003110/expl_maml_k36_layer32_tasks64_mbs32_ways5_shots5_1000.pkl\n",
      "Computing pseudo-inverse of hessian\n",
      "done  (shape (1413, 1413))\n",
      "set src_param_matrix (shape (1413, 64))\n",
      "done\n",
      "src_param_matrix: (1413, 64)\n"
     ]
    }
   ],
   "source": [
    "explainer_path = experiment_dir + 'expl_maml_k36_layer32_tasks64_mbs32_ways5_shots5_1000.pkl'\n",
    "tasks_train, tasks_test, explainer_exact, maml, feature, impurity_dict =\\\n",
    "    setup_experiment(ways=ways, shots=shots, num_tasks=num_tasks, \n",
    "    experiment_dir=experiment_dir, \n",
    "    explainer_path=explainer_path, \n",
    "    fft_crop_size=fft_crop_size,\n",
    "    dataset='omniglot'\n",
    "    )\n",
    "_preprocess = get_preprocess(feature, cuda=True)\n",
    "explainer_exact.set_src_generalized_matrix(n_positive_ev=128)\n",
    "print('done')\n",
    "print('src_param_matrix:', explainer_exact.src_param_matrix.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "181c7ba1-56e0-4d40-b6c1-6de616cde125",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "found ../examples/maml_l2l/cache/2024-0913-003110/\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "loaded ../examples/maml_l2l/cache/2024-0913-003110/opa_nh128_expl_maml_k36_layer32_tasks64_mbs32_ways5_shots5_1000.pkl\n",
      "done\n",
      "src_param_matrix: (1413, 64)\n"
     ]
    }
   ],
   "source": [
    "explainer_path = experiment_dir + 'opa_nh128_expl_maml_k36_layer32_tasks64_mbs32_ways5_shots5_1000.pkl'\n",
    "tasks_train, tasks_test, explainer_opa, maml, feature, impurity_dict =\\\n",
    "    setup_experiment(ways=ways, shots=shots, num_tasks=num_tasks, \n",
    "    experiment_dir=experiment_dir, \n",
    "    explainer_path=explainer_path, \n",
    "    fft_crop_size=fft_crop_size,\n",
    "    dataset='omniglot'\n",
    "    )\n",
    "_preprocess = get_preprocess(feature, cuda=True)\n",
    "print('done')\n",
    "print('src_param_matrix:', explainer_opa.src_param_matrix.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f1d30262-b2ac-479a-bdc5-c52396caa68c",
   "metadata": {},
   "source": [
    "### tests with train tasks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "502745f6-3545-44ac-8ab6-097a14697268",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 64/64 [00:11<00:00,  5.45it/s]\n"
     ]
    }
   ],
   "source": [
    "df_exact = explain_test_performance(explainer_exact, tasks_train, tasks_train,\n",
    "                                    preprocess=_preprocess,\n",
    "                                    loss=loss,\n",
    "                                    shots=shots,\n",
    "                                    ways=ways,\n",
    "                                    num_train_task=num_tasks,\n",
    "                                    num_test_task=num_tasks\n",
    "                                  )\n",
    "\n",
    "pkl_path = os.path.join(experiment_dir, 'df_exact.pkl')\n",
    "with open(pkl_path, 'wb') as f:\n",
    "    pickle.dump(df_exact, f)\n",
    "\n",
    "index_dict={'train_noise_tasks':[], 'train_shuffle_tasks': []}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "d80ed731-fcee-433b-ae7b-e348ff2de664",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>test_task_idx</th>\n",
       "      <th>test_accuracy</th>\n",
       "      <th>test_error</th>\n",
       "      <th>adaptation_accuracy</th>\n",
       "      <th>adaptation_error</th>\n",
       "      <th>train_accuracy</th>\n",
       "      <th>train_error</th>\n",
       "      <th>zeroshot_accuracy</th>\n",
       "      <th>zeroshot_error</th>\n",
       "      <th>train_task_idx</th>\n",
       "      <th>train_task_score</th>\n",
       "      <th>self_rank</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>2</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.00016091354</td>\n",
       "      <td>0.91999996</td>\n",
       "      <td>0.3539884</td>\n",
       "      <td>0.24</td>\n",
       "      <td>0.24</td>\n",
       "      <td>0.12</td>\n",
       "      <td>4.6844196</td>\n",
       "      <td>[63, 49, 45, 38, 27, 6, 9, 51, 33, 37, 19, 34,...</td>\n",
       "      <td>[1.7385911633027717e-05, 9.95759000943508e-06,...</td>\n",
       "      <td>32</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>41</th>\n",
       "      <td>41</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0005200979</td>\n",
       "      <td>1.0</td>\n",
       "      <td>7.1750954e-05</td>\n",
       "      <td>0.19999999</td>\n",
       "      <td>0.19999999</td>\n",
       "      <td>0.16</td>\n",
       "      <td>3.2405376</td>\n",
       "      <td>[18, 3, 26, 51, 40, 39, 43, 52, 7, 46, 20, 62,...</td>\n",
       "      <td>[8.540345879737288e-05, 8.419655205216259e-05,...</td>\n",
       "      <td>19</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>13</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.00081781053</td>\n",
       "      <td>0.91999996</td>\n",
       "      <td>0.52522385</td>\n",
       "      <td>0.19999999</td>\n",
       "      <td>0.19999999</td>\n",
       "      <td>0.16</td>\n",
       "      <td>3.1119604</td>\n",
       "      <td>[43, 28, 52, 16, 46, 29, 3, 58, 38, 19, 50, 11...</td>\n",
       "      <td>[0.00021346076391637325, 0.0001674498926149681...</td>\n",
       "      <td>23</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>61</th>\n",
       "      <td>61</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0030066997</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.010557078</td>\n",
       "      <td>0.19999999</td>\n",
       "      <td>0.19999999</td>\n",
       "      <td>0.24</td>\n",
       "      <td>3.3069751</td>\n",
       "      <td>[28, 43, 29, 59, 49, 38, 48, 46, 19, 61, 57, 2...</td>\n",
       "      <td>[0.0021560993045568466, 0.0012260971125215292,...</td>\n",
       "      <td>9</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>8</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0030586922</td>\n",
       "      <td>0.96</td>\n",
       "      <td>0.04981514</td>\n",
       "      <td>0.44</td>\n",
       "      <td>0.44</td>\n",
       "      <td>0.68</td>\n",
       "      <td>0.8835869</td>\n",
       "      <td>[43, 49, 37, 8, 16, 50, 29, 33, 46, 59, 63, 34...</td>\n",
       "      <td>[0.0017381587531417608, 0.001256835530512035, ...</td>\n",
       "      <td>3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>49</th>\n",
       "      <td>49</td>\n",
       "      <td>0.96</td>\n",
       "      <td>0.09321929</td>\n",
       "      <td>0.91999996</td>\n",
       "      <td>0.3437719</td>\n",
       "      <td>0.12</td>\n",
       "      <td>0.12</td>\n",
       "      <td>0.12</td>\n",
       "      <td>3.8926916</td>\n",
       "      <td>[49, 28, 63, 6, 26, 16, 29, 7, 32, 31, 14, 9, ...</td>\n",
       "      <td>[0.029495693743228912, 0.019883235916495323, 0...</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>14</td>\n",
       "      <td>0.96</td>\n",
       "      <td>0.14149472</td>\n",
       "      <td>0.91999996</td>\n",
       "      <td>0.22719507</td>\n",
       "      <td>0.12</td>\n",
       "      <td>0.12</td>\n",
       "      <td>0.16</td>\n",
       "      <td>3.7482178</td>\n",
       "      <td>[14, 27, 50, 52, 43, 6, 28, 18, 3, 38, 19, 16,...</td>\n",
       "      <td>[0.0678684338927269, 0.01665140688419342, 0.01...</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>28</th>\n",
       "      <td>28</td>\n",
       "      <td>0.96</td>\n",
       "      <td>0.18673684</td>\n",
       "      <td>0.88</td>\n",
       "      <td>0.21361135</td>\n",
       "      <td>0.08</td>\n",
       "      <td>0.08</td>\n",
       "      <td>0.12</td>\n",
       "      <td>2.9233105</td>\n",
       "      <td>[28, 27, 29, 63, 49, 16, 43, 59, 50, 33, 19, 5...</td>\n",
       "      <td>[0.1546240895986557, 0.03537946566939354, 0.02...</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>27</th>\n",
       "      <td>27</td>\n",
       "      <td>0.96</td>\n",
       "      <td>0.22490385</td>\n",
       "      <td>0.91999996</td>\n",
       "      <td>0.09633201</td>\n",
       "      <td>0.12</td>\n",
       "      <td>0.12</td>\n",
       "      <td>0.19999999</td>\n",
       "      <td>2.776013</td>\n",
       "      <td>[27, 28, 63, 14, 6, 19, 29, 38, 10, 33, 58, 22...</td>\n",
       "      <td>[0.07928380370140076, 0.03537945821881294, 0.0...</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>43</th>\n",
       "      <td>43</td>\n",
       "      <td>0.79999995</td>\n",
       "      <td>0.7116365</td>\n",
       "      <td>0.71999997</td>\n",
       "      <td>0.8948777</td>\n",
       "      <td>0.19999999</td>\n",
       "      <td>0.19999999</td>\n",
       "      <td>0.16</td>\n",
       "      <td>2.30172</td>\n",
       "      <td>[43, 34, 19, 26, 38, 3, 28, 58, 52, 10, 5, 14,...</td>\n",
       "      <td>[0.2586711347103119, 0.04547004774212837, 0.01...</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>64 rows × 12 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "    test_task_idx test_accuracy     test_error adaptation_accuracy  \\\n",
       "2               2           1.0  0.00016091354          0.91999996   \n",
       "41             41           1.0   0.0005200979                 1.0   \n",
       "13             13           1.0  0.00081781053          0.91999996   \n",
       "61             61           1.0   0.0030066997                 1.0   \n",
       "8               8           1.0   0.0030586922                0.96   \n",
       "..            ...           ...            ...                 ...   \n",
       "49             49          0.96     0.09321929          0.91999996   \n",
       "14             14          0.96     0.14149472          0.91999996   \n",
       "28             28          0.96     0.18673684                0.88   \n",
       "27             27          0.96     0.22490385          0.91999996   \n",
       "43             43    0.79999995      0.7116365          0.71999997   \n",
       "\n",
       "   adaptation_error train_accuracy train_error zeroshot_accuracy  \\\n",
       "2         0.3539884           0.24        0.24              0.12   \n",
       "41    7.1750954e-05     0.19999999  0.19999999              0.16   \n",
       "13       0.52522385     0.19999999  0.19999999              0.16   \n",
       "61      0.010557078     0.19999999  0.19999999              0.24   \n",
       "8        0.04981514           0.44        0.44              0.68   \n",
       "..              ...            ...         ...               ...   \n",
       "49        0.3437719           0.12        0.12              0.12   \n",
       "14       0.22719507           0.12        0.12              0.16   \n",
       "28       0.21361135           0.08        0.08              0.12   \n",
       "27       0.09633201           0.12        0.12        0.19999999   \n",
       "43        0.8948777     0.19999999  0.19999999              0.16   \n",
       "\n",
       "   zeroshot_error                                     train_task_idx  \\\n",
       "2       4.6844196  [63, 49, 45, 38, 27, 6, 9, 51, 33, 37, 19, 34,...   \n",
       "41      3.2405376  [18, 3, 26, 51, 40, 39, 43, 52, 7, 46, 20, 62,...   \n",
       "13      3.1119604  [43, 28, 52, 16, 46, 29, 3, 58, 38, 19, 50, 11...   \n",
       "61      3.3069751  [28, 43, 29, 59, 49, 38, 48, 46, 19, 61, 57, 2...   \n",
       "8       0.8835869  [43, 49, 37, 8, 16, 50, 29, 33, 46, 59, 63, 34...   \n",
       "..            ...                                                ...   \n",
       "49      3.8926916  [49, 28, 63, 6, 26, 16, 29, 7, 32, 31, 14, 9, ...   \n",
       "14      3.7482178  [14, 27, 50, 52, 43, 6, 28, 18, 3, 38, 19, 16,...   \n",
       "28      2.9233105  [28, 27, 29, 63, 49, 16, 43, 59, 50, 33, 19, 5...   \n",
       "27       2.776013  [27, 28, 63, 14, 6, 19, 29, 38, 10, 33, 58, 22...   \n",
       "43        2.30172  [43, 34, 19, 26, 38, 3, 28, 58, 52, 10, 5, 14,...   \n",
       "\n",
       "                                     train_task_score  self_rank  \n",
       "2   [1.7385911633027717e-05, 9.95759000943508e-06,...         32  \n",
       "41  [8.540345879737288e-05, 8.419655205216259e-05,...         19  \n",
       "13  [0.00021346076391637325, 0.0001674498926149681...         23  \n",
       "61  [0.0021560993045568466, 0.0012260971125215292,...          9  \n",
       "8   [0.0017381587531417608, 0.001256835530512035, ...          3  \n",
       "..                                                ...        ...  \n",
       "49  [0.029495693743228912, 0.019883235916495323, 0...          0  \n",
       "14  [0.0678684338927269, 0.01665140688419342, 0.01...          0  \n",
       "28  [0.1546240895986557, 0.03537946566939354, 0.02...          0  \n",
       "27  [0.07928380370140076, 0.03537945821881294, 0.0...          0  \n",
       "43  [0.2586711347103119, 0.04547004774212837, 0.01...          0  \n",
       "\n",
       "[64 rows x 12 columns]"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_exact['self_rank'] = df_exact.apply(\n",
    "    lambda row: row['train_task_idx'].index(row['test_task_idx']), axis=1)\n",
    "df_exact.sort_values(['test_error'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "6724802c-abfb-47e1-9f30-16e626746c2a",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 64/64 [00:11<00:00,  5.61it/s]\n"
     ]
    }
   ],
   "source": [
    "df_opa = explain_test_performance(explainer_opa, tasks_train, tasks_train,\n",
    "                                  preprocess=_preprocess,\n",
    "                                  loss=loss,\n",
    "                                  shots=shots,\n",
    "                                  ways=ways,\n",
    "                                  num_train_task=num_tasks,\n",
    "                                  num_test_task=num_tasks\n",
    "                                  )\n",
    "\n",
    "pkl_path = os.path.join(experiment_dir, 'df_opa.pkl')\n",
    "with open(pkl_path, 'wb') as f:\n",
    "    pickle.dump(df_opa, f)\n",
    "\n",
    "index_dict={'train_noise_tasks':[], 'train_shuffle_tasks': []}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "ee02e07c-2c85-44ec-b8eb-fb748d7bc5bf",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>test_task_idx</th>\n",
       "      <th>test_accuracy</th>\n",
       "      <th>test_error</th>\n",
       "      <th>adaptation_accuracy</th>\n",
       "      <th>adaptation_error</th>\n",
       "      <th>train_accuracy</th>\n",
       "      <th>train_error</th>\n",
       "      <th>zeroshot_accuracy</th>\n",
       "      <th>zeroshot_error</th>\n",
       "      <th>train_task_idx</th>\n",
       "      <th>train_task_score</th>\n",
       "      <th>self_rank</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>2</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.00016091354</td>\n",
       "      <td>0.91999996</td>\n",
       "      <td>0.3539884</td>\n",
       "      <td>0.24</td>\n",
       "      <td>0.24</td>\n",
       "      <td>0.12</td>\n",
       "      <td>4.6844196</td>\n",
       "      <td>[29, 45, 9, 51, 49, 6, 12, 15, 38, 32, 40, 22,...</td>\n",
       "      <td>[2.8031232091052516e-07, 1.1411479761136434e-0...</td>\n",
       "      <td>29</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>41</th>\n",
       "      <td>41</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0005200979</td>\n",
       "      <td>1.0</td>\n",
       "      <td>7.1750954e-05</td>\n",
       "      <td>0.19999999</td>\n",
       "      <td>0.19999999</td>\n",
       "      <td>0.16</td>\n",
       "      <td>3.2405376</td>\n",
       "      <td>[6, 18, 55, 26, 14, 49, 33, 39, 59, 46, 57, 9,...</td>\n",
       "      <td>[9.94251081465336e-07, 9.422165021533147e-07, ...</td>\n",
       "      <td>21</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>13</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.00081781053</td>\n",
       "      <td>0.91999996</td>\n",
       "      <td>0.52522385</td>\n",
       "      <td>0.19999999</td>\n",
       "      <td>0.19999999</td>\n",
       "      <td>0.16</td>\n",
       "      <td>3.1119604</td>\n",
       "      <td>[43, 29, 16, 52, 34, 46, 33, 58, 6, 20, 38, 37...</td>\n",
       "      <td>[4.505615834204946e-06, 1.0430504744363134e-06...</td>\n",
       "      <td>21</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>61</th>\n",
       "      <td>61</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0030066997</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.010557078</td>\n",
       "      <td>0.19999999</td>\n",
       "      <td>0.19999999</td>\n",
       "      <td>0.24</td>\n",
       "      <td>3.3069751</td>\n",
       "      <td>[43, 28, 14, 49, 59, 34, 50, 61, 57, 55, 48, 2...</td>\n",
       "      <td>[2.7620422770269215e-05, 1.973065809579566e-05...</td>\n",
       "      <td>7</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>8</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0030586922</td>\n",
       "      <td>0.96</td>\n",
       "      <td>0.04981514</td>\n",
       "      <td>0.44</td>\n",
       "      <td>0.44</td>\n",
       "      <td>0.68</td>\n",
       "      <td>0.8835869</td>\n",
       "      <td>[49, 8, 59, 15, 33, 37, 46, 34, 3, 29, 60, 9, ...</td>\n",
       "      <td>[1.683866503299214e-05, 1.6817852156236768e-05...</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>49</th>\n",
       "      <td>49</td>\n",
       "      <td>0.96</td>\n",
       "      <td>0.09321929</td>\n",
       "      <td>0.91999996</td>\n",
       "      <td>0.3437719</td>\n",
       "      <td>0.12</td>\n",
       "      <td>0.12</td>\n",
       "      <td>0.12</td>\n",
       "      <td>3.8926916</td>\n",
       "      <td>[49, 27, 14, 18, 19, 16, 6, 55, 56, 38, 9, 20,...</td>\n",
       "      <td>[0.0012683842796832323, 0.00024529057554900646...</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>14</td>\n",
       "      <td>0.96</td>\n",
       "      <td>0.14149472</td>\n",
       "      <td>0.91999996</td>\n",
       "      <td>0.22719507</td>\n",
       "      <td>0.12</td>\n",
       "      <td>0.12</td>\n",
       "      <td>0.16</td>\n",
       "      <td>3.7482178</td>\n",
       "      <td>[14, 43, 49, 34, 33, 28, 45, 63, 19, 6, 44, 55...</td>\n",
       "      <td>[0.002260608831420541, 0.0007628632592968643, ...</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>28</th>\n",
       "      <td>28</td>\n",
       "      <td>0.96</td>\n",
       "      <td>0.18673684</td>\n",
       "      <td>0.88</td>\n",
       "      <td>0.21361135</td>\n",
       "      <td>0.08</td>\n",
       "      <td>0.08</td>\n",
       "      <td>0.12</td>\n",
       "      <td>2.9233105</td>\n",
       "      <td>[28, 43, 27, 33, 59, 26, 14, 16, 19, 56, 57, 1...</td>\n",
       "      <td>[0.0027902130968868732, 0.000737540889531374, ...</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>27</th>\n",
       "      <td>27</td>\n",
       "      <td>0.96</td>\n",
       "      <td>0.22490385</td>\n",
       "      <td>0.91999996</td>\n",
       "      <td>0.09633201</td>\n",
       "      <td>0.12</td>\n",
       "      <td>0.12</td>\n",
       "      <td>0.19999999</td>\n",
       "      <td>2.776013</td>\n",
       "      <td>[27, 63, 28, 6, 19, 49, 26, 57, 5, 51, 10, 12,...</td>\n",
       "      <td>[0.003588391002267599, 0.0005208561778999865, ...</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>43</th>\n",
       "      <td>43</td>\n",
       "      <td>0.79999995</td>\n",
       "      <td>0.7116365</td>\n",
       "      <td>0.71999997</td>\n",
       "      <td>0.8948777</td>\n",
       "      <td>0.19999999</td>\n",
       "      <td>0.19999999</td>\n",
       "      <td>0.16</td>\n",
       "      <td>2.30172</td>\n",
       "      <td>[43, 14, 28, 57, 19, 10, 40, 58, 37, 60, 32, 3...</td>\n",
       "      <td>[0.012756180949509144, 0.0007628633175045252, ...</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>64 rows × 12 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "    test_task_idx test_accuracy     test_error adaptation_accuracy  \\\n",
       "2               2           1.0  0.00016091354          0.91999996   \n",
       "41             41           1.0   0.0005200979                 1.0   \n",
       "13             13           1.0  0.00081781053          0.91999996   \n",
       "61             61           1.0   0.0030066997                 1.0   \n",
       "8               8           1.0   0.0030586922                0.96   \n",
       "..            ...           ...            ...                 ...   \n",
       "49             49          0.96     0.09321929          0.91999996   \n",
       "14             14          0.96     0.14149472          0.91999996   \n",
       "28             28          0.96     0.18673684                0.88   \n",
       "27             27          0.96     0.22490385          0.91999996   \n",
       "43             43    0.79999995      0.7116365          0.71999997   \n",
       "\n",
       "   adaptation_error train_accuracy train_error zeroshot_accuracy  \\\n",
       "2         0.3539884           0.24        0.24              0.12   \n",
       "41    7.1750954e-05     0.19999999  0.19999999              0.16   \n",
       "13       0.52522385     0.19999999  0.19999999              0.16   \n",
       "61      0.010557078     0.19999999  0.19999999              0.24   \n",
       "8        0.04981514           0.44        0.44              0.68   \n",
       "..              ...            ...         ...               ...   \n",
       "49        0.3437719           0.12        0.12              0.12   \n",
       "14       0.22719507           0.12        0.12              0.16   \n",
       "28       0.21361135           0.08        0.08              0.12   \n",
       "27       0.09633201           0.12        0.12        0.19999999   \n",
       "43        0.8948777     0.19999999  0.19999999              0.16   \n",
       "\n",
       "   zeroshot_error                                     train_task_idx  \\\n",
       "2       4.6844196  [29, 45, 9, 51, 49, 6, 12, 15, 38, 32, 40, 22,...   \n",
       "41      3.2405376  [6, 18, 55, 26, 14, 49, 33, 39, 59, 46, 57, 9,...   \n",
       "13      3.1119604  [43, 29, 16, 52, 34, 46, 33, 58, 6, 20, 38, 37...   \n",
       "61      3.3069751  [43, 28, 14, 49, 59, 34, 50, 61, 57, 55, 48, 2...   \n",
       "8       0.8835869  [49, 8, 59, 15, 33, 37, 46, 34, 3, 29, 60, 9, ...   \n",
       "..            ...                                                ...   \n",
       "49      3.8926916  [49, 27, 14, 18, 19, 16, 6, 55, 56, 38, 9, 20,...   \n",
       "14      3.7482178  [14, 43, 49, 34, 33, 28, 45, 63, 19, 6, 44, 55...   \n",
       "28      2.9233105  [28, 43, 27, 33, 59, 26, 14, 16, 19, 56, 57, 1...   \n",
       "27       2.776013  [27, 63, 28, 6, 19, 49, 26, 57, 5, 51, 10, 12,...   \n",
       "43        2.30172  [43, 14, 28, 57, 19, 10, 40, 58, 37, 60, 32, 3...   \n",
       "\n",
       "                                     train_task_score  self_rank  \n",
       "2   [2.8031232091052516e-07, 1.1411479761136434e-0...         29  \n",
       "41  [9.94251081465336e-07, 9.422165021533147e-07, ...         21  \n",
       "13  [4.505615834204946e-06, 1.0430504744363134e-06...         21  \n",
       "61  [2.7620422770269215e-05, 1.973065809579566e-05...          7  \n",
       "8   [1.683866503299214e-05, 1.6817852156236768e-05...          1  \n",
       "..                                                ...        ...  \n",
       "49  [0.0012683842796832323, 0.00024529057554900646...          0  \n",
       "14  [0.002260608831420541, 0.0007628632592968643, ...          0  \n",
       "28  [0.0027902130968868732, 0.000737540889531374, ...          0  \n",
       "27  [0.003588391002267599, 0.0005208561778999865, ...          0  \n",
       "43  [0.012756180949509144, 0.0007628633175045252, ...          0  \n",
       "\n",
       "[64 rows x 12 columns]"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_opa['self_rank'] = df_opa.apply(\n",
    "    lambda row: row['train_task_idx'].index(row['test_task_idx']), axis=1)\n",
    "df_opa.sort_values(['test_error'])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cbabbb09-f074-4c48-aabd-ec87c60cb675",
   "metadata": {},
   "source": [
    "### tests with test tasks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "6834f43e-8f92-4f39-aed5-001cbaa7c320",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 128/128 [00:23<00:00,  5.51it/s]\n"
     ]
    }
   ],
   "source": [
    "df_exact = explain_test_performance(explainer_exact, tasks_train, tasks_test,\n",
    "                                  preprocess=_preprocess,\n",
    "                                  loss=loss,\n",
    "                                  shots=shots,\n",
    "                                  ways=ways,\n",
    "                                  num_train_task=num_tasks,\n",
    "                                  num_test_task=num_test_tasks\n",
    "                                  )\n",
    "\n",
    "pkl_path = os.path.join(experiment_dir, 'df_exact.pkl')\n",
    "with open(pkl_path, 'wb') as f:\n",
    "    pickle.dump(df_exact, f)\n",
    "\n",
    "index_dict={'train_noise_tasks':[], 'train_shuffle_tasks': []}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "63e17ce3-6112-4b62-8592-0a91b6670d17",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 128/128 [00:23<00:00,  5.45it/s]\n"
     ]
    }
   ],
   "source": [
    "df_opa = explain_test_performance(explainer_opa, tasks_train, tasks_test,\n",
    "                                  preprocess=_preprocess,\n",
    "                                  loss=loss,\n",
    "                                  shots=shots,\n",
    "                                  ways=ways,\n",
    "                                  num_train_task=num_tasks,\n",
    "                                  num_test_task=num_test_tasks\n",
    "                                  )\n",
    "\n",
    "pkl_path = os.path.join(experiment_dir, 'df_opa.pkl')\n",
    "with open(pkl_path, 'wb') as f:\n",
    "    pickle.dump(df_opa, f)\n",
    "\n",
    "index_dict={'train_noise_tasks':[], 'train_shuffle_tasks': []}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "fb98a7c3-458a-4bae-bd3a-c37ebdf2c268",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>test_task_idx</th>\n",
       "      <th>test_accuracy</th>\n",
       "      <th>test_error</th>\n",
       "      <th>adaptation_accuracy</th>\n",
       "      <th>adaptation_error</th>\n",
       "      <th>train_accuracy</th>\n",
       "      <th>train_error</th>\n",
       "      <th>zeroshot_accuracy</th>\n",
       "      <th>zeroshot_error</th>\n",
       "      <th>train_task_idx</th>\n",
       "      <th>train_task_score</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>66</th>\n",
       "      <td>66</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.00016091354</td>\n",
       "      <td>0.91999996</td>\n",
       "      <td>0.3539884</td>\n",
       "      <td>0.24</td>\n",
       "      <td>0.24</td>\n",
       "      <td>0.12</td>\n",
       "      <td>4.6844196</td>\n",
       "      <td>[27, 34, 28, 43, 63, 6, 59, 55, 44, 19, 33, 37...</td>\n",
       "      <td>[0.00012355569924693555, 6.597881292691454e-05...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>105</th>\n",
       "      <td>105</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0005200979</td>\n",
       "      <td>1.0</td>\n",
       "      <td>7.1750954e-05</td>\n",
       "      <td>0.19999999</td>\n",
       "      <td>0.19999999</td>\n",
       "      <td>0.16</td>\n",
       "      <td>3.2405376</td>\n",
       "      <td>[43, 28, 32, 38, 7, 63, 11, 49, 45, 57, 56, 48...</td>\n",
       "      <td>[0.0008775671012699604, 0.0008188745123334229,...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>77</th>\n",
       "      <td>77</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.00081781053</td>\n",
       "      <td>0.91999996</td>\n",
       "      <td>0.52522385</td>\n",
       "      <td>0.19999999</td>\n",
       "      <td>0.19999999</td>\n",
       "      <td>0.16</td>\n",
       "      <td>3.1119604</td>\n",
       "      <td>[28, 49, 57, 52, 32, 46, 26, 39, 21, 50, 51, 1...</td>\n",
       "      <td>[0.0007840418256819248, 0.0002446244761813432,...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>125</th>\n",
       "      <td>125</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0030066997</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.010557078</td>\n",
       "      <td>0.19999999</td>\n",
       "      <td>0.19999999</td>\n",
       "      <td>0.24</td>\n",
       "      <td>3.3069751</td>\n",
       "      <td>[27, 34, 61, 28, 19, 33, 18, 31, 7, 35, 10, 47...</td>\n",
       "      <td>[0.007663601543754339, 0.002994065871462226, 0...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>72</th>\n",
       "      <td>72</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0030586922</td>\n",
       "      <td>0.96</td>\n",
       "      <td>0.04981514</td>\n",
       "      <td>0.44</td>\n",
       "      <td>0.44</td>\n",
       "      <td>0.68</td>\n",
       "      <td>0.8835869</td>\n",
       "      <td>[43, 27, 8, 16, 63, 33, 57, 50, 7, 29, 19, 31,...</td>\n",
       "      <td>[0.017282940447330475, 0.004186374135315418, 0...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>34</th>\n",
       "      <td>34</td>\n",
       "      <td>0.68</td>\n",
       "      <td>1.386445</td>\n",
       "      <td>0.96</td>\n",
       "      <td>0.13553324</td>\n",
       "      <td>0.12</td>\n",
       "      <td>0.12</td>\n",
       "      <td>0.16</td>\n",
       "      <td>3.7007945</td>\n",
       "      <td>[34, 6, 39, 7, 63, 40, 3, 14, 54, 46, 42, 10, ...</td>\n",
       "      <td>[0.31956759095191956, 0.23580755293369293, 0.1...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>45</th>\n",
       "      <td>45</td>\n",
       "      <td>0.64</td>\n",
       "      <td>1.5215017</td>\n",
       "      <td>0.79999995</td>\n",
       "      <td>1.0227759</td>\n",
       "      <td>0.39999998</td>\n",
       "      <td>0.39999998</td>\n",
       "      <td>0.32</td>\n",
       "      <td>3.011552</td>\n",
       "      <td>[27, 34, 35, 29, 49, 43, 23, 33, 31, 21, 51, 5...</td>\n",
       "      <td>[0.7668546438217163, 0.17025083303451538, 0.07...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>10</td>\n",
       "      <td>0.64</td>\n",
       "      <td>1.8436437</td>\n",
       "      <td>0.88</td>\n",
       "      <td>0.8060948</td>\n",
       "      <td>0.16</td>\n",
       "      <td>0.16</td>\n",
       "      <td>0.19999999</td>\n",
       "      <td>2.4827554</td>\n",
       "      <td>[27, 43, 29, 63, 45, 28, 51, 26, 62, 5, 10, 11...</td>\n",
       "      <td>[0.8547983765602112, 0.5022902488708496, 0.064...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>22</th>\n",
       "      <td>22</td>\n",
       "      <td>0.64</td>\n",
       "      <td>2.2139666</td>\n",
       "      <td>0.71999997</td>\n",
       "      <td>0.8715461</td>\n",
       "      <td>0.12</td>\n",
       "      <td>0.12</td>\n",
       "      <td>0.28</td>\n",
       "      <td>4.545876</td>\n",
       "      <td>[27, 43, 28, 34, 6, 29, 63, 62, 38, 57, 59, 54...</td>\n",
       "      <td>[0.5951960682868958, 0.38794657588005066, 0.36...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>38</th>\n",
       "      <td>38</td>\n",
       "      <td>0.59999996</td>\n",
       "      <td>2.2241833</td>\n",
       "      <td>0.71999997</td>\n",
       "      <td>2.4568803</td>\n",
       "      <td>0.28</td>\n",
       "      <td>0.28</td>\n",
       "      <td>0.35999998</td>\n",
       "      <td>1.8577391</td>\n",
       "      <td>[27, 38, 48, 14, 16, 31, 26, 21, 49, 46, 57, 4...</td>\n",
       "      <td>[0.2002403885126114, 0.06898617744445801, 0.06...</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>128 rows × 11 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "     test_task_idx test_accuracy     test_error adaptation_accuracy  \\\n",
       "66              66           1.0  0.00016091354          0.91999996   \n",
       "105            105           1.0   0.0005200979                 1.0   \n",
       "77              77           1.0  0.00081781053          0.91999996   \n",
       "125            125           1.0   0.0030066997                 1.0   \n",
       "72              72           1.0   0.0030586922                0.96   \n",
       "..             ...           ...            ...                 ...   \n",
       "34              34          0.68       1.386445                0.96   \n",
       "45              45          0.64      1.5215017          0.79999995   \n",
       "10              10          0.64      1.8436437                0.88   \n",
       "22              22          0.64      2.2139666          0.71999997   \n",
       "38              38    0.59999996      2.2241833          0.71999997   \n",
       "\n",
       "    adaptation_error train_accuracy train_error zeroshot_accuracy  \\\n",
       "66         0.3539884           0.24        0.24              0.12   \n",
       "105    7.1750954e-05     0.19999999  0.19999999              0.16   \n",
       "77        0.52522385     0.19999999  0.19999999              0.16   \n",
       "125      0.010557078     0.19999999  0.19999999              0.24   \n",
       "72        0.04981514           0.44        0.44              0.68   \n",
       "..               ...            ...         ...               ...   \n",
       "34        0.13553324           0.12        0.12              0.16   \n",
       "45         1.0227759     0.39999998  0.39999998              0.32   \n",
       "10         0.8060948           0.16        0.16        0.19999999   \n",
       "22         0.8715461           0.12        0.12              0.28   \n",
       "38         2.4568803           0.28        0.28        0.35999998   \n",
       "\n",
       "    zeroshot_error                                     train_task_idx  \\\n",
       "66       4.6844196  [27, 34, 28, 43, 63, 6, 59, 55, 44, 19, 33, 37...   \n",
       "105      3.2405376  [43, 28, 32, 38, 7, 63, 11, 49, 45, 57, 56, 48...   \n",
       "77       3.1119604  [28, 49, 57, 52, 32, 46, 26, 39, 21, 50, 51, 1...   \n",
       "125      3.3069751  [27, 34, 61, 28, 19, 33, 18, 31, 7, 35, 10, 47...   \n",
       "72       0.8835869  [43, 27, 8, 16, 63, 33, 57, 50, 7, 29, 19, 31,...   \n",
       "..             ...                                                ...   \n",
       "34       3.7007945  [34, 6, 39, 7, 63, 40, 3, 14, 54, 46, 42, 10, ...   \n",
       "45        3.011552  [27, 34, 35, 29, 49, 43, 23, 33, 31, 21, 51, 5...   \n",
       "10       2.4827554  [27, 43, 29, 63, 45, 28, 51, 26, 62, 5, 10, 11...   \n",
       "22        4.545876  [27, 43, 28, 34, 6, 29, 63, 62, 38, 57, 59, 54...   \n",
       "38       1.8577391  [27, 38, 48, 14, 16, 31, 26, 21, 49, 46, 57, 4...   \n",
       "\n",
       "                                      train_task_score  \n",
       "66   [0.00012355569924693555, 6.597881292691454e-05...  \n",
       "105  [0.0008775671012699604, 0.0008188745123334229,...  \n",
       "77   [0.0007840418256819248, 0.0002446244761813432,...  \n",
       "125  [0.007663601543754339, 0.002994065871462226, 0...  \n",
       "72   [0.017282940447330475, 0.004186374135315418, 0...  \n",
       "..                                                 ...  \n",
       "34   [0.31956759095191956, 0.23580755293369293, 0.1...  \n",
       "45   [0.7668546438217163, 0.17025083303451538, 0.07...  \n",
       "10   [0.8547983765602112, 0.5022902488708496, 0.064...  \n",
       "22   [0.5951960682868958, 0.38794657588005066, 0.36...  \n",
       "38   [0.2002403885126114, 0.06898617744445801, 0.06...  \n",
       "\n",
       "[128 rows x 11 columns]"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_exact.sort_values(['test_error'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "7227c2e8-41e1-425b-a0b1-d2a9e813d838",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>test_task_idx</th>\n",
       "      <th>test_accuracy</th>\n",
       "      <th>test_error</th>\n",
       "      <th>adaptation_accuracy</th>\n",
       "      <th>adaptation_error</th>\n",
       "      <th>train_accuracy</th>\n",
       "      <th>train_error</th>\n",
       "      <th>zeroshot_accuracy</th>\n",
       "      <th>zeroshot_error</th>\n",
       "      <th>train_task_idx</th>\n",
       "      <th>train_task_score</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>66</th>\n",
       "      <td>66</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.00016091354</td>\n",
       "      <td>0.91999996</td>\n",
       "      <td>0.3539884</td>\n",
       "      <td>0.24</td>\n",
       "      <td>0.24</td>\n",
       "      <td>0.12</td>\n",
       "      <td>4.6844196</td>\n",
       "      <td>[27, 28, 36, 45, 30, 19, 21, 49, 52, 1, 53, 2,...</td>\n",
       "      <td>[2.4145695078914287e-06, 1.8882880112869316e-0...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>105</th>\n",
       "      <td>105</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0005200979</td>\n",
       "      <td>1.0</td>\n",
       "      <td>7.1750954e-05</td>\n",
       "      <td>0.19999999</td>\n",
       "      <td>0.19999999</td>\n",
       "      <td>0.16</td>\n",
       "      <td>3.2405376</td>\n",
       "      <td>[27, 41, 38, 14, 19, 26, 40, 36, 30, 31, 21, 4...</td>\n",
       "      <td>[8.447365871688817e-06, 4.896727205050411e-06,...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>77</th>\n",
       "      <td>77</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.00081781053</td>\n",
       "      <td>0.91999996</td>\n",
       "      <td>0.52522385</td>\n",
       "      <td>0.19999999</td>\n",
       "      <td>0.19999999</td>\n",
       "      <td>0.16</td>\n",
       "      <td>3.1119604</td>\n",
       "      <td>[13, 33, 43, 7, 63, 40, 26, 20, 32, 31, 23, 19...</td>\n",
       "      <td>[4.949064532411285e-06, 2.8040387860528426e-06...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>125</th>\n",
       "      <td>125</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0030066997</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.010557078</td>\n",
       "      <td>0.19999999</td>\n",
       "      <td>0.19999999</td>\n",
       "      <td>0.24</td>\n",
       "      <td>3.3069751</td>\n",
       "      <td>[61, 43, 29, 7, 16, 47, 56, 3, 20, 18, 58, 35,...</td>\n",
       "      <td>[3.711994213517755e-05, 8.6338804976549e-06, 2...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>72</th>\n",
       "      <td>72</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0030586922</td>\n",
       "      <td>0.96</td>\n",
       "      <td>0.04981514</td>\n",
       "      <td>0.44</td>\n",
       "      <td>0.44</td>\n",
       "      <td>0.68</td>\n",
       "      <td>0.8835869</td>\n",
       "      <td>[8, 27, 43, 51, 56, 33, 28, 59, 40, 7, 29, 10,...</td>\n",
       "      <td>[3.949532037950121e-05, 1.6003821656340733e-05...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>34</th>\n",
       "      <td>34</td>\n",
       "      <td>0.68</td>\n",
       "      <td>1.386445</td>\n",
       "      <td>0.96</td>\n",
       "      <td>0.13553324</td>\n",
       "      <td>0.12</td>\n",
       "      <td>0.12</td>\n",
       "      <td>0.16</td>\n",
       "      <td>3.7007945</td>\n",
       "      <td>[28, 43, 7, 51, 31, 40, 56, 33, 59, 14, 37, 24...</td>\n",
       "      <td>[0.0060884724371135235, 0.006054930854588747, ...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>45</th>\n",
       "      <td>45</td>\n",
       "      <td>0.64</td>\n",
       "      <td>1.5215017</td>\n",
       "      <td>0.79999995</td>\n",
       "      <td>1.0227759</td>\n",
       "      <td>0.39999998</td>\n",
       "      <td>0.39999998</td>\n",
       "      <td>0.32</td>\n",
       "      <td>3.011552</td>\n",
       "      <td>[27, 63, 21, 16, 39, 44, 20, 18, 31, 34, 43, 2...</td>\n",
       "      <td>[0.004352950025349855, 0.0014733317075297236, ...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>10</td>\n",
       "      <td>0.64</td>\n",
       "      <td>1.8436437</td>\n",
       "      <td>0.88</td>\n",
       "      <td>0.8060948</td>\n",
       "      <td>0.16</td>\n",
       "      <td>0.16</td>\n",
       "      <td>0.19999999</td>\n",
       "      <td>2.4827554</td>\n",
       "      <td>[27, 43, 40, 63, 48, 10, 30, 50, 38, 7, 20, 11...</td>\n",
       "      <td>[0.017162228003144264, 0.00560997799038887, 0....</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>22</th>\n",
       "      <td>22</td>\n",
       "      <td>0.64</td>\n",
       "      <td>2.2139666</td>\n",
       "      <td>0.71999997</td>\n",
       "      <td>0.8715461</td>\n",
       "      <td>0.12</td>\n",
       "      <td>0.12</td>\n",
       "      <td>0.28</td>\n",
       "      <td>4.545876</td>\n",
       "      <td>[27, 28, 49, 14, 46, 51, 54, 59, 0, 55, 16, 29...</td>\n",
       "      <td>[0.012110484763979912, 0.007450949400663376, 0...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>38</th>\n",
       "      <td>38</td>\n",
       "      <td>0.59999996</td>\n",
       "      <td>2.2241833</td>\n",
       "      <td>0.71999997</td>\n",
       "      <td>2.4568803</td>\n",
       "      <td>0.28</td>\n",
       "      <td>0.28</td>\n",
       "      <td>0.35999998</td>\n",
       "      <td>1.8577391</td>\n",
       "      <td>[11, 38, 47, 6, 18, 31, 29, 37, 15, 25, 34, 57...</td>\n",
       "      <td>[0.0009661400108598173, 0.0008226876379922032,...</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>128 rows × 11 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "     test_task_idx test_accuracy     test_error adaptation_accuracy  \\\n",
       "66              66           1.0  0.00016091354          0.91999996   \n",
       "105            105           1.0   0.0005200979                 1.0   \n",
       "77              77           1.0  0.00081781053          0.91999996   \n",
       "125            125           1.0   0.0030066997                 1.0   \n",
       "72              72           1.0   0.0030586922                0.96   \n",
       "..             ...           ...            ...                 ...   \n",
       "34              34          0.68       1.386445                0.96   \n",
       "45              45          0.64      1.5215017          0.79999995   \n",
       "10              10          0.64      1.8436437                0.88   \n",
       "22              22          0.64      2.2139666          0.71999997   \n",
       "38              38    0.59999996      2.2241833          0.71999997   \n",
       "\n",
       "    adaptation_error train_accuracy train_error zeroshot_accuracy  \\\n",
       "66         0.3539884           0.24        0.24              0.12   \n",
       "105    7.1750954e-05     0.19999999  0.19999999              0.16   \n",
       "77        0.52522385     0.19999999  0.19999999              0.16   \n",
       "125      0.010557078     0.19999999  0.19999999              0.24   \n",
       "72        0.04981514           0.44        0.44              0.68   \n",
       "..               ...            ...         ...               ...   \n",
       "34        0.13553324           0.12        0.12              0.16   \n",
       "45         1.0227759     0.39999998  0.39999998              0.32   \n",
       "10         0.8060948           0.16        0.16        0.19999999   \n",
       "22         0.8715461           0.12        0.12              0.28   \n",
       "38         2.4568803           0.28        0.28        0.35999998   \n",
       "\n",
       "    zeroshot_error                                     train_task_idx  \\\n",
       "66       4.6844196  [27, 28, 36, 45, 30, 19, 21, 49, 52, 1, 53, 2,...   \n",
       "105      3.2405376  [27, 41, 38, 14, 19, 26, 40, 36, 30, 31, 21, 4...   \n",
       "77       3.1119604  [13, 33, 43, 7, 63, 40, 26, 20, 32, 31, 23, 19...   \n",
       "125      3.3069751  [61, 43, 29, 7, 16, 47, 56, 3, 20, 18, 58, 35,...   \n",
       "72       0.8835869  [8, 27, 43, 51, 56, 33, 28, 59, 40, 7, 29, 10,...   \n",
       "..             ...                                                ...   \n",
       "34       3.7007945  [28, 43, 7, 51, 31, 40, 56, 33, 59, 14, 37, 24...   \n",
       "45        3.011552  [27, 63, 21, 16, 39, 44, 20, 18, 31, 34, 43, 2...   \n",
       "10       2.4827554  [27, 43, 40, 63, 48, 10, 30, 50, 38, 7, 20, 11...   \n",
       "22        4.545876  [27, 28, 49, 14, 46, 51, 54, 59, 0, 55, 16, 29...   \n",
       "38       1.8577391  [11, 38, 47, 6, 18, 31, 29, 37, 15, 25, 34, 57...   \n",
       "\n",
       "                                      train_task_score  \n",
       "66   [2.4145695078914287e-06, 1.8882880112869316e-0...  \n",
       "105  [8.447365871688817e-06, 4.896727205050411e-06,...  \n",
       "77   [4.949064532411285e-06, 2.8040387860528426e-06...  \n",
       "125  [3.711994213517755e-05, 8.6338804976549e-06, 2...  \n",
       "72   [3.949532037950121e-05, 1.6003821656340733e-05...  \n",
       "..                                                 ...  \n",
       "34   [0.0060884724371135235, 0.006054930854588747, ...  \n",
       "45   [0.004352950025349855, 0.0014733317075297236, ...  \n",
       "10   [0.017162228003144264, 0.00560997799038887, 0....  \n",
       "22   [0.012110484763979912, 0.007450949400663376, 0...  \n",
       "38   [0.0009661400108598173, 0.0008226876379922032,...  \n",
       "\n",
       "[128 rows x 11 columns]"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_opa.sort_values(['test_error'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "939a6c70-e16d-4495-a93f-3d6d6f18ac4e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.3852857798341467"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "_series_correlation(df_exact.iloc[105], df_opa.iloc[105])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "8cbe5224-5098-4935-81f9-697d7d8cdc48",
   "metadata": {},
   "outputs": [],
   "source": [
    "corr=df_correlation(df_exact, df_opa)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "006fef99-ffb7-473d-b8a3-7b4ec996ef59",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.363060957322224\n",
      "0.3735469981938026\n"
     ]
    }
   ],
   "source": [
    "print(corr['corr'].mean())\n",
    "print(corr['corr'].std())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "25e940e2-c639-457a-8981-e9546926ad15",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
