{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "'''\n",
    "Evaluate FID statistic with unbiased data, following [Choi et. al, ICML'20]\n",
    "\n",
    "References:\n",
    "    - Dingfan Chen, GS-WGAN, 2020, https://github.com/DingfanChen/GS-WGAN/blob/main/evaluation/eval_fid.py\n",
    "    - Maximilian Seitzer, Python Package pytroch-fid, 2020, https://github.com/mseitzer/pytorch-fid\n",
    "\n",
    "Config:\n",
    "    dataset: 'mnist' or 'fmnist'\n",
    "    gpu_num: str indicating GPU device to run.\n",
    "    batch_size: To inference InceptionV3 Net.\n",
    "    target_path: Folder containing trained generators.\n",
    "    num_gen_img: Number of generated images to evaluate.\n",
    "    bias_factor: 'z' or 'y' or 'multi'\n",
    "    select_runs: Whether to select top FID runs.\n",
    "'''\n",
    "\n",
    "import torch\n",
    "import os\n",
    "import numpy as np\n",
    "\n",
    "\n",
    "# config ==============\n",
    "dataset = 'mnist'\n",
    "gpu_num = '5'\n",
    "random_seed = 0\n",
    "batch_size = 100\n",
    "target_path = '/home/.../PF-GAN/models/GS-WGAN/results/fmnist/main/cond_B_sir_small'\n",
    "num_gen_img = 20000\n",
    "num_eval_img = 400\n",
    "select_runs = False\n",
    "top_k = 3\n",
    "model_iters = None\n",
    "save_eval = True\n",
    "save_gen = True\n",
    "uncond = False\n",
    "#========================\n",
    "\n",
    "\n",
    "# random seed\n",
    "torch.manual_seed(random_seed)\n",
    "np.random.seed(random_seed)\n",
    "\n",
    "# environment\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = gpu_num\n",
    "stat_file = os.path.join(STAT_DIR, dataset, 'stat_original_multiclass.npz')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "'''\n",
    "prepare InceptionV3 model\n",
    "'''\n",
    "\n",
    "\n",
    "from pytorch_fid.inception import InceptionV3\n",
    "\n",
    "\n",
    "block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[2048]\n",
    "model = InceptionV3([block_idx])\n",
    "load_model = model.cuda()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torchvision\n",
    "import torchvision.transforms as transforms\n",
    "\n",
    "# load real data for mnist\n",
    "train_data = torchvision.datasets.MNIST(\"./\", train=True, transform=transforms.ToTensor(), download=True)\n",
    "real_data = train_data.data\n",
    "real_data = real_data.unsqueeze(1).float()\n",
    "real_label = train_data.targets\n",
    "real_label = real_label.double()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "'''\n",
    "get statistic of activation\n",
    "'''\n",
    "\n",
    "from tqdm import tqdm\n",
    "from torch.nn.functional import adaptive_avg_pool2d\n",
    "import sys\n",
    "\n",
    "\n",
    "STAT_DIR = './stats'\n",
    "\n",
    "# ========= Functions ====================================\n",
    "def mkdir(dir):\n",
    "    if not os.path.exists(dir):\n",
    "        os.makedirs(dir)\n",
    "\n",
    "\n",
    "def get_act(model, batch_size, gen_data):\n",
    "    '''\n",
    "    Given InceptionV3 model, get statistic of gen_data. \n",
    "    Note gen_data should have size ( * , 28, 28, 1), type ndarray, and normalized from 0 to 1.\n",
    "    Returns:\n",
    "        mean, cov\n",
    "    '''\n",
    "    model.eval()\n",
    "    \n",
    "    if gen_data.shape[0] < batch_size:\n",
    "        print(f'Group Size({gen_data.shape[0]}) is smaller than batch size({batch_size})')\n",
    "        n_batches = 1\n",
    "        n_used_imgs = gen_data.shape[0]\n",
    "        smaller_flag = True\n",
    "\n",
    "    else:\n",
    "        n_batches = gen_data.shape[0] // batch_size\n",
    "        n_used_imgs = n_batches * batch_size\n",
    "        smaller_flag = False\n",
    "\n",
    "\n",
    "    pred_arr = np.empty((n_used_imgs, 2048))\n",
    "    for i in tqdm(range(n_batches)):\n",
    "        if smaller_flag:\n",
    "            start = 0\n",
    "            end = batch_size = gen_data.shape[0]\n",
    "            images = gen_data[start:end]\n",
    "        else:\n",
    "            start = i * batch_size\n",
    "            end = start + batch_size\n",
    "            images = gen_data[start:end]\n",
    "\n",
    "        if images.shape[1] != 3:\n",
    "            images = images.transpose((0, 3, 1, 2))\n",
    "            images = np.tile(images, [1, 3, 1, 1])\n",
    "\n",
    "        batch = torch.from_numpy(images).type(torch.FloatTensor).cuda()\n",
    "        pred = model(batch)[0]\n",
    "\n",
    "        if pred.shape[2] != 1 or pred.shape[3] != 1:\n",
    "            pred = adaptive_avg_pool2d(pred, output_size=(1, 1))\n",
    "\n",
    "        pred_arr[start:end] = pred.cpu().data.numpy().reshape(batch_size, -1)\n",
    "\n",
    "    mu = np.mean(pred_arr, axis=0)\n",
    "    sigma = np.cov(pred_arr, rowvar=False)\n",
    "\n",
    "    return mu, sigma\n",
    "# =======================================================\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "# check whether stats are pre-exist\n",
    "# stat_file = os.path.join(STAT_DIR, dataset, bias_factor, 'stat.npz')\n",
    "\n",
    "if not os.path.exists(stat_file):\n",
    "    print('Computing statistic.')\n",
    "\n",
    "    ## Save real statistics\n",
    "    mkdir(os.path.join(STAT_DIR, dataset))\n",
    "\n",
    "    # note real data has shape [bs, 1, 28, 28], while gen data has [bs, 28, 28, 1]\n",
    "    real_data = real_data.view(-1, 28, 28, 1)\n",
    "    real_data = real_data / 255.0\n",
    "    real_data = real_data.numpy()\n",
    "\n",
    "    # get stats of all digits (0--9) and all data\n",
    "\n",
    "    mu_dict = {}\n",
    "    sigma_dict = {}\n",
    "\n",
    "    for digit in range(0, 10):\n",
    "        idx_digit = (real_label == digit)\n",
    "\n",
    "        m_real_all, s_real_all = get_act(model, batch_size, real_data[idx_digit])\n",
    "        mu_dict[digit] = m_real_all\n",
    "        sigma_dict[digit] = s_real_all\n",
    "\n",
    "\n",
    "    m_real_all, s_real_all = get_act(model, batch_size, real_data)\n",
    "    mu_dict['all'] = m_real_all\n",
    "    sigma_dict['all'] = s_real_all\n",
    "\n",
    "    np.savez(stat_file, mu_dict=mu_dict, sigma_dict=sigma_dict)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Load pre-computed statistics\n",
    "stat = np.load(stat_file, allow_pickle=True)\n",
    "mu_dict = stat['mu_dict'].item()\n",
    "sigma_dict = stat['sigma_dict'].item()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "'''\n",
    "compute fid of gen_data_list\n",
    "'''\n",
    "from pytorch_fid.fid_score import calculate_frechet_distance\n",
    "import numpy as np\n",
    "\n",
    "# config =====\n",
    "target_model = 'gswgan'\n",
    "gen_data_path = \"/home/.../pfgan_hub/GS-WGAN/results/mnist/main/multiclass_downsized_4/gen_data.npz\"\n",
    "savename = f\"mnist_downsized_4_{target_model}_nodp\"\n",
    "# ==========\n",
    "\n",
    "# load gen_data, gen_label\n",
    "if target_model == 'gswgan':\n",
    "    gen_data = np.load(gen_data_path)\n",
    "    gen_data_x = gen_data['data_x']\n",
    "    gen_data_y = gen_data['data_y']\n",
    "else:\n",
    "    raise ValueError\n",
    "\n",
    "# # compute all fid\n",
    "# m_gen_all, s_gen_all = get_act(model, batch_size, gen_data_x)\n",
    "# fid_value_all = calculate_frechet_distance(mu_dict['all'], sigma_dict['all'], m_gen_all, s_gen_all)\n",
    "# print('fid value for all class: {:.3f}'.format(fid_value_all))\n",
    "\n",
    "# # compute fid for each class\n",
    "# fid_dict = {}\n",
    "# for digit in range(0, 10):\n",
    "#     idx_digit = (gen_data_y == digit)\n",
    "#     m_gen, s_gen = get_act(model, batch_size, gen_data_x[idx_digit])\n",
    "#     fid_value = calculate_frechet_distance(mu_dict[digit], sigma_dict[digit], m_gen, s_gen)\n",
    "#     print('fid value for class {}: {:.3f}'.format(digit, fid_value))\n",
    "#     fid_dict[digit] = fid_value\n",
    "\n",
    "# # save fid dict\n",
    "# fid_dict['all'] = fid_value_all\n",
    "    \n",
    "# fid_save_path = os.path.join(STAT_DIR, f'fid_{savename}.npz')\n",
    "# np.savez(fid_save_path, fid_dict=fid_dict)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "fid_dict_new = np.load(\"/home/.../pfgan_hub/stats/fid_mnist_downsized_4_gswgan.npz\", allow_pickle=True)['fid_dict'].item()\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "except_8 = fid_dict_new.keys() - [8, 'all']\n",
    "fid = 0\n",
    "for i in except_8:\n",
    "    fid += fid_dict[i]\n",
    "fid/len(except_8)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "import sys\n",
    "import logging\n",
    "from pytorch_fid.fid_score import calculate_frechet_distance\n",
    "import matplotlib.pyplot as plt\n",
    "import copy\n",
    "import torch\n",
    "\n",
    "sys.path.insert(0, './GS-WGAN/source/')\n",
    "from models import GeneratorResNet\n",
    "from utils import mkdir, savefig\n",
    "\n",
    "# functions ====================================================\n",
    "def _generate_data(netG, save_dir, num_gen_img=60000):\n",
    "    '''\n",
    "    Given model path, generate images for each class in digit_list.\n",
    "    Each class is generated evenly.\n",
    "    '''\n",
    "    netG.eval()\n",
    "    netG.to(device)\n",
    "    bs = 100\n",
    "    data_x = []\n",
    "    data_y = []\n",
    "    iter_gen = num_gen_img // 100\n",
    "\n",
    "    # generate data\n",
    "    with torch.no_grad():\n",
    "        bernoulli = torch.distributions.Bernoulli(torch.tensor([0.5]))\n",
    "        for i in range(iter_gen):\n",
    "            label = torch.arange(10).repeat(10).to(device)\n",
    "            noise = bernoulli.sample((bs, 10)).view(bs, 10).to(device)\n",
    "            samples = netG(noise, label)\n",
    "            samples = samples.view(bs, 1, 28, 28).cpu()\n",
    "            data_x.append(copy.deepcopy(samples.cpu()))\n",
    "            data_y.append(copy.deepcopy(label.cpu()))\n",
    "\n",
    "    data_x = np.concatenate(data_x)\n",
    "    data_x = np.transpose(data_x, [0, 2, 3, 1])\n",
    "    data_y = np.concatenate(data_y)\n",
    "\n",
    "    np.savez_compressed(save_dir, data_x=data_x, data_y=data_y)\n",
    "    del noise, label\n",
    "    torch.cuda.empty_cache()\n",
    "\n",
    "    return data_x, data_y\n",
    "\n",
    "\n",
    "device = torch.device(f\"cuda:5\")\n",
    "G_dict = torch.load('/home/.../pfgan_hub/GS-WGAN/results/mnist/main/multiclass_downsized_4_nodp/netGS_10000.pth', map_location=f\"cuda\")\n",
    "base_G = GeneratorResNet(z_dim=10, model_dim=64, num_classes=10).to(device)\n",
    "base_G.load_state_dict(G_dict)\n",
    "\n",
    "gen_data_x, gen_data_y = _generate_data(base_G, \"/home/.../pfgan_hub/GS-WGAN/results/mnist/main/multiclass_downsized_4_nodp/gen_data.npz\")\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# logging\n",
    "logging.basicConfig(level=logging.INFO, filemode='a')\n",
    "formatter = logging.Formatter('%(message)s')\n",
    "\n",
    "### Save results\n",
    "save_dir = os.path.join('eval', dataset)\n",
    "mkdir(save_dir)\n",
    "\n",
    "\n",
    "# compute fids and subgroup fids\n",
    "folder_name = target_path.split('/')[-1]\n",
    "curr_logger = logging.getLogger(folder_name)\n",
    "file_handler = logging.FileHandler(os.path.join(save_dir, f'FID_{folder_name}.txt'))\n",
    "file_handler.setFormatter(formatter)\n",
    "curr_logger.addHandler(file_handler)\n",
    "\n",
    "fid_all_list = []\n",
    "# if bias_factor == 'multi':\n",
    "fid_major_z1_list = []\n",
    "fid_major_z0_list = []\n",
    "fid_minor_z1_list = []\n",
    "fid_minor_z0_list = []\n",
    "\n",
    "# elif bias_factor == 'z':\n",
    "#     fid_z1_list = []\n",
    "#     fid_z0_list = []\n",
    "\n",
    "seed_list = os.listdir(target_path)\n",
    "seed_list.sort()\n",
    "\n",
    "for seed_name in seed_list:\n",
    "\n",
    "    if seed_name == 'backup':\n",
    "        continue\n",
    "\n",
    "    #FIXME\n",
    "    if seed_name == 'seed_6':\n",
    "        continue\n",
    "\n",
    "    seed_idx = seed_name.split('_')[-1]\n",
    "    curr_dir = os.path.join(target_path, f'seed_{seed_idx}')\n",
    "\n",
    "\n",
    "    # seed_idx = seed_name.split('_')[-2]\n",
    "    # curr_dir = os.path.join(target_path, f'seed_{seed_idx}_osir')\n",
    "\n",
    "    # select runs with best FID \n",
    "    if select_runs:\n",
    "        # generate 20000 data samples to compute overall FID\n",
    "        if not os.path.exists(os.path.join(curr_dir, 'netGS.pth')):\n",
    "            print(f\"Cannot find seed={seed_idx} generator!\")\n",
    "            continue\n",
    "        \n",
    "        else:\n",
    "            print(f\"Generating data for seed={seed_idx} generator...\") \n",
    "            gen_data, gen_label = generate_data(digit_list, \\\n",
    "                                                curr_dir, \\\n",
    "                                                save_eval=save_eval, \n",
    "                                                num_eval_img=num_eval_img,\n",
    "                                                save_gen=save_gen,\n",
    "                                                num_gen_img=num_gen_img,\n",
    "                                                iters = model_iters\n",
    "                                                )\n",
    "    # evaluate FID w.r.t. bias groups      \n",
    "    else:\n",
    "        if uncond or model_iters:\n",
    "            gen_data_dir = os.path.join(curr_dir, f'gen_data_{model_iters}')\n",
    "        else:\n",
    "            gen_data_dir = os.path.join(curr_dir, 'gen_data')\n",
    "        if not os.path.exists(os.path.join(gen_data_dir, f'gen_data_{num_eval_img}.npz')):\n",
    "            print(f\"No files in folder {curr_dir}\")\n",
    "            continue\n",
    "\n",
    "        gen_data = np.load(os.path.join(gen_data_dir, f'gen_data_{num_eval_img}.npz'))['data_x']\n",
    "        gen_label = torch.load(os.path.join(gen_data_dir, f'gen_data_{num_eval_img}_label.pt'))\n",
    "        gen_z = torch.load(os.path.join(gen_data_dir, f'gen_data_{num_eval_img}_z.pt'))\n",
    "\n",
    "\n",
    "    # if bias_factor == 'multi':\n",
    "    minor, major = digit_list\n",
    "\n",
    "    m_gen_all, s_gen_all = get_act(model, batch_size, gen_data)\n",
    "    fid_value_all = calculate_frechet_distance(m_real_all, s_real_all, m_gen_all, s_gen_all)\n",
    "    infostr = 'fid value for all class: {:.3f}'.format(fid_value_all)\n",
    "    fid_all_list.append(fid_value_all)\n",
    "\n",
    "    if not select_runs:\n",
    "        idx_cln_3 = (gen_label == major) & (gen_z == 1)\n",
    "        idx_rot_3 = (gen_label == major) & (gen_z == 0)\n",
    "        idx_cln_1 = (gen_label == minor) & (gen_z == 1)\n",
    "        idx_rot_1 = (gen_label == minor) & (gen_z == 0)\n",
    "\n",
    "        m_gen_cln_3, s_gen_cln_3 = get_act(model, batch_size, gen_data[idx_cln_3])\n",
    "        m_gen_rot_3, s_gen_rot_3 = get_act(model, batch_size, gen_data[idx_rot_3])\n",
    "        m_gen_cln_1, s_gen_cln_1 = get_act(model, batch_size, gen_data[idx_cln_1])\n",
    "        m_gen_rot_1, s_gen_rot_1 = get_act(model, batch_size, gen_data[idx_rot_1])\n",
    "\n",
    "        fid_value_cln_3 = calculate_frechet_distance(m_real_cln_3, s_real_cln_3, m_gen_cln_3, s_gen_cln_3)\n",
    "        fid_value_rot_3 = calculate_frechet_distance(m_real_rot_3, s_real_rot_3, m_gen_rot_3, s_gen_rot_3)\n",
    "        fid_value_cln_1 = calculate_frechet_distance(m_real_cln_1, s_real_cln_1, m_gen_cln_1, s_gen_cln_1)\n",
    "        fid_value_rot_1 = calculate_frechet_distance(m_real_rot_1, s_real_rot_1, m_gen_rot_1, s_gen_rot_1)\n",
    "\n",
    "        infostr_3 = 'fid value for major class: {:.3f}(Z=1) {:.3f}(Z=0)'.format(fid_value_cln_3, fid_value_rot_3)\n",
    "        infostr_1 = 'fid value for minor class: {:.3f}(Z=1) {:.3f}(Z=0)'.format(fid_value_cln_1, fid_value_rot_1)\n",
    "\n",
    "        fid_major_z1_list.append(fid_value_cln_3)\n",
    "        fid_major_z0_list.append(fid_value_rot_3)\n",
    "        fid_minor_z1_list.append(fid_value_cln_1)\n",
    "        fid_minor_z0_list.append(fid_value_rot_1)\n",
    "    \n",
    "    # elif bias_factor == 'z':\n",
    "\n",
    "    #     m_gen_all, s_gen_all = get_act(model, batch_size, gen_data)\n",
    "    #     fid_value_all = calculate_frechet_distance(m_real_all, s_real_all, m_gen_all, s_gen_all)\n",
    "    #     infostr = 'fid value for all class: {:.3f}'.format(fid_value_all)\n",
    "    #     fid_all_list.append(fid_value_all)\n",
    "\n",
    "    #     if not select_runs:\n",
    "    #         idx_cln = (gen_z == 1)\n",
    "    #         idx_rot = (gen_z == 0)\n",
    "\n",
    "    #         zero_minor = False\n",
    "    #         if gen_data[idx_rot].shape[0]== 0:\n",
    "    #             zero_minor = True\n",
    "\n",
    "    #         m_gen_cln, s_gen_cln = get_act(model, batch_size, gen_data[idx_cln])\n",
    "    #         if not zero_minor:\n",
    "    #             m_gen_rot, s_gen_rot = get_act(model, batch_size, gen_data[idx_rot])\n",
    "\n",
    "    #         fid_value_cln = calculate_frechet_distance(m_real_cln, s_real_cln, m_gen_cln, s_gen_cln)\n",
    "    #         fid_value_rot = calculate_frechet_distance(m_real_rot, s_real_rot, m_gen_rot, s_gen_rot) if not zero_minor else 0\n",
    "    \n",
    "    #         infostr_3 = 'fid value for Z=1: {:.3f}'.format(fid_value_cln)\n",
    "    #         infostr_1 = 'fid value for Z=0: {:.3f}'.format(fid_value_rot)\n",
    "\n",
    "    #         fid_z1_list.append(fid_value_cln)\n",
    "    #         if not zero_minor:\n",
    "    #             fid_z0_list.append(fid_value_rot)\n",
    "\n",
    "    curr_logger.info(f'[seed {seed_idx}]=========================')\n",
    "    curr_logger.info(infostr)\n",
    "    if not select_runs:\n",
    "        curr_logger.info(infostr_1)\n",
    "        curr_logger.info(infostr_3)\n",
    "\n",
    "\n",
    "curr_logger.info(f'=' * 30)\n",
    "curr_logger.info(f'fid value for all class: ')\n",
    "curr_logger.info(f'\\tmean: {np.mean(fid_all_list):.3f}')\n",
    "curr_logger.info(f'\\tstd: {np.std(fid_all_list):.3f}')\n",
    "\n",
    "if not select_runs:\n",
    "    # if bias_factor == 'multi':\n",
    "    curr_logger.info(f'fid value for Major(Z=1): ')\n",
    "    curr_logger.info(f'\\tmean: {np.mean(fid_major_z1_list):.3f}')\n",
    "    curr_logger.info(f'\\tstd: {np.std(fid_major_z1_list):.3f}')\n",
    "\n",
    "    curr_logger.info(f'fid value for Major(Z=0): ')\n",
    "    curr_logger.info(f'\\tmean: {np.mean(fid_major_z0_list):.3f}')\n",
    "    curr_logger.info(f'\\tstd: {np.std(fid_major_z0_list):.3f}')\n",
    "\n",
    "    curr_logger.info(f'fid value for Minor(Z=1): ')\n",
    "    curr_logger.info(f'\\tmean: {np.mean(fid_minor_z1_list):.3f}')\n",
    "    curr_logger.info(f'\\tstd: {np.std(fid_minor_z1_list):.3f}')\n",
    "\n",
    "    curr_logger.info(f'fid value for Minor(Z=0): ')\n",
    "    curr_logger.info(f'\\tmean: {np.mean(fid_minor_z0_list):.3f}')\n",
    "    curr_logger.info(f'\\tstd: {np.std(fid_minor_z0_list):.3f}')\n",
    "\n",
    "\n",
    "    # elif bias_factor == 'z':\n",
    "    #     curr_logger.info(f'fid value for Z=1: ')\n",
    "    #     curr_logger.info(f'\\tmean: {np.mean(fid_z1_list):.3f}')\n",
    "    #     curr_logger.info(f'\\tstd: {np.std(fid_z1_list):.3f}')\n",
    "\n",
    "    #     curr_logger.info(f'fid value for Z=0: ')\n",
    "    #     curr_logger.info(f'\\tmean: {np.mean(fid_z0_list):.3f}')\n",
    "    #     curr_logger.info(f'\\tstd: {np.std(fid_z0_list):.3f}')\n",
    "\n",
    "    #     if len(fid_z0_list) != len(fid_z1_list):\n",
    "    #         fid_z1_list = fid_z1_list[:len(fid_z0_list)]\n",
    "    #     pass\n",
    "\n",
    "      \n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# import shutil\n",
    "\n",
    "# # get top k seeds with lowest FID \n",
    "\n",
    "# if select_runs:\n",
    "#     selected_idx = sorted(range(len(fid_all_list)), key=lambda i: fid_all_list[i])[:top_k]\n",
    "#     selected_seed = [seed_list[i] for i in selected_idx]\n",
    "#     fid_topk_list = [fid_all_list[i] for i in selected_idx]\n",
    "\n",
    "#     # stat of top k seeds\n",
    "#     top_k_str = ', '.join(selected_seed)\n",
    "#     curr_logger.info(f'fid value for top {top_k} = {top_k_str}:')\n",
    "#     curr_logger.info(f'\\tmean: {np.mean(fid_topk_list):.3f}')\n",
    "#     curr_logger.info(f'\\tstd: {np.std(fid_topk_list):.3f}')\n",
    "\n",
    "#     # # clean folders\n",
    "#     for folder_name in seed_list:\n",
    "#         # Remove the folders that are not seed_i where i is in min_indices\n",
    "#         if folder_name not in selected_seed:\n",
    "#             print(\"remove: \", folder_name)\n",
    "#             # shutil.rmtree(os.path.join(target_path, folder_name))\n",
    "#         else:\n",
    "#             for item in os.listdir(os.path.join(target_path, folder_name)):\n",
    "#                 if item not in ['gen_data', 'netGS.pth', 'params.txt', 'samples_20000.png']:\n",
    "#                     pass\n",
    "#                     # os.remove(os.path.join(target_path, folder_name, item))\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# import matplotlib.pyplot as plt\n",
    "# %matplotlib inline\n",
    "\n",
    "# x = np.arange(2)\n",
    "# dset = ['MNIST-Rotate', 'FMNIST-Rotate']\n",
    "# labels = ['Original', 'ImpSIR(perc=1.0)', 'ImpSIR(perc=0.1)']\n",
    "# colors = ['red', 'blue', 'green']\n",
    "# bar_width = 0.2\n",
    "\n",
    "# bar_loc = 0\n",
    "# for idx, val in enumerate(fids):\n",
    "#     plt.bar(x + bar_loc, [val, 0], bar_width, alpha=0.5, color = colors[idx], label = labels[idx])\n",
    "#     bar_loc += bar_width\n",
    "# plt.xticks(x + bar_width, dset)\n",
    "\n",
    "# plt.xlabel('Dataset')\n",
    "# plt.ylabel('FID')\n",
    "# plt.legend()\n",
    "# plt.show()\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pfgan_new",
   "language": "python",
   "name": "pfgan"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "fbe2ce3c68f626f38b492c34585b74e32b7dd1d8b67edc9d605ebc5753ef1359"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
