{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"FID_evaluation_cifar10.ipynb","provenance":[],"collapsed_sections":[],"authorship_tag":"ABX9TyMwGBoG+gLpblB7WnIQ8jSe"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","source":["To evaluate the model using official implementation of FID score, please upload the two files:\n","\n","'best_model_cifar10.pth' and\n","\n","'fid_stats_cifar10_train.npz' \n","\n","to the content folder :"],"metadata":{"id":"rzIT_XKHJ9M7"}},{"cell_type":"code","source":["MODEL_PATH = '/content/best_model_cifar10.pth'# best performing generator trained with Lv.6 Adam\n","STATS_PATH = '/content/fid_stats_cifar10_train.npz'# training set statistics"],"metadata":{"id":"pk6dPV7dKpUB","executionInfo":{"status":"ok","timestamp":1653531832993,"user_tz":240,"elapsed":132,"user":{"displayName":"Zichu Liu","userId":"15111204870225912773"}}},"execution_count":1,"outputs":[]},{"cell_type":"code","execution_count":2,"metadata":{"id":"fCvLIUp3o0qp","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1653531874499,"user_tz":240,"elapsed":40031,"user":{"displayName":"Zichu Liu","userId":"15111204870225912773"}},"outputId":"266dd37e-0ff5-4325-bf72-078b6f223bca"},"outputs":[{"output_type":"stream","name":"stdout","text":["Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n","Collecting tensorflow==1.14\n","  Downloading tensorflow-1.14.0-cp37-cp37m-manylinux1_x86_64.whl (109.3 MB)\n","\u001b[K     |████████████████████████████████| 109.3 MB 49 kB/s \n","\u001b[?25hCollecting tensorboard<1.15.0,>=1.14.0\n","  Downloading tensorboard-1.14.0-py3-none-any.whl (3.1 MB)\n","\u001b[K     |████████████████████████████████| 3.1 MB 57.8 MB/s \n","\u001b[?25hRequirement already satisfied: google-pasta>=0.1.6 in /usr/local/lib/python3.7/dist-packages (from tensorflow==1.14) (0.2.0)\n","Collecting keras-applications>=1.0.6\n","  Downloading Keras_Applications-1.0.8-py3-none-any.whl (50 kB)\n","\u001b[K     |████████████████████████████████| 50 kB 7.5 MB/s \n","\u001b[?25hRequirement already satisfied: numpy<2.0,>=1.14.5 in /usr/local/lib/python3.7/dist-packages (from tensorflow==1.14) (1.21.6)\n","Collecting tensorflow-estimator<1.15.0rc0,>=1.14.0rc0\n","  Downloading tensorflow_estimator-1.14.0-py2.py3-none-any.whl (488 kB)\n","\u001b[K     |████████████████████████████████| 488 kB 28.7 MB/s \n","\u001b[?25hRequirement already satisfied: grpcio>=1.8.6 in /usr/local/lib/python3.7/dist-packages (from tensorflow==1.14) (1.46.1)\n","Requirement already satisfied: keras-preprocessing>=1.0.5 in /usr/local/lib/python3.7/dist-packages (from tensorflow==1.14) (1.1.2)\n","Requirement already satisfied: astor>=0.6.0 in /usr/local/lib/python3.7/dist-packages (from tensorflow==1.14) (0.8.1)\n","Requirement already satisfied: protobuf>=3.6.1 in /usr/local/lib/python3.7/dist-packages (from tensorflow==1.14) (3.17.3)\n","Requirement already satisfied: gast>=0.2.0 in /usr/local/lib/python3.7/dist-packages (from tensorflow==1.14) (0.5.3)\n","Requirement already satisfied: wheel>=0.26 in /usr/local/lib/python3.7/dist-packages (from tensorflow==1.14) (0.37.1)\n","Requirement already satisfied: six>=1.10.0 in /usr/local/lib/python3.7/dist-packages (from tensorflow==1.14) (1.15.0)\n","Requirement already satisfied: wrapt>=1.11.1 in /usr/local/lib/python3.7/dist-packages (from tensorflow==1.14) (1.14.1)\n","Requirement already satisfied: termcolor>=1.1.0 in /usr/local/lib/python3.7/dist-packages (from tensorflow==1.14) (1.1.0)\n","Requirement already satisfied: absl-py>=0.7.0 in /usr/local/lib/python3.7/dist-packages (from tensorflow==1.14) (1.0.0)\n","Requirement already satisfied: h5py in /usr/local/lib/python3.7/dist-packages (from keras-applications>=1.0.6->tensorflow==1.14) (3.1.0)\n","Requirement already satisfied: werkzeug>=0.11.15 in /usr/local/lib/python3.7/dist-packages (from tensorboard<1.15.0,>=1.14.0->tensorflow==1.14) (1.0.1)\n","Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.7/dist-packages (from tensorboard<1.15.0,>=1.14.0->tensorflow==1.14) (3.3.7)\n","Requirement already satisfied: setuptools>=41.0.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard<1.15.0,>=1.14.0->tensorflow==1.14) (57.4.0)\n","Requirement already satisfied: importlib-metadata>=4.4 in /usr/local/lib/python3.7/dist-packages (from markdown>=2.6.8->tensorboard<1.15.0,>=1.14.0->tensorflow==1.14) (4.11.3)\n","Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata>=4.4->markdown>=2.6.8->tensorboard<1.15.0,>=1.14.0->tensorflow==1.14) (3.8.0)\n","Requirement already satisfied: typing-extensions>=3.6.4 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata>=4.4->markdown>=2.6.8->tensorboard<1.15.0,>=1.14.0->tensorflow==1.14) (4.2.0)\n","Requirement already satisfied: cached-property in /usr/local/lib/python3.7/dist-packages (from h5py->keras-applications>=1.0.6->tensorflow==1.14) (1.5.2)\n","Installing collected packages: tensorflow-estimator, tensorboard, keras-applications, tensorflow\n","  Attempting uninstall: tensorflow-estimator\n","    Found existing installation: tensorflow-estimator 2.8.0\n","    Uninstalling tensorflow-estimator-2.8.0:\n","      Successfully uninstalled tensorflow-estimator-2.8.0\n","  Attempting uninstall: tensorboard\n","    Found existing installation: tensorboard 2.8.0\n","    Uninstalling tensorboard-2.8.0:\n","      Successfully uninstalled tensorboard-2.8.0\n","  Attempting uninstall: tensorflow\n","    Found existing installation: tensorflow 2.8.0+zzzcolab20220506162203\n","    Uninstalling tensorflow-2.8.0+zzzcolab20220506162203:\n","      Successfully uninstalled tensorflow-2.8.0+zzzcolab20220506162203\n","\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n","kapre 0.3.7 requires tensorflow>=2.0.0, but you have tensorflow 1.14.0 which is incompatible.\u001b[0m\n","Successfully installed keras-applications-1.0.8 tensorboard-1.14.0 tensorflow-1.14.0 tensorflow-estimator-1.14.0\n"]}],"source":["!pip install tensorflow==1.14"]},{"cell_type":"code","source":["%%writefile parsing.py\n","import argparse\n","import torch\n","from copy import deepcopy\n","import random\n","import numpy as np\n","import os\n","import torch.nn as nn\n","from tqdm import tqdm\n","from imageio import imsave\n","\n","\n","class GenBlock(nn.Module):\n","    def __init__(self, in_channels, out_channels, hidden_channels=None, ksize=3, pad=1,\n","                 activation=nn.ReLU(), upsample=False, n_classes=0):\n","        super(GenBlock, self).__init__()\n","        self.activation = activation\n","        self.upsample = upsample\n","        self.learnable_sc = in_channels != out_channels or upsample\n","        hidden_channels = out_channels if hidden_channels is None else hidden_channels\n","        self.n_classes = n_classes\n","        self.c1 = nn.Conv2d(in_channels, hidden_channels, kernel_size=ksize, padding=pad)\n","        self.c2 = nn.Conv2d(hidden_channels, out_channels, kernel_size=ksize, padding=pad)\n","\n","        self.b1 = nn.BatchNorm2d(in_channels)\n","        self.b2 = nn.BatchNorm2d(hidden_channels)\n","        if self.learnable_sc:\n","            self.c_sc = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0)\n","\n","    def upsample_conv(self, x, conv):\n","        return conv(nn.UpsamplingNearest2d(scale_factor=2)(x))\n","\n","    def residual(self, x):\n","        h = x\n","        h = self.b1(h)\n","        h = self.activation(h)\n","        h = self.upsample_conv(h, self.c1) if self.upsample else self.c1(h)\n","        h = self.b2(h)\n","        h = self.activation(h)\n","        h = self.c2(h)\n","        return h\n","\n","    def shortcut(self, x):\n","        if self.learnable_sc:\n","            x = self.upsample_conv(x, self.c_sc) if self.upsample else self.c_sc(x)\n","            return x\n","        else:\n","            return x\n","\n","    def forward(self, x):\n","        return self.residual(x) + self.shortcut(x)\n","\n","class Generator(nn.Module):\n","    def __init__(self, args, activation=nn.ReLU(), n_classes=0):\n","        super(Generator, self).__init__()\n","        self.bottom_width = args.bottom_width\n","        self.activation = activation\n","        self.n_classes = n_classes\n","        self.ch = args.gf_dim\n","        self.l1 = nn.Linear(args.latent_dim, (self.bottom_width ** 2) * self.ch)\n","        self.block2 = GenBlock(self.ch, self.ch, activation=activation, upsample=True, n_classes=n_classes)\n","        self.block3 = GenBlock(self.ch, self.ch, activation=activation, upsample=True, n_classes=n_classes)\n","        self.block4 = GenBlock(self.ch, self.ch, activation=activation, upsample=True, n_classes=n_classes)\n","        self.b5 = nn.BatchNorm2d(self.ch)\n","        self.c5 = nn.Conv2d(self.ch, 3, kernel_size=3, stride=1, padding=1)\n","\n","    def forward(self, z):\n","\n","        h = z\n","        h = self.l1(h).view(-1, self.ch, self.bottom_width, self.bottom_width)\n","        h = self.block2(h)\n","        h = self.block3(h)\n","        h = self.block4(h)\n","        h = self.b5(h)\n","        h = self.activation(h)\n","        h = nn.Tanh()(self.c5(h))\n","        return h\n","\n","\n","\"\"\"Discriminator\"\"\"\n","\n","\n","def _downsample(x):\n","    # Downsample (Mean Avg Pooling with 2x2 kernel)\n","    return nn.AvgPool2d(kernel_size=2)(x)\n","\n","\n","class OptimizedDisBlock(nn.Module):\n","    def __init__(self, args, in_channels, out_channels, ksize=3, pad=1, activation=nn.ReLU()):\n","        super(OptimizedDisBlock, self).__init__()\n","        self.activation = activation\n","\n","        self.c1 = nn.Conv2d(in_channels, out_channels, kernel_size=ksize, padding=pad)\n","        self.c2 = nn.Conv2d(out_channels, out_channels, kernel_size=ksize, padding=pad)\n","        self.c_sc = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0)\n","        if args.d_spectral_norm:\n","            self.c1 = nn.utils.spectral_norm(self.c1)\n","            self.c2 = nn.utils.spectral_norm(self.c2)\n","            self.c_sc = nn.utils.spectral_norm(self.c_sc)\n","\n","    def residual(self, x):\n","        h = x\n","        h = self.c1(h)\n","        h = self.activation(h)\n","        h = self.c2(h)\n","        h = _downsample(h)\n","        return h\n","\n","    def shortcut(self, x):\n","        return self.c_sc(_downsample(x))\n","\n","    def forward(self, x):\n","        return self.residual(x) + self.shortcut(x)\n","\n","\n","class DisBlock(nn.Module):\n","    def __init__(self, args, in_channels, out_channels, hidden_channels=None, ksize=3, pad=1,\n","                 activation=nn.ReLU(), downsample=False):\n","        super(DisBlock, self).__init__()\n","        self.activation = activation\n","        self.downsample = downsample\n","        self.learnable_sc = (in_channels != out_channels) or downsample\n","        hidden_channels = in_channels if hidden_channels is None else hidden_channels\n","        self.c1 = nn.Conv2d(in_channels, hidden_channels, kernel_size=ksize, padding=pad)\n","        self.c2 = nn.Conv2d(hidden_channels, out_channels, kernel_size=ksize, padding=pad)\n","        if args.d_spectral_norm:\n","            self.c1 = nn.utils.spectral_norm(self.c1)\n","            self.c2 = nn.utils.spectral_norm(self.c2)\n","\n","        if self.learnable_sc:\n","            self.c_sc = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0)\n","            if args.d_spectral_norm:\n","                self.c_sc = nn.utils.spectral_norm(self.c_sc)\n","\n","    def residual(self, x):\n","        h = x\n","        h = self.activation(h)\n","        h = self.c1(h)\n","        h = self.activation(h)\n","        h = self.c2(h)\n","        if self.downsample:\n","            h = _downsample(h)\n","        return h\n","\n","    def shortcut(self, x):\n","        if self.learnable_sc:\n","            x = self.c_sc(x)\n","            if self.downsample:\n","                return _downsample(x)\n","            else:\n","                return x\n","        else:\n","            return x\n","\n","    def forward(self, x):\n","        return self.residual(x) + self.shortcut(x)\n","\n","\n","class Discriminator(nn.Module):\n","    def __init__(self, args, activation=nn.ReLU()):\n","        super(Discriminator, self).__init__()\n","        self.ch = args.df_dim\n","        self.activation = activation\n","        self.block1 = OptimizedDisBlock(args, 3, self.ch)\n","        self.block2 = DisBlock(args, self.ch, self.ch, activation=activation, downsample=True)\n","        self.block3 = DisBlock(args, self.ch, self.ch, activation=activation, downsample=False)\n","        self.block4 = DisBlock(args, self.ch, self.ch, activation=activation, downsample=False)\n","        self.l5 = nn.Linear(self.ch, 1, bias=False)\n","        if args.d_spectral_norm:\n","            self.l5 = nn.utils.spectral_norm(self.l5)\n","\n","    def forward(self, x):\n","        h = x\n","        h = self.block1(h)\n","        h = self.block2(h)\n","        h = self.block3(h)\n","        h = self.block4(h)\n","        h = self.activation(h)\n","        # Global average pooling\n","        h = h.sum(2).sum(2)\n","        output = self.l5(h)\n","\n","        return output\n","\n","def str2bool(v):\n","    if v.lower() in ('yes', 'true', 't', 'y', '1'):\n","        return True\n","    elif v.lower() in ('no', 'false', 'f', 'n', '0'):\n","        return False\n","    else:\n","        raise argparse.ArgumentTypeError('Boolean value expected.')\n","\n","\n","def parse_args():\n","    parser = argparse.ArgumentParser()\n","    parser.add_argument(\n","        '--max_epoch',\n","        type=int,\n","        default=1200,\n","        help='number of epochs of training')\n","    parser.add_argument(\n","        '--extra_steps',\n","        type=int,\n","        default=4,\n","        help='number of extrapolations')\n","    parser.add_argument(\n","        '--max_iter',\n","        type=int,\n","        default=50000,\n","        help='set the max iteration number')\n","    parser.add_argument(\n","        '-gen_bs',\n","        '--gen_batch_size',\n","        type=int,\n","        default=128,\n","        help='size of the batches')\n","    parser.add_argument(\n","        '-dis_bs',\n","        '--dis_batch_size',\n","        type=int,\n","        default=128,\n","        help='size of the batches')\n","    parser.add_argument(\n","        '--lr_g',\n","        type=float,\n","        default=0.00004,\n","        help='adam: gen learning rate')\n","    parser.add_argument(\n","        '--lr_d',\n","        type=float,\n","        default=0.0002,\n","        help='adam: disc learning rate')\n","    parser.add_argument(\n","        '--lr_decay',\n","        action='store_true',\n","        help='learning rate decay or not')\n","    parser.add_argument(\n","        '--beta1',\n","        type=float,\n","        default=0.0,\n","        help='adam: decay of first order momentum of gradient')\n","    parser.add_argument(\n","        '--beta2',\n","        type=float,\n","        default=0.90,\n","        help='adam: decay of second order momentum of gradient')\n","    parser.add_argument(\n","        '--ema',\n","        type=float,\n","        default=0.9999,\n","        help='exponential moving average weights')\n","    parser.add_argument(\n","        '--num_workers',\n","        type=int,\n","        default=8,\n","        help='number of cpu threads to use during batch generation')\n","    parser.add_argument(\n","        '--ha_alpha',\n","        type=int,\n","        default=1.5,\n","        help='number of cpu threads to use during batch generation')\n","    parser.add_argument(\n","        '--ha_beta',\n","        type=int,\n","        default=0.5,\n","        help='number of cpu threads to use during batch generation')\n","    parser.add_argument(\n","        '--lr_r',\n","        type=int,\n","        default=2e-3,\n","        help='number of cpu threads to use during batch generation')\n","    parser.add_argument(\n","        '--latent_dim',\n","        type=int,\n","        default=128,\n","        help='dimensionality of the latent space')\n","    parser.add_argument(\n","        '--img_size',\n","        type=int,\n","        default=32,\n","        help='size of each image dimension')\n","    parser.add_argument(\n","        '--channels',\n","        type=int,\n","        default=3,\n","        help='number of image channels')\n","    parser.add_argument(\n","        '--n_critic',\n","        type=int,\n","        default=1,\n","        help='number of training steps for discriminator per iter')\n","    parser.add_argument(\n","        '--val_freq',\n","        type=int,\n","        default=1,\n","        help='interval between each validation')\n","    parser.add_argument(\n","        '--print_freq',\n","        type=int,\n","        default=300,\n","        help='interval between each verbose')\n","    parser.add_argument(\n","        '--load_path',\n","        default='/home/',\n","        type=str,\n","        help='The reload model path')\n","    parser.add_argument(\n","        '--alg',\n","        type=str,\n","        default='alt_sppm',\n","        help='The optimization algorithm')\n","    parser.add_argument(\n","        '--exp_name',\n","        type=str,\n","        default='cifar10_adam',\n","        help='The name of exp')\n","    parser.add_argument(\n","        '--d_spectral_norm',\n","        type=str2bool,\n","        default=True,\n","        help='add spectral_norm on discriminator?')\n","    parser.add_argument(\n","        '--g_spectral_norm',\n","        type=str2bool,\n","        default=False,\n","        help='add spectral_norm on generator?')\n","    parser.add_argument(\n","        '--dataset',\n","        type=str,\n","        default='cifar10',\n","        help='dataset type')\n","    parser.add_argument(\n","        '--data_path',\n","        type=str,\n","        default='./data',\n","        help='The path of data set')\n","    parser.add_argument('--init_type', type=str, default='xavier_uniform',\n","                        choices=['normal', 'orth', 'xavier_uniform', 'false'],\n","                        help='The init type')\n","    parser.add_argument('--gf_dim', type=int, default=256,\n","                        help='The base channel num of gen')\n","    parser.add_argument('--df_dim', type=int, default=128,\n","                        help='The base channel num of disc')\n","    parser.add_argument(\n","        '--model',\n","        type=str,\n","        default='sngan_cifar10',\n","        help='path of model')\n","    parser.add_argument('--eval_batch_size', type=int, default=100)\n","    parser.add_argument('--num_eval_imgs', type=int, default=50000)\n","    parser.add_argument(\n","        '--bottom_width',\n","        type=int,\n","        default=4,\n","        help=\"the base resolution of the GAN\")\n","    parser.add_argument('--random_seed', type=int, default=666666)\n","    parser.add_argument(\"--valid_only\", default=False, action='store_true')\n","    opt = parser.parse_args()\n","    return opt\n","\n","\n","def copy_params(model):\n","    flatten = deepcopy(list(p.data for p in model.parameters()))\n","    return flatten\n","\n","\n","def load_params(model, new_param):\n","    for p, new_p in zip(model.parameters(), new_param):\n","        p.data.copy_(new_p)\n","\n","def load_path(args,G,D):\n","    print(f'=> resuming from {args.load_path}')\n","    checkpoint_file = args.load_path\n","    print(os.path.exists(checkpoint_file))\n","    if os.path.exists(checkpoint_file):\n","        print(checkpoint_file)\n","        checkpoint = torch.load(checkpoint_file)\n","        start_epoch = checkpoint['epoch']\n","        best_is = checkpoint['best_fid']\n","        G.load_state_dict(checkpoint['gen_state_dict'])\n","        D.load_state_dict(checkpoint['dis_state_dict'])\n","        avg_gen_net = deepcopy(G)\n","        avg_gen_net.load_state_dict(checkpoint['avg_gen_state_dict'])\n","        gen_avg_param = copy_params(avg_gen_net)\n","        del avg_gen_net\n","\n","        args.path_helper = checkpoint['path_helper']\n","        print(\"path helper\", args.path_helper)\n","    load_params(G, gen_avg_param)\n","    fid_buffer_dir = os.path.join('fid_buffer')\n","    try:\n","        os.makedirs(fid_buffer_dir)\n","    except:\n","        pass\n","\n","    eval_iter = args.num_eval_imgs // args.eval_batch_size\n","    img_list = list()\n","    for iter_idx in tqdm(range(eval_iter), desc='sample images'):\n","        z = torch.cuda.FloatTensor(np.random.normal(0, 1, (args.eval_batch_size, args.latent_dim)))\n","\n","        # Generate a batch of images\n","        gen_imgs = G(z).mul_(127.5).add_(127.5).clamp_(0.0, 255.0).permute(0, 2, 3, 1).to('cpu',\n","                                                                                                torch.uint8).numpy()\n","        for img_idx, img in enumerate(gen_imgs):\n","            file_name = os.path.join(fid_buffer_dir, f'iter{iter_idx}_b{img_idx}.png')\n","            imsave(file_name, img)\n","        img_list.extend(list(gen_imgs))\n","\n","\n","if __name__ == '__main__':\n","    args = parse_args()\n","    random.seed(args.random_seed)\n","    np.random.seed(args.random_seed)\n","    torch.manual_seed(args.random_seed)\n","    torch.cuda.manual_seed_all(args.random_seed)\n","    torch.backends.cudnn.benchmark = False\n","    torch.backends.cudnn.deterministic = True\n","    G = Generator(args).cuda()\n","    D = Discriminator(args).cuda()\n","    load_path(args,G,D)"],"metadata":{"id":"szAUNz32qJ2D","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1653532042088,"user_tz":240,"elapsed":122,"user":{"displayName":"Zichu Liu","userId":"15111204870225912773"}},"outputId":"1ce92100-abc3-47de-a8bd-fb51c5bbf591"},"execution_count":7,"outputs":[{"output_type":"stream","name":"stdout","text":["Overwriting parsing.py\n"]}]},{"cell_type":"code","source":["!python3 parsing.py --load_path /content/best_model_cifar10.pth"],"metadata":{"id":"WrRO8ntQsJPR","colab":{"base_uri":"https://localhost:8080/"},"outputId":"47543ad6-f682-4cc5-e57a-562df78c9795"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["=> resuming from /content/best_model_cifar10.pth\n","True\n","/content/best_model_cifar10.pth\n","path helper {'prefix': 'logs/cifar10_adam_2022_03_18_10_30_35', 'ckpt_path': 'logs/cifar10_adam_2022_03_18_10_30_35/Model', 'log_path': 'logs/cifar10_adam_2022_03_18_10_30_35/Log', 'sample_path': 'logs/cifar10_adam_2022_03_18_10_30_35/Samples'}\n","sample images:  67% 333/500 [00:39<00:17,  9.36it/s]"]}]},{"cell_type":"code","source":["from __future__ import absolute_import, division, print_function\n","import numpy as np\n","import os\n","import gzip, pickle\n","import tensorflow as tf\n","from imageio import imread\n","from scipy import linalg\n","import pathlib\n","import urllib\n","import warnings\n","\n","class InvalidFIDException(Exception):\n","    pass\n","\n","\n","def create_inception_graph(pth):\n","    \"\"\"Creates a graph from saved GraphDef file.\"\"\"\n","    # Creates graph from saved graph_def.pb.\n","    with tf.io.gfile.GFile( pth, 'rb') as f:\n","        graph_def = tf.compat.v1.GraphDef()\n","        graph_def.ParseFromString( f.read())\n","        _ = tf.import_graph_def( graph_def, name='FID_Inception_Net')\n","#-------------------------------------------------------------------------------\n","\n","\n","# code for handling inception net derived from\n","#   https://github.com/openai/improved-gan/blob/master/inception_score/model.py\n","def _get_inception_layer(sess):\n","    \"\"\"Prepares inception net for batched usage and returns pool_3 layer. \"\"\"\n","    layername = 'FID_Inception_Net/pool_3:0'\n","    pool3 = sess.graph.get_tensor_by_name(layername)\n","    ops = pool3.graph.get_operations()\n","    for op_idx, op in enumerate(ops):\n","        for o in op.outputs:\n","            shape = o.get_shape()\n","            if shape._dims is not None:\n","              #shape = [s.value for s in shape] TF 1.x\n","              shape = [s for s in shape] #TF 2.x\n","              new_shape = []\n","              for j, s in enumerate(shape):\n","                if s == 1 and j == 0:\n","                  new_shape.append(None)\n","                else:\n","                  new_shape.append(s)\n","              o.__dict__['_shape_val'] = tf.TensorShape(new_shape)\n","    return pool3\n","#-------------------------------------------------------------------------------\n","\n","\n","def get_activations(images, sess, batch_size=50, verbose=False):\n","    \"\"\"Calculates the activations of the pool_3 layer for all images.\n","    Params:\n","    -- images      : Numpy array of dimension (n_images, hi, wi, 3). The values\n","                     must lie between 0 and 256.\n","    -- sess        : current session\n","    -- batch_size  : the images numpy array is split into batches with batch size\n","                     batch_size. A reasonable batch size depends on the disposable hardware.\n","    -- verbose    : If set to True and parameter out_step is given, the number of calculated\n","                     batches is reported.\n","    Returns:\n","    -- A numpy array of dimension (num images, 2048) that contains the\n","       activations of the given tensor when feeding inception with the query tensor.\n","    \"\"\"\n","    inception_layer = _get_inception_layer(sess)\n","    n_images = images.shape[0]\n","    if batch_size > n_images:\n","        print(\"warning: batch size is bigger than the data size. setting batch size to data size\")\n","        batch_size = n_images\n","    n_batches = n_images//batch_size # drops the last batch if < batch_size\n","    pred_arr = np.empty((n_batches * batch_size,2048))\n","    for i in range(n_batches):\n","        if verbose:\n","            print(\"\\rPropagating batch %d/%d\" % (i+1, n_batches), end=\"\", flush=True)\n","        start = i*batch_size\n","        \n","        if start+batch_size < n_images:\n","            end = start+batch_size\n","        else:\n","            end = n_images\n","        \n","        batch = images[start:end]\n","        pred = sess.run(inception_layer, {'FID_Inception_Net/ExpandDims:0': batch})\n","        pred_arr[start:end] = pred.reshape(batch.shape[0],-1)\n","    if verbose:\n","        print(\" done\")\n","    return pred_arr\n","#-------------------------------------------------------------------------------\n","\n","\n","def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):\n","    \"\"\"Numpy implementation of the Frechet Distance.\n","    The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)\n","    and X_2 ~ N(mu_2, C_2) is\n","            d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).\n","            \n","    Stable version by Dougal J. Sutherland.\n","    Params:\n","    -- mu1 : Numpy array containing the activations of the pool_3 layer of the\n","             inception net ( like returned by the function 'get_predictions')\n","             for generated samples.\n","    -- mu2   : The sample mean over activations of the pool_3 layer, precalcualted\n","               on an representive data set.\n","    -- sigma1: The covariance matrix over activations of the pool_3 layer for\n","               generated samples.\n","    -- sigma2: The covariance matrix over activations of the pool_3 layer,\n","               precalcualted on an representive data set.\n","    Returns:\n","    --   : The Frechet Distance.\n","    \"\"\"\n","\n","    mu1 = np.atleast_1d(mu1)\n","    mu2 = np.atleast_1d(mu2)\n","\n","    sigma1 = np.atleast_2d(sigma1)\n","    sigma2 = np.atleast_2d(sigma2)\n","\n","    assert mu1.shape == mu2.shape, \"Training and test mean vectors have different lengths\"\n","    assert sigma1.shape == sigma2.shape, \"Training and test covariances have different dimensions\"\n","\n","    diff = mu1 - mu2\n","\n","    # product might be almost singular\n","    covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)\n","    if not np.isfinite(covmean).all():\n","        msg = \"fid calculation produces singular product; adding %s to diagonal of cov estimates\" % eps\n","        warnings.warn(msg)\n","        offset = np.eye(sigma1.shape[0]) * eps\n","        covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))\n","\n","    # numerical error might give slight imaginary component\n","    if np.iscomplexobj(covmean):\n","        if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):\n","            m = np.max(np.abs(covmean.imag))\n","            raise ValueError(\"Imaginary component {}\".format(m))\n","        covmean = covmean.real\n","\n","    tr_covmean = np.trace(covmean)\n","\n","    return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean\n","#-------------------------------------------------------------------------------\n","\n","\n","def calculate_activation_statistics(images, sess, batch_size=50, verbose=False):\n","    \"\"\"Calculation of the statistics used by the FID.\n","    Params:\n","    -- images      : Numpy array of dimension (n_images, hi, wi, 3). The values\n","                     must lie between 0 and 255.\n","    -- sess        : current session\n","    -- batch_size  : the images numpy array is split into batches with batch size\n","                     batch_size. A reasonable batch size depends on the available hardware.\n","    -- verbose     : If set to True and parameter out_step is given, the number of calculated\n","                     batches is reported.\n","    Returns:\n","    -- mu    : The mean over samples of the activations of the pool_3 layer of\n","               the incption model.\n","    -- sigma : The covariance matrix of the activations of the pool_3 layer of\n","               the incption model.\n","    \"\"\"\n","    act = get_activations(images, sess, batch_size, verbose)\n","    mu = np.mean(act, axis=0)\n","    sigma = np.cov(act, rowvar=False)\n","    return mu, sigma\n","    \n","\n","#------------------\n","# The following methods are implemented to obtain a batched version of the activations.\n","# This has the advantage to reduce memory requirements, at the cost of slightly reduced efficiency.\n","# - Pyrestone\n","#------------------\n","\n","\n","def load_image_batch(files):\n","    \"\"\"Convenience method for batch-loading images\n","    Params:\n","    -- files    : list of paths to image files. Images need to have same dimensions for all files.\n","    Returns:\n","    -- A numpy array of dimensions (num_images,hi, wi, 3) representing the image pixel values.\n","    \"\"\"\n","    return np.array([imread(str(fn)).astype(np.float32) for fn in files])\n","\n","def get_activations_from_files(files, sess, batch_size=50, verbose=False):\n","    \"\"\"Calculates the activations of the pool_3 layer for all images.\n","    Params:\n","    -- files      : list of paths to image files. Images need to have same dimensions for all files.\n","    -- sess        : current session\n","    -- batch_size  : the images numpy array is split into batches with batch size\n","                     batch_size. A reasonable batch size depends on the disposable hardware.\n","    -- verbose    : If set to True and parameter out_step is given, the number of calculated\n","                     batches is reported.\n","    Returns:\n","    -- A numpy array of dimension (num images, 2048) that contains the\n","       activations of the given tensor when feeding inception with the query tensor.\n","    \"\"\"\n","    inception_layer = _get_inception_layer(sess)\n","    n_imgs = len(files)\n","    if batch_size > n_imgs:\n","        print(\"warning: batch size is bigger than the data size. setting batch size to data size\")\n","        batch_size = n_imgs\n","    n_batches = n_imgs//batch_size + 1\n","    pred_arr = np.empty((n_imgs,2048))\n","    for i in range(n_batches):\n","        if verbose:\n","            print(\"\\rPropagating batch %d/%d\" % (i+1, n_batches), end=\"\", flush=True)\n","        start = i*batch_size\n","        if start+batch_size < n_imgs:\n","            end = start+batch_size\n","        else:\n","            end = n_imgs\n","        \n","        batch = load_image_batch(files[start:end])\n","        pred = sess.run(inception_layer, {'FID_Inception_Net/ExpandDims:0': batch})\n","        pred_arr[start:end] = pred.reshape(batch_size,-1)\n","        del batch #clean up memory\n","    if verbose:\n","        print(\" done\")\n","    return pred_arr\n","    \n","def calculate_activation_statistics_from_files(files, sess, batch_size=50, verbose=False):\n","    \"\"\"Calculation of the statistics used by the FID.\n","    Params:\n","    -- files      : list of paths to image files. Images need to have same dimensions for all files.\n","    -- sess        : current session\n","    -- batch_size  : the images numpy array is split into batches with batch size\n","                     batch_size. A reasonable batch size depends on the available hardware.\n","    -- verbose     : If set to True and parameter out_step is given, the number of calculated\n","                     batches is reported.\n","    Returns:\n","    -- mu    : The mean over samples of the activations of the pool_3 layer of\n","               the incption model.\n","    -- sigma : The covariance matrix of the activations of the pool_3 layer of\n","               the incption model.\n","    \"\"\"\n","    act = get_activations_from_files(files, sess, batch_size, verbose)\n","    mu = np.mean(act, axis=0)\n","    sigma = np.cov(act, rowvar=False)\n","    return mu, sigma\n","    \n","#-------------------------------------------------------------------------------\n","\n","\n","#-------------------------------------------------------------------------------\n","# The following functions aren't needed for calculating the FID\n","# they're just here to make this module work as a stand-alone script\n","# for calculating FID scores\n","#-------------------------------------------------------------------------------\n","def check_or_download_inception(inception_path):\n","    ''' Checks if the path to the inception file is valid, or downloads\n","        the file if it is not present. '''\n","    INCEPTION_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz'\n","    if inception_path is None:\n","        inception_path = '/tmp'\n","    inception_path = pathlib.Path(inception_path)\n","    model_file = inception_path / 'classify_image_graph_def.pb'\n","    if not model_file.exists():\n","        print(\"Downloading Inception model\")\n","        from urllib import request\n","        import tarfile\n","        fn, _ = request.urlretrieve(INCEPTION_URL)\n","        with tarfile.open(fn, mode='r') as f:\n","            f.extract('classify_image_graph_def.pb', str(model_file.parent))\n","    return str(model_file)\n","\n","\n","def _handle_path(path, sess, low_profile=False):\n","    if path.endswith('.npz'):\n","        f = np.load(path)\n","        m, s = f['mu'][:], f['sigma'][:]\n","        f.close()\n","    else:\n","        path = pathlib.Path(path)\n","        files = list(path.glob('*.jpg')) + list(path.glob('*.png'))\n","        if low_profile:\n","            m, s = calculate_activation_statistics_from_files(files, sess)\n","        else:\n","            x = np.array([imread(str(fn)).astype(np.float32) for fn in files])\n","            m, s = calculate_activation_statistics(x, sess)\n","            del x #clean up memory\n","    return m, s\n","\n","\n","def calculate_fid_given_paths(paths, inception_path, low_profile=False):\n","    ''' Calculates the FID of two paths. '''\n","    inception_path = check_or_download_inception(inception_path)\n","\n","    for p in paths:\n","        if not os.path.exists(p):\n","            raise RuntimeError(\"Invalid path: %s\" % p)\n","\n","    create_inception_graph(str(inception_path))\n","    with tf.Session() as sess:\n","        sess.run(tf.global_variables_initializer())\n","        m1, s1 = _handle_path(paths[0], sess, low_profile=low_profile)\n","        m2, s2 = _handle_path(paths[1], sess, low_profile=low_profile)\n","        fid_value = calculate_frechet_distance(m1, s1, m2, s2)\n","        return fid_value\n","\n","\n","# if __name__ == \"__main__\":\n","#     from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter\n","#     parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)\n","#     parser.add_argument(\"path\", type=str, nargs=2,\n","#         help='Path to the generated images or to .npz statistic files')\n","#     parser.add_argument(\"-i\", \"--inception\", type=str, default=None,\n","#         help='Path to Inception model (will be downloaded if not provided)')\n","#     parser.add_argument(\"--gpu\", default=\"\", type=str,\n","#         help='GPU to use (leave blank for CPU only)')\n","#     args = parser.parse_args()\n","#     os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu\n","#     fid_value = calculate_fid_given_paths(args.path, args.inception)\n","#     print(\"FID: \", fid_value)"],"metadata":{"id":"X3mnm7OQCjY4"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["import glob\n","import os\n","os.environ[\"CUDA_DEVICE_ORDER\"] = \"PCI_BUS_ID\"\n","os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n","# Paths\n","image_path = '/content/fid_buffer' # set path to some generated images\n","# inception_path = check_or_download_inception(None) # download inception network\n","\n","fid = calculate_fid_given_paths((image_path,STATS_PATH),None)"],"metadata":{"id":"R2Q9jLxfK1S5"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["print(fid)"],"metadata":{"id":"Ja5mhXiIOjrB"},"execution_count":null,"outputs":[]}]}