{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5331ae1d-2c50-423b-83df-d28354a89fc5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "import json\n",
    "import argparse\n",
    "import numpy as np\n",
    "import math\n",
    "from einops import rearrange\n",
    "import time\n",
    "import random\n",
    "import string\n",
    "import h5py\n",
    "from tqdm import tqdm\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torchvision import transforms\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "\n",
    "# SDXL unCLIP requires code from https://github.com/Stability-AI/generative-models/tree/main\n",
    "sys.path.append('generative_models/')\n",
    "import sgm\n",
    "from generative_models.sgm.modules.encoders.modules import FrozenOpenCLIPImageEmbedder, FrozenOpenCLIPEmbedder2\n",
    "from generative_models.sgm.models.diffusion import DiffusionEngine\n",
    "from generative_models.sgm.util import append_dims\n",
    "from omegaconf import OmegaConf\n",
    "\n",
    "# tf32 data type is faster than standard float32\n",
    "torch.backends.cuda.matmul.allow_tf32 = True\n",
    "\n",
    "# custom functions #\n",
    "import v2_utils as utils\n",
    "utils.set_device('cuda:0')\n",
    "from v2_models import *\n",
    "from dataloaders import ImageVoxelDataset\n",
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "print(\"device:\",device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "adf01525-5832-4676-b677-207c1bc90196",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_path='data/new_dl'\n",
    "cache_dir='checkpoints'\n",
    "model_name='subj1_l_bclip_basictest_wll_5_finetune_s7_250_aamax'\n",
    "checkpoint_name = model_name\n",
    "checkpoint_dir = os.path.abspath(f'train_logs/{checkpoint_name}')\n",
    "checkpoint_tag = 'mid_200'\n",
    "subj=7\n",
    "hidden_dim=4096\n",
    "n_blocks=4\n",
    "blurry_recon=True\n",
    "seed=42\n",
    "split_id=1\n",
    "batch_size=1\n",
    "clip_seq_dim = 256\n",
    "clip_emb_dim = 1664\n",
    "num_voxels_list = [15724,12682]\n",
    "plotting=True\n",
    "use_ridge=True\n",
    "test_gt=False #Set this to true if you want to run recons on ground truth clip embeddings to evaluate baseline unclip performance\n",
    "data_type=torch.float16"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4bb853e4-1115-474e-8c77-2682f0f9f205",
   "metadata": {},
   "outputs": [],
   "source": [
    "!ls $checkpoint_dir"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d394d6aa-fa6e-4d69-b974-59aaee4d44f6",
   "metadata": {},
   "outputs": [],
   "source": [
    "split_id=1\n",
    "if split_id==1:\n",
    "    split_val=\"\"\n",
    "else:\n",
    "    split_val=split_id"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1ed1dce8-cdcb-4111-9c06-9855b0bf7273",
   "metadata": {},
   "outputs": [],
   "source": [
    "os.makedirs(f\"recons/{model_name}\",exist_ok=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "47eef2d6-936d-4cbc-8ab2-c9a561b88716",
   "metadata": {},
   "source": [
    "### Build Dataloader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1f65093d-08e5-4adb-ae2f-94ec054f0f50",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Test DL\n",
    "\n",
    "feature_file = f'data/new_dl/subj{subj:02d}/train/custom_split/test/betas.pt'\n",
    "image_file = f'data/new_dl/subj{subj:02d}/train/custom_split/test/images.pt'\n",
    "\n",
    "test_dataset = ImageVoxelDataset(feature_file, image_file)\n",
    "test_dl = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=0)\n",
    "print(len(test_dataset))\n",
    "voxel,image = next(iter(test_dl))\n",
    "print(voxel.shape, image.shape)\n",
    "print(image.min(),image.max())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a5276932-6e92-4973-8587-111817bc1d63",
   "metadata": {},
   "source": [
    "## Build Model"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cffaef31-7e5f-41c8-95e7-8dd20a5690d8",
   "metadata": {},
   "source": [
    "### Clipper"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "552654ea-9d6c-4793-a5d0-7520e885f469",
   "metadata": {},
   "outputs": [],
   "source": [
    "clip_img_embedder = FrozenOpenCLIPImageEmbedder(\n",
    "    arch=\"ViT-bigG-14\",\n",
    "    version=\"laion2b_s39b_b160k\",\n",
    "    output_tokens=True,\n",
    "    only_tokens=True,\n",
    ")\n",
    "clip_img_embedder.to(device)\n",
    "clip_seq_dim = 256\n",
    "clip_emb_dim = 1664"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6cad11c5-fb7d-4d6f-b35d-ff0df0d4c5d2",
   "metadata": {},
   "source": [
    "### Low Level"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b9a8536b-31ba-406f-9c56-d372c3260a38",
   "metadata": {},
   "outputs": [],
   "source": [
    "if blurry_recon:\n",
    "    from diffusers import AutoencoderKL\n",
    "    autoenc = AutoencoderKL(\n",
    "        down_block_types=['DownEncoderBlock2D', 'DownEncoderBlock2D', 'DownEncoderBlock2D', 'DownEncoderBlock2D'],\n",
    "        up_block_types=['UpDecoderBlock2D', 'UpDecoderBlock2D', 'UpDecoderBlock2D', 'UpDecoderBlock2D'],\n",
    "        block_out_channels=[128, 256, 512, 512],\n",
    "        layers_per_block=2,\n",
    "        sample_size=256,\n",
    "    )\n",
    "    ckpt = torch.load(f'{cache_dir}/sd_image_var_autoenc.pth')\n",
    "    autoenc.load_state_dict(ckpt)\n",
    "    autoenc.eval()\n",
    "    autoenc.requires_grad_(False)\n",
    "    autoenc.to(device)\n",
    "    utils.count_params(autoenc)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "aea06b6f-4106-4a34-9370-cd4d467c3563",
   "metadata": {},
   "source": [
    "### High Level"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "54409f02-f574-4fbf-805a-4868e27f982c",
   "metadata": {},
   "outputs": [],
   "source": [
    "class fMRIModule(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(fMRIModule, self).__init__()\n",
    "    def forward(self, x):\n",
    "        return x\n",
    "        \n",
    "model = fMRIModule()\n",
    "model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "622e999f-5a3e-4355-bb95-5fc4224f0c13",
   "metadata": {},
   "outputs": [],
   "source": [
    "class RidgeRegression(torch.nn.Module):\n",
    "    def __init__(self, input_sizes, out_features):\n",
    "        super(RidgeRegression, self).__init__()\n",
    "        self.out_features = out_features\n",
    "        self.linears = torch.nn.ModuleList([\n",
    "                nn.Sequential(\n",
    "                    torch.nn.Linear(input_size, out_features),\n",
    "                    # nn.LayerNorm(out_features),\n",
    "                    # nn.GELU(),\n",
    "                    # nn.Dropout(0.5)\n",
    "                ) for input_size in input_sizes\n",
    "            ])\n",
    "    def forward(self, x, subj_idx):\n",
    "        out = self.linears[subj_idx](x).unsqueeze(1)\n",
    "        return out\n",
    "        \n",
    "model.ridge = RidgeRegression(num_voxels_list, out_features=hidden_dim).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "532faa3c-eef5-4121-9971-f89db63c2924",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Sanity check\n",
    "b = torch.randn((2,num_voxels_list[0])).to(device)\n",
    "print(b.shape)\n",
    "print(b.shape, model.ridge(b,0).shape, b[:,0].shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0c49ad78-5a61-4eaf-846f-3d8af11e1154",
   "metadata": {},
   "outputs": [],
   "source": [
    "from v2_models import BrainNetwork\n",
    "#Set in_dim as anything if you dont care about the lin0 module\n",
    "voxel2clip_kwargs = dict(in_dim=hidden_dim, seq_len=1, h=hidden_dim, out_dim=clip_emb_dim*clip_seq_dim,\\\n",
    "                         clip_size=clip_emb_dim, clip_scale=1, blurry_recon=blurry_recon, n_blocks=4)\n",
    "voxel2clip = BrainNetwork(**voxel2clip_kwargs)\n",
    "model.voxel2clip = voxel2clip"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "078afa7c-13c4-4aa5-8555-78b2c5066789",
   "metadata": {},
   "outputs": [],
   "source": [
    "utils.count_params(model.voxel2clip)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "555d5418-a861-4b7d-b5ee-340234ecaa85",
   "metadata": {},
   "outputs": [],
   "source": [
    "from v2_models import *\n",
    "\n",
    "out_dim = clip_emb_dim\n",
    "depth = 6\n",
    "dim_head = 52\n",
    "heads = clip_emb_dim//52 # heads * dim_head = clip_emb_dim\n",
    "timesteps = 100\n",
    "\n",
    "prior_network = PriorNetwork(\n",
    "        dim=out_dim,\n",
    "        depth=depth,\n",
    "        dim_head=dim_head,\n",
    "        heads=heads,\n",
    "        causal=False,\n",
    "        num_tokens = clip_seq_dim,\n",
    "        learned_query_mode=\"pos_emb\"\n",
    "    )\n",
    "\n",
    "model.diffusion_prior = BrainDiffusionPrior(\n",
    "    net=prior_network,\n",
    "    image_embed_dim=out_dim,\n",
    "    condition_on_text_encodings=False,\n",
    "    timesteps=timesteps,\n",
    "    cond_drop_prob=0.2,\n",
    "    image_embed_scale=None,\n",
    ")\n",
    "\n",
    "utils.count_params(model.diffusion_prior)\n",
    "utils.count_params(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7619d08d-bd64-41ff-9b61-827dc1db5eb2",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "model.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1ce6ef4a-d053-46f5-8e56-f6edb2a456a3",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"Loading {checkpoint_dir}/{checkpoint_tag}.pth\")\n",
    "checkpoint = torch.load(f'{checkpoint_dir}/{checkpoint_tag}.pth', map_location='cpu')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3b2951e2-f6e7-4233-8da8-cd0560cca4ad",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Set strict to false as you wont have any weights for the new adapter\n",
    "state_dict = checkpoint['model_state_dict']\n",
    "model.load_state_dict(state_dict, strict=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cbefa5cc-1a20-4dd3-9d09-769c747d3cba",
   "metadata": {},
   "source": [
    "### CLIP caption generator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b29a5536-cebc-4ec7-9759-fe538ac27034",
   "metadata": {},
   "outputs": [],
   "source": [
    "# setup text caption networks\n",
    "from transformers import AutoProcessor, AutoModelForCausalLM\n",
    "from modeling_git import GitForCausalLMClipEmb\n",
    "processor = AutoProcessor.from_pretrained(\"microsoft/git-large-coco\")\n",
    "clip_text_model = GitForCausalLMClipEmb.from_pretrained(\"microsoft/git-large-coco\")\n",
    "clip_text_model.to(device) # if you get OOM running this script, you can switch this to cpu and lower minibatch_size to 4\n",
    "clip_text_model.eval().requires_grad_(False)\n",
    "clip_text_seq_dim = 257\n",
    "clip_text_emb_dim = 1024\n",
    "\n",
    "class CLIPConverter(torch.nn.Module):\n",
    "    def __init__(self):\n",
    "        super(CLIPConverter, self).__init__()\n",
    "        self.linear1 = nn.Linear(clip_seq_dim, clip_text_seq_dim)\n",
    "        self.linear2 = nn.Linear(clip_emb_dim, clip_text_emb_dim)\n",
    "    def forward(self, x):\n",
    "        x = x.permute(0,2,1)\n",
    "        x = self.linear1(x)\n",
    "        x = self.linear2(x.permute(0,2,1))\n",
    "        return x\n",
    "        \n",
    "clip_convert = CLIPConverter()\n",
    "state_dict = torch.load(f\"{cache_dir}/bigG_to_L_epoch8.pth\", map_location='cpu')['model_state_dict']\n",
    "clip_convert.load_state_dict(state_dict, strict=True)\n",
    "clip_convert.to(device) # if you get OOM running this script, you can switch this to cpu and lower minibatch_size to 4\n",
    "del state_dict"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "146af552-840c-4ac6-9b53-faf59edb7544",
   "metadata": {},
   "source": [
    "### unCLIP"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3b226e43-cead-4dcc-a529-2d85e62cf824",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# prep unCLIP\n",
    "config = OmegaConf.load(\"generative_models/configs/unclip6.yaml\")\n",
    "config = OmegaConf.to_container(config, resolve=True)\n",
    "unclip_params = config[\"model\"][\"params\"]\n",
    "network_config = unclip_params[\"network_config\"]\n",
    "denoiser_config = unclip_params[\"denoiser_config\"]\n",
    "first_stage_config = unclip_params[\"first_stage_config\"]\n",
    "conditioner_config = unclip_params[\"conditioner_config\"]\n",
    "sampler_config = unclip_params[\"sampler_config\"]\n",
    "scale_factor = unclip_params[\"scale_factor\"]\n",
    "disable_first_stage_autocast = unclip_params[\"disable_first_stage_autocast\"]\n",
    "offset_noise_level = unclip_params[\"loss_fn_config\"][\"params\"][\"offset_noise_level\"]\n",
    "\n",
    "first_stage_config['target'] = 'sgm.models.autoencoder.AutoencoderKL'\n",
    "sampler_config['params']['num_steps'] = 38\n",
    "\n",
    "diffusion_engine = DiffusionEngine(network_config=network_config,\n",
    "                       denoiser_config=denoiser_config,\n",
    "                       first_stage_config=first_stage_config,\n",
    "                       conditioner_config=conditioner_config,\n",
    "                       sampler_config=sampler_config,\n",
    "                       scale_factor=scale_factor,\n",
    "                       disable_first_stage_autocast=disable_first_stage_autocast)\n",
    "# set to inference\n",
    "diffusion_engine.eval().requires_grad_(False)\n",
    "diffusion_engine.to(device)\n",
    "\n",
    "ckpt_path = f'{cache_dir}/unclip6_epoch0_step110000.ckpt'\n",
    "ckpt = torch.load(ckpt_path, map_location='cpu')\n",
    "diffusion_engine.load_state_dict(ckpt['state_dict'])\n",
    "\n",
    "batch={\"jpg\": torch.randn(1,3,1,1).to(device), # jpg doesnt get used, it's just a placeholder\n",
    "      \"original_size_as_tuple\": torch.ones(1, 2).to(device) * 768,\n",
    "      \"crop_coords_top_left\": torch.zeros(1, 2).to(device)}\n",
    "out = diffusion_engine.conditioner(batch)\n",
    "vector_suffix = out[\"vector\"].to(device)\n",
    "print(\"vector_suffix\", vector_suffix.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c9c25a41-dd14-486a-97ad-c051c85c326b",
   "metadata": {},
   "outputs": [],
   "source": [
    "utils.count_params(diffusion_engine)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "72df7da5-638a-43fe-a3ec-1ea666ca60fa",
   "metadata": {},
   "source": [
    "## Generate Unrefined Reconstructions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "46a29f3e-9f2e-4298-abd8-ec900f6f1400",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "model.eval().requires_grad_(False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2637cb7d-f6d7-4449-8282-8b6b9060165d",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "all_images = []\n",
    "all_blurryrecons = None\n",
    "all_recons = None\n",
    "all_predcaptions = []\n",
    "all_clipvoxels = None\n",
    "num_samples_per_image = 1\n",
    "assert num_samples_per_image == 1\n",
    "\n",
    "if utils.is_interactive(): plotting=True\n",
    "with torch.no_grad(), torch.cuda.amp.autocast(dtype=data_type):\n",
    "    for test_i, (voxel, img) in enumerate(tqdm(test_dl)):\n",
    "        voxel = torch.mean(voxel,axis=1).to(device)\n",
    "        all_images.append(img)\n",
    "        if use_ridge:\n",
    "            voxel_ridge = model.ridge(voxel, 1)\n",
    "        else:\n",
    "            voxel_ridge = voxel\n",
    "        backbone, clip_voxels, blurry_image_enc = model.voxel2clip(voxel_ridge)\n",
    "        blurry_image_enc = blurry_image_enc[0]\n",
    "\n",
    "        # Feed voxels through OpenCLIP-bigG diffusion prior\n",
    "        prior_out = model.diffusion_prior.p_sample_loop(backbone.shape, \n",
    "                        text_cond = dict(text_embed = backbone),\n",
    "                        cond_scale = 1., timesteps = 20)\n",
    "        if test_gt:\n",
    "            clip_target = clip_img_embedder(img.to(device)).to(device)\n",
    "            pred_caption_emb = clip_convert(clip_target)\n",
    "        else:\n",
    "            pred_caption_emb = clip_convert(prior_out)\n",
    "       \n",
    "        generated_ids = clip_text_model.generate(pixel_values=pred_caption_emb, max_length=20)\n",
    "        generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)\n",
    "        all_predcaptions = np.hstack((all_predcaptions, generated_caption))\n",
    "        print(generated_caption)\n",
    "        for i in range(len(voxel)):\n",
    "            if test_gt:\n",
    "                samples = utils.unclip_recon(clip_target,\n",
    "                                 diffusion_engine,\n",
    "                                 vector_suffix,\n",
    "                                 num_samples=num_samples_per_image)\n",
    "            else:\n",
    "                samples = utils.unclip_recon(prior_out,\n",
    "                                 diffusion_engine,\n",
    "                                 vector_suffix,\n",
    "                                 num_samples=num_samples_per_image)\n",
    "\n",
    "            if all_recons is None:\n",
    "                all_recons = samples.cpu()\n",
    "            else:\n",
    "                all_recons = torch.vstack((all_recons, samples.cpu()))\n",
    "            if plotting:\n",
    "                for s in range(num_samples_per_image):\n",
    "                    plt.figure(figsize=(2,2))\n",
    "                    plt.imshow(transforms.ToPILImage()(samples[s]))\n",
    "                    plt.axis('off')\n",
    "                    plt.show()\n",
    "        if blurry_recon:\n",
    "            blurred_image = (autoenc.decode(blurry_image_enc/0.18215).sample/ 2 + 0.5).clamp(0,1)\n",
    "            \n",
    "            for i in range(len(voxel)):\n",
    "                im = torch.Tensor(blurred_image[i])\n",
    "                if all_blurryrecons is None:\n",
    "                    all_blurryrecons = im[None].cpu()\n",
    "                else:\n",
    "                    all_blurryrecons = torch.vstack((all_blurryrecons, im[None].cpu()))\n",
    "                if plotting:\n",
    "                    plt.figure(figsize=(2,2))\n",
    "                    plt.imshow(transforms.ToPILImage()(im))\n",
    "                    plt.axis('off')\n",
    "                    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a59755ca-e72d-476d-9285-8d763a716cd0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# resize outputs before saving\n",
    "imsize = 256\n",
    "all_recons = transforms.Resize((imsize,imsize))(all_recons).float()\n",
    "print(all_recons.shape)\n",
    "all_images =torch.vstack(all_images)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1a14d588-f9f8-4363-bc90-c86e52453ef0",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(all_images.shape)\n",
    "print(all_blurryrecons.shape)\n",
    "print(all_recons.shape)\n",
    "print(all_predcaptions.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f5d470aa-ce66-49d9-a568-e362d1e4888d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# all_images = all_images[:-1]\n",
    "# all_predcaptions = all_predcaptions[:-1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "96c74edd-8121-47a4-8c7a-4b829af821f7",
   "metadata": {},
   "outputs": [],
   "source": [
    "if blurry_recon: \n",
    "    all_blurryrecons = transforms.Resize((imsize,imsize))(all_blurryrecons).float()\n",
    "    torch.save(all_blurryrecons,f\"recons/{model_name}_all_blurryrecons.pt\")\n",
    "torch.save(all_images,f\"evals/all_images_subj{subj:02d}.pt\")\n",
    "torch.save(all_recons,f\"recons/{model_name}_all_recons.pt\")\n",
    "torch.save(all_predcaptions,f\"recons/{model_name}_all_predcaptions.pt\")\n",
    "print(f\"saved {model_name} outputs!\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ce76da0e-9628-41f9-8b64-e9ca01a86a8b",
   "metadata": {},
   "source": [
    "### Visualize"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8b119917-0d7f-4349-a4ca-3bbf011e4a72",
   "metadata": {},
   "outputs": [],
   "source": [
    "i = 9\n",
    "\n",
    "def to_np(img):\n",
    "    img = img.detach().cpu().permute(1, 2, 0)\n",
    "    return img.numpy().clip(0, 1)\n",
    "\n",
    "fig, ax = plt.subplots(1, 2, figsize=(8, 4))\n",
    "ax[0].imshow(to_np(all_recons[i]))\n",
    "ax[0].set_title(\"Reconstruction\")\n",
    "ax[0].axis(\"off\")\n",
    "\n",
    "ax[1].imshow(to_np(all_images[i]))\n",
    "ax[1].set_title(\"Ground Truth\")\n",
    "ax[1].axis(\"off\")\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3fe40754-86f4-4a50-9aac-8eb0cf2b200c",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
