{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/min/a/dravikum/Programs/miniconda3/envs/py3.11_tf/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "----- Seed 1 -- Noise 0.2 ---\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2025-05-19 13:55:24.143964: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.\n",
      "2025-05-19 13:55:24.186874: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
      "To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
      "2025-05-19 13:55:25.591832: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "----- Seed 2 -- Noise 0.2 ---\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "----- Seed 3 -- Noise 0.2 ---\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "----- Seed 1 -- Noise 0.25 ---\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "----- Seed 2 -- Noise 0.25 ---\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "----- Seed 3 -- Noise 0.25 ---\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "----- Seed 1 -- Noise 0.3 ---\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "----- Seed 2 -- Noise 0.3 ---\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "----- Seed 3 -- Noise 0.3 ---\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "----- Seed 1 -- Noise 0.2 ---\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "----- Seed 2 -- Noise 0.2 ---\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "----- Seed 3 -- Noise 0.2 ---\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "----- Seed 1 -- Noise 0.25 ---\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "----- Seed 2 -- Noise 0.25 ---\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "----- Seed 3 -- Noise 0.25 ---\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "----- Seed 1 -- Noise 0.3 ---\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "----- Seed 2 -- Noise 0.3 ---\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "----- Seed 3 -- Noise 0.3 ---\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n"
     ]
    }
   ],
   "source": [
    "import numpy as np \n",
    "from minio_obj_storage import get_numpy_from_cloud\n",
    "import pandas as pd\n",
    "import logging\n",
    "import torch\n",
    "import json\n",
    "from utils.load_dataset import load_dataset\n",
    "from sklearn.metrics import roc_auc_score, precision_recall_curve, auc, roc_curve\n",
    "from cleanlab.filter import find_label_issues\n",
    "\n",
    "with open(\"./config.json\", 'r') as f:\n",
    "    config = json.load(f)\n",
    "\n",
    "def calculate_metrics(y_true, y_scores):\n",
    "    # AUROC\n",
    "    auroc = roc_auc_score(y_true, y_scores)\n",
    "    \n",
    "    # AUPR\n",
    "    precision, recall, _ = precision_recall_curve(y_true, y_scores)\n",
    "    aupr = auc(recall[::-1], precision[::-1])\n",
    "    \n",
    "    # FPR at 95% TPR\n",
    "    fpr, tpr, thresholds = roc_curve(y_true, y_scores)\n",
    "    fpr95 = fpr[(tpr >= 0.95)][0] if any(tpr >= 0.95) else None\n",
    "    \n",
    "    return auroc, aupr, fpr95\n",
    "\n",
    "logger = logging.getLogger(f'Analyze Mislabeled')\n",
    "container_name = 'learning-dynamics-scores'\n",
    "max_epoch = 200\n",
    "epochs_total = 200\n",
    "result_seeds = []\n",
    "result_noise = []\n",
    "result_auroc = []\n",
    "result_aupr = []\n",
    "result_fpr95 = []\n",
    "result_method = []\n",
    "result_dataset = []\n",
    "\n",
    "index = np.array(list(range(50000)))\n",
    "\n",
    "for dataset in ['cifar10', 'cifar100']:\n",
    "    container_dir = dataset\n",
    "    for label_noise in [0.2, 0.25, 0.3]:\n",
    "        for seed in range(1, 4):\n",
    "            print(f'----- Seed {seed} -- Noise {label_noise} ---')\n",
    "            loss_grad = []\n",
    "            loss = []\n",
    "            loss_curvature = []\n",
    "\n",
    "            dataset_obj = load_dataset(\n",
    "                dataset=dataset,\n",
    "                val_split=0,\n",
    "                root_path=config['data_dir'],\n",
    "                random_seed=seed,\n",
    "                logger=logger,\n",
    "                label_noise=label_noise)\n",
    "\n",
    "            ft = np.zeros((epochs_total, dataset_obj.train_length))\n",
    "            lt = np.zeros((epochs_total, dataset_obj.train_length))\n",
    "\n",
    "            for epoch in range(0, max_epoch):\n",
    "                try:\n",
    "                    loss_grad_4_eph = get_numpy_from_cloud(container_name, container_dir, f\"loss_grad_{dataset}_resnet18_noisy_idx_{seed}_epoch_{epoch}_noise_{label_noise}_tid0.pt\")\n",
    "                    loss_4_eph = get_numpy_from_cloud(container_name, container_dir, f\"losses_{dataset}_resnet18_noisy_idx_{seed}_epoch_{epoch}_noise_{label_noise}_tid0.pt\")\n",
    "                    loss_curvature_4_eph = get_numpy_from_cloud(container_name, container_dir, f\"loss_curvature_{dataset}_resnet18_noisy_idx_{seed}_epoch_{epoch}_noise_{label_noise}_h0.001_tid0.pt\")\n",
    "                    \n",
    "                    loss_grad.append(loss_grad_4_eph)\n",
    "                    loss.append(loss_4_eph)\n",
    "                    loss_curvature.append(loss_curvature_4_eph)\n",
    "                except:\n",
    "                    print(f\"Not found {seed} {epoch}\")\n",
    "\n",
    "            for part in [0, 1]:\n",
    "                if part == 0:\n",
    "                    index1 = index[:len(index) // 2]\n",
    "                    index2 = index[len(index) // 2:]\n",
    "                else:\n",
    "                    index2 = index[:len(index) // 2]\n",
    "                    index1 = index[len(index) // 2:]    \n",
    "                pred = get_numpy_from_cloud(container_name, container_dir, f\"ssft_pred_resnet18_part_{part}_noisy_idx_{seed}_noise_{label_noise}.npy\")\n",
    "                ft[:, index1] = pred[epochs_total:, index1]\n",
    "                lt[:, index2] = pred[:epochs_total, index2]\n",
    "\n",
    "            ft = np.mean(ft.T, axis=1)\n",
    "            loss_curvature = np.array(loss_curvature)\n",
    "            loss = np.array(loss)\n",
    "            loss_grad = np.array(loss_grad)\n",
    "            lt = np.argmax(lt.T, axis=1)\n",
    "\n",
    "            tau = loss.mean() # threshold is the average per iteration sample grad\n",
    "            below_threshold = (loss < tau)\n",
    "            lt = 1 - np.argmax(below_threshold, axis=0) # expected lt is roughly the average of when it was learnt\n",
    "            \n",
    "            is_miss = np.zeros((loss.shape[1]))\n",
    "            is_miss[dataset_obj.noisy_idxs] = 1\n",
    "\n",
    "            gt_labels = np.array(dataset_obj.train_loader.dataset.dataset.targets)\n",
    "\n",
    "            prob_file_name = f\"conf_learning_prob_noise_idx_{seed}_noise_{label_noise}.pt\"\n",
    "            prob_4_eph = get_numpy_from_cloud(container_name, container_dir, prob_file_name)\n",
    "\n",
    "            conf_learning_labels = get_numpy_from_cloud(\n",
    "                container_name, \n",
    "                container_dir,\n",
    "                f\"conf_learning_labels_noise_idx_{1}_noise_{label_noise}.pt\"\n",
    "            ).astype(np.int32)\n",
    "\n",
    "            conf_learning = find_label_issues(\n",
    "                conf_learning_labels,\n",
    "                prob_4_eph,\n",
    "                return_indices_ranked_by=\"self_confidence\",\n",
    "                filter_by=\"confident_learning\",\n",
    "            )\n",
    "            \n",
    "            prob_4_eph = get_numpy_from_cloud(\n",
    "                container_name, \n",
    "                container_dir, \n",
    "                f\"prob_{dataset}_resnet18_noisy_idx_{seed}_epoch_{199}_noise_{label_noise}_tid0.pt\")\n",
    "\n",
    "            conf = prob_4_eph.max(axis=1)\n",
    "            conf_learning_soft = np.zeros(loss.shape[1])\n",
    "            for idx, i in enumerate(conf_learning[::-1]):\n",
    "                conf_learning_soft[i] = (idx  + 1) / len(conf_learning)\n",
    "\n",
    "            df = pd.DataFrame(data={\n",
    "                'loss_curvature': loss_curvature.mean(0),\n",
    "                'loss_grad': loss_grad.mean(0), \n",
    "                'noisy_labels': gt_labels,\n",
    "                'conf_learning': conf_learning_soft.astype(np.float32),\n",
    "                'in_conf': 1 - conf,\n",
    "                'lt': lt,\n",
    "                'ssft': 1 - ft,\n",
    "                'is_mis': is_miss})\n",
    "\n",
    "            methods =  ['loss_curvature', 'loss_grad', 'conf_learning', 'in_conf', 'ssft', 'lt']\n",
    "            for metric in methods:\n",
    "                auroc, aupr, fpr95 = calculate_metrics(df['is_mis'], df[metric])\n",
    "                result_seeds.append(seed)\n",
    "                result_noise.append(label_noise)\n",
    "                result_auroc.append(auroc)\n",
    "                result_aupr.append(aupr)\n",
    "                result_fpr95.append(fpr95)\n",
    "                result_method.append(metric)\n",
    "                result_dataset.append(dataset)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th>seed</th>\n",
       "      <th>auroc</th>\n",
       "      <th>aupr</th>\n",
       "      <th>fpr95</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>dataset</th>\n",
       "      <th>label_noise</th>\n",
       "      <th>method</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th rowspan=\"18\" valign=\"top\">cifar10</th>\n",
       "      <th rowspan=\"6\" valign=\"top\">0.2000</th>\n",
       "      <th>conf_learning</th>\n",
       "      <td>2.0000</td>\n",
       "      <td>0.7169</td>\n",
       "      <td>0.4427</td>\n",
       "      <td>1.0000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>in_conf</th>\n",
       "      <td>2.0000</td>\n",
       "      <td>0.5978</td>\n",
       "      <td>0.2696</td>\n",
       "      <td>0.9644</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>loss_curvature</th>\n",
       "      <td>2.0000</td>\n",
       "      <td>0.9934</td>\n",
       "      <td>0.9501</td>\n",
       "      <td>0.0111</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>loss_grad</th>\n",
       "      <td>2.0000</td>\n",
       "      <td>0.9936</td>\n",
       "      <td>0.9539</td>\n",
       "      <td>0.0109</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>lt</th>\n",
       "      <td>2.0000</td>\n",
       "      <td>0.4981</td>\n",
       "      <td>0.3362</td>\n",
       "      <td>0.9604</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ssft</th>\n",
       "      <td>2.0000</td>\n",
       "      <td>0.9077</td>\n",
       "      <td>0.7461</td>\n",
       "      <td>0.3662</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"6\" valign=\"top\">0.2500</th>\n",
       "      <th>conf_learning</th>\n",
       "      <td>2.0000</td>\n",
       "      <td>0.6960</td>\n",
       "      <td>0.4626</td>\n",
       "      <td>1.0000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>in_conf</th>\n",
       "      <td>2.0000</td>\n",
       "      <td>0.5800</td>\n",
       "      <td>0.3085</td>\n",
       "      <td>0.9620</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>loss_curvature</th>\n",
       "      <td>2.0000</td>\n",
       "      <td>0.9932</td>\n",
       "      <td>0.9609</td>\n",
       "      <td>0.0110</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>loss_grad</th>\n",
       "      <td>2.0000</td>\n",
       "      <td>0.9934</td>\n",
       "      <td>0.9635</td>\n",
       "      <td>0.0108</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>lt</th>\n",
       "      <td>2.0000</td>\n",
       "      <td>0.4977</td>\n",
       "      <td>0.4083</td>\n",
       "      <td>0.9598</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ssft</th>\n",
       "      <td>2.0000</td>\n",
       "      <td>0.8910</td>\n",
       "      <td>0.7551</td>\n",
       "      <td>0.4174</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"6\" valign=\"top\">0.3000</th>\n",
       "      <th>conf_learning</th>\n",
       "      <td>2.0000</td>\n",
       "      <td>0.6794</td>\n",
       "      <td>0.4903</td>\n",
       "      <td>1.0000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>in_conf</th>\n",
       "      <td>2.0000</td>\n",
       "      <td>0.5669</td>\n",
       "      <td>0.3558</td>\n",
       "      <td>0.9653</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>loss_curvature</th>\n",
       "      <td>2.0000</td>\n",
       "      <td>0.9932</td>\n",
       "      <td>0.9700</td>\n",
       "      <td>0.0114</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>loss_grad</th>\n",
       "      <td>2.0000</td>\n",
       "      <td>0.9935</td>\n",
       "      <td>0.9724</td>\n",
       "      <td>0.0111</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>lt</th>\n",
       "      <td>2.0000</td>\n",
       "      <td>0.4988</td>\n",
       "      <td>0.5124</td>\n",
       "      <td>0.9559</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ssft</th>\n",
       "      <td>2.0000</td>\n",
       "      <td>0.8710</td>\n",
       "      <td>0.7636</td>\n",
       "      <td>0.4776</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"18\" valign=\"top\">cifar100</th>\n",
       "      <th rowspan=\"6\" valign=\"top\">0.2000</th>\n",
       "      <th>conf_learning</th>\n",
       "      <td>2.0000</td>\n",
       "      <td>0.7030</td>\n",
       "      <td>0.3815</td>\n",
       "      <td>1.0000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>in_conf</th>\n",
       "      <td>2.0000</td>\n",
       "      <td>0.6493</td>\n",
       "      <td>0.2978</td>\n",
       "      <td>0.8515</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>loss_curvature</th>\n",
       "      <td>2.0000</td>\n",
       "      <td>0.9931</td>\n",
       "      <td>0.9482</td>\n",
       "      <td>0.0127</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>loss_grad</th>\n",
       "      <td>2.0000</td>\n",
       "      <td>0.9934</td>\n",
       "      <td>0.9490</td>\n",
       "      <td>0.0111</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>lt</th>\n",
       "      <td>2.0000</td>\n",
       "      <td>0.5119</td>\n",
       "      <td>0.2612</td>\n",
       "      <td>0.9913</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ssft</th>\n",
       "      <td>2.0000</td>\n",
       "      <td>0.8358</td>\n",
       "      <td>0.6783</td>\n",
       "      <td>0.4898</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"6\" valign=\"top\">0.2500</th>\n",
       "      <th>conf_learning</th>\n",
       "      <td>2.0000</td>\n",
       "      <td>0.6833</td>\n",
       "      <td>0.4094</td>\n",
       "      <td>1.0000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>in_conf</th>\n",
       "      <td>2.0000</td>\n",
       "      <td>0.6324</td>\n",
       "      <td>0.3439</td>\n",
       "      <td>0.8649</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>loss_curvature</th>\n",
       "      <td>2.0000</td>\n",
       "      <td>0.9931</td>\n",
       "      <td>0.9612</td>\n",
       "      <td>0.0131</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>loss_grad</th>\n",
       "      <td>2.0000</td>\n",
       "      <td>0.9936</td>\n",
       "      <td>0.9625</td>\n",
       "      <td>0.0111</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>lt</th>\n",
       "      <td>2.0000</td>\n",
       "      <td>0.5069</td>\n",
       "      <td>0.3086</td>\n",
       "      <td>0.9832</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ssft</th>\n",
       "      <td>2.0000</td>\n",
       "      <td>0.8203</td>\n",
       "      <td>0.7020</td>\n",
       "      <td>0.5566</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"6\" valign=\"top\">0.3000</th>\n",
       "      <th>conf_learning</th>\n",
       "      <td>2.0000</td>\n",
       "      <td>0.6662</td>\n",
       "      <td>0.4439</td>\n",
       "      <td>1.0000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>in_conf</th>\n",
       "      <td>2.0000</td>\n",
       "      <td>0.6257</td>\n",
       "      <td>0.3955</td>\n",
       "      <td>0.8724</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>loss_curvature</th>\n",
       "      <td>2.0000</td>\n",
       "      <td>0.9932</td>\n",
       "      <td>0.9697</td>\n",
       "      <td>0.0126</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>loss_grad</th>\n",
       "      <td>2.0000</td>\n",
       "      <td>0.9936</td>\n",
       "      <td>0.9709</td>\n",
       "      <td>0.0110</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>lt</th>\n",
       "      <td>2.0000</td>\n",
       "      <td>0.5041</td>\n",
       "      <td>0.3626</td>\n",
       "      <td>0.9772</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ssft</th>\n",
       "      <td>2.0000</td>\n",
       "      <td>0.8043</td>\n",
       "      <td>0.7223</td>\n",
       "      <td>0.6019</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                      seed  auroc   aupr  fpr95\n",
       "dataset  label_noise method                                    \n",
       "cifar10  0.2000      conf_learning  2.0000 0.7169 0.4427 1.0000\n",
       "                     in_conf        2.0000 0.5978 0.2696 0.9644\n",
       "                     loss_curvature 2.0000 0.9934 0.9501 0.0111\n",
       "                     loss_grad      2.0000 0.9936 0.9539 0.0109\n",
       "                     lt             2.0000 0.4981 0.3362 0.9604\n",
       "                     ssft           2.0000 0.9077 0.7461 0.3662\n",
       "         0.2500      conf_learning  2.0000 0.6960 0.4626 1.0000\n",
       "                     in_conf        2.0000 0.5800 0.3085 0.9620\n",
       "                     loss_curvature 2.0000 0.9932 0.9609 0.0110\n",
       "                     loss_grad      2.0000 0.9934 0.9635 0.0108\n",
       "                     lt             2.0000 0.4977 0.4083 0.9598\n",
       "                     ssft           2.0000 0.8910 0.7551 0.4174\n",
       "         0.3000      conf_learning  2.0000 0.6794 0.4903 1.0000\n",
       "                     in_conf        2.0000 0.5669 0.3558 0.9653\n",
       "                     loss_curvature 2.0000 0.9932 0.9700 0.0114\n",
       "                     loss_grad      2.0000 0.9935 0.9724 0.0111\n",
       "                     lt             2.0000 0.4988 0.5124 0.9559\n",
       "                     ssft           2.0000 0.8710 0.7636 0.4776\n",
       "cifar100 0.2000      conf_learning  2.0000 0.7030 0.3815 1.0000\n",
       "                     in_conf        2.0000 0.6493 0.2978 0.8515\n",
       "                     loss_curvature 2.0000 0.9931 0.9482 0.0127\n",
       "                     loss_grad      2.0000 0.9934 0.9490 0.0111\n",
       "                     lt             2.0000 0.5119 0.2612 0.9913\n",
       "                     ssft           2.0000 0.8358 0.6783 0.4898\n",
       "         0.2500      conf_learning  2.0000 0.6833 0.4094 1.0000\n",
       "                     in_conf        2.0000 0.6324 0.3439 0.8649\n",
       "                     loss_curvature 2.0000 0.9931 0.9612 0.0131\n",
       "                     loss_grad      2.0000 0.9936 0.9625 0.0111\n",
       "                     lt             2.0000 0.5069 0.3086 0.9832\n",
       "                     ssft           2.0000 0.8203 0.7020 0.5566\n",
       "         0.3000      conf_learning  2.0000 0.6662 0.4439 1.0000\n",
       "                     in_conf        2.0000 0.6257 0.3955 0.8724\n",
       "                     loss_curvature 2.0000 0.9932 0.9697 0.0126\n",
       "                     loss_grad      2.0000 0.9936 0.9709 0.0110\n",
       "                     lt             2.0000 0.5041 0.3626 0.9772\n",
       "                     ssft           2.0000 0.8043 0.7223 0.6019"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "results_df = pd.DataFrame({\n",
    "    'seed': result_seeds,\n",
    "    'label_noise': result_noise,\n",
    "    'method': result_method,\n",
    "    'auroc': result_auroc,\n",
    "    'aupr': result_aupr,\n",
    "    'fpr95': result_fpr95,\n",
    "    'dataset': result_dataset\n",
    "})\n",
    "\n",
    "avg_res = results_df.groupby(['dataset', 'label_noise', 'method']).mean()\n",
    "std_res = results_df.groupby(['dataset', 'label_noise', 'method']).std()\n",
    "\n",
    "pd.options.display.float_format = '{:.4f}'.format\n",
    "avg_res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th>seed</th>\n",
       "      <th>auroc</th>\n",
       "      <th>aupr</th>\n",
       "      <th>fpr95</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>dataset</th>\n",
       "      <th>label_noise</th>\n",
       "      <th>method</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th rowspan=\"18\" valign=\"top\">cifar10</th>\n",
       "      <th rowspan=\"6\" valign=\"top\">0.2000</th>\n",
       "      <th>conf_learning</th>\n",
       "      <td>1.0000</td>\n",
       "      <td>0.1539</td>\n",
       "      <td>0.0488</td>\n",
       "      <td>0.0000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>in_conf</th>\n",
       "      <td>1.0000</td>\n",
       "      <td>0.0131</td>\n",
       "      <td>0.0092</td>\n",
       "      <td>0.0310</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>loss_curvature</th>\n",
       "      <td>1.0000</td>\n",
       "      <td>0.0002</td>\n",
       "      <td>0.0012</td>\n",
       "      <td>0.0004</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>loss_grad</th>\n",
       "      <td>1.0000</td>\n",
       "      <td>0.0002</td>\n",
       "      <td>0.0010</td>\n",
       "      <td>0.0002</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>lt</th>\n",
       "      <td>1.0000</td>\n",
       "      <td>0.0004</td>\n",
       "      <td>0.0070</td>\n",
       "      <td>0.0034</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ssft</th>\n",
       "      <td>1.0000</td>\n",
       "      <td>0.0023</td>\n",
       "      <td>0.0063</td>\n",
       "      <td>0.0106</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"6\" valign=\"top\">0.2500</th>\n",
       "      <th>conf_learning</th>\n",
       "      <td>1.0000</td>\n",
       "      <td>0.1387</td>\n",
       "      <td>0.0554</td>\n",
       "      <td>0.0000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>in_conf</th>\n",
       "      <td>1.0000</td>\n",
       "      <td>0.0051</td>\n",
       "      <td>0.0059</td>\n",
       "      <td>0.0336</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>loss_curvature</th>\n",
       "      <td>1.0000</td>\n",
       "      <td>0.0001</td>\n",
       "      <td>0.0004</td>\n",
       "      <td>0.0004</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>loss_grad</th>\n",
       "      <td>1.0000</td>\n",
       "      <td>0.0001</td>\n",
       "      <td>0.0002</td>\n",
       "      <td>0.0003</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>lt</th>\n",
       "      <td>1.0000</td>\n",
       "      <td>0.0020</td>\n",
       "      <td>0.0227</td>\n",
       "      <td>0.0035</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ssft</th>\n",
       "      <td>1.0000</td>\n",
       "      <td>0.0050</td>\n",
       "      <td>0.0070</td>\n",
       "      <td>0.0178</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"6\" valign=\"top\">0.3000</th>\n",
       "      <th>conf_learning</th>\n",
       "      <td>1.0000</td>\n",
       "      <td>0.1264</td>\n",
       "      <td>0.0488</td>\n",
       "      <td>0.0000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>in_conf</th>\n",
       "      <td>1.0000</td>\n",
       "      <td>0.0106</td>\n",
       "      <td>0.0089</td>\n",
       "      <td>0.0314</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>loss_curvature</th>\n",
       "      <td>1.0000</td>\n",
       "      <td>0.0006</td>\n",
       "      <td>0.0035</td>\n",
       "      <td>0.0002</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>loss_grad</th>\n",
       "      <td>1.0000</td>\n",
       "      <td>0.0006</td>\n",
       "      <td>0.0030</td>\n",
       "      <td>0.0003</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>lt</th>\n",
       "      <td>1.0000</td>\n",
       "      <td>0.0028</td>\n",
       "      <td>0.0237</td>\n",
       "      <td>0.0077</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ssft</th>\n",
       "      <td>1.0000</td>\n",
       "      <td>0.0071</td>\n",
       "      <td>0.0118</td>\n",
       "      <td>0.0223</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"18\" valign=\"top\">cifar100</th>\n",
       "      <th rowspan=\"6\" valign=\"top\">0.2000</th>\n",
       "      <th>conf_learning</th>\n",
       "      <td>1.0000</td>\n",
       "      <td>0.1565</td>\n",
       "      <td>0.0300</td>\n",
       "      <td>0.0000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>in_conf</th>\n",
       "      <td>1.0000</td>\n",
       "      <td>0.0075</td>\n",
       "      <td>0.0097</td>\n",
       "      <td>0.0138</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>loss_curvature</th>\n",
       "      <td>1.0000</td>\n",
       "      <td>0.0004</td>\n",
       "      <td>0.0013</td>\n",
       "      <td>0.0009</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>loss_grad</th>\n",
       "      <td>1.0000</td>\n",
       "      <td>0.0003</td>\n",
       "      <td>0.0014</td>\n",
       "      <td>0.0004</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>lt</th>\n",
       "      <td>1.0000</td>\n",
       "      <td>0.0059</td>\n",
       "      <td>0.0063</td>\n",
       "      <td>0.0009</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ssft</th>\n",
       "      <td>1.0000</td>\n",
       "      <td>0.0008</td>\n",
       "      <td>0.0005</td>\n",
       "      <td>0.0308</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"6\" valign=\"top\">0.2500</th>\n",
       "      <th>conf_learning</th>\n",
       "      <td>1.0000</td>\n",
       "      <td>0.1427</td>\n",
       "      <td>0.0273</td>\n",
       "      <td>0.0000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>in_conf</th>\n",
       "      <td>1.0000</td>\n",
       "      <td>0.0051</td>\n",
       "      <td>0.0054</td>\n",
       "      <td>0.0037</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>loss_curvature</th>\n",
       "      <td>1.0000</td>\n",
       "      <td>0.0004</td>\n",
       "      <td>0.0013</td>\n",
       "      <td>0.0013</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>loss_grad</th>\n",
       "      <td>1.0000</td>\n",
       "      <td>0.0002</td>\n",
       "      <td>0.0006</td>\n",
       "      <td>0.0001</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>lt</th>\n",
       "      <td>1.0000</td>\n",
       "      <td>0.0050</td>\n",
       "      <td>0.0019</td>\n",
       "      <td>0.0025</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ssft</th>\n",
       "      <td>1.0000</td>\n",
       "      <td>0.0016</td>\n",
       "      <td>0.0013</td>\n",
       "      <td>0.0075</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"6\" valign=\"top\">0.3000</th>\n",
       "      <th>conf_learning</th>\n",
       "      <td>1.0000</td>\n",
       "      <td>0.1289</td>\n",
       "      <td>0.0337</td>\n",
       "      <td>0.0000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>in_conf</th>\n",
       "      <td>1.0000</td>\n",
       "      <td>0.0044</td>\n",
       "      <td>0.0035</td>\n",
       "      <td>0.0059</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>loss_curvature</th>\n",
       "      <td>1.0000</td>\n",
       "      <td>0.0002</td>\n",
       "      <td>0.0010</td>\n",
       "      <td>0.0005</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>loss_grad</th>\n",
       "      <td>1.0000</td>\n",
       "      <td>0.0001</td>\n",
       "      <td>0.0008</td>\n",
       "      <td>0.0004</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>lt</th>\n",
       "      <td>1.0000</td>\n",
       "      <td>0.0002</td>\n",
       "      <td>0.0052</td>\n",
       "      <td>0.0013</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ssft</th>\n",
       "      <td>1.0000</td>\n",
       "      <td>0.0061</td>\n",
       "      <td>0.0056</td>\n",
       "      <td>0.0102</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                      seed  auroc   aupr  fpr95\n",
       "dataset  label_noise method                                    \n",
       "cifar10  0.2000      conf_learning  1.0000 0.1539 0.0488 0.0000\n",
       "                     in_conf        1.0000 0.0131 0.0092 0.0310\n",
       "                     loss_curvature 1.0000 0.0002 0.0012 0.0004\n",
       "                     loss_grad      1.0000 0.0002 0.0010 0.0002\n",
       "                     lt             1.0000 0.0004 0.0070 0.0034\n",
       "                     ssft           1.0000 0.0023 0.0063 0.0106\n",
       "         0.2500      conf_learning  1.0000 0.1387 0.0554 0.0000\n",
       "                     in_conf        1.0000 0.0051 0.0059 0.0336\n",
       "                     loss_curvature 1.0000 0.0001 0.0004 0.0004\n",
       "                     loss_grad      1.0000 0.0001 0.0002 0.0003\n",
       "                     lt             1.0000 0.0020 0.0227 0.0035\n",
       "                     ssft           1.0000 0.0050 0.0070 0.0178\n",
       "         0.3000      conf_learning  1.0000 0.1264 0.0488 0.0000\n",
       "                     in_conf        1.0000 0.0106 0.0089 0.0314\n",
       "                     loss_curvature 1.0000 0.0006 0.0035 0.0002\n",
       "                     loss_grad      1.0000 0.0006 0.0030 0.0003\n",
       "                     lt             1.0000 0.0028 0.0237 0.0077\n",
       "                     ssft           1.0000 0.0071 0.0118 0.0223\n",
       "cifar100 0.2000      conf_learning  1.0000 0.1565 0.0300 0.0000\n",
       "                     in_conf        1.0000 0.0075 0.0097 0.0138\n",
       "                     loss_curvature 1.0000 0.0004 0.0013 0.0009\n",
       "                     loss_grad      1.0000 0.0003 0.0014 0.0004\n",
       "                     lt             1.0000 0.0059 0.0063 0.0009\n",
       "                     ssft           1.0000 0.0008 0.0005 0.0308\n",
       "         0.2500      conf_learning  1.0000 0.1427 0.0273 0.0000\n",
       "                     in_conf        1.0000 0.0051 0.0054 0.0037\n",
       "                     loss_curvature 1.0000 0.0004 0.0013 0.0013\n",
       "                     loss_grad      1.0000 0.0002 0.0006 0.0001\n",
       "                     lt             1.0000 0.0050 0.0019 0.0025\n",
       "                     ssft           1.0000 0.0016 0.0013 0.0075\n",
       "         0.3000      conf_learning  1.0000 0.1289 0.0337 0.0000\n",
       "                     in_conf        1.0000 0.0044 0.0035 0.0059\n",
       "                     loss_curvature 1.0000 0.0002 0.0010 0.0005\n",
       "                     loss_grad      1.0000 0.0001 0.0008 0.0004\n",
       "                     lt             1.0000 0.0002 0.0052 0.0013\n",
       "                     ssft           1.0000 0.0061 0.0056 0.0102"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "std_res"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py3.11_tf",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
