{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "f45d5893-5499-4a4f-8a57-d07445ad0eef",
   "metadata": {},
   "source": [
    "## Brain Captioning\n",
    "\n",
    "Using the GIT model we want to predict image captions from brain activity"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "8fdd3fc4-0c78-4e1c-a89b-8c05eff841ef",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from transformers import AutoProcessor\n",
    "from modeling_git import GitForCausalLM, GitModel, GitForCausalLMClipEmb\n",
    "import requests\n",
    "from PIL import Image\n",
    "import numpy as np\n",
    "import os\n",
    "import glob\n",
    "from os.path import join as opj\n",
    "import h5py  \n",
    "import matplotlib.pyplot as plt\n",
    "import pandas as pd\n",
    "import nibabel as nib\n",
    "from scipy.io import loadmat\n",
    "import torch\n",
    "from torch.utils.data import Dataset, Subset, DataLoader\n",
    "import json\n",
    "from torchsummary import summary\n",
    "import torchvision\n",
    "import tqdm\n",
    "from sklearn.linear_model import Ridge\n",
    "import pickle\n",
    "import wandb\n",
    "from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline\n",
    "\n",
    "from yellowbrick.cluster import KElbowVisualizer\n",
    "\n",
    "## vdvae\n",
    "import pickle\n",
    "\n",
    "\n",
    "from hps import Hyperparams\n",
    "from vae import VAE\n",
    "\n",
    "from sklearn.decomposition import PCA\n",
    "from sklearn.model_selection import GridSearchCV, RandomizedSearchCV\n",
    "from sklearn.manifold import TSNE\n",
    "import seaborn as sns\n",
    "import nltk\n",
    "from nltk.corpus import stopwords\n",
    "import string\n",
    "from wordcloud import WordCloud\n",
    "import numpy as np\n",
    "from transformers import AutoImageProcessor, UperNetForSemanticSegmentation\n",
    "from PIL import Image\n",
    "import numpy as np\n",
    "import torch\n",
    "from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler\n",
    "from diffusers.utils import load_image\n",
    "from sklearn.cluster import KMeans\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "\n",
    "\n",
    "from torchmetrics import BLEUScore\n",
    "import evaluate\n",
    "from sentence_transformers import SentenceTransformer, util\n",
    "from transformers import AutoProcessor, CLIPModel, AutoTokenizer\n",
    "from PIL import ImageFilter\n",
    "# from controlnet import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "a11847e4-7778-4988-8914-58ab36ce144a",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.\n",
      "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mmatteoferrante\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "wandb.login()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "d0f22359-fdff-4bcb-a8c6-18840c0221f1",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "base_path=\"/home/matteo/brain-diffuser/data\"\n",
    "timeseries_path=opj(base_path,\"nsddata_timeseries\")\n",
    "betas_path=opj(base_path,\"nsddata_betas\")\n",
    "\n",
    "stimuli_path=opj(base_path,\"nsddata_stimuli\",\"stimuli\",\"nsd\")\n",
    "stim_file_path=opj(stimuli_path,\"nsd_stimuli.hdf5\")\n",
    "sub=\"subj05\"\n",
    "mod=\"func1pt8mm\"\n",
    "subj_data_path=opj(timeseries_path,\"ppdata\",sub,mod,\"timeseries\")\n",
    "subj_betas_path=opj(betas_path,\"ppdata\",sub,mod,\"betas_assumehrf\")\n",
    "\n",
    "subj_betas_roi_extracted_path=opj(base_path,\"processed_roi\",sub,mod)\n",
    "\n",
    "stim_order_path=opj(base_path,\"nsddata\",\"experiments\",\"nsd\",\"nsd_expdesign.mat\")\n",
    "stim_info_path=opj(base_path,\"nsddata\",\"experiments\",\"nsd\",\"nsd_stim_info_merged.csv\")\n",
    "stim_captions_train_path=opj(base_path,\"nsddata_stimuli\",\"stimuli\",\"nsd\",\"annotations\",f\"captions_train2017.json\")\n",
    "stim_captions_val_path=opj(base_path,\"nsddata_stimuli\",\"stimuli\",\"nsd\",\"annotations\",f\"captions_val2017.json\")\n",
    "\n",
    "processed_data=opj(base_path,\"processed_data\",sub)\n",
    "sub_idx=int(sub.split(\"0\")[-1])\n",
    "\n",
    "fmri_train_data=opj(processed_data,f\"nsd_train_fmriavg_nsdgeneral_sub{sub_idx}.npy\")\n",
    "imgs_train_data=opj(processed_data,f\"nsd_train_stim_sub{sub_idx}.npy\")\n",
    "captions_train_data=opj(processed_data, f\"nsd_train_cap_sub{sub_idx}.npy\")\n",
    "       \n",
    "fmri_test_data=opj(processed_data,f\"nsd_test_fmriavg_nsdgeneral_sub{sub_idx}.npy\")\n",
    "imgs_test_data=opj(processed_data,f\"nsd_test_stim_sub{sub_idx}.npy\")\n",
    "captions_test_data=opj(processed_data, f\"nsd_test_cap_sub{sub_idx}.npy\")\n",
    "                       "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "0376aab9-ec78-45b5-868f-fce35c663bfd",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "compute_dataset=True\n",
    "train=True\n",
    "adjust=True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "fb184829-50c9-47b9-8a61-991789055d9f",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "class NSDDataset(Dataset):\n",
    "    \n",
    "\n",
    "    \n",
    "    def __init__(self, fmri_data,imgs_data,caption_data,transforms=None):\n",
    "        self.fmri_data=np.load(fmri_data)\n",
    "        self.imgs_data=np.load(imgs_data).astype(np.uint8)\n",
    "        self.caption_data=np.load(caption_data,allow_pickle=True)\n",
    "        self.transforms=transforms\n",
    "        \n",
    "    def __len__(self):\n",
    "        return  len(self.fmri_data)\n",
    "    \n",
    "    def __getitem__(self,idx):\n",
    "        fmri=torch.tensor(self.fmri_data[idx])\n",
    "        img=Image.fromarray(self.imgs_data[idx])\n",
    "        \n",
    "        if self.transforms:\n",
    "            img=self.transforms(img)\n",
    "        \n",
    "        caption=self.caption_data[idx][0] #cambiare se ne voglio altre\n",
    "        \n",
    "        return fmri,img,caption"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "fd76a19d-3cd0-4fc3-88b0-2af7c91bdd85",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "tr=torchvision.transforms.ToTensor()\n",
    "train_dataset=NSDDataset(fmri_train_data,imgs_train_data,captions_train_data,transforms=tr)\n",
    "test_dataset=NSDDataset(fmri_test_data,imgs_test_data,captions_test_data,transforms=tr)\n",
    "\n",
    "BS=128\n",
    "\n",
    "train_dataloader=DataLoader(train_dataset,BS,shuffle=True)\n",
    "test_dataloader=DataLoader(test_dataset,BS,shuffle=False)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "bfe57284-6323-4d77-9754-ca8937965a93",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "to_pil=torchvision.transforms.ToPILImage()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "f37a98e9-710f-4dfe-9d83-16eda90b045d",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "device=\"cuda:0\"\n",
    "processor = AutoProcessor.from_pretrained(\"microsoft/git-base-coco\")\n",
    "model = GitForCausalLMClipEmb.from_pretrained(\"microsoft/git-base-coco\")\n",
    "\n",
    "model.to(device)\n",
    "url = \"prova.png\"\n",
    "image = Image.open(url)\n",
    "\n",
    "pixel_values = processor(images=image, return_tensors=\"pt\").pixel_values\n",
    "vision_encoder=model.git.image_encoder\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "80f0da25-77c7-49ec-92ed-aa89a53b3830",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "GitForCausalLM(\n",
       "  (git): GitModel(\n",
       "    (embeddings): GitEmbeddings(\n",
       "      (word_embeddings): Embedding(30522, 768, padding_idx=0)\n",
       "      (position_embeddings): Embedding(1024, 768)\n",
       "      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "      (dropout): Dropout(p=0.1, inplace=False)\n",
       "    )\n",
       "    (image_encoder): GitVisionModel(\n",
       "      (vision_model): GitVisionTransformer(\n",
       "        (embeddings): GitVisionEmbeddings(\n",
       "          (patch_embedding): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16), bias=False)\n",
       "          (position_embedding): Embedding(197, 768)\n",
       "        )\n",
       "        (pre_layrnorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "        (encoder): GitVisionEncoder(\n",
       "          (layers): ModuleList(\n",
       "            (0-11): 12 x GitVisionEncoderLayer(\n",
       "              (self_attn): GitVisionAttention(\n",
       "                (k_proj): Linear(in_features=768, out_features=768, bias=True)\n",
       "                (v_proj): Linear(in_features=768, out_features=768, bias=True)\n",
       "                (q_proj): Linear(in_features=768, out_features=768, bias=True)\n",
       "                (out_proj): Linear(in_features=768, out_features=768, bias=True)\n",
       "              )\n",
       "              (layer_norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "              (mlp): GitVisionMLP(\n",
       "                (activation_fn): QuickGELUActivation()\n",
       "                (fc1): Linear(in_features=768, out_features=3072, bias=True)\n",
       "                (fc2): Linear(in_features=3072, out_features=768, bias=True)\n",
       "              )\n",
       "              (layer_norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "            )\n",
       "          )\n",
       "        )\n",
       "        (post_layernorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "      )\n",
       "    )\n",
       "    (encoder): GitEncoder(\n",
       "      (layer): ModuleList(\n",
       "        (0-5): 6 x GitLayer(\n",
       "          (attention): GitAttention(\n",
       "            (self): GitSelfAttention(\n",
       "              (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (dropout): Dropout(p=0.1, inplace=False)\n",
       "            )\n",
       "            (output): GitSelfOutput(\n",
       "              (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "              (dropout): Dropout(p=0.1, inplace=False)\n",
       "            )\n",
       "          )\n",
       "          (intermediate): GitIntermediate(\n",
       "            (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "            (intermediate_act_fn): GELUActivation()\n",
       "          )\n",
       "          (output): GitOutput(\n",
       "            (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "            (dropout): Dropout(p=0.1, inplace=False)\n",
       "          )\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "    (visual_projection): GitProjection(\n",
       "      (visual_projection): Sequential(\n",
       "        (0): Linear(in_features=768, out_features=768, bias=True)\n",
       "        (1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "      )\n",
       "    )\n",
       "  )\n",
       "  (output): Linear(in_features=768, out_features=30522, bias=True)\n",
       ")"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model_base = GitForCausalLM.from_pretrained(\"microsoft/git-base-coco\")\n",
    "\n",
    "model_base.to(device)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "f1a64fd0-0fd5-433e-bbea-7973d7b30e0f",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████| 70/70 [05:58<00:00,  5.12s/it]\n"
     ]
    }
   ],
   "source": [
    "if compute_dataset:\n",
    "    train_fmri=[]\n",
    "    train_imgs=[]\n",
    "    train_captions=[]\n",
    "    train_clip_img_embeds=[]\n",
    "\n",
    "\n",
    "    for x,y,c in tqdm.tqdm(train_dataloader):\n",
    "\n",
    "        #save fMRI data\n",
    "        train_fmri.append(x)\n",
    "\n",
    "        #save img data\n",
    "        train_imgs.append(y)\n",
    "\n",
    "        train_captions+=list(c)\n",
    "\n",
    "        #encode images in autoencoder and save z representation\n",
    "        with torch.no_grad():\n",
    "            \n",
    "            #encode images in CLIP\n",
    "            \n",
    "            pixel_values= processor(images=y, return_tensors=\"pt\").pixel_values.to(device)\n",
    "            \n",
    "            image_features=vision_encoder(pixel_values).last_hidden_state.cpu()\n",
    "            train_clip_img_embeds.append(image_features)\n",
    "\n",
    "           \n",
    "    train_clip_img_embeds = torch.cat(train_clip_img_embeds,axis=0)\n",
    "\n",
    "    train_fmri = torch.cat(train_fmri,axis=0)\n",
    "    train_imgs = torch.cat(train_imgs,axis=0)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "e0c23ad8-b4a1-457f-a177-5db2aa5bd81d",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████| 8/8 [00:12<00:00,  1.54s/it]\n"
     ]
    }
   ],
   "source": [
    "if compute_dataset:\n",
    "    test_fmri=[]\n",
    "    test_imgs=[]\n",
    "    test_captions=[]\n",
    "    test_clip_img_embeds=[]\n",
    "\n",
    "\n",
    "    for x,y,c in tqdm.tqdm(test_dataloader):\n",
    "\n",
    "        #save fMRI data\n",
    "        test_fmri.append(x)\n",
    "\n",
    "        #save img data\n",
    "        test_imgs.append(y)\n",
    "\n",
    "        test_captions+=list(c)\n",
    "\n",
    "        #encode images in autoencoder and save z representation\n",
    "        with torch.no_grad():\n",
    "            \n",
    "            #encode images in CLIP\n",
    "            \n",
    "            pixel_values= processor(images=y, return_tensors=\"pt\").pixel_values.to(device)\n",
    "            \n",
    "            image_features=vision_encoder(pixel_values).last_hidden_state.cpu()\n",
    "            test_clip_img_embeds.append(image_features)\n",
    "\n",
    "           \n",
    "    test_clip_img_embeds = torch.cat(test_clip_img_embeds,axis=0)\n",
    "\n",
    "    test_fmri = torch.cat(test_fmri,axis=0)\n",
    "    test_imgs = torch.cat(test_imgs,axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "e0f4e8d6-254d-4958-8492-d4505ab18a12",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "saved training stuff\n",
      "saved testing stuff\n"
     ]
    }
   ],
   "source": [
    "if compute_dataset:\n",
    "    os.makedirs(f\"models/{sub}\",exist_ok=True)\n",
    "    \n",
    "    ## train\n",
    "    torch.save(train_fmri,f\"models/{sub}/train_fmri.pt\")\n",
    "    torch.save(train_clip_img_embeds,f\"models/{sub}/train_clip_img_embeds.pt\")\n",
    "    torch.save(train_imgs,f\"models/{sub}/train_imgs.pt\")\n",
    "        \n",
    "    with open(f\"models/{sub}/train_captions.sav\",\"wb\") as f:\n",
    "        pickle.dump(train_captions,f)\n",
    "        \n",
    "    print(\"saved training stuff\")\n",
    "    \n",
    "    ## test\n",
    "    torch.save(test_fmri,f\"models/{sub}/test_fmri.pt\")\n",
    "    torch.save(test_clip_img_embeds,f\"models/{sub}/test_clip_img_embeds.pt\")\n",
    "    torch.save(test_imgs,f\"models/{sub}/test_imgs.pt\")\n",
    "        \n",
    "    with open(f\"models/{sub}/test_captions.sav\",\"wb\") as f:\n",
    "        pickle.dump(test_captions,f)\n",
    "    \n",
    "    \n",
    "    print(\"saved testing stuff\")\n",
    "    \n",
    "else:\n",
    "    if sub==\"subj01_good2\":\n",
    "        sub=\"subj01\"\n",
    "    ## train\n",
    "    train_fmri=torch.load(f\"models/{sub}/train_fmri.pt\")\n",
    "    train_clip_img_embeds= torch.load(f\"models/{sub}/train_clip_img_embeds.pt\")\n",
    "    train_imgs=torch.load(f\"models/{sub}/train_imgs.pt\")\n",
    "        \n",
    "    with open(f\"models/{sub}/train_captions.sav\",\"rb\") as f:\n",
    "        train_captions=pickle.load(f)\n",
    "\n",
    "    ## test\n",
    "    test_fmri=torch.load(f\"models/{sub}/test_fmri.pt\")\n",
    "    test_clip_img_embeds= torch.load(f\"models/{sub}/test_clip_img_embeds.pt\")\n",
    "    test_imgs=torch.load(f\"models/{sub}/test_imgs.pt\")\n",
    "    \n",
    "    with open(f\"models/{sub}/test_captions.sav\",\"rb\") as f:\n",
    "        test_captions=pickle.load(f)\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "4840b0fc-9261-40ce-91b7-8766976b77b9",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "train_fmri_mean=torch.mean(train_fmri,axis=0)\n",
    "train_fmri_std=torch.std(train_fmri,axis=0)\n",
    "\n",
    "train_fmri_norm=(train_fmri-train_fmri_mean)/train_fmri_std\n",
    "# val_fmri_norm=(val_fmri-train_fmri_mean)/train_fmri_std\n",
    "test_fmri_norm=(test_fmri-train_fmri_mean)/train_fmri_std\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "4c627213-8faa-453b-a9c8-eaaf92a581d9",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████| 197/197 [57:36<00:00, 17.54s/it]\n"
     ]
    }
   ],
   "source": [
    "max_len_img=197\n",
    "\n",
    "if train:\n",
    "    brain_to_img_emb=[]\n",
    "\n",
    "    for i in tqdm.tqdm(range(max_len_img)):\n",
    "        m=Ridge(alpha=6e4)\n",
    "        m.fit(train_fmri_norm.numpy(),train_clip_img_embeds[:,i,:].numpy())\n",
    "        brain_to_img_emb.append(m)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "f21bdf5d-7f57-4633-a6c1-3a6ca6d98734",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "if train:\n",
    "    os.makedirs(f\"models/{sub}/decoding\",exist_ok=True)\n",
    "    for i in range(max_len_img):\n",
    "        filename = f'brain_to_img_emb_ridge_{i}.sav'\n",
    "        with open(opj(f\"models/{sub}/decoding\",filename), 'wb') as f:\n",
    "            pickle.dump(brain_to_img_emb[i], f) \n",
    "else:\n",
    "    brain_to_img_emb=[]\n",
    "    for i in range(max_len_img):\n",
    "        filename = f'brain_to_img_emb_ridge_{i}.sav'\n",
    "        with open(opj(f\"models/{sub}/decoding\",filename), 'rb') as f:\n",
    "            p=pickle.load(f)\n",
    "            brain_to_img_emb.append(p)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "33447c22-f351-4daf-b3ac-616eed118f8a",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████| 197/197 [01:20<00:00,  2.46it/s]\n"
     ]
    }
   ],
   "source": [
    "if train:\n",
    "    img_emb_train=[]\n",
    "    for i in tqdm.tqdm(range(max_len_img)):\n",
    "        emb=torch.tensor(brain_to_img_emb[i].predict(train_fmri_norm.numpy()))\n",
    "\n",
    "\n",
    "        img_emb_train.append(emb)\n",
    "    img_emb_train=torch.stack(img_emb_train,1)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "5383fbdd-5f4d-4c62-9b98-ec53241b8251",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "if train:\n",
    "    \n",
    "    train_clip_img_embeds_mean=train_clip_img_embeds.mean(0)\n",
    "    train_clip_img_embeds_std=train_clip_img_embeds.std(0)\n",
    "    \n",
    "    pred_clip_img_embeds_mean=img_emb_train.mean(0)\n",
    "    pred_clip_img_embeds_std=img_emb_train.std(0)\n",
    "    \n",
    "    torch.save(train_clip_img_embeds_mean, opj(f\"models/{sub}\",\"train_clip_img_embeds_mean.pt\"))\n",
    "    torch.save(train_clip_img_embeds_std, opj(f\"models/{sub}\",\"train_clip_img_embeds_std.pt\"))\n",
    "    torch.save(pred_clip_img_embeds_mean, opj(f\"models/{sub}\",\"pred_clip_img_embeds_mean.pt\"))\n",
    "    torch.save(pred_clip_img_embeds_std, opj(f\"models/{sub}\",\"pred_clip_img_embeds_std.pt\"))\n",
    "    \n",
    "else:\n",
    "    train_clip_img_embeds_mean=torch.load(opj(f\"models/{sub}\",\"train_clip_img_embeds_mean.pt\"))\n",
    "    train_clip_img_embeds_std=torch.load(opj(f\"models/{sub}\",\"train_clip_img_embeds_std.pt\"))\n",
    "   \n",
    "\n",
    "    pred_clip_img_embeds_mean=torch.load(opj(f\"models/{sub}\",\"pred_clip_img_embeds_mean.pt\"))\n",
    "    pred_clip_img_embeds_std=torch.load(opj(f\"models/{sub}\",\"pred_clip_img_embeds_std.pt\"))\n",
    "    \n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "b357bcda-a8e3-45a7-97bc-a7eed335512f",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████| 197/197 [00:36<00:00,  5.35it/s]\n"
     ]
    }
   ],
   "source": [
    "\n",
    "img_emb_test=[]\n",
    "for i in tqdm.tqdm(range(max_len_img)):\n",
    "    emb=torch.tensor(brain_to_img_emb[i].predict(test_fmri_norm.numpy()))\n",
    "\n",
    "\n",
    "    img_emb_test.append(emb)\n",
    "img_emb_test=torch.stack(img_emb_test,1)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "b51e0f95-7f34-431b-b204-7534ed5d04b2",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "if adjust:\n",
    "    img_emb_test_adj=(img_emb_test-img_emb_test.mean(0))/(img_emb_test.std(0))\n",
    "    img_emb_test_adj=train_clip_img_embeds_std*img_emb_test_adj+train_clip_img_embeds_mean"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "ea80afe6-3bc5-4d90-9e50-7e3203f5a938",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████| 8/8 [02:18<00:00, 17.36s/it]\n"
     ]
    }
   ],
   "source": [
    "captions_from_images=[]\n",
    "captions_from_brain=[]\n",
    "\n",
    "with torch.no_grad():\n",
    "    for i in tqdm.tqdm(range(0,len(test_imgs),BS)):\n",
    "\n",
    "        #get reference images\n",
    "        imgs=[to_pil(j) for j in test_imgs[i:i+BS]]\n",
    "\n",
    "        #compute captions from images\n",
    "        pixel_values = processor(images=imgs, return_tensors=\"pt\").pixel_values\n",
    "        generated_ids = model_base.generate(pixel_values=pixel_values.to(device), max_length=25)\n",
    "        generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)\n",
    "\n",
    "        captions_from_images+=generated_caption\n",
    "\n",
    "        #compute captions from brain\n",
    "        generated_ids = model.generate(pixel_values=img_emb_test_adj[i:i+BS].to(device).float(), max_length=25)\n",
    "        generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)\n",
    "        captions_from_brain+=generated_caption"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "431e6104-982b-48d0-80a1-b01646ebbc4a",
   "metadata": {},
   "source": [
    "## Compute text metrics\n",
    "* BLEU\n",
    "* METEOR\n",
    "* Sentence\n",
    "* CLIP"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "903adfaf-0595-4e0a-baf0-98bcc3e452f3",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[nltk_data] Downloading package wordnet to /home/matteo/nltk_data...\n",
      "[nltk_data]   Package wordnet is already up-to-date!\n",
      "[nltk_data] Downloading package punkt to /home/matteo/nltk_data...\n",
      "[nltk_data]   Package punkt is already up-to-date!\n",
      "[nltk_data] Downloading package omw-1.4 to /home/matteo/nltk_data...\n",
      "[nltk_data]   Package omw-1.4 is already up-to-date!\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[GROUND] BLEU@1 GIT from images vs captions: 0.43787420529004933\n",
      "[ABSOLUTE] BLEU@1 GIT from brain vs images: 0.4335066205372695\n",
      "[RELATIVE] BLEU@1  0.9900254806060413\n",
      "\n",
      "[GROUND] BLEU@4 GIT from images vs captions: 0.13830765150649799\n",
      "[ABSOLUTE] BLEU@4 GIT from brain vs images: 0.11177697624503605\n",
      "[RELATIVE] BLEU@4  0.8081763736678338\n"
     ]
    }
   ],
   "source": [
    "bleu=evaluate.load(\"bleu\")\n",
    "meteor = evaluate.load('meteor')\n",
    "bleu_img_ref=bleu.compute(predictions=captions_from_images,references=test_captions,max_order=1)\n",
    "bleu_brain_ref=bleu.compute(predictions=captions_from_brain,references=test_captions,max_order=1)\n",
    "bleu_brain_img=bleu.compute(predictions=captions_from_brain,references=captions_from_images,max_order=1)\n",
    "\n",
    "bleu_img_ref_4=bleu.compute(predictions=captions_from_images,references=test_captions,max_order=4)\n",
    "bleu_brain_ref_4=bleu.compute(predictions=captions_from_brain,references=test_captions,max_order=4)\n",
    "bleu_brain_img_4=bleu.compute(predictions=captions_from_brain,references=captions_from_images,max_order=4)\n",
    "\n",
    "\n",
    "\n",
    "relative_brain_image_bleu=bleu_brain_img[\"bleu\"]/bleu_img_ref[\"bleu\"]\n",
    "relative_brain_image_bleu_4=bleu_brain_img_4[\"bleu\"]/bleu_img_ref_4[\"bleu\"]\n",
    "\n",
    "\n",
    "print(f\"[GROUND] BLEU@1 GIT from images vs captions: {bleu_img_ref['bleu']}\")\n",
    "print(f\"[ABSOLUTE] BLEU@1 GIT from brain vs images: {bleu_brain_img['bleu']}\")\n",
    "print(f\"[RELATIVE] BLEU@1  {relative_brain_image_bleu}\")\n",
    "\n",
    "print()\n",
    "print(f\"[GROUND] BLEU@4 GIT from images vs captions: {bleu_img_ref_4['bleu']}\")\n",
    "print(f\"[ABSOLUTE] BLEU@4 GIT from brain vs images: {bleu_brain_img_4['bleu']}\")\n",
    "print(f\"[RELATIVE] BLEU@4  {relative_brain_image_bleu_4}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "f06f2ef5-d9bb-4089-acfe-a7859d8ada7c",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[GROUND] METEOR GIT from images vs captions: 0.4038153815765378\n",
      "[ABSOLUTE] METEOR GIT from brain vs images: 0.3038103202360226\n",
      "[RELATIVE] METEOR  0.7523495490684756\n"
     ]
    }
   ],
   "source": [
    "meteor_img_ref=meteor.compute(predictions=captions_from_images,references=test_captions)\n",
    "meteor_brain_ref=meteor.compute(predictions=captions_from_brain,references=test_captions)\n",
    "meteor_brain_img=meteor.compute(predictions=captions_from_brain,references=captions_from_images)\n",
    "\n",
    "\n",
    "relative_brain_image_meteor=meteor_brain_img[\"meteor\"]/meteor_img_ref[\"meteor\"]\n",
    "\n",
    "print(f\"[GROUND] METEOR GIT from images vs captions: {meteor_img_ref['meteor']}\")\n",
    "print(f\"[ABSOLUTE] METEOR GIT from brain vs images: {meteor_brain_img['meteor']}\")\n",
    "print(f\"[RELATIVE] METEOR  {relative_brain_image_meteor}\")\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "9f5ca02d-2202-4110-a259-70eb3ac5a347",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "sentence_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "a5abea6a-8fbf-4648-afc3-4cce67901380",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "#Compute embedding for both lists\n",
    "with torch.no_grad():\n",
    "    embedding_brain= sentence_model.encode(captions_from_brain, convert_to_tensor=True)\n",
    "    embedding_captions = sentence_model.encode(test_captions, convert_to_tensor=True)\n",
    "    embedding_images = sentence_model.encode(captions_from_images, convert_to_tensor=True)\n",
    "\n",
    "    ss_sim_brain_img=util.pytorch_cos_sim(embedding_brain, embedding_images).cpu()\n",
    "    ss_sim_brain_cap=util.pytorch_cos_sim(embedding_brain, embedding_captions).cpu()\n",
    "    ss_sim_img_cap=util.pytorch_cos_sim(embedding_images, embedding_captions).cpu()\n",
    "\n",
    "    relative_brain_image_ss=ss_sim_brain_img.diag().mean()/ss_sim_img_cap.diag().mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "c170e500-c729-4e4c-a153-567311919a7c",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[GROUND] Sentence Transformer Similarity GIT from images vs captions: 0.7029159069061279\n",
      "[ABSOLUTE] Sentence Transformer Similarity GIT from brain vs images: 0.443602979183197\n",
      "[RELATIVE] Sentence Transformer Similarity   0.6310896873474121\n"
     ]
    }
   ],
   "source": [
    "print(f\"[GROUND] Sentence Transformer Similarity GIT from images vs captions: {ss_sim_img_cap.diag().mean()}\")\n",
    "print(f\"[ABSOLUTE] Sentence Transformer Similarity GIT from brain vs images: {ss_sim_brain_img.diag().mean()}\")\n",
    "print(f\"[RELATIVE] Sentence Transformer Similarity   {relative_brain_image_ss.mean()}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "9fc3d7a3-ba99-4f06-b9df-bc68b99c34b4",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[GROUND] CLIP Similarity GIT from images vs captions: 0.8314802050590515\n",
      "[ABSOLUTE] CLIP Similarity GIT from brain vs images: 0.7022090554237366\n",
      "[RELATIVE] CLIP Similarity   0.8445289134979248\n"
     ]
    }
   ],
   "source": [
    "model_clip = CLIPModel.from_pretrained(\"openai/clip-vit-base-patch32\")\n",
    "processor_clip = AutoProcessor.from_pretrained(\"openai/clip-vit-base-patch32\")\n",
    "tokenizer =  AutoTokenizer.from_pretrained(\"openai/clip-vit-base-patch32\")\n",
    "\n",
    "with torch.no_grad():\n",
    "    input_ids=tokenizer(captions_from_brain,return_tensors=\"pt\",padding=True)\n",
    "    embedding_brain= model_clip.get_text_features(**input_ids)\n",
    "\n",
    "    input_ids=tokenizer(test_captions,return_tensors=\"pt\",padding=True)\n",
    "    embedding_captions= model_clip.get_text_features(**input_ids)\n",
    "\n",
    "    input_ids=tokenizer(captions_from_images,return_tensors=\"pt\",padding=True)\n",
    "    embedding_images= model_clip.get_text_features(**input_ids)\n",
    "\n",
    "clip_sim_brain_img=util.pytorch_cos_sim(embedding_brain, embedding_images).cpu()\n",
    "clip_sim_brain_cap=util.pytorch_cos_sim(embedding_brain, embedding_captions).cpu()\n",
    "clip_sim_img_cap=util.pytorch_cos_sim(embedding_images, embedding_captions).cpu()\n",
    "\n",
    "relative_brain_image_clip=clip_sim_brain_img.diag().mean()/clip_sim_img_cap.diag().mean()\n",
    "\n",
    "print(f\"[GROUND] CLIP Similarity GIT from images vs captions: {clip_sim_img_cap.diag().mean()}\")\n",
    "print(f\"[ABSOLUTE] CLIP Similarity GIT from brain vs images: {clip_sim_brain_img.diag().mean()}\")\n",
    "print(f\"[RELATIVE] CLIP Similarity   {relative_brain_image_clip.mean()}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "be16e73b-d324-4ea4-bad2-2adcdda9cb23",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "out_data={\"ROI\":\"nsdgeneral\",\n",
    "          \"Bleu_1_img_ref\": bleu_img_ref['bleu'],\n",
    "          \"Bleu_1_brain_img\": bleu_img_ref['bleu'],\n",
    "          \"Bleu_1_relative\": bleu_img_ref['bleu'],\n",
    "          \"Bleu_4_img_ref\": bleu_img_ref_4['bleu'],\n",
    "          \"Bleu_4_brain_img\": bleu_img_ref_4['bleu'],\n",
    "          \"Bleu_4_relative\": bleu_img_ref_4['bleu'],\n",
    "          \"Meteor_img_ref\":meteor_img_ref['meteor'],\n",
    "          \"Meteor_brain_img\":meteor_brain_img['meteor'],\n",
    "          \"Meteor_relative\":relative_brain_image_meteor,\n",
    "          \"Sentence_img_ref\":ss_sim_img_cap.diag().mean().item(),\n",
    "          \"Sentence_brain_img\":ss_sim_brain_img.diag().mean().item(),\n",
    "          \"Sentence_relative\":relative_brain_image_ss.mean().item(),\n",
    "          \"CLIP_img_ref\":clip_sim_img_cap.diag().mean().item(),\n",
    "          \"CLIP_brain_img\":clip_sim_brain_img.diag().mean().item(),\n",
    "          \"CLIP_relative\":relative_brain_image_clip.mean().item()}\n",
    "\n",
    "df=pd.DataFrame.from_dict([out_data])\n",
    "df.to_csv(f\"nsdgeneral_{sub}.csv\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "988591b2-e6d3-4de0-ad30-fc09883ff4f9",
   "metadata": {},
   "source": [
    "### Save all to wandb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "b3ad17e6-3af3-4548-8606-1812f5078675",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
      "To disable this warning, you can either:\n",
      "\t- Avoid using `tokenizers` before the fork if possible\n",
      "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "wandb version 0.15.2 is available!  To upgrade, please run:\n",
       " $ pip install wandb --upgrade"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Tracking run with wandb version 0.14.2"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Run data is saved locally in <code>/home/matteo/fMRI_text_prediction/wandb/run-20230508_184836-fsii8zqb</code>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Syncing run <strong><a href='https://wandb.ai/matteoferrante/BrainCaptioning/runs/fsii8zqb' target=\"_blank\">fluent-disco-3</a></strong> to <a href='https://wandb.ai/matteoferrante/BrainCaptioning' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       " View project at <a href='https://wandb.ai/matteoferrante/BrainCaptioning' target=\"_blank\">https://wandb.ai/matteoferrante/BrainCaptioning</a>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       " View run at <a href='https://wandb.ai/matteoferrante/BrainCaptioning/runs/fsii8zqb' target=\"_blank\">https://wandb.ai/matteoferrante/BrainCaptioning/runs/fsii8zqb</a>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<button onClick=\"this.nextSibling.style.display='block';this.style.display='none';\">Display W&B run</button><iframe src='https://wandb.ai/matteoferrante/BrainCaptioning/runs/fsii8zqb?jupyter=true' style='border:none;width:100%;height:420px;display:none;'></iframe>"
      ],
      "text/plain": [
       "<wandb.sdk.wandb_run.Run at 0x7fd982102670>"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "config={\"sub\":sub, \"mod\": mod, \"model\": \"Ridge\"}\n",
    "wandb.init(project=\"BrainCaptioning\", config=config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "bdbd8f22-9492-4eda-8bed-eaaff1779469",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "wandb.log(out_data)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "887bd1e8-d64d-4139-969c-32d9229b1fdf",
   "metadata": {},
   "source": [
    "## Perform Image Decoding\n",
    "\n",
    "* Define styles\n",
    "* Import StableDiffusion2\n",
    "* Import StableDiffusion + ControlNet\n",
    "* Load VDVAE for image reconstruction\n",
    "* Load VDVAE for depth image reconstruction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "b1862d31-a365-46c0-ad2b-984b1bb52f3a",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "styles=[\"illustration\", \"drawing\", \"cartoon\", \"hyperrealistic\",\"fantasy\", \"surrealist\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "27610541-f966-4ad6-b888-4fa5cf78c9f6",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def get_latents(data,brain_to_latent,shapes,adjust=None):\n",
    "    latents={}\n",
    "    bs=data.shape[0]\n",
    "    for k,v in brain_to_latent.items():\n",
    "        s=shapes[k]\n",
    "        z=torch.tensor(v.predict(data)).reshape(-1,*s)\n",
    "        \n",
    "    \n",
    "        if adjust is not None and bs>1:\n",
    "            #compute actual mean and std\n",
    "            z_mean=z.mean(0)  \n",
    "            z_std=z.std(0)\n",
    "            #standardize \n",
    "            z = (z - z_mean)/(1e-9+z_std)\n",
    "            \n",
    "            #replace with latent mean and std\n",
    "            z = z*adjust[k][\"std\"]+adjust[k][\"mean\"]\n",
    "        \n",
    "        latents[k]=z\n",
    "        \n",
    "    return latents"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "1de413d9-57bb-49d6-9b51-c0b148df76bf",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def decode_with_partial_sampling(decoder,latents,keep=31):\n",
    "    xs = {a.shape[2]: a for a in decoder.bias_xs}\n",
    "    \n",
    "    out=vae.decoder.forward_manual_latents(keep,latents.values(),t=None)\n",
    "    \n",
    "    xs=decoder.out_net.sample(out)\n",
    "    xs=torch.tensor(xs).permute(0,3,1,2)/255\n",
    "    return xs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "e4661066-35b8-48f9-9b11-aad5a8f4fa70",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "upsample=torchvision.transforms.Resize(512,interpolation=torchvision.transforms.InterpolationMode.BILINEAR)\n",
    "to_pil=torchvision.transforms.ToPILImage()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "aa18e25a-5920-43d8-9cdf-3bb47850aeec",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "if sub==\"subj01\":\n",
    "    sub=\"subj01_good2\"\n",
    "with open(f\"/home/matteo/explore_NSD/models/{sub}/train_z.sav\",\"rb\") as f:\n",
    "    train_z=pickle.load(f)\n",
    "    \n",
    "with open(f\"/home/matteo/explore_NSD/models/{sub}/test_z.sav\",\"rb\") as f:\n",
    "    test_z=pickle.load(f)\n",
    "shapes={k: (v.shape[1],v.shape[2],v.shape[3]) for k,v in train_z.items()}\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "7a9d6d89-4042-4879-960a-8e33eb1d8827",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "keep=31\n",
    "if sub==\"subj01_good2\":\n",
    "    sub=\"subj01\"\n",
    "if not train:\n",
    "    keys=np.arange(keep)\n",
    "    # filename='brain_to_latent_ridge.sav'\n",
    "    brain_to_depth = {}\n",
    "#     pickle.load(open(opj(f\"models/{sub}/decoding\",filename), 'rb'))\n",
    "    for k in keys:\n",
    "        filename = f'brain_to_depth_vdvae_latent_ridge_{k}.sav'\n",
    "        p=pickle.load(open(opj(f\"/home/matteo/explore_NSD/models/{sub}/depth\",filename), 'rb'))\n",
    "        brain_to_depth[k]=p\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "41e424ea-4e30-4116-925c-c3a0defde483",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "keep=31\n",
    "keys=np.arange(keep)\n",
    "brain_to_latent = {}\n",
    "\n",
    "\n",
    "for k in keys:\n",
    "    filename = f'brain_to_vdvae_latent_ridge_{k}.sav'\n",
    "    p=pickle.load(open(opj(f\"/home/matteo/explore_NSD/models/{sub}/decoding\",filename), 'rb'))\n",
    "    brain_to_latent[k]=p\n",
    "\n",
    "    \n",
    "latent_adjust_values={}\n",
    "for i in range(keep):\n",
    "    latent_adjust_values[i]={\"mean\":train_z[i].mean(0), \"std\": train_z[i].std(0)}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "171ef80e-f17f-436a-8582-6775ec49c0c8",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "filename = f'latent_depth_adjust_values.sav'\n",
    "latent_depth_adjust_values={}\n",
    "\n",
    "if sub==\"subj01_good2\":\n",
    "    sub=\"subj01\"\n",
    "\n",
    "with open(opj(f\"/home/matteo/explore_NSD/models/{sub}\",filename), 'rb') as f:\n",
    "    latent_depth_adjust_values=pickle.load( f)\n",
    "    \n",
    "    # explore_NSD/models/subj01/latent_depth_adjust_values.sav"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "id": "b90b7b78-c906-4d94-9422-f6c24c848f8e",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# fix nan\n",
    "\n",
    "for k in train_z.keys():\n",
    "    train_z[k]=torch.nan_to_num(train_z[k])\n",
    "    # val_z[k]=torch.nan_to_num(val_z[k])\n",
    "    test_z[k]=torch.nan_to_num(test_z[k])\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "ab63d017-5896-4754-a0aa-f99229690ca8",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "with open('/home/matteo/models/vdvae/H.sav', 'rb') as fp:\n",
    "    d = pickle.load(fp)\n",
    "    \n",
    "H=Hyperparams()\n",
    "for k,v in d.items():\n",
    "    H[k]=v\n",
    "    \n",
    "vae=VAE(H)\n",
    "\n",
    "state_dict = torch.load(\"/home/matteo/models/vdvae/vae2.pt\")\n",
    "new_state_dict = {}\n",
    "l = len('module.')\n",
    "for k in state_dict:\n",
    "    if k.startswith('module.'):\n",
    "        new_state_dict[k[l:]] = state_dict[k]\n",
    "    else:\n",
    "        new_state_dict[k] = state_dict[k]\n",
    "state_dict = new_state_dict\n",
    "vae.load_state_dict(state_dict)\n",
    "vae=vae.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "00a8f6d0-4cf0-45cb-b8e0-ec3534637bd4",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "checkpoint_shuffle = \"lllyasviel/control_v11e_sd15_shuffle\"\n",
    "controlnet_shuffle = ControlNetModel.from_pretrained(checkpoint_shuffle, torch_dtype=torch.float16)\n",
    "\n",
    "checkpoint_pix = \"lllyasviel/control_v11e_sd15_ip2p\"\n",
    "controlnet_pix = ControlNetModel.from_pretrained(checkpoint_pix, torch_dtype=torch.float16)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "id": "bb7602b0-2148-4e0f-826c-0f65a5c4be38",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/matteo/anaconda3/lib/python3.9/site-packages/transformers/models/clip/feature_extraction_clip.py:28: FutureWarning: The class CLIPFeatureExtractor is deprecated and will be removed in version 5 of Transformers. Please use CLIPImageProcessor instead.\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "pipe = StableDiffusionPipeline.from_pretrained(\"stabilityai/stable-diffusion-2-1-base\", torch_dtype=torch.float16)\n",
    "pipe = pipe.to(device)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "id": "a357d928-2813-44a9-af91-928477b3f7cb",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "pipe_img=StableDiffusionImg2ImgPipeline.from_pretrained(\"stabilityai/stable-diffusion-2-1-base\", torch_dtype=torch.float16)\n",
    "pipe_img = pipe_img.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "id": "2cb94484-9ddb-4928-b3ef-3b0cb69afbe7",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config[\"id2label\"]` will be overriden.\n"
     ]
    }
   ],
   "source": [
    "checkpoint = \"lllyasviel/control_v11f1p_sd15_depth\"\n",
    "\n",
    "controlnet = ControlNetModel.from_pretrained(checkpoint, torch_dtype=torch.float16)\n",
    "\n",
    "depth_pipe  = StableDiffusionControlNetPipeline.from_pretrained(\n",
    "    \"runwayml/stable-diffusion-v1-5\", controlnet=[controlnet,controlnet_pix,controlnet_shuffle], torch_dtype=torch.float16\n",
    ")\n",
    "\n",
    "depth_pipe.control_net=[controlnet,controlnet_shuffle]\n",
    "\n",
    "depth_pipe.to(device)\n",
    "\n",
    "depth_pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5c1b8ce9-6de6-4e26-8779-d52c3e89d24e",
   "metadata": {
    "tags": []
   },
   "source": [
    "#### compute latents"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "id": "46353a5a-1c70-4a22-bf61-787debe7677e",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "gaussian_blur=torchvision.transforms.GaussianBlur(kernel_size=5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "id": "bd8549b3-d18d-45fa-ac65-893abfc58d52",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "z=get_latents(test_fmri_norm.numpy(),brain_to_latent,shapes,adjust=latent_adjust_values)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "id": "81203bf0-745b-4fda-8e7f-631783800b55",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/matteo/anaconda3/lib/python3.9/site-packages/torchvision/transforms/functional.py:1603: UserWarning: The default value of the antialias parameter of all the resizing transforms (Resize(), RandomResizedCrop(), etc.) will change from None to True in v0.17, in order to be consistent across the PIL and Tensor backends. To suppress this warning, directly pass antialias=True (recommended, future default), antialias=None (current default, which means False for Tensors and True for PIL), or antialias=False (only works on Tensors - PIL will still use antialiasing). This also applies if you are using the inference transforms from the models weights: update the call to weights.transforms(antialias=True).\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(0.)\n"
     ]
    }
   ],
   "source": [
    "\n",
    "with torch.no_grad():\n",
    "\n",
    "    guess_img=decode_with_partial_sampling(vae.decoder,{k:v.to(device).float() for k,v in z.items()},keep=len(z[0]))\n",
    "    guess_img=upsample(guess_img).clamp(0,1)\n",
    "    print(guess_img.max())\n",
    "    guessed=[to_pil(i) for i in guess_img]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "id": "2a877f62-bd39-448f-9cb2-0cb5e2b674d2",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'brain_to_depth' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m/tmp/ipykernel_202130/3484587426.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mdepth_z\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mget_latents\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtest_fmri_norm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnumpy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mbrain_to_depth\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mshapes\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0madjust\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mlatent_depth_adjust_values\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      2\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mno_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      3\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      4\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      5\u001b[0m     \u001b[0;31m# guess_img=upsample(autoencoder.decoder.double()(z.to(device)).cpu())\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mNameError\u001b[0m: name 'brain_to_depth' is not defined"
     ]
    }
   ],
   "source": [
    "depth_z=get_latents(test_fmri_norm.numpy(),brain_to_depth,shapes,adjust=latent_depth_adjust_values)\n",
    "with torch.no_grad():\n",
    "\n",
    "\n",
    "    # guess_img=upsample(autoencoder.decoder.double()(z.to(device)).cpu())\n",
    "    guess_img=decode_with_partial_sampling(vae.decoder,{k:v.to(device).float() for k,v in depth_z.items()},keep=len(z[0]))\n",
    "    # img_out=pipe_embed.vae.float().decode(z.float().to(device)).sample.cpu()\n",
    "    print(guess_img.max())\n",
    "    guess_img=upsample(gaussian_blur(guess_img)).clamp(0,1)\n",
    "    depth_guessed=[to_pil(i).convert(\"L\").convert(\"RGB\").filter(ImageFilter.SMOOTH_MORE) for i in guess_img]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a030ae04-a7b0-4745-b6f1-f62b8e347f79",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "depth_images_quantized=[i.quantize(64).convert(\"RGB\").filter(ImageFilter.SMOOTH_MORE) for i in depth_guessed]\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ef14ed2c-aacc-4152-ad74-a566b6d73978",
   "metadata": {},
   "source": [
    "#### Process the entire test set\n",
    "* Base pipeline\n",
    "* Depth\n",
    "* Depth with styles\n",
    "* Save caption from image\n",
    "* Save caption from brain"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b3e4a7de-b48c-4c9a-be95-3d50c2e03d0d",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "cols=[\"Caption from image\",\"Caption from brain\", \"Reference image\",\"Init Reconstruction\",\"Depth Reconstruction\",\n",
    "      \"SD_img_to_img\",\"SD_only_text\",\"SD_controldepth\",f\"SD_control\",f\"SD_control_{styles[0]}\",\n",
    "      f\"SD_control_{styles[1]}\",f\"SD_control_{styles[2]}\",f\"SD_control_{styles[3]}\",f\"SD_control_{styles[4]}\",\n",
    "      f\"SD_control_{styles[5]}\"]\n",
    "table = wandb.Table(columns=cols)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "39370232-95d0-4489-877f-70cc8c032eb8",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "for idx in range(len(test_dataset)):\n",
    "\n",
    "    print(f\"processing {idx+1}/{len(test_dataset)}\")\n",
    "    text_image = pipe(captions_from_brain[idx], guidance_scale=9, negative_prompt = \"monochrome, lowres, bad anatomy, worst quality, low quality\"\n",
    "    ).images\n",
    "    img_image= pipe_img(captions_from_brain[idx],image=guessed[idx], strength=0.6, guidance_scale=9, negative_prompt = \"monochrome, lowres, bad anatomy, worst quality, low quality\"\n",
    "    ).images\n",
    "\n",
    "    multi_cond_images=[depth_guessed[idx].quantize(64).convert(\"RGB\").filter(ImageFilter.SMOOTH_MORE),guessed[idx],guessed[idx]]\n",
    "\n",
    "    control_image=depth_pipe(captions_from_brain[idx], num_inference_steps=30,  image=multi_cond_images,\n",
    "                           guidance_scale=9,negative_prompt = \"monochrome, lowres, bad anatomy, worst quality, low quality\",\n",
    "                          controlnet_conditioning_scale=[0.7,0.2,0.5]).images\n",
    "\n",
    "    depth_image=depth_pipe(captions_from_brain[idx], num_inference_steps=30,  image=multi_cond_images,\n",
    "                           guidance_scale=9,negative_prompt = \"monochrome, lowres, bad anatomy, worst quality, low quality\",\n",
    "                          controlnet_conditioning_scale=[0.7,0,0]).images\n",
    "\n",
    "    style_image=[]\n",
    "    for style in enumerate(styles):\n",
    "        style_prompt=captions_from_brain[idx]+f\",{style}.\"\n",
    "        style_image+=depth_pipe(style_prompt, num_inference_steps=30,  image=multi_cond_images,\n",
    "                           guidance_scale=9,negative_prompt = \"monochrome, lowres, bad anatomy, worst quality, low quality\",\n",
    "                          controlnet_conditioning_scale=[0.7,0,0]).images\n",
    "\n",
    "    style_imgs=(wandb.Image(x) for x in style_image)\n",
    "    table.add_data(captions_from_images[idx],captions_from_brain[idx],\n",
    "                   wandb.Image(test_imgs[idx].permute(1,2,0).numpy()), #ref\n",
    "                   wandb.Image(guessed[idx]), #init\n",
    "                   wandb.Image(depth_guessed[idx]), #depth\n",
    "                   wandb.Image(img_image[0]),\n",
    "                   wandb.Image(text_image[0]),\n",
    "                   wandb.Image(depth_image[0]),\n",
    "                   wandb.Image(control_image[0]),\n",
    "                   *style_imgs\n",
    "                  )\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f8f00b20-e882-4b90-938b-736436c81ad0",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "wandb.log({\"table\":table})"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1ade0586-b6d5-4579-98e0-8871ff7b801c",
   "metadata": {},
   "source": [
    "#### Example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "44fd9b69-4e58-40ee-8e6d-15afa7b27777",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "idx=np.random.randint(len(test_dataset))\n",
    "\n",
    "\n",
    "\n",
    "print(captions_from_brain[idx])\n",
    "print(captions_from_images[idx])\n",
    "print(test_captions[idx])\n",
    "\n",
    "\n",
    "text_image = pipe(captions_from_brain[idx], guidance_scale=9, negative_prompt = \"monochrome, lowres, bad anatomy, worst quality, low quality\"\n",
    ").images\n",
    "img_image= pipe_img(captions_from_brain[idx],image=guessed[idx], strength=0.6, guidance_scale=9, negative_prompt = \"monochrome, lowres, bad anatomy, worst quality, low quality\"\n",
    ").images\n",
    "\n",
    "multi_cond_images=[depth_guessed[idx].quantize(64).convert(\"RGB\").filter(ImageFilter.SMOOTH_MORE),guessed[idx],guessed[idx]]\n",
    "\n",
    "control_image=depth_pipe(captions_from_brain[idx], num_inference_steps=30,  image=multi_cond_images,\n",
    "                       guidance_scale=9,negative_prompt = \"monochrome, lowres, bad anatomy, worst quality, low quality\",\n",
    "                      controlnet_conditioning_scale=[0.7,0.2,0.5]).images\n",
    "\n",
    "depth_image=depth_pipe(captions_from_brain[idx], num_inference_steps=30,  image=multi_cond_images,\n",
    "                       guidance_scale=9,negative_prompt = \"monochrome, lowres, bad anatomy, worst quality, low quality\",\n",
    "                      controlnet_conditioning_scale=[0.7,0,0]).images\n",
    "\n",
    "style_image=[]\n",
    "for style in enumerate(styles):\n",
    "    style_prompt=captions_from_brain[idx]+f\",{style} style.\"\n",
    "    style_image+=depth_pipe(style_prompt, num_inference_steps=30,  image=multi_cond_images,\n",
    "                       guidance_scale=9,negative_prompt = \"monochrome, lowres, bad anatomy, worst quality, low quality\",\n",
    "                      controlnet_conditioning_scale=[0.7,0,0]).images\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6e634726-2d84-46df-8736-3dec2c3863ea",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "\n",
    "# Create the figure with 3 columns and 10 rows\n",
    "fig, ax_row = plt.subplots(nrows=3, ncols=5, figsize=(30,30))\n",
    "\n",
    "\n",
    "\n",
    "ax_row[0,0].set_title(\"Caption from image\")\n",
    "ax_row[0,1].set_title(\"Reference image\",color=\"red\")\n",
    "ax_row[0,2].set_title(\"Init Reconstruction\")\n",
    "ax_row[0,3].set_title(\"Depth\")\n",
    "ax_row[0,4].set_title(\"SD_img_to_img\")\n",
    "\n",
    "\n",
    "ax_row[1,0].set_title(\"Caption from brain\")\n",
    "ax_row[1,1].set_title(\"SD_only_text\")\n",
    "ax_row[1,2].set_title(\"SD_controldepth\")\n",
    "ax_row[1,3].set_title(f\"SD_control\")\n",
    "ax_row[1,4].set_title(f\"SD_control_{styles[0]}\")\n",
    "\n",
    "\n",
    "ax_row[2,0].set_title(f\"SD_control_{styles[1]}\")\n",
    "ax_row[2,1].set_title(f\"SD_control_{styles[2]}\")\n",
    "ax_row[2,2].set_title(f\"SD_control_{styles[3]}\")\n",
    "ax_row[2,3].set_title(f\"SD_control_{styles[4]}\")\n",
    "ax_row[2,4].set_title(f\"SD_control_{styles[5]}\")\n",
    "\n",
    "# ax_row[2,1].set_title(\"SD_depth\")\n",
    "# ax_row[2,1].set_title(\"SD_depth\")\n",
    "# ax_row[2,2].set_title(\"SD_controlnet\")\n",
    "# ax_row[2,3].set_title(f\"SD_controlnet_{styles[0]}\")\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "ax_row[0,0].text(0.5, 0.5, captions_from_images[idx], ha='center', va='center', fontsize=15, wrap=True)\n",
    "ax_row[0,0].axis('off')\n",
    "\n",
    "ax_row[1,0].text(0.5, 0.5, captions_from_brain[idx], ha='center', va='center', fontsize=15, wrap=True)\n",
    "ax_row[1,0].axis('off')\n",
    "\n",
    "# wrapText(ax_row)\n",
    "\n",
    "# Add the images to the second and third columns\n",
    "ax_row[0,1].imshow(test_imgs[idx].permute(1,2,0))\n",
    "ax_row[0,1].axis('off')\n",
    "\n",
    "ax_row[0,2].imshow(guessed[idx])\n",
    "ax_row[0,2].axis('off')\n",
    "\n",
    "ax_row[0,3].imshow(depth_guessed[idx])\n",
    "ax_row[0,3].axis('off')\n",
    "\n",
    "ax_row[0,4].imshow(img_image[0])\n",
    "ax_row[0,4].axis('off')\n",
    "\n",
    "## second row\n",
    "\n",
    "ax_row[1,1].imshow(text_image[0])\n",
    "ax_row[1,1].axis('off')\n",
    "\n",
    "ax_row[1,2].imshow(depth_image[0])\n",
    "ax_row[1,2].axis('off')\n",
    "\n",
    "ax_row[1,3].imshow(control_image[0])\n",
    "ax_row[1,3].axis('off')\n",
    "\n",
    "ax_row[1,3].imshow(style_image[0])\n",
    "ax_row[1,3].axis('off')\n",
    "\n",
    "ax_row[1,4].imshow(style_image[1])\n",
    "ax_row[1,4].axis('off')\n",
    "\n",
    "## third row\n",
    "\n",
    "ax_row[2,0].imshow(style_image[1])\n",
    "ax_row[2,0].axis('off')\n",
    "\n",
    "ax_row[2,1].imshow(style_image[2])\n",
    "ax_row[2,1].axis('off')\n",
    "\n",
    "ax_row[2,2].imshow(style_image[3])\n",
    "ax_row[2,2].axis('off')\n",
    "\n",
    "ax_row[2,3].imshow(style_image[4])\n",
    "ax_row[2,3].axis('off')\n",
    "\n",
    "ax_row[2,4].imshow(style_image[5])\n",
    "ax_row[2,4].axis('off')\n",
    "\n",
    "\n",
    "\n",
    "# Adjust the spacing between subplots\n",
    "fig.subplots_adjust(hspace=0.5, wspace=0.2)\n",
    "# plt.savefig(\"captions.png\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8bf92ee7-b1a8-4aa9-8b05-9b624ed72234",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "len(styles)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ai",
   "language": "python",
   "name": "ai"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
