{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import random\n",
    "import math\n",
    "import copy\n",
    "import io\n",
    "import time\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "\n",
    "\n",
    "import torch\n",
    "import torchvision\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch import cuda\n",
    "import torch.optim as optim\n",
    "from torch.utils.data import Dataset, DataLoader, TensorDataset\n",
    "import torch.mps as mps\n",
    "\n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    device = 'cuda'\n",
    "else:\n",
    "    device = 'cpu'\n",
    "\n",
    "print(f\"Using device: {device}\")\n",
    "\n",
    "import sys\n",
    "sys.path.append('models_scratch/')\n",
    "from models_scratch import *\n",
    "\n",
    "torch.manual_seed(42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "infile = open('EuroSATData_bin.pk','rb')\n",
    "SavedData = pickle.load(infile)\n",
    "infile.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_train = SavedData[\"X_train\"]\n",
    "S_train = SavedData[\"S_train\"]\n",
    "y_train = SavedData[\"Y_train\"]\n",
    "X_test = SavedData[\"X_test\"]\n",
    "S_test = SavedData[\"S_test\"]\n",
    "y_test = SavedData[\"Y_test\"]\n",
    "\n",
    "print(100*((S_train == 0).sum().item())/(S_train == 1).sum().item())  # 3% images bluish"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def build_model(network, num_classes, input_channels, input_height, input_width, device, batch_norm=False):\n",
    "    \n",
    "    if batch_norm:\n",
    "        norm_layer = nn.BatchNorm2d\n",
    "    else:\n",
    "        norm_layer = None\n",
    "\n",
    "    if network == \"vgg11\":\n",
    "        net = VGG(\"VGG11\", num_classes=num_classes, batch_norm=False)\n",
    "    elif network == \"vgg16\":\n",
    "        net = VGG(\"VGG16\", num_classes=num_classes, batch_norm=False)\n",
    "    elif network == \"resnet18\":\n",
    "        net = resnet18(norm_layer=norm_layer, num_classes=num_classes)\n",
    "    elif network == \"resnet34\":\n",
    "        net = resnet34(norm_layer=norm_layer, num_classes=num_classes)\n",
    "    elif network == \"resnet50\":\n",
    "        net = resnet50(norm_layer=norm_layer, num_classes=num_classes)\n",
    "    elif network == \"densenet121\":\n",
    "        net = densenet121(norm_layer=norm_layer, num_classes=num_classes)\n",
    "        net.head = nn.Linear(net.head.in_features, num_classes)\n",
    "    elif network == \"lenet\":\n",
    "        net = LeNet5(num_classes=num_classes, input_channels=input_channels, input_height=input_height, input_width=input_width)\n",
    "    else:\n",
    "        raise ValueError(\"Invalid network name.\")\n",
    "\n",
    "    net = net.to(device)\n",
    "    \n",
    "    num_params = sum(p.numel() for p in net.parameters())\n",
    "    print(f\"Total number of parameters in {network}: {num_params:,}\")\n",
    "    \n",
    "    class ModelWithPredict(nn.Module):\n",
    "        def __init__(self, base_model):\n",
    "            super().__init__()\n",
    "            self.base_model = base_model\n",
    "\n",
    "        def forward(self, x):\n",
    "            output = self.base_model(x)\n",
    "            return torch.sigmoid(output)\n",
    "\n",
    "        def predict(self, X):\n",
    "            prediction_probabilities = self(X)\n",
    "            return 1. * (prediction_probabilities > 0.5)\n",
    "\n",
    "        def predict_probabilities(self, X):\n",
    "            return self(X)\n",
    "\n",
    "    net = ModelWithPredict(net)\n",
    "    \n",
    "    return net\n",
    "\n",
    "input_channels = X_train.shape[1]\n",
    "input_height = X_train.shape[2]\n",
    "input_width = X_train.shape[3]\n",
    "net = build_model(\"resnet18\", 1, input_channels, input_height, input_width, device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CustomDataset(Dataset):\n",
    "    def __init__(self, X, S, y):\n",
    "        self.X = X\n",
    "        self.S = S\n",
    "        self.y = y\n",
    "        \n",
    "    def __len__(self):\n",
    "        return len(self.X)\n",
    "    \n",
    "    def __getitem__(self, idx):\n",
    "        return self.X[idx], self.S[idx], self.y[idx]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_model(network, X_train, S_train, y_train, X_test, S_test, y_test, \n",
    "                epochs=10, batch_size=64, learning_rate=0.05, device='cpu', batch_norm=False, kappa=90, tau=90):\n",
    "    \n",
    "    trainset = TensorDataset(X_train, S_train, y_train)\n",
    "    testset  = TensorDataset(X_test, S_test, y_test)\n",
    "    trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True)\n",
    "    testloader  = DataLoader(testset, batch_size=100, shuffle=False)\n",
    "    \n",
    "    input_channels, input_height, input_width = X_train.shape[1], X_train.shape[2], X_train.shape[3]\n",
    "    model = build_model(network, num_classes=1, input_channels=input_channels, \n",
    "                        input_height=input_height, input_width=input_width, device=device, batch_norm=batch_norm)\n",
    "    \n",
    "    optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0)\n",
    "    criterion = nn.MSELoss(reduction='sum')\n",
    "    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)\n",
    "    \n",
    "    train_loss_s0, train_loss_s1, train_loss_all = [], [], []\n",
    "    test_loss_s0, test_loss_s1, test_loss_all = [], [], []\n",
    "    train_acc_s0, train_acc_s1, train_acc_all = [], [], []\n",
    "    test_acc_s0, test_acc_s1, test_acc_all = [], [], []\n",
    "    times, cumulative_time = [], 0\n",
    "    \n",
    "    early_stopping_epoch = None\n",
    "    final_epoch = None\n",
    "\n",
    "    best_train_loss_s0 = float('inf')\n",
    "    best_train_loss_s1 = float('inf')\n",
    "    best_train_loss    = float('inf')\n",
    "    best_test_loss_s0  = float('inf')\n",
    "    best_test_loss_s1  = float('inf')\n",
    "    best_test_loss     = float('inf')\n",
    "    best_train_acc_s0  = 0\n",
    "    best_train_acc_s1  = 0\n",
    "    best_train_acc     = 0\n",
    "    best_test_acc_s0   = 0\n",
    "    best_test_acc_s1   = 0\n",
    "    best_test_acc      = 0\n",
    "    \n",
    "    nb_epochs = 0\n",
    "\n",
    "    for epoch in range(epochs):\n",
    "        start_time = time.time()\n",
    "        model.train()\n",
    "        \n",
    "        total_loss, total_samples = 0.0, 0\n",
    "        loss_s0_sum, count_s0 = 0.0, 0\n",
    "        loss_s1_sum, count_s1 = 0.0, 0\n",
    "        correct_total, correct_s0, correct_s1 = 0, 0, 0\n",
    "        \n",
    "        for X_batch, S_batch, y_batch in trainloader:\n",
    "            X_batch, S_batch, y_batch = X_batch.to(device), S_batch.to(device), y_batch.to(device)\n",
    "            optimizer.zero_grad()\n",
    "            y_batch = y_batch.float().view(-1, 1)\n",
    "            outputs = model(X_batch)\n",
    "            loss = criterion(outputs, y_batch)\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            \n",
    "            bsize = y_batch.size(0)\n",
    "            total_loss += loss.item() \n",
    "            total_samples += bsize\n",
    "            \n",
    "            preds = (outputs > 0.5).float()\n",
    "            correct_total += preds.eq(y_batch).sum().item()\n",
    "            \n",
    "            mask_s0 = (S_batch == 0)\n",
    "            if mask_s0.any():\n",
    "                n_s0 = mask_s0.sum().item()\n",
    "                loss_s0 = criterion(outputs[mask_s0], y_batch[mask_s0]).item()\n",
    "                loss_s0_sum += loss_s0 \n",
    "                count_s0 += n_s0\n",
    "                correct_s0 += preds[mask_s0].eq(y_batch[mask_s0]).sum().item()\n",
    "                \n",
    "            mask_s1 = (S_batch == 1)\n",
    "            if mask_s1.any():\n",
    "                n_s1 = mask_s1.sum().item()\n",
    "                loss_s1 = criterion(outputs[mask_s1], y_batch[mask_s1]).item()\n",
    "                loss_s1_sum += loss_s1 \n",
    "                count_s1 += n_s1\n",
    "                correct_s1 += preds[mask_s1].eq(y_batch[mask_s1]).sum().item()\n",
    "        \n",
    "        avg_loss    = total_loss / total_samples\n",
    "        avg_loss_s0 = loss_s0_sum / count_s0 \n",
    "        avg_loss_s1 = loss_s1_sum / count_s1 \n",
    "        \n",
    "        acc_total = (correct_total / total_samples) * 100\n",
    "        acc_s0    = (correct_s0 / count_s0) * 100\n",
    "        acc_s1    = (correct_s1 / count_s1) * 100 \n",
    "        \n",
    "        best_train_loss_s0 = min(best_train_loss_s0, avg_loss_s0) \n",
    "        best_train_loss_s1 = min(best_train_loss_s1, avg_loss_s1)\n",
    "        best_train_loss    = min(best_train_loss, avg_loss) \n",
    "        best_train_acc_s0  = max(best_train_acc_s0, acc_s0) \n",
    "        best_train_acc_s1  = max(best_train_acc_s1, acc_s1) \n",
    "        best_train_acc     = max(best_train_acc, acc_total)\n",
    "        \n",
    "        train_loss_s0.append(best_train_loss_s0)\n",
    "        train_loss_s1.append(best_train_loss_s1)\n",
    "        train_loss_all.append(best_train_loss)\n",
    "        train_acc_s0.append(best_train_acc_s0)\n",
    "        train_acc_s1.append(best_train_acc_s1)\n",
    "        train_acc_all.append(best_train_acc)\n",
    "        \n",
    "        if early_stopping_epoch is None and best_train_acc > tau:\n",
    "            early_stopping_epoch = epoch + 1\n",
    "        \n",
    "        model.eval()\n",
    "        test_loss_sum, total_test = 0.0, 0\n",
    "        loss_s0_test_sum, count_s0_test = 0.0, 0\n",
    "        loss_s1_test_sum, count_s1_test = 0.0, 0\n",
    "        correct_test_total, correct_s0_test, correct_s1_test = 0, 0, 0\n",
    "        \n",
    "        with torch.no_grad():\n",
    "            for X_batch, S_batch, y_batch in testloader:\n",
    "                X_batch, S_batch, y_batch = X_batch.to(device), S_batch.to(device), y_batch.to(device)\n",
    "                y_batch = y_batch.float().view(-1, 1)\n",
    "                outputs = model(X_batch)\n",
    "                loss = criterion(outputs, y_batch)\n",
    "                bsize = y_batch.size(0)\n",
    "                test_loss_sum += loss.item() \n",
    "                total_test += bsize\n",
    "                \n",
    "                preds = (outputs > 0.5).float()\n",
    "                correct_test_total += preds.eq(y_batch).sum().item()\n",
    "                \n",
    "                mask_s0 = (S_batch == 0)\n",
    "                if mask_s0.any():\n",
    "                    n_s0 = mask_s0.sum().item()\n",
    "                    loss_s0 = criterion(outputs[mask_s0], y_batch[mask_s0]).item()\n",
    "                    loss_s0_test_sum += loss_s0 \n",
    "                    count_s0_test += n_s0\n",
    "                    correct_s0_test += preds[mask_s0].eq(y_batch[mask_s0]).sum().item()\n",
    "                \n",
    "                mask_s1 = (S_batch == 1)\n",
    "                if mask_s1.any():\n",
    "                    n_s1 = mask_s1.sum().item()\n",
    "                    loss_s1 = criterion(outputs[mask_s1], y_batch[mask_s1]).item()\n",
    "                    loss_s1_test_sum += loss_s1 \n",
    "                    count_s1_test += n_s1\n",
    "                    correct_s1_test += preds[mask_s1].eq(y_batch[mask_s1]).sum().item()\n",
    "        \n",
    "        avg_loss_test    = test_loss_sum / total_test\n",
    "        avg_loss_s0_test = loss_s0_test_sum / count_s0_test \n",
    "        avg_loss_s1_test = loss_s1_test_sum / count_s1_test \n",
    "        \n",
    "        acc_test         = (correct_test_total / total_test) * 100\n",
    "        acc_s0_test      = (correct_s0_test / count_s0_test) * 100 if count_s0_test > 0 else 0\n",
    "        acc_s1_test      = (correct_s1_test / count_s1_test) * 100 if count_s1_test > 0 else 0\n",
    "        \n",
    "        best_test_loss_s0 = min(best_test_loss_s0, avg_loss_s0_test) \n",
    "        best_test_loss_s1 = min(best_test_loss_s1, avg_loss_s1_test) \n",
    "        best_test_loss    = min(best_test_loss, avg_loss_test) \n",
    "        best_test_acc_s0  = max(best_test_acc_s0, acc_s0_test) \n",
    "        best_test_acc_s1  = max(best_test_acc_s1, acc_s1_test) \n",
    "        best_test_acc     = max(best_test_acc, acc_test) \n",
    "        \n",
    "        test_loss_s0.append(best_test_loss_s0)\n",
    "        test_loss_s1.append(best_test_loss_s1)\n",
    "        test_loss_all.append(best_test_loss)\n",
    "        test_acc_s0.append(best_test_acc_s0)\n",
    "        test_acc_s1.append(best_test_acc_s1)\n",
    "        test_acc_all.append(best_test_acc)\n",
    "        \n",
    "        cumulative_time += time.time() - start_time\n",
    "        times.append(cumulative_time)\n",
    "        \n",
    "        if (epoch + 1) % 50 == 0 or epoch == 0:\n",
    "            print(f\"Epoch [{epoch+1}/{epochs}] | Time: {cumulative_time:.2f}s\")\n",
    "            print(f\"  Train -> Loss: S0={avg_loss_s0:.4f}, S1={avg_loss_s1:.4f}, Global={avg_loss:.4f} | \"\n",
    "                  f\"Acc: S0={acc_s0:.2f}%, S1={acc_s1:.2f}%, Global={acc_total:.2f}%\")\n",
    "            print(f\"  Test  -> Loss: S0={avg_loss_s0_test:.4f}, S1={avg_loss_s1_test:.4f}, Global={avg_loss_test:.4f} | \"\n",
    "                  f\"Acc: S0={acc_s0_test:.2f}%, S1={acc_s1_test:.2f}%, Global={acc_test:.2f}%\")\n",
    "        nb_epochs += 1\n",
    "        \n",
    "        if best_train_acc_s0 > kappa:\n",
    "            final_epoch = epoch + 1\n",
    "            break\n",
    "            \n",
    "    if final_epoch is None:\n",
    "        final_epoch = epochs\n",
    "\n",
    "    debiasing_duration = (final_epoch - early_stopping_epoch) if early_stopping_epoch is not None else None\n",
    "\n",
    "    print(f\"\\nTraining finished in {final_epoch} epochs.\")\n",
    "    if early_stopping_epoch:\n",
    "        print(f\"→ Early stopping threshold (Acc > τ={tau}%) reached at epoch {early_stopping_epoch}.\")\n",
    "    print(f\"→ Fairness threshold (Acc_S=0 > κ={kappa}%) reached at epoch {final_epoch}.\")\n",
    "    if early_stopping_epoch:\n",
    "        print(f\"→ Debiasing duration: {debiasing_duration} epochs.\")\n",
    "\n",
    "    return {\n",
    "         \"times\": np.array(times),\n",
    "         \"epoch\": np.arange(1, nb_epochs + 1),\n",
    "         \"train_loss_s0\": np.array(train_loss_s0),\n",
    "         \"train_loss_s1\": np.array(train_loss_s1),\n",
    "         \"train_loss_all\": np.array(train_loss_all),\n",
    "         \"test_loss_s0\": np.array(test_loss_s0),\n",
    "         \"test_loss_s1\": np.array(test_loss_s1),\n",
    "         \"test_loss_all\": np.array(test_loss_all),\n",
    "         \"train_acc_s0\": np.array(train_acc_s0),\n",
    "         \"train_acc_s1\": np.array(train_acc_s1),\n",
    "         \"train_acc_all\": np.array(train_acc_all),\n",
    "         \"test_acc_s0\": np.array(test_acc_s0),\n",
    "         \"test_acc_s1\": np.array(test_acc_s1),\n",
    "         \"test_acc_all\": np.array(test_acc_all)\n",
    "    }\n",
    "\n",
    "def run_multiple_experiments(network, X_train, S_train, y_train, X_test, S_test, y_test,\n",
    "                             epochs=10, batch_size=64, learning_rate=0.05, device='cpu', num_runs=5, batch_norm=False):\n",
    "    records = []\n",
    "    \n",
    "    for run in range(num_runs):\n",
    "        print(f\"\\n--- Run {run+1}/{num_runs} ---\")\n",
    "        result = train_model(network, X_train, S_train, y_train, X_test, S_test, y_test,\n",
    "                             epochs=epochs, batch_size=batch_size, learning_rate=learning_rate, device=device, batch_norm=batch_norm)\n",
    "        for i, epoch in enumerate(result[\"epoch\"]):\n",
    "            record = {\n",
    "                \"num_run\": run+1,\n",
    "                \"epoch\": result[\"epoch\"][i],\n",
    "                \"time\": result[\"times\"][i],\n",
    "                \"train_loss_all\": result[\"train_loss_all\"][i],\n",
    "                \"train_loss_s0\": result[\"train_loss_s0\"][i],\n",
    "                \"train_loss_s1\": result[\"train_loss_s1\"][i],\n",
    "                \"test_loss_all\": result[\"test_loss_all\"][i],\n",
    "                \"test_loss_s0\": result[\"test_loss_s0\"][i],\n",
    "                \"test_loss_s1\": result[\"test_loss_s1\"][i],\n",
    "                \"train_acc_all\": result[\"train_acc_all\"][i],\n",
    "                \"train_acc_s0\": result[\"train_acc_s0\"][i],\n",
    "                \"train_acc_s1\": result[\"train_acc_s1\"][i],\n",
    "                \"test_acc_all\": result[\"test_acc_all\"][i],\n",
    "                \"test_acc_s0\": result[\"test_acc_s0\"][i],\n",
    "                \"test_acc_s1\": result[\"test_acc_s1\"][i],\n",
    "            }\n",
    "            records.append(record)\n",
    "    \n",
    "    df = pd.DataFrame(records)\n",
    "    df[\"epoch\"] = df[\"epoch\"].astype(int) \n",
    "    \n",
    "    return df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_results = run_multiple_experiments(\"resnet18\", X_train, S_train, y_train, X_test, S_test, y_test,\n",
    "                             epochs=1000, batch_size=512, learning_rate=1e-4, device='cuda', num_runs=3, batch_norm=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_xp(df, network, kappa=90):\n",
    "    \n",
    "    df_loss_train = df.melt(id_vars=[\"num_run\", \"epoch\", \"time\"], \n",
    "                            value_vars=[\"train_loss_all\", \"train_loss_s0\", \"train_loss_s1\"],\n",
    "                            var_name=\"group\", value_name=\"train_loss\")\n",
    "    df_acc_train = df.melt(id_vars=[\"num_run\", \"epoch\", \"time\"], \n",
    "                           value_vars=[\"train_acc_all\", \"train_acc_s0\", \"train_acc_s1\"],\n",
    "                           var_name=\"group\", value_name=\"train_acc\")\n",
    "    df_acc_test = df.melt(id_vars=[\"num_run\", \"epoch\", \"time\"], \n",
    "                          value_vars=[\"test_acc_all\", \"test_acc_s0\", \"test_acc_s1\"],\n",
    "                          var_name=\"group\", value_name=\"test_acc\")\n",
    "\n",
    "    mapping = {\n",
    "        \"train_loss_all\": \"Global\", \"train_loss_s0\": \"A=0\", \"train_loss_s1\": \"A=1\",\n",
    "        \"train_acc_all\": \"Global\", \"train_acc_s0\": \"A=0\", \"train_acc_s1\": \"A=1\",\n",
    "        \"test_acc_all\": \"Global\", \"test_acc_s0\": \"A=0\", \"test_acc_s1\": \"A=1\"\n",
    "    }\n",
    "\n",
    "    for df_, name in [(df_loss_train, \"group\"), (df_acc_train, \"group\"), (df_acc_test, \"group\")]:\n",
    "        df_[name] = df_[name].map(mapping)\n",
    "\n",
    "    palette = {\"A=0\": \"blue\", \"A=1\": \"orange\", \"Global\": \"green\"}\n",
    "\n",
    "    fig, axs = plt.subplots(1, 3, figsize=(14, 4), sharex=True)\n",
    "\n",
    "    sns.lineplot(ax=axs[0], data=df_loss_train, x=\"epoch\", y=\"train_loss\", hue=\"group\",\n",
    "                 estimator='mean', errorbar='sd', palette=palette)\n",
    "    axs[0].set_yscale(\"log\")\n",
    "    axs[0].set_xlabel(\"Epoch\", fontsize=14)\n",
    "    axs[0].set_ylabel(\"Train loss\", fontsize=14)\n",
    "    axs[0].grid(True)\n",
    "    axs[0].get_legend().remove()\n",
    "\n",
    "    sns.lineplot(ax=axs[1], data=df_acc_train, x=\"epoch\", y=\"train_acc\", hue=\"group\",\n",
    "                 estimator='mean', errorbar='sd', palette=palette)\n",
    "    axs[1].set_xlabel(\"Epoch\", fontsize=14)\n",
    "    axs[1].set_ylabel(\"Train accuracy (%)\", fontsize=14)\n",
    "    axs[1].axhline(kappa, color='red', linestyle='--', linewidth=1)\n",
    "\n",
    "    axs[1].grid(True)\n",
    "    axs[1].get_legend().remove()\n",
    "\n",
    "    sns.lineplot(ax=axs[2], data=df_acc_test, x=\"epoch\", y=\"test_acc\", hue=\"group\",\n",
    "                 estimator='mean', errorbar='sd', palette=palette)\n",
    "    axs[2].set_xlabel(\"Epoch\", fontsize=14)\n",
    "    axs[2].set_ylabel(\"Test accuracy (%)\", fontsize=14)\n",
    "    axs[2].grid(True)\n",
    "    axs[2].get_legend().remove()\n",
    "\n",
    "    handles = [Line2D([0], [0], color=color, lw=2, label=label) for label, color in palette.items()]\n",
    "    kappa_line = Line2D([0], [0], color='red', lw=1, linestyle='--', label=r\"$\\kappa$ threshold\")\n",
    "    handles.append(kappa_line)\n",
    "\n",
    "    fig.legend(handles, [h.get_label() for h in handles], loc='lower center',\n",
    "               ncol=4, bbox_to_anchor=(0.5, -0.15), fontsize=14)\n",
    "\n",
    "    for ax in axs:\n",
    "        ax.xaxis.set_major_locator(MaxNLocator(integer=True))\n",
    "\n",
    "    plt.tight_layout()\n",
    "    plt.savefig(f\"eurosat_{network}\", bbox_inches=\"tight\")\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_xp(df_results,'resnet18')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
