{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 239,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "from error_parity import RelaxedThresholdOptimizer\n",
    "from error_parity.classifiers import RandomizedClassifier\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "\n",
    "from source.constants import RESULTS_PATH, PLOTS_PATH\n",
    "from source.data.face_detection import get_fair_face, get_utk\n",
    "from source.utils.metrics import accuracy, aod, eod, spd\n",
    "\n",
    "os.makedirs(PLOTS_PATH, exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 240,
   "metadata": {},
   "outputs": [],
   "source": [
    "method_seeds = [42, 142, 242, 342, 442]\n",
    "dseed = 42\n",
    "\n",
    "model = [\"resnet18\", \"resnet34\", \"resnet50\"][2]\n",
    "\n",
    "verbose = False\n",
    "\n",
    "targets = [\"age\", \"gender\", \"race(old)\", \"race\"]\n",
    "# predicting race does not give high unfairness (with either pa) for eod and aod\n",
    "# predicting gender also not too nice (only unfairness with age)\n",
    "target = 3 # 0, 1, 2, 3\n",
    "pa = 1 # 0, 1, 2, 3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 241,
   "metadata": {},
   "outputs": [],
   "source": [
    "# parameters\n",
    "c = 2\n",
    "constraint = [\"demographic_parity\", \"true_positive_rate_parity\", \"average_odds\"][c]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 242,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "10843 10843 10954 23705\n"
     ]
    }
   ],
   "source": [
    "# no need to define targets and protected attributes, are queried directly afterwards\n",
    "ff_train_ds, ff_test_ds = get_fair_face(binarize=True, augment=False)\n",
    "utk_test_ds = get_utk(binarize=True)\n",
    "\n",
    "run_path = os.path.join(RESULTS_PATH, f\"fairface_target{target}_{model}_mseed{method_seeds[0]}_dseed{dseed}\")\n",
    "fair_inds = torch.load(os.path.join(run_path, \"fair_inds.pt\"))\n",
    "val_inds = torch.load(os.path.join(run_path, \"val_inds.pt\"))\n",
    "\n",
    "print(len(fair_inds), len(val_inds), len(ff_test_ds), len(utk_test_ds))\n",
    "\n",
    "# get targets and protected attributes\n",
    "y_fair_t = ff_train_ds.targets[target, fair_inds]\n",
    "a_fair_t = ff_train_ds.targets[pa, fair_inds]\n",
    "y_val_t = ff_train_ds.targets[target, val_inds]\n",
    "a_val_t = ff_train_ds.targets[pa, val_inds]\n",
    "y_ff_test_t = ff_test_ds.targets[target]\n",
    "a_ff_test_t = ff_test_ds.targets[pa]\n",
    "y_utk_test_t = utk_test_ds.targets[target]\n",
    "a_utk_test_t = utk_test_ds.targets[pa]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 243,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load probits\n",
    "fair_probits, val_probits, ff_test_probits, utk_test_probits = list(), list(), list(), list()\n",
    "for mseed in method_seeds:\n",
    "    path = os.path.join(RESULTS_PATH, f\"fairface_target{target}_{model}_mseed{mseed}_dseed{dseed}\")\n",
    "\n",
    "    fair_probits.append(torch.load(os.path.join(path, f\"fair_probits_t{target}.pt\")))\n",
    "    val_probits.append(torch.load(os.path.join(path, f\"val_probits_t{target}.pt\")))\n",
    "    ff_test_probits.append(torch.load(os.path.join(path, f\"ff_test_probits_t{target}.pt\")))\n",
    "    utk_test_probits.append(torch.load(os.path.join(path, f\"utk_test_probits_t{target}.pt\")))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 244,
   "metadata": {},
   "outputs": [],
   "source": [
    "# calculate accuracies and fairness measures\n",
    "val_accs, ff_test_accs, utk_test_accs = list(), list(), list()\n",
    "val_spds, ff_test_spds, utk_test_spds = list(), list(), list()\n",
    "val_eods, ff_test_eods, utk_test_eods = list(), list(), list()\n",
    "val_aods, ff_test_aods, utk_test_aods = list(), list(), list()\n",
    "\n",
    "for m in range(len(method_seeds)):\n",
    "    val_accs.append([accuracy(p.argmax(dim=1), y_val_t) for p in val_probits[m]])\n",
    "    ff_test_accs.append([accuracy(p.argmax(dim=1), y_ff_test_t) for p in ff_test_probits[m]])\n",
    "    utk_test_accs.append([accuracy(p.argmax(dim=1), y_utk_test_t) for p in utk_test_probits[m]])\n",
    "\n",
    "    val_spds.append([spd(p.argmax(dim=1), a_val_t) for p in val_probits[m]])\n",
    "    ff_test_spds.append([spd(p.argmax(dim=1), a_ff_test_t) for p in ff_test_probits[m]])\n",
    "    utk_test_spds.append([spd(p.argmax(dim=1), a_utk_test_t) for p in utk_test_probits[m]])\n",
    "    \n",
    "    val_eods.append([eod(p.argmax(dim=1), y_val_t, a_val_t) for p in val_probits[m]])\n",
    "    ff_test_eods.append([eod(p.argmax(dim=1), y_ff_test_t, a_ff_test_t) for p in ff_test_probits[m]])\n",
    "    utk_test_eods.append([eod(p.argmax(dim=1), y_utk_test_t, a_utk_test_t) for p in utk_test_probits[m]])\n",
    "\n",
    "    val_aods.append([aod(p.argmax(dim=1), y_val_t, a_val_t) for p in val_probits[m]])\n",
    "    ff_test_aods.append([aod(p.argmax(dim=1), y_ff_test_t, a_ff_test_t) for p in ff_test_probits[m]])\n",
    "    utk_test_aods.append([aod(p.argmax(dim=1), y_utk_test_t, a_utk_test_t) for p in utk_test_probits[m]])\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 245,
   "metadata": {},
   "outputs": [],
   "source": [
    "# method to do the fake predictions\n",
    "class DummyPredictor(nn.Module):\n",
    "    def __init__(self, probits):\n",
    "        super(DummyPredictor, self).__init__()\n",
    "        self.probits = probits\n",
    "\n",
    "    def forward(self, indices:torch.Tensor):\n",
    "        return self.probits[indices].numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 246,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_thresholds(fair_clf, verbose=True):\n",
    "    thresholds = list()\n",
    "    for i in range(2):\n",
    "        if verbose: print(f\"Class {i}\")\n",
    "        if isinstance(fair_clf._realized_classifier.group_to_clf[i], RandomizedClassifier):\n",
    "            thrs = list()\n",
    "            for clf in fair_clf._realized_classifier.group_to_clf[i].classifiers:\n",
    "                if verbose: print(clf.threshold)\n",
    "                thrs.append(clf.threshold)\n",
    "            thresholds.append(thrs)   \n",
    "        else:\n",
    "            thrs = fair_clf._realized_classifier.group_to_clf[i].threshold\n",
    "            if verbose: print(thrs)\n",
    "            thresholds.append([thrs, thrs])\n",
    "    return thresholds"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## FairFace"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Optimize Ensemble for average member constraint"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 247,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "------------------------------\n",
      "$0.873_{\\pm 0.002}$ & $0.013_{\\pm 0.006}$\n",
      "$0.888_{\\pm 0.001}$ & $0.016_{\\pm 0.002}$\n",
      "$0.888_{\\pm 0.001}$ & $0.016_{\\pm 0.005}$\n",
      "------------------------------\n",
      "Group 0\n",
      "$0.439_{\\pm 0.049}$\n",
      "$0.408_{\\pm 0.044}$\n",
      "Group 1\n",
      "$0.345_{\\pm 0.033}$\n",
      "$0.337_{\\pm 0.030}$\n"
     ]
    }
   ],
   "source": [
    "accs_bma, fairs_bma = list(), list()\n",
    "accs_bma_pp, fairs_bma_pp = list(), list()\n",
    "accs_avg, fairs_avg = list(), list()\n",
    "thresholds_bma_pp = list()\n",
    "\n",
    "for m in range(len(method_seeds)):\n",
    "\n",
    "    if verbose: print(\"-\"*20 + f\"  seed {m}  \" + \"-\"*20)\n",
    "\n",
    "    val_m_probits = torch.mean(val_probits[m], dim=0)\n",
    "\n",
    "    val_fairness = [val_spds[m], val_eods[m], val_aods[m]][c]\n",
    "    test_fairness = [ff_test_spds[m], ff_test_eods[m], ff_test_aods[m]][c]\n",
    "\n",
    "    model = DummyPredictor(val_m_probits)\n",
    "\n",
    "    # Given any trained model that outputs real-valued scores\n",
    "    fair_clf = RelaxedThresholdOptimizer(\n",
    "        predictor=lambda X: model(X)[:, -1],   # for sklearn API\n",
    "        constraint=constraint,\n",
    "        tolerance=max(np.mean(val_fairness), 0), # fairness constraint tolerance\n",
    "    )\n",
    "\n",
    "    # Fit the fairness adjustment on some data\n",
    "    # This will find the optimal _fair classifier_\n",
    "    fair_clf.fit(X=torch.tensor(range(len(y_val_t))), y=y_val_t.numpy(), group=a_val_t.numpy())\n",
    "\n",
    "    # Get the thresholds for the optimal classifier\n",
    "    thresholds_bma_pp.append(get_thresholds(fair_clf, verbose=verbose))\n",
    "\n",
    "    # overwrite model for predictor\n",
    "    ff_test_m_probits = torch.mean(ff_test_probits[m], dim=0)\n",
    "    model.probits = ff_test_m_probits\n",
    "\n",
    "    # Now you can use `fair_clf` as any other classifier\n",
    "    # You have to provide group information to compute fair predictions\n",
    "    y_pred_test = fair_clf(X=torch.tensor(range(len(y_ff_test_t))), group=a_ff_test_t.numpy())\n",
    "    y_pred_test = torch.tensor(y_pred_test, dtype=torch.long)\n",
    "\n",
    "    if verbose: print(\"Avg Member\")\n",
    "    accs_avg.extend(ff_test_accs[m])\n",
    "    if verbose: print(f\"  {(ff_test_accs[0][m]):.3f}\")\n",
    "    fairs_avg.extend(test_fairness)\n",
    "    if verbose: print(f\"  {test_fairness[0]:.3f} (val: {val_fairness[0]:.3f})\")\n",
    "    if verbose: print(\"BMA\")\n",
    "    accs_bma.append(accuracy(ff_test_m_probits.argmax(dim=1), y_ff_test_t).item())\n",
    "    if verbose: print(f\"  {(accuracy(ff_test_m_probits.argmax(dim=1), y_ff_test_t).item()):.3f}\")\n",
    "    if c == 0:\n",
    "        fairs_bma.append(spd(ff_test_m_probits.argmax(dim=1), a_ff_test_t).item())\n",
    "        if verbose: print(f\"  {spd(ff_test_m_probits.argmax(dim=1), a_ff_test_t).item():.3f}\")\n",
    "    elif c == 1:\n",
    "        fairs_bma.append(eod(ff_test_m_probits.argmax(dim=1), y_ff_test_t, a_ff_test_t).item())\n",
    "        if verbose: print(f\"  {eod(ff_test_m_probits.argmax(dim=1), y_ff_test_t, a_ff_test_t).item():.3f}\")\n",
    "    elif c == 2:\n",
    "        fairs_bma.append(aod(ff_test_m_probits.argmax(dim=1), y_ff_test_t, a_ff_test_t).item())\n",
    "        if verbose: print(f\"  {aod(ff_test_m_probits.argmax(dim=1), y_ff_test_t, a_ff_test_t).item():.3f}\")\n",
    "    if verbose: print(\"BMA-PP\")\n",
    "    accs_bma_pp.append(accuracy(y_pred_test, y_ff_test_t).item())\n",
    "    if verbose: print(f\"  {(accuracy(y_pred_test, y_ff_test_t).item()):.3f}\")\n",
    "    if c == 0:\n",
    "        fairs_bma_pp.append(spd(y_pred_test, a_ff_test_t).item())\n",
    "        if verbose: print(f\"  {spd(y_pred_test, a_ff_test_t).item():.3f}\")\n",
    "    elif c == 1:\n",
    "        fairs_bma_pp.append(eod(y_pred_test, y_ff_test_t, a_ff_test_t).item())\n",
    "        if verbose: print(f\"  {eod(y_pred_test, y_ff_test_t, a_ff_test_t).item():.3f}\")\n",
    "    elif c == 2:\n",
    "        fairs_bma_pp.append(aod(y_pred_test, y_ff_test_t, a_ff_test_t).item())\n",
    "        if verbose: print(f\"  {aod(y_pred_test, y_ff_test_t, a_ff_test_t).item():.3f}\")\n",
    "\n",
    "thresholds_bma_pp = np.asarray(thresholds_bma_pp)\n",
    "\n",
    "print(\"-\"*30)\n",
    "print(f\"${np.mean(accs_avg):.3f}_{'{'}\\pm {np.std(accs_avg):.3f}{'}'}$\", end=\" & \")\n",
    "print(f\"${np.mean(fairs_avg):.3f}_{'{'}\\pm {np.std(fairs_avg):.3f}{'}'}$\")\n",
    "print(f\"${np.mean(accs_bma):.3f}_{'{'}\\pm {np.std(accs_bma):.3f}{'}'}$\", end=\" & \")\n",
    "print(f\"${np.mean(fairs_bma):.3f}_{'{'}\\pm {np.std(fairs_bma):.3f}{'}'}$\")\n",
    "print(f\"${np.mean(accs_bma_pp):.3f}_{'{'}\\pm {np.std(accs_bma_pp):.3f}{'}'}$\", end=\" & \")\n",
    "print(f\"${np.mean(fairs_bma_pp):.3f}_{'{'}\\pm {np.std(fairs_bma_pp):.3f}{'}'}$\")\n",
    "print(\"-\"*30)\n",
    "for i in range(2):\n",
    "    print(f\"Group {i}\")\n",
    "    print(f\"${np.mean(thresholds_bma_pp[:, i, 0]):.3f}_{'{'}\\pm {np.std(thresholds_bma_pp[:, i, 0]):.3f}{'}'}$\")\n",
    "    print(f\"${np.mean(thresholds_bma_pp[:, i, 1]):.3f}_{'{'}\\pm {np.std(thresholds_bma_pp[:, i, 1]):.3f}{'}'}$\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Optimize both ensemble and single model for 0.05 constraint"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 248,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "------------------------------\n",
      "$0.888_{\\pm 0.002}$ & $0.019_{\\pm 0.004}$\n",
      "$0.873_{\\pm 0.004}$ & $0.029_{\\pm 0.027}$\n",
      "------------------------------\n",
      "Group 0\n",
      "$0.439_{\\pm 0.049}$\n",
      "$0.408_{\\pm 0.044}$\n",
      "Group 1\n",
      "$0.345_{\\pm 0.033}$\n",
      "$0.337_{\\pm 0.030}$\n",
      "------------------------------\n",
      "Group 0\n",
      "$0.646_{\\pm 0.162}$\n",
      "$0.577_{\\pm 0.165}$\n",
      "Group 1\n",
      "$0.628_{\\pm 0.156}$\n",
      "$0.566_{\\pm 0.174}$\n"
     ]
    }
   ],
   "source": [
    "accs_bma_pp, fairs_bma_pp = list(), list()\n",
    "accs_member_pp, fairs_member_pp = list(), list()\n",
    "thresholds_bma_pp = list()\n",
    "thresholds_member_pp = list()\n",
    "\n",
    "for m in range(len(method_seeds)):\n",
    "\n",
    "    if verbose: print(\"-\"*20 + f\"  seed {m}  \" + \"-\"*20)\n",
    "\n",
    "    val_m_probits = torch.mean(val_probits[m], dim=0)\n",
    "\n",
    "    model = DummyPredictor(val_m_probits)\n",
    "\n",
    "    # Given any trained model that outputs real-valued scores\n",
    "    fair_clf = RelaxedThresholdOptimizer(\n",
    "        predictor=lambda X: model(X)[:, -1],   # for sklearn API\n",
    "        constraint=constraint,\n",
    "        tolerance=0.05, # fairness constraint tolerance\n",
    "    )\n",
    "\n",
    "    # Fit the fairness adjustment on some data\n",
    "    # This will find the optimal _fair classifier_\n",
    "    fair_clf.fit(X=torch.tensor(range(len(y_val_t))), y=y_val_t.numpy(), group=a_val_t.numpy())\n",
    "\n",
    "    # Get the thresholds for the optimal classifier\n",
    "    thresholds_bma_pp.append(get_thresholds(fair_clf, verbose=verbose))\n",
    "\n",
    "    # overwrite model for predictor\n",
    "    ff_test_m_probits = torch.mean(ff_test_probits[m], dim=0)\n",
    "    model.probits = ff_test_m_probits\n",
    "\n",
    "    # Now you can use `fair_clf` as any other classifier\n",
    "    # You have to provide group information to compute fair predictions\n",
    "    y_pred_test = fair_clf(X=torch.tensor(range(len(y_ff_test_t))), group=a_ff_test_t.numpy())\n",
    "    y_pred_test = torch.tensor(y_pred_test, dtype=torch.long)\n",
    "\n",
    "    if verbose: print(\"BMA-PP\")\n",
    "    accs_bma_pp.append(accuracy(y_pred_test, y_ff_test_t).item())\n",
    "    if verbose: print(f\"  {(accuracy(y_pred_test, y_ff_test_t).item()):.3f}\")\n",
    "    if c == 0:\n",
    "        fairs_bma_pp.append(spd(y_pred_test, a_ff_test_t).item())\n",
    "        if verbose: print(f\"  {spd(y_pred_test, a_ff_test_t).item():.3f}\")\n",
    "    elif c == 1:\n",
    "        fairs_bma_pp.append(eod(y_pred_test, y_ff_test_t, a_ff_test_t).item())\n",
    "        if verbose: print(f\"  {eod(y_pred_test, y_ff_test_t, a_ff_test_t).item():.3f}\")\n",
    "    elif c == 2:\n",
    "        fairs_bma_pp.append(aod(y_pred_test, y_ff_test_t, a_ff_test_t).item())\n",
    "        if verbose: print(f\"  {aod(y_pred_test, y_ff_test_t, a_ff_test_t).item():.3f}\")\n",
    "\n",
    "    for mem in range(len(val_probits[m])):\n",
    "        val_m_probits = val_probits[m][mem]\n",
    "\n",
    "        model = DummyPredictor(val_m_probits)\n",
    "\n",
    "        # Given any trained model that outputs real-valued scores\n",
    "        fair_clf = RelaxedThresholdOptimizer(\n",
    "            predictor=lambda X: model(X)[:, -1],   # for sklearn API\n",
    "            constraint=constraint,\n",
    "            tolerance=0.05, # fairness constraint tolerance\n",
    "        )\n",
    "\n",
    "        # Fit the fairness adjustment on some data\n",
    "        # This will find the optimal _fair classifier_\n",
    "        fair_clf.fit(X=torch.tensor(range(len(y_val_t))), y=y_val_t.numpy(), group=a_val_t.numpy())\n",
    "\n",
    "        # Get the thresholds for the optimal classifier\n",
    "        thresholds_member_pp.append(get_thresholds(fair_clf, verbose=verbose))\n",
    "\n",
    "        # overwrite model for predictor\n",
    "        ff_test_m_probits = ff_test_probits[m][0]\n",
    "        model.probits = ff_test_m_probits\n",
    "\n",
    "        # Now you can use `fair_clf` as any other classifier\n",
    "        # You have to provide group information to compute fair predictions\n",
    "        y_pred_test = fair_clf(X=torch.tensor(range(len(y_ff_test_t))), group=a_ff_test_t.numpy())\n",
    "        y_pred_test = torch.tensor(y_pred_test, dtype=torch.long)\n",
    "\n",
    "        if mem == 0 and verbose : print(\"Member-PP\")\n",
    "        accs_member_pp.append(accuracy(y_pred_test, y_ff_test_t).item())\n",
    "        if mem == 0 and verbose : print(f\"  {(accuracy(y_pred_test, y_ff_test_t).item()):.3f}\")\n",
    "        if c == 0:\n",
    "            fairs_member_pp.append(spd(y_pred_test, a_ff_test_t).item())\n",
    "            if mem == 0 and verbose : print(f\"  {spd(y_pred_test, a_ff_test_t).item():.3f}\")\n",
    "        elif c == 1:\n",
    "            fairs_member_pp.append(eod(y_pred_test, y_ff_test_t, a_ff_test_t).item())\n",
    "            if mem == 0 and verbose : print(f\"  {eod(y_pred_test, y_ff_test_t, a_ff_test_t).item():.3f}\")\n",
    "        elif c == 2:\n",
    "            fairs_member_pp.append(aod(y_pred_test, y_ff_test_t, a_ff_test_t).item())\n",
    "            if mem == 0 and verbose : print(f\"  {aod(y_pred_test, y_ff_test_t, a_ff_test_t).item():.3f}\")\n",
    "\n",
    "thresholds_bma_pp = np.asarray(thresholds_bma_pp)\n",
    "thresholds_member_pp = np.asarray(thresholds_member_pp).reshape((-1, 2, 2))\n",
    "\n",
    "print(\"-\"*30)\n",
    "print(f\"${np.mean(accs_bma_pp):.3f}_{'{'}\\pm {np.std(accs_bma_pp):.3f}{'}'}$\", end=\" & \")\n",
    "print(f\"${np.mean(fairs_bma_pp):.3f}_{'{'}\\pm {np.std(fairs_bma_pp):.3f}{'}'}$\")\n",
    "print(f\"${np.mean(accs_member_pp):.3f}_{'{'}\\pm {np.std(accs_member_pp):.3f}{'}'}$\", end=\" & \")\n",
    "print(f\"${np.mean(fairs_member_pp):.3f}_{'{'}\\pm {np.std(fairs_member_pp):.3f}{'}'}$\")\n",
    "print(\"-\"*30)\n",
    "for i in range(2):\n",
    "    print(f\"Group {i}\")\n",
    "    print(f\"${np.mean(thresholds_bma_pp[:, i, 0]):.3f}_{'{'}\\pm {np.std(thresholds_bma_pp[:, i, 0]):.3f}{'}'}$\")\n",
    "    print(f\"${np.mean(thresholds_bma_pp[:, i, 1]):.3f}_{'{'}\\pm {np.std(thresholds_bma_pp[:, i, 1]):.3f}{'}'}$\")\n",
    "print(\"-\"*30)\n",
    "for i in range(2):\n",
    "    print(f\"Group {i}\")\n",
    "    print(f\"${np.mean(thresholds_member_pp[:, i, 0]):.3f}_{'{'}\\pm {np.std(thresholds_member_pp[:, i, 0]):.3f}{'}'}$\")\n",
    "    print(f\"${np.mean(thresholds_member_pp[:, i, 1]):.3f}_{'{'}\\pm {np.std(thresholds_member_pp[:, i, 1]):.3f}{'}'}$\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## UTKFaces"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Optimize Ensemble for average member constraint"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 249,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "------------------------------\n",
      "$0.822_{\\pm 0.006}$ & $0.015_{\\pm 0.010}$\n",
      "$0.843_{\\pm 0.002}$ & $0.013_{\\pm 0.002}$\n",
      "$0.859_{\\pm 0.003}$ & $0.016_{\\pm 0.008}$\n"
     ]
    }
   ],
   "source": [
    "accs_bma, fairs_bma = list(), list()\n",
    "accs_bma_pp, fairs_bma_pp = list(), list()\n",
    "accs_avg, fairs_avg = list(), list()\n",
    "\n",
    "for m in range(len(method_seeds)):\n",
    "\n",
    "    if verbose: print(\"-\"*20 + f\"  seed {m}  \" + \"-\"*20)\n",
    "\n",
    "    val_m_probits = torch.mean(val_probits[m], dim=0)\n",
    "\n",
    "    val_fairness = [val_spds[m], val_eods[m], val_aods[m]][c]\n",
    "    test_fairness = [utk_test_spds[m], utk_test_eods[m], utk_test_aods[m]][c]\n",
    "\n",
    "    model = DummyPredictor(val_m_probits)\n",
    "\n",
    "    # Given any trained model that outputs real-valued scores\n",
    "    fair_clf = RelaxedThresholdOptimizer(\n",
    "        predictor=lambda X: model(X)[:, -1],   # for sklearn API\n",
    "        constraint=constraint,\n",
    "        tolerance=max(np.mean(val_fairness), 0), # fairness constraint tolerance\n",
    "    )\n",
    "\n",
    "    # Fit the fairness adjustment on some data\n",
    "    # This will find the optimal _fair classifier_\n",
    "    fair_clf.fit(X=torch.tensor(range(len(y_val_t))), y=y_val_t.numpy(), group=a_val_t.numpy())\n",
    "\n",
    "    # overwrite model for predictor\n",
    "    utk_test_m_probits = torch.mean(utk_test_probits[m], dim=0)\n",
    "    model.probits = utk_test_m_probits\n",
    "\n",
    "    # Now you can use `fair_clf` as any other classifier\n",
    "    # You have to provide group information to compute fair predictions\n",
    "    y_pred_test = fair_clf(X=torch.tensor(range(len(y_utk_test_t))), group=a_utk_test_t.numpy())\n",
    "    y_pred_test = torch.tensor(y_pred_test, dtype=torch.long)\n",
    "\n",
    "    if verbose: print(\"Avg Member\")\n",
    "    accs_avg.extend(utk_test_accs[m])\n",
    "    if verbose: print(f\"  {(utk_test_accs[0][m]):.3f}\")\n",
    "    fairs_avg.extend(test_fairness)\n",
    "    if verbose: print(f\"  {test_fairness[0]:.3f} (val: {val_fairness[0]:.3f})\")\n",
    "\n",
    "    if verbose: print(\"BMA\")\n",
    "    accs_bma.append(accuracy(utk_test_m_probits.argmax(dim=1), y_utk_test_t).item())\n",
    "    if verbose: print(f\"  {(accuracy(utk_test_m_probits.argmax(dim=1), y_utk_test_t).item()):.3f}\")\n",
    "    if c == 0:\n",
    "        fairs_bma.append(spd(utk_test_m_probits.argmax(dim=1), a_utk_test_t).item())\n",
    "        if verbose: print(f\"  {spd(utk_test_m_probits.argmax(dim=1), a_utk_test_t).item():.3f}\")\n",
    "    elif c == 1:\n",
    "        fairs_bma.append(eod(utk_test_m_probits.argmax(dim=1), y_utk_test_t, a_utk_test_t).item())\n",
    "        if verbose: print(f\"  {eod(utk_test_m_probits.argmax(dim=1), y_utk_test_t, a_utk_test_t).item():.3f}\")\n",
    "    elif c == 2:\n",
    "        fairs_bma.append(aod(utk_test_m_probits.argmax(dim=1), y_utk_test_t, a_utk_test_t).item())\n",
    "        if verbose: print(f\"  {aod(utk_test_m_probits.argmax(dim=1), y_utk_test_t, a_utk_test_t).item():.3f}\")\n",
    "    if verbose: print(\"BMA-PP\")\n",
    "    accs_bma_pp.append(accuracy(y_pred_test, y_utk_test_t).item())\n",
    "    if verbose: print(f\"  {(accuracy(y_pred_test, y_utk_test_t).item()):.3f}\")\n",
    "    if c == 0:\n",
    "        fairs_bma_pp.append(spd(y_pred_test, a_utk_test_t).item())\n",
    "        if verbose: print(f\"  {spd(y_pred_test, a_utk_test_t).item():.3f}\")\n",
    "    elif c == 1:\n",
    "        fairs_bma_pp.append(eod(y_pred_test, y_utk_test_t, a_utk_test_t).item())\n",
    "        if verbose: print(f\"  {eod(y_pred_test, y_utk_test_t, a_utk_test_t).item():.3f}\")\n",
    "    elif c == 2:\n",
    "        fairs_bma_pp.append(aod(y_pred_test, y_utk_test_t, a_utk_test_t).item())\n",
    "        if verbose: print(f\"  {aod(y_pred_test, y_utk_test_t, a_utk_test_t).item():.3f}\")\n",
    "\n",
    "print(\"-\"*30)\n",
    "print(f\"${np.mean(accs_avg):.3f}_{'{'}\\pm {np.std(accs_avg):.3f}{'}'}$\", end=\" & \")\n",
    "print(f\"${np.mean(fairs_avg):.3f}_{'{'}\\pm {np.std(fairs_avg):.3f}{'}'}$\")\n",
    "print(f\"${np.mean(accs_bma):.3f}_{'{'}\\pm {np.std(accs_bma):.3f}{'}'}$\", end=\" & \")\n",
    "print(f\"${np.mean(fairs_bma):.3f}_{'{'}\\pm {np.std(fairs_bma):.3f}{'}'}$\")\n",
    "print(f\"${np.mean(accs_bma_pp):.3f}_{'{'}\\pm {np.std(accs_bma_pp):.3f}{'}'}$\", end=\" & \")\n",
    "print(f\"${np.mean(fairs_bma_pp):.3f}_{'{'}\\pm {np.std(fairs_bma_pp):.3f}{'}'}$\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Optimize both ensemble and single model for 0.05 constraint"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 250,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "------------------------------\n",
      "$0.859_{\\pm 0.002}$ & $0.019_{\\pm 0.006}$\n",
      "$0.816_{\\pm 0.014}$ & $0.030_{\\pm 0.032}$\n"
     ]
    }
   ],
   "source": [
    "accs_bma_pp, fairs_bma_pp = list(), list()\n",
    "accs_member_pp, fairs_member_pp = list(), list()\n",
    "\n",
    "for m in range(len(method_seeds)):\n",
    "\n",
    "    if verbose: print(\"-\"*20 + f\"  seed {m}  \" + \"-\"*20)\n",
    "\n",
    "    val_m_probits = torch.mean(val_probits[m], dim=0)\n",
    "\n",
    "    model = DummyPredictor(val_m_probits)\n",
    "\n",
    "    # Given any trained model that outputs real-valued scores\n",
    "    fair_clf = RelaxedThresholdOptimizer(\n",
    "        predictor=lambda X: model(X)[:, -1],   # for sklearn API\n",
    "        constraint=constraint,\n",
    "        tolerance=0.05, # fairness constraint tolerance\n",
    "    )\n",
    "\n",
    "    # Fit the fairness adjustment on some data\n",
    "    # This will find the optimal _fair classifier_\n",
    "    fair_clf.fit(X=torch.tensor(range(len(y_val_t))), y=y_val_t.numpy(), group=a_val_t.numpy())\n",
    "\n",
    "    # overwrite model for predictor\n",
    "    utk_test_m_probits = torch.mean(utk_test_probits[m], dim=0)\n",
    "    model.probits = utk_test_m_probits\n",
    "\n",
    "    # Now you can use `fair_clf` as any other classifier\n",
    "    # You have to provide group information to compute fair predictions\n",
    "    y_pred_test = fair_clf(X=torch.tensor(range(len(y_utk_test_t))), group=a_utk_test_t.numpy())\n",
    "    y_pred_test = torch.tensor(y_pred_test, dtype=torch.long)\n",
    "\n",
    "    if verbose: print(\"BMA-PP\")\n",
    "    accs_bma_pp.append(accuracy(y_pred_test, y_utk_test_t).item())\n",
    "    if verbose: print(f\"  {(accuracy(y_pred_test, y_utk_test_t).item()):.3f}\")\n",
    "    if c == 0:\n",
    "        fairs_bma_pp.append(spd(y_pred_test, a_utk_test_t).item())\n",
    "        if verbose: print(f\"  {spd(y_pred_test, a_utk_test_t).item():.3f}\")\n",
    "    elif c == 1:\n",
    "        fairs_bma_pp.append(eod(y_pred_test, y_utk_test_t, a_utk_test_t).item())\n",
    "        if verbose: print(f\"  {eod(y_pred_test, y_utk_test_t, a_utk_test_t).item():.3f}\")\n",
    "    elif c == 2:\n",
    "        fairs_bma_pp.append(aod(y_pred_test, y_utk_test_t, a_utk_test_t).item())\n",
    "        if verbose: print(f\"  {aod(y_pred_test, y_utk_test_t, a_utk_test_t).item():.3f}\")\n",
    "\n",
    "    for mem in range(len(val_probits[m])):\n",
    "        val_m_probits = val_probits[m][mem]\n",
    "\n",
    "        model = DummyPredictor(val_m_probits)\n",
    "\n",
    "        # Given any trained model that outputs real-valued scores\n",
    "        fair_clf = RelaxedThresholdOptimizer(\n",
    "            predictor=lambda X: model(X)[:, -1],   # for sklearn API\n",
    "            constraint=constraint,\n",
    "            tolerance=0.05, # fairness constraint tolerance\n",
    "        )\n",
    "\n",
    "        # Fit the fairness adjustment on some data\n",
    "        # This will find the optimal _fair classifier_\n",
    "        fair_clf.fit(X=torch.tensor(range(len(y_val_t))), y=y_val_t.numpy(), group=a_val_t.numpy())\n",
    "\n",
    "        # overwrite model for predictor\n",
    "        utk_test_m_probits = utk_test_probits[m][0]\n",
    "        model.probits = utk_test_m_probits\n",
    "\n",
    "        # Now you can use `fair_clf` as any other classifier\n",
    "        # You have to provide group information to compute fair predictions\n",
    "        y_pred_test = fair_clf(X=torch.tensor(range(len(y_utk_test_t))), group=a_utk_test_t.numpy())\n",
    "        y_pred_test = torch.tensor(y_pred_test, dtype=torch.long)\n",
    "\n",
    "        if mem == 0 and verbose : print(\"Member-PP\")\n",
    "        accs_member_pp.append(accuracy(y_pred_test, y_utk_test_t).item())\n",
    "        if mem == 0 and verbose : print(f\"  {(accuracy(y_pred_test, y_utk_test_t).item()):.3f}\")\n",
    "        if c == 0:\n",
    "            fairs_member_pp.append(spd(y_pred_test, a_utk_test_t).item())\n",
    "            if mem == 0 and verbose : print(f\"  {spd(y_pred_test, a_utk_test_t).item():.3f}\")\n",
    "        elif c == 1:\n",
    "            fairs_member_pp.append(eod(y_pred_test, y_utk_test_t, a_utk_test_t).item())\n",
    "            if mem == 0 and verbose : print(f\"  {eod(y_pred_test, y_utk_test_t, a_utk_test_t).item():.3f}\")\n",
    "        elif c == 2:\n",
    "            fairs_member_pp.append(aod(y_pred_test, y_utk_test_t, a_utk_test_t).item())\n",
    "            if mem == 0 and verbose : print(f\"  {aod(y_pred_test, y_utk_test_t, a_utk_test_t).item():.3f}\")\n",
    "\n",
    "print(\"-\"*30)\n",
    "print(f\"${np.mean(accs_bma_pp):.3f}_{'{'}\\pm {np.std(accs_bma_pp):.3f}{'}'}$\", end=\" & \")\n",
    "print(f\"${np.mean(fairs_bma_pp):.3f}_{'{'}\\pm {np.std(fairs_bma_pp):.3f}{'}'}$\")\n",
    "print(f\"${np.mean(accs_member_pp):.3f}_{'{'}\\pm {np.std(accs_member_pp):.3f}{'}'}$\", end=\" & \")\n",
    "print(f\"${np.mean(fairs_member_pp):.3f}_{'{'}\\pm {np.std(fairs_member_pp):.3f}{'}'}$\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "quam",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
