{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3444f833",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "from torchvision import datasets, transforms\n",
    "import copy\n",
    "from PIL import Image\n",
    "import os\n",
    "import time\n",
    "import scipy\n",
    "import gc\n",
    "\n",
    "import math\n",
    "import torch.nn.functional as F\n",
    "import torchvision\n",
    "import torch.nn as nn\n",
    "from torchvision import models\n",
    "\n",
    "from mpl_toolkits.axes_grid1 import ImageGrid\n",
    "from skimage import exposure\n",
    "from skimage.exposure import match_histograms\n",
    "from datetime import datetime\n",
    "from torchvision import models"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1d1cd217",
   "metadata": {},
   "source": [
    "# Function Definitions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "68d1ffec",
   "metadata": {},
   "outputs": [],
   "source": [
    "def composition(im_path, IM_SIZE, cropped=True):\n",
    "    cx = 89\n",
    "    cy = 121\n",
    "    \n",
    "    # center crop to 128 x 128, then resize to relevant dimension\n",
    "    with Image.open(im_path) as im1:\n",
    "        if cropped:\n",
    "            im1 = im1.crop((cx-64, cy-64, cx+64, cy+64))\n",
    "        return im1.resize((IM_SIZE, IM_SIZE))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5169ddd9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# computing FID code, adapted from https://github.com/mseitzer/pytorch-fid\n",
    "class InceptionV3(nn.Module):\n",
    "    \"\"\"Pretrained InceptionV3 network returning feature maps\"\"\"\n",
    "\n",
    "    # Index of default block of inception to return,\n",
    "    # corresponds to output of final average pooling\n",
    "    DEFAULT_BLOCK_INDEX = 3\n",
    "\n",
    "    # Maps feature dimensionality to their output blocks indices\n",
    "    BLOCK_INDEX_BY_DIM = {\n",
    "        64: 0,   # First max pooling features\n",
    "        192: 1,  # Second max pooling featurs\n",
    "        768: 2,  # Pre-aux classifier features\n",
    "        2048: 3  # Final average pooling features\n",
    "    }\n",
    "\n",
    "    def __init__(self,\n",
    "                 output_blocks=[DEFAULT_BLOCK_INDEX],\n",
    "                 resize_input=True,\n",
    "                 normalize_input=True,\n",
    "                 requires_grad=False):\n",
    "        \n",
    "        super(InceptionV3, self).__init__()\n",
    "\n",
    "        self.resize_input = resize_input\n",
    "        self.normalize_input = normalize_input\n",
    "        self.output_blocks = sorted(output_blocks)\n",
    "        self.last_needed_block = max(output_blocks)\n",
    "\n",
    "        assert self.last_needed_block <= 3, \\\n",
    "            'Last possible output block index is 3'\n",
    "\n",
    "        self.blocks = nn.ModuleList()\n",
    "\n",
    "        \n",
    "        inception = models.inception_v3(pretrained=True)\n",
    "\n",
    "        # Block 0: input to maxpool1\n",
    "        block0 = [\n",
    "            inception.Conv2d_1a_3x3,\n",
    "            inception.Conv2d_2a_3x3,\n",
    "            inception.Conv2d_2b_3x3,\n",
    "            nn.MaxPool2d(kernel_size=3, stride=2)\n",
    "        ]\n",
    "        self.blocks.append(nn.Sequential(*block0))\n",
    "\n",
    "        # Block 1: maxpool1 to maxpool2\n",
    "        if self.last_needed_block >= 1:\n",
    "            block1 = [\n",
    "                inception.Conv2d_3b_1x1,\n",
    "                inception.Conv2d_4a_3x3,\n",
    "                nn.MaxPool2d(kernel_size=3, stride=2)\n",
    "            ]\n",
    "            self.blocks.append(nn.Sequential(*block1))\n",
    "\n",
    "        # Block 2: maxpool2 to aux classifier\n",
    "        if self.last_needed_block >= 2:\n",
    "            block2 = [\n",
    "                inception.Mixed_5b,\n",
    "                inception.Mixed_5c,\n",
    "                inception.Mixed_5d,\n",
    "                inception.Mixed_6a,\n",
    "                inception.Mixed_6b,\n",
    "                inception.Mixed_6c,\n",
    "                inception.Mixed_6d,\n",
    "                inception.Mixed_6e,\n",
    "            ]\n",
    "            self.blocks.append(nn.Sequential(*block2))\n",
    "\n",
    "        # Block 3: aux classifier to final avgpool\n",
    "        if self.last_needed_block >= 3:\n",
    "            block3 = [\n",
    "                inception.Mixed_7a,\n",
    "                inception.Mixed_7b,\n",
    "                inception.Mixed_7c,\n",
    "                nn.AdaptiveAvgPool2d(output_size=(1, 1))\n",
    "            ]\n",
    "            self.blocks.append(nn.Sequential(*block3))\n",
    "\n",
    "        for param in self.parameters():\n",
    "            param.requires_grad = requires_grad\n",
    "\n",
    "    def forward(self, inp):\n",
    "        \"\"\"Get Inception feature maps\n",
    "        Parameters\n",
    "        ----------\n",
    "        inp : torch.autograd.Variable\n",
    "            Input tensor of shape Bx3xHxW. Values are expected to be in\n",
    "            range (0, 1)\n",
    "        Returns\n",
    "        -------\n",
    "        List of torch.autograd.Variable, corresponding to the selected output\n",
    "        block, sorted ascending by index\n",
    "        \"\"\"\n",
    "        outp = []\n",
    "        x = inp\n",
    "\n",
    "        if self.resize_input:\n",
    "            x = F.interpolate(x,\n",
    "                              size=(299, 299),\n",
    "                              mode='bilinear',\n",
    "                              align_corners=False)\n",
    "\n",
    "        if self.normalize_input:\n",
    "            x = 2 * x - 1  # Scale from range (0, 1) to range (-1, 1)\n",
    "\n",
    "        for idx, block in enumerate(self.blocks):\n",
    "            x = block(x)\n",
    "            if idx in self.output_blocks:\n",
    "                outp.append(x)\n",
    "\n",
    "            if idx == self.last_needed_block:\n",
    "                break\n",
    "\n",
    "        return outp\n",
    "    \n",
    "def calculate_activation_statistics(images,model,batch_size=100, dims=2048,\n",
    "                    cuda=False):\n",
    "    model.eval()\n",
    "    act=np.empty((len(images), dims))\n",
    "    nBatches = len(images)//batch_size\n",
    "    \n",
    "    for i in range(nBatches):\n",
    "        batch = images[i*batch_size:(i+1)*batch_size]\n",
    "        if cuda:\n",
    "            batch=batch.cuda()\n",
    "\n",
    "        pred = model(batch)[0]\n",
    "\n",
    "            # If model output is not scalar, apply global spatial average pooling.\n",
    "            # This happens if you choose a dimensionality not equal 2048.\n",
    "        if pred.size(2) != 1 or pred.size(3) != 1:\n",
    "            pred = adaptive_avg_pool2d(pred, output_size=(1, 1))\n",
    "\n",
    "        act[i*batch_size:(i+1)*batch_size]= pred.cpu().data.numpy().reshape(pred.size(0), -1)\n",
    "    \n",
    "    mu = np.mean(act, axis=0)\n",
    "    sigma = np.cov(act, rowvar=False)\n",
    "    return mu, sigma\n",
    "\n",
    "def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):\n",
    "    \"\"\"Numpy implementation of the Frechet Distance.\n",
    "    The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)\n",
    "    and X_2 ~ N(mu_2, C_2) is\n",
    "            d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).\n",
    "    \"\"\"\n",
    "\n",
    "    mu1 = np.atleast_1d(mu1)\n",
    "    mu2 = np.atleast_1d(mu2)\n",
    "\n",
    "    sigma1 = np.atleast_2d(sigma1)\n",
    "    sigma2 = np.atleast_2d(sigma2)\n",
    "\n",
    "    assert mu1.shape == mu2.shape, \\\n",
    "        'Training and test mean vectors have different lengths'\n",
    "    assert sigma1.shape == sigma2.shape, \\\n",
    "        'Training and test covariances have different dimensions'\n",
    "\n",
    "    diff = mu1 - mu2\n",
    "\n",
    "    \n",
    "    covmean, _ = scipy.linalg.sqrtm(sigma1.dot(sigma2), disp=False)\n",
    "    if not np.isfinite(covmean).all():\n",
    "        msg = ('fid calculation produces singular product; '\n",
    "               'adding %s to diagonal of cov estimates') % eps\n",
    "        print(msg)\n",
    "        offset = np.eye(sigma1.shape[0]) * eps\n",
    "        covmean = scipy.linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))\n",
    "\n",
    "    \n",
    "    if np.iscomplexobj(covmean):\n",
    "        if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):\n",
    "            m = np.max(np.abs(covmean.imag))\n",
    "            raise ValueError('Imaginary component {}'.format(m))\n",
    "        covmean = covmean.real\n",
    "\n",
    "    tr_covmean = np.trace(covmean)\n",
    "\n",
    "    return (diff.dot(diff) + np.trace(sigma1) +\n",
    "            np.trace(sigma2) - 2 * tr_covmean)\n",
    "\n",
    "\n",
    "def calculate_fretchet(images_real,images_fake, batch_size=10):\n",
    "    block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[2048]\n",
    "    model = InceptionV3([block_idx])\n",
    "    model=model.cuda()\n",
    "    \n",
    "    images_real = torch.from_numpy(images_real).float().permute(0, 3, 1, 2)\n",
    "    images_fake = torch.from_numpy(images_fake).float().permute(0, 3, 1, 2)\n",
    "    \n",
    "    mu_1,std_1=calculate_activation_statistics(images_real,model,batch_size=batch_size,cuda=True)\n",
    "    mu_2 ,std_2=calculate_activation_statistics(images_fake,model,batch_size=batch_size,cuda=True)\n",
    "    \n",
    "    \"\"\"get fretchet distance\"\"\"\n",
    "    fid_value = calculate_frechet_distance(mu_1, std_1, mu_2, std_2)\n",
    "    print('FID score', fid_value)\n",
    "    return fid_value"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4cc7ea99",
   "metadata": {},
   "outputs": [],
   "source": [
    "# for clearing gpu memory between runs, if necessary\n",
    "def pretty_size(size):\n",
    "    \"\"\"Pretty prints a torch.Size object\"\"\"\n",
    "    assert(isinstance(size, torch.Size))\n",
    "    return \" × \".join(map(str, size))\n",
    "\n",
    "def dump_tensors(gpu_only=True):\n",
    "    \"\"\"Prints a list of the Tensors being tracked by the garbage collector.\"\"\"\n",
    "    total_size = 0\n",
    "    for obj in gc.get_objects():\n",
    "        try:\n",
    "            if torch.is_tensor(obj):\n",
    "                if not gpu_only or obj.is_cuda:\n",
    "                    print(\"%s:%s%s %s\" % (type(obj).__name__, \n",
    "                                          \" GPU\" if obj.is_cuda else \"\",\n",
    "                                          \" pinned\" if obj.is_pinned else \"\",\n",
    "                                          pretty_size(obj.size())))\n",
    "                    total_size += obj.numel()\n",
    "            elif hasattr(obj, \"data\") and torch.is_tensor(obj.data):\n",
    "                if not gpu_only or obj.is_cuda:\n",
    "                    print(\"%s → %s:%s%s%s%s %s\" % (type(obj).__name__, \n",
    "                                                   type(obj.data).__name__, \n",
    "                                                   \" GPU\" if obj.is_cuda else \"\",\n",
    "                                                   \" pinned\" if obj.data.is_pinned else \"\",\n",
    "                                                   \" grad\" if obj.requires_grad else \"\", \n",
    "                                                   \" volatile\" if obj.volatile else \"\",\n",
    "                                                   pretty_size(obj.data.size())))                    \n",
    "                    \n",
    "                    total_size += obj.data.numel()\n",
    "                    del obj\n",
    "            del obj\n",
    "        except Exception as e:\n",
    "            pass        \n",
    "    print(\"Total size:\", total_size)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1029b14a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# layerwise training of procogan \n",
    "def layerwise(truncation_indices, root,\n",
    "              attribute_path=None, attribute_index=20, \n",
    "              init_resolution=4, final_resolution=128, \n",
    "              n_validation=1000, training_samples=50000, \n",
    "              training_gen_samples=50000, d_latent=48, \n",
    "              upsample=True, plot_histogram=False, \n",
    "              bias=True, match_hist=True, \n",
    "              cropped=False, use_attr=False, \n",
    "              not_attr=False, fit_upsampled=False,\n",
    "              avg=False, save_img=True, skip_connection=False, device='cpu'):\n",
    "    \n",
    "    \"\"\"Implementation of layerwise training of ProCoGAN. Requires truncation_indices, which determines how much\n",
    "    to truncate at each stage (i.e. how many singular values of the true data to keep), and root, which represents\n",
    "    the root directory of the CelebA dataset. If one desires, one can filter images by a particular attribute\n",
    "    (e.g. facial hair) if attribute_path, attribute_index, and use_attr are specified. May use GPU or CPU. \n",
    "    \"\"\"\n",
    "    \n",
    "    assert d_latent >= 3/truncation_indices[0]*(init_resolution)**2, \"latent dimension must be greater than truncation dimension\"\n",
    "    start_all = time.time()\n",
    "    \n",
    "    if fit_upsampled:\n",
    "        upsample=False\n",
    "    \n",
    "    print('loading data')\n",
    "    if use_attr and attribute_path is not None:\n",
    "        attr_data = np.loadtxt(attribute_path, skiprows=2, dtype='str')\n",
    "        fnames = attr_data[:, 0]\n",
    "        if not_attr:\n",
    "             attrs = (attr_data[:, attribute_index+1].astype(np.int32) - 1).astype('bool')\n",
    "        else:\n",
    "            attrs = (attr_data[:, attribute_index+1].astype(np.int32) + 1).astype('bool')\n",
    "        relevant_images = set(fnames[attrs])\n",
    "\n",
    "        images_path =np.array([ os.path.join(root, item)  for item in os.listdir(root) if item in relevant_images])\n",
    "        print(len(images_path), 'images with desired attribute')\n",
    "    else:\n",
    "        images_path =np.array([ os.path.join(root, item)  for item in os.listdir(root)])\n",
    "\n",
    "    weights = []\n",
    "    num_blocks = int(np.log2(final_resolution//init_resolution)) + 1\n",
    "    \n",
    "    curr_images_path = images_path[:training_samples]\n",
    "    lazy_arrays = [composition(fn, final_resolution, cropped=cropped) for fn in curr_images_path]\n",
    "    \n",
    "    curr_res = init_resolution\n",
    "    \n",
    "    for i in range(num_blocks):\n",
    "        with torch.no_grad():\n",
    "            print('res', curr_res, 'reshaping data')\n",
    "\n",
    "            if fit_upsampled:\n",
    "                curr_arrs = [np.array(im1.resize((curr_res, curr_res)).resize((final_resolution, final_resolution), resample=0), \n",
    "                                      dtype='uint8') for im1 in lazy_arrays]\n",
    "            else:\n",
    "                curr_arrs = [np.array(im1.resize((curr_res, curr_res)), dtype='uint8') for im1 in lazy_arrays]\n",
    "            image_data = np.stack(curr_arrs)\n",
    "            celebA= image_data/255\n",
    "            \n",
    "            celeb_a_means = np.mean(celebA, axis=0)\n",
    "            celeb_a_centered = celebA - celeb_a_means\n",
    "\n",
    "            celeb_a_centered = celeb_a_centered.reshape((celeb_a_centered.shape[0], -1))\n",
    "\n",
    "            d = celeb_a_centered.shape[1]\n",
    "\n",
    "            truncation_idx = truncation_indices[i] # keep fraction of singular values for this experiment at each block\n",
    "\n",
    "            print('computing SVD of inputs')\n",
    "            Z = np.random.randn(training_gen_samples, d_latent)\n",
    "            Z_hat = Z\n",
    "            tmp_im_size = init_resolution\n",
    "\n",
    "            if bias:\n",
    "                Z_hat = np.concatenate((Z_hat, np.ones((len(Z_hat), 1))), 1)\n",
    "\n",
    "            # forward pass through the network\n",
    "            for j, w in enumerate(weights):\n",
    "                if device=='cuda':\n",
    "                    Z_hat = torch.from_numpy(Z_hat).float().cuda()\n",
    "                    w = torch.from_numpy(w).float().cuda()\n",
    "\n",
    "                if skip_connection and j > 0:\n",
    "                    Z_hat_prev = Z_hat\n",
    "                    Z_hat =Z_hat @ w \n",
    "                    Z_hat += Z_hat_prev[:, :Z_hat.shape[1]]\n",
    "                else: \n",
    "                    Z_hat = Z_hat @ w\n",
    "\n",
    "                if device=='cuda':\n",
    "                    Z_hat = Z_hat.detach().cpu().numpy()\n",
    "                    w = w.detach().cpu().numpy()\n",
    "\n",
    "                # upsample\n",
    "                if upsample:\n",
    "                    Z_hat = Z_hat.reshape((Z_hat.shape[0], tmp_im_size, tmp_im_size, 3))\n",
    "                    Z_hat = Z_hat.repeat(2, axis=1).repeat(2, axis=2)\n",
    "                    Z_hat = Z_hat.reshape((Z_hat.shape[0], -1))\n",
    "                    tmp_im_size *= 2\n",
    "\n",
    "                if bias:\n",
    "                    Z_hat = np.concatenate((Z_hat, np.ones((len(Z_hat), 1))), 1)\n",
    "\n",
    "            start = time.time()\n",
    "            if device=='cuda':\n",
    "                \n",
    "                if curr_res < 64:\n",
    "                    _, D, q = torch.svd(torch.from_numpy(Z_hat).float().cuda())\n",
    "                    qt = q.t()\n",
    "                else:\n",
    "                    _, D, q = torch.svd_lowrank(torch.from_numpy(Z_hat).float().cuda(), q=truncation_idx)\n",
    "                    qt = q.t()\n",
    "                del _\n",
    "            else:\n",
    "                _, D, qt = np.linalg.svd(Z_hat)\n",
    "            print('svd took', time.time()- start)\n",
    "\n",
    "            if skip_connection and i > 0:\n",
    "                celeb_a_centered -= Z_hat[:, :d]\n",
    "            \n",
    "            print('computing svd')\n",
    "            start = time.time()\n",
    "            if device=='cuda':\n",
    "                if curr_res < 64:\n",
    "                    _, s, v = torch.svd(torch.from_numpy(celeb_a_centered).float().cuda())\n",
    "                    vt = v.t()\n",
    "                else:\n",
    "                    _, s, v = torch.svd_lowrank(torch.from_numpy(celeb_a_centered).float().cuda(), q=truncation_idx+1)\n",
    "                    vt = v.t()\n",
    "            else:\n",
    "                _, s, vt = np.linalg.svd(celeb_a_centered)\n",
    "            print('svd took', time.time()- start)\n",
    "            \n",
    "            beta = s[truncation_idx]**2\n",
    "            print('beta', beta)\n",
    "\n",
    "            if device=='cuda':\n",
    "                W_g = qt[:truncation_idx, :].t() @ torch.diag(1/D[:truncation_idx]) @ torch.diag(torch.sqrt(s[:truncation_idx]**2-beta)) @ vt[:truncation_idx, :]\n",
    "                W_g = W_g.detach().cpu().numpy()\n",
    "            else:\n",
    "                W_g = qt[:truncation_idx, :].T @np.diag(1/D[:truncation_idx]) @ np.diag(np.sqrt(s[:truncation_idx]**2-beta)) @ vt[:truncation_idx, :]\n",
    "\n",
    "            weights.append(W_g)\n",
    "\n",
    "            print('computing validation instances')\n",
    "            out = np.random.randn(n_validation, d_latent)\n",
    "            if bias:\n",
    "                out = np.concatenate((out, np.ones((len(out), 1))), 1)\n",
    "\n",
    "            tmp_im_size = init_resolution\n",
    "\n",
    "            # forward pass through the network\n",
    "            for j, w in enumerate(weights):\n",
    "                if device=='cuda':\n",
    "                    out = torch.from_numpy(out).float().cuda()\n",
    "                    w = torch.from_numpy(w).float().cuda()\n",
    "\n",
    "                if skip_connection and j > 0:\n",
    "                    out_prev = out\n",
    "                    out = out @ w \n",
    "                    out = out + out_prev[:, :out.shape[1]]\n",
    "                else: \n",
    "                    out = out @ w\n",
    "\n",
    "                if device=='cuda':\n",
    "                    out = out.detach().cpu().numpy()\n",
    "                    w = w.detach().cpu().numpy()\n",
    "\n",
    "                if upsample and  j < len(weights)-1:\n",
    "                    # upsample\n",
    "                    out = out.reshape((out.shape[0], tmp_im_size, tmp_im_size, 3))\n",
    "                    out = out.repeat(2, axis=1).repeat(2, axis=2)\n",
    "                    out= out.reshape((out.shape[0], -1))\n",
    "                    tmp_im_size *= 2\n",
    "\n",
    "                if bias and j < len(weights) - 1:\n",
    "                    out = np.concatenate((out, np.ones((len(out), 1))), 1)\n",
    "\n",
    "            if fit_upsampled:\n",
    "                validation_generated_data = out.reshape((n_validation, final_resolution, final_resolution,3))\n",
    "            else:\n",
    "                validation_generated_data = out.reshape((n_validation, curr_res, curr_res,3))\n",
    "            \n",
    "            validation_generated_data += celeb_a_means\n",
    "\n",
    "            display_img = validation_generated_data\n",
    "\n",
    "            if avg:\n",
    "                display_img = (display_img[:n_validation//2] + display_img[n_validation//2:])/2\n",
    "\n",
    "            if match_hist:\n",
    "                display_img = match_histograms(display_img, celebA, multichannel=True)\n",
    "\n",
    "\n",
    "            print('real data')\n",
    "            fig = plt.figure(figsize=(10., 10.))\n",
    "            grid = ImageGrid(fig, 111,  # similar to subplot(111)\n",
    "                             nrows_ncols=(3, 3),\n",
    "                             axes_pad=0.1,  # pad between axes in inch.\n",
    "                             )\n",
    "\n",
    "            for ax, im in zip(grid, celebA[:9]):\n",
    "                # Iterating over the grid returns the Axes.\n",
    "                ax.imshow(im)\n",
    "\n",
    "            plt.show()\n",
    "\n",
    "            print('generated data')\n",
    "            fig = plt.figure(figsize=(10., 10.))\n",
    "            grid = ImageGrid(fig, 111,  # similar to subplot(111)\n",
    "                             nrows_ncols=(3, 3),\n",
    "                             axes_pad=0.1,  # pad between axes in inch.\n",
    "                             )\n",
    "\n",
    "            for ax, im in zip(grid, display_img[:9]):\n",
    "                # Iterating over the grid returns the Axes.\n",
    "                ax.imshow(im)\n",
    "\n",
    "\n",
    "            if plot_histogram:\n",
    "                print('histogram of generated data samples')\n",
    "                fig = plt.figure()\n",
    "                fig, ((ax0, ax1, ax2), (ax3, ax4, ax5), (ax6, ax7, ax8)) = plt.subplots(nrows=3, ncols=3,\n",
    "                                figsize=(10., 10.)# pad between axes in inch.\n",
    "                                 )\n",
    "\n",
    "                ax0.hist(validation_generated_data[0].flatten())\n",
    "                ax1.hist(validation_generated_data[1].flatten())\n",
    "                ax2.hist(validation_generated_data[2].flatten())\n",
    "                ax3.hist(validation_generated_data[3].flatten())\n",
    "                ax4.hist(validation_generated_data[4].flatten())\n",
    "                ax5.hist(validation_generated_data[5].flatten())\n",
    "                ax6.hist(validation_generated_data[6].flatten())\n",
    "                ax7.hist(validation_generated_data[7].flatten())\n",
    "                ax8.hist(validation_generated_data[8].flatten())\n",
    "            \n",
    "            if save_img:\n",
    "                print('saving images')\n",
    "                dirname = '../generated_images' + str(datetime.now())\n",
    "                os.mkdir(dirname)\n",
    "                for i, im in enumerate(display_img[:50]):\n",
    "                    plt.imsave(os.path.join(dirname, str(i)+'.png'), im)\n",
    "            \n",
    "            plt.show()\n",
    "            curr_res *= 2\n",
    "        \n",
    "    print('total running time', time.time()-start_all)\n",
    "\n",
    "    print('calculating FID score')\n",
    "    fid = calculate_fretchet(celebA,display_img, batch_size=20)\n",
    "             \n",
    "    return validation_generated_data, weights, fid"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e2c59225",
   "metadata": {},
   "source": [
    "# Example Run"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dc546a7a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set CelebA directory here\n",
    "root=\"/mnt/raid3/sahiner/img_align_celeba\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d0532498",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Runs 3 times with the lower beta values reported in Figure 3 of the main paper, and saves 50 \n",
    "# representative images from each resolution. Clears GPU memory between each run. \n",
    "\n",
    "runs_low_beta = 3\n",
    "\n",
    "for i in range(runs_low_beta):\n",
    "\n",
    "    if i==0:\n",
    "        saveimg = False\n",
    "    else:\n",
    "        saveimg=False\n",
    "    res= layerwise(truncation_indices=[20, 30, 40, 100, 175],\n",
    "                   root=root, init_resolution=4, \n",
    "                   final_resolution=64, \n",
    "                   training_samples=50000, \n",
    "                   training_gen_samples=50000, d_latent=48, \n",
    "                   upsample=True, plot_histogram=False, \n",
    "                   bias=True, match_hist=True, \n",
    "                   cropped=False, not_attr=True, use_attr=False, \n",
    "                   fit_upsampled=False, avg=False, save_img=saveimg, \n",
    "                   skip_connection=False, device='cuda')\n",
    "    \n",
    "    torch.cuda.empty_cache()\n",
    "    torch.cuda.ipc_collect()\n",
    "    gc.collect()\n",
    "    dump_tensors()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f0db4b0b",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
