{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1e5d559d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import seaborn as sns\n",
    "import random\n",
    "import math\n",
    "import copy\n",
    "import io\n",
    "import time\n",
    "import matplotlib.pyplot as plt\n",
    "import os\n",
    "\n",
    "\n",
    "import torch\n",
    "import torchvision\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch import cuda\n",
    "import torch.optim as optim\n",
    "from torch.utils.data import Dataset, DataLoader, TensorDataset\n",
    "import torch.mps as mps\n",
    "\n",
    "from sklearn.datasets import fetch_openml\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "\n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    device = 'cuda'\n",
    "else:\n",
    "    device = 'cpu'\n",
    "\n",
    "print(f\"Using device: {device}\")\n",
    "\n",
    "import sys\n",
    "sys.path.append('models_scratch/')\n",
    "sys.path.append('data/')\n",
    "from models_scratch import *\n",
    "\n",
    "%matplotlib inline\n",
    "sns.set(style=\"whitegrid\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a2a450ce",
   "metadata": {},
   "outputs": [],
   "source": [
    "adult = fetch_openml('adult', version=2, as_frame=True, parser='auto')\n",
    "df = adult.frame\n",
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0480516b",
   "metadata": {},
   "outputs": [],
   "source": [
    "df['class'] = df['class'].str.strip().str.lower()\n",
    "df['class'] = (df['class'] == '>50k').astype(int)\n",
    "\n",
    "df['sex'] = df['sex'].astype('category')\n",
    "df['sex'] = df['sex'].str.strip().str.lower()\n",
    "df['sex_race'] = df['sex'].str.strip() + \"_\" + df['race'].str.strip()\n",
    "\n",
    "categorical_cols = df.select_dtypes(include=['category', 'object']).columns.tolist()\n",
    "continuous_cols = df.select_dtypes(include=['int64', 'float64']).columns.tolist()\n",
    "continuous_cols.remove('class')\n",
    "\n",
    "for col in categorical_cols:\n",
    "    df[col] = df[col].astype('category').cat.codes\n",
    "    \n",
    "df_train, df_test = train_test_split(df, test_size=0.3, random_state=0)\n",
    "print(len(df_train.columns))\n",
    "print(continuous_cols)\n",
    "print(categorical_cols)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bd80a918",
   "metadata": {},
   "outputs": [],
   "source": [
    "def filter_adult(df, minority_fraction=1.0, random_state=42):\n",
    "    df = df.copy()\n",
    "\n",
    "    condition_femmes_riches = (df['sex'] == 0) & (df['class'] == 1)  \n",
    "\n",
    "    df['S'] = 1  \n",
    "    df.loc[condition_femmes_riches, 'S'] = 0  \n",
    "\n",
    "    minority_indices = df[df['S'] == 0].index  \n",
    "    majority_indices = df[df['S'] == 1].index \n",
    "\n",
    "    n_minority = int(minority_fraction * len(majority_indices)) \n",
    "    selected_minority_indices = minority_indices.to_series().sample(\n",
    "        n=min(n_minority, len(minority_indices)), random_state=random_state\n",
    "    ).index\n",
    "\n",
    "    final_indices = pd.Index(selected_minority_indices.tolist() + majority_indices.tolist())\n",
    "    final_indices = final_indices.to_series().sample(frac=1, random_state=random_state).index  \n",
    "\n",
    "    df_filtered = df.loc[final_indices]\n",
    "\n",
    "    X_filtered = df_filtered.drop(columns=['class', 'S'])\n",
    "    y_filtered = torch.tensor(df_filtered['class'].values, dtype=torch.long)\n",
    "    S_filtered = torch.tensor(df_filtered['S'].values, dtype=torch.long)\n",
    "\n",
    "    return X_filtered, y_filtered, S_filtered\n",
    "\n",
    "X_train_filtered, y_train, S_train = filter_adult(df_train)\n",
    "X_test_filtered, y_test, S_test = filter_adult(df_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1dac2c13",
   "metadata": {},
   "outputs": [],
   "source": [
    "scaler = StandardScaler()\n",
    "X_train_filtered[continuous_cols] = scaler.fit_transform(X_train_filtered[continuous_cols])\n",
    "X_test_filtered[continuous_cols] = scaler.transform(X_test_filtered[continuous_cols])\n",
    "\n",
    "X_categ_train = torch.tensor(X_train_filtered[categorical_cols].values, dtype=torch.long)\n",
    "X_cont_train  = torch.tensor(X_train_filtered[continuous_cols].values, dtype=torch.float32)\n",
    "X_train = (X_categ_train, X_cont_train)\n",
    "\n",
    "X_categ_test = torch.tensor(X_test_filtered[categorical_cols].values, dtype=torch.long)\n",
    "X_cont_test  = torch.tensor(X_test_filtered[continuous_cols].values, dtype=torch.float32)\n",
    "X_test = (X_categ_test, X_cont_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6974c5af",
   "metadata": {},
   "outputs": [],
   "source": [
    "def build_model(network, cat_sizes, num_continuous, device, num_classes=2):\n",
    "    if network == \"ft_transformer\":\n",
    "        net = FTTransformer(categories=cat_sizes, num_continuous=num_continuous, num_classes=num_classes)\n",
    "        \n",
    "    elif network == \"mlp\":\n",
    "        net = MLP(cat_dims=cat_sizes, num_continuous=num_continuous, num_classes=num_classes)\n",
    "        \n",
    "    elif network == \"tabnet_v1\":\n",
    "        net = TabNet_v1(cat_dims=cat_sizes, num_continuous=num_continuous, num_classes=num_classes)\n",
    "        \n",
    "    elif network == \"tabnet_v2\":\n",
    "        net = TabNet_v2(cat_dims=cat_sizes, num_continuous=num_continuous, num_classes=num_classes, n_steps=3, hidden_dim=32, gamma=1.5, dropout=0.1)\n",
    "        \n",
    "    elif network == \"saint\":\n",
    "        net = SAINT(cat_dims=cat_sizes, num_continuous=num_continuous, num_classes=num_classes)\n",
    "        \n",
    "    elif network == \"resmlp\":\n",
    "        net = ResMLP(cat_dims=cat_sizes, num_continuous=num_continuous, num_classes=num_classes)\n",
    "        \n",
    "    else:\n",
    "        raise ValueError(f\"Invalid network name: {network}\")\n",
    "\n",
    "    net.to(device)\n",
    "\n",
    "    num_params = sum(p.numel() for p in net.parameters())\n",
    "    print(f\"Total number of parameters in {network}: {num_params:,}\")\n",
    "\n",
    "    return net"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bad58f70",
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_model(network, X_train, S_train, y_train,\n",
    "                X_test, S_test, y_test,\n",
    "                epochs=10, batch_size=64, learning_rate=0.05,\n",
    "                device='cpu', kappa=90, tau=90):\n",
    "\n",
    "    X_categ_train, X_cont_train = X_train\n",
    "    X_categ_test, X_cont_test = X_test\n",
    "\n",
    "    trainset = TensorDataset(X_categ_train, X_cont_train, S_train, y_train)\n",
    "    testset = TensorDataset(X_categ_test, X_cont_test, S_test, y_test)\n",
    "    trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True)\n",
    "    testloader = DataLoader(testset, batch_size=1024, shuffle=False)\n",
    "\n",
    "    cat_sizes = [int(X_categ_train[:, i].max().item() + 1) for i in range(X_categ_train.size(1))]\n",
    "    num_continuous = X_cont_train.size(1)\n",
    "\n",
    "    model = build_model(network, cat_sizes, num_continuous, device, num_classes=2)\n",
    "\n",
    "    optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0, weight_decay=1e-4) \n",
    "    criterion = nn.CrossEntropyLoss()\n",
    "    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)\n",
    "\n",
    "    train_loss_s0, train_loss_s1, train_loss_all = [], [], []\n",
    "    test_loss_s0, test_loss_s1, test_loss_all = [], [], []\n",
    "    train_acc_s0, train_acc_s1, train_acc_all = [], [], []\n",
    "    test_acc_s0, test_acc_s1, test_acc_all = [], [], []\n",
    "    times, cumulative_time = [], 0\n",
    "\n",
    "    early_stopping_epoch = None\n",
    "    final_epoch = None\n",
    "\n",
    "    best_train_loss_s0 = float('inf')\n",
    "    best_train_loss_s1 = float('inf')\n",
    "    best_train_loss = float('inf')\n",
    "    best_test_loss_s0 = float('inf')\n",
    "    best_test_loss_s1 = float('inf')\n",
    "    best_test_loss = float('inf')\n",
    "    best_train_acc_s0 = 0\n",
    "    best_train_acc_s1 = 0\n",
    "    best_train_acc = 0\n",
    "    best_test_acc_s0 = 0\n",
    "    best_test_acc_s1 = 0\n",
    "    best_test_acc = 0\n",
    "\n",
    "    nb_epochs = 0\n",
    "\n",
    "    for epoch in range(epochs):\n",
    "        start_time = time.time()\n",
    "        model.train()\n",
    "\n",
    "        total_loss, total_samples = 0.0, 0\n",
    "        loss_s0_sum, count_s0 = 0.0, 0\n",
    "        loss_s1_sum, count_s1 = 0.0, 0\n",
    "        correct_total, correct_s0, correct_s1 = 0, 0, 0\n",
    "\n",
    "        for X_cat, X_cont, S_batch, y_batch in trainloader:\n",
    "            X_cat, X_cont, S_batch, y_batch = X_cat.to(device), X_cont.to(device), S_batch.to(device), y_batch.to(device)\n",
    "            optimizer.zero_grad()\n",
    "            outputs = model(X_cat, X_cont)\n",
    "            loss = criterion(outputs, y_batch)\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "\n",
    "            bsize = y_batch.size(0)\n",
    "            total_loss += loss.item() * bsize\n",
    "            total_samples += bsize\n",
    "\n",
    "            _, preds = outputs.max(1)\n",
    "            correct_total += preds.eq(y_batch).sum().item()\n",
    "\n",
    "            mask_s0 = (S_batch == 0)\n",
    "            if mask_s0.any():\n",
    "                n_s0 = mask_s0.sum().item()\n",
    "                loss_s0 = criterion(outputs[mask_s0], y_batch[mask_s0]).item()\n",
    "                loss_s0_sum += loss_s0 * n_s0\n",
    "                count_s0 += n_s0\n",
    "                correct_s0 += preds[mask_s0].eq(y_batch[mask_s0]).sum().item()\n",
    "\n",
    "            mask_s1 = (S_batch == 1)\n",
    "            if mask_s1.any():\n",
    "                n_s1 = mask_s1.sum().item()\n",
    "                loss_s1 = criterion(outputs[mask_s1], y_batch[mask_s1]).item()\n",
    "                loss_s1_sum += loss_s1 * n_s1\n",
    "                count_s1 += n_s1\n",
    "                correct_s1 += preds[mask_s1].eq(y_batch[mask_s1]).sum().item()\n",
    "\n",
    "        avg_loss = total_loss / total_samples\n",
    "        avg_loss_s0 = loss_s0_sum / count_s0\n",
    "        avg_loss_s1 = loss_s1_sum / count_s1\n",
    "        acc_total = (correct_total / total_samples) * 100\n",
    "        acc_s0 = (correct_s0 / count_s0) * 100\n",
    "        acc_s1 = (correct_s1 / count_s1) * 100\n",
    "\n",
    "        best_train_loss_s0 = min(best_train_loss_s0, avg_loss_s0)\n",
    "        best_train_loss_s1 = min(best_train_loss_s1, avg_loss_s1)\n",
    "        best_train_loss = min(best_train_loss, avg_loss)\n",
    "        best_train_acc_s0 = max(best_train_acc_s0, acc_s0)\n",
    "        best_train_acc_s1 = max(best_train_acc_s1, acc_s1)\n",
    "        best_train_acc = max(best_train_acc, acc_total)\n",
    "\n",
    "        train_loss_s0.append(best_train_loss_s0)\n",
    "        train_loss_s1.append(best_train_loss_s1)\n",
    "        train_loss_all.append(best_train_loss)\n",
    "        train_acc_s0.append(best_train_acc_s0)\n",
    "        train_acc_s1.append(best_train_acc_s1)\n",
    "        train_acc_all.append(best_train_acc)\n",
    "\n",
    "        if early_stopping_epoch is None and best_train_acc > tau:\n",
    "            early_stopping_epoch = epoch + 1\n",
    "\n",
    "        model.eval()\n",
    "        test_loss_sum, total_test = 0.0, 0\n",
    "        loss_s0_test_sum, count_s0_test = 0.0, 0\n",
    "        loss_s1_test_sum, count_s1_test = 0.0, 0\n",
    "        correct_test_total, correct_s0_test, correct_s1_test = 0, 0, 0\n",
    "\n",
    "        with torch.no_grad():\n",
    "            for X_cat, X_cont, S_batch, y_batch in testloader:\n",
    "                X_cat, X_cont, S_batch, y_batch = X_cat.to(device), X_cont.to(device), S_batch.to(device), y_batch.to(device)\n",
    "                outputs = model(X_cat, X_cont)\n",
    "                loss = criterion(outputs, y_batch)\n",
    "                bsize = y_batch.size(0)\n",
    "                test_loss_sum += loss.item() * bsize\n",
    "                total_test += bsize\n",
    "\n",
    "                _, preds = outputs.max(1)\n",
    "                correct_test_total += preds.eq(y_batch).sum().item()\n",
    "\n",
    "                mask_s0 = (S_batch == 0)\n",
    "                if mask_s0.any():\n",
    "                    n_s0 = mask_s0.sum().item()\n",
    "                    loss_s0 = criterion(outputs[mask_s0], y_batch[mask_s0]).item()\n",
    "                    loss_s0_test_sum += loss_s0 * n_s0\n",
    "                    count_s0_test += n_s0\n",
    "                    correct_s0_test += preds[mask_s0].eq(y_batch[mask_s0]).sum().item()\n",
    "\n",
    "                mask_s1 = (S_batch == 1)\n",
    "                if mask_s1.any():\n",
    "                    n_s1 = mask_s1.sum().item()\n",
    "                    loss_s1 = criterion(outputs[mask_s1], y_batch[mask_s1]).item()\n",
    "                    loss_s1_test_sum += loss_s1 * n_s1\n",
    "                    count_s1_test += n_s1\n",
    "                    correct_s1_test += preds[mask_s1].eq(y_batch[mask_s1]).sum().item()\n",
    "\n",
    "        avg_loss_test = test_loss_sum / total_test\n",
    "        avg_loss_s0_test = loss_s0_test_sum / count_s0_test\n",
    "        avg_loss_s1_test = loss_s1_test_sum / count_s1_test\n",
    "        acc_test = (correct_test_total / total_test) * 100\n",
    "        acc_s0_test = (correct_s0_test / count_s0_test) * 100 if count_s0_test > 0 else 0\n",
    "        acc_s1_test = (correct_s1_test / count_s1_test) * 100 if count_s1_test > 0 else 0\n",
    "\n",
    "        best_test_loss_s0 = min(best_test_loss_s0, avg_loss_s0_test)\n",
    "        best_test_loss_s1 = min(best_test_loss_s1, avg_loss_s1_test)\n",
    "        best_test_loss = min(best_test_loss, avg_loss_test)\n",
    "        best_test_acc_s0 = max(best_test_acc_s0, acc_s0_test)\n",
    "        best_test_acc_s1 = max(best_test_acc_s1, acc_s1_test)\n",
    "        best_test_acc = max(best_test_acc, acc_test)\n",
    "\n",
    "        test_loss_s0.append(best_test_loss_s0)\n",
    "        test_loss_s1.append(best_test_loss_s1)\n",
    "        test_loss_all.append(best_test_loss)\n",
    "        test_acc_s0.append(best_test_acc_s0)\n",
    "        test_acc_s1.append(best_test_acc_s1)\n",
    "        test_acc_all.append(best_test_acc)\n",
    "\n",
    "        cumulative_time += time.time() - start_time\n",
    "        times.append(cumulative_time)\n",
    "\n",
    "        if (epoch + 1) % 10 == 0 or epoch == 0:\n",
    "            print(f\"Epoch [{epoch+1}/{epochs}] | Time: {cumulative_time:.2f}s\")\n",
    "            print(f\"  Train -> Loss: S0={avg_loss_s0:.4f}, S1={avg_loss_s1:.4f}, Global={avg_loss:.4f} | \"\n",
    "                  f\"Acc: S0={acc_s0:.2f}%, S1={acc_s1:.2f}%, Global={acc_total:.2f}%\")\n",
    "            print(f\"  Test  -> Loss: S0={avg_loss_s0_test:.4f}, S1={avg_loss_s1_test:.4f}, Global={avg_loss_test:.4f} | \"\n",
    "                  f\"Acc: S0={acc_s0_test:.2f}%, S1={acc_s1_test:.2f}%, Global={acc_test:.2f}%\")\n",
    "        scheduler.step()\n",
    "        nb_epochs += 1\n",
    "        if best_train_acc_s0 > kappa:\n",
    "            final_epoch = epoch + 1\n",
    "            break\n",
    "\n",
    "    if final_epoch is None:\n",
    "        final_epoch = epochs\n",
    "\n",
    "    debiasing_duration = (final_epoch - early_stopping_epoch) if early_stopping_epoch is not None else None\n",
    "\n",
    "    print(f\"\\nTraining finished in {final_epoch} epochs.\")\n",
    "    if early_stopping_epoch:\n",
    "        print(f\"→ Early stopping threshold (Acc > τ={tau}%) reached at epoch {early_stopping_epoch}.\")\n",
    "    print(f\"→ Fairness threshold (Acc_S=0 > κ={kappa}%) reached at epoch {final_epoch}.\")\n",
    "    if early_stopping_epoch:\n",
    "        print(f\"→ Debiasing duration: {debiasing_duration} epochs.\")\n",
    "\n",
    "    return {\n",
    "         \"times\": np.array(times),\n",
    "         \"epoch\": np.arange(1, nb_epochs + 1),\n",
    "         \"train_loss_s0\": np.array(train_loss_s0),\n",
    "         \"train_loss_s1\": np.array(train_loss_s1),\n",
    "         \"train_loss_all\": np.array(train_loss_all),\n",
    "         \"test_loss_s0\": np.array(test_loss_s0),\n",
    "         \"test_loss_s1\": np.array(test_loss_s1),\n",
    "         \"test_loss_all\": np.array(test_loss_all),\n",
    "         \"train_acc_s0\": np.array(train_acc_s0),\n",
    "         \"train_acc_s1\": np.array(train_acc_s1),\n",
    "         \"train_acc_all\": np.array(train_acc_all),\n",
    "         \"test_acc_s0\": np.array(test_acc_s0),\n",
    "         \"test_acc_s1\": np.array(test_acc_s1),\n",
    "         \"test_acc_all\": np.array(test_acc_all)\n",
    "    }, model\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f85aba7c",
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_multiple_experiments(network, X_train, S_train, y_train,\n",
    "                X_test, S_test, y_test,\n",
    "                epochs=10, batch_size=64, learning_rate=0.05, device='cpu', num_runs=5):\n",
    "    records = []\n",
    "    \n",
    "    for run in range(num_runs):\n",
    "        print(f\"\\n--- Run {run+1}/{num_runs} ---\")\n",
    "        result, model = train_model(network, X_train, S_train, y_train,\n",
    "                X_test, S_test, y_test,\n",
    "                epochs=epochs, batch_size=batch_size, learning_rate=learning_rate,\n",
    "                device=device)\n",
    "        \n",
    "        for i, epoch in enumerate(result[\"epoch\"]):\n",
    "            record = {\n",
    "                \"num_run\": run+1,\n",
    "                \"epoch\": result[\"epoch\"][i],\n",
    "                \"time\": result[\"times\"][i],\n",
    "                \"train_loss_all\": result[\"train_loss_all\"][i],\n",
    "                \"train_loss_s0\": result[\"train_loss_s0\"][i],\n",
    "                \"train_loss_s1\": result[\"train_loss_s1\"][i],\n",
    "                \"test_loss_all\": result[\"test_loss_all\"][i],\n",
    "                \"test_loss_s0\": result[\"test_loss_s0\"][i],\n",
    "                \"test_loss_s1\": result[\"test_loss_s1\"][i],\n",
    "                \"train_acc_all\": result[\"train_acc_all\"][i],\n",
    "                \"train_acc_s0\": result[\"train_acc_s0\"][i],\n",
    "                \"train_acc_s1\": result[\"train_acc_s1\"][i],\n",
    "                \"test_acc_all\": result[\"test_acc_all\"][i],\n",
    "                \"test_acc_s0\": result[\"test_acc_s0\"][i],\n",
    "                \"test_acc_s1\": result[\"test_acc_s1\"][i],\n",
    "            }\n",
    "            records.append(record)\n",
    "    \n",
    "    df = pd.DataFrame(records)\n",
    "    df[\"epoch\"] = df[\"epoch\"].astype(int) \n",
    "    \n",
    "    return df, model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8c2faf2f",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_results, model = run_multiple_experiments(\"tabnet_v1\", X_train, S_train, y_train, X_test, S_test, y_test,\n",
    "                             epochs=1000, batch_size=1024, learning_rate=2e-2, device='cuda', num_runs=2) #5e-2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7fb67617",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_xp(df, network, kappa=90):\n",
    "    import matplotlib.pyplot as plt\n",
    "    import seaborn as sns\n",
    "    import pandas as pd\n",
    "    from matplotlib.ticker import MaxNLocator\n",
    "    from matplotlib.lines import Line2D\n",
    "\n",
    "    df_loss_train = df.melt(id_vars=[\"num_run\", \"epoch\", \"time\"], \n",
    "                            value_vars=[\"train_loss_all\", \"train_loss_s0\", \"train_loss_s1\"],\n",
    "                            var_name=\"group\", value_name=\"train_loss\")\n",
    "    df_acc_train = df.melt(id_vars=[\"num_run\", \"epoch\", \"time\"], \n",
    "                           value_vars=[\"train_acc_all\", \"train_acc_s0\", \"train_acc_s1\"],\n",
    "                           var_name=\"group\", value_name=\"train_acc\")\n",
    "    df_acc_test = df.melt(id_vars=[\"num_run\", \"epoch\", \"time\"], \n",
    "                          value_vars=[\"test_acc_all\", \"test_acc_s0\", \"test_acc_s1\"],\n",
    "                          var_name=\"group\", value_name=\"test_acc\")\n",
    "\n",
    "    mapping = {\n",
    "        \"train_loss_all\": \"Global\", \"train_loss_s0\": \"A=0\", \"train_loss_s1\": \"A=1\",\n",
    "        \"train_acc_all\": \"Global\", \"train_acc_s0\": \"A=0\", \"train_acc_s1\": \"A=1\",\n",
    "        \"test_acc_all\": \"Global\", \"test_acc_s0\": \"A=0\", \"test_acc_s1\": \"A=1\"\n",
    "    }\n",
    "\n",
    "    for df_, name in [(df_loss_train, \"group\"), (df_acc_train, \"group\"), (df_acc_test, \"group\")]:\n",
    "        df_[name] = df_[name].map(mapping)\n",
    "\n",
    "    palette = {\"A=0\": \"blue\", \"A=1\": \"orange\", \"Global\": \"green\"}\n",
    "\n",
    "    fig, axs = plt.subplots(1, 3, figsize=(14, 4), sharex=True)\n",
    "\n",
    "    sns.lineplot(ax=axs[0], data=df_loss_train, x=\"epoch\", y=\"train_loss\", hue=\"group\",\n",
    "                 estimator='mean', errorbar='sd', palette=palette)\n",
    "    axs[0].set_yscale(\"log\")\n",
    "    axs[0].set_xlabel(\"Epoch\", fontsize=14)\n",
    "    axs[0].set_ylabel(\"Train loss\", fontsize=14)\n",
    "    axs[0].grid(True)\n",
    "    axs[0].get_legend().remove()\n",
    "\n",
    "    sns.lineplot(ax=axs[1], data=df_acc_train, x=\"epoch\", y=\"train_acc\", hue=\"group\",\n",
    "                 estimator='mean', errorbar='sd', palette=palette)\n",
    "    axs[1].set_xlabel(\"Epoch\", fontsize=14)\n",
    "    axs[1].set_ylabel(\"Train accuracy (%)\", fontsize=14)\n",
    "    axs[1].axhline(kappa, color='red', linestyle='--', linewidth=1)\n",
    "    axs[1].grid(True)\n",
    "    axs[1].get_legend().remove()\n",
    "\n",
    "    sns.lineplot(ax=axs[2], data=df_acc_test, x=\"epoch\", y=\"test_acc\", hue=\"group\",\n",
    "                 estimator='mean', errorbar='sd', palette=palette)\n",
    "    axs[2].set_xlabel(\"Epoch\", fontsize=14)\n",
    "    axs[2].set_ylabel(\"Test accuracy (%)\", fontsize=14)\n",
    "    axs[2].grid(True)\n",
    "    axs[2].get_legend().remove()\n",
    "\n",
    "    handles = [Line2D([0], [0], color=color, lw=2, label=label) for label, color in palette.items()]\n",
    "    kappa_line = Line2D([0], [0], color='red', lw=1, linestyle='--', label=r\"$\\kappa$ threshold\")\n",
    "    handles.append(kappa_line)\n",
    "\n",
    "    fig.legend(handles, [h.get_label() for h in handles], loc='lower center',\n",
    "               ncol=4, bbox_to_anchor=(0.5, -0.15), fontsize=14)\n",
    "\n",
    "    for ax in axs:\n",
    "        ax.xaxis.set_major_locator(MaxNLocator(integer=True))\n",
    "\n",
    "    plt.tight_layout()\n",
    "    plt.savefig(f\"results/ADULT/adult_{network}_metrics_with_kappa.pdf\", bbox_inches=\"tight\")\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "13facf4f",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_xp(df_results,\"tabnet_v1\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
