{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from minio_obj_storage import get_numpy_from_cloud\n",
    "import pandas as pd\n",
    "import logging\n",
    "import json\n",
    "from utils.load_dataset import load_dataset\n",
    "from sklearn.metrics import roc_auc_score, precision_recall_curve, auc, roc_curve\n",
    "import numpy as np\n",
    "\n",
    "logger = logging.getLogger('Default')\n",
    "with open(\"./config.json\", 'r') as f:\n",
    "    config = json.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "results_seeds = []\n",
    "results_dataset = []\n",
    "results_metric = []\n",
    "results_auroc = []\n",
    "\n",
    "for seed in [1,2,3]:\n",
    "    for dataset_name in ['cifar10_duplicate', 'cifar100_duplicate']:\n",
    "        indices_map = get_numpy_from_cloud('learning-dynamics-models', dataset_name, 'duplicate_index_map.npy').item()\n",
    "        losses = []\n",
    "        loss_grads = []\n",
    "        loss_curvature = []\n",
    "        lt = []\n",
    "        for arrays, metric, suffix in [\n",
    "            (losses, 'loss', \"\"), \n",
    "            (loss_grads, 'loss_grad', \"\"), \n",
    "            (loss_curvature, 'loss_curvature', \"_h_0.001_n_10\"), \n",
    "            (lt, 'correct', \"\")]:\n",
    "            for epoch in range(200):\n",
    "                score = get_numpy_from_cloud(\n",
    "                    'learning-dynamics-scores', \n",
    "                    dataset_name,\n",
    "                    f\"{metric}_{dataset_name}_resnet18_seed_{seed}_epoch_{epoch}{suffix}.npy\")\n",
    "                arrays.append(score)\n",
    "\n",
    "        losses = np.array(losses)\n",
    "        loss_grads = np.array(loss_grads)\n",
    "        loss_curvature = np.array(loss_curvature)\n",
    "        lt = np.array(lt)\n",
    "\n",
    "        duplicate_idxs = []\n",
    "        for idx1, (idx2, _) in indices_map.items():\n",
    "            duplicate_idxs.append(idx1)\n",
    "            duplicate_idxs.append(idx2)\n",
    "\n",
    "        is_duplicate = np.zeros((losses.shape[1]))\n",
    "        is_duplicate[duplicate_idxs] = 1\n",
    "        results_df = pd.DataFrame(\n",
    "            data={\n",
    "                'loss': losses.mean(0),\n",
    "                'loss_grad': loss_grads.mean(0),\n",
    "                'loss_curvature': loss_curvature.mean(0),\n",
    "                'clt': 1 - lt.mean(0),\n",
    "                'lt': np.argmax(lt, 0),\n",
    "                'is_duplicate': is_duplicate\n",
    "            }\n",
    "        )\n",
    "\n",
    "        for metric in ['loss', 'loss_grad', 'loss_curvature', 'clt', 'lt']:\n",
    "            y_true = results_df['is_duplicate']\n",
    "            y_scores = results_df[metric]\n",
    "            auroc = roc_auc_score(y_true, y_scores)\n",
    "            results_seeds.append(seed)\n",
    "            results_dataset.append(dataset_name)\n",
    "            results_metric.append(metric)\n",
    "            results_auroc.append(auroc)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "from cleanlab.filter import find_label_issues\n",
    "for seed in [1,2,3]:\n",
    "    for dataset_name in ['cifar10_duplicate', 'cifar100_duplicate']:\n",
    "        container_name = 'learning-dynamics-scores'\n",
    "        indices_map = get_numpy_from_cloud('learning-dynamics-models', dataset_name, 'duplicate_index_map.npy').item()\n",
    "        duplicate_idxs = []\n",
    "        for idx1, (idx2, _) in indices_map.items():\n",
    "            duplicate_idxs.append(idx1)\n",
    "            duplicate_idxs.append(idx2)\n",
    "\n",
    "        conf_learning_labels = get_numpy_from_cloud(\n",
    "            container_name, \n",
    "            f\"{dataset_name}_noisy\",\n",
    "            f\"duplicate_conf_learning_labels_noise_idx_{seed}_noise_0.0.pt\"\n",
    "        ).astype(np.int32)\n",
    "\n",
    "        is_duplicate = np.zeros((len(conf_learning_labels)))\n",
    "        is_duplicate[duplicate_idxs] = 1\n",
    "\n",
    "        prob_file_name = f\"duplicate_conf_learning_prob_noise_idx_{seed}_noise_0.0.pt\"\n",
    "        prob_4_eph = get_numpy_from_cloud(container_name, f\"{dataset_name}_noisy\", prob_file_name)\n",
    "\n",
    "        conf_learning = find_label_issues(\n",
    "            conf_learning_labels,\n",
    "            prob_4_eph,\n",
    "            return_indices_ranked_by=\"self_confidence\",\n",
    "            filter_by=\"confident_learning\",\n",
    "        )\n",
    "\n",
    "        conf = prob_4_eph.max(axis=1)\n",
    "        conf_learning_soft = np.zeros(len(conf_learning_labels))\n",
    "        for idx, i in enumerate(conf_learning[::-1]):\n",
    "            conf_learning_soft[i] = (idx  + 1) / len(conf_learning)\n",
    "\n",
    "        y_true = is_duplicate\n",
    "        y_scores = conf_learning_soft.astype(np.float32)\n",
    "        auroc = roc_auc_score(y_true, y_scores)\n",
    "        results_seeds.append(seed)\n",
    "        results_dataset.append(dataset_name)\n",
    "        results_metric.append('cl')\n",
    "        results_auroc.append(auroc)\n",
    "\n",
    "        y_true = is_duplicate\n",
    "        y_scores = 1 - prob_4_eph.max(1)\n",
    "        auroc = roc_auc_score(y_true, y_scores)\n",
    "        results_seeds.append(seed)\n",
    "        results_dataset.append(dataset_name)\n",
    "        results_metric.append('in conf.')\n",
    "        results_auroc.append(auroc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from minio_obj_storage import get_numpy_from_cloud\n",
    "import torch\n",
    "\n",
    "bucket_name = 'learning-dynamics-scores'\n",
    "index = torch.load(\"./index/data_index_cifar100.pt\")\n",
    "index = np.array(list(range(len(index))))\n",
    "label_noise = 0.0\n",
    "for dataset in ['cifar100', 'cifar10']:\n",
    "    container_dir = dataset + '_duplicate'\n",
    "\n",
    "    indices_map = get_numpy_from_cloud('learning-dynamics-models', container_dir, 'duplicate_index_map.npy').item()\n",
    "    duplicate_idxs = []\n",
    "    for idx1, (idx2, _) in indices_map.items():\n",
    "        duplicate_idxs.append(idx1)\n",
    "        duplicate_idxs.append(idx2)\n",
    "\n",
    "    is_duplicate = np.zeros((len(index)))\n",
    "    is_duplicate[duplicate_idxs] = 1\n",
    "    epochs_total = 200\n",
    "    for seed in [1,2,3]:\n",
    "        preds = np.zeros((epochs_total, len(index)))\n",
    "        for part in [0, 1]:\n",
    "            if part == 0:\n",
    "                index1 = index[:len(index) // 2]\n",
    "                index2 = index[len(index) // 2:]\n",
    "            else:\n",
    "                index2 = index[:len(index) // 2]\n",
    "                index1 = index[len(index) // 2:]    \n",
    "            pred = get_numpy_from_cloud(bucket_name, container_dir, f\"ssft_pred_resnet18_part_{part}_noisy_idx_{seed}_noise_{label_noise}.npy\")\n",
    "            preds[:, index1] = pred[epochs_total:, index1]\n",
    "\n",
    "        preds = np.array(preds).T\n",
    "        ft = np.mean(preds, axis=1)\n",
    "\n",
    "        for metric in ['ssft']:\n",
    "            auroc = roc_auc_score(is_duplicate, 1-ft)\n",
    "            results_seeds.append(seed)\n",
    "            results_auroc.append(auroc)\n",
    "            results_metric.append(metric)\n",
    "            results_dataset.append(container_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_results = pd.DataFrame(data={\n",
    "    'seed': results_seeds,\n",
    "    'metric': results_metric,\n",
    "    'dataset': results_dataset,\n",
    "    'auroc': results_auroc\n",
    "})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th>seed</th>\n",
       "      <th>auroc</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>dataset</th>\n",
       "      <th>metric</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th rowspan=\"8\" valign=\"top\">cifar100_duplicate</th>\n",
       "      <th>cl</th>\n",
       "      <td>2.0</td>\n",
       "      <td>0.587337</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>clt</th>\n",
       "      <td>2.0</td>\n",
       "      <td>0.986984</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>in conf.</th>\n",
       "      <td>2.0</td>\n",
       "      <td>0.862300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>loss</th>\n",
       "      <td>2.0</td>\n",
       "      <td>0.988604</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>loss_curvature</th>\n",
       "      <td>2.0</td>\n",
       "      <td>0.963915</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>loss_grad</th>\n",
       "      <td>2.0</td>\n",
       "      <td>0.971474</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>lt</th>\n",
       "      <td>2.0</td>\n",
       "      <td>0.741868</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ssft</th>\n",
       "      <td>2.0</td>\n",
       "      <td>0.793816</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"8\" valign=\"top\">cifar10_duplicate</th>\n",
       "      <th>cl</th>\n",
       "      <td>2.0</td>\n",
       "      <td>0.553295</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>clt</th>\n",
       "      <td>2.0</td>\n",
       "      <td>0.968022</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>in conf.</th>\n",
       "      <td>2.0</td>\n",
       "      <td>0.923722</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>loss</th>\n",
       "      <td>2.0</td>\n",
       "      <td>0.982120</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>loss_curvature</th>\n",
       "      <td>2.0</td>\n",
       "      <td>0.953584</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>loss_grad</th>\n",
       "      <td>2.0</td>\n",
       "      <td>0.949611</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>lt</th>\n",
       "      <td>2.0</td>\n",
       "      <td>0.702908</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ssft</th>\n",
       "      <td>2.0</td>\n",
       "      <td>0.848983</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                   seed     auroc\n",
       "dataset            metric                        \n",
       "cifar100_duplicate cl               2.0  0.587337\n",
       "                   clt              2.0  0.986984\n",
       "                   in conf.         2.0  0.862300\n",
       "                   loss             2.0  0.988604\n",
       "                   loss_curvature   2.0  0.963915\n",
       "                   loss_grad        2.0  0.971474\n",
       "                   lt               2.0  0.741868\n",
       "                   ssft             2.0  0.793816\n",
       "cifar10_duplicate  cl               2.0  0.553295\n",
       "                   clt              2.0  0.968022\n",
       "                   in conf.         2.0  0.923722\n",
       "                   loss             2.0  0.982120\n",
       "                   loss_curvature   2.0  0.953584\n",
       "                   loss_grad        2.0  0.949611\n",
       "                   lt               2.0  0.702908\n",
       "                   ssft             2.0  0.848983"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "all_results.groupby(['dataset', 'metric']).mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_results.groupby(['dataset', 'metric']).mean().to_csv('./output/dupl_mean.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_results.groupby(['dataset', 'metric']).std().to_csv('./output/dupl_std.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th>seed</th>\n",
       "      <th>auroc</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>dataset</th>\n",
       "      <th>metric</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th rowspan=\"8\" valign=\"top\">cifar100_duplicate</th>\n",
       "      <th>cl</th>\n",
       "      <td>2.0</td>\n",
       "      <td>0.587337</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>clt</th>\n",
       "      <td>2.0</td>\n",
       "      <td>0.986984</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>in conf.</th>\n",
       "      <td>2.0</td>\n",
       "      <td>0.862300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>loss</th>\n",
       "      <td>2.0</td>\n",
       "      <td>0.988604</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>loss_curvature</th>\n",
       "      <td>2.0</td>\n",
       "      <td>0.963915</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>loss_grad</th>\n",
       "      <td>2.0</td>\n",
       "      <td>0.971474</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>lt</th>\n",
       "      <td>2.0</td>\n",
       "      <td>0.741868</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ssft</th>\n",
       "      <td>2.0</td>\n",
       "      <td>0.793816</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"8\" valign=\"top\">cifar10_duplicate</th>\n",
       "      <th>cl</th>\n",
       "      <td>2.0</td>\n",
       "      <td>0.553295</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>clt</th>\n",
       "      <td>2.0</td>\n",
       "      <td>0.968022</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>in conf.</th>\n",
       "      <td>2.0</td>\n",
       "      <td>0.923722</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>loss</th>\n",
       "      <td>2.0</td>\n",
       "      <td>0.982120</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>loss_curvature</th>\n",
       "      <td>2.0</td>\n",
       "      <td>0.953584</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>loss_grad</th>\n",
       "      <td>2.0</td>\n",
       "      <td>0.949611</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>lt</th>\n",
       "      <td>2.0</td>\n",
       "      <td>0.702908</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ssft</th>\n",
       "      <td>2.0</td>\n",
       "      <td>0.848983</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                   seed     auroc\n",
       "dataset            metric                        \n",
       "cifar100_duplicate cl               2.0  0.587337\n",
       "                   clt              2.0  0.986984\n",
       "                   in conf.         2.0  0.862300\n",
       "                   loss             2.0  0.988604\n",
       "                   loss_curvature   2.0  0.963915\n",
       "                   loss_grad        2.0  0.971474\n",
       "                   lt               2.0  0.741868\n",
       "                   ssft             2.0  0.793816\n",
       "cifar10_duplicate  cl               2.0  0.553295\n",
       "                   clt              2.0  0.968022\n",
       "                   in conf.         2.0  0.923722\n",
       "                   loss             2.0  0.982120\n",
       "                   loss_curvature   2.0  0.953584\n",
       "                   loss_grad        2.0  0.949611\n",
       "                   lt               2.0  0.702908\n",
       "                   ssft             2.0  0.848983"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "all_results.groupby(['dataset', 'metric']).mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th>seed</th>\n",
       "      <th>auroc</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>dataset</th>\n",
       "      <th>metric</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th rowspan=\"8\" valign=\"top\">cifar100_duplicate</th>\n",
       "      <th>cl</th>\n",
       "      <td>1.0</td>\n",
       "      <td>0.009021</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>clt</th>\n",
       "      <td>1.0</td>\n",
       "      <td>0.000490</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>in conf.</th>\n",
       "      <td>1.0</td>\n",
       "      <td>0.013077</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>loss</th>\n",
       "      <td>1.0</td>\n",
       "      <td>0.000758</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>loss_curvature</th>\n",
       "      <td>1.0</td>\n",
       "      <td>0.003042</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>loss_grad</th>\n",
       "      <td>1.0</td>\n",
       "      <td>0.002801</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>lt</th>\n",
       "      <td>1.0</td>\n",
       "      <td>0.005899</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ssft</th>\n",
       "      <td>1.0</td>\n",
       "      <td>0.004508</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"8\" valign=\"top\">cifar10_duplicate</th>\n",
       "      <th>cl</th>\n",
       "      <td>1.0</td>\n",
       "      <td>0.003081</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>clt</th>\n",
       "      <td>1.0</td>\n",
       "      <td>0.003355</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>in conf.</th>\n",
       "      <td>1.0</td>\n",
       "      <td>0.011387</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>loss</th>\n",
       "      <td>1.0</td>\n",
       "      <td>0.000609</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>loss_curvature</th>\n",
       "      <td>1.0</td>\n",
       "      <td>0.003026</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>loss_grad</th>\n",
       "      <td>1.0</td>\n",
       "      <td>0.002227</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>lt</th>\n",
       "      <td>1.0</td>\n",
       "      <td>0.000588</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ssft</th>\n",
       "      <td>1.0</td>\n",
       "      <td>0.003446</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                   seed     auroc\n",
       "dataset            metric                        \n",
       "cifar100_duplicate cl               1.0  0.009021\n",
       "                   clt              1.0  0.000490\n",
       "                   in conf.         1.0  0.013077\n",
       "                   loss             1.0  0.000758\n",
       "                   loss_curvature   1.0  0.003042\n",
       "                   loss_grad        1.0  0.002801\n",
       "                   lt               1.0  0.005899\n",
       "                   ssft             1.0  0.004508\n",
       "cifar10_duplicate  cl               1.0  0.003081\n",
       "                   clt              1.0  0.003355\n",
       "                   in conf.         1.0  0.011387\n",
       "                   loss             1.0  0.000609\n",
       "                   loss_curvature   1.0  0.003026\n",
       "                   loss_grad        1.0  0.002227\n",
       "                   lt               1.0  0.000588\n",
       "                   ssft             1.0  0.003446"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "all_results.groupby(['dataset', 'metric']).std()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py3.11_tf",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
