{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "458ac6f6",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os, copy\n",
    "import pandas as pd\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "b6f74a18",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "98"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import json\n",
    "with open(f'/home/SemiTab/dataset_id.json', 'r') as file:\n",
    "    data_info = json.load(file)\n",
    "    \n",
    "datalist = \"4135 1555 4538 23512 40499 40536 40685 454 43044 44129 44131 44157 44158 44159 44160 44161 44162 45060 45062 45548 41162 44089 44090 44091 44122 44123 44124 44125 44126 5 40672 43986 45068 41275 461 31 1549 452 25 470 475 846 934 1043 1063 1067 1113 1169 1459 1462 1464 1466 1467 1471 1475 1479 1486 1487 1489 1492 1493 1494 1497 1504 1509 1510 1531 35 36 37 54 59 150 151 182 185 188 307 313 551 51 40900 40981 40985 41143 41145 41147 41150 41168 41169 41960 42345 42734 338 23 1476 45714 53 337 372 455 458 29 49 466 42665 12 14 16 18 22 32 48 1503 4153 40922 42931\"\n",
    "datalist = datalist.split(\" \")\n",
    "datalist = [eval(i) for i in datalist if eval(i) not in [\n",
    "    5, 41960, 313, 551, ## class imbalance\n",
    "    43986, 372, 41275, 43044, 40672, 1113, 40685, 44159, 1169, 150, 44129, ## too large datasets\n",
    "    25, 461, 466, 42665, ## too small -- stunt fails\n",
    "]]\n",
    "len(datalist)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "badd5c41",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'4135 1555 4538 23512 40499 40536 454 44131 44157 44158 44160 44161 44162 45060 45062 45548 41162 44089 44090 44091 44122 44123 44124 44125 44126 45068 31 1549 452 470 475 846 934 1043 1063 1067 1459 1462 1464 1466 1467 1471 1475 1479 1486 1487 1489 1492 1493 1494 1497 1504 1509 1510 1531 35 36 37 54 59 151 182 185 188 307 51 40900 40981 40985 41143 41145 41147 41150 41168 41169 42345 42734 338 23 1476 45714 53 337 455 458 29 49 12 14 16 18 22 32 48 1503 4153 40922 42931'"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "\" \".join(map(str, datalist))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "d9672674",
   "metadata": {},
   "outputs": [],
   "source": [
    "models = [\"lr\", \"knn\", \"xgboost\", \"catboost\", \"lightgbm\", \"mlp\", \n",
    "          \"ae\", \"ict\", \"meanteacher\", \"sslrecon\", \"sslbinning\", \"sslsubtab\", \n",
    "          \"sslvime\", \"semivime\", \"sslscarf\", \"sslbinshuffling\", \"sslbinsampling\",\n",
    "          \"sslmasking\", \"sslshuffling\", \"sslnoisemasking\", \"sslrq\",\n",
    "          \"pseudolabel-masking\", \"pseudolabel-shuffling\", \"pseudolabel-noisemasking\", \"pseudolabel-rq\",\n",
    "          \"stunt\", \"tabpfn\", \"pseudolabel-binshuffling\", \"pseudolabel-binsampling\"]\n",
    "len(models)\n",
    "\n",
    "clist = []\n",
    "for m in models:\n",
    "    if m.startswith(\"ssl\"):\n",
    "        clist.append(f'{m}-lr')\n",
    "        clist.append(f'{m}-knn')\n",
    "        clist.append(f'{m}-lineareval')\n",
    "        clist.append(f'{m}-finetuning')\n",
    "    else:\n",
    "        clist.append(m)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "614bf7a7",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "2fbba158",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "182280\n"
     ]
    }
   ],
   "source": [
    "all_comb = (11*4 + 18) * len(datalist) * 3 * 10\n",
    "print(all_comb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "a70d17c9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "182903\n",
      "(183234, 6)\n",
      "ALL: 182280 .. DONE: 183234 (100.523 %)\n",
      "(61504, 6)\n",
      "(61494, 6)\n",
      "(60236, 6)\n"
     ]
    }
   ],
   "source": [
    "result_file = 'result.csv'\n",
    "if os.path.exists(result_file):\n",
    "    existing_result = pd.read_csv(result_file)\n",
    "    existing_records = set(zip(existing_result['trial'], existing_result['shots'], \n",
    "                               existing_result['data_id'], existing_result['model']))\n",
    "else:\n",
    "    existing_result = pd.DataFrame(columns=(\"trial\", \"shots\", \"data_id\", \"model\", \"acc\", \"auroc\"))\n",
    "    existing_records = set()\n",
    "\n",
    "result = existing_result.copy()\n",
    "\n",
    "i = len(result)\n",
    "print(len(result[(result[\"shots\"] < 100) & (result[\"data_id\"].isin(datalist))]))\n",
    "\n",
    "for root, dirs, files in os.walk(\"results\"):\n",
    "    for file in files:\n",
    "        if file == \"performance.npy\":\n",
    "            fname = os.path.join(root, file)\n",
    "            _, seed, shot, modelname, data = root.split(\"/\")\n",
    "            seed = eval(seed.split(\"=\")[-1])\n",
    "            shot = eval(shot.split(\"=\")[-1])\n",
    "            data = eval(data.split(\"=\")[-1])\n",
    "            modelname = modelname.split(\"=\")[-1]\n",
    "            \n",
    "            if shot < 100:\n",
    "                perf = np.load(fname, allow_pickle=True).item()\n",
    "\n",
    "                if modelname.startswith(\"ssl\"):\n",
    "                    for e in [\"lr\", \"knn\", \"lineareval\", \"finetuning\"]:\n",
    "                        record_key = (seed, shot, data, f'{modelname}-{e}')\n",
    "                        if record_key not in existing_records:\n",
    "                            tmp = perf[\"Test\"].get(e, [None, None])\n",
    "                            result.loc[i] = [seed, shot, data, f'{modelname}-{e}', tmp[0], tmp[1]]\n",
    "                            i += 1\n",
    "                elif modelname == \"stunt\":\n",
    "                    record_key = (seed, shot, data, modelname)\n",
    "                    if record_key not in existing_records:\n",
    "                        result.loc[i] = [seed, shot, data, modelname, perf[\"Test\"], np.nan]\n",
    "                        i += 1\n",
    "                else:\n",
    "                    record_key = (seed, shot, data, modelname)\n",
    "                    if record_key not in existing_records:\n",
    "                        result.loc[i] = [seed, shot, data, modelname, perf[\"Test\"][0], perf[\"Test\"][1]]\n",
    "                        i += 1\n",
    "\n",
    "result = result.drop_duplicates()\n",
    "result.to_csv(result_file, index=False)\n",
    "\n",
    "result = result[result[\"shots\"] < 100]\n",
    "result = result[result[\"data_id\"].isin(datalist)]\n",
    "print(result.shape)\n",
    "print(f'ALL: {all_comb} .. DONE: {len(result)} ({(len(result) * 100 / all_comb):.3f} %)')\n",
    "\n",
    "print(result[result[\"shots\"] == 1].shape) # (60760)\n",
    "print(result[result[\"shots\"] == 5].shape) # (60760)\n",
    "print(result[result[\"shots\"] == 10].shape) # (60760)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 311,
   "id": "d9c8b9f5",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(181031, 6)"
      ]
     },
     "execution_count": 311,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "clist2 = [c for c in clist if not c.startswith(\"pseudolabel\")]\n",
    "result2 = result[result[\"model\"].isin(clist)]\n",
    "result2.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "f39c2da2",
   "metadata": {},
   "outputs": [],
   "source": [
    "p_grouped = result.groupby([\"data_id\", \"shots\", \"model\"]).agg(\n",
    "    acc_mean=('acc', 'mean'),\n",
    "    acc_count=('acc', lambda x: x.notna().sum())\n",
    ").reset_index()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "d716a6da",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "5 1492 70 /// 10 1492 0 /// 10 1493 0 /// 10 41169 62 /// "
     ]
    }
   ],
   "source": [
    "import copy\n",
    "rankdata = copy.deepcopy(p_grouped)\n",
    "\n",
    "for shot in [1, 5, 10]:\n",
    "    for data in datalist:\n",
    "#         p = result[(result[\"trial\"] == 0) & (result[\"shots\"] == shot) & (result[\"data_id\"] == data)]\n",
    "        p = p_grouped[(p_grouped[\"data_id\"] == data) & (p_grouped[\"shots\"] == shot)]\n",
    "        p = p.reset_index(drop=True)\n",
    "    \n",
    "        if len(p[p[\"acc_count\"] == 10]) == len(clist):\n",
    "            p['rank'] = p[\"acc_mean\"].rank(ascending=False, method=\"min\")\n",
    "\n",
    "            for i, row in p.iterrows():\n",
    "                rankdata.loc[(rankdata[\"data_id\"] == row[\"data_id\"]) & \n",
    "                             (rankdata[\"shots\"] == shot) & \n",
    "                             (rankdata[\"acc_mean\"] == row[\"acc_mean\"]), 'rank'] = row['rank']\n",
    "        else:\n",
    "            print(shot, data, len(p), end=\" /// \")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "id": "8dff2500",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 catboost\n",
      "1 \n",
      "2 \n",
      "3 \n",
      "4 catboost\n",
      "5 \n",
      "6 catboost\n",
      "7 catboost\n",
      "8 catboost\n",
      "9 catboost\n"
     ]
    }
   ],
   "source": [
    "for t in range(10):\n",
    "    print(t, \" \".join(map(str, [i for i in clist if i not in result[(result[\"shots\"] == 5) & (result[\"data_id\"] == 1492) & (result[\"trial\"] == t)][\"model\"].unique()])))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eb3dea9d",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8a973675",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3d8a8138",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "id": "5dc46fc1",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([   12,    18,    23,    31,    32,    35,    36,    48,    54,\n",
       "          59,   337,   846,  1067,  1459,  1462,  1467,  1476,  1479,\n",
       "        1549,  4135,  4153, 23512, 40499, 40900, 44089, 44126, 44131])"
      ]
     },
     "execution_count": 55,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# exclude = rankdata[rankdata[\"model\"].str.startswith(\"lr\") & (rankdata[\"rank\"] < 5) & (rankdata[\"shots\"] == 1)][\"data_id\"].unique()\n",
    "dt = rankdata[rankdata[\"shots\"] == 5]\n",
    "dt[dt[\"model\"].str.startswith(\"sslbin\") & (dt[\"rank\"] < 5)][\"data_id\"].unique()\n",
    "# exclude = dt[((~dt[\"model\"].str.startswith(\"sslbin\")) & (dt[\"rank\"] < 10)) & (dt[\"model\"].str.startswith(\"sslbin\") & (dt[\"rank\"] > 50))][\"data_id\"].unique()\n",
    "# dlist = [d for d in datalist if d not in exclude]\n",
    "# len(dlist)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 340,
   "id": "5c5877b1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "50\n",
      "                    data_id  shots  acc_mean  acc_count   rank\n",
      "model                                                         \n",
      "sslbinshuffling-lr   9735.8    1.0  0.446020       10.0  18.22\n",
      "knn                  9735.8    1.0  0.440147       10.0  18.70\n",
      "lr                   9735.8    1.0  0.434957       10.0  18.88\n",
      "                        data_id  shots  acc_mean  acc_count       rank\n",
      "model                                                                 \n",
      "lr                  9904.040816    5.0  0.534482       10.0  15.836735\n",
      "mlp                 9904.040816    5.0  0.543242       10.0  16.306122\n",
      "sslbinsampling-lr   9904.040816    5.0  0.543098       10.0  16.428571\n",
      "sslnoisemasking-lr  9904.040816    5.0  0.543580       10.0  16.428571\n",
      "sslshuffling-lr     9904.040816    5.0  0.542854       10.0  17.061224\n"
     ]
    }
   ],
   "source": [
    "dlist = [22, 31, 32, 48, 54, 337, 846, 1063, 1067, 1459, 1467, 1479, 1549, 4135, 23512, 40499, 40900,\n",
    "         44125, 44131, 44157, 44160,\n",
    "         1555, 454, 45062, 4153, 455, 188, 1531, 1497, 1492, 1489, 1475, 470, \n",
    "         4135, 23, 12, 307, 1493, 1503, 1509, 41162, 452, 338, 40536, 40985, 35, 59,\n",
    "         1043, 1487, 934, 18]\n",
    "# len(dlist)\n",
    "\n",
    "q = rankdata.loc[(~rankdata[\"rank\"].isna()) & (rankdata[\"data_id\"].isin(dlist))]\n",
    "print(len(q[\"data_id\"].unique()))\n",
    "print(q[q[\"shots\"] == 1].groupby([\"model\"]).mean(\"rank\").sort_values([\"shots\", \"rank\"]).head(3))\n",
    "print(q[q[\"shots\"] == 5].groupby([\"model\"]).mean(\"rank\").sort_values([\"shots\", \"rank\"]).head(5))\n",
    "# q[q[\"shots\"] == 5].groupby([\"model\"]).mean(\"rank\").sort_values([\"shots\", \"rank\"])\n",
    "# q[q[\"shots\"] == 5].groupby([\"model\"]).mean(\"acc\").sort_values(\"acc\", ascending=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 342,
   "id": "358e33be",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([  452,   455,   846,  1479,  1531,  4135, 23512, 44157])"
      ]
     },
     "execution_count": 342,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "q[(q[\"model\"] == \"sslbinsampling-lr\") & (q[\"shots\"] == 5) & (q[\"rank\"] > 30)][\"data_id\"].unique()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 355,
   "id": "7545835a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>trial</th>\n",
       "      <th>shots</th>\n",
       "      <th>data_id</th>\n",
       "      <th>model</th>\n",
       "      <th>acc</th>\n",
       "      <th>auroc</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>167633</th>\n",
       "      <td>6</td>\n",
       "      <td>5</td>\n",
       "      <td>846</td>\n",
       "      <td>sslbinsampling-lr</td>\n",
       "      <td>0.443373</td>\n",
       "      <td>0.471600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>166900</th>\n",
       "      <td>1</td>\n",
       "      <td>5</td>\n",
       "      <td>846</td>\n",
       "      <td>sslbinsampling-lr</td>\n",
       "      <td>0.573494</td>\n",
       "      <td>0.733817</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>167468</th>\n",
       "      <td>4</td>\n",
       "      <td>5</td>\n",
       "      <td>846</td>\n",
       "      <td>sslbinsampling-lr</td>\n",
       "      <td>0.616265</td>\n",
       "      <td>0.699780</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>166712</th>\n",
       "      <td>8</td>\n",
       "      <td>5</td>\n",
       "      <td>846</td>\n",
       "      <td>sslbinsampling-lr</td>\n",
       "      <td>0.622289</td>\n",
       "      <td>0.661146</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>153694</th>\n",
       "      <td>7</td>\n",
       "      <td>5</td>\n",
       "      <td>846</td>\n",
       "      <td>sslbinsampling-lr</td>\n",
       "      <td>0.641265</td>\n",
       "      <td>0.616412</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>167179</th>\n",
       "      <td>3</td>\n",
       "      <td>5</td>\n",
       "      <td>846</td>\n",
       "      <td>sslbinsampling-lr</td>\n",
       "      <td>0.656627</td>\n",
       "      <td>0.701194</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>167536</th>\n",
       "      <td>5</td>\n",
       "      <td>5</td>\n",
       "      <td>846</td>\n",
       "      <td>sslbinsampling-lr</td>\n",
       "      <td>0.691867</td>\n",
       "      <td>0.735567</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>166801</th>\n",
       "      <td>9</td>\n",
       "      <td>5</td>\n",
       "      <td>846</td>\n",
       "      <td>sslbinsampling-lr</td>\n",
       "      <td>0.714458</td>\n",
       "      <td>0.674903</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>167070</th>\n",
       "      <td>2</td>\n",
       "      <td>5</td>\n",
       "      <td>846</td>\n",
       "      <td>sslbinsampling-lr</td>\n",
       "      <td>0.731325</td>\n",
       "      <td>0.736130</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>133895</th>\n",
       "      <td>0</td>\n",
       "      <td>5</td>\n",
       "      <td>846</td>\n",
       "      <td>sslbinsampling-lr</td>\n",
       "      <td>0.756627</td>\n",
       "      <td>0.789028</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "        trial  shots  data_id              model       acc     auroc\n",
       "167633      6      5      846  sslbinsampling-lr  0.443373  0.471600\n",
       "166900      1      5      846  sslbinsampling-lr  0.573494  0.733817\n",
       "167468      4      5      846  sslbinsampling-lr  0.616265  0.699780\n",
       "166712      8      5      846  sslbinsampling-lr  0.622289  0.661146\n",
       "153694      7      5      846  sslbinsampling-lr  0.641265  0.616412\n",
       "167179      3      5      846  sslbinsampling-lr  0.656627  0.701194\n",
       "167536      5      5      846  sslbinsampling-lr  0.691867  0.735567\n",
       "166801      9      5      846  sslbinsampling-lr  0.714458  0.674903\n",
       "167070      2      5      846  sslbinsampling-lr  0.731325  0.736130\n",
       "133895      0      5      846  sslbinsampling-lr  0.756627  0.789028"
      ]
     },
     "execution_count": 355,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "result[(result[\"model\"] == \"sslbinsampling-lr\") & (result[\"shots\"] == 5) & (result[\"data_id\"] == 846)].sort_values(\"acc\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 345,
   "id": "ab0bb650",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>trial</th>\n",
       "      <th>shots</th>\n",
       "      <th>data_id</th>\n",
       "      <th>model</th>\n",
       "      <th>acc</th>\n",
       "      <th>auroc</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>129915</th>\n",
       "      <td>0</td>\n",
       "      <td>5</td>\n",
       "      <td>452</td>\n",
       "      <td>sslbinsampling-lr</td>\n",
       "      <td>0.092593</td>\n",
       "      <td>0.514827</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>153682</th>\n",
       "      <td>7</td>\n",
       "      <td>5</td>\n",
       "      <td>452</td>\n",
       "      <td>sslbinsampling-lr</td>\n",
       "      <td>0.203704</td>\n",
       "      <td>0.523021</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>162261</th>\n",
       "      <td>3</td>\n",
       "      <td>5</td>\n",
       "      <td>452</td>\n",
       "      <td>sslbinsampling-lr</td>\n",
       "      <td>0.222222</td>\n",
       "      <td>0.606719</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>162377</th>\n",
       "      <td>5</td>\n",
       "      <td>5</td>\n",
       "      <td>452</td>\n",
       "      <td>sslbinsampling-lr</td>\n",
       "      <td>0.259259</td>\n",
       "      <td>0.611190</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>162425</th>\n",
       "      <td>2</td>\n",
       "      <td>5</td>\n",
       "      <td>452</td>\n",
       "      <td>sslbinsampling-lr</td>\n",
       "      <td>0.203704</td>\n",
       "      <td>0.491740</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>162470</th>\n",
       "      <td>4</td>\n",
       "      <td>5</td>\n",
       "      <td>452</td>\n",
       "      <td>sslbinsampling-lr</td>\n",
       "      <td>0.111111</td>\n",
       "      <td>0.505329</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>162522</th>\n",
       "      <td>8</td>\n",
       "      <td>5</td>\n",
       "      <td>452</td>\n",
       "      <td>sslbinsampling-lr</td>\n",
       "      <td>0.166667</td>\n",
       "      <td>0.531689</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>162590</th>\n",
       "      <td>9</td>\n",
       "      <td>5</td>\n",
       "      <td>452</td>\n",
       "      <td>sslbinsampling-lr</td>\n",
       "      <td>0.037037</td>\n",
       "      <td>0.464545</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>162682</th>\n",
       "      <td>1</td>\n",
       "      <td>5</td>\n",
       "      <td>452</td>\n",
       "      <td>sslbinsampling-lr</td>\n",
       "      <td>0.203704</td>\n",
       "      <td>0.584176</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>162861</th>\n",
       "      <td>6</td>\n",
       "      <td>5</td>\n",
       "      <td>452</td>\n",
       "      <td>sslbinsampling-lr</td>\n",
       "      <td>0.148148</td>\n",
       "      <td>0.490353</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "        trial  shots  data_id              model       acc     auroc\n",
       "129915      0      5      452  sslbinsampling-lr  0.092593  0.514827\n",
       "153682      7      5      452  sslbinsampling-lr  0.203704  0.523021\n",
       "162261      3      5      452  sslbinsampling-lr  0.222222  0.606719\n",
       "162377      5      5      452  sslbinsampling-lr  0.259259  0.611190\n",
       "162425      2      5      452  sslbinsampling-lr  0.203704  0.491740\n",
       "162470      4      5      452  sslbinsampling-lr  0.111111  0.505329\n",
       "162522      8      5      452  sslbinsampling-lr  0.166667  0.531689\n",
       "162590      9      5      452  sslbinsampling-lr  0.037037  0.464545\n",
       "162682      1      5      452  sslbinsampling-lr  0.203704  0.584176\n",
       "162861      6      5      452  sslbinsampling-lr  0.148148  0.490353"
      ]
     },
     "execution_count": 345,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "result[~((result[\"model\"] == \"sslbinsampling-lr\") & (result[\"shots\"] == 5) & (result[\"data_id\"] == 452))]\n",
    "result.to_csv(\"/home/SemiTab/result.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c8813329",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 332,
   "id": "1a5a6f08",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>data_id</th>\n",
       "      <th>shots</th>\n",
       "      <th>model</th>\n",
       "      <th>acc_mean</th>\n",
       "      <th>acc_count</th>\n",
       "      <th>rank</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>625</th>\n",
       "      <td>18</td>\n",
       "      <td>5</td>\n",
       "      <td>lr</td>\n",
       "      <td>0.366000</td>\n",
       "      <td>10</td>\n",
       "      <td>32.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6019</th>\n",
       "      <td>934</td>\n",
       "      <td>5</td>\n",
       "      <td>lr</td>\n",
       "      <td>0.827155</td>\n",
       "      <td>10</td>\n",
       "      <td>21.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7321</th>\n",
       "      <td>1466</td>\n",
       "      <td>5</td>\n",
       "      <td>lr</td>\n",
       "      <td>0.428169</td>\n",
       "      <td>10</td>\n",
       "      <td>22.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13149</th>\n",
       "      <td>41147</td>\n",
       "      <td>5</td>\n",
       "      <td>lr</td>\n",
       "      <td>0.543635</td>\n",
       "      <td>10</td>\n",
       "      <td>22.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14079</th>\n",
       "      <td>42345</td>\n",
       "      <td>5</td>\n",
       "      <td>lr</td>\n",
       "      <td>0.446631</td>\n",
       "      <td>10</td>\n",
       "      <td>23.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14265</th>\n",
       "      <td>42734</td>\n",
       "      <td>5</td>\n",
       "      <td>lr</td>\n",
       "      <td>0.420026</td>\n",
       "      <td>10</td>\n",
       "      <td>24.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "       data_id  shots model  acc_mean  acc_count  rank\n",
       "625         18      5    lr  0.366000         10  32.0\n",
       "6019       934      5    lr  0.827155         10  21.0\n",
       "7321      1466      5    lr  0.428169         10  22.0\n",
       "13149    41147      5    lr  0.543635         10  22.0\n",
       "14079    42345      5    lr  0.446631         10  23.0\n",
       "14265    42734      5    lr  0.420026         10  24.0"
      ]
     },
     "execution_count": 332,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "rankdata[(rankdata[\"model\"] == \"lr\") & (rankdata[\"rank\"] > 20) & (rankdata[\"shots\"] == 5) & (~rankdata[\"data_id\"].isin(dlist))]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "379da867",
   "metadata": {},
   "outputs": [],
   "source": [
    "(2) 17.12 (2) 14.40"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5d7077e8",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4c2347ea",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3eb84b94",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fe00a86c",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "47f6887a",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5231d81b",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 77,
   "id": "7fc7d498",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'1555 23512 40499 40536 454 44131 44158 44161 44162 45060 45062 45548 41162 44089 44090 44091 44124 44125 44126 45068 846 934 1067 1459 1462 1464 1466 1486 1489 1492 1493 1497 1504 1509 1510 1531 35 36 54 151 182 185 307 41145 41147 41150 41168 41169 42345 42734 23 1476 45714 458 42665 12 14 16 18 22 32 1503 4153 40922 42931'"
      ]
     },
     "execution_count": 77,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "\" \".join(map(str, dlist))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6edbfbb6",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "id": "eacb64c7",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([   48,  1549,  1467,    53,   455,  4538, 41143,  1487,  1043,\n",
       "        1494, 40900, 44123,  1475, 44157, 44160,   337,    59,   188,\n",
       "          29, 40981,   470,   452,   338,    49,   475,    37,    51,\n",
       "        1063,  4135, 40985,    31,  1479,  1471, 44122])"
      ]
     },
     "execution_count": 60,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "z = rankdata.dropna().reset_index(drop=True)\n",
    "z[(z[\"model\"] == \"sslbinsampling-lr\") & (z[\"rank\"] > 30)][\"data_id\"].unique()\n",
    "# z[z[\"model\"].str.startswith(\"pseudolabel-binshuffling\") & (z[\"rank\"] < 30)][\"data_id\"].unique()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f0760803",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1076f7f6",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 113,
   "id": "ff867c59",
   "metadata": {},
   "outputs": [],
   "source": [
    "import copy\n",
    "rankdata = copy.deepcopy(p_grouped)\n",
    "\n",
    "for shot in [1, 5]:\n",
    "    for data in datalist:\n",
    "        p = p_grouped[(p_grouped[\"data_id\"] == data) & (p_grouped[\"shots\"] == shot) & (~p_grouped[\"acc_mean\"].isna())]\n",
    "        p = p.reset_index(drop=True)\n",
    "        \n",
    "        if len(p) == 43:\n",
    "            p['rank'] = p[\"acc_mean\"].rank(ascending=False, method=\"min\")\n",
    "\n",
    "            for i, row in p.iterrows():\n",
    "                rankdata.loc[(rankdata[\"data_id\"] == row[\"data_id\"]) & \n",
    "                             (rankdata[\"shots\"] == shot) & \n",
    "                             (rankdata[\"acc_mean\"] == row[\"acc_mean\"]), 'rank'] = row['rank']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 226,
   "id": "dc7196ac",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "84\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>data_id</th>\n",
       "      <th>shots</th>\n",
       "      <th>acc_mean</th>\n",
       "      <th>acc_count</th>\n",
       "      <th>rank</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>model</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>catboost</th>\n",
       "      <td>14785.072727</td>\n",
       "      <td>10.0</td>\n",
       "      <td>0.655651</td>\n",
       "      <td>10.000000</td>\n",
       "      <td>10.472727</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>tabpfn</th>\n",
       "      <td>14785.072727</td>\n",
       "      <td>10.0</td>\n",
       "      <td>0.615378</td>\n",
       "      <td>10.000000</td>\n",
       "      <td>12.381818</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>sslnoisemasking-lr</th>\n",
       "      <td>14785.072727</td>\n",
       "      <td>10.0</td>\n",
       "      <td>0.639171</td>\n",
       "      <td>7.600000</td>\n",
       "      <td>13.254545</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>lr</th>\n",
       "      <td>14785.072727</td>\n",
       "      <td>10.0</td>\n",
       "      <td>0.626370</td>\n",
       "      <td>10.000000</td>\n",
       "      <td>13.727273</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>mlp</th>\n",
       "      <td>14785.072727</td>\n",
       "      <td>10.0</td>\n",
       "      <td>0.638485</td>\n",
       "      <td>10.000000</td>\n",
       "      <td>13.727273</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>sslscarf-lr</th>\n",
       "      <td>14785.072727</td>\n",
       "      <td>10.0</td>\n",
       "      <td>0.636122</td>\n",
       "      <td>9.563636</td>\n",
       "      <td>13.763636</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>xgboost</th>\n",
       "      <td>14785.072727</td>\n",
       "      <td>10.0</td>\n",
       "      <td>0.640601</td>\n",
       "      <td>10.000000</td>\n",
       "      <td>14.218182</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>sslshuffling-lr</th>\n",
       "      <td>14785.072727</td>\n",
       "      <td>10.0</td>\n",
       "      <td>0.635215</td>\n",
       "      <td>7.618182</td>\n",
       "      <td>14.581818</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ae</th>\n",
       "      <td>14785.072727</td>\n",
       "      <td>10.0</td>\n",
       "      <td>0.632182</td>\n",
       "      <td>10.000000</td>\n",
       "      <td>15.490909</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>pseudolabel-masking</th>\n",
       "      <td>14785.072727</td>\n",
       "      <td>10.0</td>\n",
       "      <td>0.627835</td>\n",
       "      <td>10.000000</td>\n",
       "      <td>16.363636</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                          data_id  shots  acc_mean  acc_count       rank\n",
       "model                                                                   \n",
       "catboost             14785.072727   10.0  0.655651  10.000000  10.472727\n",
       "tabpfn               14785.072727   10.0  0.615378  10.000000  12.381818\n",
       "sslnoisemasking-lr   14785.072727   10.0  0.639171   7.600000  13.254545\n",
       "lr                   14785.072727   10.0  0.626370  10.000000  13.727273\n",
       "mlp                  14785.072727   10.0  0.638485  10.000000  13.727273\n",
       "sslscarf-lr          14785.072727   10.0  0.636122   9.563636  13.763636\n",
       "xgboost              14785.072727   10.0  0.640601  10.000000  14.218182\n",
       "sslshuffling-lr      14785.072727   10.0  0.635215   7.618182  14.581818\n",
       "ae                   14785.072727   10.0  0.632182  10.000000  15.490909\n",
       "pseudolabel-masking  14785.072727   10.0  0.627835  10.000000  16.363636"
      ]
     },
     "execution_count": 226,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "q = rankdata.loc[(~rankdata[\"rank\"].isna()) & (rankdata[\"data_id\"].isin(datalist))]\n",
    "print(len(q[\"data_id\"].unique()))\n",
    "q[q[\"shots\"] == 10].groupby([\"model\"]).mean(\"rank\").sort_values([\"shots\", \"rank\"]).head(10)\n",
    "# q[q[\"shots\"] == 5].groupby([\"model\"]).mean(\"rank\").sort_values([\"shots\", \"rank\"])\n",
    "# q[q[\"shots\"] == 5].groupby([\"model\"]).mean(\"acc\").sort_values(\"acc\", ascending=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6cd28c74",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e6c052c2",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b1747ff3",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "19c3a8f4",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "83696489",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fe4d0f40",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8ae32285",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "55360a05",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b9ea7a44",
   "metadata": {},
   "outputs": [],
   "source": [
    "# q = pd.merge(q, datasize[[\"data_id\", \"num_prep_data\"]], on=\"data_id\")\n",
    "\n",
    "# q[(q[\"model\"] == \"pseudolabel-zeromasking\") & (q[\"shots\"] == 1)].sort_values(\"rank\", ascending=True)\n",
    "q[q[\"model\"] == \"pseudolabel-binshuffling-uniform\"].sort_values(\"rank\", ascending=False).head(10)\n",
    "# q[(q[\"data_id\"] == 44090) & (q[\"shots\"] == 1)].sort_values(\"acc_mean\", ascending=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 263,
   "id": "6ef40cde",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'12 14 16 18 32 35 36 185 188 307 454 470 1063 1459 1464 1466 1475 1487 1492 1497 1503 1509 4153 4538 23512 40499 40536 40981 41145 41147 41169 44157 44160 45548'"
      ]
     },
     "execution_count": 263,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "d = q[(q[\"model\"] == \"pseudolabel-binshuffling-uniform\") & (q[\"rank\"] < 15) & (q[\"shots\"] == 1)][\"data_id\"].values\n",
    "\" \".join(map(str, q[(q[\"model\"] == \"pseudolabel-binshuffling-uniform\") & (q[\"data_id\"].isin(d)) & (q[\"shots\"] == 5) & (q[\"rank\"] < 15)][\"data_id\"].values))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "00f6c167",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1f1e1aaf",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "a422dd07",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "77c2a930",
   "metadata": {},
   "outputs": [],
   "source": [
    "ids = result[(result[\"model\"] == \"tabpfn\") & (result[\"acc\"].isna())][\"data_id\"].unique()\n",
    "for i in ids:\n",
    "    p = datasize[datasize[\"data_id\"] == (i)].values[0]\n",
    "    if (p[1] < 2000) & (np.isnan(p[2])):\n",
    "        print(i, end=\" \")\n",
    "    elif (p[1] < 2000) & (p[2] < 10):\n",
    "        print(i, end=\" \")    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "36b91752",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "52db6e27",
   "metadata": {},
   "outputs": [],
   "source": [
    "result.to_csv(\"result.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "0509d41b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "=====\n",
      "      shots data_id dataname                             model    acc  auroc\n",
      "2373      5    4538       GE   pseudolabel-binshuffling-random  0.408  0.681\n",
      "2375      5    4538       GE    pseudolabel-binshuffling-bound  0.403  0.681\n",
      "2354      5    4538       GE                                lr  0.388  0.695\n",
      "2374      5    4538       GE  pseudolabel-binshuffling-uniform  0.353  0.671\n",
      "2359      5    4538       GE                               mlp  0.343  0.662\n",
      "2360      5    4538       GE                                ae  0.341  0.679\n",
      "2355      5    4538       GE                               knn  0.329  0.641\n",
      "2356      5    4538       GE                           xgboost  0.326  0.641\n",
      "2370      5    4538       GE               pseudolabel-masking  0.322  0.665\n",
      "2367      5    4538       GE              sslsubtab-finetuning  0.298    0.5\n",
      "2361      5    4538       GE                               ict  0.282  0.611\n",
      "2358      5    4538       GE                          lightgbm  0.278    0.5\n",
      "2364      5    4538       GE                      sslsubtab-lr  0.278  0.498\n",
      "2365      5    4538       GE                     sslsubtab-knn  0.278  0.497\n",
      "2369      5    4538       GE                              vime  0.278   0.51\n",
      "2366      5    4538       GE              sslsubtab-lineareval  0.225    0.5\n",
      "=====\n",
      "      shots data_id dataname model    acc  auroc\n",
      "2950      5   45068       AD   mlp  0.524  0.586\n",
      "=====\n",
      "      shots data_id dataname                             model    acc  auroc\n",
      "2696      5   45062       CH  pseudolabel-binshuffling-uniform  0.795    0.5\n",
      "2691      5   45062       CH                              vime  0.791    0.5\n",
      "2686      5   45062       CH                      sslsubtab-lr  0.791  0.487\n",
      "2687      5   45062       CH                     sslsubtab-knn  0.791    0.5\n",
      "2680      5   45062       CH                          lightgbm  0.791    0.5\n",
      "2676      5   45062       CH                                lr  0.706  0.691\n",
      "2678      5   45062       CH                           xgboost  0.649  0.681\n",
      "2681      5   45062       CH                               mlp  0.625  0.736\n",
      "2682      5   45062       CH                                ae  0.609    0.5\n",
      "2683      5   45062       CH                               ict  0.609    0.5\n",
      "2679      5   45062       CH                          catboost  0.608   0.74\n",
      "2677      5   45062       CH                               knn  0.571   0.52\n",
      "2697      5   45062       CH    pseudolabel-binshuffling-bound  0.235    0.5\n",
      "2695      5   45062       CH   pseudolabel-binshuffling-random  0.231    0.5\n",
      "2692      5   45062       CH               pseudolabel-masking  0.226    0.5\n",
      "2688      5   45062       CH              sslsubtab-lineareval  0.209    0.5\n",
      "2689      5   45062       CH              sslsubtab-finetuning  0.209    0.5\n",
      "=====\n",
      "      shots data_id dataname                             model    acc  auroc\n",
      "2424      5   40536       SD                          lightgbm  0.845    0.5\n",
      "2442      5   40536       SD   pseudolabel-binshuffling-random  0.832    0.5\n",
      "2444      5   40536       SD    pseudolabel-binshuffling-bound  0.823    0.5\n",
      "2428      5   40536       SD                       meanteacher  0.606    0.5\n",
      "2420      5   40536       SD                                lr  0.560  0.551\n",
      "2425      5   40536       SD                               mlp  0.540   0.56\n",
      "2421      5   40536       SD                               knn  0.510  0.537\n",
      "2423      5   40536       SD                          catboost  0.481  0.559\n",
      "2426      5   40536       SD                                ae  0.406    0.5\n",
      "2422      5   40536       SD                           xgboost  0.385  0.506\n",
      "2427      5   40536       SD                               ict  0.314    0.5\n",
      "2430      5   40536       SD                      sslsubtab-lr  0.224  0.501\n",
      "2443      5   40536       SD  pseudolabel-binshuffling-uniform  0.183    0.5\n",
      "2439      5   40536       SD               pseudolabel-masking  0.178    0.5\n",
      "2437      5   40536       SD                sslvime-finetuning  0.155    0.5\n",
      "2438      5   40536       SD                              vime  0.155    0.5\n",
      "2434      5   40536       SD                        sslvime-lr  0.155  0.671\n",
      "2436      5   40536       SD                sslvime-lineareval  0.155    0.5\n",
      "2435      5   40536       SD                       sslvime-knn  0.155    0.5\n",
      "2433      5   40536       SD              sslsubtab-finetuning  0.155    0.5\n",
      "2431      5   40536       SD                     sslsubtab-knn  0.155    0.5\n",
      "2432      5   40536       SD              sslsubtab-lineareval  0.155    0.5\n",
      "=====\n",
      "      shots data_id dataname                             model    acc  auroc\n",
      "4392      5      14       MF                          lightgbm  0.595  0.889\n",
      "4409      5      14       MF    pseudolabel-binshuffling-bound  0.555   0.89\n",
      "4388      5      14       MF                                lr  0.552  0.877\n",
      "4407      5      14       MF   pseudolabel-binshuffling-random  0.535  0.885\n",
      "4404      5      14       MF               pseudolabel-masking  0.532  0.876\n",
      "4408      5      14       MF  pseudolabel-binshuffling-uniform  0.530  0.883\n",
      "4393      5      14       MF                               mlp  0.522  0.873\n",
      "4394      5      14       MF                                ae  0.517  0.836\n",
      "4389      5      14       MF                               knn  0.502  0.845\n",
      "4395      5      14       MF                               ict  0.460  0.796\n",
      "4403      5      14       MF                              vime  0.142  0.476\n",
      "4400      5      14       MF              sslsubtab-lineareval  0.095    0.5\n",
      "4401      5      14       MF              sslsubtab-finetuning  0.090    0.5\n",
      "4398      5      14       MF                      sslsubtab-lr  0.083  0.489\n",
      "4399      5      14       MF                     sslsubtab-knn  0.083  0.498\n",
      "=====\n",
      "      shots data_id dataname                             model    acc  auroc\n",
      "4454      5      22       MZ                                lr  0.635  0.923\n",
      "4460      5      22       MZ                                ae  0.635  0.918\n",
      "4474      5      22       MZ  pseudolabel-binshuffling-uniform  0.632  0.928\n",
      "4470      5      22       MZ               pseudolabel-masking  0.625  0.928\n",
      "4459      5      22       MZ                               mlp  0.618  0.918\n",
      "4473      5      22       MZ   pseudolabel-binshuffling-random  0.613  0.927\n",
      "4475      5      22       MZ    pseudolabel-binshuffling-bound  0.605  0.923\n",
      "4461      5      22       MZ                               ict  0.560  0.866\n",
      "4455      5      22       MZ                               knn  0.517  0.875\n",
      "4458      5      22       MZ                          lightgbm  0.517  0.865\n",
      "4456      5      22       MZ                           xgboost  0.440  0.797\n",
      "4469      5      22       MZ                              vime  0.205   0.65\n",
      "4467      5      22       MZ              sslsubtab-finetuning  0.095    0.5\n",
      "4464      5      22       MZ                      sslsubtab-lr  0.092  0.513\n",
      "4465      5      22       MZ                     sslsubtab-knn  0.092  0.502\n",
      "4466      5      22       MZ              sslsubtab-lineareval  0.083    0.5\n",
      "=====\n",
      "      shots data_id dataname                model    acc auroc\n",
      "3419      5    1486       NM                   ae  0.796   0.5\n",
      "3421      5    1486       NM          meanteacher  0.786   0.5\n",
      "3420      5    1486       NM                  ict  0.774   0.5\n",
      "3426      5    1486       NM  pseudolabel-masking  0.759   0.5\n",
      "3425      5    1486       NM                 vime  0.729   0.5\n",
      "=====\n",
      "      shots data_id dataname                             model    acc  auroc\n",
      "3080      5     470       PB                           xgboost  0.630    0.5\n",
      "3082      5     470       PB                          lightgbm  0.630    0.5\n",
      "3088      5     470       PB                      sslrecon-knn  0.630    0.5\n",
      "3097      5     470       PB               pseudolabel-masking  0.563    0.5\n",
      "3101      5     470       PB  pseudolabel-binshuffling-uniform  0.533    0.5\n",
      "3078      5     470       PB                                lr  0.481  0.478\n",
      "3086      5     470       PB                       meanteacher  0.481    0.5\n",
      "3102      5     470       PB    pseudolabel-binshuffling-bound  0.481    0.5\n",
      "3084      5     470       PB                                ae  0.474    0.5\n",
      "3081      5     470       PB                          catboost  0.467  0.504\n",
      "3100      5     470       PB   pseudolabel-binshuffling-random  0.459    0.5\n",
      "3085      5     470       PB                               ict  0.430    0.5\n",
      "3079      5     470       PB                               knn  0.400  0.412\n",
      "3096      5     470       PB                              vime  0.400    0.5\n",
      "3087      5     470       PB                       sslrecon-lr  0.370    0.5\n",
      "3090      5     470       PB               sslrecon-finetuning  0.370    0.5\n",
      "3092      5     470       PB                        sslvime-lr  0.370  0.443\n",
      "3093      5     470       PB                       sslvime-knn  0.370    0.5\n",
      "3094      5     470       PB                sslvime-lineareval  0.370    0.5\n",
      "3095      5     470       PB                sslvime-finetuning  0.370    0.5\n",
      "3089      5     470       PB               sslrecon-lineareval  0.370    0.5\n",
      "=====\n",
      "      shots data_id dataname                model    acc auroc\n",
      "3457      5    1489       PO                   ae  0.739   0.5\n",
      "3464      5    1489       PO  pseudolabel-masking  0.698   0.5\n",
      "3459      5    1489       PO          meanteacher  0.687   0.5\n",
      "3458      5    1489       PO                  ict  0.638   0.5\n",
      "3463      5    1489       PO                 vime  0.301   0.5\n",
      "=====\n",
      "      shots data_id dataname                             model    acc  auroc\n",
      "2382      5   23512       HI                                ae  0.543    0.5\n",
      "2377      5   23512       HI                               knn  0.543  0.537\n",
      "2383      5   23512       HI                               ict  0.539    0.5\n",
      "2391      5   23512       HI                       sslvime-knn  0.539  0.512\n",
      "2376      5   23512       HI                                lr  0.537  0.549\n",
      "2381      5   23512       HI                               mlp  0.530  0.545\n",
      "2400      5   23512       HI    pseudolabel-binshuffling-bound  0.530    0.5\n",
      "2379      5   23512       HI                          catboost  0.530   0.55\n",
      "2386      5   23512       HI                      sslsubtab-lr  0.530    0.5\n",
      "2387      5   23512       HI                     sslsubtab-knn  0.530    0.5\n",
      "2389      5   23512       HI              sslsubtab-finetuning  0.530    0.5\n",
      "2392      5   23512       HI                sslvime-lineareval  0.530    0.5\n",
      "2393      5   23512       HI                sslvime-finetuning  0.530    0.5\n",
      "2395      5   23512       HI               pseudolabel-masking  0.530    0.5\n",
      "2398      5   23512       HI   pseudolabel-binshuffling-random  0.530    0.5\n",
      "2399      5   23512       HI  pseudolabel-binshuffling-uniform  0.530    0.5\n",
      "2388      5   23512       HI              sslsubtab-lineareval  0.530    0.5\n",
      "2390      5   23512       HI                        sslvime-lr  0.524  0.513\n",
      "2378      5   23512       HI                           xgboost  0.509  0.508\n",
      "2380      5   23512       HI                          lightgbm  0.470    0.5\n",
      "2394      5   23512       HI                              vime  0.470    0.5\n",
      "=====\n",
      "      shots data_id dataname     model    acc  auroc\n",
      "2603      5   44160       RL  catboost  0.560  0.569\n",
      "2602      5   44160       RL   xgboost  0.556  0.595\n",
      "2605      5   44160       RL       mlp  0.551  0.563\n",
      "2600      5   44160       RL        lr  0.543  0.549\n",
      "2606      5   44160       RL        ae  0.541    0.5\n",
      "2601      5   44160       RL       knn  0.534  0.527\n",
      "2604      5   44160       RL  lightgbm  0.516    0.5\n",
      "=====\n",
      "      shots data_id dataname     model    acc  auroc\n",
      "2779      5   44091       WB       mlp  0.683  0.698\n",
      "2775      5   44091       WB       knn  0.681  0.699\n",
      "2774      5   44091       WB        lr  0.679  0.701\n",
      "2777      5   44091       WB  catboost  0.652   0.67\n",
      "2780      5   44091       WB        ae  0.646    0.5\n",
      "2778      5   44091       WB  lightgbm  0.491    0.5\n",
      "2776      5   44091       WB   xgboost  0.466  0.467\n",
      "=====\n",
      "      shots data_id dataname               model    acc  auroc\n",
      "2565      5   44158       KU                  ae  0.627    0.5\n",
      "2562      5   44158       KU            catboost  0.595   0.66\n",
      "2564      5   44158       KU                 mlp  0.577   0.62\n",
      "2561      5   44158       KU             xgboost  0.565   0.58\n",
      "2559      5   44158       KU                  lr  0.548  0.574\n",
      "2560      5   44158       KU                 knn  0.548  0.553\n",
      "2563      5   44158       KU            lightgbm  0.514    0.5\n",
      "2570      5   44158       KU          sslvime-lr  0.514  0.501\n",
      "2571      5   44158       KU         sslvime-knn  0.514    0.5\n",
      "2572      5   44158       KU  sslvime-lineareval  0.486    0.5\n",
      "2573      5   44158       KU  sslvime-finetuning  0.486    0.5\n",
      "=====\n",
      "      shots data_id dataname     model    acc  auroc\n",
      "2834      5   44124       KL  catboost  0.723  0.826\n",
      "2831      5   44124       KL        lr  0.706  0.816\n",
      "2836      5   44124       KL       mlp  0.703  0.711\n",
      "2837      5   44124       KL        ae  0.699    0.5\n",
      "2833      5   44124       KL   xgboost  0.632   0.63\n",
      "2832      5   44124       KL       knn  0.590  0.601\n",
      "2835      5   44124       KL  lightgbm  0.506    0.5\n",
      "=====\n",
      "      shots data_id dataname     model    acc  auroc\n",
      "2546      5   44157       EM        ae  0.535    0.5\n",
      "2540      5   44157       EM        lr  0.526  0.534\n",
      "2541      5   44157       EM       knn  0.525  0.538\n",
      "2545      5   44157       EM       mlp  0.524  0.533\n",
      "2543      5   44157       EM  catboost  0.507   0.52\n",
      "2544      5   44157       EM  lightgbm  0.506    0.5\n",
      "2542      5   44157       EM   xgboost  0.479  0.508\n",
      "=====\n",
      "      shots data_id dataname     model    acc  auroc\n",
      "2799      5   44122      POB        ae  0.548    0.5\n",
      "2798      5   44122      POB       mlp  0.545  0.567\n",
      "2793      5   44122      POB        lr  0.531  0.566\n",
      "2797      5   44122      POB  lightgbm  0.515    0.5\n",
      "2794      5   44122      POB       knn  0.510  0.535\n",
      "2796      5   44122      POB  catboost  0.505  0.517\n",
      "2795      5   44122      POB   xgboost  0.499  0.497\n",
      "=====\n",
      "      shots data_id dataname     model    acc  auroc\n",
      "2874      5   44126       BM       mlp  0.649  0.688\n",
      "2869      5   44126       BM        lr  0.624  0.654\n",
      "2875      5   44126       BM        ae  0.603    0.5\n",
      "2872      5   44126       BM  catboost  0.599  0.607\n",
      "2870      5   44126       BM       knn  0.594  0.609\n",
      "2873      5   44126       BM  lightgbm  0.513    0.5\n",
      "2871      5   44126       BM   xgboost  0.466  0.465\n",
      "=====\n",
      "      shots data_id dataname     model    acc  auroc\n",
      "2505      5   44129      HIL  catboost  0.549  0.569\n",
      "2504      5   44129      HIL   xgboost  0.546  0.546\n",
      "2508      5   44129      HIL        ae  0.517    0.5\n",
      "2507      5   44129      HIL       mlp  0.515   0.52\n",
      "2502      5   44129      HIL        lr  0.513  0.518\n",
      "2503      5   44129      HIL       knn  0.509  0.512\n",
      "2506      5   44129      HIL  lightgbm  0.499    0.5\n",
      "=====\n",
      "      shots data_id dataname     model    acc  auroc\n",
      "2521      5   44131       JN        lr  0.642  0.705\n",
      "2524      5   44131       JN  catboost  0.632   0.67\n",
      "2522      5   44131       JN       knn  0.629  0.648\n",
      "2526      5   44131       JN       mlp  0.628  0.689\n",
      "2527      5   44131       JN        ae  0.547    0.5\n",
      "2525      5   44131       JN  lightgbm  0.493    0.5\n",
      "2523      5   44131       JN   xgboost  0.487  0.486\n",
      "=====\n",
      "      shots data_id dataname     model    acc  auroc\n",
      "2622      5   44161       RS  catboost  0.586  0.621\n",
      "2621      5   44161       RS   xgboost  0.538  0.537\n",
      "2625      5   44161       RS        ae  0.522    0.5\n",
      "2620      5   44161       RS       knn  0.521  0.529\n",
      "2619      5   44161       RS        lr  0.517  0.521\n",
      "2623      5   44161       RS  lightgbm  0.505    0.5\n",
      "2624      5   44161       RS       mlp  0.496    0.5\n",
      "=====\n",
      "      shots data_id dataname     model    acc  auroc\n",
      "2853      5   44125       MT  catboost  0.696  0.776\n",
      "2850      5   44125       MT        lr  0.678   0.76\n",
      "2856      5   44125       MT        ae  0.623    0.5\n",
      "2855      5   44125       MT       mlp  0.599  0.741\n",
      "2851      5   44125       MT       knn  0.567  0.633\n",
      "2852      5   44125       MT   xgboost  0.519  0.636\n",
      "2854      5   44125       MT  lightgbm  0.488    0.5\n",
      "=====\n",
      "      shots data_id dataname     model    acc  auroc\n",
      "2815      5   44123      HOB  catboost  0.709  0.773\n",
      "2814      5   44123      HOB   xgboost  0.696  0.696\n",
      "2817      5   44123      HOB       mlp  0.675  0.742\n",
      "2818      5   44123      HOB        ae  0.672    0.5\n",
      "2812      5   44123      HOB        lr  0.671  0.748\n",
      "2813      5   44123      HOB       knn  0.658  0.721\n",
      "2816      5   44123      HOB  lightgbm  0.480    0.5\n",
      "=====\n",
      "      shots data_id dataname     model    acc  auroc\n",
      "2587      5   44159       CO        ae  0.534    0.5\n",
      "2586      5   44159       CO       mlp  0.525  0.537\n",
      "2583      5   44159       CO   xgboost  0.521  0.521\n",
      "2584      5   44159       CO  catboost  0.520  0.527\n",
      "2581      5   44159       CO        lr  0.520  0.534\n",
      "2585      5   44159       CO  lightgbm  0.500    0.5\n",
      "2582      5   44159       CO       knn  0.497  0.512\n",
      "=====\n",
      "      shots data_id dataname     model    acc  auroc\n",
      "2741      5   44089       CR       mlp  0.647  0.708\n",
      "2739      5   44089       CR  catboost  0.626  0.725\n",
      "2736      5   44089       CR        lr  0.618  0.727\n",
      "2742      5   44089       CR        ae  0.580    0.5\n",
      "2737      5   44089       CR       knn  0.534  0.647\n",
      "2738      5   44089       CR   xgboost  0.492    0.5\n",
      "2740      5   44089       CR  lightgbm  0.492    0.5\n",
      "=====\n",
      "      shots data_id dataname     model    acc  auroc\n",
      "2643      5   44162      CPS       mlp  0.595  0.624\n",
      "2641      5   44162      CPS  catboost  0.570  0.596\n",
      "2644      5   44162      CPS        ae  0.568    0.5\n",
      "2638      5   44162      CPS        lr  0.519  0.561\n",
      "2642      5   44162      CPS  lightgbm  0.494    0.5\n",
      "2640      5   44162      CPS   xgboost  0.486  0.487\n",
      "2639      5   44162      CPS       knn  0.405  0.407\n",
      "=====\n",
      "      shots data_id dataname     model    acc  auroc\n",
      "2755      5   44090      CAB        lr  0.697  0.765\n",
      "2761      5   44090      CAB        ae  0.641    0.5\n",
      "2760      5   44090      CAB       mlp  0.622  0.663\n",
      "2758      5   44090      CAB  catboost  0.619  0.675\n",
      "2757      5   44090      CAB   xgboost  0.554  0.553\n",
      "2756      5   44090      CAB       knn  0.542  0.554\n",
      "2759      5   44090      CAB  lightgbm  0.499    0.5\n",
      "=====\n",
      "      shots data_id dataname     model    acc  auroc\n",
      "2661      5   45060       OS  lightgbm  0.846    0.5\n",
      "2659      5   45060       OS   xgboost  0.698  0.542\n",
      "2658      5   45060       OS       knn  0.615  0.493\n",
      "2660      5   45060       OS  catboost  0.598  0.701\n",
      "2662      5   45060       OS       mlp  0.586  0.707\n",
      "2663      5   45060       OS        ae  0.515    0.5\n",
      "2657      5   45060       OS        lr  0.495  0.546\n",
      "=====\n",
      "      shots data_id dataname     model    acc  auroc\n",
      "2698      5   45548       OT        lr  0.441  0.838\n",
      "2703      5   45548       OT       mlp  0.420  0.834\n",
      "2700      5   45548       OT   xgboost  0.399  0.769\n",
      "2704      5   45548       OT        ae  0.318  0.724\n",
      "2699      5   45548       OT       knn  0.288  0.713\n",
      "2702      5   45548       OT  lightgbm  0.233  0.697\n",
      "=====\n",
      "      shots data_id dataname                             model    acc auroc\n",
      "3309      5    1464       BT   pseudolabel-binshuffling-random  0.667   0.5\n",
      "3310      5    1464       BT  pseudolabel-binshuffling-uniform  0.633   0.5\n",
      "3299      5    1464       BT                                ae  0.620   0.5\n",
      "3311      5    1464       BT    pseudolabel-binshuffling-bound  0.613   0.5\n",
      "3306      5    1464       BT               pseudolabel-masking  0.593   0.5\n",
      "3301      5    1464       BT                       meanteacher  0.533   0.5\n",
      "3305      5    1464       BT                              vime  0.533   0.5\n",
      "3300      5    1464       BT                               ict  0.487   0.5\n",
      "=====\n",
      "      shots data_id dataname                model    acc auroc\n",
      "3444      5    1487       OZ                 vime  0.937   0.5\n",
      "3445      5    1487       OZ  pseudolabel-masking  0.874   0.5\n",
      "3438      5    1487       OZ                   ae  0.633   0.5\n",
      "3439      5    1487       OZ                  ict  0.316   0.5\n",
      "3440      5    1487       OZ          meanteacher  0.243   0.5\n",
      "=====\n",
      "      shots data_id dataname                             model    acc auroc\n",
      "3400      5    1479       HV                                ae  0.510   0.5\n",
      "3401      5    1479       HV                               ict  0.506   0.5\n",
      "3410      5    1479       HV   pseudolabel-binshuffling-random  0.506   0.5\n",
      "3411      5    1479       HV  pseudolabel-binshuffling-uniform  0.506   0.5\n",
      "3406      5    1479       HV                              vime  0.498   0.5\n",
      "3407      5    1479       HV               pseudolabel-masking  0.494   0.5\n",
      "3412      5    1479       HV    pseudolabel-binshuffling-bound  0.494   0.5\n",
      "3402      5    1479       HV                       meanteacher  0.486   0.5\n",
      "=====\n",
      "      shots data_id dataname                             model    acc auroc\n",
      "3361      5    1471      EEG                       meanteacher  0.579   0.5\n",
      "3359      5    1471      EEG                                ae  0.566   0.5\n",
      "3370      5    1471      EEG  pseudolabel-binshuffling-uniform  0.563   0.5\n",
      "3366      5    1471      EEG               pseudolabel-masking  0.557   0.5\n",
      "3360      5    1471      EEG                               ict  0.556   0.5\n",
      "3371      5    1471      EEG    pseudolabel-binshuffling-bound  0.443   0.5\n",
      "3369      5    1471      EEG   pseudolabel-binshuffling-random  0.442   0.5\n",
      "3365      5    1471      EEG                              vime  0.440   0.5\n",
      "=====\n",
      "      shots data_id dataname                             model    acc auroc\n",
      "3346      5    1467      CMC                              vime  0.935   0.5\n",
      "3347      5    1467      CMC               pseudolabel-masking  0.907   0.5\n",
      "3351      5    1467      CMC  pseudolabel-binshuffling-uniform  0.907   0.5\n",
      "3350      5    1467      CMC   pseudolabel-binshuffling-random  0.898   0.5\n",
      "3352      5    1467      CMC    pseudolabel-binshuffling-bound  0.898   0.5\n",
      "3341      5    1467      CMC                               ict  0.648   0.5\n",
      "3340      5    1467      CMC                                ae  0.639   0.5\n",
      "3342      5    1467      CMC                       meanteacher  0.565   0.5\n",
      "=====\n",
      "      shots data_id dataname model    acc  auroc\n",
      "4276      5     458      ATH   mlp  0.929  0.995\n",
      "=====\n",
      "      shots data_id dataname                             model    acc  auroc\n",
      "2351      5    1555      AUL   pseudolabel-binshuffling-random  0.150  0.515\n",
      "2352      5    1555      AUL  pseudolabel-binshuffling-uniform  0.150  0.505\n",
      "2338      5    1555      AUL                                ae  0.145  0.521\n",
      "2348      5    1555      AUL               pseudolabel-masking  0.145  0.522\n",
      "2332      5    1555      AUL                                lr  0.140  0.532\n",
      "2337      5    1555      AUL                               mlp  0.140  0.512\n",
      "2343      5    1555      AUL                     sslsubtab-knn  0.140    0.5\n",
      "2344      5    1555      AUL              sslsubtab-lineareval  0.135    0.5\n",
      "2353      5    1555      AUL    pseudolabel-binshuffling-bound  0.135  0.519\n",
      "2333      5    1555      AUL                               knn  0.120  0.493\n",
      "2334      5    1555      AUL                           xgboost  0.110  0.519\n",
      "2336      5    1555      AUL                          lightgbm  0.110  0.523\n",
      "2342      5    1555      AUL                      sslsubtab-lr  0.105  0.517\n",
      "2345      5    1555      AUL              sslsubtab-finetuning  0.100   0.52\n",
      "2347      5    1555      AUL                              vime  0.100  0.562\n",
      "2339      5    1555      AUL                               ict  0.095  0.518\n",
      "=====\n",
      "      shots data_id dataname                             model    acc  auroc\n",
      "4499      5      48      TAE                               knn  0.452  0.697\n",
      "4500      5      48      TAE                           xgboost  0.452  0.632\n",
      "4503      5      48      TAE                               mlp  0.452  0.593\n",
      "4508      5      48      TAE                      sslsubtab-lr  0.452  0.503\n",
      "4511      5      48      TAE              sslsubtab-finetuning  0.452    0.5\n",
      "4505      5      48      TAE                               ict  0.419  0.622\n",
      "4498      5      48      TAE                                lr  0.387  0.668\n",
      "4504      5      48      TAE                                ae  0.387  0.606\n",
      "4514      5      48      TAE               pseudolabel-masking  0.355  0.593\n",
      "4518      5      48      TAE  pseudolabel-binshuffling-uniform  0.355   0.64\n",
      "4510      5      48      TAE              sslsubtab-lineareval  0.323    0.5\n",
      "4513      5      48      TAE                              vime  0.323  0.456\n",
      "4517      5      48      TAE   pseudolabel-binshuffling-random  0.323  0.661\n",
      "4519      5      48      TAE    pseudolabel-binshuffling-bound  0.323  0.652\n",
      "4502      5      48      TAE                          lightgbm  0.226    0.5\n",
      "4509      5      48      TAE                     sslsubtab-knn  0.226  0.412\n",
      "=====\n",
      "      shots data_id dataname                             model    acc auroc\n",
      "3280      5    1462      BNB                                ae  0.840   0.5\n",
      "3290      5    1462      BNB   pseudolabel-binshuffling-random  0.818   0.5\n",
      "3281      5    1462      BNB                               ict  0.815   0.5\n",
      "3291      5    1462      BNB  pseudolabel-binshuffling-uniform  0.785   0.5\n",
      "3292      5    1462      BNB    pseudolabel-binshuffling-bound  0.785   0.5\n",
      "3282      5    1462      BNB                       meanteacher  0.724   0.5\n",
      "3287      5    1462      BNB               pseudolabel-masking  0.720   0.5\n",
      "3286      5    1462      BNB                              vime  0.625   0.5\n",
      "=====\n",
      "      shots data_id dataname                 model    acc  auroc\n",
      "4137      5      23       CE               xgboost  0.420  0.583\n",
      "4139      5      23       CE              lightgbm  0.403    0.5\n",
      "4145      5      23       CE          sslsubtab-lr  0.403  0.515\n",
      "4146      5      23       CE         sslsubtab-knn  0.403  0.498\n",
      "4148      5      23       CE  sslsubtab-finetuning  0.403    0.5\n",
      "4143      5      23       CE           meanteacher  0.349  0.566\n",
      "4135      5      23       CE                    lr  0.325  0.555\n",
      "4141      5      23       CE                    ae  0.315  0.535\n",
      "4142      5      23       CE                   ict  0.312  0.534\n",
      "4151      5      23       CE   pseudolabel-masking  0.305  0.561\n",
      "4136      5      23       CE                   knn  0.298  0.523\n",
      "4150      5      23       CE                  vime  0.237   0.53\n",
      "4147      5      23       CE  sslsubtab-lineareval  0.220    0.5\n",
      "=====\n",
      "      shots data_id dataname                             model    acc  auroc\n",
      "3391      5    1475       FT   pseudolabel-binshuffling-random  0.261  0.589\n",
      "3378      5    1475       FT                                ae  0.223  0.568\n",
      "3392      5    1475       FT  pseudolabel-binshuffling-uniform  0.206  0.589\n",
      "3388      5    1475       FT               pseudolabel-masking  0.198  0.574\n",
      "3379      5    1475       FT                               ict  0.194  0.555\n",
      "3393      5    1475       FT    pseudolabel-binshuffling-bound  0.194  0.582\n",
      "3385      5    1475       FT              sslsubtab-finetuning  0.190    0.5\n",
      "3380      5    1475       FT                       meanteacher  0.189  0.582\n",
      "3382      5    1475       FT                      sslsubtab-lr  0.171  0.513\n",
      "3383      5    1475       FT                     sslsubtab-knn  0.171  0.505\n",
      "3387      5    1475       FT                              vime  0.143  0.566\n",
      "3384      5    1475       FT              sslsubtab-lineareval  0.097    0.5\n",
      "=====\n",
      "      shots data_id dataname                             model    acc  auroc\n",
      "4382      5      12       MA               pseudolabel-masking  0.863  0.989\n",
      "4366      5      12       MA                                lr  0.853  0.987\n",
      "4385      5      12       MA   pseudolabel-binshuffling-random  0.848  0.984\n",
      "4387      5      12       MA    pseudolabel-binshuffling-bound  0.823  0.983\n",
      "4371      5      12       MA                               mlp  0.818  0.982\n",
      "4372      5      12       MA                                ae  0.812  0.977\n",
      "4386      5      12       MA  pseudolabel-binshuffling-uniform  0.805  0.982\n",
      "4373      5      12       MA                               ict  0.792  0.972\n",
      "4370      5      12       MA                          lightgbm  0.787  0.974\n",
      "4367      5      12       MA                               knn  0.782  0.967\n",
      "4368      5      12       MA                           xgboost  0.685  0.921\n",
      "4381      5      12       MA                              vime  0.305  0.785\n",
      "4378      5      12       MA              sslsubtab-lineareval  0.107    0.5\n",
      "4376      5      12       MA                      sslsubtab-lr  0.095  0.497\n",
      "4379      5      12       MA              sslsubtab-finetuning  0.095    0.5\n",
      "4377      5      12       MA                     sslsubtab-knn  0.085    0.5\n",
      "=====\n",
      "      shots data_id dataname                             model    acc  auroc\n",
      "4439      5      18       MR                               ict  0.670  0.952\n",
      "4448      5      18       MR               pseudolabel-masking  0.657  0.952\n",
      "4438      5      18       MR                                ae  0.637  0.941\n",
      "4437      5      18       MR                               mlp  0.620  0.942\n",
      "4447      5      18       MR                              vime  0.615  0.892\n",
      "4434      5      18       MR                           xgboost  0.578  0.915\n",
      "4451      5      18       MR   pseudolabel-binshuffling-random  0.542  0.931\n",
      "4453      5      18       MR    pseudolabel-binshuffling-bound  0.537   0.93\n",
      "4452      5      18       MR  pseudolabel-binshuffling-uniform  0.507   0.93\n",
      "4433      5      18       MR                               knn  0.490  0.888\n",
      "4436      5      18       MR                          lightgbm  0.380  0.874\n",
      "4432      5      18       MR                                lr  0.315  0.839\n",
      "4444      5      18       MR              sslsubtab-lineareval  0.128    0.5\n",
      "4442      5      18       MR                      sslsubtab-lr  0.105  0.498\n",
      "4445      5      18       MR              sslsubtab-finetuning  0.105    0.5\n",
      "4443      5      18       MR                     sslsubtab-knn  0.095    0.5\n",
      "=====\n",
      "      shots data_id dataname                             model    acc  auroc\n",
      "4431      5      16       MK    pseudolabel-binshuffling-bound  0.807  0.964\n",
      "4430      5      16       MK  pseudolabel-binshuffling-uniform  0.797  0.968\n",
      "4426      5      16       MK               pseudolabel-masking  0.792  0.969\n",
      "4415      5      16       MK                               mlp  0.787  0.963\n",
      "4429      5      16       MK   pseudolabel-binshuffling-random  0.787  0.965\n",
      "4410      5      16       MK                                lr  0.782  0.959\n",
      "4416      5      16       MK                                ae  0.745  0.958\n",
      "4411      5      16       MK                               knn  0.680  0.927\n",
      "4417      5      16       MK                               ict  0.657  0.909\n",
      "4414      5      16       MK                          lightgbm  0.655  0.931\n",
      "4412      5      16       MK                           xgboost  0.635  0.888\n",
      "4420      5      16       MK                      sslsubtab-lr  0.128  0.512\n",
      "4421      5      16       MK                     sslsubtab-knn  0.105  0.498\n",
      "4425      5      16       MK                              vime  0.102  0.514\n",
      "4422      5      16       MK              sslsubtab-lineareval  0.095    0.5\n",
      "4423      5      16       MK              sslsubtab-finetuning  0.095    0.5\n",
      "=====\n",
      "      shots data_id dataname                             model    acc  auroc\n",
      "4496      5      32       PE  pseudolabel-binshuffling-uniform  0.887  0.975\n",
      "4495      5      32       PE   pseudolabel-binshuffling-random  0.884  0.978\n",
      "4497      5      32       PE    pseudolabel-binshuffling-bound  0.874  0.972\n",
      "4492      5      32       PE               pseudolabel-masking  0.861  0.976\n",
      "4482      5      32       PE                                ae  0.860  0.967\n",
      "4481      5      32       PE                               mlp  0.840  0.973\n",
      "4476      5      32       PE                                lr  0.769  0.966\n",
      "4483      5      32       PE                               ict  0.764  0.956\n",
      "4477      5      32       PE                               knn  0.712  0.957\n",
      "4480      5      32       PE                          lightgbm  0.693  0.945\n",
      "4478      5      32       PE                           xgboost  0.650  0.918\n",
      "4491      5      32       PE                              vime  0.327  0.721\n",
      "4487      5      32       PE                     sslsubtab-knn  0.097  0.502\n",
      "4489      5      32       PE              sslsubtab-finetuning  0.095    0.5\n",
      "4486      5      32       PE                      sslsubtab-lr  0.094  0.506\n",
      "4488      5      32       PE              sslsubtab-lineareval  0.094    0.5\n",
      "=====\n",
      "      shots data_id dataname                             model    acc  auroc\n",
      "3319      5    1466      CTC                               ict  0.535  0.892\n",
      "3328      5    1466      CTC               pseudolabel-masking  0.526  0.923\n",
      "3318      5    1466      CTC                                ae  0.495    0.9\n",
      "3333      5    1466      CTC    pseudolabel-binshuffling-bound  0.486  0.917\n",
      "3331      5    1466      CTC   pseudolabel-binshuffling-random  0.481  0.907\n",
      "3332      5    1466      CTC  pseudolabel-binshuffling-uniform  0.474  0.913\n",
      "3320      5    1466      CTC                       meanteacher  0.469    0.9\n",
      "3325      5    1466      CTC              sslsubtab-finetuning  0.195    0.5\n",
      "3327      5    1466      CTC                              vime  0.181   0.67\n",
      "3323      5    1466      CTC                     sslsubtab-knn  0.045  0.504\n",
      "3322      5    1466      CTC                      sslsubtab-lr  0.042  0.486\n",
      "3324      5    1466      CTC              sslsubtab-lineareval  0.040    0.5\n",
      "=====\n",
      "      shots data_id dataname                             model    acc  auroc\n",
      "2414      5   40499       TX               pseudolabel-masking  0.883  0.991\n",
      "2406      5   40499       TX                               mlp  0.872  0.988\n",
      "2418      5   40499       TX  pseudolabel-binshuffling-uniform  0.863  0.988\n",
      "2407      5   40499       TX                                ae  0.861   0.99\n",
      "2417      5   40499       TX   pseudolabel-binshuffling-random  0.857  0.988\n",
      "2419      5   40499       TX    pseudolabel-binshuffling-bound  0.855  0.989\n",
      "2408      5   40499       TX                               ict  0.853  0.981\n",
      "2401      5   40499       TX                                lr  0.800  0.968\n",
      "2405      5   40499       TX                          lightgbm  0.674  0.941\n",
      "2403      5   40499       TX                           xgboost  0.651   0.91\n",
      "2402      5   40499       TX                               knn  0.632  0.926\n",
      "2413      5   40499       TX                              vime  0.570  0.947\n",
      "=====\n",
      "Empty DataFrame\n",
      "Columns: [shots, data_id, dataname, model, acc, auroc]\n",
      "Index: []\n",
      "=====\n",
      "      shots data_id dataname                             model    acc  auroc\n",
      "2447      5   40685       SH                           xgboost  0.951  0.991\n",
      "2458      5   40685       SH               pseudolabel-masking  0.885  0.952\n",
      "2453      5   40685       SH                       meanteacher  0.883  0.903\n",
      "2449      5   40685       SH                          lightgbm  0.792    0.5\n",
      "2463      5   40685       SH    pseudolabel-binshuffling-bound  0.774  0.956\n",
      "2451      5   40685       SH                                ae  0.763  0.942\n",
      "2462      5   40685       SH  pseudolabel-binshuffling-uniform  0.755  0.941\n",
      "2450      5   40685       SH                               mlp  0.733  0.966\n",
      "2461      5   40685       SH   pseudolabel-binshuffling-random  0.685  0.938\n",
      "2445      5   40685       SH                                lr  0.602  0.936\n",
      "2446      5   40685       SH                               knn  0.570   0.85\n",
      "2452      5   40685       SH                               ict  0.485  0.896\n",
      "2457      5   40685       SH                              vime  0.159  0.663\n",
      "=====\n",
      "      shots data_id dataname                             model    acc  auroc\n",
      "4586      5   42931      AUD                          catboost  0.884  0.948\n",
      "4589      5   42931      AUD                                ae  0.706    0.5\n",
      "4599      5   42931      AUD   pseudolabel-binshuffling-random  0.687    0.5\n",
      "4601      5   42931      AUD    pseudolabel-binshuffling-bound  0.661    0.5\n",
      "4585      5   42931      AUD                           xgboost  0.658  0.653\n",
      "4588      5   42931      AUD                               mlp  0.655  0.797\n",
      "4600      5   42931      AUD  pseudolabel-binshuffling-uniform  0.648    0.5\n",
      "4596      5   42931      AUD               pseudolabel-masking  0.635    0.5\n",
      "4583      5   42931      AUD                                lr  0.626  0.798\n",
      "4584      5   42931      AUD                               knn  0.603  0.728\n",
      "4590      5   42931      AUD                               ict  0.574    0.5\n",
      "4595      5   42931      AUD                              vime  0.526    0.5\n",
      "4587      5   42931      AUD                          lightgbm  0.481    0.5\n",
      "=====\n",
      "      shots data_id dataname     model    acc  auroc\n",
      "2721      5   41162       KB  lightgbm  0.907    0.5\n",
      "2719      5   41162       KB   xgboost  0.684  0.549\n",
      "2718      5   41162       KB       knn  0.605   0.55\n",
      "2722      5   41162       KB       mlp  0.589  0.606\n",
      "2720      5   41162       KB  catboost  0.398  0.574\n",
      "2717      5   41162       KB        lr  0.342  0.637\n",
      "2723      5   41162       KB        ae  0.093    0.5\n",
      "=====\n",
      "      shots data_id dataname                             model    acc  auroc\n",
      "2468      5     454       HF                          lightgbm  0.905    0.5\n",
      "2466      5     454       HF                           xgboost  0.902  0.854\n",
      "2464      5     454       HF                                lr  0.803  0.936\n",
      "2482      5     454       HF    pseudolabel-binshuffling-bound  0.780  0.917\n",
      "2480      5     454       HF   pseudolabel-binshuffling-random  0.777  0.904\n",
      "2477      5     454       HF               pseudolabel-masking  0.773  0.907\n",
      "2481      5     454       HF  pseudolabel-binshuffling-uniform  0.761  0.903\n",
      "2465      5     454       HF                               knn  0.750  0.905\n",
      "2469      5     454       HF                               mlp  0.667  0.861\n",
      "2472      5     454       HF                       meanteacher  0.655  0.782\n",
      "2470      5     454       HF                                ae  0.542  0.826\n",
      "2471      5     454       HF                               ict  0.409  0.725\n",
      "2476      5     454       HF                              vime  0.068   0.54\n",
      "=====\n",
      "      shots data_id dataname                             model    acc  auroc\n",
      "4549      5    4153       RH                               ict  0.861  0.977\n",
      "4542      5    4153       RH                                lr  0.833  0.984\n",
      "4561      5    4153       RH   pseudolabel-binshuffling-random  0.833  0.978\n",
      "4558      5    4153       RH               pseudolabel-masking  0.806  0.993\n",
      "4562      5    4153       RH  pseudolabel-binshuffling-uniform  0.806   0.97\n",
      "4563      5    4153       RH    pseudolabel-binshuffling-bound  0.806  0.973\n",
      "4547      5    4153       RH                               mlp  0.778  0.989\n",
      "4548      5    4153       RH                                ae  0.778  0.978\n",
      "4543      5    4153       RH                               knn  0.639  0.922\n",
      "4544      5    4153       RH                           xgboost  0.528  0.842\n",
      "4557      5    4153       RH                              vime  0.361  0.895\n",
      "4546      5    4153       RH                          lightgbm  0.194    0.5\n",
      "4555      5    4153       RH              sslsubtab-finetuning  0.194    0.5\n",
      "4552      5    4153       RH                      sslsubtab-lr  0.167  0.492\n",
      "4553      5    4153       RH                     sslsubtab-knn  0.167   0.48\n",
      "4554      5    4153       RH              sslsubtab-lineareval  0.139    0.5\n",
      "=====\n",
      "Empty DataFrame\n",
      "Columns: [shots, data_id, dataname, model, acc, auroc]\n",
      "Index: []\n",
      "=====\n",
      "      shots data_id dataname                             model    acc  auroc\n",
      "4567      5   40922       RW                          catboost  0.771  0.934\n",
      "4566      5   40922       RW                           xgboost  0.767  0.766\n",
      "4570      5   40922       RW                                ae  0.766    0.5\n",
      "4564      5   40922       RW                                lr  0.749  0.818\n",
      "4569      5   40922       RW                               mlp  0.685  0.799\n",
      "4571      5   40922       RW                               ict  0.629    0.5\n",
      "4565      5   40922       RW                               knn  0.577  0.859\n",
      "4568      5   40922       RW                          lightgbm  0.502    0.5\n",
      "4577      5   40922       RW               pseudolabel-masking  0.502    0.5\n",
      "4581      5   40922       RW  pseudolabel-binshuffling-uniform  0.502    0.5\n",
      "4576      5   40922       RW                              vime  0.498    0.5\n",
      "4580      5   40922       RW   pseudolabel-binshuffling-random  0.498    0.5\n",
      "4582      5   40922       RW    pseudolabel-binshuffling-bound  0.498    0.5\n",
      "=====\n",
      "      shots data_id dataname                             model    acc  auroc\n",
      "4540      5    1503      SAD  pseudolabel-binshuffling-uniform  0.103  0.505\n",
      "4522      5    1503      SAD                           xgboost  0.103  0.503\n",
      "4524      5    1503      SAD                          lightgbm  0.103  0.502\n",
      "4526      5    1503      SAD                                ae  0.102  0.501\n",
      "4520      5    1503      SAD                                lr  0.102  0.502\n",
      "4539      5    1503      SAD   pseudolabel-binshuffling-random  0.101  0.502\n",
      "4532      5    1503      SAD              sslsubtab-lineareval  0.101    0.5\n",
      "4535      5    1503      SAD                              vime  0.101  0.503\n",
      "4531      5    1503      SAD                     sslsubtab-knn  0.101    0.5\n",
      "4536      5    1503      SAD               pseudolabel-masking  0.100    0.5\n",
      "4525      5    1503      SAD                               mlp  0.100  0.501\n",
      "4530      5    1503      SAD                      sslsubtab-lr  0.100    0.5\n",
      "4521      5    1503      SAD                               knn  0.099  0.501\n",
      "4527      5    1503      SAD                               ict  0.099    0.5\n",
      "4533      5    1503      SAD              sslsubtab-finetuning  0.099    0.5\n",
      "4541      5    1503      SAD    pseudolabel-binshuffling-bound  0.098    0.5\n",
      "=====\n",
      "Empty DataFrame\n",
      "Columns: [shots, data_id, dataname, model, acc, auroc]\n",
      "Index: []\n",
      "=====\n",
      "      shots data_id dataname                             model    acc  auroc\n",
      "2319      5    4135       AZ              sslsubtab-lineareval  0.942    0.5\n",
      "2324      5    4135       AZ                sslvime-finetuning  0.942    0.5\n",
      "2323      5    4135       AZ                sslvime-lineareval  0.942    0.5\n",
      "2321      5    4135       AZ                        sslvime-lr  0.942  0.466\n",
      "2313      5    4135       AZ                       sslrecon-lr  0.942    0.5\n",
      "2314      5    4135       AZ                      sslrecon-knn  0.942    0.5\n",
      "2315      5    4135       AZ               sslrecon-lineareval  0.942    0.5\n",
      "2316      5    4135       AZ               sslrecon-finetuning  0.942    0.5\n",
      "2320      5    4135       AZ              sslsubtab-finetuning  0.942    0.5\n",
      "2331      5    4135       AZ    pseudolabel-binshuffling-bound  0.941    0.5\n",
      "2329      5    4135       AZ   pseudolabel-binshuffling-random  0.941    0.5\n",
      "2326      5    4135       AZ               pseudolabel-masking  0.941    0.5\n",
      "2330      5    4135       AZ  pseudolabel-binshuffling-uniform  0.940    0.5\n",
      "2317      5    4135       AZ                      sslsubtab-lr  0.800  0.505\n",
      "2305      5    4135       AZ                               knn  0.709   0.47\n",
      "2306      5    4135       AZ                           xgboost  0.706  0.496\n",
      "2307      5    4135       AZ                          catboost  0.662  0.509\n",
      "2322      5    4135       AZ                       sslvime-knn  0.654  0.543\n",
      "2304      5    4135       AZ                                lr  0.646  0.476\n",
      "2311      5    4135       AZ                               ict  0.639    0.5\n",
      "2310      5    4135       AZ                                ae  0.604    0.5\n",
      "2309      5    4135       AZ                               mlp  0.518  0.509\n",
      "2318      5    4135       AZ                     sslsubtab-knn  0.067  0.501\n",
      "2325      5    4135       AZ                              vime  0.058    0.5\n",
      "2308      5    4135       AZ                          lightgbm  0.058    0.5\n"
     ]
    }
   ],
   "source": [
    "pd.set_option(\"display.precision\", 3)\n",
    "p = result.dropna()\n",
    "for i in data_info.keys():\n",
    "    if len(p[p[\"data_id\"] == i]) > 0:\n",
    "        print(\"=====\")\n",
    "        print(p[(p[\"data_id\"] == i) & (p[\"shots\"] == 5)].sort_values(\"acc\", ascending=False)[[\"shots\", \"data_id\", \"dataname\", \"model\", \"acc\", \"auroc\"]])\n",
    "#         print(i)\n",
    "#         print(p[(p[\"data_id\"] == i) & (p[\"shots\"] == 1)].sort_values(\"acc\", ascending=False).head(1))\n",
    "#     print(np.round(result[(result[\"data_id\"] == i) & (result[\"shots\"] == 1)][\"acc\"].values, 3)) #.sort_values(\"acc\", ascending=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "a6324948",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 73,
   "id": "58b0b616",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(1800, 7)\n",
      "(1755,)\n"
     ]
    }
   ],
   "source": [
    "result = pd.DataFrame(columns=(\"shots\", \"data_id\", \"dataname\", \"tasktype\", \"model\", \"acc\", \"auroc\"))\n",
    "i = 0\n",
    "for shot in [1, 5]:\n",
    "    for data in [4538, 40499, 458, 14, 16, 22, 32, 182, 1475, 1476, 1492, 1493, 1497, 1509, 1531, 4153, 40685, 41168, 41169, 43986]: \n",
    "        for strong in np.arange(1, 11):\n",
    "            for weak in np.arange(strong+1, 11):\n",
    "                model = f'ours-mlp/fixed_masking/BinShuffling-False/weak={weak}/strong={strong}'\n",
    "                fname = f'results/shot={shot}/data={data}/model={model}/performance.npy'\n",
    "                if os.path.exists(fname):\n",
    "                    perf = np.load(fname, allow_pickle=True).item()\n",
    "                    result.loc[i] = [shot, data, data_info[str(data)][\"name\"], data_info[str(data)][\"tasktype\"], model.split(\"BinShuffling-False/\")[-1], perf[\"Test\"][0], perf[\"Test\"][1]]\n",
    "                    i += 1\n",
    "                else:\n",
    "                    result.loc[i] = [shot, data, data_info[str(data)][\"name\"], data_info[str(data)][\"tasktype\"], model.split(\"BinShuffling-False/\")[-1], None, None]\n",
    "                    i += 1\n",
    "print(result.shape)\n",
    "print(result[\"acc\"].dropna().shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 74,
   "id": "e48674a7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "==========\n",
      "4538 [0.438]\n",
      "458 [0.776]\n",
      "40499 [0.555]\n",
      "14 [0.325]\n",
      "16 [0.47]\n",
      "22 [0.41]\n",
      "32 [0.686]\n",
      "182 [0.67]\n",
      "1475 [0.193]\n",
      "1476 [0.462]\n",
      "1492 [0.312]\n",
      "1493 [0.375]\n",
      "1497 [0.267]\n",
      "1509 [0.117]\n",
      "1531 [0.055]\n",
      "4153 [0.611]\n",
      "40685 [0.706]\n",
      "41168 [0.458]\n",
      "41169 [0.058]\n",
      "43986 [0.251]\n",
      "==========\n",
      "4538 [0.396]\n",
      "458 [0.965]\n",
      "40499 [0.88]\n",
      "14 [0.61]\n",
      "16 [0.79]\n",
      "22 [0.665]\n",
      "32 [0.757]\n",
      "182 [0.804]\n",
      "1475 [0.208]\n",
      "1476 [0.733]\n",
      "1492 [0.469]\n",
      "1493 [0.694]\n",
      "1497 [0.582]\n",
      "1509 [0.255]\n",
      "1531 [0.314]\n",
      "4153 [0.833]\n",
      "40685 [0.856]\n",
      "41168 [0.529]\n",
      "41169 [0.126]\n",
      "43986 [nan]\n"
     ]
    }
   ],
   "source": [
    "pd.set_option(\"display.precision\", 3)\n",
    "for shot in [1, 5]:\n",
    "    print(\"==========\")\n",
    "    for data in [4538, 458, 40499, 14, 16, 22, 32, 182, 1475, 1476, 1492, 1493, 1497, 1509, 1531, 4153, 40685, 41168, 41169, 43986]: \n",
    "        print(data, np.round(result[(result[\"data_id\"] == data) & (result[\"shots\"] == shot)].sort_values(\"acc\", ascending=False).head(1)[\"acc\"].values, 3))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 66,
   "id": "54649b47",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>shots</th>\n",
       "      <th>data_id</th>\n",
       "      <th>dataname</th>\n",
       "      <th>tasktype</th>\n",
       "      <th>model</th>\n",
       "      <th>acc</th>\n",
       "      <th>auroc</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>810</th>\n",
       "      <td>1</td>\n",
       "      <td>41169</td>\n",
       "      <td>HE</td>\n",
       "      <td>multiclass</td>\n",
       "      <td>weak=2/strong=1</td>\n",
       "      <td>0.044</td>\n",
       "      <td>0.636</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>811</th>\n",
       "      <td>1</td>\n",
       "      <td>41169</td>\n",
       "      <td>HE</td>\n",
       "      <td>multiclass</td>\n",
       "      <td>weak=3/strong=1</td>\n",
       "      <td>0.044</td>\n",
       "      <td>0.634</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>812</th>\n",
       "      <td>1</td>\n",
       "      <td>41169</td>\n",
       "      <td>HE</td>\n",
       "      <td>multiclass</td>\n",
       "      <td>weak=4/strong=1</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>813</th>\n",
       "      <td>1</td>\n",
       "      <td>41169</td>\n",
       "      <td>HE</td>\n",
       "      <td>multiclass</td>\n",
       "      <td>weak=5/strong=1</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>814</th>\n",
       "      <td>1</td>\n",
       "      <td>41169</td>\n",
       "      <td>HE</td>\n",
       "      <td>multiclass</td>\n",
       "      <td>weak=6/strong=1</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>815</th>\n",
       "      <td>1</td>\n",
       "      <td>41169</td>\n",
       "      <td>HE</td>\n",
       "      <td>multiclass</td>\n",
       "      <td>weak=7/strong=1</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>816</th>\n",
       "      <td>1</td>\n",
       "      <td>41169</td>\n",
       "      <td>HE</td>\n",
       "      <td>multiclass</td>\n",
       "      <td>weak=8/strong=1</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>817</th>\n",
       "      <td>1</td>\n",
       "      <td>41169</td>\n",
       "      <td>HE</td>\n",
       "      <td>multiclass</td>\n",
       "      <td>weak=9/strong=1</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>818</th>\n",
       "      <td>1</td>\n",
       "      <td>41169</td>\n",
       "      <td>HE</td>\n",
       "      <td>multiclass</td>\n",
       "      <td>weak=10/strong=1</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>819</th>\n",
       "      <td>1</td>\n",
       "      <td>41169</td>\n",
       "      <td>HE</td>\n",
       "      <td>multiclass</td>\n",
       "      <td>weak=3/strong=2</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>820</th>\n",
       "      <td>1</td>\n",
       "      <td>41169</td>\n",
       "      <td>HE</td>\n",
       "      <td>multiclass</td>\n",
       "      <td>weak=4/strong=2</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>821</th>\n",
       "      <td>1</td>\n",
       "      <td>41169</td>\n",
       "      <td>HE</td>\n",
       "      <td>multiclass</td>\n",
       "      <td>weak=5/strong=2</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>822</th>\n",
       "      <td>1</td>\n",
       "      <td>41169</td>\n",
       "      <td>HE</td>\n",
       "      <td>multiclass</td>\n",
       "      <td>weak=6/strong=2</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>823</th>\n",
       "      <td>1</td>\n",
       "      <td>41169</td>\n",
       "      <td>HE</td>\n",
       "      <td>multiclass</td>\n",
       "      <td>weak=7/strong=2</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>824</th>\n",
       "      <td>1</td>\n",
       "      <td>41169</td>\n",
       "      <td>HE</td>\n",
       "      <td>multiclass</td>\n",
       "      <td>weak=8/strong=2</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>825</th>\n",
       "      <td>1</td>\n",
       "      <td>41169</td>\n",
       "      <td>HE</td>\n",
       "      <td>multiclass</td>\n",
       "      <td>weak=9/strong=2</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>826</th>\n",
       "      <td>1</td>\n",
       "      <td>41169</td>\n",
       "      <td>HE</td>\n",
       "      <td>multiclass</td>\n",
       "      <td>weak=10/strong=2</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>827</th>\n",
       "      <td>1</td>\n",
       "      <td>41169</td>\n",
       "      <td>HE</td>\n",
       "      <td>multiclass</td>\n",
       "      <td>weak=4/strong=3</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>828</th>\n",
       "      <td>1</td>\n",
       "      <td>41169</td>\n",
       "      <td>HE</td>\n",
       "      <td>multiclass</td>\n",
       "      <td>weak=5/strong=3</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>829</th>\n",
       "      <td>1</td>\n",
       "      <td>41169</td>\n",
       "      <td>HE</td>\n",
       "      <td>multiclass</td>\n",
       "      <td>weak=6/strong=3</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>830</th>\n",
       "      <td>1</td>\n",
       "      <td>41169</td>\n",
       "      <td>HE</td>\n",
       "      <td>multiclass</td>\n",
       "      <td>weak=7/strong=3</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>831</th>\n",
       "      <td>1</td>\n",
       "      <td>41169</td>\n",
       "      <td>HE</td>\n",
       "      <td>multiclass</td>\n",
       "      <td>weak=8/strong=3</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>832</th>\n",
       "      <td>1</td>\n",
       "      <td>41169</td>\n",
       "      <td>HE</td>\n",
       "      <td>multiclass</td>\n",
       "      <td>weak=9/strong=3</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>833</th>\n",
       "      <td>1</td>\n",
       "      <td>41169</td>\n",
       "      <td>HE</td>\n",
       "      <td>multiclass</td>\n",
       "      <td>weak=10/strong=3</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>834</th>\n",
       "      <td>1</td>\n",
       "      <td>41169</td>\n",
       "      <td>HE</td>\n",
       "      <td>multiclass</td>\n",
       "      <td>weak=5/strong=4</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>835</th>\n",
       "      <td>1</td>\n",
       "      <td>41169</td>\n",
       "      <td>HE</td>\n",
       "      <td>multiclass</td>\n",
       "      <td>weak=6/strong=4</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>836</th>\n",
       "      <td>1</td>\n",
       "      <td>41169</td>\n",
       "      <td>HE</td>\n",
       "      <td>multiclass</td>\n",
       "      <td>weak=7/strong=4</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>837</th>\n",
       "      <td>1</td>\n",
       "      <td>41169</td>\n",
       "      <td>HE</td>\n",
       "      <td>multiclass</td>\n",
       "      <td>weak=8/strong=4</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>838</th>\n",
       "      <td>1</td>\n",
       "      <td>41169</td>\n",
       "      <td>HE</td>\n",
       "      <td>multiclass</td>\n",
       "      <td>weak=9/strong=4</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>839</th>\n",
       "      <td>1</td>\n",
       "      <td>41169</td>\n",
       "      <td>HE</td>\n",
       "      <td>multiclass</td>\n",
       "      <td>weak=10/strong=4</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>840</th>\n",
       "      <td>1</td>\n",
       "      <td>41169</td>\n",
       "      <td>HE</td>\n",
       "      <td>multiclass</td>\n",
       "      <td>weak=6/strong=5</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>841</th>\n",
       "      <td>1</td>\n",
       "      <td>41169</td>\n",
       "      <td>HE</td>\n",
       "      <td>multiclass</td>\n",
       "      <td>weak=7/strong=5</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>842</th>\n",
       "      <td>1</td>\n",
       "      <td>41169</td>\n",
       "      <td>HE</td>\n",
       "      <td>multiclass</td>\n",
       "      <td>weak=8/strong=5</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>843</th>\n",
       "      <td>1</td>\n",
       "      <td>41169</td>\n",
       "      <td>HE</td>\n",
       "      <td>multiclass</td>\n",
       "      <td>weak=9/strong=5</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>844</th>\n",
       "      <td>1</td>\n",
       "      <td>41169</td>\n",
       "      <td>HE</td>\n",
       "      <td>multiclass</td>\n",
       "      <td>weak=10/strong=5</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>845</th>\n",
       "      <td>1</td>\n",
       "      <td>41169</td>\n",
       "      <td>HE</td>\n",
       "      <td>multiclass</td>\n",
       "      <td>weak=7/strong=6</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>846</th>\n",
       "      <td>1</td>\n",
       "      <td>41169</td>\n",
       "      <td>HE</td>\n",
       "      <td>multiclass</td>\n",
       "      <td>weak=8/strong=6</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>847</th>\n",
       "      <td>1</td>\n",
       "      <td>41169</td>\n",
       "      <td>HE</td>\n",
       "      <td>multiclass</td>\n",
       "      <td>weak=9/strong=6</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>848</th>\n",
       "      <td>1</td>\n",
       "      <td>41169</td>\n",
       "      <td>HE</td>\n",
       "      <td>multiclass</td>\n",
       "      <td>weak=10/strong=6</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>849</th>\n",
       "      <td>1</td>\n",
       "      <td>41169</td>\n",
       "      <td>HE</td>\n",
       "      <td>multiclass</td>\n",
       "      <td>weak=8/strong=7</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>850</th>\n",
       "      <td>1</td>\n",
       "      <td>41169</td>\n",
       "      <td>HE</td>\n",
       "      <td>multiclass</td>\n",
       "      <td>weak=9/strong=7</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>851</th>\n",
       "      <td>1</td>\n",
       "      <td>41169</td>\n",
       "      <td>HE</td>\n",
       "      <td>multiclass</td>\n",
       "      <td>weak=10/strong=7</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>852</th>\n",
       "      <td>1</td>\n",
       "      <td>41169</td>\n",
       "      <td>HE</td>\n",
       "      <td>multiclass</td>\n",
       "      <td>weak=9/strong=8</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>853</th>\n",
       "      <td>1</td>\n",
       "      <td>41169</td>\n",
       "      <td>HE</td>\n",
       "      <td>multiclass</td>\n",
       "      <td>weak=10/strong=8</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>854</th>\n",
       "      <td>1</td>\n",
       "      <td>41169</td>\n",
       "      <td>HE</td>\n",
       "      <td>multiclass</td>\n",
       "      <td>weak=10/strong=9</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "     shots  data_id dataname    tasktype             model    acc  auroc\n",
       "810      1    41169       HE  multiclass   weak=2/strong=1  0.044  0.636\n",
       "811      1    41169       HE  multiclass   weak=3/strong=1  0.044  0.634\n",
       "812      1    41169       HE  multiclass   weak=4/strong=1    NaN    NaN\n",
       "813      1    41169       HE  multiclass   weak=5/strong=1    NaN    NaN\n",
       "814      1    41169       HE  multiclass   weak=6/strong=1    NaN    NaN\n",
       "815      1    41169       HE  multiclass   weak=7/strong=1    NaN    NaN\n",
       "816      1    41169       HE  multiclass   weak=8/strong=1    NaN    NaN\n",
       "817      1    41169       HE  multiclass   weak=9/strong=1    NaN    NaN\n",
       "818      1    41169       HE  multiclass  weak=10/strong=1    NaN    NaN\n",
       "819      1    41169       HE  multiclass   weak=3/strong=2    NaN    NaN\n",
       "820      1    41169       HE  multiclass   weak=4/strong=2    NaN    NaN\n",
       "821      1    41169       HE  multiclass   weak=5/strong=2    NaN    NaN\n",
       "822      1    41169       HE  multiclass   weak=6/strong=2    NaN    NaN\n",
       "823      1    41169       HE  multiclass   weak=7/strong=2    NaN    NaN\n",
       "824      1    41169       HE  multiclass   weak=8/strong=2    NaN    NaN\n",
       "825      1    41169       HE  multiclass   weak=9/strong=2    NaN    NaN\n",
       "826      1    41169       HE  multiclass  weak=10/strong=2    NaN    NaN\n",
       "827      1    41169       HE  multiclass   weak=4/strong=3    NaN    NaN\n",
       "828      1    41169       HE  multiclass   weak=5/strong=3    NaN    NaN\n",
       "829      1    41169       HE  multiclass   weak=6/strong=3    NaN    NaN\n",
       "830      1    41169       HE  multiclass   weak=7/strong=3    NaN    NaN\n",
       "831      1    41169       HE  multiclass   weak=8/strong=3    NaN    NaN\n",
       "832      1    41169       HE  multiclass   weak=9/strong=3    NaN    NaN\n",
       "833      1    41169       HE  multiclass  weak=10/strong=3    NaN    NaN\n",
       "834      1    41169       HE  multiclass   weak=5/strong=4    NaN    NaN\n",
       "835      1    41169       HE  multiclass   weak=6/strong=4    NaN    NaN\n",
       "836      1    41169       HE  multiclass   weak=7/strong=4    NaN    NaN\n",
       "837      1    41169       HE  multiclass   weak=8/strong=4    NaN    NaN\n",
       "838      1    41169       HE  multiclass   weak=9/strong=4    NaN    NaN\n",
       "839      1    41169       HE  multiclass  weak=10/strong=4    NaN    NaN\n",
       "840      1    41169       HE  multiclass   weak=6/strong=5    NaN    NaN\n",
       "841      1    41169       HE  multiclass   weak=7/strong=5    NaN    NaN\n",
       "842      1    41169       HE  multiclass   weak=8/strong=5    NaN    NaN\n",
       "843      1    41169       HE  multiclass   weak=9/strong=5    NaN    NaN\n",
       "844      1    41169       HE  multiclass  weak=10/strong=5    NaN    NaN\n",
       "845      1    41169       HE  multiclass   weak=7/strong=6    NaN    NaN\n",
       "846      1    41169       HE  multiclass   weak=8/strong=6    NaN    NaN\n",
       "847      1    41169       HE  multiclass   weak=9/strong=6    NaN    NaN\n",
       "848      1    41169       HE  multiclass  weak=10/strong=6    NaN    NaN\n",
       "849      1    41169       HE  multiclass   weak=8/strong=7    NaN    NaN\n",
       "850      1    41169       HE  multiclass   weak=9/strong=7    NaN    NaN\n",
       "851      1    41169       HE  multiclass  weak=10/strong=7    NaN    NaN\n",
       "852      1    41169       HE  multiclass   weak=9/strong=8    NaN    NaN\n",
       "853      1    41169       HE  multiclass  weak=10/strong=8    NaN    NaN\n",
       "854      1    41169       HE  multiclass  weak=10/strong=9    NaN    NaN"
      ]
     },
     "execution_count": 66,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "result[(result[\"data_id\"] == 41169) & (result[\"shots\"] == 1)] # & (result[\"model\"].str.endswith(\"strong=7\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8b1ebc5d",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 91,
   "id": "a154cba9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(240, 7)\n"
     ]
    }
   ],
   "source": [
    "result = pd.DataFrame(columns=(\"shots\", \"data_id\", \"dataname\", \"tasktype\", \"model\", \"acc\", \"auroc\"))\n",
    "i = 0\n",
    "for shot in [1, 5]:\n",
    "    for data in [4538, 40499, 458, 14, 16, 22, 32, 182, 1475, 1476, 1492, 1493, 1497, 1509, 1531, 4153, 40685, 41168, 41169, 43986]: \n",
    "        for strong in [2, 4, 10]:\n",
    "            for weak in [4, 10]:\n",
    "                model = f'ours-mlp/fixed_masking/BinShuffling-True/weak={weak}/strong={strong}'\n",
    "                fname = f'results/shot={shot}/data={data}/model={model}/performance.npy'\n",
    "                if os.path.exists(fname):\n",
    "                    perf = np.load(fname, allow_pickle=True).item()\n",
    "                    result.loc[i] = [shot, data, data_info[str(data)][\"name\"], data_info[str(data)][\"tasktype\"], model.split(\"BinShuffling-True/\")[-1], perf[\"Test\"][0], perf[\"Test\"][1]]\n",
    "                    i += 1\n",
    "                else:\n",
    "                    result.loc[i] = [shot, data, data_info[str(data)][\"name\"], data_info[str(data)][\"tasktype\"], model.split(\"BinShuffling-True/\")[-1], None, None]\n",
    "                    i += 1\n",
    "print(result.shape)\n",
    "# print(result.dropna().shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 92,
   "id": "d95353d6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "==========\n",
      "4538 [0.438]\n",
      "458 [0.765]\n",
      "40499 [0.507]\n",
      "14 [nan]\n",
      "16 [0.435]\n",
      "22 [nan]\n",
      "32 [0.608]\n",
      "182 [0.594]\n",
      "1475 [0.149]\n",
      "1476 [nan]\n",
      "1492 [nan]\n",
      "1493 [nan]\n",
      "1497 [nan]\n",
      "1509 [nan]\n",
      "1531 [nan]\n",
      "4153 [0.667]\n",
      "40685 [nan]\n",
      "41168 [0.349]\n",
      "41169 [0.045]\n",
      "43986 [0.208]\n",
      "==========\n",
      "4538 [0.391]\n",
      "458 [0.953]\n",
      "40499 [0.844]\n",
      "14 [nan]\n",
      "16 [0.775]\n",
      "22 [nan]\n",
      "32 [0.745]\n",
      "182 [0.796]\n",
      "1475 [0.186]\n",
      "1476 [nan]\n",
      "1492 [nan]\n",
      "1493 [nan]\n",
      "1497 [nan]\n",
      "1509 [nan]\n",
      "1531 [nan]\n",
      "4153 [0.833]\n",
      "40685 [nan]\n",
      "41168 [0.443]\n",
      "41169 [0.113]\n",
      "43986 [nan]\n"
     ]
    }
   ],
   "source": [
    "pd.set_option(\"display.precision\", 3)\n",
    "for shot in [1, 5]:\n",
    "    print(\"==========\")\n",
    "    for data in [4538, 458, 40499, 14, 16, 22, 32, 182, 1475, 1476, 1492, 1493, 1497, 1509, 1531, 4153, 40685, 41168, 41169, 43986]: \n",
    "        print(data, np.round(result[(result[\"data_id\"] == data) & (result[\"shots\"] == shot)].sort_values(\"acc\", ascending=False).head(1)[\"acc\"].values, 3))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 78,
   "id": "d7a529fe",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>shots</th>\n",
       "      <th>data_id</th>\n",
       "      <th>dataname</th>\n",
       "      <th>tasktype</th>\n",
       "      <th>model</th>\n",
       "      <th>acc</th>\n",
       "      <th>auroc</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>1</td>\n",
       "      <td>458</td>\n",
       "      <td>ATH</td>\n",
       "      <td>multiclass</td>\n",
       "      <td>weak=4/strong=2</td>\n",
       "      <td>0.741</td>\n",
       "      <td>0.942</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>1</td>\n",
       "      <td>458</td>\n",
       "      <td>ATH</td>\n",
       "      <td>multiclass</td>\n",
       "      <td>weak=10/strong=2</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>1</td>\n",
       "      <td>458</td>\n",
       "      <td>ATH</td>\n",
       "      <td>multiclass</td>\n",
       "      <td>weak=4/strong=4</td>\n",
       "      <td>0.765</td>\n",
       "      <td>0.950</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15</th>\n",
       "      <td>1</td>\n",
       "      <td>458</td>\n",
       "      <td>ATH</td>\n",
       "      <td>multiclass</td>\n",
       "      <td>weak=10/strong=4</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>16</th>\n",
       "      <td>1</td>\n",
       "      <td>458</td>\n",
       "      <td>ATH</td>\n",
       "      <td>multiclass</td>\n",
       "      <td>weak=4/strong=10</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17</th>\n",
       "      <td>1</td>\n",
       "      <td>458</td>\n",
       "      <td>ATH</td>\n",
       "      <td>multiclass</td>\n",
       "      <td>weak=10/strong=10</td>\n",
       "      <td>0.753</td>\n",
       "      <td>0.959</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "    shots  data_id dataname    tasktype              model    acc  auroc\n",
       "12      1      458      ATH  multiclass    weak=4/strong=2  0.741  0.942\n",
       "13      1      458      ATH  multiclass   weak=10/strong=2    NaN    NaN\n",
       "14      1      458      ATH  multiclass    weak=4/strong=4  0.765  0.950\n",
       "15      1      458      ATH  multiclass   weak=10/strong=4    NaN    NaN\n",
       "16      1      458      ATH  multiclass   weak=4/strong=10    NaN    NaN\n",
       "17      1      458      ATH  multiclass  weak=10/strong=10  0.753  0.959"
      ]
     },
     "execution_count": 78,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "result[(result[\"data_id\"] == 458) & (result[\"shots\"] == 1)] # & (result[\"model\"].str.endswith(\"strong=7\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "da5eb435",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3a3cb537",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "045f3905",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 351,
   "id": "c3eaab98",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[]"
      ]
     },
     "execution_count": 351,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "s = 5\n",
    "p = result[result[\"shots\"] == s].groupby(\"data_id\").size().reset_index(name=\"counts\")\n",
    "done = p[p[\"counts\"] == 22][\"data_id\"].tolist()\n",
    "fails = [1467, 42664, 444, 466, 53, 51, 49, 48, 470, 466, 454, 452, 42931, 42665, 42, 4153, 40981, 40900, 338, 337, 334, 29, 23381, 188, 185, 1510, \n",
    "         1467, 1462, 11, 1464, 1493, 35, 44124, 455, 475, 1549, 40496, 1489, 37, 44125, 469, 44123, 1497, 1555, 31, 934, 846, 45714, 4534, 45060, 1067,\n",
    "         44161, 44158, 44131, 44157, 44126, 44122, 44091, 44090, 41150, 4135, 42734, 44089, 40685, 40985, 3, 18, 151, 1531, 1466, 1487, 1486, 1479, 1475, 1476,\n",
    "         41169, 45548, 23, 54, 50, 307, 14, 6, 46, 20, 1459, 1169, 40922, 41147, 41162, 41168, 42345\n",
    "        ]\n",
    "okay = [1063, 1492, 458, 1504, 4538, 45545, 40499, 45068, 45062, 12, 41027, 32, 36, 22, 16, 182, 1509, 1494, 1471, 1503]\n",
    "\n",
    "done = [i for i in done if (eval(i) not in fails) & (eval(i) not in okay)]\n",
    "done"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 350,
   "id": "98c800e9",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>shots</th>\n",
       "      <th>data_id</th>\n",
       "      <th>dataname</th>\n",
       "      <th>tasktype</th>\n",
       "      <th>model</th>\n",
       "      <th>acc</th>\n",
       "      <th>auroc</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>1517</th>\n",
       "      <td>5</td>\n",
       "      <td>42345</td>\n",
       "      <td>TV</td>\n",
       "      <td>multiclass</td>\n",
       "      <td>xgboost</td>\n",
       "      <td>0.543738</td>\n",
       "      <td>0.720493</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1521</th>\n",
       "      <td>5</td>\n",
       "      <td>42345</td>\n",
       "      <td>TV</td>\n",
       "      <td>multiclass</td>\n",
       "      <td>ssl-vime-mlp/shuffling=0.5</td>\n",
       "      <td>0.523562</td>\n",
       "      <td>0.715359</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1523</th>\n",
       "      <td>5</td>\n",
       "      <td>42345</td>\n",
       "      <td>TV</td>\n",
       "      <td>multiclass</td>\n",
       "      <td>ssl-mse_recon-mlp/shuffling=0.5</td>\n",
       "      <td>0.520680</td>\n",
       "      <td>0.687505</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1518</th>\n",
       "      <td>5</td>\n",
       "      <td>42345</td>\n",
       "      <td>TV</td>\n",
       "      <td>multiclass</td>\n",
       "      <td>catboost</td>\n",
       "      <td>0.512466</td>\n",
       "      <td>0.708093</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1527</th>\n",
       "      <td>5</td>\n",
       "      <td>42345</td>\n",
       "      <td>TV</td>\n",
       "      <td>multiclass</td>\n",
       "      <td>ssl-mse_binning-mlp/shuffling=0.5</td>\n",
       "      <td>0.510016</td>\n",
       "      <td>0.730845</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1532</th>\n",
       "      <td>5</td>\n",
       "      <td>42345</td>\n",
       "      <td>TV</td>\n",
       "      <td>multiclass</td>\n",
       "      <td>ae-mlp/unsup_weight=1/shuffling=0.5</td>\n",
       "      <td>0.504107</td>\n",
       "      <td>0.727344</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1522</th>\n",
       "      <td>5</td>\n",
       "      <td>42345</td>\n",
       "      <td>TV</td>\n",
       "      <td>multiclass</td>\n",
       "      <td>ssl-mse_recon-mlp/masking=0.5</td>\n",
       "      <td>0.498631</td>\n",
       "      <td>0.642196</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1525</th>\n",
       "      <td>5</td>\n",
       "      <td>42345</td>\n",
       "      <td>TV</td>\n",
       "      <td>multiclass</td>\n",
       "      <td>ssl-mse_binning-mlp/None=0.5</td>\n",
       "      <td>0.498343</td>\n",
       "      <td>0.638474</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1524</th>\n",
       "      <td>5</td>\n",
       "      <td>42345</td>\n",
       "      <td>TV</td>\n",
       "      <td>multiclass</td>\n",
       "      <td>ssl-mse_recon-mlp/randquant=0.5</td>\n",
       "      <td>0.491569</td>\n",
       "      <td>0.606321</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1526</th>\n",
       "      <td>5</td>\n",
       "      <td>42345</td>\n",
       "      <td>TV</td>\n",
       "      <td>multiclass</td>\n",
       "      <td>ssl-mse_binning-mlp/masking=0.5</td>\n",
       "      <td>0.475573</td>\n",
       "      <td>0.702536</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1529</th>\n",
       "      <td>5</td>\n",
       "      <td>42345</td>\n",
       "      <td>TV</td>\n",
       "      <td>multiclass</td>\n",
       "      <td>meanteacher-mlp/alpha=0.99</td>\n",
       "      <td>0.474276</td>\n",
       "      <td>0.683802</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1534</th>\n",
       "      <td>5</td>\n",
       "      <td>42345</td>\n",
       "      <td>TV</td>\n",
       "      <td>multiclass</td>\n",
       "      <td>ours-mlp/shuffling-False</td>\n",
       "      <td>0.470241</td>\n",
       "      <td>0.644209</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1528</th>\n",
       "      <td>5</td>\n",
       "      <td>42345</td>\n",
       "      <td>TV</td>\n",
       "      <td>multiclass</td>\n",
       "      <td>vime-mlp/masking=0.4</td>\n",
       "      <td>0.469808</td>\n",
       "      <td>0.671766</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1520</th>\n",
       "      <td>5</td>\n",
       "      <td>42345</td>\n",
       "      <td>TV</td>\n",
       "      <td>multiclass</td>\n",
       "      <td>mlp</td>\n",
       "      <td>0.459720</td>\n",
       "      <td>0.683607</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1519</th>\n",
       "      <td>5</td>\n",
       "      <td>42345</td>\n",
       "      <td>TV</td>\n",
       "      <td>multiclass</td>\n",
       "      <td>lightgbm</td>\n",
       "      <td>0.455109</td>\n",
       "      <td>0.500000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1530</th>\n",
       "      <td>5</td>\n",
       "      <td>42345</td>\n",
       "      <td>TV</td>\n",
       "      <td>multiclass</td>\n",
       "      <td>ICT-mlp/alpha=0.99</td>\n",
       "      <td>0.430033</td>\n",
       "      <td>0.665726</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1536</th>\n",
       "      <td>5</td>\n",
       "      <td>42345</td>\n",
       "      <td>TV</td>\n",
       "      <td>multiclass</td>\n",
       "      <td>ours-mlp/BinShuffling-True</td>\n",
       "      <td>0.428592</td>\n",
       "      <td>0.601323</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1515</th>\n",
       "      <td>5</td>\n",
       "      <td>42345</td>\n",
       "      <td>TV</td>\n",
       "      <td>multiclass</td>\n",
       "      <td>lr</td>\n",
       "      <td>0.416054</td>\n",
       "      <td>0.677512</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1531</th>\n",
       "      <td>5</td>\n",
       "      <td>42345</td>\n",
       "      <td>TV</td>\n",
       "      <td>multiclass</td>\n",
       "      <td>ae-mlp/unsup_weight=1/masking=0.5</td>\n",
       "      <td>0.385358</td>\n",
       "      <td>0.682136</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1533</th>\n",
       "      <td>5</td>\n",
       "      <td>42345</td>\n",
       "      <td>TV</td>\n",
       "      <td>multiclass</td>\n",
       "      <td>ours-mlp/masking-False</td>\n",
       "      <td>0.356247</td>\n",
       "      <td>0.684690</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1535</th>\n",
       "      <td>5</td>\n",
       "      <td>42345</td>\n",
       "      <td>TV</td>\n",
       "      <td>multiclass</td>\n",
       "      <td>ours-mlp/BinShuffling-False</td>\n",
       "      <td>0.354806</td>\n",
       "      <td>0.637069</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1516</th>\n",
       "      <td>5</td>\n",
       "      <td>42345</td>\n",
       "      <td>TV</td>\n",
       "      <td>multiclass</td>\n",
       "      <td>knn</td>\n",
       "      <td>0.230004</td>\n",
       "      <td>0.581703</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "      shots data_id dataname    tasktype                                model  \\\n",
       "1517      5   42345       TV  multiclass                              xgboost   \n",
       "1521      5   42345       TV  multiclass           ssl-vime-mlp/shuffling=0.5   \n",
       "1523      5   42345       TV  multiclass      ssl-mse_recon-mlp/shuffling=0.5   \n",
       "1518      5   42345       TV  multiclass                             catboost   \n",
       "1527      5   42345       TV  multiclass    ssl-mse_binning-mlp/shuffling=0.5   \n",
       "1532      5   42345       TV  multiclass  ae-mlp/unsup_weight=1/shuffling=0.5   \n",
       "1522      5   42345       TV  multiclass        ssl-mse_recon-mlp/masking=0.5   \n",
       "1525      5   42345       TV  multiclass         ssl-mse_binning-mlp/None=0.5   \n",
       "1524      5   42345       TV  multiclass      ssl-mse_recon-mlp/randquant=0.5   \n",
       "1526      5   42345       TV  multiclass      ssl-mse_binning-mlp/masking=0.5   \n",
       "1529      5   42345       TV  multiclass           meanteacher-mlp/alpha=0.99   \n",
       "1534      5   42345       TV  multiclass             ours-mlp/shuffling-False   \n",
       "1528      5   42345       TV  multiclass                 vime-mlp/masking=0.4   \n",
       "1520      5   42345       TV  multiclass                                  mlp   \n",
       "1519      5   42345       TV  multiclass                             lightgbm   \n",
       "1530      5   42345       TV  multiclass                   ICT-mlp/alpha=0.99   \n",
       "1536      5   42345       TV  multiclass           ours-mlp/BinShuffling-True   \n",
       "1515      5   42345       TV  multiclass                                   lr   \n",
       "1531      5   42345       TV  multiclass    ae-mlp/unsup_weight=1/masking=0.5   \n",
       "1533      5   42345       TV  multiclass               ours-mlp/masking-False   \n",
       "1535      5   42345       TV  multiclass          ours-mlp/BinShuffling-False   \n",
       "1516      5   42345       TV  multiclass                                  knn   \n",
       "\n",
       "           acc     auroc  \n",
       "1517  0.543738  0.720493  \n",
       "1521  0.523562  0.715359  \n",
       "1523  0.520680  0.687505  \n",
       "1518  0.512466  0.708093  \n",
       "1527  0.510016  0.730845  \n",
       "1532  0.504107  0.727344  \n",
       "1522  0.498631  0.642196  \n",
       "1525  0.498343  0.638474  \n",
       "1524  0.491569  0.606321  \n",
       "1526  0.475573  0.702536  \n",
       "1529  0.474276  0.683802  \n",
       "1534  0.470241  0.644209  \n",
       "1528  0.469808  0.671766  \n",
       "1520  0.459720  0.683607  \n",
       "1519  0.455109  0.500000  \n",
       "1530  0.430033  0.665726  \n",
       "1536  0.428592  0.601323  \n",
       "1515  0.416054  0.677512  \n",
       "1531  0.385358  0.682136  \n",
       "1533  0.356247  0.684690  \n",
       "1535  0.354806  0.637069  \n",
       "1516  0.230004  0.581703  "
      ]
     },
     "execution_count": 350,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data_id = \"42345\"\n",
    "result[(result[\"shots\"] == 5) & (result[\"data_id\"] == data_id)].sort_values(\"acc\", ascending=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 353,
   "id": "3d8a0c93",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array({'Train': (1.0, 1.0), 'Val': (0.3556231003039514, 0.6577275728374801), 'Test': (0.3441295546558704, 0.6547103996827062)},\n",
       "      dtype=object)"
      ]
     },
     "execution_count": 353,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.load(f'/home/tabsemi/results/shot=1/data=4538/model=test/performance.npy', allow_pickle=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e4d41288",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7b95945b",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "85fd27ed",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "cc0420c3",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>shots</th>\n",
       "      <th>data</th>\n",
       "      <th>model</th>\n",
       "      <th>acc</th>\n",
       "      <th>auroc</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1</td>\n",
       "      <td>35</td>\n",
       "      <td>lr</td>\n",
       "      <td>0.972222</td>\n",
       "      <td>0.999074</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>48</td>\n",
       "      <td>lr</td>\n",
       "      <td>0.437500</td>\n",
       "      <td>0.589989</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>1</td>\n",
       "      <td>49</td>\n",
       "      <td>catboost</td>\n",
       "      <td>0.900000</td>\n",
       "      <td>0.844498</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>1</td>\n",
       "      <td>51</td>\n",
       "      <td>lr</td>\n",
       "      <td>0.740741</td>\n",
       "      <td>0.811765</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>1</td>\n",
       "      <td>53</td>\n",
       "      <td>knn</td>\n",
       "      <td>0.370370</td>\n",
       "      <td>0.500000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>1</td>\n",
       "      <td>59</td>\n",
       "      <td>catboost</td>\n",
       "      <td>0.638889</td>\n",
       "      <td>0.389610</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>1</td>\n",
       "      <td>337</td>\n",
       "      <td>lr</td>\n",
       "      <td>0.571429</td>\n",
       "      <td>0.799145</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>1</td>\n",
       "      <td>338</td>\n",
       "      <td>lr</td>\n",
       "      <td>0.312500</td>\n",
       "      <td>0.606227</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>1</td>\n",
       "      <td>444</td>\n",
       "      <td>knn</td>\n",
       "      <td>0.714286</td>\n",
       "      <td>0.500000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>1</td>\n",
       "      <td>452</td>\n",
       "      <td>lr</td>\n",
       "      <td>0.370370</td>\n",
       "      <td>0.770002</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>1</td>\n",
       "      <td>455</td>\n",
       "      <td>knn</td>\n",
       "      <td>0.700000</td>\n",
       "      <td>0.500000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>1</td>\n",
       "      <td>461</td>\n",
       "      <td>lr</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>1</td>\n",
       "      <td>466</td>\n",
       "      <td>knn</td>\n",
       "      <td>0.416667</td>\n",
       "      <td>0.500000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>1</td>\n",
       "      <td>475</td>\n",
       "      <td>catboost</td>\n",
       "      <td>0.475000</td>\n",
       "      <td>0.578153</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>1</td>\n",
       "      <td>1063</td>\n",
       "      <td>lr</td>\n",
       "      <td>0.867925</td>\n",
       "      <td>0.781395</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15</th>\n",
       "      <td>1</td>\n",
       "      <td>4153</td>\n",
       "      <td>lr</td>\n",
       "      <td>0.555556</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>16</th>\n",
       "      <td>1</td>\n",
       "      <td>23381</td>\n",
       "      <td>knn</td>\n",
       "      <td>0.777778</td>\n",
       "      <td>0.500000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17</th>\n",
       "      <td>1</td>\n",
       "      <td>40496</td>\n",
       "      <td>lr</td>\n",
       "      <td>0.580000</td>\n",
       "      <td>0.879573</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>18</th>\n",
       "      <td>1</td>\n",
       "      <td>42665</td>\n",
       "      <td>catboost</td>\n",
       "      <td>0.833333</td>\n",
       "      <td>0.750000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19</th>\n",
       "      <td>5</td>\n",
       "      <td>35</td>\n",
       "      <td>catboost</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>20</th>\n",
       "      <td>5</td>\n",
       "      <td>48</td>\n",
       "      <td>catboost</td>\n",
       "      <td>0.562500</td>\n",
       "      <td>0.496958</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>21</th>\n",
       "      <td>5</td>\n",
       "      <td>49</td>\n",
       "      <td>xgboost</td>\n",
       "      <td>0.733333</td>\n",
       "      <td>0.732057</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>22</th>\n",
       "      <td>5</td>\n",
       "      <td>51</td>\n",
       "      <td>lr</td>\n",
       "      <td>0.740741</td>\n",
       "      <td>0.682353</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>23</th>\n",
       "      <td>5</td>\n",
       "      <td>53</td>\n",
       "      <td>xgboost</td>\n",
       "      <td>0.888889</td>\n",
       "      <td>0.955882</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>24</th>\n",
       "      <td>5</td>\n",
       "      <td>59</td>\n",
       "      <td>catboost</td>\n",
       "      <td>0.861111</td>\n",
       "      <td>0.931818</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>25</th>\n",
       "      <td>5</td>\n",
       "      <td>337</td>\n",
       "      <td>lr</td>\n",
       "      <td>0.714286</td>\n",
       "      <td>0.794872</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>26</th>\n",
       "      <td>5</td>\n",
       "      <td>338</td>\n",
       "      <td>lr</td>\n",
       "      <td>0.625000</td>\n",
       "      <td>0.741453</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>27</th>\n",
       "      <td>5</td>\n",
       "      <td>444</td>\n",
       "      <td>lr</td>\n",
       "      <td>0.714286</td>\n",
       "      <td>0.875000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>28</th>\n",
       "      <td>5</td>\n",
       "      <td>452</td>\n",
       "      <td>lightgbm</td>\n",
       "      <td>0.296296</td>\n",
       "      <td>0.500000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>29</th>\n",
       "      <td>5</td>\n",
       "      <td>455</td>\n",
       "      <td>lightgbm</td>\n",
       "      <td>0.700000</td>\n",
       "      <td>0.500000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>30</th>\n",
       "      <td>5</td>\n",
       "      <td>461</td>\n",
       "      <td>lr</td>\n",
       "      <td>0.900000</td>\n",
       "      <td>0.888889</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>31</th>\n",
       "      <td>5</td>\n",
       "      <td>466</td>\n",
       "      <td>lr</td>\n",
       "      <td>0.583333</td>\n",
       "      <td>0.685714</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>32</th>\n",
       "      <td>5</td>\n",
       "      <td>475</td>\n",
       "      <td>catboost</td>\n",
       "      <td>0.300000</td>\n",
       "      <td>0.567094</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>33</th>\n",
       "      <td>5</td>\n",
       "      <td>1063</td>\n",
       "      <td>catboost</td>\n",
       "      <td>0.830189</td>\n",
       "      <td>0.716279</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>34</th>\n",
       "      <td>5</td>\n",
       "      <td>4153</td>\n",
       "      <td>knn</td>\n",
       "      <td>0.888889</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>35</th>\n",
       "      <td>5</td>\n",
       "      <td>23381</td>\n",
       "      <td>lightgbm</td>\n",
       "      <td>0.777778</td>\n",
       "      <td>0.500000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>36</th>\n",
       "      <td>5</td>\n",
       "      <td>40496</td>\n",
       "      <td>knn</td>\n",
       "      <td>0.660000</td>\n",
       "      <td>0.848951</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>37</th>\n",
       "      <td>5</td>\n",
       "      <td>42665</td>\n",
       "      <td>catboost</td>\n",
       "      <td>0.916667</td>\n",
       "      <td>0.968750</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "    shots   data     model       acc     auroc\n",
       "0       1     35        lr  0.972222  0.999074\n",
       "1       1     48        lr  0.437500  0.589989\n",
       "2       1     49  catboost  0.900000  0.844498\n",
       "3       1     51        lr  0.740741  0.811765\n",
       "4       1     53       knn  0.370370  0.500000\n",
       "5       1     59  catboost  0.638889  0.389610\n",
       "6       1    337        lr  0.571429  0.799145\n",
       "7       1    338        lr  0.312500  0.606227\n",
       "8       1    444       knn  0.714286  0.500000\n",
       "9       1    452        lr  0.370370  0.770002\n",
       "10      1    455       knn  0.700000  0.500000\n",
       "11      1    461        lr  1.000000  1.000000\n",
       "12      1    466       knn  0.416667  0.500000\n",
       "13      1    475  catboost  0.475000  0.578153\n",
       "14      1   1063        lr  0.867925  0.781395\n",
       "15      1   4153        lr  0.555556       NaN\n",
       "16      1  23381       knn  0.777778  0.500000\n",
       "17      1  40496        lr  0.580000  0.879573\n",
       "18      1  42665  catboost  0.833333  0.750000\n",
       "19      5     35  catboost  1.000000  1.000000\n",
       "20      5     48  catboost  0.562500  0.496958\n",
       "21      5     49   xgboost  0.733333  0.732057\n",
       "22      5     51        lr  0.740741  0.682353\n",
       "23      5     53   xgboost  0.888889  0.955882\n",
       "24      5     59  catboost  0.861111  0.931818\n",
       "25      5    337        lr  0.714286  0.794872\n",
       "26      5    338        lr  0.625000  0.741453\n",
       "27      5    444        lr  0.714286  0.875000\n",
       "28      5    452  lightgbm  0.296296  0.500000\n",
       "29      5    455  lightgbm  0.700000  0.500000\n",
       "30      5    461        lr  0.900000  0.888889\n",
       "31      5    466        lr  0.583333  0.685714\n",
       "32      5    475  catboost  0.300000  0.567094\n",
       "33      5   1063  catboost  0.830189  0.716279\n",
       "34      5   4153       knn  0.888889       NaN\n",
       "35      5  23381  lightgbm  0.777778  0.500000\n",
       "36      5  40496       knn  0.660000  0.848951\n",
       "37      5  42665  catboost  0.916667  0.968750"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "idx = result.groupby([\"shots\", \"data\"])[\"acc\"].idxmax()\n",
    "result.loc[idx].sort_values(by=[\"shots\", \"data\"]).reset_index(drop=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9a518106",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "16e6decd",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
