{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 59,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "\n",
    "import torch\n",
    "\n",
    "from sklearn.calibration import calibration_curve\n",
    "\n",
    "from source.constants import RESULTS_PATH, PLOTS_PATH\n",
    "from source.data.face_detection import get_fair_face, get_utk\n",
    "\n",
    "os.makedirs(PLOTS_PATH, exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "metadata": {},
   "outputs": [],
   "source": [
    "method_seeds = [42, 142, 242, 342, 442]\n",
    "dseed = 42\n",
    "\n",
    "model = [\"resnet18\", \"resnet34\", \"resnet50\", \"regnet\", \"efficientnet\", \"efficientnet_mcdropout\"][2]\n",
    "\n",
    "targets = [\"age\", \"gender\", \"race (old)\", \"race\"]\n",
    "# predicting race does not give high unfairness (with either pa) for eod and aod\n",
    "# predicting gender also not too nice (only unfairness with age)\n",
    "target = 3 # 0, 1, 2, 3\n",
    "pa = 0 # 0, 1, 2, 3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# no need to define targets and protected attributes, are queried directly afterwards\n",
    "ff_train_ds, ff_test_ds = get_fair_face(binarize=True, augment=False)\n",
    "utk_test_ds = get_utk(binarize=True)\n",
    "\n",
    "run_path = os.path.join(RESULTS_PATH, f\"fairface_target{target}_{model}_mseed{method_seeds[0]}_dseed{dseed}\")\n",
    "fair_inds = torch.load(os.path.join(run_path, \"fair_inds.pt\"))\n",
    "val_inds = torch.load(os.path.join(run_path, \"val_inds.pt\"))\n",
    "\n",
    "print(len(fair_inds), len(val_inds), len(ff_test_ds), len(utk_test_ds))\n",
    "\n",
    "# get targets and protected attributes\n",
    "y_fair_t = ff_train_ds.targets[target, fair_inds]\n",
    "a_fair_t = ff_train_ds.targets[pa, fair_inds]\n",
    "y_val_t = ff_train_ds.targets[target, val_inds]\n",
    "a_val_t = ff_train_ds.targets[pa, val_inds]\n",
    "y_ff_test_t = ff_test_ds.targets[target]\n",
    "a_ff_test_t = ff_test_ds.targets[pa]\n",
    "y_utk_test_t = utk_test_ds.targets[target]\n",
    "a_utk_test_t = utk_test_ds.targets[pa]\n",
    "\n",
    "# get fraction of protected attribute for testsets\n",
    "p_a_ff_test = a_ff_test_t.float().mean().item() * 100\n",
    "p_a_utk_test = a_utk_test_t.float().mean().item() * 100\n",
    "print(p_a_ff_test, p_a_utk_test)\n",
    "\n",
    "# get fraction of targets for testsets\n",
    "p_y_ff_test = y_ff_test_t.float().mean().item() * 100\n",
    "p_y_utk_test = y_utk_test_t.float().mean().item() * 100\n",
    "print(p_y_ff_test, p_y_utk_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load probits\n",
    "fair_probits, val_probits, ff_test_probits, utk_test_probits = list(), list(), list(), list()\n",
    "for mseed in method_seeds:\n",
    "    path = os.path.join(RESULTS_PATH, f\"fairface_target{target}_{model}_mseed{mseed}_dseed{dseed}\")\n",
    "\n",
    "    fair_probits.append(torch.load(os.path.join(path, f\"fair_probits_t{target}.pt\")))\n",
    "    val_probits.append(torch.load(os.path.join(path, f\"val_probits_t{target}.pt\")))\n",
    "    ff_test_probits.append(torch.load(os.path.join(path, f\"ff_test_probits_t{target}.pt\")))\n",
    "    utk_test_probits.append(torch.load(os.path.join(path, f\"utk_test_probits_t{target}.pt\")))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "metadata": {},
   "outputs": [],
   "source": [
    "def ece(y_probs, y_trues, n_bins):\n",
    "    # Compute the calibration curve\n",
    "    fraction_of_positives, mean_predicted_value = calibration_curve(y_trues, y_probs, n_bins=n_bins, strategy='uniform')\n",
    "    \n",
    "    # Define bin edges\n",
    "    bin_edges = np.linspace(0.0, 1.0, n_bins + 1)\n",
    "    \n",
    "    # Assign each probability prediction to a bin\n",
    "    bin_indices = np.digitize(y_probs, bins=bin_edges, right=True) - 1\n",
    "    # Correct any indices that are out of bounds\n",
    "    bin_indices = np.clip(bin_indices, 0, n_bins - 1)\n",
    "    \n",
    "    # Total number of samples\n",
    "    n_samples = len(y_trues)\n",
    "    \n",
    "    # Count the number of samples per bin\n",
    "    bin_counts = np.bincount(bin_indices, minlength=n_bins)\n",
    "    \n",
    "    # Calculate the weight of each bin (proportion of total samples)\n",
    "    bin_weights = bin_counts / n_samples\n",
    "    \n",
    "    # Compute the absolute difference between accuracy and confidence for each bin\n",
    "    bin_errors = np.abs(fraction_of_positives - mean_predicted_value)\n",
    "\n",
    "    # Calculate the Expected Calibration Error\n",
    "    ece = np.sum(bin_weights * bin_errors)\n",
    "    \n",
    "    return ece\n",
    "\n",
    "# y_true = np.asarray([0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1])\n",
    "# y_prob = np.asarray([0.1, 0.4, 0.35, 0.8, 0.1, 0.4, 0.25, 0.5, 0.1, 0.4, 0.35, 0.9])\n",
    "# ece = ece(y_prob, y_true, n_bins=5)\n",
    "# print(f\"Expected Calibration Error: {ece:.4f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.091 $\\pm$ 0.024\n",
      "0.046 $\\pm$ 0.006\n",
      "0.136 $\\pm$ 0.023\n",
      "0.102 $\\pm$ 0.006\n"
     ]
    }
   ],
   "source": [
    "ensemble_members = list(range(1, len(ff_test_probits[0]) + 1))\n",
    "\n",
    "ff_test_eces, utk_test_eces = list(), list()\n",
    "ff_test_m_eces, utk_test_m_eces = list(), list()\n",
    "\n",
    "for m in range(len(method_seeds)):\n",
    "\n",
    "    ff_test_eces.append([ece(p[:, 1], y_ff_test_t, n_bins=10) for p in ff_test_probits[m]])\n",
    "    utk_test_eces.append([ece(p[:, 1], y_utk_test_t, n_bins=10) for p in utk_test_probits[m]])\n",
    "\n",
    "    ff_test_fm_eces_, utk_test_fm_eces_ = list(), list()\n",
    "\n",
    "\n",
    "    probs = torch.mean(ff_test_probits[m], dim=0)[:, 1]\n",
    "    ff_test_fm_eces_.append(ece(probs, y_ff_test_t, n_bins=10))\n",
    "\n",
    "    probs = torch.mean(utk_test_probits[m], dim=0)[:, 1]\n",
    "    utk_test_fm_eces_.append(ece(probs, y_utk_test_t, n_bins=10))\n",
    "\n",
    "    ff_test_m_eces.append(ff_test_fm_eces_)\n",
    "    utk_test_m_eces.append(utk_test_fm_eces_)\n",
    "\n",
    "ff_test_m_eces = np.asarray(ff_test_m_eces).reshape(-1, )\n",
    "utk_test_m_eces = np.asarray(utk_test_m_eces).reshape(-1, )\n",
    "ff_test_eces = np.asarray(ff_test_eces).reshape(-1, )\n",
    "utk_test_eces = np.asarray(utk_test_eces).reshape(-1, )\n",
    "\n",
    "print(f\"{ff_test_eces.mean(axis=0):.3f} $\\pm$ {ff_test_eces.std(axis=0):.3f}\")\n",
    "print(f\"{ff_test_m_eces.mean(axis=0):.3f} $\\pm$ {ff_test_m_eces.std(axis=0):.3f}\")\n",
    "print(f\"{utk_test_eces.mean(axis=0):.3f} $\\pm$ {utk_test_eces.std(axis=0):.3f}\")\n",
    "print(f\"{utk_test_m_eces.mean(axis=0):.3f} $\\pm$ {utk_test_m_eces.std(axis=0):.3f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "quam",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
