{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "g8yoP4NQQVK5"
      },
      "outputs": [],
      "source": [
        "!nvidia-smi"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "id": "LJfz-3FrQc9h"
      },
      "outputs": [],
      "source": [
        "import zipfile\n",
        "with zipfile.ZipFile(\"/content/edm.zip\", 'r') as zip_ref:\n",
        "    zip_ref.extractall(\"/content/\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7Prin8FTQdnq"
      },
      "outputs": [],
      "source": [
        "# THE FOLLOWING COPYRIGHT IS LEFT IN BECAUSE WE BUILD ON THE WORKS OF Karras et al. (2022)\n",
        "\n",
        "# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n",
        "#\n",
        "# This work is licensed under a Creative Commons\n",
        "# Attribution-NonCommercial-ShareAlike 4.0 International License.\n",
        "# You should have received a copy of the license along with this\n",
        "# work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/\n",
        "\n",
        "\"\"\"Minimal standalone example to reproduce the main results from the paper\n",
        "\"Elucidating the Design Space of Diffusion-Based Generative Models\".\"\"\"\n",
        "\n",
        "# appending a path\n",
        "import sys\n",
        "sys.path.append('/content/edm/')\n",
        "\n",
        "import tqdm\n",
        "import pickle\n",
        "import numpy as np\n",
        "import torch\n",
        "import PIL.Image\n",
        "import dnnlib\n",
        "import torchvision\n",
        "import torchvision.transforms as transforms\n",
        "import matplotlib.pyplot as plt\n",
        "import torch.nn.functional as F\n",
        "import time\n",
        "\n",
        "\n",
        "SET_NAME = \"C10\"\n",
        "#SET_NAME = \"C100\"\n",
        "\n",
        "# change these as you like\n",
        "NORM_TAG = \"L2\" #L2/LINF\n",
        "KEY=\"LEAST\" #\"LEAST/MOST\", i.e. least/most sensitive\n",
        "\n",
        "#change this if you like\n",
        "SEED_LIST = [] #default\n",
        "#SEED_LIST = [k for k in range(NR_SAMPLES)]\n",
        "\n",
        "# change/select these as you like\n",
        "VERSION, NR_SAMPLES, NR_SAMPLES_SQRT, NR_CLONES, BATCHING = \"V1\", 16, 4, 1, -1\n",
        "#VERSION, NR_SAMPLES, NR_SAMPLES_SQRT, NR_CLONES, BATCHING = \"V2\", 16, -1, 10, -1\n",
        "#VERSION, NR_SAMPLES, NR_SAMPLES_SQRT, NR_CLONES, BATCHING = \"V3\", 10000, -1, 1, 1024\n",
        "#VERSION, NR_SAMPLES, NR_SAMPLES_SQRT, NR_CLONES, BATCHING = \"V4\", 1024, -1, 10, -1\n",
        "\n",
        "if VERSION == \"V1\":\n",
        "    #if you want to see the original most/least sensitive elements, set NR_CLONES = 0\n",
        "    assert NR_SAMPLES == int(NR_SAMPLES_SQRT**2)\n",
        "\n",
        "#taken from https://colab.research.google.com/drive/1_kbRZPTjnFgViPrmGcUsaszEdYa8XTpq?usp=sharing#scrollTo=Mb-u1yDAJ23R\n",
        "def perturb_latents(latents, scale):\n",
        "    noise = torch.randn_like(latents)\n",
        "    new_latents = (1 - scale) * latents + scale * noise\n",
        "    return (new_latents - new_latents.mean()) / new_latents.std()\n",
        "\n",
        "#----------------------------------------------------------------------------\n",
        "\n",
        "# this function is an adpation from the example at ( https://github.com/NVlabs/edm/blob/main/example.py ) by Karras et al. (2022)\n",
        "def generate_image_grid(\n",
        "    net,\n",
        "    image,\n",
        "    label,\n",
        "    seed=0,\n",
        "    gridw=25,\n",
        "    gridh=1,\n",
        "    device=torch.device('cuda'),\n",
        "    c_version=0,\n",
        "    set_name=\"C10\",\n",
        "    num_steps=20,\n",
        "    sigma_min=0.002,\n",
        "    sigma_max=1.0,\n",
        "    rho=7,\n",
        "    S_churn=0,\n",
        "    S_min=0,\n",
        "    S_max=float('inf'),\n",
        "    S_noise=1,\n",
        "    scale=0.0\n",
        "):\n",
        "    batch_size = gridw * gridh\n",
        "    torch.manual_seed(seed)\n",
        "\n",
        "    if c_version == 0:\n",
        "      image = image.repeat(gridw,1,1,1)\n",
        "      label = label.repeat(gridw,1)\n",
        "\n",
        "    image *= 255.0\n",
        "    image = image.to(device).to(torch.float32) / 127.5 - 1\n",
        "    latents = perturb_latents(image, scale)\n",
        "\n",
        "    if set_name == \"C10\":\n",
        "      class_labels = F.one_hot(label, num_classes=10).to(device)\n",
        "    else:\n",
        "      class_labels = F.one_hot(label, num_classes=100).to(device)\n",
        "\n",
        "\n",
        "    # Adjust noise levels based on what's supported by the network.\n",
        "    sigma_min = max(sigma_min, net.sigma_min)\n",
        "    sigma_max = min(sigma_max, net.sigma_max)\n",
        "\n",
        "    # Time step discretization.\n",
        "    step_indices = torch.arange(num_steps, dtype=torch.float64, device=device)\n",
        "    t_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho\n",
        "    t_steps = torch.cat([net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])]) # t_N = 0\n",
        "\n",
        "    # Main sampling loop.\n",
        "    x_next = latents.to(torch.float64) * t_steps[0]\n",
        "    for i, (t_cur, t_next) in tqdm.tqdm(list(enumerate(zip(t_steps[:-1], t_steps[1:]))), unit='step'): # 0, ..., N-1\n",
        "        x_cur = x_next\n",
        "\n",
        "        # Increase noise temporarily.\n",
        "        gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0\n",
        "        t_hat = net.round_sigma(t_cur + gamma * t_cur)\n",
        "        x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * torch.randn_like(x_cur)\n",
        "\n",
        "        # Euler step.\n",
        "        denoised = net(x_hat, t_hat, class_labels).to(torch.float64)\n",
        "        d_cur = (x_hat - denoised) / t_hat\n",
        "        x_next = x_hat + (t_next - t_hat) * d_cur\n",
        "\n",
        "        # Apply 2nd order correction.\n",
        "        if i < num_steps - 1:\n",
        "            denoised = net(x_next, t_next, class_labels).to(torch.float64)\n",
        "            d_prime = (x_next - denoised) / t_next\n",
        "            x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime)\n",
        "\n",
        "    image = (x_next * 127.5 + 128).clip(0, 255).to(torch.uint8)\n",
        "    image = image.permute(0, 2, 3, 1)\n",
        "    image = image.cpu().numpy()\n",
        "\n",
        "    return image\n",
        "\n",
        "#----------------------------------------------------------------------------\n",
        "\n",
        "def main():\n",
        "    start = time.time()\n",
        "\n",
        "    #define origins of trained diffusion models;\n",
        "    #the CIFAR-10 model is from Karras et al. (2022) and\n",
        "    #the CIFAR-100 model is from Wang et al. (2023) [see the paper for both references]\n",
        "    if SET_NAME == \"C10\":\n",
        "      model_root = 'https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-cifar10-32x32-cond-vp.pkl'\n",
        "    else:\n",
        "      model_root = 'https://huggingface.co/wzekai99/DM-Improves-AT/resolve/main/others/edm-cifar100-32x32-cond-vp.pkl'\n",
        "\n",
        "    data_root_in = '/content/'\n",
        "    data_root_out = '/content/generated_data.npz'\n",
        "    device = torch.device('cuda')\n",
        "\n",
        "    #FIDs over three random seeds w.r.t. the most sensitive 10K images (w.r.t. the linf norm) from the original CIFAR10 set are 5.83/5.76/5.78\n",
        "    ADD_PARAMS = {\"scale\" : 0.55, \"sigma_min\" : 0.004, \"sigma_max\" : 1.4, \"rho\" : 1.7, \"S_churn\" : 2.4, \"S_min\" : 0.5, \"S_max\" : float('inf'), \"S_noise\" : 1.015}\n",
        "\n",
        "    # Load network.\n",
        "    with dnnlib.util.open_url(model_root) as f:\n",
        "        net = pickle.load(f)['ema'].to(device)\n",
        "\n",
        "    transform = transforms.Compose([transforms.ToTensor()])\n",
        "    batch_size = 50_000\n",
        "\n",
        "    #load data\n",
        "    if SET_NAME == \"C10\":\n",
        "      trainset = torchvision.datasets.CIFAR10(root=data_root_in, train=True,\n",
        "                                            download=True, transform=transform)\n",
        "    else:\n",
        "      trainset = torchvision.datasets.CIFAR100(root=data_root_in, train=True,\n",
        "                                              download=True, transform=transform)\n",
        "\n",
        "    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,\n",
        "                                              shuffle=False, num_workers=2)\n",
        "    images, labels = next(iter(trainloader))\n",
        "\n",
        "    output_imgs = []\n",
        "    output_labels = []\n",
        "\n",
        "    #load ordered sensitivity indices\n",
        "    if SET_NAME == \"C10\":\n",
        "      with open(\"/content/edm/cifar10_sens_indices_ordered_\" + NORM_TAG + \".pickle\", 'rb') as f:\n",
        "        SENS_INDICES_ORDERED = np.array(pickle.load(f))\n",
        "    else:\n",
        "      with open(\"/content/edm/cifar100_sens_indices_ordered_\" + NORM_TAG + \".pickle\", 'rb') as f:\n",
        "        SENS_INDICES_ORDERED = np.array(pickle.load(f))\n",
        "\n",
        "    if KEY == \"MOST\": #i.e. most sensitive / least robust\n",
        "      if SET_NAME == \"C10\":\n",
        "        RS_LIST = SENS_INDICES_ORDERED[-NR_SAMPLES:]\n",
        "      else:\n",
        "        RS_LIST = SENS_INDICES_ORDERED[-(NR_SAMPLES+40):-40] #disregard the 40 most sensitive elements due to too many near-duplicates\n",
        "    else: #i.e. least sensitive / most robust\n",
        "      RS_LIST = SENS_INDICES_ORDERED[:NR_SAMPLES]\n",
        "\n",
        "\n",
        "    if not VERSION == \"V3\":\n",
        "\n",
        "      #if NR_CLONES==0, we can simply ignore this and continue with the original most/least sensitive images\n",
        "      if NR_CLONES > 0:\n",
        "\n",
        "        #iterate over the indices selected above for the most/least sensitive elements\n",
        "        for ix, rs in enumerate(RS_LIST):\n",
        "\n",
        "          #generator function: generates \"NR_CLONES\"-images for a given original image and \"random\" random seed\n",
        "          gen_imgs = generate_image_grid(net,\n",
        "                              image=images[rs],\n",
        "                              label=labels[rs],\n",
        "                              seed=RS_LIST[ix] if len(SEED_LIST)==0 else SEED_LIST[ix],\n",
        "                              gridw=NR_CLONES,\n",
        "                              c_version=0,\n",
        "                              set_name=SET_NAME,\n",
        "                              num_steps= 20 if SET_NAME==\"C10\" else 25,\n",
        "                              **ADD_PARAMS\n",
        "                              )\n",
        "\n",
        "          #collect images and labels (which we may want to work with in the future)\n",
        "          output_imgs.append(gen_imgs)\n",
        "          output_labels = output_labels + [labels[rs].item() for k in range(NR_CLONES)]\n",
        "\n",
        "        #produce one unified tensor\n",
        "        output_imgs = np.concatenate([o_i for o_i in output_imgs])\n",
        "\n",
        "      else:\n",
        "        #this merely selects the most/least sensitive images from the original dataset\n",
        "        output_imgs = images[RS_LIST].permute(0, 2, 3, 1)\n",
        "        output_labels = labels[RS_LIST]\n",
        "\n",
        "    else:\n",
        "      # select/order images here because of the different argument structure of the generator function\n",
        "      images, labels = images[RS_LIST], labels[RS_LIST]\n",
        "\n",
        "      #consistency check to allow compatibility for any batch size\n",
        "      if NR_SAMPLES % BATCHING == 0:\n",
        "        limiter = NR_SAMPLES // BATCHING\n",
        "      else:\n",
        "        limiter = NR_SAMPLES // BATCHING + 1\n",
        "\n",
        "      for ix, j in enumerate(range(limiter)):\n",
        "\n",
        "        #\"manually\" adjusting the batch indices\n",
        "        if ix == NR_SAMPLES // BATCHING:\n",
        "          ending_on = NR_SAMPLES\n",
        "        else:\n",
        "          ending_on = (j+1)*BATCHING\n",
        "\n",
        "        #sanity check\n",
        "        print(\"ending_on\", ending_on)\n",
        "\n",
        "        #generator function: generates batches of images for a \"random\" random seed\n",
        "        gen_imgs = generate_image_grid(net,\n",
        "                            image=images[j*BATCHING : ending_on],\n",
        "                            label=labels[j*BATCHING : ending_on],\n",
        "                            seed=np.random.randint(1000000000),\n",
        "                            gridw=BATCHING,\n",
        "                            c_version=1,\n",
        "                            set_name=SET_NAME,\n",
        "                            num_steps= 20 if SET_NAME==\"C10\" else 25,\n",
        "                            **ADD_PARAMS\n",
        "                            )\n",
        "        #collect images and labels (which we may want to work with in the future)\n",
        "        output_imgs.append(gen_imgs)\n",
        "        output_labels = output_labels + [-1 for l in range(len(gen_imgs))]\n",
        "\n",
        "      #produce one unified tensor\n",
        "      output_imgs = np.concatenate([o_i for o_i in output_imgs])\n",
        "\n",
        "\n",
        "    print()\n",
        "    print(\"Final output shape of images:\", output_imgs.shape)\n",
        "    print(\"Final output shape of labels:\", len(output_labels))\n",
        "    print()\n",
        "\n",
        "    #save the generated data as a npz file\n",
        "    np.savez(data_root_out, **{\"data\":output_imgs, \"labels\":output_labels})\n",
        "\n",
        "    end = time.time()\n",
        "    print(\"On this gpu it took me\", np.round(end - start, 4), \"secs to generate\", len(RS_LIST) * NR_CLONES, \"images.\")\n",
        "\n",
        "\n",
        "#----------------------------------------------------------------------------\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "\n",
        "\n",
        "    main()\n",
        "    print()\n",
        "    print(\"Thx 4 generating :)\")\n",
        "\n",
        "#----------------------------------------------------------------------------\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "QPw9yTyeQhnH"
      },
      "outputs": [],
      "source": [
        "from torch import linalg as LA\n",
        "import os\n",
        "import PIL.Image\n",
        "\n",
        "#load previously generated data\n",
        "npzfile = np.load(\"/content/generated_data.npz\")\n",
        "print(npzfile.files)\n",
        "x_gen, _ =  npzfile['data'], npzfile['labels']\n",
        "print(x_gen.shape)\n",
        "\n",
        "\n",
        "if VERSION == \"V1\": #plotting a (quadratic) collection of the most/least sens. images or their clones\n",
        "    plt.figure(figsize=(NR_SAMPLES_SQRT,NR_SAMPLES_SQRT))\n",
        "    for i in range(NR_SAMPLES_SQRT ** 2):\n",
        "        plt.subplot(NR_SAMPLES_SQRT,NR_SAMPLES_SQRT,i+1)\n",
        "        plt.imshow(x_gen[i])\n",
        "        plt.rcParams['axes.grid'] = False\n",
        "        plt.axis('off')\n",
        "    plt.tight_layout()\n",
        "    if SET_NAME == \"C10\":\n",
        "      plt.savefig(\"collection_\"+NORM_TAG+\"_\"+KEY+\"_\"+str(NR_CLONES)+\".png\", dpi=200)\n",
        "    else:\n",
        "      plt.savefig(\"c100_collection_\"+NORM_TAG+\"_\"+KEY+\"_\"+str(NR_CLONES)+\".png\", dpi=200)\n",
        "\n",
        "elif VERSION == \"V2\": #plotting a collection of clones of the most/least sens. images\n",
        "    plt.figure(figsize=(NR_CLONES, NR_SAMPLES))\n",
        "    for i in range(NR_SAMPLES * NR_CLONES):\n",
        "        plt.subplot(NR_SAMPLES, NR_CLONES,i+1)\n",
        "        plt.imshow(x_gen[i])\n",
        "        plt.rcParams['axes.grid'] = False\n",
        "        plt.axis('off')\n",
        "    plt.tight_layout()\n",
        "    if SET_NAME == \"C10\":\n",
        "      plt.savefig(\"collection_\"+NORM_TAG+\"_\"+KEY+\"_\"+str(NR_CLONES)+\".png\", dpi=200)\n",
        "    else:\n",
        "      plt.savefig(\"c100_collection_\"+NORM_TAG+\"_\"+KEY+\"_\"+str(NR_CLONES)+\".png\", dpi=200)\n",
        "\n",
        "\n",
        "elif VERSION == \"V3\": #calculating the average l2/linf distances between the 10K most or least images and their clones, respectively; also saves data for the FID calculations below\n",
        "\n",
        "  data_root_in = '/content/'\n",
        "  transform = transforms.Compose([transforms.ToTensor()])\n",
        "  batch_size = 50_000\n",
        "\n",
        "  #load data\n",
        "  if SET_NAME == \"C10\":\n",
        "    trainset = torchvision.datasets.CIFAR10(root=data_root_in, train=True,\n",
        "                                          download=True, transform=transform)\n",
        "  else:\n",
        "    trainset = torchvision.datasets.CIFAR100(root=data_root_in, train=True,\n",
        "                                            download=True, transform=transform)\n",
        "  trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,\n",
        "                                            shuffle=False, num_workers=2)\n",
        "  ref_images, _ = next(iter(trainloader))\n",
        "\n",
        "  #load ordered sensitivity indices\n",
        "  if SET_NAME == \"C10\":\n",
        "    with open(\"/content/edm/cifar10_sens_indices_ordered_\" + NORM_TAG + \".pickle\", 'rb') as f:\n",
        "      sens_indices_ordered = pickle.load(f)\n",
        "  else:\n",
        "    with open(\"/content/edm/cifar100_sens_indices_ordered_\" + NORM_TAG + \".pickle\", 'rb') as f:\n",
        "      sens_indices_ordered = pickle.load(f)\n",
        "\n",
        "  #select the indices of interest\n",
        "  if KEY == \"MOST\":\n",
        "    if SET_NAME == \"C10\":\n",
        "      RS_LIST = sens_indices_ordered[-NR_SAMPLES:]\n",
        "    else:\n",
        "      RS_LIST = sens_indices_ordered[-(NR_SAMPLES+40):-40] #disregard the 40 most sensitive elements due to too many near-duplicates\n",
        "  else:\n",
        "    RS_LIST = sens_indices_ordered[:NR_SAMPLES]\n",
        "\n",
        "  #order reference images, unify shape and normalise images\n",
        "  ref_images = ref_images[RS_LIST]\n",
        "  ref_images = ref_images.permute(0, 2, 3, 1)\n",
        "  x_gen = x_gen / 255.0\n",
        "\n",
        "  #calculate l2/linf distances\n",
        "  l_inf_average = torch.norm(input=ref_images - x_gen, p=float('inf'), dim=(1,2,3))\n",
        "  l_2_average = LA.vector_norm(input=ref_images - x_gen, ord=2, dim=(1,2,3))\n",
        "\n",
        "  #print average distances with tags\n",
        "  print(SET_NAME, NORM_TAG, KEY)\n",
        "  print(\"l_inf average:\", torch.mean(l_inf_average).item())\n",
        "  print(\"l_2 average:\", torch.mean(l_2_average).item())\n",
        "\n",
        "  \"\"\"<<<##-#-##>>>\"\"\"\n",
        "\n",
        "  #process images to save them as pillow images\n",
        "  ref_images = ref_images * 255.0\n",
        "  REFERENCE_SET = ref_images.clip(0, 255).to(torch.uint8).numpy()\n",
        "\n",
        "  for k, b_s in enumerate(range(NR_SAMPLES)):\n",
        "      image_np = REFERENCE_SET[k]\n",
        "      # Save image.\n",
        "      image_dir = os.path.join('/content/data_ref/', f'{b_s-b_s%1000:06d}')\n",
        "      os.makedirs(image_dir, exist_ok=True)\n",
        "      image_path = os.path.join(image_dir, f'{b_s:06d}.png')\n",
        "      if image_np.shape[2] == 1:\n",
        "          PIL.Image.fromarray(image_np[:, :, 0], 'L').save(image_path)\n",
        "      else:\n",
        "          PIL.Image.fromarray(image_np, 'RGB').save(image_path)\n",
        "\n",
        "  #load data again to save as pillow images\n",
        "  npzfile = np.load(\"/content/generated_data.npz\")\n",
        "  REFERENCE_SET =  npzfile['data']\n",
        "\n",
        "  for k, b_s in enumerate(range(NR_SAMPLES)):\n",
        "      image_np = REFERENCE_SET[k]\n",
        "      # Save image.\n",
        "      image_dir = os.path.join('/content/data/', f'{b_s-b_s%1000:06d}')\n",
        "      os.makedirs(image_dir, exist_ok=True)\n",
        "      image_path = os.path.join(image_dir, f'{b_s:06d}.png')\n",
        "      if image_np.shape[2] == 1:\n",
        "          PIL.Image.fromarray(image_np[:, :, 0], 'L').save(image_path)\n",
        "      else:\n",
        "          PIL.Image.fromarray(image_np, 'RGB').save(image_path)\n",
        "\n",
        "\n",
        "elif VERSION == \"V4\": #calculates the average l2/linf distances between each clone and its follow-up clone, either for the most or least sensitive elements\n",
        "\n",
        "  x_gen_copy = x_gen.copy()\n",
        "  x_gen_copy = torch.Tensor(x_gen_copy / 255.0)\n",
        "\n",
        "  SUM_linf = 0.0\n",
        "  SUM_l2 = 0.0\n",
        "\n",
        "  #use vector iteration: the first \"NR_CLONES\"-images are cloned from the same original image,\n",
        "  #hence we need to skip \"1*NR_CLONES\" elements to get from the first to the second row of clones,\n",
        "  #\"2*NR_CLONES\" elements to get from the second to the third row of clones and so on.\n",
        "  for ax_1 in range(NR_SAMPLES):\n",
        "    for ax_2 in range(NR_CLONES-1):\n",
        "      SUM_linf += torch.norm(input=x_gen_copy[ax_1*NR_CLONES+ax_2] - x_gen_copy[ax_1*NR_CLONES+ax_2+1], p=float('inf'), dim=(0,1,2))\n",
        "      SUM_l2 += LA.vector_norm(input=x_gen_copy[ax_1*NR_CLONES+ax_2] - x_gen_copy[ax_1*NR_CLONES+ax_2+1], ord=2, dim=(0,1,2))\n",
        "\n",
        "  print(SET_NAME, NORM_TAG, KEY)\n",
        "  print(\"l_inf average:\", SUM_linf.item() / (NR_SAMPLES * (NR_CLONES-1)))\n",
        "  print(\"l_2 average:\", SUM_l2.item() / (NR_SAMPLES * (NR_CLONES-1)))\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "##### Restart runtime before the last two cells to free the GPU memory for the Inception model (if needed or if the FID calculations bug out)"
      ],
      "metadata": {
        "id": "jwzFOQHBpVaU"
      },
      "execution_count": 5,
      "outputs": []
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "rS0Hnx_9UBG3"
      },
      "outputs": [],
      "source": [
        "!python edm/fid.py ref --data=\"/content/data_ref/\" --dest=\"/content/data_ref_stats.npz\""
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "!torchrun --standalone --nproc_per_node=1 edm/fid.py calc --images=\"data\" --num=10000 --ref=\"/content/data_ref_stats.npz\""
      ],
      "metadata": {
        "id": "uRJ3RDC72ufM"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "UlLeYb2x57LC"
      },
      "execution_count": 7,
      "outputs": []
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "provenance": [],
      "gpuType": "V100",
      "machine_shape": "hm"
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}