{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np \n",
    "from azure_blob_storage import get_numpy_from_azure\n",
    "import pandas as pd\n",
    "import logging\n",
    "import json\n",
    "from utils.load_dataset import load_dataset\n",
    "from sklearn.metrics import roc_auc_score, precision_recall_curve, auc, roc_curve\n",
    "from cleanlab.filter import find_label_issues\n",
    "\n",
    "with open(\"./config.json\", 'r') as f:\n",
    "    config = json.load(f)\n",
    "\n",
    "def calculate_metrics(y_true, y_scores):\n",
    "    # AUROC\n",
    "    auroc = roc_auc_score(y_true, y_scores)\n",
    "    \n",
    "    # AUPR\n",
    "    precision, recall, _ = precision_recall_curve(y_true, y_scores)\n",
    "    aupr = auc(recall, precision)\n",
    "    \n",
    "    # FPR at 95% TPR\n",
    "    fpr, tpr, thresholds = roc_curve(y_true, y_scores)\n",
    "    fpr95 = fpr[(tpr >= 0.95)][0] if any(tpr >= 0.95) else None\n",
    "    \n",
    "    return auroc, aupr, fpr95\n",
    "\n",
    "logger = logging.getLogger(f'Analyze the')\n",
    "container_name = 'learning-dynamics-scores'\n",
    "max_epoch = 1 #50\n",
    "result_seeds = []\n",
    "result_noise = []\n",
    "result_auroc = []\n",
    "result_aupr = []\n",
    "result_fpr95 = []\n",
    "result_method = []\n",
    "result_dataset = []\n",
    "\n",
    "for dataset in []: #'cifar100', 'cifar10']:\n",
    "    container_dir = dataset\n",
    "    for label_noise in [0.01, 0.02, 0.05, 0.1]:\n",
    "        for seed in range(2, 4):\n",
    "            print(f'----- Seed {seed} -- Noise {label_noise} ---')\n",
    "            loss_grad = []\n",
    "            loss = []\n",
    "            loss_curvature = []\n",
    "            for epoch in range(0, max_epoch):\n",
    "                try:\n",
    "                    loss_grad_4_eph = get_numpy_from_azure(container_name, container_dir, f\"loss_grad_{dataset}_resnet18_noisy_idx_{seed}_epoch_{epoch}_noise_{label_noise}_tid0.pt\")\n",
    "                    loss_4_eph = get_numpy_from_azure(container_name, container_dir, f\"losses_{dataset}_resnet18_noisy_idx_{seed}_epoch_{epoch}_noise_{label_noise}_tid0.pt\")\n",
    "                    loss_curvature_4_eph = get_numpy_from_azure(container_name, container_dir, f\"loss_curvature_{dataset}_resnet18_noisy_idx_{seed}_epoch_{epoch}_noise_{label_noise}_h0.001_tid0.pt\")\n",
    "                    loss_grad.append(loss_grad_4_eph)\n",
    "                    loss.append(loss_4_eph)\n",
    "                    loss_curvature.append(loss_curvature_4_eph)\n",
    "                except:\n",
    "                    print(f\"Not found {seed} {epoch}\")\n",
    "\n",
    "            loss_curvature = np.array(loss_curvature)\n",
    "            loss = np.array(loss)\n",
    "            loss_grad = np.array(loss_grad)\n",
    "            \n",
    "            dataset_obj = load_dataset(\n",
    "                dataset=dataset,\n",
    "                val_split=0,\n",
    "                root_path=config['data_dir'],\n",
    "                random_seed=seed,\n",
    "                logger=logger,\n",
    "                label_noise=label_noise)\n",
    "\n",
    "            is_miss = np.zeros((loss.shape[1]))\n",
    "            is_miss[dataset_obj.noisy_idxs] = 1\n",
    "\n",
    "            gt_labels = np.array(dataset_obj.train_loader.dataset.dataset.targets)\n",
    "\n",
    "            prob_file_name = f\"conf_learning_prob_noise_idx_{seed}_noise_{label_noise}.pt\"\n",
    "            prob_4_eph = get_numpy_from_azure(container_name, container_dir, prob_file_name)\n",
    "\n",
    "            conf_learning_labels = get_numpy_from_azure(\n",
    "                container_name, \n",
    "                container_dir,\n",
    "                f\"conf_learning_labels_noise_idx_{1}_noise_{label_noise}.pt\"\n",
    "            ).astype(np.int32)\n",
    "\n",
    "            conf_learning = find_label_issues(\n",
    "                conf_learning_labels,\n",
    "                prob_4_eph,\n",
    "                return_indices_ranked_by=\"self_confidence\",\n",
    "                filter_by=\"confident_learning\",\n",
    "            )\n",
    "            \n",
    "            prob_4_eph = get_numpy_from_azure(\n",
    "                container_name, \n",
    "                container_dir, \n",
    "                f\"prob_{dataset}_resnet18_noisy_idx_{seed}_epoch_{199}_noise_{label_noise}_tid0.pt\")\n",
    "\n",
    "            conf = prob_4_eph.max(axis=1)\n",
    "            conf_learning_soft = np.zeros(loss.shape[1])\n",
    "            for idx, i in enumerate(conf_learning[::-1]):\n",
    "                conf_learning_soft[i] = (idx  + 1) / len(conf_learning)\n",
    "\n",
    "            df = pd.DataFrame(data={\n",
    "                'loss': loss.mean(0), \n",
    "                'loss_curvature': loss_curvature.mean(0), \n",
    "                'loss_grad': loss_grad.mean(0), \n",
    "                'noisy_labels': gt_labels,\n",
    "                'conf_learning': conf_learning_soft.astype(np.float32),\n",
    "                'in_conf': 1 - conf,\n",
    "                'is_mis': is_miss})\n",
    "\n",
    "            methods = ['loss', 'loss_curvature', 'loss_grad', 'conf_learning', 'in_conf']\n",
    "            for metric in methods:\n",
    "                auroc, aupr, fpr95 = calculate_metrics(df['is_mis'], df[metric])\n",
    "                result_seeds.append(seed)\n",
    "                result_noise.append(label_noise)\n",
    "                result_auroc.append(auroc)\n",
    "                result_aupr.append(aupr)\n",
    "                result_fpr95.append(fpr95)\n",
    "                result_method.append(metric)\n",
    "                result_dataset.append(dataset)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>seed</th>\n",
       "      <th>label_noise</th>\n",
       "      <th>method</th>\n",
       "      <th>auroc</th>\n",
       "      <th>aupr</th>\n",
       "      <th>fpr95</th>\n",
       "      <th>dataset</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>2</td>\n",
       "      <td>0.01</td>\n",
       "      <td>loss</td>\n",
       "      <td>0.484704</td>\n",
       "      <td>0.009905</td>\n",
       "      <td>0.947798</td>\n",
       "      <td>cifar100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>2</td>\n",
       "      <td>0.01</td>\n",
       "      <td>loss_curvature</td>\n",
       "      <td>0.489576</td>\n",
       "      <td>0.009878</td>\n",
       "      <td>0.961960</td>\n",
       "      <td>cifar100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>2</td>\n",
       "      <td>0.01</td>\n",
       "      <td>loss_grad</td>\n",
       "      <td>0.496035</td>\n",
       "      <td>0.010024</td>\n",
       "      <td>0.964485</td>\n",
       "      <td>cifar100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>2</td>\n",
       "      <td>0.01</td>\n",
       "      <td>conf_learning</td>\n",
       "      <td>0.857524</td>\n",
       "      <td>0.206060</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>cifar100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>2</td>\n",
       "      <td>0.01</td>\n",
       "      <td>in_conf</td>\n",
       "      <td>0.737499</td>\n",
       "      <td>0.027769</td>\n",
       "      <td>0.712202</td>\n",
       "      <td>cifar100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>75</th>\n",
       "      <td>3</td>\n",
       "      <td>0.10</td>\n",
       "      <td>loss</td>\n",
       "      <td>0.497906</td>\n",
       "      <td>0.100003</td>\n",
       "      <td>0.948689</td>\n",
       "      <td>cifar10</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>76</th>\n",
       "      <td>3</td>\n",
       "      <td>0.10</td>\n",
       "      <td>loss_curvature</td>\n",
       "      <td>0.501485</td>\n",
       "      <td>0.100257</td>\n",
       "      <td>0.946556</td>\n",
       "      <td>cifar10</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>77</th>\n",
       "      <td>3</td>\n",
       "      <td>0.10</td>\n",
       "      <td>loss_grad</td>\n",
       "      <td>0.505833</td>\n",
       "      <td>0.101018</td>\n",
       "      <td>0.948511</td>\n",
       "      <td>cifar10</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>78</th>\n",
       "      <td>3</td>\n",
       "      <td>0.10</td>\n",
       "      <td>conf_learning</td>\n",
       "      <td>0.857250</td>\n",
       "      <td>0.335264</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>cifar10</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>79</th>\n",
       "      <td>3</td>\n",
       "      <td>0.10</td>\n",
       "      <td>in_conf</td>\n",
       "      <td>0.655832</td>\n",
       "      <td>0.179811</td>\n",
       "      <td>0.924667</td>\n",
       "      <td>cifar10</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>80 rows × 7 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "    seed  label_noise          method     auroc      aupr     fpr95   dataset\n",
       "0      2         0.01            loss  0.484704  0.009905  0.947798  cifar100\n",
       "1      2         0.01  loss_curvature  0.489576  0.009878  0.961960  cifar100\n",
       "2      2         0.01       loss_grad  0.496035  0.010024  0.964485  cifar100\n",
       "3      2         0.01   conf_learning  0.857524  0.206060  1.000000  cifar100\n",
       "4      2         0.01         in_conf  0.737499  0.027769  0.712202  cifar100\n",
       "..   ...          ...             ...       ...       ...       ...       ...\n",
       "75     3         0.10            loss  0.497906  0.100003  0.948689   cifar10\n",
       "76     3         0.10  loss_curvature  0.501485  0.100257  0.946556   cifar10\n",
       "77     3         0.10       loss_grad  0.505833  0.101018  0.948511   cifar10\n",
       "78     3         0.10   conf_learning  0.857250  0.335264  1.000000   cifar10\n",
       "79     3         0.10         in_conf  0.655832  0.179811  0.924667   cifar10\n",
       "\n",
       "[80 rows x 7 columns]"
      ]
     },
     "execution_count": 45,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "results_df = pd.DataFrame({\n",
    "    'seed': result_seeds,\n",
    "    'label_noise': result_noise,\n",
    "    'method': result_method,\n",
    "    'auroc': result_auroc,\n",
    "    'aupr': result_aupr,\n",
    "    'fpr95': result_fpr95,\n",
    "    'dataset': result_dataset\n",
    "})\n",
    "\n",
    "results_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [],
   "source": [
    "avg_res = results_df.groupby(['dataset', 'label_noise', 'method']).mean()\n",
    "std_res = results_df.groupby(['dataset', 'label_noise', 'method']).std()\n",
    "\n",
    "avg_res.to_csv('./avg_res2.csv')\n",
    "std_res.to_csv(\"./std_res2.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [],
   "source": [
    "results_df.to_csv(\"./res_mis2.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n"
     ]
    }
   ],
   "source": [
    "\n",
    "for label_noise in [0.01, 0.02, 0.05, 0.1]:\n",
    "    for dataset in ['cifar100', 'cifar10']:\n",
    "        container_dir = dataset\n",
    "        for seed in [1,2,3]:\n",
    "            preds = []\n",
    "            for epoch in range(200):\n",
    "                pred = get_numpy_from_azure(container_name, container_dir, f\"pred_{dataset}_resnet18_noisy_idx_{seed}_epoch_{epoch}_noise_{label_noise}_tid0.pt\")\n",
    "                preds.append(pred)\n",
    "\n",
    "            preds = np.array(preds).T\n",
    "            lt = preds.mean(1)\n",
    "\n",
    "            dataset_obj = load_dataset(\n",
    "                    dataset=dataset,\n",
    "                    val_split=0,\n",
    "                    root_path=config['data_dir'],\n",
    "                    random_seed=seed,\n",
    "                    logger=logger,\n",
    "                    label_noise=label_noise)\n",
    "\n",
    "            is_miss = np.zeros((dataset_obj.train_length))\n",
    "            is_miss[dataset_obj.noisy_idxs] = 1\n",
    "\n",
    "            for metric in ['lt']:\n",
    "                auroc, aupr, fpr95 = calculate_metrics(is_miss, 1-lt)\n",
    "                result_seeds.append(seed)\n",
    "                result_noise.append(label_noise)\n",
    "                result_auroc.append(auroc)\n",
    "                result_aupr.append(aupr)\n",
    "                result_fpr95.append(fpr95)\n",
    "                result_method.append(metric)\n",
    "                result_dataset.append(dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th>seed</th>\n",
       "      <th>auroc</th>\n",
       "      <th>aupr</th>\n",
       "      <th>fpr95</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>dataset</th>\n",
       "      <th>method</th>\n",
       "      <th>label_noise</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th rowspan=\"24\" valign=\"top\">cifar10</th>\n",
       "      <th rowspan=\"4\" valign=\"top\">conf_learning</th>\n",
       "      <th>0.01</th>\n",
       "      <td>2.5</td>\n",
       "      <td>0.865128</td>\n",
       "      <td>0.235432</td>\n",
       "      <td>1.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.02</th>\n",
       "      <td>2.5</td>\n",
       "      <td>0.890472</td>\n",
       "      <td>0.268471</td>\n",
       "      <td>1.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.05</th>\n",
       "      <td>2.5</td>\n",
       "      <td>0.887356</td>\n",
       "      <td>0.303826</td>\n",
       "      <td>1.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.10</th>\n",
       "      <td>2.5</td>\n",
       "      <td>0.855121</td>\n",
       "      <td>0.331980</td>\n",
       "      <td>1.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"4\" valign=\"top\">in_conf</th>\n",
       "      <th>0.01</th>\n",
       "      <td>2.5</td>\n",
       "      <td>0.874435</td>\n",
       "      <td>0.094724</td>\n",
       "      <td>0.502141</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.02</th>\n",
       "      <td>2.5</td>\n",
       "      <td>0.814045</td>\n",
       "      <td>0.109471</td>\n",
       "      <td>0.716082</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.05</th>\n",
       "      <td>2.5</td>\n",
       "      <td>0.733242</td>\n",
       "      <td>0.139594</td>\n",
       "      <td>0.845758</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.10</th>\n",
       "      <td>2.5</td>\n",
       "      <td>0.655193</td>\n",
       "      <td>0.180920</td>\n",
       "      <td>0.940767</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"4\" valign=\"top\">loss</th>\n",
       "      <th>0.01</th>\n",
       "      <td>2.5</td>\n",
       "      <td>0.485553</td>\n",
       "      <td>0.009594</td>\n",
       "      <td>0.953333</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.02</th>\n",
       "      <td>2.5</td>\n",
       "      <td>0.491932</td>\n",
       "      <td>0.019334</td>\n",
       "      <td>0.946724</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.05</th>\n",
       "      <td>2.5</td>\n",
       "      <td>0.498724</td>\n",
       "      <td>0.049527</td>\n",
       "      <td>0.949958</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.10</th>\n",
       "      <td>2.5</td>\n",
       "      <td>0.500474</td>\n",
       "      <td>0.100867</td>\n",
       "      <td>0.950967</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"4\" valign=\"top\">loss_curvature</th>\n",
       "      <th>0.01</th>\n",
       "      <td>2.5</td>\n",
       "      <td>0.501023</td>\n",
       "      <td>0.009935</td>\n",
       "      <td>0.958939</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.02</th>\n",
       "      <td>2.5</td>\n",
       "      <td>0.485957</td>\n",
       "      <td>0.018881</td>\n",
       "      <td>0.944286</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.05</th>\n",
       "      <td>2.5</td>\n",
       "      <td>0.503167</td>\n",
       "      <td>0.050107</td>\n",
       "      <td>0.946937</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.10</th>\n",
       "      <td>2.5</td>\n",
       "      <td>0.500125</td>\n",
       "      <td>0.100149</td>\n",
       "      <td>0.948033</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"4\" valign=\"top\">loss_grad</th>\n",
       "      <th>0.01</th>\n",
       "      <td>2.5</td>\n",
       "      <td>0.494204</td>\n",
       "      <td>0.009765</td>\n",
       "      <td>0.959101</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.02</th>\n",
       "      <td>2.5</td>\n",
       "      <td>0.494249</td>\n",
       "      <td>0.019500</td>\n",
       "      <td>0.948286</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.05</th>\n",
       "      <td>2.5</td>\n",
       "      <td>0.504706</td>\n",
       "      <td>0.050856</td>\n",
       "      <td>0.950105</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.10</th>\n",
       "      <td>2.5</td>\n",
       "      <td>0.502819</td>\n",
       "      <td>0.101140</td>\n",
       "      <td>0.951089</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"4\" valign=\"top\">lt</th>\n",
       "      <th>0.01</th>\n",
       "      <td>2.0</td>\n",
       "      <td>0.988792</td>\n",
       "      <td>0.309281</td>\n",
       "      <td>0.010229</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.02</th>\n",
       "      <td>2.0</td>\n",
       "      <td>0.988920</td>\n",
       "      <td>0.456666</td>\n",
       "      <td>0.010075</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.05</th>\n",
       "      <td>2.0</td>\n",
       "      <td>0.989658</td>\n",
       "      <td>0.653368</td>\n",
       "      <td>0.010154</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.10</th>\n",
       "      <td>2.0</td>\n",
       "      <td>0.989893</td>\n",
       "      <td>0.779442</td>\n",
       "      <td>0.010230</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"24\" valign=\"top\">cifar100</th>\n",
       "      <th rowspan=\"4\" valign=\"top\">conf_learning</th>\n",
       "      <th>0.01</th>\n",
       "      <td>2.5</td>\n",
       "      <td>0.872252</td>\n",
       "      <td>0.211245</td>\n",
       "      <td>1.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.02</th>\n",
       "      <td>2.5</td>\n",
       "      <td>0.883835</td>\n",
       "      <td>0.245146</td>\n",
       "      <td>1.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.05</th>\n",
       "      <td>2.5</td>\n",
       "      <td>0.873299</td>\n",
       "      <td>0.277597</td>\n",
       "      <td>1.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.10</th>\n",
       "      <td>2.5</td>\n",
       "      <td>0.853590</td>\n",
       "      <td>0.311817</td>\n",
       "      <td>1.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"4\" valign=\"top\">in_conf</th>\n",
       "      <th>0.01</th>\n",
       "      <td>2.5</td>\n",
       "      <td>0.728815</td>\n",
       "      <td>0.026181</td>\n",
       "      <td>0.716657</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.02</th>\n",
       "      <td>2.5</td>\n",
       "      <td>0.725245</td>\n",
       "      <td>0.047506</td>\n",
       "      <td>0.707031</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.05</th>\n",
       "      <td>2.5</td>\n",
       "      <td>0.710836</td>\n",
       "      <td>0.104555</td>\n",
       "      <td>0.738453</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.10</th>\n",
       "      <td>2.5</td>\n",
       "      <td>0.691419</td>\n",
       "      <td>0.186432</td>\n",
       "      <td>0.793233</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"4\" valign=\"top\">loss</th>\n",
       "      <th>0.01</th>\n",
       "      <td>2.5</td>\n",
       "      <td>0.476275</td>\n",
       "      <td>0.009423</td>\n",
       "      <td>0.954848</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.02</th>\n",
       "      <td>2.5</td>\n",
       "      <td>0.483714</td>\n",
       "      <td>0.019222</td>\n",
       "      <td>0.955633</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.05</th>\n",
       "      <td>2.5</td>\n",
       "      <td>0.495444</td>\n",
       "      <td>0.049193</td>\n",
       "      <td>0.948505</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.10</th>\n",
       "      <td>2.5</td>\n",
       "      <td>0.498975</td>\n",
       "      <td>0.099728</td>\n",
       "      <td>0.947044</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"4\" valign=\"top\">loss_curvature</th>\n",
       "      <th>0.01</th>\n",
       "      <td>2.5</td>\n",
       "      <td>0.494691</td>\n",
       "      <td>0.010226</td>\n",
       "      <td>0.960515</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.02</th>\n",
       "      <td>2.5</td>\n",
       "      <td>0.487300</td>\n",
       "      <td>0.019918</td>\n",
       "      <td>0.949684</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.05</th>\n",
       "      <td>2.5</td>\n",
       "      <td>0.498324</td>\n",
       "      <td>0.050192</td>\n",
       "      <td>0.951958</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.10</th>\n",
       "      <td>2.5</td>\n",
       "      <td>0.498045</td>\n",
       "      <td>0.100605</td>\n",
       "      <td>0.953189</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"4\" valign=\"top\">loss_grad</th>\n",
       "      <th>0.01</th>\n",
       "      <td>2.5</td>\n",
       "      <td>0.489182</td>\n",
       "      <td>0.009830</td>\n",
       "      <td>0.955899</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.02</th>\n",
       "      <td>2.5</td>\n",
       "      <td>0.489540</td>\n",
       "      <td>0.019920</td>\n",
       "      <td>0.947735</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.05</th>\n",
       "      <td>2.5</td>\n",
       "      <td>0.493861</td>\n",
       "      <td>0.049650</td>\n",
       "      <td>0.953021</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.10</th>\n",
       "      <td>2.5</td>\n",
       "      <td>0.498231</td>\n",
       "      <td>0.100768</td>\n",
       "      <td>0.952500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"4\" valign=\"top\">lt</th>\n",
       "      <th>0.01</th>\n",
       "      <td>2.0</td>\n",
       "      <td>0.989924</td>\n",
       "      <td>0.328575</td>\n",
       "      <td>0.010175</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.02</th>\n",
       "      <td>2.0</td>\n",
       "      <td>0.989998</td>\n",
       "      <td>0.478288</td>\n",
       "      <td>0.010170</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.05</th>\n",
       "      <td>2.0</td>\n",
       "      <td>0.990051</td>\n",
       "      <td>0.677260</td>\n",
       "      <td>0.010154</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.10</th>\n",
       "      <td>2.0</td>\n",
       "      <td>0.990217</td>\n",
       "      <td>0.804914</td>\n",
       "      <td>0.010059</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                     seed     auroc      aupr     fpr95\n",
       "dataset  method         label_noise                                    \n",
       "cifar10  conf_learning  0.01          2.5  0.865128  0.235432  1.000000\n",
       "                        0.02          2.5  0.890472  0.268471  1.000000\n",
       "                        0.05          2.5  0.887356  0.303826  1.000000\n",
       "                        0.10          2.5  0.855121  0.331980  1.000000\n",
       "         in_conf        0.01          2.5  0.874435  0.094724  0.502141\n",
       "                        0.02          2.5  0.814045  0.109471  0.716082\n",
       "                        0.05          2.5  0.733242  0.139594  0.845758\n",
       "                        0.10          2.5  0.655193  0.180920  0.940767\n",
       "         loss           0.01          2.5  0.485553  0.009594  0.953333\n",
       "                        0.02          2.5  0.491932  0.019334  0.946724\n",
       "                        0.05          2.5  0.498724  0.049527  0.949958\n",
       "                        0.10          2.5  0.500474  0.100867  0.950967\n",
       "         loss_curvature 0.01          2.5  0.501023  0.009935  0.958939\n",
       "                        0.02          2.5  0.485957  0.018881  0.944286\n",
       "                        0.05          2.5  0.503167  0.050107  0.946937\n",
       "                        0.10          2.5  0.500125  0.100149  0.948033\n",
       "         loss_grad      0.01          2.5  0.494204  0.009765  0.959101\n",
       "                        0.02          2.5  0.494249  0.019500  0.948286\n",
       "                        0.05          2.5  0.504706  0.050856  0.950105\n",
       "                        0.10          2.5  0.502819  0.101140  0.951089\n",
       "         lt             0.01          2.0  0.988792  0.309281  0.010229\n",
       "                        0.02          2.0  0.988920  0.456666  0.010075\n",
       "                        0.05          2.0  0.989658  0.653368  0.010154\n",
       "                        0.10          2.0  0.989893  0.779442  0.010230\n",
       "cifar100 conf_learning  0.01          2.5  0.872252  0.211245  1.000000\n",
       "                        0.02          2.5  0.883835  0.245146  1.000000\n",
       "                        0.05          2.5  0.873299  0.277597  1.000000\n",
       "                        0.10          2.5  0.853590  0.311817  1.000000\n",
       "         in_conf        0.01          2.5  0.728815  0.026181  0.716657\n",
       "                        0.02          2.5  0.725245  0.047506  0.707031\n",
       "                        0.05          2.5  0.710836  0.104555  0.738453\n",
       "                        0.10          2.5  0.691419  0.186432  0.793233\n",
       "         loss           0.01          2.5  0.476275  0.009423  0.954848\n",
       "                        0.02          2.5  0.483714  0.019222  0.955633\n",
       "                        0.05          2.5  0.495444  0.049193  0.948505\n",
       "                        0.10          2.5  0.498975  0.099728  0.947044\n",
       "         loss_curvature 0.01          2.5  0.494691  0.010226  0.960515\n",
       "                        0.02          2.5  0.487300  0.019918  0.949684\n",
       "                        0.05          2.5  0.498324  0.050192  0.951958\n",
       "                        0.10          2.5  0.498045  0.100605  0.953189\n",
       "         loss_grad      0.01          2.5  0.489182  0.009830  0.955899\n",
       "                        0.02          2.5  0.489540  0.019920  0.947735\n",
       "                        0.05          2.5  0.493861  0.049650  0.953021\n",
       "                        0.10          2.5  0.498231  0.100768  0.952500\n",
       "         lt             0.01          2.0  0.989924  0.328575  0.010175\n",
       "                        0.02          2.0  0.989998  0.478288  0.010170\n",
       "                        0.05          2.0  0.990051  0.677260  0.010154\n",
       "                        0.10          2.0  0.990217  0.804914  0.010059"
      ]
     },
     "execution_count": 61,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "results_df = pd.DataFrame({\n",
    "    'seed': result_seeds,\n",
    "    'label_noise': result_noise,\n",
    "    'method': result_method,\n",
    "    'auroc': result_auroc,\n",
    "    'aupr': result_aupr,\n",
    "    'fpr95': result_fpr95,\n",
    "    'dataset': result_dataset\n",
    "})\n",
    "\n",
    "results_df.groupby(['dataset', 'method', 'label_noise']).mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 87,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1 cifar100 0.01\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "2 cifar100 0.01\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "3 cifar100 0.01\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "1 cifar10 0.01\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "2 cifar10 0.01\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "3 cifar10 0.01\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "1 cifar100 0.02\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "2 cifar100 0.02\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "3 cifar100 0.02\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "1 cifar10 0.02\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "2 cifar10 0.02\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "3 cifar10 0.02\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "1 cifar100 0.05\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "2 cifar100 0.05\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "3 cifar100 0.05\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "1 cifar10 0.05\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "2 cifar10 0.05\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "3 cifar10 0.05\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "1 cifar100 0.1\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "2 cifar100 0.1\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "3 cifar100 0.1\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "1 cifar10 0.1\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "2 cifar10 0.1\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "3 cifar10 0.1\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n"
     ]
    }
   ],
   "source": [
    "result_seeds = []\n",
    "result_noise = []\n",
    "result_auroc = []\n",
    "result_aupr = []\n",
    "result_fpr95 = []\n",
    "result_method = []\n",
    "result_dataset = []\n",
    "\n",
    "for label_noise in [0.01, 0.02, 0.05, 0.1]:\n",
    "    for dataset in ['cifar100', 'cifar10']:\n",
    "        container_dir = dataset\n",
    "        for seed in [1,2,3]:\n",
    "            print(f\"{seed} {dataset} {label_noise}\")\n",
    "            preds = []\n",
    "            for epoch in range(200):\n",
    "                pred = get_numpy_from_azure(container_name, container_dir, f\"pred_{dataset}_resnet18_noisy_idx_{seed}_epoch_{epoch}_noise_{label_noise}_tid0.pt\")\n",
    "                preds.append(pred)\n",
    "\n",
    "            preds = np.array(preds).T\n",
    "            lt = np.argmax(preds, axis=1)\n",
    "            # print(lt.shape)\n",
    "            # print(lt)\n",
    "            # print(preds[0])\n",
    "\n",
    "            dataset_obj = load_dataset(\n",
    "                    dataset=dataset,\n",
    "                    val_split=0,\n",
    "                    root_path=config['data_dir'],\n",
    "                    random_seed=seed,\n",
    "                    logger=logger,\n",
    "                    label_noise=label_noise)\n",
    "\n",
    "            is_miss = np.zeros((dataset_obj.train_length))\n",
    "            is_miss[dataset_obj.noisy_idxs] = 1\n",
    "\n",
    "            for metric in ['lt']:\n",
    "                auroc, aupr, fpr95 = calculate_metrics(is_miss, 1-lt)\n",
    "                result_seeds.append(seed)\n",
    "                result_noise.append(label_noise)\n",
    "                result_auroc.append(auroc)\n",
    "                result_aupr.append(aupr)\n",
    "                result_fpr95.append(fpr95)\n",
    "                result_method.append(metric)\n",
    "                result_dataset.append(dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 91,
   "metadata": {},
   "outputs": [],
   "source": [
    "results_df = pd.DataFrame({\n",
    "    'seed': result_seeds,\n",
    "    'label_noise': result_noise,\n",
    "    'method': result_method,\n",
    "    'auroc': result_auroc,\n",
    "    'aupr': result_aupr,\n",
    "    'fpr95': result_fpr95,\n",
    "    'dataset': result_dataset\n",
    "})\n",
    "\n",
    "results_df.groupby(['dataset', 'method', 'label_noise']).mean().to_csv(\"./mean_lt.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 92,
   "metadata": {},
   "outputs": [],
   "source": [
    "results_df.groupby(['dataset', 'method', 'label_noise']).std().to_csv(\"./std_lt.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 96,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1 cifar100 0.01\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "2 cifar100 0.01\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "3 cifar100 0.01\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "1 cifar10 0.01\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "2 cifar10 0.01\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "3 cifar10 0.01\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "1 cifar100 0.02\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "2 cifar100 0.02\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "3 cifar100 0.02\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "1 cifar10 0.02\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "2 cifar10 0.02\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "3 cifar10 0.02\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "1 cifar100 0.05\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "2 cifar100 0.05\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "3 cifar100 0.05\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "1 cifar10 0.05\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "2 cifar10 0.05\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "3 cifar10 0.05\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "1 cifar100 0.1\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "2 cifar100 0.1\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "3 cifar100 0.1\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "1 cifar10 0.1\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "2 cifar10 0.1\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "3 cifar10 0.1\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n"
     ]
    }
   ],
   "source": [
    "for label_noise in [0.01, 0.02, 0.05, 0.1]:\n",
    "    for dataset in ['cifar100', 'cifar10']:\n",
    "        container_dir = dataset\n",
    "        for seed in [1,2,3]:\n",
    "            print(f\"{seed} {dataset} {label_noise}\")\n",
    "            preds = []\n",
    "            for epoch in range(200):\n",
    "                pred = get_numpy_from_azure(container_name, container_dir, f\"pred_{dataset}_resnet18_noisy_idx_{seed}_epoch_{epoch}_noise_{label_noise}_tid0.pt\")\n",
    "                preds.append(pred)\n",
    "\n",
    "            preds = np.array(preds).T\n",
    "            # Calculate the difference between adjacent elements to find transitions\n",
    "            # diff = np.diff(preds, axis=1)\n",
    "            lt = preds.mean(1)\n",
    "\n",
    "            # The transition from 0 to 1 will have a value of 1 in diff\n",
    "            # Add 1 to get the actual index of the transition in the original array\n",
    "            # transitions = np.where(diff == 1)[1] + 1\n",
    "\n",
    "            # # Group the indices by row and find the last transition for each row\n",
    "            # lt = np.array([transitions[transitions // (preds.shape[1] - 1) == i][-1] \n",
    "            #                 if np.any(transitions // (preds.shape[1] - 1) == i) \n",
    "            # #                 else -1 for i in range(preds.shape[0])])\n",
    "\n",
    "            # print(lt)\n",
    "            # print(preds[0].astype(np.int32))\n",
    "            # print(preds[1].astype(np.int32))\n",
    "\n",
    "            dataset_obj = load_dataset(\n",
    "                    dataset=dataset,\n",
    "                    val_split=0,\n",
    "                    root_path=config['data_dir'],\n",
    "                    random_seed=seed,\n",
    "                    logger=logger,\n",
    "                    label_noise=label_noise)\n",
    "\n",
    "            is_miss = np.zeros((dataset_obj.train_length))\n",
    "            is_miss[dataset_obj.noisy_idxs] = 1\n",
    "\n",
    "            for metric in ['01_loss']:\n",
    "                auroc, aupr, fpr95 = calculate_metrics(is_miss, 1-lt)\n",
    "                result_seeds.append(seed)\n",
    "                result_noise.append(label_noise)\n",
    "                result_auroc.append(auroc)\n",
    "                result_aupr.append(aupr)\n",
    "                result_fpr95.append(fpr95)\n",
    "                result_method.append(metric)\n",
    "                result_dataset.append(dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 97,
   "metadata": {},
   "outputs": [],
   "source": [
    "results_df = pd.DataFrame({\n",
    "    'seed': result_seeds,\n",
    "    'label_noise': result_noise,\n",
    "    'method': result_method,\n",
    "    'auroc': result_auroc,\n",
    "    'aupr': result_aupr,\n",
    "    'fpr95': result_fpr95,\n",
    "    'dataset': result_dataset\n",
    "})\n",
    "\n",
    "results_df.groupby(['dataset', 'method', 'label_noise']).mean().to_csv(\"./mean_zo_loss.csv\")\n",
    "results_df.groupby(['dataset', 'method', 'label_noise']).std().to_csv(\"./std_zo_loss.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 98,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th>seed</th>\n",
       "      <th>auroc</th>\n",
       "      <th>aupr</th>\n",
       "      <th>fpr95</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>dataset</th>\n",
       "      <th>method</th>\n",
       "      <th>label_noise</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th rowspan=\"8\" valign=\"top\">cifar10</th>\n",
       "      <th rowspan=\"4\" valign=\"top\">01_loss</th>\n",
       "      <th>0.01</th>\n",
       "      <td>2.0</td>\n",
       "      <td>0.950145</td>\n",
       "      <td>0.203776</td>\n",
       "      <td>0.183603</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.02</th>\n",
       "      <td>2.0</td>\n",
       "      <td>0.952810</td>\n",
       "      <td>0.315056</td>\n",
       "      <td>0.162932</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.05</th>\n",
       "      <td>2.0</td>\n",
       "      <td>0.943315</td>\n",
       "      <td>0.473292</td>\n",
       "      <td>0.203751</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.10</th>\n",
       "      <td>2.0</td>\n",
       "      <td>0.927446</td>\n",
       "      <td>0.591965</td>\n",
       "      <td>0.267941</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"4\" valign=\"top\">lt</th>\n",
       "      <th>0.01</th>\n",
       "      <td>2.0</td>\n",
       "      <td>0.495143</td>\n",
       "      <td>0.226429</td>\n",
       "      <td>0.965138</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.02</th>\n",
       "      <td>2.0</td>\n",
       "      <td>0.495400</td>\n",
       "      <td>0.221963</td>\n",
       "      <td>0.965259</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.05</th>\n",
       "      <td>2.0</td>\n",
       "      <td>0.491135</td>\n",
       "      <td>0.234056</td>\n",
       "      <td>0.962133</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.10</th>\n",
       "      <td>2.0</td>\n",
       "      <td>0.494791</td>\n",
       "      <td>0.267576</td>\n",
       "      <td>0.962252</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"8\" valign=\"top\">cifar100</th>\n",
       "      <th rowspan=\"4\" valign=\"top\">01_loss</th>\n",
       "      <th>0.01</th>\n",
       "      <td>2.0</td>\n",
       "      <td>0.984558</td>\n",
       "      <td>0.297591</td>\n",
       "      <td>0.040256</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.02</th>\n",
       "      <td>2.0</td>\n",
       "      <td>0.985654</td>\n",
       "      <td>0.447762</td>\n",
       "      <td>0.033289</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.05</th>\n",
       "      <td>2.0</td>\n",
       "      <td>0.985995</td>\n",
       "      <td>0.653129</td>\n",
       "      <td>0.029039</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.10</th>\n",
       "      <td>2.0</td>\n",
       "      <td>0.986489</td>\n",
       "      <td>0.789180</td>\n",
       "      <td>0.027096</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"4\" valign=\"top\">lt</th>\n",
       "      <th>0.01</th>\n",
       "      <td>2.0</td>\n",
       "      <td>0.525569</td>\n",
       "      <td>0.108358</td>\n",
       "      <td>0.978801</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.02</th>\n",
       "      <td>2.0</td>\n",
       "      <td>0.522664</td>\n",
       "      <td>0.120862</td>\n",
       "      <td>0.983728</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.05</th>\n",
       "      <td>2.0</td>\n",
       "      <td>0.516106</td>\n",
       "      <td>0.145894</td>\n",
       "      <td>0.984688</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.10</th>\n",
       "      <td>2.0</td>\n",
       "      <td>0.520254</td>\n",
       "      <td>0.187721</td>\n",
       "      <td>0.981030</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                              seed     auroc      aupr     fpr95\n",
       "dataset  method  label_noise                                    \n",
       "cifar10  01_loss 0.01          2.0  0.950145  0.203776  0.183603\n",
       "                 0.02          2.0  0.952810  0.315056  0.162932\n",
       "                 0.05          2.0  0.943315  0.473292  0.203751\n",
       "                 0.10          2.0  0.927446  0.591965  0.267941\n",
       "         lt      0.01          2.0  0.495143  0.226429  0.965138\n",
       "                 0.02          2.0  0.495400  0.221963  0.965259\n",
       "                 0.05          2.0  0.491135  0.234056  0.962133\n",
       "                 0.10          2.0  0.494791  0.267576  0.962252\n",
       "cifar100 01_loss 0.01          2.0  0.984558  0.297591  0.040256\n",
       "                 0.02          2.0  0.985654  0.447762  0.033289\n",
       "                 0.05          2.0  0.985995  0.653129  0.029039\n",
       "                 0.10          2.0  0.986489  0.789180  0.027096\n",
       "         lt      0.01          2.0  0.525569  0.108358  0.978801\n",
       "                 0.02          2.0  0.522664  0.120862  0.983728\n",
       "                 0.05          2.0  0.516106  0.145894  0.984688\n",
       "                 0.10          2.0  0.520254  0.187721  0.981030"
      ]
     },
     "execution_count": 98,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "results_df.groupby(['dataset', 'method', 'label_noise']).mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "1 cifar100 0.01\n",
      "(50000,)\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "0.8875348080808081 1 0.01 cifar100\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "2 cifar100 0.01\n",
      "(50000,)\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "0.896372808080808 2 0.01 cifar100\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "3 cifar100 0.01\n",
      "(50000,)\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "0.890675595959596 3 0.01 cifar100\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "1 cifar10 0.01\n",
      "(50000,)\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "0.9633156565656567 1 0.01 cifar10\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "2 cifar10 0.01\n",
      "(50000,)\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "0.9604676767676767 2 0.01 cifar10\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "3 cifar10 0.01\n",
      "(50000,)\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "0.9639407878787879 3 0.01 cifar10\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "1 cifar100 0.02\n",
      "(50000,)\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "0.8906332653061224 1 0.02 cifar100\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "2 cifar100 0.02\n",
      "(50000,)\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "0.8893357959183674 2 0.02 cifar100\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "3 cifar100 0.02\n",
      "(50000,)\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "0.8880342448979591 3 0.02 cifar100\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "1 cifar10 0.02\n",
      "(50000,)\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "0.9540723877551021 1 0.02 cifar10\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "2 cifar10 0.02\n",
      "(50000,)\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "0.9574356122448979 2 0.02 cifar10\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "3 cifar10 0.02\n",
      "(50000,)\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "0.953780387755102 3 0.02 cifar10\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "1 cifar100 0.05\n",
      "(50000,)\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "0.8801231157894737 1 0.05 cifar100\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "2 cifar100 0.05\n",
      "(50000,)\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "0.8749439284210526 2 0.05 cifar100\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "3 cifar100 0.05\n",
      "(50000,)\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "0.8800819747368421 3 0.05 cifar100\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "1 cifar10 0.05\n",
      "(50000,)\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "0.9481329347368421 1 0.05 cifar10\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "2 cifar10 0.05\n",
      "(50000,)\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "0.9546395242105263 2 0.05 cifar10\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "3 cifar10 0.05\n",
      "(50000,)\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "0.9466511368421052 3 0.05 cifar10\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "1 cifar100 0.1\n",
      "(50000,)\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "0.8688714577777779 1 0.1 cifar100\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "2 cifar100 0.1\n",
      "(50000,)\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "0.8662577422222222 2 0.1 cifar100\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "3 cifar100 0.1\n",
      "(50000,)\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "0.8641377266666667 3 0.1 cifar100\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "1 cifar10 0.1\n",
      "(50000,)\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "0.9367418577777777 1 0.1 cifar10\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "2 cifar10 0.1\n",
      "(50000,)\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "0.9374978333333333 2 0.1 cifar10\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "3 cifar10 0.1\n",
      "(50000,)\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "0.9337248511111111 3 0.1 cifar10\n"
     ]
    }
   ],
   "source": [
    "from minio_obj_storage import get_numpy_from_cloud\n",
    "import torch\n",
    "\n",
    "result_seeds = []\n",
    "result_noise = []\n",
    "result_auroc = []\n",
    "result_aupr = []\n",
    "result_fpr95 = []\n",
    "result_method = []\n",
    "result_dataset = []\n",
    "\n",
    "bucket_name = 'learning-dynamics-scores'\n",
    "index = torch.load(\"./index/data_index_cifar100.pt\")\n",
    "index = np.array(list(range(len(index))))\n",
    "for label_noise in [0.01, 0.02, 0.05, 0.1]:\n",
    "    for dataset in ['cifar100', 'cifar10']:\n",
    "        container_dir = dataset\n",
    "        for seed in [1,2,3]:\n",
    "            dataset_obj = load_dataset(\n",
    "                    dataset=dataset,\n",
    "                    val_split=0,\n",
    "                    root_path=config['data_dir'],\n",
    "                    random_seed=seed,\n",
    "                    logger=logger,\n",
    "                    label_noise=label_noise)\n",
    "\n",
    "            print(f\"{seed} {dataset} {label_noise}\")\n",
    "            epochs_total = 200\n",
    "            preds = np.zeros((epochs_total, dataset_obj.train_length))\n",
    "\n",
    "            for part in [0, 1]:\n",
    "                if part == 0:\n",
    "                    index1 = index[:len(index) // 2]\n",
    "                    index2 = index[len(index) // 2:]\n",
    "                else:\n",
    "                    index2 = index[:len(index) // 2]\n",
    "                    index1 = index[len(index) // 2:]    \n",
    "                pred = get_numpy_from_cloud(bucket_name, container_dir, f\"ssft_pred_resnet18_part_{part}_noisy_idx_{seed}_noise_{label_noise}.npy\")\n",
    "                # print(pred[epochs_total:, index1].shape)\n",
    "                # print(preds[:, index1].shape)\n",
    "                preds[:, index1] = pred[epochs_total:, index1]\n",
    "\n",
    "            preds = np.array(preds).T\n",
    "            ft = np.mean(preds, axis=1)\n",
    "            print(ft.shape)\n",
    "            # print(lt)\n",
    "            # print(preds[0])\n",
    "\n",
    "            dataset_obj = load_dataset(\n",
    "                    dataset=dataset,\n",
    "                    val_split=0,\n",
    "                    root_path=config['data_dir'],\n",
    "                    random_seed=seed,\n",
    "                    logger=logger,\n",
    "                    label_noise=label_noise)\n",
    "\n",
    "            is_miss = np.zeros((dataset_obj.train_length))\n",
    "            is_miss[dataset_obj.noisy_idxs] = 1\n",
    "\n",
    "            for metric in ['ssft']:\n",
    "                auroc, aupr, fpr95 = calculate_metrics(is_miss, 1-ft)\n",
    "                print(auroc, seed, label_noise, dataset)\n",
    "                result_seeds.append(seed)\n",
    "                result_noise.append(label_noise)\n",
    "                result_auroc.append(auroc)\n",
    "                result_aupr.append(aupr)\n",
    "                result_fpr95.append(fpr95)\n",
    "                result_method.append(metric)\n",
    "                result_dataset.append(dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th>seed</th>\n",
       "      <th>auroc</th>\n",
       "      <th>aupr</th>\n",
       "      <th>fpr95</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>dataset</th>\n",
       "      <th>method</th>\n",
       "      <th>label_noise</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th rowspan=\"4\" valign=\"top\">cifar10</th>\n",
       "      <th rowspan=\"4\" valign=\"top\">ssft</th>\n",
       "      <th>0.01</th>\n",
       "      <td>2.0</td>\n",
       "      <td>0.962575</td>\n",
       "      <td>0.480661</td>\n",
       "      <td>0.137313</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.02</th>\n",
       "      <td>2.0</td>\n",
       "      <td>0.955096</td>\n",
       "      <td>0.526567</td>\n",
       "      <td>0.171735</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.05</th>\n",
       "      <td>2.0</td>\n",
       "      <td>0.949808</td>\n",
       "      <td>0.626042</td>\n",
       "      <td>0.207228</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.10</th>\n",
       "      <td>2.0</td>\n",
       "      <td>0.935988</td>\n",
       "      <td>0.689172</td>\n",
       "      <td>0.255696</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"4\" valign=\"top\">cifar100</th>\n",
       "      <th rowspan=\"4\" valign=\"top\">ssft</th>\n",
       "      <th>0.01</th>\n",
       "      <td>2.0</td>\n",
       "      <td>0.891528</td>\n",
       "      <td>0.476681</td>\n",
       "      <td>0.281152</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.02</th>\n",
       "      <td>2.0</td>\n",
       "      <td>0.889334</td>\n",
       "      <td>0.500372</td>\n",
       "      <td>0.272068</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.05</th>\n",
       "      <td>2.0</td>\n",
       "      <td>0.878383</td>\n",
       "      <td>0.545858</td>\n",
       "      <td>0.318140</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.10</th>\n",
       "      <td>2.0</td>\n",
       "      <td>0.866422</td>\n",
       "      <td>0.608292</td>\n",
       "      <td>0.360326</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                             seed     auroc      aupr     fpr95\n",
       "dataset  method label_noise                                    \n",
       "cifar10  ssft   0.01          2.0  0.962575  0.480661  0.137313\n",
       "                0.02          2.0  0.955096  0.526567  0.171735\n",
       "                0.05          2.0  0.949808  0.626042  0.207228\n",
       "                0.10          2.0  0.935988  0.689172  0.255696\n",
       "cifar100 ssft   0.01          2.0  0.891528  0.476681  0.281152\n",
       "                0.02          2.0  0.889334  0.500372  0.272068\n",
       "                0.05          2.0  0.878383  0.545858  0.318140\n",
       "                0.10          2.0  0.866422  0.608292  0.360326"
      ]
     },
     "execution_count": 62,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "results_df = pd.DataFrame({\n",
    "    'seed': result_seeds,\n",
    "    'label_noise': result_noise,\n",
    "    'method': result_method,\n",
    "    'auroc': result_auroc,\n",
    "    'aupr': result_aupr,\n",
    "    'fpr95': result_fpr95,\n",
    "    'dataset': result_dataset\n",
    "})\n",
    "\n",
    "results_df.groupby(['dataset', 'method', 'label_noise']).mean()\n",
    "\n",
    "# to_csv(\"./mean_zo_loss.csv\")\n",
    "# results_df.groupby(['dataset', 'method', 'label_noise']).std().to_csv(\"./std_zo_loss.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th>seed</th>\n",
       "      <th>auroc</th>\n",
       "      <th>aupr</th>\n",
       "      <th>fpr95</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>dataset</th>\n",
       "      <th>method</th>\n",
       "      <th>label_noise</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th rowspan=\"4\" valign=\"top\">cifar10</th>\n",
       "      <th rowspan=\"4\" valign=\"top\">ssft</th>\n",
       "      <th>0.01</th>\n",
       "      <td>1.0</td>\n",
       "      <td>0.001851</td>\n",
       "      <td>0.006744</td>\n",
       "      <td>0.015048</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.02</th>\n",
       "      <td>1.0</td>\n",
       "      <td>0.002031</td>\n",
       "      <td>0.015827</td>\n",
       "      <td>0.013989</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.05</th>\n",
       "      <td>1.0</td>\n",
       "      <td>0.004249</td>\n",
       "      <td>0.013463</td>\n",
       "      <td>0.022658</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.10</th>\n",
       "      <td>1.0</td>\n",
       "      <td>0.001996</td>\n",
       "      <td>0.008184</td>\n",
       "      <td>0.000796</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"4\" valign=\"top\">cifar100</th>\n",
       "      <th rowspan=\"4\" valign=\"top\">ssft</th>\n",
       "      <th>0.01</th>\n",
       "      <td>1.0</td>\n",
       "      <td>0.004480</td>\n",
       "      <td>0.003406</td>\n",
       "      <td>0.032075</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.02</th>\n",
       "      <td>1.0</td>\n",
       "      <td>0.001300</td>\n",
       "      <td>0.004170</td>\n",
       "      <td>0.018522</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.05</th>\n",
       "      <td>1.0</td>\n",
       "      <td>0.002978</td>\n",
       "      <td>0.006987</td>\n",
       "      <td>0.002768</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.10</th>\n",
       "      <td>1.0</td>\n",
       "      <td>0.002371</td>\n",
       "      <td>0.003180</td>\n",
       "      <td>0.019947</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                             seed     auroc      aupr     fpr95\n",
       "dataset  method label_noise                                    \n",
       "cifar10  ssft   0.01          1.0  0.001851  0.006744  0.015048\n",
       "                0.02          1.0  0.002031  0.015827  0.013989\n",
       "                0.05          1.0  0.004249  0.013463  0.022658\n",
       "                0.10          1.0  0.001996  0.008184  0.000796\n",
       "cifar100 ssft   0.01          1.0  0.004480  0.003406  0.032075\n",
       "                0.02          1.0  0.001300  0.004170  0.018522\n",
       "                0.05          1.0  0.002978  0.006987  0.002768\n",
       "                0.10          1.0  0.002371  0.003180  0.019947"
      ]
     },
     "execution_count": 64,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "results_df.groupby(['dataset', 'method', 'label_noise']).std()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py3.11_tf",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
