{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "'''\n",
    "Evaluate FID statistic with unbiased data, following [Choi et. al, ICML'20]\n",
    "\n",
    "References:\n",
    "    - Dingfan Chen, GS-WGAN, 2020, https://github.com/DingfanChen/GS-WGAN/blob/main/evaluation/eval_fid.py\n",
    "    - Maximilian Seitzer, Python Package pytroch-fid, 2020, https://github.com/mseitzer/pytorch-fid\n",
    "\n",
    "Config:\n",
    "    dataset: 'mnist' or 'fmnist'\n",
    "    gpu_num: str indicating GPU device to run.\n",
    "    batch_size: To inference InceptionV3 Net.\n",
    "    target_path: Folder containing trained generators.\n",
    "    num_gen_img: Number of generated images to evaluate.\n",
    "    bias_factor: 'z' or 'y' or 'multi'\n",
    "    select_runs: Whether to select top FID runs.\n",
    "'''\n",
    "\n",
    "import torch\n",
    "import os\n",
    "import numpy as np\n",
    "\n",
    "\n",
    "# config ==============\n",
    "dataset = 'fmnist'\n",
    "# dataset = 'fmnist'\n",
    "\n",
    "# target_model = 'datalens'\n",
    "target_model = 'gpate'\n",
    "\n",
    "result_folder_path = '/home/.../pfgan_hub/G-PATE/fashion_mnist_binary_eps10_results'\n",
    "# result_folder_path = '/home/.../pfgan_hub/G-PATE/large_celebA_eps1_results'\n",
    "\n",
    "eval_mode = 'ours'\n",
    "# eval_mode = 'base'\n",
    "\n",
    "\n",
    "gpu_num = '0'\n",
    "only_y = False\n",
    "#========================\n",
    "random_seed = 0\n",
    "batch_size = 100\n",
    "\n",
    "\n",
    "# random seed\n",
    "torch.manual_seed(random_seed)\n",
    "np.random.seed(random_seed)\n",
    "\n",
    "# environment\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = gpu_num\n",
    "\n",
    "if dataset == 'small_celeba':\n",
    "    dpath = '/home/.../pfgan_hub/dataset/celebA/train_celeba_gender_32x32.npz'\n",
    "    real_data = np.load(dpath)['data_x']\n",
    "    real_label = np.load(dpath)['data_y']\n",
    "\n",
    "    # balance data\n",
    "    major_idx = np.where(real_label == 1)[0]\n",
    "    minor_idx = np.where(real_label == 0)[0]\n",
    "\n",
    "    balanced_num = min(len(major_idx), len(minor_idx))\n",
    "    balanced_id = np.concatenate([major_idx[:balanced_num], minor_idx[:balanced_num]])\n",
    "\n",
    "    real_data = real_data[balanced_id].transpose((0, 3, 1, 2))\n",
    "    real_label = real_label[balanced_id]\n",
    "    real_z = None\n",
    "    # minor, major\n",
    "    digit_list = [0, 1]\n",
    "    img_size = 32\n",
    "\n",
    "elif dataset == 'large_celeba':\n",
    "    dpath = '/home/.../pfgan_hub/dataset/celebA/train_celeba_gender_64x64.npz'\n",
    "    real_data = np.load(dpath)['data_x']\n",
    "    real_label = np.load(dpath)['data_y']\n",
    "\n",
    "    # balance data\n",
    "    major_idx = np.where(real_label == 1)[0]\n",
    "    minor_idx = np.where(real_label == 0)[0]\n",
    "\n",
    "    balanced_num = min(len(major_idx), len(minor_idx))\n",
    "    balanced_id = np.concatenate([major_idx[:balanced_num], minor_idx[:balanced_num]])\n",
    "\n",
    "    real_data = real_data[balanced_id].transpose((0, 3, 1, 2))\n",
    "    real_label = real_label[balanced_id]\n",
    "    real_z = None\n",
    "    # minor, major\n",
    "    digit_list = [0, 1]\n",
    "    img_size = 64\n",
    "\n",
    "# if dataset == 'mnist':\n",
    "#     dpath = '/home/.../PF-GAN/dataset/rotated/mnist_31_Unbiased'\n",
    "#     real_data = torch.load(os.path.join(dpath, 'train_data.pt'))\n",
    "#     real_label = torch.load(os.path.join(dpath, 'train_Y.pt'))\n",
    "#     real_z = torch.load(os.path.join(dpath, 'train_A.pt'))\n",
    "#     # minor, major\n",
    "#     digit_list = [1, 3]\n",
    "\n",
    "elif dataset == 'fmnist':\n",
    "    dpath = '/home/.../nas/PF-GAN/dataset/rotated/fmnist_71_Unbiased'\n",
    "    real_data = torch.load(os.path.join(dpath, 'train_data.pt')).unsqueeze(1).numpy()\n",
    "    real_label = torch.load(os.path.join(dpath, 'train_Y.pt')).numpy()\n",
    "    real_z = torch.load(os.path.join(dpath, 'train_A.pt')).numpy()\n",
    "    # minor, major\n",
    "    digit_list = [1, 7]\n",
    "# else:\n",
    "#     raise NotImplementedError\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "'''\n",
    "prepare InceptionV3 model\n",
    "'''\n",
    "\n",
    "\n",
    "from pytorch_fid.inception import InceptionV3\n",
    "\n",
    "\n",
    "block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[2048]\n",
    "model = InceptionV3([block_idx])\n",
    "load_model = model.cuda()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "'''\n",
    "get statistic of activation\n",
    "'''\n",
    "\n",
    "from tqdm import tqdm\n",
    "from torch.nn.functional import adaptive_avg_pool2d\n",
    "import sys\n",
    "\n",
    "sys.path.insert(0, '../../GS-WGAN/source/')\n",
    "\n",
    "STAT_DIR = './stats'\n",
    "\n",
    "# ========= Functions ====================================\n",
    "def mkdir(dir):\n",
    "    if not os.path.exists(dir):\n",
    "        os.makedirs(dir)\n",
    "\n",
    "\n",
    "def get_act(model, batch_size, gen_data):\n",
    "    '''\n",
    "    Given InceptionV3 model, get statistic of gen_data. \n",
    "    Note gen_data should have size ( * , 28, 28, 1), type ndarray, and normalized from 0 to 1. (for binary)\n",
    "    Returns:\n",
    "        mean, cov\n",
    "    '''\n",
    "    model.eval()\n",
    "    \n",
    "    if gen_data.shape[0] < batch_size:\n",
    "        print(f'Group Size({gen_data.shape[0]}) is smaller than batch size({batch_size})')\n",
    "        n_batches = 1\n",
    "        n_used_imgs = gen_data.shape[0]\n",
    "        smaller_flag = True\n",
    "\n",
    "    else:\n",
    "        n_batches = gen_data.shape[0] // batch_size\n",
    "        n_used_imgs = n_batches * batch_size\n",
    "        smaller_flag = False\n",
    "\n",
    "\n",
    "    pred_arr = np.empty((n_used_imgs, 2048))\n",
    "    for i in tqdm(range(n_batches)):\n",
    "        if smaller_flag:\n",
    "            start = 0\n",
    "            end = batch_size = gen_data.shape[0]\n",
    "            images = gen_data[start:end]\n",
    "        else:\n",
    "            start = i * batch_size\n",
    "            end = start + batch_size\n",
    "            images = gen_data[start:end]\n",
    "\n",
    "        if images.shape[1] != 3:\n",
    "            images = images.transpose((0, 3, 1, 2))\n",
    "            images = np.tile(images, [1, 3, 1, 1])\n",
    "\n",
    "        batch = torch.from_numpy(images).type(torch.FloatTensor).cuda()\n",
    "        pred = model(batch)[0]\n",
    "\n",
    "        if pred.shape[2] != 1 or pred.shape[3] != 1:\n",
    "            pred = adaptive_avg_pool2d(pred, output_size=(1, 1))\n",
    "\n",
    "        pred_arr[start:end] = pred.cpu().data.numpy().reshape(batch_size, -1)\n",
    "\n",
    "    mu = np.mean(pred_arr, axis=0)\n",
    "    sigma = np.cov(pred_arr, rowvar=False)\n",
    "\n",
    "    return mu, sigma\n",
    "# =======================================================\n",
    "\n",
    "\n",
    "# check whether stats are pre-exist\n",
    "# stat_file = os.path.join(STAT_DIR, dataset, bias_factor, 'stat.npz')\n",
    "\n",
    "if not only_y:\n",
    "    stat_file = os.path.join(STAT_DIR, dataset, 'stat.npz')\n",
    "    if not os.path.exists(stat_file):\n",
    "        print('Computing statistic.')\n",
    "\n",
    "        ## Save real statistics\n",
    "        mkdir(os.path.join(STAT_DIR, dataset))\n",
    "\n",
    "        # note real data has shape [bs, 1, 28, 28], while gen data has [bs, 28, 28, 1]\n",
    "        real_data = real_data.view(-1, 28, 28, 1)\n",
    "        real_data = real_data / 255.0\n",
    "        real_data = real_data.numpy()\n",
    "\n",
    "        # get stats of all groups\n",
    "        minor, major = digit_list\n",
    "\n",
    "        idx_cln_3 = (real_label == major) & (real_z == 1)\n",
    "        idx_rot_3 = (real_label == major) & (real_z == 0)\n",
    "        idx_cln_1 = (real_label == minor) & (real_z == 1)\n",
    "        idx_rot_1 = (real_label == minor) & (real_z == 0)\n",
    "\n",
    "        m_real_all, s_real_all = get_act(model, batch_size, real_data)\n",
    "        m_real_cln_3, s_real_cln_3 = get_act(model, batch_size, real_data[idx_cln_3])\n",
    "        m_real_rot_3, s_real_rot_3 = get_act(model, batch_size, real_data[idx_rot_3])\n",
    "        m_real_cln_1, s_real_cln_1 = get_act(model, batch_size, real_data[idx_cln_1])\n",
    "        m_real_rot_1, s_real_rot_1 = get_act(model, batch_size, real_data[idx_rot_1])\n",
    "\n",
    "        np.savez(stat_file, mu_all= m_real_all, sigma_all=s_real_all, \\\n",
    "            mu_cln_3 = m_real_cln_3, sigma_cln_3 = s_real_cln_3, \n",
    "            mu_rot_3 = m_real_rot_3, sigma_rot_3 = s_real_rot_3,\n",
    "            mu_cln_1 = m_real_cln_1, sigma_cln_1 = s_real_cln_1,\n",
    "            mu_rot_1 = m_real_rot_1, sigma_rot_1 = s_real_rot_1)\n",
    "        \n",
    "\n",
    "    else:\n",
    "        ## Load pre-computed statistics\n",
    "        print('Loaded pre-computed statistic.')\n",
    "        f = np.load(stat_file)\n",
    "        \n",
    "\n",
    "        m_real_all, s_real_all = f['mu_all'][:], f['sigma_all'][:]\n",
    "        m_real_cln_3, s_real_cln_3 = f['mu_cln_3'][:], f['sigma_cln_3'][:]\n",
    "        m_real_rot_3, s_real_rot_3 = f['mu_rot_3'][:], f['sigma_rot_3'][:]\n",
    "        m_real_cln_1, s_real_cln_1 = f['mu_cln_1'][:], f['sigma_cln_1'][:]\n",
    "        m_real_rot_1, s_real_rot_1 = f['mu_rot_1'][:], f['sigma_rot_1'][:]\n",
    "\n",
    "else:\n",
    "    stat_file = os.path.join(STAT_DIR, dataset, 'stat.npz')\n",
    "\n",
    "    if not os.path.exists(stat_file):\n",
    "        print('Computing statistic.')\n",
    "\n",
    "        ## Save real statistics\n",
    "        mkdir(os.path.join(STAT_DIR, dataset))\n",
    "\n",
    "        # note real data has shape [bs, 1, 28, 28], while gen data has [bs, 28, 28, 1]\n",
    "\n",
    "        # get stats of all groups\n",
    "        minor, major = digit_list\n",
    "\n",
    "        idx_major = (real_label == major).squeeze()\n",
    "        idx_minor = (real_label == major).squeeze()\n",
    "\n",
    "        m_real_all, s_real_all = get_act(model, batch_size, real_data)\n",
    "        m_real_major, s_real_major = get_act(model, batch_size, real_data[idx_major])\n",
    "        m_real_minor, s_real_minor = get_act(model, batch_size, real_data[idx_minor])\n",
    "\n",
    "        np.savez(stat_file, mu_all= m_real_all, sigma_all=s_real_all, \\\n",
    "            mu_major = m_real_major, sigma_major = s_real_major, \n",
    "            mu_minor = m_real_minor, sigma_minor = s_real_minor)\n",
    "        \n",
    "    else:\n",
    "        ## Load pre-computed statistics\n",
    "        print('Loaded pre-computed statistic.')\n",
    "        f = np.load(stat_file)\n",
    "\n",
    "        m_real_all, s_real_all = f['mu_all'][:], f['sigma_all'][:]\n",
    "        m_real_major, s_real_major = f['mu_major'][:], f['sigma_major'][:]\n",
    "        m_real_minor, s_real_minor = f['mu_minor'][:], f['sigma_minor'][:]\n",
    "\n",
    "    \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "'''\n",
    "compute fid of gen_data_list\n",
    "'''\n",
    "\n",
    "from pytorch_fid.fid_score import calculate_frechet_distance\n",
    "import numpy as np\n",
    "from collections import defaultdict\n",
    "\n",
    "\n",
    "# get gen data path list\n",
    "gen_data_path_list = []\n",
    "for x in os.listdir(result_folder_path):\n",
    "    if eval_mode in x:\n",
    "        for y in os.listdir(os.path.join(result_folder_path, x)):\n",
    "            if only_y:\n",
    "                if '.data' in y:\n",
    "                    gen_data_path_list.append(os.path.join(result_folder_path, x, y))\n",
    "            else:\n",
    "                if '_labeled.npz' in y:\n",
    "                    gen_data_path_list.append(os.path.join(result_folder_path, x, y))\n",
    "\n",
    "print(f\"Evaluating {eval_mode}: \")\n",
    "for gen_data_path in gen_data_path_list:\n",
    "    print(gen_data_path)\n",
    "\n",
    "\n",
    "fid_values = defaultdict(list)\n",
    "# for each gen_data, evaluate FIDs\n",
    "for gen_data_path in gen_data_path_list:\n",
    "    print(\"Current gen_data: \", gen_data_path)\n",
    "\n",
    "    if not only_y:\n",
    "        # load gen_data, gen_data_y\n",
    "        if target_model == 'gpate' or target_model == 'datalens':\n",
    "            gen_data = np.load(gen_data_path)\n",
    "\n",
    "            gen_data_x = gen_data['data_x'][:60000] / 255.0\n",
    "            gen_data_x = gen_data_x.reshape(-1, 28, 28, 1)\n",
    "            gen_data_y = gen_data['data_y'][:60000]\n",
    "            gen_data_z = gen_data['data_z'][:60000]\n",
    "        else:\n",
    "            raise ValueError\n",
    "\n",
    "        # overall fid\n",
    "        m_gen_all, s_gen_all = get_act(model, batch_size, gen_data_x)\n",
    "        fid_value_all = calculate_frechet_distance(m_real_all, s_real_all, m_gen_all, s_gen_all)\n",
    "        print(\"fid_value_all: \", fid_value_all)\n",
    "        fid_values['overall'].append(np.round(fid_value_all, 3))\n",
    "\n",
    "        minor, major = digit_list\n",
    "\n",
    "        idx_cln_3 = (gen_data_y == major) & (gen_data_z == 1)\n",
    "        idx_rot_3 = (gen_data_y == major) & (gen_data_z == 0)\n",
    "        idx_cln_1 = (gen_data_y == minor) & (gen_data_z == 1)\n",
    "        idx_rot_1 = (gen_data_y == minor) & (gen_data_z == 0)\n",
    "\n",
    "        m_gen_cln_3, s_gen_cln_3 = get_act(model, batch_size, gen_data_x[idx_cln_3])\n",
    "        m_gen_rot_3, s_gen_rot_3 = get_act(model, batch_size, gen_data_x[idx_rot_3])\n",
    "        m_gen_cln_1, s_gen_cln_1 = get_act(model, batch_size, gen_data_x[idx_cln_1])\n",
    "        m_gen_rot_1, s_gen_rot_1 = get_act(model, batch_size, gen_data_x[idx_rot_1])\n",
    "\n",
    "        fid_value_cln_3 = calculate_frechet_distance(m_real_cln_3, s_real_cln_3, m_gen_cln_3, s_gen_cln_3)\n",
    "        fid_value_rot_3 = calculate_frechet_distance(m_real_rot_3, s_real_rot_3, m_gen_rot_3, s_gen_rot_3)\n",
    "        fid_value_cln_1 = calculate_frechet_distance(m_real_cln_1, s_real_cln_1, m_gen_cln_1, s_gen_cln_1)\n",
    "        fid_value_rot_1 = calculate_frechet_distance(m_real_rot_1, s_real_rot_1, m_gen_rot_1, s_gen_rot_1)\n",
    "\n",
    "        print('fid value for major class: {:.3f}(Z=1) {:.3f}(Z=0)'.format(fid_value_cln_3, fid_value_rot_3))\n",
    "        print('fid value for minor class: {:.3f}(Z=1) {:.3f}(Z=0)'.format(fid_value_cln_1, fid_value_rot_1))\n",
    "\n",
    "        fid_values['fid_major_z1'].append(fid_value_cln_3)\n",
    "        fid_values['fid_major_z0'].append(fid_value_rot_3)\n",
    "        fid_values['fid_minor_z1'].append(fid_value_cln_1)\n",
    "        fid_values['fid_minor_z0'].append(fid_value_rot_1)\n",
    "\n",
    "\n",
    "\n",
    "    else:\n",
    "        import joblib\n",
    "\n",
    "        for i in range(1):\n",
    "            gen_data = joblib.load(gen_data_path)\n",
    "            gen_data_x, gen_data_y = np.hsplit(gen_data, [-2])\n",
    "    \n",
    "            random_indices = np.random.choice(len(gen_data_x), 60000, replace=False).tolist()\n",
    "            gen_data_x = gen_data_x[random_indices]\n",
    "            gen_data_y = gen_data_y[random_indices]\n",
    "\n",
    "            gen_data_x = gen_data_x.reshape(-1, 3, img_size, img_size)\n",
    "            gen_data_y = np.argmax(gen_data_y, axis=1)\n",
    "\n",
    "            # overall fid\n",
    "            m_gen_all, s_gen_all = get_act(model, batch_size, gen_data_x)\n",
    "            fid_value_all = calculate_frechet_distance(m_real_all, s_real_all, m_gen_all, s_gen_all)\n",
    "            print(\"fid_value_all: \", fid_value_all)\n",
    "            fid_values['overall'].append(np.round(fid_value_all, 3))\n",
    "            \n",
    "            # get group indices for gen_data\n",
    "            group_indices = {}\n",
    "            for group in [0, 1]:\n",
    "                indices = np.where((gen_data_y == group))[0]\n",
    "                group_indices[group] = indices.tolist()\n",
    "\n",
    "            for group, indices in group_indices.items():\n",
    "                print(\"Y = \", group, \"\\tlen: \", len(indices))\n",
    "                m_gen, s_gen = get_act(model, batch_size, gen_data_x[indices])\n",
    "                fid_value = calculate_frechet_distance(m_real_all, s_real_all, m_gen, s_gen)\n",
    "                fid_values[group].append(np.round(fid_value, 2))\n",
    "                print(\"group: \", group, \"fid_value: \", fid_value)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# result folder\n",
    "result_file_folder = os.path.join('fid', result_folder_path.split('/')[-2])\n",
    "os.makedirs(result_file_folder, exist_ok = True)\n",
    "\n",
    "# save results\n",
    "savename = result_folder_path.split('/')[-1]\n",
    "with open(os.path.join(result_file_folder, f'{eval_mode}_{savename}.txt'), 'w') as f:\n",
    "    for k, v in fid_values.items():\n",
    "        f.write(f'fid value for {k}: {v}\\n')\n",
    "        f.write(f'\\tmean: {np.mean(v):.3f}\\n')\n",
    "        f.write(f'\\tstd: {np.std(v):.3f}\\n')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pfgan_new",
   "language": "python",
   "name": "pfgan"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "fbe2ce3c68f626f38b492c34585b74e32b7dd1d8b67edc9d605ebc5753ef1359"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
