{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "6943a46a-79fe-4895-a5a1-65e6c25c7bb0",
   "metadata": {
    "tags": []
   },
   "source": [
    "# Note!\n",
    "\n",
    "To run this notebook, you first need to train models using the bash script we provide. Then in the cell number 5 enter the Path to a valid model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "2818c06f-8c40-44f5-a0be-7ac7e6c1e6ab",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.insert(0, \"../\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "45cff5e0-9bb2-4823-b476-615fbfacd42e",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "from tqdm.auto import tqdm\n",
    "from src.config.models import NatPnModel, LeNet5, ResNet18\n",
    "from src.config.nat_pn.loss import BayesianLoss\n",
    "from src.config.uncertainty_metrics import (\n",
    "    load_dataset,\n",
    "    load_dataloaders,\n",
    "    choose_threshold,\n",
    "    load_model,\n",
    ")\n",
    "from src.config.utils import evaluate_accuracy, evaluate_switch, quantiles\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "from collections import defaultdict\n",
    "import csv\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "ed8cf6ea-0448-463f-8c43-1801c779276f",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "measure = 'log_prob'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6066b953-c86f-4205-9f82-dc763d22b195",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5dc9061b-466c-4a9d-8821-e4a1e52326b6",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "eab1ed6e330a40e7869ca67b35b33447",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/20 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "44a82cc59e3747a0afe4f428469a6d96",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/20 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "for dataset_name in ['mnist', 'fmnist', 'cifar10', 'svhn', 'medmnistA', 'medmnistC', 'medmnistS']:\n",
    "    prefix = \"None\"\n",
    "    density_model = 'flow'\n",
    "    embedding_dim = 16\n",
    "    \n",
    "    path = f\"../out/FedAvg/{prefix}all_params_stopgrad_logp_{dataset_name}_500_natpn.pt\"\n",
    "    # YOUR PATH TO A VALID MODEL IS HERE\n",
    "    # for example \n",
    "    # path = f\"../out/FedAvg/actual_models/{prefix}all_params_stopgrad_logp_{dataset_name}_100_natpn.pt\"\n",
    "    \n",
    "    \n",
    "\n",
    "    backbone = 'lenet5' #'res18' if dataset_name in ['cifar10', 'svhn'] else 'lenet5'\n",
    "    batch_size = 7000 if dataset_name in ['cifar10', 'svhn'] else 25000\n",
    "\n",
    "    stopgrad = False\n",
    "    device = 'cuda:0'\n",
    "\n",
    "    all_params_dict = torch.load(path)\n",
    "\n",
    "    data_indices, trainset, testset = load_dataset(\n",
    "        dataset_name=dataset_name,\n",
    "        normalization_name=dataset_name,\n",
    "    )\n",
    "\n",
    "    global_model = load_model(\n",
    "        dataset_name=dataset_name,\n",
    "        backbone=backbone,\n",
    "        stopgrad=stopgrad,\n",
    "        density_model=density_model,\n",
    "        embedding_dim=embedding_dim,\n",
    "        index='global',\n",
    "        all_params_dict=all_params_dict,\n",
    "    )\n",
    "    global_model.eval()\n",
    "    global_model = global_model.to(device)\n",
    "\n",
    "\n",
    "    ind_local_accuracies = []\n",
    "    ind_global_accuracies = []\n",
    "    ind_switch_accuracies = []\n",
    "\n",
    "    ood_local_accuracies = []\n",
    "    ood_global_accuracies = []\n",
    "    ood_switch_accuracies = []\n",
    "    \n",
    "    mix_local_accuracies = []\n",
    "    mix_global_accuracies = []\n",
    "    mix_switch_accuracies = []\n",
    "\n",
    "\n",
    "    for index in tqdm(range(len(all_params_dict) - 1)):\n",
    "        local_model = load_model(\n",
    "            dataset_name=dataset_name,\n",
    "            backbone=backbone,\n",
    "            density_model=density_model,\n",
    "            embedding_dim=embedding_dim,\n",
    "            stopgrad=stopgrad,\n",
    "            index=index,\n",
    "            all_params_dict=all_params_dict,\n",
    "        )\n",
    "        local_model.eval()\n",
    "        local_model = local_model.to(device)\n",
    "\n",
    "        in_classes = local_model.labels.cpu().numpy().tolist()\n",
    "\n",
    "        # create in-distribution and out-of-distribution datasets\n",
    "        in_distribution_indices = [i for i, t in enumerate(trainset.dataset.targets) if t in in_classes]\n",
    "        out_distribution_indices = [i for i, t in enumerate(trainset.dataset.targets) if t not in in_classes]\n",
    "\n",
    "        in_distribution_dataset = torch.utils.data.Subset(trainset.dataset, in_distribution_indices)\n",
    "        out_distribution_dataset = torch.utils.data.Subset(trainset.dataset, out_distribution_indices)\n",
    "\n",
    "        ind_loader = torch.utils.data.DataLoader(in_distribution_dataset, batch_size=batch_size, shuffle=False)\n",
    "        ood_loader = torch.utils.data.DataLoader(out_distribution_dataset, batch_size=batch_size, shuffle=False)\n",
    "\n",
    "        _, _, calloader = load_dataloaders(\n",
    "        client_id=index, data_indices=data_indices, trainset=trainset, testset=testset\n",
    "        )\n",
    "        \n",
    "        quantile = quantiles[dataset_name]\n",
    "\n",
    "        threshold, values = choose_threshold(\n",
    "            model=local_model,\n",
    "            calloader=calloader,\n",
    "            device=device,\n",
    "            alpha=quantile,\n",
    "        )\n",
    "        ind_correct_decision, ind_correct_local, ind_correct_global, ind_sample_num = evaluate_switch(\n",
    "            local_model=local_model,\n",
    "            global_model=global_model, \n",
    "            dataloader=ind_loader,\n",
    "            threshold=threshold[measure],\n",
    "            uncertainty_measure=measure,\n",
    "            device=device,\n",
    "            return_predictions=False\n",
    "        )\n",
    "\n",
    "        ood_correct_decision, ood_correct_local, ood_correct_global, ood_sample_num, ood_local_predictions, ood_global_predictions, ood_switch_predictions, ood_true_labels, _, _ = evaluate_switch(\n",
    "            local_model=local_model,\n",
    "            global_model=global_model, \n",
    "            dataloader=ood_loader,\n",
    "            threshold=threshold[measure],\n",
    "            uncertainty_measure=measure,\n",
    "            device=device,\n",
    "            return_predictions=True\n",
    "        )\n",
    "\n",
    "        ind_switch_accuracies.append(ind_correct_decision / ind_sample_num)\n",
    "        ood_switch_accuracies.append(ood_correct_decision / ood_sample_num)\n",
    "\n",
    "        ind_local_accuracies.append(ind_correct_local / ind_sample_num)\n",
    "        ind_global_accuracies.append(ind_correct_global / ind_sample_num)\n",
    "\n",
    "        ood_local_accuracies.append(ood_correct_local / ood_sample_num)\n",
    "        ood_global_accuracies.append(ood_correct_global / ood_sample_num)\n",
    "        \n",
    "        \n",
    "        mix_samples = ind_sample_num\n",
    "        \n",
    "        mix_local_correct = ind_correct_local\n",
    "        mix_global_correct = ind_correct_global\n",
    "        mix_switch_correct = ind_correct_decision\n",
    "        \n",
    "        \n",
    "        true_labels = torch.hstack(ood_true_labels)\n",
    "        all_labels = torch.unique(true_labels)\n",
    "        \n",
    "        mix_local_answers = torch.hstack(ood_local_predictions)\n",
    "        mix_global_answers = torch.hstack(ood_global_predictions)\n",
    "        mix_switch_answers = torch.hstack(ood_switch_predictions)\n",
    "        \n",
    "        sample_to_pick = int(ind_sample_num / len(all_labels))\n",
    "        for l in all_labels:\n",
    "            selected_predictions = (mix_local_answers == true_labels)[true_labels == l][:sample_to_pick]\n",
    "            mix_local_correct += selected_predictions.sum().cpu().item()\n",
    "            \n",
    "            selected_predictions = (mix_global_answers == true_labels)[true_labels == l][:sample_to_pick]\n",
    "            mix_global_correct += selected_predictions.sum().cpu().item()\n",
    "            \n",
    "            selected_predictions = (mix_switch_answers == true_labels)[true_labels == l][:sample_to_pick]\n",
    "            mix_switch_correct += selected_predictions.sum().cpu().item()            \n",
    "            \n",
    "            mix_samples += len(selected_predictions)\n",
    "            \n",
    "        mix_local_accuracies.append(mix_local_correct / mix_samples)\n",
    "        mix_global_accuracies.append(mix_global_correct / mix_samples)\n",
    "        mix_switch_accuracies.append(mix_switch_correct / mix_samples)\n",
    "        \n",
    "        \n",
    "        new_row = [\n",
    "            dataset_name,\n",
    "            index,\n",
    "            ind_local_accuracies[-1],\n",
    "            ind_global_accuracies[-1],\n",
    "            ind_switch_accuracies[-1],\n",
    "            ood_local_accuracies[-1],\n",
    "            ood_global_accuracies[-1],\n",
    "            ood_switch_accuracies[-1],\n",
    "            mix_local_accuracies[-1],\n",
    "            mix_global_accuracies[-1],\n",
    "            mix_switch_accuracies[-1],\n",
    "        ]\n",
    "        # open the csv file in append mode ('a') and write the row\n",
    "        with open(f'mix_threshold_compact_results_{measure}.csv', 'a', newline='') as file:\n",
    "            writer = csv.writer(file)\n",
    "            writer.writerow(new_row)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "612f6598-b4be-441e-bdcc-157bc883bb68",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "51ddd5f2-e863-4ad1-9217-c91a492a31d2",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ba1c1b65-795f-4ea8-80a6-3a372c45511d",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "comp_results = pd.read_csv('mix_threshold_compact_results.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ea7dff18-249d-486f-9014-3dc36ac11759",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "grouped_df = comp_results.groupby('dataset')[[col for col in comp_results.columns if col not in  ['client_id', 'dataset']]].mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8db016b7-86ab-41c8-adab-655302d2cf00",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "grouped_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6ca46426-eb1f-442e-a397-e98752b79143",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Create a list to hold the new multi-index tuples\n",
    "new_columns = []\n",
    "\n",
    "# Loop through the existing columns\n",
    "for column in grouped_df.columns:\n",
    "    # Split the column name into the group and the rest\n",
    "    group, rest = column.split('_', 1)\n",
    "    # Add the new tuple to the list\n",
    "    new_columns.append((group, rest))\n",
    "\n",
    "# Assign the new multi-index to the DataFrame\n",
    "grouped_df.columns = pd.MultiIndex.from_tuples(new_columns, names=['group', 'metric'])\n",
    "\n",
    "# Now you can select columns by group\n",
    "ind_df = grouped_df['ind']\n",
    "ood_df = grouped_df['ood']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "b7c58b1f-ab0a-4cbb-b8e6-6188d7b89e50",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead tr th {\n",
       "        text-align: left;\n",
       "    }\n",
       "\n",
       "    .dataframe thead tr:last-of-type th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr>\n",
       "      <th>group</th>\n",
       "      <th colspan=\"3\" halign=\"left\">ind</th>\n",
       "      <th colspan=\"3\" halign=\"left\">ood</th>\n",
       "      <th colspan=\"3\" halign=\"left\">mix</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>metric</th>\n",
       "      <th>local_acc</th>\n",
       "      <th>global_acc</th>\n",
       "      <th>switch_acc</th>\n",
       "      <th>local_acc</th>\n",
       "      <th>global_acc</th>\n",
       "      <th>switch_acc</th>\n",
       "      <th>local_acc</th>\n",
       "      <th>global_acc</th>\n",
       "      <th>switch_acc</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>dataset</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>cifar10</th>\n",
       "      <td>0.900215</td>\n",
       "      <td>0.307864</td>\n",
       "      <td>0.675893</td>\n",
       "      <td>0.000845</td>\n",
       "      <td>0.294641</td>\n",
       "      <td>0.269158</td>\n",
       "      <td>0.450491</td>\n",
       "      <td>0.300517</td>\n",
       "      <td>0.472019</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>fmnist</th>\n",
       "      <td>0.977021</td>\n",
       "      <td>0.635744</td>\n",
       "      <td>0.855219</td>\n",
       "      <td>0.000010</td>\n",
       "      <td>0.690713</td>\n",
       "      <td>0.618427</td>\n",
       "      <td>0.488514</td>\n",
       "      <td>0.662771</td>\n",
       "      <td>0.736789</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>medmnistA</th>\n",
       "      <td>0.970899</td>\n",
       "      <td>0.655083</td>\n",
       "      <td>0.836334</td>\n",
       "      <td>0.008557</td>\n",
       "      <td>0.671791</td>\n",
       "      <td>0.597737</td>\n",
       "      <td>0.493356</td>\n",
       "      <td>0.638717</td>\n",
       "      <td>0.695090</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>medmnistC</th>\n",
       "      <td>0.815595</td>\n",
       "      <td>0.552472</td>\n",
       "      <td>0.718647</td>\n",
       "      <td>0.021637</td>\n",
       "      <td>0.597626</td>\n",
       "      <td>0.341459</td>\n",
       "      <td>0.417899</td>\n",
       "      <td>0.545564</td>\n",
       "      <td>0.531786</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>medmnistS</th>\n",
       "      <td>0.851344</td>\n",
       "      <td>0.431561</td>\n",
       "      <td>0.645904</td>\n",
       "      <td>0.001495</td>\n",
       "      <td>0.501779</td>\n",
       "      <td>0.294639</td>\n",
       "      <td>0.426700</td>\n",
       "      <td>0.397341</td>\n",
       "      <td>0.436633</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>mnist</th>\n",
       "      <td>0.996584</td>\n",
       "      <td>0.980309</td>\n",
       "      <td>0.984199</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.980538</td>\n",
       "      <td>0.976770</td>\n",
       "      <td>0.498332</td>\n",
       "      <td>0.980417</td>\n",
       "      <td>0.980415</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>svhn</th>\n",
       "      <td>0.958544</td>\n",
       "      <td>0.359524</td>\n",
       "      <td>0.844695</td>\n",
       "      <td>0.009107</td>\n",
       "      <td>0.319849</td>\n",
       "      <td>0.292301</td>\n",
       "      <td>0.485988</td>\n",
       "      <td>0.372051</td>\n",
       "      <td>0.597601</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "group           ind                             ood                         \n",
       "metric    local_acc global_acc switch_acc local_acc global_acc switch_acc   \n",
       "dataset                                                                     \n",
       "cifar10    0.900215   0.307864   0.675893  0.000845   0.294641   0.269158  \\\n",
       "fmnist     0.977021   0.635744   0.855219  0.000010   0.690713   0.618427   \n",
       "medmnistA  0.970899   0.655083   0.836334  0.008557   0.671791   0.597737   \n",
       "medmnistC  0.815595   0.552472   0.718647  0.021637   0.597626   0.341459   \n",
       "medmnistS  0.851344   0.431561   0.645904  0.001495   0.501779   0.294639   \n",
       "mnist      0.996584   0.980309   0.984199  0.000000   0.980538   0.976770   \n",
       "svhn       0.958544   0.359524   0.844695  0.009107   0.319849   0.292301   \n",
       "\n",
       "group           mix                        \n",
       "metric    local_acc global_acc switch_acc  \n",
       "dataset                                    \n",
       "cifar10    0.450491   0.300517   0.472019  \n",
       "fmnist     0.488514   0.662771   0.736789  \n",
       "medmnistA  0.493356   0.638717   0.695090  \n",
       "medmnistC  0.417899   0.545564   0.531786  \n",
       "medmnistS  0.426700   0.397341   0.436633  \n",
       "mnist      0.498332   0.980417   0.980415  \n",
       "svhn       0.485988   0.372051   0.597601  "
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "grouped_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "abcd1c7d-265b-48d7-8e30-16b6f6f23e70",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead tr th {\n",
       "        text-align: left;\n",
       "    }\n",
       "\n",
       "    .dataframe thead tr:last-of-type th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr>\n",
       "      <th>group</th>\n",
       "      <th colspan=\"3\" halign=\"left\">ind</th>\n",
       "      <th colspan=\"3\" halign=\"left\">ood</th>\n",
       "      <th colspan=\"3\" halign=\"left\">mix</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>metric</th>\n",
       "      <th>local_acc</th>\n",
       "      <th>global_acc</th>\n",
       "      <th>switch_acc</th>\n",
       "      <th>local_acc</th>\n",
       "      <th>global_acc</th>\n",
       "      <th>switch_acc</th>\n",
       "      <th>local_acc</th>\n",
       "      <th>global_acc</th>\n",
       "      <th>switch_acc</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>dataset</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>cifar10</th>\n",
       "      <td>90.0</td>\n",
       "      <td>30.8</td>\n",
       "      <td>67.6</td>\n",
       "      <td>0.1</td>\n",
       "      <td>29.5</td>\n",
       "      <td>26.9</td>\n",
       "      <td>45.0</td>\n",
       "      <td>30.1</td>\n",
       "      <td>47.2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>fmnist</th>\n",
       "      <td>97.7</td>\n",
       "      <td>63.6</td>\n",
       "      <td>85.5</td>\n",
       "      <td>0.0</td>\n",
       "      <td>69.1</td>\n",
       "      <td>61.8</td>\n",
       "      <td>48.9</td>\n",
       "      <td>66.3</td>\n",
       "      <td>73.7</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>medmnistA</th>\n",
       "      <td>97.1</td>\n",
       "      <td>65.5</td>\n",
       "      <td>83.6</td>\n",
       "      <td>0.9</td>\n",
       "      <td>67.2</td>\n",
       "      <td>59.8</td>\n",
       "      <td>49.3</td>\n",
       "      <td>63.9</td>\n",
       "      <td>69.5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>medmnistC</th>\n",
       "      <td>81.6</td>\n",
       "      <td>55.2</td>\n",
       "      <td>71.9</td>\n",
       "      <td>2.2</td>\n",
       "      <td>59.8</td>\n",
       "      <td>34.1</td>\n",
       "      <td>41.8</td>\n",
       "      <td>54.6</td>\n",
       "      <td>53.2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>medmnistS</th>\n",
       "      <td>85.1</td>\n",
       "      <td>43.2</td>\n",
       "      <td>64.6</td>\n",
       "      <td>0.1</td>\n",
       "      <td>50.2</td>\n",
       "      <td>29.5</td>\n",
       "      <td>42.7</td>\n",
       "      <td>39.7</td>\n",
       "      <td>43.7</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>mnist</th>\n",
       "      <td>99.7</td>\n",
       "      <td>98.0</td>\n",
       "      <td>98.4</td>\n",
       "      <td>0.0</td>\n",
       "      <td>98.1</td>\n",
       "      <td>97.7</td>\n",
       "      <td>49.8</td>\n",
       "      <td>98.0</td>\n",
       "      <td>98.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>svhn</th>\n",
       "      <td>95.9</td>\n",
       "      <td>36.0</td>\n",
       "      <td>84.5</td>\n",
       "      <td>0.9</td>\n",
       "      <td>32.0</td>\n",
       "      <td>29.2</td>\n",
       "      <td>48.6</td>\n",
       "      <td>37.2</td>\n",
       "      <td>59.8</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "group           ind                             ood                         \n",
       "metric    local_acc global_acc switch_acc local_acc global_acc switch_acc   \n",
       "dataset                                                                     \n",
       "cifar10        90.0       30.8       67.6       0.1       29.5       26.9  \\\n",
       "fmnist         97.7       63.6       85.5       0.0       69.1       61.8   \n",
       "medmnistA      97.1       65.5       83.6       0.9       67.2       59.8   \n",
       "medmnistC      81.6       55.2       71.9       2.2       59.8       34.1   \n",
       "medmnistS      85.1       43.2       64.6       0.1       50.2       29.5   \n",
       "mnist          99.7       98.0       98.4       0.0       98.1       97.7   \n",
       "svhn           95.9       36.0       84.5       0.9       32.0       29.2   \n",
       "\n",
       "group           mix                        \n",
       "metric    local_acc global_acc switch_acc  \n",
       "dataset                                    \n",
       "cifar10        45.0       30.1       47.2  \n",
       "fmnist         48.9       66.3       73.7  \n",
       "medmnistA      49.3       63.9       69.5  \n",
       "medmnistC      41.8       54.6       53.2  \n",
       "medmnistS      42.7       39.7       43.7  \n",
       "mnist          49.8       98.0       98.0  \n",
       "svhn           48.6       37.2       59.8  "
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(grouped_df * 100).round(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "75285eb3-80ec-42ba-ae0c-62c12b2c5b68",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\\begin{tabular}{lrrrrrrrrr}\n",
      "\\toprule\n",
      "group & \\multicolumn{3}{r}{ind} & \\multicolumn{3}{r}{ood} & \\multicolumn{3}{r}{mix} \\\\\n",
      "metric & local_acc & global_acc & switch_acc & local_acc & global_acc & switch_acc & local_acc & global_acc & switch_acc \\\\\n",
      "dataset &  &  &  &  &  &  &  &  &  \\\\\n",
      "\\midrule\n",
      "cifar10 & 90.0 & 30.8 & 67.6 & 0.1 & 29.5 & 26.9 & 45.0 & 30.1 & 47.2 \\\\\n",
      "fmnist & 97.7 & 63.6 & 85.5 & 0.0 & 69.1 & 61.8 & 48.9 & 66.3 & 73.7 \\\\\n",
      "medmnistA & 97.1 & 65.5 & 83.6 & 0.9 & 67.2 & 59.8 & 49.3 & 63.9 & 69.5 \\\\\n",
      "medmnistC & 81.6 & 55.2 & 71.9 & 2.2 & 59.8 & 34.1 & 41.8 & 54.6 & 53.2 \\\\\n",
      "medmnistS & 85.1 & 43.2 & 64.6 & 0.1 & 50.2 & 29.5 & 42.7 & 39.7 & 43.7 \\\\\n",
      "mnist & 99.7 & 98.0 & 98.4 & 0.0 & 98.1 & 97.7 & 49.8 & 98.0 & 98.0 \\\\\n",
      "svhn & 95.9 & 36.0 & 84.5 & 0.9 & 32.0 & 29.2 & 48.6 & 37.2 & 59.8 \\\\\n",
      "\\bottomrule\n",
      "\\end{tabular}\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print((grouped_df * 100).round(1).to_latex(float_format=\"%.1f\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fb1c1623-1a59-4f63-a399-0b5daae2e81f",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
