{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": 1,
      "id": "x2YPjbNS9rEF",
      "metadata": {
        "id": "x2YPjbNS9rEF"
      },
      "outputs": [],
      "source": [
        "import random\n",
        "import math\n",
        "import numpy as np\n",
        "import torch\n",
        "from sklearn import datasets as sklearn_datasets\n",
        "import torch.nn.functional as F\n",
        "import matplotlib.pyplot as plt\n",
        "from IPython.display import clear_output\n",
        "import torch.nn as nn\n",
        "import einops\n",
        "\n",
        "\n",
        "import os, sys\n",
        "sys.path.append(\"..\")\n",
        "sys.path.append(\"../ALAE\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 24,
      "id": "qFyiuAODBQNm",
      "metadata": {
        "id": "qFyiuAODBQNm"
      },
      "outputs": [],
      "source": [
        "#%cd ../ALAE\n",
        "\n",
        "#! sed -i '/bimpy/d' requirements.txt\n",
        "#! sed -i '/dareblopy/d' requirements.txt\n",
        "#! sed -i 's/^sklearn$/scikit-learn/' requirements.txt\n",
        "\n",
        "#! pip install -r requirements.txt\n",
        "#! python training_artifacts/download_all.py\n",
        "#%cd ../notebooks"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "id": "k6Bh0BEc9sXq",
      "metadata": {
        "id": "k6Bh0BEc9sXq"
      },
      "outputs": [],
      "source": [
        "class Sampler:\n",
        "    def __init__(\n",
        "        self, device='cuda',\n",
        "    ):\n",
        "        self.device = device\n",
        "\n",
        "    def sample(self, size=5):\n",
        "        pass\n",
        "\n",
        "\n",
        "class TensorSampler(Sampler):\n",
        "    def __init__(self, tensor, device='cuda'):\n",
        "        super(TensorSampler, self).__init__(device)\n",
        "        self.tensor = torch.clone(tensor).to(device)\n",
        "\n",
        "    def sample(self, size=5):\n",
        "        assert size <= self.tensor.shape[0]\n",
        "\n",
        "        ind = torch.tensor(np.random.choice(np.arange(self.tensor.shape[0]), size=size, replace=False), device=self.device)\n",
        "        return torch.clone(self.tensor[ind]).detach().to(self.device)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "id": "mafwQcEw9s_h",
      "metadata": {
        "id": "mafwQcEw9s_h"
      },
      "outputs": [],
      "source": [
        "class EOTConfig:\n",
        "    def __init__(self,\n",
        "                 eps: float = 0.1,\n",
        "                 batch_size: int = 2048,\n",
        "                 device: str =\"cpu\",\n",
        "                 K: int = 32,\n",
        "                 epoch: int = 100,\n",
        "                 lmc_steps: int = 100,\n",
        "                 lmc_step_size: float = 0.003,\n",
        "                 seed: int = 42,\n",
        "                 grad_clip = 100000.0,\n",
        "                 max_diff_exp_clip=100,\n",
        "                 ema_momentum = 0.999\n",
        "        ):\n",
        "        self.device = device\n",
        "        self.eps = eps\n",
        "        self.batch_size = batch_size\n",
        "        self.K = K\n",
        "        self.epoch = epoch\n",
        "        self.score_clip = 10000000.0\n",
        "        self.grad_clip = grad_clip\n",
        "        self.max_diff_exp_clip = max_diff_exp_clip\n",
        "        self.lmc_steps = lmc_steps\n",
        "        self.lmc_step_size = lmc_step_size\n",
        "        self.seed = seed\n",
        "        self.ema_momentum = ema_momentum\n",
        "\n",
        "class MLP(nn.Module):\n",
        "    def __init__(self, din=2, hidden=128, dout=1):\n",
        "        super().__init__()\n",
        "        self.net = nn.Sequential(\n",
        "            nn.Linear(din, hidden),\n",
        "            nn.SiLU(),\n",
        "            nn.Linear(hidden, hidden),\n",
        "            nn.SiLU(),\n",
        "            nn.Linear(hidden, hidden),\n",
        "            nn.SiLU(),\n",
        "            nn.Linear(hidden, dout),\n",
        "        )\n",
        "\n",
        "    def forward(self, x):\n",
        "        return self.net(x)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "id": "Lv7pMsik9tzq",
      "metadata": {
        "id": "Lv7pMsik9tzq"
      },
      "outputs": [],
      "source": [
        "DIM = 512\n",
        "INPUT_DATA = \"MAN\"\n",
        "TARGET_DATA = \"WOMAN\"\n",
        "\n",
        "OUTPUT_SEED = 0xBADBEEF\n",
        "EPSILON = 1.0\n",
        "\n",
        "EXP_NAME = f'VarEOT_ALAE_{INPUT_DATA}_TO_{TARGET_DATA}_EPSILON_{EPSILON}'"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "gM-quzUpNRAQ",
      "metadata": {
        "id": "gM-quzUpNRAQ"
      },
      "source": [
        "\n",
        "## Data loading\n",
        "### TO DOWNLOAD PRE-PROCESSED ALAE DATA, UNCOMMENT THE CODE OF THE NEXT CELL.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 9,
      "id": "cBSSw9Ji-Ccl",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "cBSSw9Ji-Ccl",
        "outputId": "a8baa00e-aadd-4a3e-9e8e-f5bad8a04e32"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "Downloading...\n",
            "From: https://drive.google.com/uc?id=1Vi6NzxCsS23GBNq48E-97Z9UuIuNaxPJ\n",
            "To: /content/data/age.npy\n",
            "100%|██████████| 560k/560k [00:00<00:00, 7.93MB/s]\n",
            "Downloading...\n",
            "From: https://drive.google.com/uc?id=1SEdsmQGL3mOok1CPTBEfc_O1750fGRtf\n",
            "To: /content/data/gender.npy\n",
            "100%|██████████| 1.68M/1.68M [00:00<00:00, 14.5MB/s]\n",
            "Downloading...\n",
            "From (original): https://drive.google.com/uc?id=1ENhiTRsHtSjIjoRu1xYprcpNd8M9aVu8\n",
            "From (redirected): https://drive.google.com/uc?id=1ENhiTRsHtSjIjoRu1xYprcpNd8M9aVu8&confirm=t&uuid=037faf7d-4c42-424d-907e-b3fb70ff4a61\n",
            "To: /content/data/latents.npy\n",
            "100%|██████████| 143M/143M [00:03<00:00, 39.7MB/s]\n",
            "Downloading...\n",
            "From (original): https://drive.google.com/uc?id=1SjBWWlPjq-dxX4kxzW-Zn3iUR3po8Z0i\n",
            "From (redirected): https://drive.google.com/uc?id=1SjBWWlPjq-dxX4kxzW-Zn3iUR3po8Z0i&confirm=t&uuid=e42adeb9-43e3-479e-9f73-04ba24a2a9bc\n",
            "To: /content/data/test_images.npy\n",
            "100%|██████████| 944M/944M [00:08<00:00, 107MB/s]\n"
          ]
        }
      ],
      "source": [
        "import gdown\n",
        "import os\n",
        "\n",
        "urls = {\n",
        "    \"../data/age.npy\": \"https://drive.google.com/uc?id=1Vi6NzxCsS23GBNq48E-97Z9UuIuNaxPJ\",\n",
        "    \"../data/gender.npy\": \"https://drive.google.com/uc?id=1SEdsmQGL3mOok1CPTBEfc_O1750fGRtf\",\n",
        "    \"../data/latents.npy\": \"https://drive.google.com/uc?id=1ENhiTRsHtSjIjoRu1xYprcpNd8M9aVu8\",\n",
        "    \"../data/test_images.npy\": \"https://drive.google.com/uc?id=1SjBWWlPjq-dxX4kxzW-Zn3iUR3po8Z0i\",\n",
        "}\n",
        "\n",
        "for name, url in urls.items():\n",
        "    gdown.download(url, os.path.join(f\"{name}\"), quiet=False)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 10,
      "id": "tDktiX-Z-NGk",
      "metadata": {
        "id": "tDktiX-Z-NGk"
      },
      "outputs": [],
      "source": [
        "train_size = 60000\n",
        "test_size = 10000\n",
        "\n",
        "latents = np.load(\"../data/latents.npy\")\n",
        "gender = np.load(\"../data/gender.npy\")\n",
        "age = np.load(\"../data/age.npy\")\n",
        "test_inp_images = np.load(\"../data/test_images.npy\")\n",
        "\n",
        "train_latents, test_latents = latents[:train_size], latents[train_size:]\n",
        "train_gender, test_gender = gender[:train_size], gender[train_size:]\n",
        "train_age, test_age = age[:train_size], age[train_size:]\n",
        "\n",
        "if INPUT_DATA == \"MAN\":\n",
        "    x_inds_train = np.arange(train_size)[(train_gender == \"male\").reshape(-1)]\n",
        "    x_inds_test = np.arange(test_size)[(test_gender == \"male\").reshape(-1)]\n",
        "elif INPUT_DATA == \"WOMAN\":\n",
        "    x_inds_train = np.arange(train_size)[(train_gender == \"female\").reshape(-1)]\n",
        "    x_inds_test = np.arange(test_size)[(test_gender == \"female\").reshape(-1)]\n",
        "elif INPUT_DATA == \"ADULT\":\n",
        "    x_inds_train = np.arange(train_size)[\n",
        "        (train_age >= 18).reshape(-1)*(train_age != -1).reshape(-1)\n",
        "    ]\n",
        "    x_inds_test = np.arange(test_size)[\n",
        "        (test_age >= 18).reshape(-1)*(test_age != -1).reshape(-1)\n",
        "    ]\n",
        "elif INPUT_DATA == \"CHILDREN\":\n",
        "    x_inds_train = np.arange(train_size)[\n",
        "        (train_age < 18).reshape(-1)*(train_age != -1).reshape(-1)\n",
        "    ]\n",
        "    x_inds_test = np.arange(test_size)[\n",
        "        (test_age < 18).reshape(-1)*(test_age != -1).reshape(-1)\n",
        "    ]\n",
        "x_data_train = train_latents[x_inds_train]\n",
        "x_data_test = test_latents[x_inds_test]\n",
        "\n",
        "if TARGET_DATA == \"MAN\":\n",
        "    y_inds_train = np.arange(train_size)[(train_gender == \"male\").reshape(-1)]\n",
        "    y_inds_test = np.arange(test_size)[(test_gender == \"male\").reshape(-1)]\n",
        "elif TARGET_DATA == \"WOMAN\":\n",
        "    y_inds_train = np.arange(train_size)[(train_gender == \"female\").reshape(-1)]\n",
        "    y_inds_test = np.arange(test_size)[(test_gender == \"female\").reshape(-1)]\n",
        "elif TARGET_DATA == \"ADULT\":\n",
        "    y_inds_train = np.arange(train_size)[\n",
        "        (train_age >= 18).reshape(-1)*(train_age != -1).reshape(-1)\n",
        "    ]\n",
        "    y_inds_test = np.arange(test_size)[\n",
        "        (test_age >= 18).reshape(-1)*(test_age != -1).reshape(-1)\n",
        "    ]\n",
        "elif TARGET_DATA == \"CHILDREN\":\n",
        "    y_inds_train = np.arange(train_size)[\n",
        "        (train_age < 18).reshape(-1)*(train_age != -1).reshape(-1)\n",
        "    ]\n",
        "    y_inds_test = np.arange(test_size)[\n",
        "        (test_age < 18).reshape(-1)*(test_age != -1).reshape(-1)\n",
        "    ]\n",
        "y_data_train = train_latents[y_inds_train]\n",
        "y_data_test = test_latents[y_inds_test]\n",
        "\n",
        "X_train = torch.tensor(x_data_train)\n",
        "Y_train = torch.tensor(y_data_train)\n",
        "\n",
        "X_test = torch.tensor(x_data_test)\n",
        "Y_test = torch.tensor(y_data_test)\n",
        "\n",
        "X_sampler = TensorSampler(X_train, device=\"cpu\")\n",
        "Y_sampler = TensorSampler(Y_train, device=\"cpu\")"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "rbplPxi2-oMS",
      "metadata": {
        "id": "rbplPxi2-oMS"
      },
      "source": [
        "# Model initialisation\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 12,
      "id": "4NFSOdYv-pSX",
      "metadata": {
        "id": "4NFSOdYv-pSX"
      },
      "outputs": [],
      "source": [
        "class EOTTrainer:\n",
        "    def __init__(self, config, source_sampler, target_sampler, model_theta, model_phi, name):\n",
        "        self.experiment_name = name\n",
        "\n",
        "        self.config = config\n",
        "        self.source_sampler = source_sampler\n",
        "        self.target_sampler = target_sampler\n",
        "\n",
        "        self.f_theta = model_theta\n",
        "        self.f_phi = model_phi\n",
        "        import copy\n",
        "        self.f_theta_ema = copy.deepcopy(self.f_theta).eval()\n",
        "        for p in self.f_theta_ema.parameters(): p.requires_grad_(False)\n",
        "        self.current_step = 0\n",
        "\n",
        "    def ema_update(self, model, ema):\n",
        "        m = self.config.ema_momentum\n",
        "        with torch.no_grad():\n",
        "            for p, pe in zip(model.parameters(), ema.parameters()):\n",
        "                pe.mul_(m).add_(p, alpha=1-m)\n",
        "\n",
        "\n",
        "    def compute_loss(self, x, y):\n",
        "        cfg = self.config\n",
        "        broad_shape = list(x.shape)\n",
        "        broad_shape.insert(1, cfg.K)\n",
        "        z = torch.randn(size=broad_shape, device=cfg.device)\n",
        "\n",
        "        x_noisy = x[:, None, :] - math.sqrt(cfg.eps(self.current_step)) * z\n",
        "\n",
        "        fphi_x = self.f_phi(x)\n",
        "        ftheta_xnoisy = self.f_theta(x_noisy.reshape(-1, *x_noisy.shape[2:])).view(cfg.batch_size, cfg.K) / cfg.eps(self.current_step)\n",
        "        ftheta_y = self.f_theta(y)\n",
        "\n",
        "        diff_in_exp = ftheta_xnoisy - fphi_x\n",
        "        diff_in_exp_truncated = torch.clamp(diff_in_exp, min=None, max=self.config.max_diff_exp_clip)\n",
        "        exp_term = torch.exp(diff_in_exp_truncated.to(torch.float64))\n",
        "\n",
        "        ftheta_y_mean = ftheta_y.mean()\n",
        "        fphi_x_mean = fphi_x.mean()\n",
        "        Loss_theor = (fphi_x_mean + exp_term.mean()) * cfg.eps(self.current_step) - ftheta_y_mean\n",
        "        L_main = Loss_theor\n",
        "        return L_main\n",
        "\n",
        "    def train_step(self):\n",
        "        x = self.source_sampler.sample(self.config.batch_size).to(self.config.device)\n",
        "        y = self.target_sampler.sample(self.config.batch_size).to(self.config.device)\n",
        "        self.train_theta = True\n",
        "        self.train_phi = True\n",
        "        loss = self.compute_loss(x, y)\n",
        "        self.opt_both.zero_grad(set_to_none=True)\n",
        "        loss.backward()\n",
        "        torch.nn.utils.clip_grad_norm_(self.f_theta.parameters(), max_norm=self.config.grad_clip)\n",
        "        torch.nn.utils.clip_grad_norm_(self.f_phi.parameters(), max_norm=self.config.grad_clip)\n",
        "        self.opt_both.step()\n",
        "        self.ema_update(self.f_theta, self.f_theta_ema)\n",
        "        self.current_step += 1\n",
        "\n",
        "    def train(self, viz_callback=None):\n",
        "        print(f\"Starting training name = {self.experiment_name}\")\n",
        "        print(\"-\" * 60)\n",
        "\n",
        "        while True:\n",
        "            L_main = self.train_step()\n",
        "            viz_callback(self)\n",
        "            if (self.current_step >= self.config.epoch):\n",
        "                break\n",
        "\n",
        "        print(\"Training complete!\")\n",
        "        return 0\n",
        "\n",
        "\n",
        "    def score_y_given_x(self, y, x):\n",
        "        y = y.detach().requires_grad_(True)\n",
        "\n",
        "        ft = self.f_theta_ema(y).sum()\n",
        "        (gy,) = torch.autograd.grad(ft, y, retain_graph=False, create_graph=False)\n",
        "        sc = (gy - (y - x)) / self.config.eps(self.current_step)\n",
        "\n",
        "        return torch.clamp(sc, -self.config.score_clip, self.config.score_clip)\n",
        "\n",
        "    def sample_pi_given_x(self, x, n=200):\n",
        "        y = x.unsqueeze(1) + 0.0 * (self.config.eps(self.current_step) ** 0.5) * torch.randn(x.shape[0], n, x.shape[1], device=self.config.device)\n",
        "        for _ in range(self.config.lmc_steps):\n",
        "            y = y.detach()\n",
        "            sc = self.score_y_given_x(y, x.unsqueeze(1))\n",
        "            with torch.no_grad():\n",
        "                y = y + self.config.lmc_step_size * sc + math.sqrt(2*self.config.lmc_step_size) * torch.randn_like(y)\n",
        "        return y.detach()\n",
        "\n",
        "    def sample_marginal(self, n_x=500):\n",
        "        cfg = self.config\n",
        "        xs = self.source_sampler.sample(n_x).to(cfg.device)\n",
        "        y = xs\n",
        "\n",
        "        for _ in range(cfg.lmc_steps):\n",
        "            y = y.detach()\n",
        "            sc = self.score_y_given_x(y, xs)\n",
        "\n",
        "            with torch.no_grad():\n",
        "                y = y + cfg.lmc_step_size * sc + \\\n",
        "                    math.sqrt(2 * cfg.lmc_step_size) * torch.randn_like(y)\n",
        "\n",
        "        return y.detach()\n",
        "\n",
        "\n",
        "    def save_checkpoint(self, filename):\n",
        "        checkpoint = {\n",
        "            'f_theta_state_dict': self.f_theta_ema.state_dict(),\n",
        "            'f_phi_state_dict': self.f_phi.state_dict(),\n",
        "            'epoch': self.current_step,\n",
        "            'eps': self.config.eps(self.current_step),\n",
        "        }\n",
        "        torch.save(checkpoint, filename)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 13,
      "id": "9gL17n3r-u5L",
      "metadata": {
        "id": "9gL17n3r-u5L"
      },
      "outputs": [],
      "source": [
        "config = EOTConfig(\n",
        "     eps = lambda step: EPSILON,\n",
        "     batch_size = 256,\n",
        "     device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\"),\n",
        "     K = 256,\n",
        "     epoch = 5000,\n",
        "     lmc_steps = 1000,\n",
        "     lmc_step_size = 0.001,\n",
        "     seed = 42,\n",
        "     ema_momentum=0.999,\n",
        "     max_diff_exp_clip=30,\n",
        "     grad_clip=1e20\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 14,
      "id": "8yTvO9f5-xSI",
      "metadata": {
        "id": "8yTvO9f5-xSI"
      },
      "outputs": [],
      "source": [
        "def visualize_training(trainer):\n",
        "    if (trainer.current_step % 1000 == 0):\n",
        "        print(trainer.current_step)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 15,
      "id": "cFSrP6gR-yNB",
      "metadata": {
        "id": "cFSrP6gR-yNB"
      },
      "outputs": [],
      "source": [
        "trainer = EOTTrainer(\n",
        "    config=config,\n",
        "    source_sampler=X_sampler,\n",
        "    target_sampler=Y_sampler,\n",
        "    model_theta=MLP(din=DIM, hidden=256).to(config.device),\n",
        "    model_phi=MLP(din=DIM, hidden=256).to(config.device),\n",
        "    name = f\"alae_test\"\n",
        ")\n",
        "\n",
        "trainer.opt_both = torch.optim.AdamW(\n",
        "            list(trainer.f_phi.parameters()) + list(trainer.f_theta.parameters()),\n",
        "            lr=3e-4,\n",
        "            weight_decay=1e-4,\n",
        "            betas = (0.7, 0.8)\n",
        "        )"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 16,
      "id": "vsqxP_YY_C0h",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "vsqxP_YY_C0h",
        "outputId": "c87fe4fe-7ce7-4356-e16e-76bb0bd9daf2"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Starting training name = alae_test\n",
            "------------------------------------------------------------\n",
            "1000\n",
            "2000\n",
            "3000\n",
            "4000\n",
            "5000\n",
            "Training complete!\n"
          ]
        },
        {
          "data": {
            "text/plain": [
              "0"
            ]
          },
          "execution_count": 16,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "trainer.train(viz_callback=lambda t: visualize_training(t))"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "qwuj55n5OrvF",
      "metadata": {
        "id": "qwuj55n5OrvF"
      },
      "source": [
        "# Results plotting\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 17,
      "id": "51TRam1JAv_p",
      "metadata": {
        "id": "51TRam1JAv_p"
      },
      "outputs": [],
      "source": [
        "from tracker import RunningMeanTorch\n",
        "torch.serialization.add_safe_globals([RunningMeanTorch])\n",
        "from alae_ffhq_inference import load_model, encode, decode\n",
        "\n",
        "model = load_model(\"../ALAE/configs/ffhq.yaml\", training_artifacts_dir=\"../ALAE/training_artifacts/ffhq/\").cuda()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 20,
      "id": "5bHIXwrZ_Kge",
      "metadata": {
        "id": "5bHIXwrZ_Kge"
      },
      "outputs": [],
      "source": [
        "N=7\n",
        "repeat=3\n",
        "torch.manual_seed(OUTPUT_SEED); np.random.seed(OUTPUT_SEED)\n",
        "inds_to_map = np.random.choice(np.arange((x_inds_test < 300).sum()), size=N, replace=False)\n",
        "mapped_all = []\n",
        "latent_to_map = torch.tensor(test_latents[x_inds_test[inds_to_map]])\n",
        "mapped = trainer.sample_pi_given_x(latent_to_map.to(trainer.config.device), n=repeat).to(trainer.config.device)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 21,
      "id": "RI2CKuqgAf42",
      "metadata": {
        "id": "RI2CKuqgAf42"
      },
      "outputs": [],
      "source": [
        "ref_plus_gen = torch.cat([latent_to_map.to(trainer.config.device).unsqueeze(1), mapped], dim=1)\n",
        "with torch.no_grad():\n",
        "    decoded_img = decode(model, ref_plus_gen.reshape(-1, 512))\n",
        "    decoded_img = ((decoded_img * 0.5 + 0.5) * 255).type(torch.long).clamp(0, 255).cpu().type(torch.uint8).permute(0, 2, 3, 1).numpy()\n",
        "    decoded_img = decoded_img.reshape(N, repeat + 1, 1024, 1024, 3)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 22,
      "id": "WAX5C-TQMEaZ",
      "metadata": {
        "id": "WAX5C-TQMEaZ"
      },
      "outputs": [],
      "source": [
        "def show_image_grid(images, figsize=(15, 6)):\n",
        "    plt.close('all')\n",
        "    %matplotlib inline\n",
        "    if isinstance(images, torch.Tensor):\n",
        "        images = images.cpu().numpy()\n",
        "\n",
        "    grid = einops.rearrange(images, 'cols rows h w c -> (rows h) (cols w) c')\n",
        "\n",
        "    if grid.max() > 1.0:\n",
        "        grid = grid / 255.0\n",
        "\n",
        "    plt.figure(figsize=figsize)\n",
        "    plt.imshow(grid)\n",
        "    plt.axis('off')\n",
        "    plt.tight_layout()\n",
        "    plt.show()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "NHTN6H4_MMlk",
      "metadata": {
        "id": "NHTN6H4_MMlk"
      },
      "outputs": [],
      "source": [
        "show_image_grid(decoded_img)"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "gpuType": "T4",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.8.10"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 5
}
