{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "import torch.nn.functional as F\n",
    "import os\n",
    "\n",
    "# config\n",
    "result_folder_path = '/home/.../pfgan_hub/G-PATE/fashion_mnist_binary_eps10_results'\n",
    "eval_mode = 'ours' \n",
    "bias_factor = 'z' # 'both' or 'z'\n",
    "from collections import defaultdict\n",
    "\n",
    "\n",
    "# get gen data path list\n",
    "gen_data_path_list = []\n",
    "for x in os.listdir(result_folder_path):\n",
    "    if eval_mode in x:\n",
    "        for y in os.listdir(os.path.join(result_folder_path, x)):\n",
    "            if '_labeled.npz' in y:\n",
    "                gen_data_path_list.append(os.path.join(result_folder_path, x, y))\n",
    "\n",
    "print(f\"Evaluating {eval_mode}: \")\n",
    "for gen_data_path in gen_data_path_list:\n",
    "    print(gen_data_path)\n",
    "\n",
    "result_dict = defaultdict(list)\n",
    "unif = torch.Tensor([0.25, 0.25, 0.25, 0.25])\n",
    "\n",
    "# for each gen_data, evaluate FIDs\n",
    "for data_path in gen_data_path_list:\n",
    "    print(\"Current gen_data: \", data_path)\n",
    "\n",
    "    # load data\n",
    "    data_x = np.load(data_path)['data_x']\n",
    "    data_y = np.load(data_path)['data_y']\n",
    "    data_z = np.load(data_path)['data_z']\n",
    "\n",
    "    # categorize groups\n",
    "    pairs = [ str(y)+str(z) for y,z in zip(data_y, data_z)]\n",
    "    groups, counts = np.unique(pairs, return_counts=True)\n",
    "\n",
    "    group_dict = dict(zip(groups, counts))\n",
    "    data_distrib = torch.Tensor(counts) / len(data_y)\n",
    "    print(group_dict)\n",
    "\n",
    "\n",
    "    kl_base = F.kl_div(unif.log(), data_distrib, None, None, 'sum')\n",
    "    result_dict['kl_to_uniform'].append(np.round(kl_base, 3))\n",
    "    print(f'kl_to_uniform: {kl_base:3f}')\n",
    "\n",
    "    fd_base = (torch.Tensor([0.5, 0.5]) - torch.sum(data_distrib.view((2,2)), dim=0)).norm(dim=0, p=2)\n",
    "    result_dict['fairness_discrepancy'].append(np.round(fd_base, 3))\n",
    "    print(f'fariness discrepancy: {fd_base:3f}')\n",
    "\n",
    "\n",
    "# result folder\n",
    "result_file_folder = os.path.join('diversity', result_folder_path.split('/')[-2])\n",
    "os.makedirs(result_file_folder, exist_ok = True)\n",
    "\n",
    "# save results\n",
    "savename = result_folder_path.split('/')[-1]\n",
    "with open(os.path.join(result_file_folder, f'{eval_mode}_{savename}.txt'), 'w') as f:\n",
    "    for k, v in result_dict.items():\n",
    "        f.write(f'Result for {k}: {v}\\n')\n",
    "        f.write(f'\\tmean: {np.mean(v):.3f}\\n')\n",
    "        f.write(f'\\tstd: {np.std(v):.3f}\\n')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "d1 = {\"Acc\": 0.9168999791145325, \"Acc_Z_0\": 0.7007556675062973, \"Acc_Z_1\": 0.9704304429195258, \"DP_diff\": 0.17957304785894207, \"EO_Y0_diff\": 0.040503840839788446, \"EO_Y1_diff\": 0.3945021533299998, \"EqOdds_diff\": 0.3945021533299998}\n",
    "d2 = {\"Acc\": 0.8946999907493591, \"Acc_Z_0\": 0.5863979848866498, \"Acc_Z_1\": 0.9710542732376793, \"DP_diff\": 0.25562493702770783, \"EO_Y0_diff\": 0.05655399934743992, \"EO_Y1_diff\": 0.5638740584221816, \"EqOdds_diff\": 0.5638740584221816}\n",
    "d3 = {\"Acc\": 0.9176999926567078, \"Acc_Z_0\": 0.7032745591939547, \"Acc_Z_1\": 0.9708047411104179, \"DP_diff\": 0.17440125944584384, \"EO_Y0_diff\": 0.04391044619803185, \"EO_Y1_diff\": 0.38753568721067144, \"EqOdds_diff\": 0.38753568721067144}\n",
    "\n",
    "arr = []\n",
    "\n",
    "for i in [d1, d2, d3]:\n",
    "    arr.append(i['DP_diff'])\n",
    "\n",
    "print(np.mean(arr))\n",
    "print(np.std(arr))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "openai",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
