{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "6943a46a-79fe-4895-a5a1-65e6c25c7bb0",
   "metadata": {
    "tags": []
   },
   "source": [
    "# Note!\n",
    "\n",
    "To run this notebook, you first need to train models using the bash script we provide. Then in the cell number 5 enter the Path to a valid model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "2818c06f-8c40-44f5-a0be-7ac7e6c1e6ab",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.insert(0, \"../\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "45cff5e0-9bb2-4823-b476-615fbfacd42e",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "from tqdm.auto import tqdm\n",
    "from src.config.models import NatPnModel, LeNet5, ResNet18\n",
    "from src.config.nat_pn.loss import BayesianLoss\n",
    "from src.config.uncertainty_metrics import (\n",
    "    load_dataset,\n",
    "    load_dataloaders,\n",
    "    choose_threshold,\n",
    "    load_model,\n",
    ")\n",
    "from src.config.utils import evaluate_accuracy, evaluate_switch, quantiles\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "from collections import defaultdict\n",
    "import csv\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "ed8cf6ea-0448-463f-8c43-1801c779276f",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "measure = 'entropy' # entropy, log_prob, epkl"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "6066b953-c86f-4205-9f82-dc763d22b195",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "header = [\n",
    "    'dataset',\n",
    "    'client_id',\n",
    "    'ind_local_acc',\n",
    "    'ind_global_acc',\n",
    "    'ind_switch_acc',\n",
    "    'ood_local_acc',\n",
    "    'ood_global_acc',\n",
    "    'ood_switch_acc',\n",
    "    'mix_local_acc',\n",
    "    'mix_global_acc',\n",
    "    'mix_switch_acc',\n",
    "]\n",
    "\n",
    "# create and write the header to the csv file\n",
    "with open(f'mix_threshold_compact_results_{measure}.csv', 'w', newline='') as file:\n",
    "    writer = csv.writer(file)\n",
    "    writer.writerow(header)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "5dc9061b-466c-4a9d-8821-e4a1e52326b6",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "be1ed33e80484253b9e9b95ead877828",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/20 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7508908986be48d7a6561444a4a2cdb6",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/20 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "119e31eca53b445b8292b1329aa2d0f0",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/20 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using downloaded and verified file: ../data/svhn/raw/train_32x32.mat\n",
      "Using downloaded and verified file: ../data/svhn/raw/test_32x32.mat\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "125502a7a19e4248a19546a46000a99d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/20 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ba75a91fb9274477aef29eeddcf9cc4f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/20 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "929e56c9e4c64ccab1579c0c64735811",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/20 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c16fbf91fb714d2eac31ba84664ef78d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/20 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "for dataset_name in ['mnist', 'fmnist', 'cifar10', 'svhn', 'medmnistA', 'medmnistC', 'medmnistS']:\n",
    "    prefix = \"None\"\n",
    "    if dataset_name in ['cifar10', 'svhn']:\n",
    "        prefix = 'lenet_' + prefix\n",
    "    density_model = 'flow'\n",
    "    embedding_dim = 16\n",
    "    \n",
    "    path = f\"../out/FedAvg/{prefix}all_params_stopgrad_logp_{dataset_name}_500_natpn.pt\"\n",
    "    # YOUR PATH TO A VALID MODEL IS HERE\n",
    "    # for example \n",
    "    # path = f\"../out/FedAvg/actual_models/{prefix}all_params_stopgrad_logp_{dataset_name}_100_natpn.pt\"\n",
    "    \n",
    "    \n",
    "\n",
    "    backbone = 'lenet5' #'res18' if dataset_name in ['cifar10', 'svhn'] else 'lenet5'\n",
    "    batch_size = 7000 if dataset_name in ['cifar10', 'svhn'] else 25000\n",
    "\n",
    "    stopgrad = False\n",
    "    device = 'cuda:0'\n",
    "\n",
    "    all_params_dict = torch.load(path)\n",
    "\n",
    "    data_indices, trainset, testset = load_dataset(\n",
    "        dataset_name=dataset_name,\n",
    "        normalization_name=dataset_name,\n",
    "    )\n",
    "\n",
    "    global_model = load_model(\n",
    "        dataset_name=dataset_name,\n",
    "        backbone=backbone,\n",
    "        stopgrad=stopgrad,\n",
    "        density_model=density_model,\n",
    "        embedding_dim=embedding_dim,\n",
    "        index='global',\n",
    "        all_params_dict=all_params_dict,\n",
    "    )\n",
    "    global_model.eval()\n",
    "    global_model = global_model.to(device)\n",
    "\n",
    "\n",
    "    ind_local_accuracies = []\n",
    "    ind_global_accuracies = []\n",
    "    ind_switch_accuracies = []\n",
    "\n",
    "    ood_local_accuracies = []\n",
    "    ood_global_accuracies = []\n",
    "    ood_switch_accuracies = []\n",
    "    \n",
    "    mix_local_accuracies = []\n",
    "    mix_global_accuracies = []\n",
    "    mix_switch_accuracies = []\n",
    "\n",
    "\n",
    "    for index in tqdm(range(len(all_params_dict) - 1)):\n",
    "        local_model = load_model(\n",
    "            dataset_name=dataset_name,\n",
    "            backbone=backbone,\n",
    "            density_model=density_model,\n",
    "            embedding_dim=embedding_dim,\n",
    "            stopgrad=stopgrad,\n",
    "            index=index,\n",
    "            all_params_dict=all_params_dict,\n",
    "        )\n",
    "        local_model.eval()\n",
    "        local_model = local_model.to(device)\n",
    "\n",
    "        in_classes = local_model.labels.cpu().numpy().tolist()\n",
    "\n",
    "        # create in-distribution and out-of-distribution datasets\n",
    "        in_distribution_indices = [i for i, t in enumerate(trainset.dataset.targets) if t in in_classes]\n",
    "        out_distribution_indices = [i for i, t in enumerate(trainset.dataset.targets) if t not in in_classes]\n",
    "\n",
    "        in_distribution_dataset = torch.utils.data.Subset(trainset.dataset, in_distribution_indices)\n",
    "        out_distribution_dataset = torch.utils.data.Subset(trainset.dataset, out_distribution_indices)\n",
    "\n",
    "        ind_loader = torch.utils.data.DataLoader(in_distribution_dataset, batch_size=batch_size, shuffle=False)\n",
    "        ood_loader = torch.utils.data.DataLoader(out_distribution_dataset, batch_size=batch_size, shuffle=False)\n",
    "\n",
    "        _, _, calloader = load_dataloaders(\n",
    "        client_id=index, data_indices=data_indices, trainset=trainset, testset=testset\n",
    "        )\n",
    "        \n",
    "        quantile = quantiles[dataset_name]\n",
    "\n",
    "        threshold, values = choose_threshold(\n",
    "            model=local_model,\n",
    "            calloader=calloader,\n",
    "            device=device,\n",
    "            alpha=quantile,\n",
    "        )\n",
    "        ind_correct_decision, ind_correct_local, ind_correct_global, ind_sample_num = evaluate_switch(\n",
    "            local_model=local_model,\n",
    "            global_model=global_model, \n",
    "            dataloader=ind_loader,\n",
    "            threshold=threshold[measure],\n",
    "            uncertainty_measure=measure,\n",
    "            device=device,\n",
    "            return_predictions=False\n",
    "        )\n",
    "\n",
    "        ood_correct_decision, ood_correct_local, ood_correct_global, ood_sample_num, ood_local_predictions, ood_global_predictions, ood_switch_predictions, ood_true_labels, _, _ = evaluate_switch(\n",
    "            local_model=local_model,\n",
    "            global_model=global_model, \n",
    "            dataloader=ood_loader,\n",
    "            threshold=threshold[measure],\n",
    "            uncertainty_measure=measure,\n",
    "            device=device,\n",
    "            return_predictions=True\n",
    "        )\n",
    "\n",
    "        ind_switch_accuracies.append(ind_correct_decision / ind_sample_num)\n",
    "        ood_switch_accuracies.append(ood_correct_decision / ood_sample_num)\n",
    "\n",
    "        ind_local_accuracies.append(ind_correct_local / ind_sample_num)\n",
    "        ind_global_accuracies.append(ind_correct_global / ind_sample_num)\n",
    "\n",
    "        ood_local_accuracies.append(ood_correct_local / ood_sample_num)\n",
    "        ood_global_accuracies.append(ood_correct_global / ood_sample_num)\n",
    "        \n",
    "        \n",
    "        mix_samples = ind_sample_num\n",
    "        \n",
    "        mix_local_correct = ind_correct_local\n",
    "        mix_global_correct = ind_correct_global\n",
    "        mix_switch_correct = ind_correct_decision\n",
    "        \n",
    "        \n",
    "        true_labels = torch.hstack(ood_true_labels)\n",
    "        all_labels = torch.unique(true_labels)\n",
    "        \n",
    "        mix_local_answers = torch.hstack(ood_local_predictions)\n",
    "        mix_global_answers = torch.hstack(ood_global_predictions)\n",
    "        mix_switch_answers = torch.hstack(ood_switch_predictions)\n",
    "        \n",
    "        sample_to_pick = int(ind_sample_num / len(all_labels))\n",
    "        for l in all_labels:\n",
    "            selected_predictions = (mix_local_answers == true_labels)[true_labels == l][:sample_to_pick]\n",
    "            mix_local_correct += selected_predictions.sum().cpu().item()\n",
    "            \n",
    "            selected_predictions = (mix_global_answers == true_labels)[true_labels == l][:sample_to_pick]\n",
    "            mix_global_correct += selected_predictions.sum().cpu().item()\n",
    "            \n",
    "            selected_predictions = (mix_switch_answers == true_labels)[true_labels == l][:sample_to_pick]\n",
    "            mix_switch_correct += selected_predictions.sum().cpu().item()            \n",
    "            \n",
    "            mix_samples += len(selected_predictions)\n",
    "            \n",
    "        mix_local_accuracies.append(mix_local_correct / mix_samples)\n",
    "        mix_global_accuracies.append(mix_global_correct / mix_samples)\n",
    "        mix_switch_accuracies.append(mix_switch_correct / mix_samples)\n",
    "        \n",
    "        \n",
    "        new_row = [\n",
    "            dataset_name,\n",
    "            index,\n",
    "            ind_local_accuracies[-1],\n",
    "            ind_global_accuracies[-1],\n",
    "            ind_switch_accuracies[-1],\n",
    "            ood_local_accuracies[-1],\n",
    "            ood_global_accuracies[-1],\n",
    "            ood_switch_accuracies[-1],\n",
    "            mix_local_accuracies[-1],\n",
    "            mix_global_accuracies[-1],\n",
    "            mix_switch_accuracies[-1],\n",
    "        ]\n",
    "        # open the csv file in append mode ('a') and write the row\n",
    "        with open(f'mix_threshold_compact_results_{measure}.csv', 'a', newline='') as file:\n",
    "            writer = csv.writer(file)\n",
    "            writer.writerow(new_row)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "612f6598-b4be-441e-bdcc-157bc883bb68",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "51ddd5f2-e863-4ad1-9217-c91a492a31d2",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "ba1c1b65-795f-4ea8-80a6-3a372c45511d",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "comp_results = pd.read_csv(f'mix_threshold_compact_results_{measure}.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "ea7dff18-249d-486f-9014-3dc36ac11759",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "grouped_df = comp_results.groupby('dataset')[[col for col in comp_results.columns if col not in  ['client_id', 'dataset']]].mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "8db016b7-86ab-41c8-adab-655302d2cf00",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>ind_local_acc</th>\n",
       "      <th>ind_global_acc</th>\n",
       "      <th>ind_switch_acc</th>\n",
       "      <th>ood_local_acc</th>\n",
       "      <th>ood_global_acc</th>\n",
       "      <th>ood_switch_acc</th>\n",
       "      <th>mix_local_acc</th>\n",
       "      <th>mix_global_acc</th>\n",
       "      <th>mix_switch_acc</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>dataset</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>cifar10</th>\n",
       "      <td>0.751285</td>\n",
       "      <td>0.396237</td>\n",
       "      <td>0.591292</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.388211</td>\n",
       "      <td>0.287977</td>\n",
       "      <td>0.375660</td>\n",
       "      <td>0.393820</td>\n",
       "      <td>0.441049</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>fmnist</th>\n",
       "      <td>0.957070</td>\n",
       "      <td>0.799873</td>\n",
       "      <td>0.842708</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.814732</td>\n",
       "      <td>0.781737</td>\n",
       "      <td>0.478535</td>\n",
       "      <td>0.807163</td>\n",
       "      <td>0.812558</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>medmnistA</th>\n",
       "      <td>0.989735</td>\n",
       "      <td>0.958855</td>\n",
       "      <td>0.961896</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.952013</td>\n",
       "      <td>0.949385</td>\n",
       "      <td>0.496466</td>\n",
       "      <td>0.954885</td>\n",
       "      <td>0.955200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>medmnistC</th>\n",
       "      <td>0.966257</td>\n",
       "      <td>0.938915</td>\n",
       "      <td>0.944187</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.931231</td>\n",
       "      <td>0.887122</td>\n",
       "      <td>0.483515</td>\n",
       "      <td>0.932255</td>\n",
       "      <td>0.911745</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>medmnistS</th>\n",
       "      <td>0.907246</td>\n",
       "      <td>0.829766</td>\n",
       "      <td>0.868632</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.826807</td>\n",
       "      <td>0.755279</td>\n",
       "      <td>0.454031</td>\n",
       "      <td>0.814063</td>\n",
       "      <td>0.800323</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>mnist</th>\n",
       "      <td>0.994038</td>\n",
       "      <td>0.983275</td>\n",
       "      <td>0.983547</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.983121</td>\n",
       "      <td>0.982914</td>\n",
       "      <td>0.497059</td>\n",
       "      <td>0.983246</td>\n",
       "      <td>0.983290</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>svhn</th>\n",
       "      <td>0.922265</td>\n",
       "      <td>0.777547</td>\n",
       "      <td>0.871364</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.775414</td>\n",
       "      <td>0.621665</td>\n",
       "      <td>0.461163</td>\n",
       "      <td>0.762285</td>\n",
       "      <td>0.733723</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "           ind_local_acc  ind_global_acc  ind_switch_acc  ood_local_acc   \n",
       "dataset                                                                   \n",
       "cifar10         0.751285        0.396237        0.591292            0.0  \\\n",
       "fmnist          0.957070        0.799873        0.842708            0.0   \n",
       "medmnistA       0.989735        0.958855        0.961896            0.0   \n",
       "medmnistC       0.966257        0.938915        0.944187            0.0   \n",
       "medmnistS       0.907246        0.829766        0.868632            0.0   \n",
       "mnist           0.994038        0.983275        0.983547            0.0   \n",
       "svhn            0.922265        0.777547        0.871364            0.0   \n",
       "\n",
       "           ood_global_acc  ood_switch_acc  mix_local_acc  mix_global_acc   \n",
       "dataset                                                                    \n",
       "cifar10          0.388211        0.287977       0.375660        0.393820  \\\n",
       "fmnist           0.814732        0.781737       0.478535        0.807163   \n",
       "medmnistA        0.952013        0.949385       0.496466        0.954885   \n",
       "medmnistC        0.931231        0.887122       0.483515        0.932255   \n",
       "medmnistS        0.826807        0.755279       0.454031        0.814063   \n",
       "mnist            0.983121        0.982914       0.497059        0.983246   \n",
       "svhn             0.775414        0.621665       0.461163        0.762285   \n",
       "\n",
       "           mix_switch_acc  \n",
       "dataset                    \n",
       "cifar10          0.441049  \n",
       "fmnist           0.812558  \n",
       "medmnistA        0.955200  \n",
       "medmnistC        0.911745  \n",
       "medmnistS        0.800323  \n",
       "mnist            0.983290  \n",
       "svhn             0.733723  "
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "grouped_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "6ca46426-eb1f-442e-a397-e98752b79143",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Create a list to hold the new multi-index tuples\n",
    "new_columns = []\n",
    "\n",
    "# Loop through the existing columns\n",
    "for column in grouped_df.columns:\n",
    "    # Split the column name into the group and the rest\n",
    "    group, rest = column.split('_', 1)\n",
    "    # Add the new tuple to the list\n",
    "    new_columns.append((group, rest))\n",
    "\n",
    "# Assign the new multi-index to the DataFrame\n",
    "grouped_df.columns = pd.MultiIndex.from_tuples(new_columns, names=['group', 'metric'])\n",
    "\n",
    "# Now you can select columns by group\n",
    "ind_df = grouped_df['ind']\n",
    "ood_df = grouped_df['ood']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "b7c58b1f-ab0a-4cbb-b8e6-6188d7b89e50",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead tr th {\n",
       "        text-align: left;\n",
       "    }\n",
       "\n",
       "    .dataframe thead tr:last-of-type th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr>\n",
       "      <th>group</th>\n",
       "      <th colspan=\"3\" halign=\"left\">ind</th>\n",
       "      <th colspan=\"3\" halign=\"left\">ood</th>\n",
       "      <th colspan=\"3\" halign=\"left\">mix</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>metric</th>\n",
       "      <th>local_acc</th>\n",
       "      <th>global_acc</th>\n",
       "      <th>switch_acc</th>\n",
       "      <th>local_acc</th>\n",
       "      <th>global_acc</th>\n",
       "      <th>switch_acc</th>\n",
       "      <th>local_acc</th>\n",
       "      <th>global_acc</th>\n",
       "      <th>switch_acc</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>dataset</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>cifar10</th>\n",
       "      <td>0.751285</td>\n",
       "      <td>0.396237</td>\n",
       "      <td>0.591292</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.388211</td>\n",
       "      <td>0.287977</td>\n",
       "      <td>0.375660</td>\n",
       "      <td>0.393820</td>\n",
       "      <td>0.441049</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>fmnist</th>\n",
       "      <td>0.957070</td>\n",
       "      <td>0.799873</td>\n",
       "      <td>0.842708</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.814732</td>\n",
       "      <td>0.781737</td>\n",
       "      <td>0.478535</td>\n",
       "      <td>0.807163</td>\n",
       "      <td>0.812558</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>medmnistA</th>\n",
       "      <td>0.989735</td>\n",
       "      <td>0.958855</td>\n",
       "      <td>0.961896</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.952013</td>\n",
       "      <td>0.949385</td>\n",
       "      <td>0.496466</td>\n",
       "      <td>0.954885</td>\n",
       "      <td>0.955200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>medmnistC</th>\n",
       "      <td>0.966257</td>\n",
       "      <td>0.938915</td>\n",
       "      <td>0.944187</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.931231</td>\n",
       "      <td>0.887122</td>\n",
       "      <td>0.483515</td>\n",
       "      <td>0.932255</td>\n",
       "      <td>0.911745</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>medmnistS</th>\n",
       "      <td>0.907246</td>\n",
       "      <td>0.829766</td>\n",
       "      <td>0.868632</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.826807</td>\n",
       "      <td>0.755279</td>\n",
       "      <td>0.454031</td>\n",
       "      <td>0.814063</td>\n",
       "      <td>0.800323</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>mnist</th>\n",
       "      <td>0.994038</td>\n",
       "      <td>0.983275</td>\n",
       "      <td>0.983547</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.983121</td>\n",
       "      <td>0.982914</td>\n",
       "      <td>0.497059</td>\n",
       "      <td>0.983246</td>\n",
       "      <td>0.983290</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>svhn</th>\n",
       "      <td>0.922265</td>\n",
       "      <td>0.777547</td>\n",
       "      <td>0.871364</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.775414</td>\n",
       "      <td>0.621665</td>\n",
       "      <td>0.461163</td>\n",
       "      <td>0.762285</td>\n",
       "      <td>0.733723</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "group           ind                             ood                         \n",
       "metric    local_acc global_acc switch_acc local_acc global_acc switch_acc   \n",
       "dataset                                                                     \n",
       "cifar10    0.751285   0.396237   0.591292       0.0   0.388211   0.287977  \\\n",
       "fmnist     0.957070   0.799873   0.842708       0.0   0.814732   0.781737   \n",
       "medmnistA  0.989735   0.958855   0.961896       0.0   0.952013   0.949385   \n",
       "medmnistC  0.966257   0.938915   0.944187       0.0   0.931231   0.887122   \n",
       "medmnistS  0.907246   0.829766   0.868632       0.0   0.826807   0.755279   \n",
       "mnist      0.994038   0.983275   0.983547       0.0   0.983121   0.982914   \n",
       "svhn       0.922265   0.777547   0.871364       0.0   0.775414   0.621665   \n",
       "\n",
       "group           mix                        \n",
       "metric    local_acc global_acc switch_acc  \n",
       "dataset                                    \n",
       "cifar10    0.375660   0.393820   0.441049  \n",
       "fmnist     0.478535   0.807163   0.812558  \n",
       "medmnistA  0.496466   0.954885   0.955200  \n",
       "medmnistC  0.483515   0.932255   0.911745  \n",
       "medmnistS  0.454031   0.814063   0.800323  \n",
       "mnist      0.497059   0.983246   0.983290  \n",
       "svhn       0.461163   0.762285   0.733723  "
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "grouped_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "abcd1c7d-265b-48d7-8e30-16b6f6f23e70",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead tr th {\n",
       "        text-align: left;\n",
       "    }\n",
       "\n",
       "    .dataframe thead tr:last-of-type th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr>\n",
       "      <th>group</th>\n",
       "      <th colspan=\"3\" halign=\"left\">ind</th>\n",
       "      <th colspan=\"3\" halign=\"left\">ood</th>\n",
       "      <th colspan=\"3\" halign=\"left\">mix</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>metric</th>\n",
       "      <th>local_acc</th>\n",
       "      <th>global_acc</th>\n",
       "      <th>switch_acc</th>\n",
       "      <th>local_acc</th>\n",
       "      <th>global_acc</th>\n",
       "      <th>switch_acc</th>\n",
       "      <th>local_acc</th>\n",
       "      <th>global_acc</th>\n",
       "      <th>switch_acc</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>dataset</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>cifar10</th>\n",
       "      <td>75.1</td>\n",
       "      <td>39.6</td>\n",
       "      <td>59.1</td>\n",
       "      <td>0.0</td>\n",
       "      <td>38.8</td>\n",
       "      <td>28.8</td>\n",
       "      <td>37.6</td>\n",
       "      <td>39.4</td>\n",
       "      <td>44.1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>fmnist</th>\n",
       "      <td>95.7</td>\n",
       "      <td>80.0</td>\n",
       "      <td>84.3</td>\n",
       "      <td>0.0</td>\n",
       "      <td>81.5</td>\n",
       "      <td>78.2</td>\n",
       "      <td>47.9</td>\n",
       "      <td>80.7</td>\n",
       "      <td>81.3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>medmnistA</th>\n",
       "      <td>99.0</td>\n",
       "      <td>95.9</td>\n",
       "      <td>96.2</td>\n",
       "      <td>0.0</td>\n",
       "      <td>95.2</td>\n",
       "      <td>94.9</td>\n",
       "      <td>49.6</td>\n",
       "      <td>95.5</td>\n",
       "      <td>95.5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>medmnistC</th>\n",
       "      <td>96.6</td>\n",
       "      <td>93.9</td>\n",
       "      <td>94.4</td>\n",
       "      <td>0.0</td>\n",
       "      <td>93.1</td>\n",
       "      <td>88.7</td>\n",
       "      <td>48.4</td>\n",
       "      <td>93.2</td>\n",
       "      <td>91.2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>medmnistS</th>\n",
       "      <td>90.7</td>\n",
       "      <td>83.0</td>\n",
       "      <td>86.9</td>\n",
       "      <td>0.0</td>\n",
       "      <td>82.7</td>\n",
       "      <td>75.5</td>\n",
       "      <td>45.4</td>\n",
       "      <td>81.4</td>\n",
       "      <td>80.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>mnist</th>\n",
       "      <td>99.4</td>\n",
       "      <td>98.3</td>\n",
       "      <td>98.4</td>\n",
       "      <td>0.0</td>\n",
       "      <td>98.3</td>\n",
       "      <td>98.3</td>\n",
       "      <td>49.7</td>\n",
       "      <td>98.3</td>\n",
       "      <td>98.3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>svhn</th>\n",
       "      <td>92.2</td>\n",
       "      <td>77.8</td>\n",
       "      <td>87.1</td>\n",
       "      <td>0.0</td>\n",
       "      <td>77.5</td>\n",
       "      <td>62.2</td>\n",
       "      <td>46.1</td>\n",
       "      <td>76.2</td>\n",
       "      <td>73.4</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "group           ind                             ood                         \n",
       "metric    local_acc global_acc switch_acc local_acc global_acc switch_acc   \n",
       "dataset                                                                     \n",
       "cifar10        75.1       39.6       59.1       0.0       38.8       28.8  \\\n",
       "fmnist         95.7       80.0       84.3       0.0       81.5       78.2   \n",
       "medmnistA      99.0       95.9       96.2       0.0       95.2       94.9   \n",
       "medmnistC      96.6       93.9       94.4       0.0       93.1       88.7   \n",
       "medmnistS      90.7       83.0       86.9       0.0       82.7       75.5   \n",
       "mnist          99.4       98.3       98.4       0.0       98.3       98.3   \n",
       "svhn           92.2       77.8       87.1       0.0       77.5       62.2   \n",
       "\n",
       "group           mix                        \n",
       "metric    local_acc global_acc switch_acc  \n",
       "dataset                                    \n",
       "cifar10        37.6       39.4       44.1  \n",
       "fmnist         47.9       80.7       81.3  \n",
       "medmnistA      49.6       95.5       95.5  \n",
       "medmnistC      48.4       93.2       91.2  \n",
       "medmnistS      45.4       81.4       80.0  \n",
       "mnist          49.7       98.3       98.3  \n",
       "svhn           46.1       76.2       73.4  "
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(grouped_df * 100).round(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "75285eb3-80ec-42ba-ae0c-62c12b2c5b68",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\\begin{tabular}{lrrrrrrrrr}\n",
      "\\toprule\n",
      "group & \\multicolumn{3}{r}{ind} & \\multicolumn{3}{r}{ood} & \\multicolumn{3}{r}{mix} \\\\\n",
      "metric & local_acc & global_acc & switch_acc & local_acc & global_acc & switch_acc & local_acc & global_acc & switch_acc \\\\\n",
      "dataset &  &  &  &  &  &  &  &  &  \\\\\n",
      "\\midrule\n",
      "cifar10 & 75.1 & 39.6 & 59.1 & 0.0 & 38.8 & 28.8 & 37.6 & 39.4 & 44.1 \\\\\n",
      "fmnist & 95.7 & 80.0 & 84.3 & 0.0 & 81.5 & 78.2 & 47.9 & 80.7 & 81.3 \\\\\n",
      "medmnistA & 99.0 & 95.9 & 96.2 & 0.0 & 95.2 & 94.9 & 49.6 & 95.5 & 95.5 \\\\\n",
      "medmnistC & 96.6 & 93.9 & 94.4 & 0.0 & 93.1 & 88.7 & 48.4 & 93.2 & 91.2 \\\\\n",
      "medmnistS & 90.7 & 83.0 & 86.9 & 0.0 & 82.7 & 75.5 & 45.4 & 81.4 & 80.0 \\\\\n",
      "mnist & 99.4 & 98.3 & 98.4 & 0.0 & 98.3 & 98.3 & 49.7 & 98.3 & 98.3 \\\\\n",
      "svhn & 92.2 & 77.8 & 87.1 & 0.0 & 77.5 & 62.2 & 46.1 & 76.2 & 73.4 \\\\\n",
      "\\bottomrule\n",
      "\\end{tabular}\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print((grouped_df * 100).round(1).to_latex(float_format=\"%.1f\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fb1c1623-1a59-4f63-a399-0b5daae2e81f",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9ca77e84",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
