{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "\n",
    "import torch\n",
    "\n",
    "from sklearn.calibration import calibration_curve\n",
    "\n",
    "from source.constants import RESULTS_PATH, PLOTS_PATH\n",
    "from source.data.medical_imaging import get_chexpert\n",
    "\n",
    "os.makedirs(PLOTS_PATH, exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "method_seeds = [42, 142, 242, 342, 442]\n",
    "dseed = 42\n",
    "\n",
    "model = [\"resnet18\", \"resnet34\", \"resnet50\", \"regnet\", \"efficientnet\", \"efficientnet_mcdropout\"][2]\n",
    "\n",
    "pas = [\"old\", \"woman\", \"white\"]\n",
    "pa = 0 # 0, 1, 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "# patients general 65401\n",
      "# patients with race 58010\n",
      "24638 24638\n"
     ]
    }
   ],
   "source": [
    "full_ds, _, _ = get_chexpert(load_to_ram=False)\n",
    "\n",
    "run_path = os.path.join(RESULTS_PATH, f\"chexpert_{model}_mseed{method_seeds[0]}_dseed{dseed}\")\n",
    "fair_inds = torch.load(os.path.join(run_path, \"fair_inds.pt\"))\n",
    "val_inds = torch.load(os.path.join(run_path, \"val_inds.pt\"))\n",
    "\n",
    "print(len(fair_inds), len(val_inds))\n",
    "\n",
    "y_fair_t = full_ds.targets[fair_inds]\n",
    "a_fair_t = full_ds.protected_attributes[pa, fair_inds]\n",
    "y_val_t = full_ds.targets[val_inds]\n",
    "a_val_t = full_ds.protected_attributes[pa, val_inds]\n",
    "\n",
    "# switch 0 / 1 in protected attribute\n",
    "pas = [\"young\", \"man\", \"non-white\"]\n",
    "a_fair_t = 1 - a_fair_t\n",
    "a_val_t = 1 - a_val_t\n",
    "\n",
    "# switch label\n",
    "# y_fair_t = 1 - y_fair_t\n",
    "# y_val_t = 1 - y_val_t\n",
    "\n",
    "p_a_fair = a_fair_t.float().mean().item() * 100\n",
    "p_y_fair = y_fair_t.float().mean().item() * 100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load probits\n",
    "fair_probits, val_probits = list(), list()\n",
    "for mseed in method_seeds:\n",
    "    path = os.path.join(RESULTS_PATH, f\"chexpert_{model}_mseed{mseed}_dseed{dseed}\")\n",
    "\n",
    "    # don't do fairness ensemble on medical imaging - use this split as test dataset\n",
    "    fp = torch.load(os.path.join(path, f\"fair_probits.pt\"))\n",
    "    #fp = 1 - fp\n",
    "    fair_probits.append(fp)\n",
    "    vp = torch.load(os.path.join(path, f\"val_probits.pt\"))\n",
    "    #vp = 1 - vp\n",
    "    val_probits.append(vp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "def ece(y_probs, y_trues, n_bins):\n",
    "    # Compute the calibration curve\n",
    "    fraction_of_positives, mean_predicted_value = calibration_curve(y_trues, y_probs, n_bins=n_bins, strategy='uniform')\n",
    "    \n",
    "    # Define bin edges\n",
    "    bin_edges = np.linspace(0.0, 1.0, n_bins + 1)\n",
    "    \n",
    "    # Assign each probability prediction to a bin\n",
    "    bin_indices = np.digitize(y_probs, bins=bin_edges, right=True) - 1\n",
    "    # Correct any indices that are out of bounds\n",
    "    bin_indices = np.clip(bin_indices, 0, n_bins - 1)\n",
    "    \n",
    "    # Total number of samples\n",
    "    n_samples = len(y_trues)\n",
    "    \n",
    "    # Count the number of samples per bin\n",
    "    bin_counts = np.bincount(bin_indices, minlength=n_bins)\n",
    "    \n",
    "    # Calculate the weight of each bin (proportion of total samples)\n",
    "    bin_weights = bin_counts / n_samples\n",
    "    \n",
    "    # Compute the absolute difference between accuracy and confidence for each bin\n",
    "    bin_errors = np.abs(fraction_of_positives - mean_predicted_value)\n",
    "\n",
    "    # Calculate the Expected Calibration Error\n",
    "    ece = np.sum(bin_weights * bin_errors)\n",
    "    \n",
    "    return ece\n",
    "\n",
    "# y_true = np.asarray([0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1])\n",
    "# y_prob = np.asarray([0.1, 0.4, 0.35, 0.8, 0.1, 0.4, 0.25, 0.5, 0.1, 0.4, 0.35, 0.9])\n",
    "# ece = ece(y_prob, y_true, n_bins=5)\n",
    "# print(f\"Expected Calibration Error: {ece:.4f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(5,) (50,)\n",
      "0.011 $\\pm$ 0.001\n",
      "0.012 $\\pm$ 0.003\n"
     ]
    }
   ],
   "source": [
    "ensemble_members = list(range(1, len(fair_probits[0]) + 1))\n",
    "\n",
    "fair_m_eces, fair_eces,  = list(), list()\n",
    "\n",
    "for m in range(len(method_seeds)):\n",
    "\n",
    "    fair_eces.append([ece(p[:, 1], y_fair_t, n_bins=10) for p in fair_probits[m]])\n",
    "\n",
    "    fair_fm_eces_ = list()\n",
    "\n",
    "\n",
    "    probs = torch.mean(fair_probits[m], dim=0)[:, 1]\n",
    "    fair_fm_eces_.append(ece(probs, y_fair_t, n_bins=10))\n",
    "\n",
    "    fair_m_eces.append(fair_fm_eces_)\n",
    "\n",
    "fair_m_eces = np.asarray(fair_m_eces).reshape(-1, )\n",
    "fair_eces = np.asarray(fair_eces).reshape(-1, )\n",
    "\n",
    "print(f\"{fair_eces.mean(axis=0):.3f} $\\pm$ {fair_eces.std(axis=0):.3f}\")\n",
    "print(f\"{fair_m_eces.mean(axis=0):.3f} $\\pm$ {fair_m_eces.std(axis=0):.3f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "quam",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
