{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 81,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "from error_parity import RelaxedThresholdOptimizer\n",
    "from error_parity.classifiers import RandomizedClassifier\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "\n",
    "from source.constants import RESULTS_PATH, PLOTS_PATH\n",
    "from source.data.medical_imaging import get_chexpert\n",
    "from source.utils.metrics import balanced_accuracy, aod, eod, spd\n",
    "\n",
    "os.makedirs(PLOTS_PATH, exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 82,
   "metadata": {},
   "outputs": [],
   "source": [
    "method_seeds = [42, 142, 242, 342, 442]\n",
    "dseed = 42\n",
    "\n",
    "model = [\"resnet18\", \"resnet34\", \"resnet50\"][2]\n",
    "\n",
    "verbose = False\n",
    "\n",
    "pas = [\"old\", \"woman\", \"white\"]\n",
    "pa = 2 # 0, 1, 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 83,
   "metadata": {},
   "outputs": [],
   "source": [
    "# parameters\n",
    "c = 2\n",
    "constraint = [\"demographic_parity\", \"true_positive_rate_parity\", \"average_odds\"][c]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 84,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "# patients general 65401\n",
      "# patients with race 58010\n",
      "24638 24638\n"
     ]
    }
   ],
   "source": [
    "full_ds, _, _ = get_chexpert(load_to_ram=False)\n",
    "\n",
    "run_path = os.path.join(RESULTS_PATH, f\"chexpert_{model}_mseed{method_seeds[0]}_dseed{dseed}\")\n",
    "fair_inds = torch.load(os.path.join(run_path, \"fair_inds.pt\"))\n",
    "val_inds = torch.load(os.path.join(run_path, \"val_inds.pt\"))\n",
    "\n",
    "print(len(fair_inds), len(val_inds))\n",
    "\n",
    "y_fair_t = full_ds.targets[fair_inds]\n",
    "a_fair_t = full_ds.protected_attributes[pa, fair_inds]\n",
    "y_val_t = full_ds.targets[val_inds]\n",
    "a_val_t = full_ds.protected_attributes[pa, val_inds]\n",
    "\n",
    "# switch 0 / 1 in protected attribute\n",
    "pas = [\"young\", \"man\", \"non-white\"]\n",
    "a_fair_t = 1 - a_fair_t\n",
    "a_val_t = 1 - a_val_t\n",
    "\n",
    "# switch label\n",
    "# y_fair_t = 1 - y_fair_t\n",
    "# y_val_t = 1 - y_val_t\n",
    "\n",
    "p_a_fair = a_fair_t.float().mean().item() "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 85,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load probits\n",
    "fair_probits, val_probits = list(), list()\n",
    "for mseed in method_seeds:\n",
    "    path = os.path.join(RESULTS_PATH, f\"chexpert_{model}_mseed{mseed}_dseed{dseed}\")\n",
    "\n",
    "    # don't do fairness ensemble on medical imaging - use this split as test dataset\n",
    "    fp = torch.load(os.path.join(path, f\"fair_probits.pt\"))\n",
    "    #fp = 1 - fp\n",
    "    fair_probits.append(fp)\n",
    "    vp = torch.load(os.path.join(path, f\"val_probits.pt\"))\n",
    "    #vp = 1 - vp\n",
    "    val_probits.append(vp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 86,
   "metadata": {},
   "outputs": [],
   "source": [
    "# calculate accuracies and fairness measures\n",
    "fair_balanced_accuracys, val_balanced_accuracys = list(), list()\n",
    "fair_spds, val_spds = list(), list()\n",
    "fair_eods, val_eods = list(), list()\n",
    "fair_aods, val_aods = list(), list()\n",
    "\n",
    "for m in range(len(method_seeds)):\n",
    "    fair_balanced_accuracys.append([balanced_accuracy(p.argmax(dim=1), y_fair_t) for p in fair_probits[m]])\n",
    "    val_balanced_accuracys.append([balanced_accuracy(p.argmax(dim=1), y_val_t) for p in val_probits[m]])\n",
    "\n",
    "    fair_spds.append([spd(p.argmax(dim=1), a_fair_t) for p in fair_probits[m]])\n",
    "    val_spds.append([spd(p.argmax(dim=1), a_val_t) for p in val_probits[m]])\n",
    "\n",
    "    fair_eods.append([eod(p.argmax(dim=1), y_fair_t, a_fair_t) for p in fair_probits[m]])\n",
    "    val_eods.append([eod(p.argmax(dim=1), y_val_t, a_val_t) for p in val_probits[m]])\n",
    "\n",
    "    fair_aods.append([aod(p.argmax(dim=1), y_fair_t, a_fair_t) for p in fair_probits[m]])\n",
    "    val_aods.append([aod(p.argmax(dim=1), y_val_t, a_val_t) for p in val_probits[m]])\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 87,
   "metadata": {},
   "outputs": [],
   "source": [
    "# method to do the fake predictions\n",
    "class DummyPredictor(nn.Module):\n",
    "    def __init__(self, probits):\n",
    "        super(DummyPredictor, self).__init__()\n",
    "        self.probits = probits\n",
    "\n",
    "    def forward(self, indices:torch.Tensor):\n",
    "        return self.probits[indices].numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 88,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_thresholds(fair_clf, verbose=True):\n",
    "    thresholds = list()\n",
    "    for i in range(2):\n",
    "        if verbose: print(f\"Class {i}\")\n",
    "        if isinstance(fair_clf._realized_classifier.group_to_clf[i], RandomizedClassifier):\n",
    "            thrs = list()\n",
    "            for clf in fair_clf._realized_classifier.group_to_clf[i].classifiers:\n",
    "                if verbose: print(clf.threshold)\n",
    "                thrs.append(clf.threshold)\n",
    "            thresholds.append(thrs)   \n",
    "        else:\n",
    "            thrs = fair_clf._realized_classifier.group_to_clf[i].threshold\n",
    "            if verbose: print(thrs)\n",
    "            thresholds.append([thrs, thrs])\n",
    "    return thresholds"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Optimize Ensemble for average member constraint"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 89,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "------------------------------\n",
      "$0.783_{\\pm 0.008}$ & $0.049_{\\pm 0.004}$\n",
      "$0.786_{\\pm 0.004}$ & $0.047_{\\pm 0.002}$\n",
      "$0.802_{\\pm 0.007}$ & $0.044_{\\pm 0.006}$\n",
      "------------------------------\n",
      "Group 0\n",
      "$0.454_{\\pm 0.014}$\n",
      "$0.450_{\\pm 0.018}$\n",
      "Group 1\n",
      "$0.481_{\\pm 0.029}$\n",
      "$0.445_{\\pm 0.017}$\n"
     ]
    }
   ],
   "source": [
    "balanced_accuracys_bma, fairs_bma = list(), list()\n",
    "balanced_accuracys_bma_pp, fairs_bma_pp = list(), list()\n",
    "balanced_accuracys_avg, fairs_avg = list(), list()\n",
    "thresholds_fairs_bma_pp = list()\n",
    "\n",
    "for m in range(len(method_seeds)):\n",
    "\n",
    "    if verbose: print(\"-\"*20 + f\"  seed {m}  \" + \"-\"*20)\n",
    "\n",
    "    val_m_probits = torch.mean(val_probits[m], dim=0)\n",
    "\n",
    "    val_fairness = [val_spds[m], val_eods[m], val_aods[m]][c]\n",
    "    test_fairness = [fair_spds[m], fair_eods[m], fair_aods[m]][c]\n",
    "\n",
    "    model = DummyPredictor(val_m_probits)\n",
    "\n",
    "    # Given any trained model that outputs real-valued scores\n",
    "    fair_clf = RelaxedThresholdOptimizer(\n",
    "        predictor=lambda X: model(X)[:, -1],   # for sklearn API\n",
    "        constraint=constraint,\n",
    "        tolerance=max(np.mean(val_fairness), 0), # fairness constraint tolerance\n",
    "    )\n",
    "\n",
    "    # Fit the fairness adjustment on some data\n",
    "    # This will find the optimal _fair classifier_\n",
    "    fair_clf.fit(X=torch.tensor(range(len(y_val_t))), y=y_val_t.numpy(), group=a_val_t.numpy())\n",
    "\n",
    "    # Get the thresholds for the optimal classifier\n",
    "    thresholds_fairs_bma_pp.append(get_thresholds(fair_clf, verbose=verbose))\n",
    "\n",
    "    # overwrite model for predictor\n",
    "    ff_test_m_probits = torch.mean(fair_probits[m], dim=0)\n",
    "    model.probits = ff_test_m_probits\n",
    "\n",
    "    # Now you can use `fair_clf` as any other classifier\n",
    "    # You have to provide group information to compute fair predictions\n",
    "    y_pred_test = fair_clf(X=torch.tensor(range(len(y_fair_t))), group=a_fair_t.numpy())\n",
    "    y_pred_test = torch.tensor(y_pred_test, dtype=torch.long)\n",
    "\n",
    "    if verbose: print(\"Avg Member\")\n",
    "    balanced_accuracys_avg.extend(fair_balanced_accuracys[m])\n",
    "    if verbose: print(f\"  {(fair_balanced_accuracys[0][m]):.3f} \")\n",
    "    fairs_avg.extend(test_fairness)\n",
    "    if verbose: print(f\"  {test_fairness[0]:.3f} (val: {val_fairness[0]:.3f})\")\n",
    "    if verbose: print(\"BMA\")\n",
    "    balanced_accuracys_bma.append(balanced_accuracy(ff_test_m_probits.argmax(dim=1), y_fair_t).item())\n",
    "    if verbose: print(f\"  {(balanced_accuracy(ff_test_m_probits.argmax(dim=1), y_fair_t).item()):.3f}\")\n",
    "    if c == 0:\n",
    "        fairs_bma.append(spd(ff_test_m_probits.argmax(dim=1), a_fair_t).item())\n",
    "        if verbose: print(f\"  {spd(ff_test_m_probits.argmax(dim=1), a_fair_t).item():.3f}\")\n",
    "    elif c == 1:\n",
    "        fairs_bma.append(eod(ff_test_m_probits.argmax(dim=1), y_fair_t, a_fair_t).item())\n",
    "        if verbose: print(f\"  {eod(ff_test_m_probits.argmax(dim=1), y_fair_t, a_fair_t).item():.3f}\")\n",
    "    elif c == 2:\n",
    "        fairs_bma.append(aod(ff_test_m_probits.argmax(dim=1), y_fair_t, a_fair_t).item())\n",
    "        if verbose: print(f\"  {aod(ff_test_m_probits.argmax(dim=1), y_fair_t, a_fair_t).item():.3f}\")\n",
    "    if verbose: print(\"BMA-PP\")\n",
    "    balanced_accuracys_bma_pp.append(balanced_accuracy(y_pred_test, y_fair_t).item())\n",
    "    if verbose: print(f\"  {(balanced_accuracy(y_pred_test, y_fair_t).item()):.3f} \")\n",
    "    if c == 0:\n",
    "        fairs_bma_pp.append(spd(y_pred_test, a_fair_t).item())\n",
    "        if verbose: print(f\"  {spd(y_pred_test, a_fair_t).item():.3f}\")\n",
    "    elif c == 1:\n",
    "        fairs_bma_pp.append(eod(y_pred_test, y_fair_t, a_fair_t).item())\n",
    "        if verbose: print(f\"  {eod(y_pred_test, y_fair_t, a_fair_t).item():.3f}\")\n",
    "    elif c == 2:\n",
    "        fairs_bma_pp.append(aod(y_pred_test, y_fair_t, a_fair_t).item())\n",
    "        if verbose: print(f\"  {aod(y_pred_test, y_fair_t, a_fair_t).item():.3f}\")\n",
    "\n",
    "thresholds_fairs_bma_pp = np.asarray(thresholds_fairs_bma_pp)\n",
    "\n",
    "print(\"-\"*30)\n",
    "print(f\"${np.mean(balanced_accuracys_avg):.3f}_{'{'}\\pm {np.std(balanced_accuracys_avg):.3f}{'}'}$\", end=\" & \")\n",
    "print(f\"${np.mean(fairs_avg):.3f}_{'{'}\\pm {np.std(fairs_avg):.3f}{'}'}$\")\n",
    "print(f\"${np.mean(balanced_accuracys_bma):.3f}_{'{'}\\pm {np.std(balanced_accuracys_bma):.3f}{'}'}$\", end=\" & \")\n",
    "print(f\"${np.mean(fairs_bma):.3f}_{'{'}\\pm {np.std(fairs_bma):.3f}{'}'}$\")\n",
    "print(f\"${np.mean(balanced_accuracys_bma_pp):.3f}_{'{'}\\pm {np.std(balanced_accuracys_bma_pp):.3f}{'}'}$\", end=\" & \")\n",
    "print(f\"${np.mean(fairs_bma_pp):.3f}_{'{'}\\pm {np.std(fairs_bma_pp):.3f}{'}'}$\")\n",
    "print(\"-\"*30)\n",
    "for i in range(2):\n",
    "    print(f\"Group {i}\")\n",
    "    print(f\"${np.mean(thresholds_fairs_bma_pp[:, i, 0]):.3f}_{'{'}\\pm {np.std(thresholds_fairs_bma_pp[:, i, 0]):.3f}{'}'}$\")\n",
    "    print(f\"${np.mean(thresholds_fairs_bma_pp[:, i, 1]):.3f}_{'{'}\\pm {np.std(thresholds_fairs_bma_pp[:, i, 1]):.3f}{'}'}$\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 90,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "------------------------------\n",
      "$0.799_{\\pm 0.004}$ & $0.053_{\\pm 0.013}$\n",
      "$0.793_{\\pm 0.011}$ & $0.047_{\\pm 0.016}$\n",
      "------------------------------\n",
      "Group 0\n",
      "$0.491_{\\pm 0.025}$\n",
      "$0.456_{\\pm 0.014}$\n",
      "Group 1\n",
      "$0.461_{\\pm 0.022}$\n",
      "$0.445_{\\pm 0.017}$\n",
      "------------------------------\n",
      "Group 0\n",
      "$0.481_{\\pm 0.030}$\n",
      "$0.469_{\\pm 0.026}$\n",
      "Group 1\n",
      "$0.488_{\\pm 0.034}$\n",
      "$0.469_{\\pm 0.036}$\n"
     ]
    }
   ],
   "source": [
    "balanced_accuracys_bma_pp, fairs_bma_pp = list(), list()\n",
    "balanced_accuracys_member_pp, fairs_member_pp = list(), list()\n",
    "thresholds_fairs_bma_pp = list()\n",
    "thresholds_fairs_member_pp = list()\n",
    "\n",
    "for m in range(len(method_seeds)):\n",
    "\n",
    "    if verbose: print(\"-\"*20 + f\"  seed {m}  \" + \"-\"*20)\n",
    "\n",
    "    val_m_probits = torch.mean(val_probits[m], dim=0)\n",
    "\n",
    "    model = DummyPredictor(val_m_probits)\n",
    "\n",
    "    # Given any trained model that outputs real-valued scores\n",
    "    fair_clf = RelaxedThresholdOptimizer(\n",
    "        predictor=lambda X: model(X)[:, -1],   # for sklearn API\n",
    "        constraint=constraint,\n",
    "        tolerance=0.05, # fairness constraint tolerance\n",
    "    )\n",
    "\n",
    "    # Fit the fairness adjustment on some data\n",
    "    # This will find the optimal _fair classifier_\n",
    "    fair_clf.fit(X=torch.tensor(range(len(y_val_t))), y=y_val_t.numpy(), group=a_val_t.numpy())\n",
    "\n",
    "    thresholds_fairs_bma_pp.append(get_thresholds(fair_clf, verbose=verbose))\n",
    "\n",
    "    # overwrite model for predictor\n",
    "    ff_test_m_probits = torch.mean(fair_probits[m], dim=0)\n",
    "    model.probits = ff_test_m_probits\n",
    "\n",
    "    # Now you can use `fair_clf` as any other classifier\n",
    "    # You have to provide group information to compute fair predictions\n",
    "    y_pred_test = fair_clf(X=torch.tensor(range(len(y_fair_t))), group=a_fair_t.numpy())\n",
    "    y_pred_test = torch.tensor(y_pred_test, dtype=torch.long)\n",
    "\n",
    "    if verbose: print(\"BMA-PP\")\n",
    "    balanced_accuracys_bma_pp.append(balanced_accuracy(y_pred_test, y_fair_t).item())\n",
    "    if verbose: print(f\"  {(balanced_accuracy(y_pred_test, y_fair_t).item()):.3f}\")\n",
    "    if c == 0:\n",
    "        fairs_bma_pp.append(spd(y_pred_test, a_fair_t).item())\n",
    "        if verbose: print(f\"  {spd(y_pred_test, a_fair_t).item():.3f}\")\n",
    "    elif c == 1:\n",
    "        fairs_bma_pp.append(eod(y_pred_test, y_fair_t, a_fair_t).item())\n",
    "        if verbose: print(f\"  {eod(y_pred_test, y_fair_t, a_fair_t).item():.3f}\")\n",
    "    elif c == 2:\n",
    "        fairs_bma_pp.append(aod(y_pred_test, y_fair_t, a_fair_t).item())\n",
    "        if verbose: print(f\"  {aod(y_pred_test, y_fair_t, a_fair_t).item():.3f}\")\n",
    "\n",
    "    for mem in range(len(val_probits[m])):\n",
    "        val_m_probits = val_probits[m][mem]\n",
    "\n",
    "        model = DummyPredictor(val_m_probits)\n",
    "\n",
    "        # Given any trained model that outputs real-valued scores\n",
    "        fair_clf = RelaxedThresholdOptimizer(\n",
    "            predictor=lambda X: model(X)[:, -1],   # for sklearn API\n",
    "            constraint=constraint,\n",
    "            tolerance=0.05, # fairness constraint tolerance\n",
    "        )\n",
    "\n",
    "        # Fit the fairness adjustment on some data\n",
    "        # This will find the optimal _fair classifier_\n",
    "        fair_clf.fit(X=torch.tensor(range(len(y_val_t))), y=y_val_t.numpy(), group=a_val_t.numpy())\n",
    "\n",
    "        thresholds_fairs_member_pp.append(get_thresholds(fair_clf, verbose=verbose))\n",
    "\n",
    "        # overwrite model for predictor\n",
    "        ff_test_m_probits = fair_probits[m][0]\n",
    "        model.probits = ff_test_m_probits\n",
    "\n",
    "        # Now you can use `fair_clf` as any other classifier\n",
    "        # You have to provide group information to compute fair predictions\n",
    "        y_pred_test = fair_clf(X=torch.tensor(range(len(y_fair_t))), group=a_fair_t.numpy())\n",
    "        y_pred_test = torch.tensor(y_pred_test, dtype=torch.long)\n",
    "\n",
    "        if mem == 0 and verbose : print(\"Member-PP\")\n",
    "        balanced_accuracys_member_pp.append(balanced_accuracy(y_pred_test, y_fair_t).item())\n",
    "        if mem == 0 and verbose : print(f\"  {(balanced_accuracy(y_pred_test, y_fair_t).item()):.3f} \")\n",
    "        if c == 0:\n",
    "            fairs_member_pp.append(spd(y_pred_test, a_fair_t).item())\n",
    "            if mem == 0 and verbose : print(f\"  {spd(y_pred_test, a_fair_t).item():.3f}\")\n",
    "        elif c == 1:\n",
    "            fairs_member_pp.append(eod(y_pred_test, y_fair_t, a_fair_t).item())\n",
    "            if mem == 0 and verbose : print(f\"  {eod(y_pred_test, y_fair_t, a_fair_t).item():.3f}\")\n",
    "        elif c == 2:\n",
    "            fairs_member_pp.append(aod(y_pred_test, y_fair_t, a_fair_t).item())\n",
    "            if mem == 0 and verbose : print(f\"  {aod(y_pred_test, y_fair_t, a_fair_t).item():.3f}\")\n",
    "\n",
    "thresholds_fairs_bma_pp = np.asarray(thresholds_fairs_bma_pp)\n",
    "thresholds_fairs_member_pp = np.asarray(thresholds_fairs_member_pp).reshape((-1, 2, 2))\n",
    "\n",
    "print(\"-\"*30)\n",
    "print(f\"${np.mean(balanced_accuracys_bma_pp):.3f}_{'{'}\\pm {np.std(balanced_accuracys_bma_pp):.3f}{'}'}$\", end=\" & \")\n",
    "print(f\"${np.mean(fairs_bma_pp):.3f}_{'{'}\\pm {np.std(fairs_bma_pp):.3f}{'}'}$\")\n",
    "print(f\"${np.mean(balanced_accuracys_member_pp):.3f}_{'{'}\\pm {np.std(balanced_accuracys_member_pp):.3f}{'}'}$\", end=\" & \")\n",
    "print(f\"${np.mean(fairs_member_pp):.3f}_{'{'}\\pm {np.std(fairs_member_pp):.3f}{'}'}$\")\n",
    "print(\"-\"*30)\n",
    "for i in range(2):\n",
    "    print(f\"Group {i}\")\n",
    "    print(f\"${np.mean(thresholds_fairs_bma_pp[:, i, 0]):.3f}_{'{'}\\pm {np.std(thresholds_fairs_bma_pp[:, i, 0]):.3f}{'}'}$\")\n",
    "    print(f\"${np.mean(thresholds_fairs_bma_pp[:, i, 1]):.3f}_{'{'}\\pm {np.std(thresholds_fairs_bma_pp[:, i, 1]):.3f}{'}'}$\")\n",
    "print(\"-\"*30)\n",
    "for i in range(2):\n",
    "    print(f\"Group {i}\")\n",
    "    print(f\"${np.mean(thresholds_fairs_member_pp[:, i, 0]):.3f}_{'{'}\\pm {np.std(thresholds_fairs_member_pp[:, i, 0]):.3f}{'}'}$\")\n",
    "    print(f\"${np.mean(thresholds_fairs_member_pp[:, i, 1]):.3f}_{'{'}\\pm {np.std(thresholds_fairs_member_pp[:, i, 1]):.3f}{'}'}$\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "quam",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
